Initial import
This commit is contained in:
64
cmd/protodial/main.go
Normal file
64
cmd/protodial/main.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"git.maze.io/go/dpi/protocol"
|
||||
)
|
||||
|
||||
func main() {
|
||||
acceptFlag := flag.String("accept", "", "comma separated list of accepted protocols")
|
||||
flag.Parse()
|
||||
|
||||
if flag.NArg() != 2 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <host> <port>\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
accept := make(map[string]bool)
|
||||
acceptFlags := strings.Split(*acceptFlag, ",")
|
||||
if len(acceptFlags) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "No -accept was provided, refusing all protocols!")
|
||||
} else {
|
||||
for _, proto := range acceptFlags {
|
||||
accept[proto] = true
|
||||
}
|
||||
}
|
||||
|
||||
c, err := net.Dial("tcp", net.JoinHostPort(flag.Arg(0), flag.Arg(1)))
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
c = protocol.Limit(c, func(dir protocol.Direction, p *protocol.Protocol) error {
|
||||
if p == nil {
|
||||
return errors.New("No protocol detected")
|
||||
}
|
||||
if !accept[p.Name] {
|
||||
return fmt.Errorf("Protocol %s is not accepted", p.Name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Accepting protocol %s version %s initiated by %s\n",
|
||||
p.Name, p.Version, dir)
|
||||
return nil
|
||||
})
|
||||
defer c.Close()
|
||||
|
||||
var wait sync.WaitGroup
|
||||
wait.Go(func() { multiplex(c, os.Stdin) })
|
||||
wait.Go(func() { multiplex(os.Stdout, c) })
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
func multiplex(w io.Writer, r io.Reader) {
|
||||
if _, err := io.Copy(w, r); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Fatalln("Copy terminated:", err)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user