diff --git a/cmd/dpi-protocol-probe/main.go b/cmd/dpi-protocol-probe/main.go new file mode 100644 index 0000000..f7feaaf --- /dev/null +++ b/cmd/dpi-protocol-probe/main.go @@ -0,0 +1,44 @@ +package main + +import ( + "flag" + "fmt" + "net" + "os" + "time" + + "git.maze.io/go/dpi/protocol" +) + +func main() { + networkFlag := flag.String("network", "tcp", "Network type to use for probing") + timeoutFlag := flag.Duration("timeout", 30*time.Second, "Timeout") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s []
:\n\n", os.Args[0]) + fmt.Fprintln(os.Stderr, "Available flags:") + flag.PrintDefaults() + fmt.Fprintln(os.Stderr, "\nRequired arguments:") + fmt.Fprintln(os.Stderr, " address string") + fmt.Fprintln(os.Stderr, "\tNetwork address to connect to (ie localhost:22)") + os.Exit(0) + } + flag.Parse() + + if flag.NArg() != 1 { + flag.Usage() + } + + address := flag.Arg(0) + if _, _, err := net.SplitHostPort(address); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(2) + } + + protocol, confidence, err := protocol.ProbeTimeout(*networkFlag, address, *timeoutFlag) + if err != nil { + fmt.Fprintf(os.Stderr, "Probing address %q failed: %v\n", address, err) + os.Exit(3) + } + + fmt.Printf("Protocol at address %q is %s version %s (confidence %g%%)\n", address, protocol.Name, protocol.Version, confidence*100) +} diff --git a/protocol/probe.go b/protocol/probe.go new file mode 100644 index 0000000..afb6d5a --- /dev/null +++ b/protocol/probe.go @@ -0,0 +1,70 @@ +package protocol + +import ( + "context" + "crypto/tls" + "net" + "time" +) + +// Dialer used by probes to establish a connection. +var Dialer net.Dialer + +// Probe a network service by reading its banner and running protocol detection. +func Probe(network, address string) (proto *Protocol, confidence float64, err error) { + return ProbeContext(context.Background(), network, address) +} + +// ProbeContext is like [Probe] with a [context.Context]. +func ProbeContext(ctx context.Context, network, address string) (proto *Protocol, confidence float64, err error) { + var conn net.Conn + if conn, err = Dialer.DialContext(ctx, network, address); err != nil { + return + } + defer func() { _ = conn.Close() }() + return probeConn(conn) +} + +// ProbeTimeout is like [Probe] but with a fixed timeout. +func ProbeTimeout(network, address string, timeout time.Duration) (proto *Protocol, confidence float64, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return ProbeContext(ctx, network, address) +} + +// ProbeTLS is like [Probe] but first establishes a TLS connection. +func ProbeTLS(network, address string, tlsConfig *tls.Config) (proto *Protocol, confidence float64, err error) { + return ProbeTLSContext(context.Background(), network, address, tlsConfig) +} + +// ProbeTLSContext is like [ProbeTLS] with a [context.Context]. +func ProbeTLSContext(ctx context.Context, network, address string, tlsConfig *tls.Config) (proto *Protocol, confidence float64, err error) { + var conn net.Conn + if conn, err = Dialer.DialContext(ctx, network, address); err != nil { + return + } + defer func() { _ = conn.Close() }() + secure := tls.Client(conn, tlsConfig) + if err = secure.Handshake(); err != nil { + return + } + return probeConn(secure) +} + +// ProbeTLSTimeout is like [ProbeTLS] but with a fixed timeout. +func ProbeTLSTimeout(network, address string, tlsConfig *tls.Config, timeout time.Duration) (proto *Protocol, confidence float64, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return ProbeTLSContext(ctx, network, address, tlsConfig) +} + +func probeConn(conn net.Conn) (proto *Protocol, confidence float64, err error) { + var ( + data = make([]byte, 2048) + n int + ) + if n, err = conn.Read(data); err != nil { + return + } + return Detect(Client, data[:n], getPortFromAddr(conn.LocalAddr()), getPortFromAddr(conn.RemoteAddr())) +}