Checkpoint
This commit is contained in:
227
proxy/context.go
Normal file
227
proxy/context.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil/arp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Context provides convenience functions for the current ongoing HTTP proxy transaction (request).
|
||||
type Context interface {
|
||||
// Conn is the backing connection for this context.
|
||||
net.Conn
|
||||
|
||||
// ID is a unique connection identifier.
|
||||
ID() uint64
|
||||
|
||||
// Reader returns a buffered reader on top of the [net.Conn].
|
||||
Reader() *bufio.Reader
|
||||
|
||||
// BytesRead returns the number of bytes read.
|
||||
BytesRead() int64
|
||||
|
||||
// BytesSent returns the number of bytes written.
|
||||
BytesSent() int64
|
||||
|
||||
// TLSState returns the TLS connection state, it returns nil if the connection is not a TLS connection.
|
||||
TLSState() *tls.ConnectionState
|
||||
|
||||
// Request is the request made to the proxy.
|
||||
Request() *http.Request
|
||||
|
||||
// Response is the response that will be sent back to the client.
|
||||
Response() *http.Response
|
||||
}
|
||||
|
||||
type countingReader struct {
|
||||
reader io.Reader
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func (r *countingReader) Read(p []byte) (n int, err error) {
|
||||
if n, err = r.reader.Read(p); n > 0 {
|
||||
atomic.AddInt64(&r.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type countingWriter struct {
|
||||
writer io.Writer
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func (w *countingWriter) Write(p []byte) (n int, err error) {
|
||||
if n, err = w.writer.Write(p); n > 0 {
|
||||
atomic.AddInt64(&w.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type proxyContext struct {
|
||||
net.Conn
|
||||
id uint64
|
||||
mac net.HardwareAddr
|
||||
cr *countingReader
|
||||
br *bufio.Reader
|
||||
cw *countingWriter
|
||||
isTransparent bool
|
||||
isTransparentTLS bool
|
||||
serverName string
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewContext returns an initialized context for the provided [net.Conn].
|
||||
func NewContext(c net.Conn) Context {
|
||||
if c, ok := c.(*proxyContext); ok {
|
||||
return c
|
||||
}
|
||||
|
||||
b := make([]byte, 8)
|
||||
io.ReadFull(rand.Reader, b)
|
||||
|
||||
cr := &countingReader{reader: c}
|
||||
cw := &countingWriter{writer: c}
|
||||
return &proxyContext{
|
||||
Conn: c,
|
||||
id: binary.BigEndian.Uint64(b),
|
||||
mac: arp.Get(c.RemoteAddr()),
|
||||
cr: cr,
|
||||
br: bufio.NewReader(cr),
|
||||
cw: cw,
|
||||
res: &http.Response{StatusCode: 200},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *proxyContext) AccessLogEntry() *logrus.Entry {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
entry := AccessLog.WithFields(logrus.Fields{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
"id": hex.EncodeToString(id[:]),
|
||||
"bytes_rx": c.BytesRead(),
|
||||
"bytes_tx": c.BytesSent(),
|
||||
})
|
||||
if c.mac != nil {
|
||||
return entry.WithField("client_mac", c.mac.String())
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *proxyContext) LogEntry() *logrus.Entry {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
return ServerLog.WithFields(logrus.Fields{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
"id": hex.EncodeToString(id[:]),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *proxyContext) String() string {
|
||||
return fmt.Sprintf("client=%s server=%s id=%#08x",
|
||||
c.RemoteAddr().String(),
|
||||
c.LocalAddr().String(),
|
||||
c.id)
|
||||
}
|
||||
|
||||
func (c *proxyContext) ID() uint64 {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *proxyContext) BytesRead() int64 {
|
||||
return atomic.LoadInt64(&c.cr.bytes)
|
||||
}
|
||||
|
||||
func (c *proxyContext) BytesSent() int64 {
|
||||
return atomic.LoadInt64(&c.cw.bytes)
|
||||
}
|
||||
|
||||
func (c *proxyContext) Read(p []byte) (n int, err error) {
|
||||
if c.idleTimeout > 0 {
|
||||
if err = c.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return c.br.Read(p)
|
||||
}
|
||||
|
||||
func (c *proxyContext) Write(p []byte) (n int, err error) {
|
||||
if c.idleTimeout > 0 {
|
||||
if err = c.SetWriteDeadline(time.Now().Add(c.idleTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return c.cw.Write(p)
|
||||
}
|
||||
|
||||
func (c *proxyContext) Reader() *bufio.Reader {
|
||||
return c.br
|
||||
}
|
||||
|
||||
func (c *proxyContext) Request() *http.Request {
|
||||
return c.req
|
||||
}
|
||||
|
||||
func (c *proxyContext) SetRequest(req *http.Request) {
|
||||
c.req = req
|
||||
}
|
||||
|
||||
func (c *proxyContext) Response() *http.Response {
|
||||
return c.res
|
||||
}
|
||||
|
||||
func (c *proxyContext) SetIdleTimeout(t time.Duration) {
|
||||
c.idleTimeout = t
|
||||
}
|
||||
|
||||
type connectionStater interface {
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
|
||||
func (c *proxyContext) TLSState() *tls.ConnectionState {
|
||||
if s, ok := c.Conn.(connectionStater); ok {
|
||||
state := s.ConnectionState()
|
||||
return &state
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// http.ResponseWriter interface:
|
||||
|
||||
func (c *proxyContext) Header() http.Header {
|
||||
if c.res == nil {
|
||||
c.res = NewResponse(http.StatusOK, nil, c.req)
|
||||
}
|
||||
return c.res.Header
|
||||
}
|
||||
|
||||
func (c *proxyContext) WriteHeader(code int) {
|
||||
if c.res == nil {
|
||||
c.res = NewResponse(code, nil, c.req)
|
||||
} else {
|
||||
if text := http.StatusText(code); text != "" {
|
||||
c.res.Status = strconv.Itoa(code) + " " + text
|
||||
} else {
|
||||
c.res.Status = strconv.Itoa(code)
|
||||
}
|
||||
c.res.StatusCode = code
|
||||
}
|
||||
//return c.res.Header.Write(c)
|
||||
}
|
||||
|
||||
var _ Context = (*proxyContext)(nil)
|
Reference in New Issue
Block a user