package netutil import ( "bufio" "errors" "io" "net" "sync/atomic" "syscall" "time" "git.maze.io/maze/styx/logger" ) // BufferedConn uses byte buffers for Read and Write operations on a [net.Conn]. type BufferedConn struct { net.Conn Reader *bufio.Reader Writer *bufio.Writer } func NewBufferedConn(c net.Conn) *BufferedConn { if b, ok := c.(*BufferedConn); ok { return b } return &BufferedConn{ Conn: c, Reader: bufio.NewReader(c), Writer: bufio.NewWriter(c), } } func (conn BufferedConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) } func (conn BufferedConn) Write(p []byte) (int, error) { return conn.Writer.Write(p) } func (conn BufferedConn) Flush() error { return conn.Writer.Flush() } func (conn BufferedConn) NetConn() net.Conn { return conn.Conn } // ReaderConn is a [net.Conn] with a separate [io.Reader] to read from. type ReaderConn struct { net.Conn io.Reader } func (conn ReaderConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) } func (conn ReaderConn) NetConn() net.Conn { return conn.Conn } // ReadOnlyConn only allows reading, all other operations will fail. type ReadOnlyConn struct { io.Reader } func (conn ReadOnlyConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) } func (conn ReadOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } func (conn ReadOnlyConn) Close() error { return nil } func (conn ReadOnlyConn) LocalAddr() net.Addr { return nil } func (conn ReadOnlyConn) RemoteAddr() net.Addr { return nil } func (conn ReadOnlyConn) SetDeadline(_ time.Time) error { return nil } func (conn ReadOnlyConn) SetReadDeadline(_ time.Time) error { return nil } func (conn ReadOnlyConn) SetWriteDeadline(_ time.Time) error { return nil } func (conn ReadOnlyConn) NetConn() net.Conn { if c, ok := conn.Reader.(net.Conn); ok { return c } return nil } var _ net.Conn = (*ReadOnlyConn)(nil) type Pipe struct { Reader *io.PipeReader Writer *io.PipeWriter } func (conn Pipe) Read(p []byte) (int, error) { return conn.Reader.Read(p) } func (conn Pipe) Write(p []byte) (int, error) { return conn.Writer.Write(p) } func (conn Pipe) Close() error { if err := conn.Writer.Close(); err != nil { _ = conn.Reader.Close() return err } if err := conn.Reader.Close(); err != nil { return err } return nil } func (conn Pipe) LocalAddr() net.Addr { return &net.UnixAddr{Name: "pipe"} } func (conn Pipe) RemoteAddr() net.Addr { return conn.LocalAddr() } func (conn Pipe) SetDeadline(_ time.Time) error { return nil } func (conn Pipe) SetReadDeadline(_ time.Time) error { return nil } func (conn Pipe) SetWriteDeadline(_ time.Time) error { return nil } var _ net.Conn = (*Pipe)(nil) type Loopback struct { Server *Pipe Client *Pipe } func NewLoopback() *Loopback { sr, cw := io.Pipe() cr, sw := io.Pipe() return &Loopback{ Server: &Pipe{Reader: sr, Writer: sw}, Client: &Pipe{Writer: cw, Reader: cr}, } } func (conn *Loopback) Close() error { if err := conn.Server.Close(); err != nil { _ = conn.Client.Close() return err } if err := conn.Client.Close(); err != nil { return err } return nil } type AcceptOnce struct { net.Conn once atomic.Bool } func (listener *AcceptOnce) Accept() (net.Conn, error) { log := logger.StandardLog.Value("client", listener.Conn.RemoteAddr().String()) if listener.once.Load() { log.Trace("Accept already happened, responding EOF") return nil, io.EOF } listener.once.Store(true) log.Trace("Accept client") return listener.Conn, nil } func (listener *AcceptOnce) Addr() net.Addr { return listener.Conn.LocalAddr() } var _ net.Listener = (*AcceptOnce)(nil) func IsClosing(err error) bool { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err.Error() != "proxy: shutdown" { return true } if err, ok := err.(net.Error); ok && err.Timeout() { return true } // log.Debug().Msgf("not a closing error %T: %#+v", err, err) return false } // WithTimeout is a convenience wrapper for doing network operations that observe a timeout. func WithTimeout(c net.Conn, timeout time.Duration, do func() error) error { if timeout <= 0 { return do() } if err := c.SetDeadline(time.Now().Add(timeout)); err != nil { return err } if err := do(); err != nil { _ = c.SetDeadline(time.Time{}) return err } if err := c.SetDeadline(time.Time{}); err != nil { return err } return nil }