Files
styx/internal/netutil/conn.go
2025-10-06 22:25:23 +02:00

173 lines
4.5 KiB
Go

package netutil
import (
"bufio"
"errors"
"io"
"net"
"sync/atomic"
"syscall"
"time"
"git.maze.io/maze/styx/logger"
)
// BufferedConn uses byte buffers for Read and Write operations on a [net.Conn].
type BufferedConn struct {
net.Conn
Reader *bufio.Reader
Writer *bufio.Writer
}
func NewBufferedConn(c net.Conn) *BufferedConn {
if b, ok := c.(*BufferedConn); ok {
return b
}
return &BufferedConn{
Conn: c,
Reader: bufio.NewReader(c),
Writer: bufio.NewWriter(c),
}
}
func (conn BufferedConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn BufferedConn) Write(p []byte) (int, error) { return conn.Writer.Write(p) }
func (conn BufferedConn) Flush() error { return conn.Writer.Flush() }
func (conn BufferedConn) NetConn() net.Conn { return conn.Conn }
// ReaderConn is a [net.Conn] with a separate [io.Reader] to read from.
type ReaderConn struct {
net.Conn
io.Reader
}
func (conn ReaderConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn ReaderConn) NetConn() net.Conn { return conn.Conn }
// ReadOnlyConn only allows reading, all other operations will fail.
type ReadOnlyConn struct {
io.Reader
}
func (conn ReadOnlyConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn ReadOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn ReadOnlyConn) Close() error { return nil }
func (conn ReadOnlyConn) LocalAddr() net.Addr { return nil }
func (conn ReadOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn ReadOnlyConn) SetDeadline(_ time.Time) error { return nil }
func (conn ReadOnlyConn) SetReadDeadline(_ time.Time) error { return nil }
func (conn ReadOnlyConn) SetWriteDeadline(_ time.Time) error { return nil }
func (conn ReadOnlyConn) NetConn() net.Conn {
if c, ok := conn.Reader.(net.Conn); ok {
return c
}
return nil
}
var _ net.Conn = (*ReadOnlyConn)(nil)
type Pipe struct {
Reader *io.PipeReader
Writer *io.PipeWriter
}
func (conn Pipe) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn Pipe) Write(p []byte) (int, error) { return conn.Writer.Write(p) }
func (conn Pipe) Close() error {
if err := conn.Writer.Close(); err != nil {
_ = conn.Reader.Close()
return err
}
if err := conn.Reader.Close(); err != nil {
return err
}
return nil
}
func (conn Pipe) LocalAddr() net.Addr { return &net.UnixAddr{Name: "pipe"} }
func (conn Pipe) RemoteAddr() net.Addr { return conn.LocalAddr() }
func (conn Pipe) SetDeadline(_ time.Time) error { return nil }
func (conn Pipe) SetReadDeadline(_ time.Time) error { return nil }
func (conn Pipe) SetWriteDeadline(_ time.Time) error { return nil }
var _ net.Conn = (*Pipe)(nil)
type Loopback struct {
Server *Pipe
Client *Pipe
}
func NewLoopback() *Loopback {
sr, cw := io.Pipe()
cr, sw := io.Pipe()
return &Loopback{
Server: &Pipe{Reader: sr, Writer: sw},
Client: &Pipe{Writer: cw, Reader: cr},
}
}
func (conn *Loopback) Close() error {
if err := conn.Server.Close(); err != nil {
_ = conn.Client.Close()
return err
}
if err := conn.Client.Close(); err != nil {
return err
}
return nil
}
type AcceptOnce struct {
net.Conn
once atomic.Bool
}
func (listener *AcceptOnce) Accept() (net.Conn, error) {
log := logger.StandardLog.Value("client", listener.Conn.RemoteAddr().String())
if listener.once.Load() {
log.Trace("Accept already happened, responding EOF")
return nil, io.EOF
}
listener.once.Store(true)
log.Trace("Accept client")
return listener.Conn, nil
}
func (listener *AcceptOnce) Addr() net.Addr {
return listener.Conn.LocalAddr()
}
var _ net.Listener = (*AcceptOnce)(nil)
func IsClosing(err error) bool {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err.Error() != "proxy: shutdown" {
return true
}
if err, ok := err.(net.Error); ok && err.Timeout() {
return true
}
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
return false
}
// WithTimeout is a convenience wrapper for doing network operations that observe a timeout.
func WithTimeout(c net.Conn, timeout time.Duration, do func() error) error {
if timeout <= 0 {
return do()
}
if err := c.SetDeadline(time.Now().Add(timeout)); err != nil {
return err
}
if err := do(); err != nil {
_ = c.SetDeadline(time.Time{})
return err
}
if err := c.SetDeadline(time.Time{}); err != nil {
return err
}
return nil
}