package protocol import ( "net" "sync" "sync/atomic" ) // AcceptFunc receives a direction and a detected protocol. type AcceptFunc func(Direction, *Protocol) error // Limit the connection protocol, by running a detection after either side sends // a banner within timeout. // // If no protocol could be detected, the accept function is called with a nil // argument to check if we should proceed. // // If the accept function returns false, the connection will be closed. func Limit(conn net.Conn, accept AcceptFunc) net.Conn { if accept == nil { // Nothing to do here. return conn } return &connLimiter{ Conn: conn, accept: accept, } } type connLimiter struct { net.Conn accept AcceptFunc acceptOnce sync.Once acceptError atomic.Value } func (l *connLimiter) init(readData, writeData []byte) { l.acceptOnce.Do(func() { var ( dir Direction data []byte ) if readData != nil { // init called by initial read dir, data = Server, readData } else { // init called by initial write dir, data = Client, writeData } protocol, _ := Detect(dir, data) if err := l.accept(dir, protocol); err != nil { l.acceptError.Store(err) } }) } func (l *connLimiter) Read(p []byte) (n int, err error) { var ok bool if err, ok = l.acceptError.Load().(error); ok && err != nil { return } if n, err = l.Conn.Read(p); n > 0 { l.init(p[:n], nil) } return } func (l *connLimiter) Write(p []byte) (n int, err error) { l.init(nil, p) var ok bool if err, ok = l.acceptError.Load().(error); ok && err != nil { return } return l.Conn.Write(p) }