package protocol import ( "net" "sync" "sync/atomic" ) // AcceptFunc receives a direction and a detected protocol. type AcceptFunc func(Direction, *Protocol) error // Limit the connection protocol, by running a detection after either side sends // a banner within timeout. // // If no protocol could be detected, the accept function is called with a nil // argument to check if we should proceed. // // If the accept function returns false, the connection will be closed. func Limit(conn net.Conn, accept AcceptFunc) net.Conn { if accept == nil { // Nothing to do here. return conn } return &connLimiter{ Conn: conn, accept: accept, } } type connLimiter struct { net.Conn accept AcceptFunc acceptOnce sync.Once acceptError atomic.Value } func (limiter *connLimiter) init(readData, writeData []byte) { limiter.acceptOnce.Do(func() { var ( dir Direction data []byte srcPort, dstPort int ) if readData != nil { // init called by initial read dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr()) } else { // init called by initial write dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr()) } protocol, _, _ := Detect(dir, data, srcPort, dstPort) if err := limiter.accept(dir, protocol); err != nil { limiter.acceptError.Store(err) } }) } func (limiter *connLimiter) Read(p []byte) (n int, err error) { var ok bool if err, ok = limiter.acceptError.Load().(error); ok && err != nil { return } if n, err = limiter.Conn.Read(p); n > 0 { limiter.init(p[:n], nil) } return } func (limiter *connLimiter) Write(p []byte) (n int, err error) { limiter.init(nil, p) var ok bool if err, ok = limiter.acceptError.Load().(error); ok && err != nil { return } return limiter.Conn.Write(p) }