package main import ( "errors" "flag" "fmt" "io" "log" "net" "os" "strings" "sync" "git.maze.io/go/dpi/protocol" ) func main() { acceptFlag := flag.String("accept", "", "comma separated list of accepted protocols") flag.Parse() if flag.NArg() != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } accept := make(map[string]bool) acceptFlags := strings.Split(*acceptFlag, ",") if len(acceptFlags) == 0 { fmt.Fprintln(os.Stderr, "No -accept was provided, refusing all protocols!") } else { for _, proto := range acceptFlags { accept[proto] = true } } c, err := net.Dial("tcp", net.JoinHostPort(flag.Arg(0), flag.Arg(1))) if err != nil { log.Fatalln(err) } c = protocol.Limit(c, func(dir protocol.Direction, p *protocol.Protocol) error { if p == nil { return errors.New("no protocol detected") } if !accept[p.Type] { return fmt.Errorf("protocol %s is not accepted", p.Type) } fmt.Fprintf(os.Stderr, "Accepting protocol %s version %s initiated by %s\n", p.Type, p.Version, dir) return nil }) defer func() { _ = c.Close() }() var wait sync.WaitGroup wait.Go(func() { multiplex(c, os.Stdin) }) wait.Go(func() { multiplex(os.Stdout, c) }) wait.Wait() } func multiplex(w io.Writer, r io.Reader) { if _, err := io.Copy(w, r); err != nil && !errors.Is(err, io.EOF) { log.Fatalln("Copy terminated:", err) } }