65 lines
1.4 KiB
Go
65 lines
1.4 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"git.maze.io/go/dpi/protocol"
|
|
)
|
|
|
|
func main() {
|
|
acceptFlag := flag.String("accept", "", "comma separated list of accepted protocols")
|
|
flag.Parse()
|
|
|
|
if flag.NArg() != 2 {
|
|
fmt.Fprintf(os.Stderr, "Usage: %s <host> <port>\n", os.Args[0])
|
|
os.Exit(1)
|
|
}
|
|
|
|
accept := make(map[string]bool)
|
|
acceptFlags := strings.Split(*acceptFlag, ",")
|
|
if len(acceptFlags) == 0 {
|
|
fmt.Fprintln(os.Stderr, "No -accept was provided, refusing all protocols!")
|
|
} else {
|
|
for _, proto := range acceptFlags {
|
|
accept[proto] = true
|
|
}
|
|
}
|
|
|
|
c, err := net.Dial("tcp", net.JoinHostPort(flag.Arg(0), flag.Arg(1)))
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
|
|
c = protocol.Limit(c, func(dir protocol.Direction, p *protocol.Protocol) error {
|
|
if p == nil {
|
|
return errors.New("No protocol detected")
|
|
}
|
|
if !accept[p.Name] {
|
|
return fmt.Errorf("Protocol %s is not accepted", p.Name)
|
|
}
|
|
fmt.Fprintf(os.Stderr, "Accepting protocol %s version %s initiated by %s\n",
|
|
p.Name, p.Version, dir)
|
|
return nil
|
|
})
|
|
defer c.Close()
|
|
|
|
var wait sync.WaitGroup
|
|
wait.Go(func() { multiplex(c, os.Stdin) })
|
|
wait.Go(func() { multiplex(os.Stdout, c) })
|
|
wait.Wait()
|
|
}
|
|
|
|
func multiplex(w io.Writer, r io.Reader) {
|
|
if _, err := io.Copy(w, r); err != nil && !errors.Is(err, io.EOF) {
|
|
log.Fatalln("Copy terminated:", err)
|
|
}
|
|
}
|