Fixed up the documentation and refactored Limit
This commit is contained in:
		@@ -15,7 +15,8 @@ type AcceptFunc func(Direction, *Protocol) error
 | 
			
		||||
// If no protocol could be detected, the accept function is called with a nil
 | 
			
		||||
// argument to check if we should proceed.
 | 
			
		||||
//
 | 
			
		||||
// If the accept function returns false, the connection will be closed.
 | 
			
		||||
// If the accept function returns an error, all future reads and writes on the
 | 
			
		||||
// returned connection will return that error.
 | 
			
		||||
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
 | 
			
		||||
	if accept == nil {
 | 
			
		||||
		// Nothing to do here.
 | 
			
		||||
@@ -38,24 +39,33 @@ type connLimiter struct {
 | 
			
		||||
func (limiter *connLimiter) init(readData, writeData []byte) {
 | 
			
		||||
	limiter.acceptOnce.Do(func() {
 | 
			
		||||
		var (
 | 
			
		||||
			dir              Direction
 | 
			
		||||
			data             []byte
 | 
			
		||||
			srcPort, dstPort int
 | 
			
		||||
			dir        Direction
 | 
			
		||||
			localPort  = getPortFromAddr(limiter.LocalAddr())
 | 
			
		||||
			remotePort = getPortFromAddr(limiter.RemoteAddr())
 | 
			
		||||
			protocol   *Protocol
 | 
			
		||||
		)
 | 
			
		||||
		if readData != nil {
 | 
			
		||||
			// init called by initial read
 | 
			
		||||
			dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
 | 
			
		||||
			// init called by initial read (from the server)
 | 
			
		||||
			dir = Server
 | 
			
		||||
			protocol, _, _ = Detect(dir, readData, localPort, remotePort)
 | 
			
		||||
		} else {
 | 
			
		||||
			// init called by initial write
 | 
			
		||||
			dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
 | 
			
		||||
			// init called by initial write (from the client)
 | 
			
		||||
			dir = Client
 | 
			
		||||
			protocol, _, _ = Detect(dir, writeData, remotePort, localPort)
 | 
			
		||||
		}
 | 
			
		||||
		protocol, _, _ := Detect(dir, data, srcPort, dstPort)
 | 
			
		||||
		if err := limiter.accept(dir, protocol); err != nil {
 | 
			
		||||
			limiter.acceptError.Store(err)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn].
 | 
			
		||||
func (limiter *connLimiter) NetConn() net.Conn {
 | 
			
		||||
	return limiter.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Read from the connection, if this is the first read then the data returned by the underlying
 | 
			
		||||
// connection is used for protocol detection.
 | 
			
		||||
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
 | 
			
		||||
	var ok bool
 | 
			
		||||
	if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
 | 
			
		||||
@@ -67,6 +77,8 @@ func (limiter *connLimiter) Read(p []byte) (n int, err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write to the connection, if this is the first write then the data is used for protocol detection
 | 
			
		||||
// before it gets sent to the underlying connection.
 | 
			
		||||
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
 | 
			
		||||
	limiter.init(nil, p)
 | 
			
		||||
	var ok bool
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user