Browse Source

Upgrade to API version 52

maze 6 months ago
parent
commit
d3c39ef977
7 changed files with 298 additions and 258 deletions
  1. 0
    161
      cmd/filter-rbl/main.go
  2. 172
    0
      cmd/table-rbl/main.go
  3. 13
    13
      filter.go
  4. 2
    2
      filter_test.go
  5. 53
    51
      imsg.go
  6. 16
    12
      smtpd.go
  7. 42
    19
      table.go

+ 0
- 161
cmd/filter-rbl/main.go View File

@@ -1,161 +0,0 @@
1
-package main
2
-
3
-import (
4
-	"flag"
5
-	"fmt"
6
-	"log"
7
-	"net"
8
-	"os"
9
-	"strings"
10
-
11
-	lru "github.com/hashicorp/golang-lru"
12
-
13
-	"gopkg.in/opensmtpd.v0"
14
-)
15
-
16
-var (
17
-	prog = os.Args[0]
18
-	skip = []*net.IPNet{}
19
-	rbls = []string{
20
-		"b.barracudacentral.org",
21
-		"bl.spamcop.net",
22
-		"virbl.bit.nl",
23
-		"xbl.spamhaus.org",
24
-	}
25
-	debug bool
26
-	masq  bool
27
-	cache *lru.Cache
28
-)
29
-
30
-func debugf(fmt string, args ...interface{}) {
31
-	if !debug {
32
-		return
33
-	}
34
-	log.Printf("debug: "+fmt, args...)
35
-}
36
-
37
-func reverse(ip net.IP) string {
38
-	if ip.To4() == nil {
39
-		return ""
40
-	}
41
-
42
-	splitAddress := strings.Split(ip.String(), ".")
43
-
44
-	for i, j := 0, len(splitAddress)-1; i < len(splitAddress)/2; i, j = i+1, j-1 {
45
-		splitAddress[i], splitAddress[j] = splitAddress[j], splitAddress[i]
46
-	}
47
-
48
-	return strings.Join(splitAddress, ".")
49
-}
50
-
51
-func lookup(rbl string, host string) (result string, listed bool, err error) {
52
-	host = fmt.Sprintf("%s.%s", host, rbl)
53
-
54
-	var res []string
55
-	res, err = net.LookupHost(host)
56
-	if listed = len(res) > 0; listed {
57
-		txt, _ := net.LookupTXT(host)
58
-		if len(txt) > 0 {
59
-			result = txt[0]
60
-		}
61
-	}
62
-
63
-	// Expected error
64
-	if err != nil && strings.HasSuffix(err.Error(), ": no such host") {
65
-		err = nil
66
-	}
67
-
68
-	return
69
-}
70
-
71
-func onConnect(s *opensmtpd.Session, query *opensmtpd.ConnectQuery) error {
72
-	ip := query.Remote.(opensmtpd.Sockaddr).IP()
73
-	if ip == nil {
74
-		return nil
75
-	}
76
-
77
-	debugf("%s: connect from %s\n", prog, ip)
78
-
79
-	for _, ipnet := range skip {
80
-		if ipnet.Contains(ip) {
81
-			debugf("%s: skip %s, IP ignored", prog, ip)
82
-			return s.Accept()
83
-		}
84
-	}
85
-
86
-	var (
87
-		result string
88
-		listed bool
89
-		host   = reverse(ip)
90
-		err    error
91
-	)
92
-	for _, rbl := range rbls {
93
-		if result, listed, err = lookup(rbl, host); err != nil {
94
-			log.Printf("%s: %s failed %s: %v\n", prog, rbl, ip, err)
95
-		} else if listed {
96
-			log.Printf("%s: %s listed %s: %v\n", prog, rbl, ip, result)
97
-			cache.Add(s.ID, result)
98
-			break
99
-		}
100
-	}
101
-
102
-	debugf("%s: pass: %s\n", prog, ip)
103
-
104
-	if !listed {
105
-		// Add negative hit
106
-		cache.Add(s.ID, "")
107
-	}
108
-
109
-	return s.Accept()
110
-}
111
-
112
-func onDATA(s *opensmtpd.Session) error {
113
-	debugf("%s: %s DATA\n", prog, s)
114
-
115
-	if result, block := cache.Get(s.ID); block && result.(string) != "" {
116
-		return s.RejectCode(opensmtpd.FilterClose, 421, result.(string))
117
-	}
118
-
119
-	return s.Accept()
120
-}
121
-
122
-func main() {
123
-	cacheSize := flag.Int("cache-size", 1024, "LRU cache size")
124
-	rblServer := flag.String("servers", strings.Join(rbls, ","), "RBL servers")
125
-	ignoreIPs := flag.String("ignore", "127.0.0.0/8,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fe80::/64", "ignore IPs")
126
-	debugging := flag.Bool("d", false, "be verbose")
127
-	verbosity := flag.Bool("v", false, "be verbose")
128
-	flag.BoolVar(&masq, "masq", true, "masquerade SMTP banner")
129
-	flag.Parse()
130
-
131
-	debug = *debugging || *verbosity
132
-
133
-	var err error
134
-	if cache, err = lru.New(*cacheSize); err != nil {
135
-		log.Fatalln(err)
136
-	}
137
-
138
-	rbls = strings.Split(*rblServer, ",")
139
-
140
-	for _, prefix := range strings.Split(*ignoreIPs, ",") {
141
-		var ipnet *net.IPNet
142
-		if _, ipnet, err = net.ParseCIDR(prefix); err != nil {
143
-			log.Fatalln(err)
144
-		}
145
-		skip = append(skip, ipnet)
146
-		debugf("ignore: %s\n", ipnet)
147
-	}
148
-
149
-	filter := &opensmtpd.Filter{
150
-		Connect: onConnect,
151
-		DATA:    onDATA,
152
-	}
153
-
154
-	if err = filter.Register(); err != nil {
155
-		log.Fatalln(err)
156
-	}
157
-
158
-	if err = filter.Serve(); err != nil {
159
-		log.Fatalln(err)
160
-	}
161
-}

+ 172
- 0
cmd/table-rbl/main.go View File

@@ -0,0 +1,172 @@
1
+package main
2
+
3
+import (
4
+	"flag"
5
+	"fmt"
6
+	"io/ioutil"
7
+	"log"
8
+	"net"
9
+	"os"
10
+	"strings"
11
+
12
+	lru "github.com/hashicorp/golang-lru"
13
+	"github.com/hashicorp/hcl"
14
+	opensmtpd "gopkg.in/opensmtpd.v0"
15
+)
16
+
17
+var (
18
+	cache   *lru.Cache
19
+	ignored []*net.IPNet
20
+	config  struct {
21
+		Cache  int
22
+		Ignore []string
23
+		Accept []string
24
+		Reject []string
25
+	}
26
+)
27
+
28
+func debugf(format string, args ...interface{}) {
29
+	log.Printf("debug: "+format, args...)
30
+}
31
+
32
+func update() (int, error) {
33
+	log.Println("table-rbl: update")
34
+	return 1, nil
35
+}
36
+
37
+func reverse(ip net.IP) net.IP {
38
+	log.Printf("ip: %#+v", ip)
39
+	return net.IP{ip[3], ip[2], ip[1], ip[0]}
40
+}
41
+
42
+func lookup(rbl string, host net.IP) (result string, listed bool, err error) {
43
+	var (
44
+		query   = fmt.Sprintf("%s.%s", host, rbl)
45
+		results []string
46
+	)
47
+	log.Printf("table-rbl: lookup %q", query)
48
+	if results, err = net.LookupHost(query); err != nil {
49
+		if strings.HasSuffix(err.Error(), ": no such host") {
50
+			err = nil
51
+		}
52
+		return
53
+	}
54
+
55
+	if listed = len(results) > 0; listed {
56
+		txts, _ := net.LookupTXT(query)
57
+		if len(txts) > 0 {
58
+			result = txts[0]
59
+		}
60
+	}
61
+	return
62
+}
63
+
64
+func check(service int, params opensmtpd.Dict, key string) (int, error) {
65
+	log.Printf("table-rbl: check key=%q", key)
66
+	if key == "local" {
67
+		return 1, nil
68
+	}
69
+
70
+	ips, err := net.LookupIP(key)
71
+	if err != nil {
72
+		log.Printf("table-rbl: error looking up %q: %v", key, err)
73
+		return -1, err
74
+	}
75
+
76
+	for _, ip := range ips {
77
+		if ip = ip.To4(); ip == nil {
78
+			continue
79
+		}
80
+		log.Printf("table-rbl: %q resolved to %s (%s)", key, ip, reverse(ip))
81
+		for _, network := range ignored {
82
+			if network.Contains(ip) {
83
+				log.Printf("table-rbl: %s is ignored", ip)
84
+				return 1, nil
85
+			}
86
+		}
87
+
88
+		if result, block := cache.Get(key); block && result.(string) != "" {
89
+			log.Printf("table-rbl: reject %s (reason %q)", ip, result)
90
+			return 0, nil
91
+		}
92
+
93
+		var (
94
+			result string
95
+			listed bool
96
+			host   = reverse(ip)
97
+			err    error
98
+		)
99
+		for _, rbl := range config.Accept {
100
+			if result, listed, err = lookup(rbl, host); err != nil {
101
+				log.Printf("table-rbl: error looking up %q in %q: %v", host, rbl, err)
102
+				return -1, nil
103
+			} else if listed {
104
+				log.Printf("table-rbl: accept %q (reason %q)", ip, result)
105
+				return 1, nil
106
+			}
107
+		}
108
+		for _, rbl := range config.Reject {
109
+			if result, listed, err = lookup(rbl, host); err != nil {
110
+				log.Printf("table-rbl: error looking up %q in %q: %v", host, rbl, err)
111
+				return -1, nil
112
+			} else if listed {
113
+				log.Printf("table-rbl: reject %q (reason %q)", ip, result)
114
+				return 0, nil
115
+			}
116
+		}
117
+	}
118
+
119
+	log.Printf("table-rbl: accept %q (not rejected)", key)
120
+	return 1, nil
121
+}
122
+
123
+func main() {
124
+	flag.Parse()
125
+	if flag.NArg() != 1 {
126
+		panic(fmt.Sprintf("%s <config>\n", os.Args[0]))
127
+	}
128
+	log.Printf("table-rbl: args=%v", flag.Args())
129
+
130
+	b, err := ioutil.ReadFile(flag.Arg(0))
131
+	if err != nil {
132
+		log.Fatalln("table-rbl", err)
133
+	}
134
+	if err = hcl.Unmarshal(b, &config); err != nil {
135
+		log.Fatalln("table-rbl", err)
136
+	}
137
+	if len(config.Reject) == 0 {
138
+		log.Fatalln("table-rbl: no reject rules configured")
139
+	}
140
+
141
+	// Setup cache
142
+	if config.Cache == 0 {
143
+		cache, err = lru.New(1024)
144
+	} else {
145
+		cache, err = lru.New(config.Cache)
146
+	}
147
+	if err != nil {
148
+		log.Fatalln("table-rbl", err)
149
+	}
150
+
151
+	// Parse ignore rules
152
+	for _, prefix := range config.Ignore {
153
+		var ipnet *net.IPNet
154
+		if _, ipnet, err = net.ParseCIDR(prefix); err != nil {
155
+			panic(err)
156
+		}
157
+		ignored = append(ignored, ipnet)
158
+		debugf("ignore %s", ipnet)
159
+	}
160
+
161
+	opensmtpd.Debug = true
162
+
163
+	table := &opensmtpd.Table{
164
+		Update: update,
165
+		Check:  check,
166
+		Close: func() error {
167
+			log.Println("table-rbl: close")
168
+			return nil
169
+		},
170
+	}
171
+	log.Fatalln(table.Serve())
172
+}

+ 13
- 13
filter.go View File

@@ -178,7 +178,7 @@ type Filter struct {
178 178
 	Version uint32
179 179
 
180 180
 	c net.Conn
181
-	m *message
181
+	m *Message
182 182
 
183 183
 	hooks   int
184 184
 	flags   int
@@ -190,7 +190,7 @@ type Filter struct {
190 190
 func (f *Filter) Register() error {
191 191
 	var err error
192 192
 	if f.m == nil {
193
-		f.m = new(message)
193
+		f.m = new(Message)
194 194
 	}
195 195
 	if f.c == nil {
196 196
 		if f.c, err = newConn(0); err != nil {
@@ -230,13 +230,13 @@ func (f *Filter) Register() error {
230 230
 		f.hooks |= hookCommit
231 231
 	}
232 232
 
233
-	if t, ok := filterTypeName[f.m.Type]; ok {
233
+	if t, ok := filterTypeName[f.m.Header.Type]; ok {
234 234
 		log.Printf("filter: imsg %s\n", t)
235 235
 	} else {
236
-		log.Printf("filter: imsg UNKNOWN %d\n", f.m.Type)
236
+		log.Printf("filter: imsg UNKNOWN %d\n", f.m.Header.Type)
237 237
 	}
238 238
 
239
-	switch f.m.Type {
239
+	switch f.m.Header.Type {
240 240
 	case typeFilterRegister:
241 241
 		var err error
242 242
 		if f.Version, err = f.m.GetTypeUint32(); err != nil {
@@ -248,14 +248,14 @@ func (f *Filter) Register() error {
248 248
 		log.Printf("register version=%d,name=%q\n", f.Version, f.Name)
249 249
 
250 250
 		f.m.reset()
251
-		f.m.Type = typeFilterRegister
251
+		f.m.Header.Type = typeFilterRegister
252 252
 		f.m.PutTypeInt(f.hooks)
253 253
 		f.m.PutTypeInt(f.flags)
254 254
 		if err = f.m.WriteTo(f.c); err != nil {
255 255
 			return err
256 256
 		}
257 257
 	default:
258
-		return fmt.Errorf("filter: unexpected imsg type=%s\n", filterTypeName[f.m.Type])
258
+		return fmt.Errorf("filter: unexpected imsg type=%s\n", filterTypeName[f.m.Header.Type])
259 259
 	}
260 260
 
261 261
 	f.ready = true
@@ -274,7 +274,7 @@ func (f *Filter) Serve() error {
274 274
 	}
275 275
 
276 276
 	if f.m == nil {
277
-		f.m = new(message)
277
+		f.m = new(Message)
278 278
 	}
279 279
 	if f.session == nil {
280 280
 		if f.session, err = lru.New(1024); err != nil {
@@ -300,13 +300,13 @@ func (f *Filter) Serve() error {
300 300
 }
301 301
 
302 302
 func (f *Filter) handle() (err error) {
303
-	if t, ok := filterTypeName[f.m.Type]; ok {
303
+	if t, ok := filterTypeName[f.m.Header.Type]; ok {
304 304
 		log.Printf("filter: imsg %s\n", t)
305 305
 	} else {
306
-		log.Printf("filter: imsg UNKNOWN %d\n", f.m.Type)
306
+		log.Printf("filter: imsg UNKNOWN %d\n", f.m.Header.Type)
307 307
 	}
308 308
 
309
-	switch f.m.Type {
309
+	switch f.m.Header.Type {
310 310
 	case typeFilterEvent:
311 311
 		if err = f.handleEvent(); err != nil {
312 312
 			return
@@ -488,8 +488,8 @@ func (f *Filter) respond(s *Session, status, code int, line string) error {
488 488
 		return nil
489 489
 	}
490 490
 
491
-	m := new(message)
492
-	m.Type = typeFilterResponse
491
+	m := new(Message)
492
+	m.Header.Type = typeFilterResponse
493 493
 	m.PutTypeID(s.qid)
494 494
 	m.PutTypeInt(s.qtype)
495 495
 	if s.qtype == queryEOM {

+ 2
- 2
filter_test.go View File

@@ -7,7 +7,7 @@ func ExampleFilter() {
7 7
 	filter := &Filter{
8 8
 		HELO: func(session *Session, helo string) error {
9 9
 			if helo == "test" {
10
-				return session.Reject()
10
+				return session.Reject(FilterOK, 0)
11 11
 			}
12 12
 			return session.Accept()
13 13
 		},
@@ -16,7 +16,7 @@ func ExampleFilter() {
16 16
 	// Add another hook
17 17
 	filter.MAIL = func(session *Session, user, domain string) error {
18 18
 		if strings.ToLower(domain) == "example.org" {
19
-			return session.Reject()
19
+			return session.Reject(FilterOK, 0)
20 20
 		}
21 21
 		return session.Accept()
22 22
 	}

+ 53
- 51
imsg.go View File

@@ -20,8 +20,8 @@ const (
20 20
 	maxDomainPartSize = (255 + 1)
21 21
 )
22 22
 
23
-// messageHeader is the header of an imsg frame (struct imsg_hdr)
24
-type messageHeader struct {
23
+// MessageHeader is the header of an imsg frame (struct imsg_hdr)
24
+type MessageHeader struct {
25 25
 	Type   uint32
26 26
 	Len    uint16
27 27
 	Flags  uint16
@@ -30,8 +30,10 @@ type messageHeader struct {
30 30
 }
31 31
 
32 32
 // message implements OpenBSD imsg
33
-type message struct {
34
-	messageHeader
33
+type Message struct {
34
+	Header MessageHeader
35
+
36
+	// Data is the Message payload.
35 37
 	Data []byte
36 38
 
37 39
 	// rpos is the read position in the current Data
@@ -41,12 +43,12 @@ type message struct {
41 43
 	buf []byte
42 44
 }
43 45
 
44
-func (m *message) reset() {
45
-	m.Type = 0
46
-	m.Len = 0
47
-	m.Flags = 0
48
-	m.PeerID = imsgVersion
49
-	m.PID = uint32(os.Getpid())
46
+func (m *Message) reset() {
47
+	m.Header.Type = 0
48
+	m.Header.Len = 0
49
+	m.Header.Flags = 0
50
+	m.Header.PeerID = imsgVersion
51
+	m.Header.PID = uint32(os.Getpid())
50 52
 	m.Data = m.Data[:0]
51 53
 	m.rpos = 0
52 54
 	m.buf = m.buf[:0]
@@ -54,22 +56,22 @@ func (m *message) reset() {
54 56
 
55 57
 // ReadFrom reads a message from the specified net.Conn, parses the header and
56 58
 // reads the data payload.
57
-func (m *message) ReadFrom(c net.Conn) error {
59
+func (m *Message) ReadFrom(r io.Reader) error {
58 60
 	m.reset()
59 61
 
60 62
 	head := make([]byte, imsgHeaderSize)
61
-	if _, err := c.Read(head); err != nil {
63
+	if _, err := r.Read(head); err != nil {
62 64
 		return err
63 65
 	}
64 66
 
65
-	r := bytes.NewBuffer(head)
66
-	if err := binary.Read(r, binary.LittleEndian, &m.messageHeader); err != nil {
67
+	buf := bytes.NewBuffer(head)
68
+	if err := binary.Read(buf, binary.LittleEndian, &m.Header); err != nil {
67 69
 		return err
68 70
 	}
69
-	debugf("imsg header: %+v\n", m.messageHeader)
71
+	debugf("imsg header: %+v\n", m.Header)
70 72
 
71
-	data := make([]byte, m.messageHeader.Len-imsgHeaderSize)
72
-	if _, err := c.Read(data); err != nil {
73
+	data := make([]byte, m.Header.Len-imsgHeaderSize)
74
+	if _, err := r.Read(data); err != nil {
73 75
 		return err
74 76
 	}
75 77
 	m.Data = data
@@ -79,22 +81,22 @@ func (m *message) ReadFrom(c net.Conn) error {
79 81
 }
80 82
 
81 83
 // WriteTo marshals the message to wire format and sends it to the net.Conn
82
-func (m *message) WriteTo(c net.Conn) error {
83
-	m.Len = uint16(len(m.Data)) + imsgHeaderSize
84
+func (m *Message) WriteTo(w io.Writer) error {
85
+	m.Header.Len = uint16(len(m.Data)) + imsgHeaderSize
84 86
 
85 87
 	buf := new(bytes.Buffer)
86
-	debugf("imsg header: %+v\n", m.messageHeader)
87
-	if err := binary.Write(buf, binary.LittleEndian, &m.messageHeader); err != nil {
88
+	debugf("imsg header: %+v\n", m.Header)
89
+	if err := binary.Write(buf, binary.LittleEndian, &m.Header); err != nil {
88 90
 		return err
89 91
 	}
90 92
 	buf.Write(m.Data)
91 93
 	debugf("imsg send: %d / %q\n", buf.Len(), buf.Bytes())
92 94
 
93
-	_, err := c.Write(buf.Bytes())
95
+	_, err := w.Write(buf.Bytes())
94 96
 	return err
95 97
 }
96 98
 
97
-func (m *message) GetInt() (int, error) {
99
+func (m *Message) GetInt() (int, error) {
98 100
 	if m.rpos+4 > len(m.Data) {
99 101
 		return 0, io.ErrShortBuffer
100 102
 	}
@@ -103,7 +105,7 @@ func (m *message) GetInt() (int, error) {
103 105
 	return int(i), nil
104 106
 }
105 107
 
106
-func (m *message) GetUint32() (uint32, error) {
108
+func (m *Message) GetUint32() (uint32, error) {
107 109
 	if m.rpos+4 > len(m.Data) {
108 110
 		return 0, io.ErrShortBuffer
109 111
 	}
@@ -112,7 +114,7 @@ func (m *message) GetUint32() (uint32, error) {
112 114
 	return u, nil
113 115
 }
114 116
 
115
-func (m *message) GetSize() (uint64, error) {
117
+func (m *Message) GetSize() (uint64, error) {
116 118
 	if m.rpos+8 > len(m.Data) {
117 119
 		return 0, io.ErrShortBuffer
118 120
 	}
@@ -121,7 +123,7 @@ func (m *message) GetSize() (uint64, error) {
121 123
 	return u, nil
122 124
 }
123 125
 
124
-func (m *message) GetString() (string, error) {
126
+func (m *Message) GetString() (string, error) {
125 127
 	o := bytes.IndexByte(m.Data[m.rpos:], 0)
126 128
 	if o < 0 {
127 129
 		return "", errors.New("imsg: string not NULL-terminated")
@@ -132,7 +134,7 @@ func (m *message) GetString() (string, error) {
132 134
 	return s, nil
133 135
 }
134 136
 
135
-func (m *message) GetID() (uint64, error) {
137
+func (m *Message) GetID() (uint64, error) {
136 138
 	if m.rpos+8 > len(m.Data) {
137 139
 		return 0, io.ErrShortBuffer
138 140
 	}
@@ -174,7 +176,7 @@ func (sa Sockaddr) String() string {
174 176
 	return fmt.Sprintf("%s:%d", sa.IP(), sa.Port())
175 177
 }
176 178
 
177
-func (m *message) GetSockaddr() (net.Addr, error) {
179
+func (m *Message) GetSockaddr() (net.Addr, error) {
178 180
 	s, err := m.GetSize()
179 181
 	if err != nil {
180 182
 		return nil, err
@@ -190,7 +192,7 @@ func (m *message) GetSockaddr() (net.Addr, error) {
190 192
 	return a, nil
191 193
 }
192 194
 
193
-func (m *message) GetMailaddr() (user, domain string, err error) {
195
+func (m *Message) GetMailaddr() (user, domain string, err error) {
194 196
 	var buf [maxLocalPartSize + maxDomainPartSize]byte
195 197
 	if maxLocalPartSize+maxDomainPartSize > len(m.Data[m.rpos:]) {
196 198
 		return "", "", io.ErrShortBuffer
@@ -202,7 +204,7 @@ func (m *message) GetMailaddr() (user, domain string, err error) {
202 204
 	return
203 205
 }
204 206
 
205
-func (m *message) GetType(t uint8) error {
207
+func (m *Message) GetType(t uint8) error {
206 208
 	if m.rpos >= len(m.Data) {
207 209
 		return io.ErrShortBuffer
208 210
 	}
@@ -215,102 +217,102 @@ func (m *message) GetType(t uint8) error {
215 217
 	return nil
216 218
 }
217 219
 
218
-func (m *message) GetTypeInt() (int, error) {
220
+func (m *Message) GetTypeInt() (int, error) {
219 221
 	if err := m.GetType(mINT); err != nil {
220 222
 		return 0, err
221 223
 	}
222 224
 	return m.GetInt()
223 225
 }
224 226
 
225
-func (m *message) GetTypeUint32() (uint32, error) {
227
+func (m *Message) GetTypeUint32() (uint32, error) {
226 228
 	if err := m.GetType(mUINT32); err != nil {
227 229
 		return 0, err
228 230
 	}
229 231
 	return m.GetUint32()
230 232
 }
231 233
 
232
-func (m *message) GetTypeSize() (uint64, error) {
234
+func (m *Message) GetTypeSize() (uint64, error) {
233 235
 	if err := m.GetType(mSIZET); err != nil {
234 236
 		return 0, err
235 237
 	}
236 238
 	return m.GetSize()
237 239
 }
238 240
 
239
-func (m *message) GetTypeString() (string, error) {
241
+func (m *Message) GetTypeString() (string, error) {
240 242
 	if err := m.GetType(mSTRING); err != nil {
241 243
 		return "", err
242 244
 	}
243 245
 	return m.GetString()
244 246
 }
245 247
 
246
-func (m *message) GetTypeID() (uint64, error) {
248
+func (m *Message) GetTypeID() (uint64, error) {
247 249
 	if err := m.GetType(mID); err != nil {
248 250
 		return 0, err
249 251
 	}
250 252
 	return m.GetID()
251 253
 }
252 254
 
253
-func (m *message) GetTypeSockaddr() (net.Addr, error) {
255
+func (m *Message) GetTypeSockaddr() (net.Addr, error) {
254 256
 	if err := m.GetType(mSOCKADDR); err != nil {
255 257
 		return nil, err
256 258
 	}
257 259
 	return m.GetSockaddr()
258 260
 }
259 261
 
260
-func (m *message) GetTypeMailaddr() (user, domain string, err error) {
262
+func (m *Message) GetTypeMailaddr() (user, domain string, err error) {
261 263
 	if err = m.GetType(mMAILADDR); err != nil {
262 264
 		return
263 265
 	}
264 266
 	return m.GetMailaddr()
265 267
 }
266 268
 
267
-func (m *message) PutInt(v int) {
269
+func (m *Message) PutInt(v int) {
268 270
 	var b [4]byte
269 271
 	binary.LittleEndian.PutUint32(b[:], uint32(v))
270 272
 	m.Data = append(m.Data, b[:]...)
271
-	m.Len += 4
273
+	m.Header.Len += 4
272 274
 }
273 275
 
274
-func (m *message) PutUint32(v uint32) {
276
+func (m *Message) PutUint32(v uint32) {
275 277
 	var b [4]byte
276 278
 	binary.LittleEndian.PutUint32(b[:], v)
277 279
 	m.Data = append(m.Data, b[:]...)
278
-	m.Len += 4
280
+	m.Header.Len += 4
279 281
 }
280 282
 
281
-func (m *message) PutString(s string) {
283
+func (m *Message) PutString(s string) {
282 284
 	m.Data = append(m.Data, append([]byte(s), 0)...)
283
-	m.Len += uint16(len(s)) + 1
285
+	m.Header.Len += uint16(len(s)) + 1
284 286
 }
285 287
 
286
-func (m *message) PutID(id uint64) {
288
+func (m *Message) PutID(id uint64) {
287 289
 	var b [8]byte
288 290
 	binary.LittleEndian.PutUint64(b[:], id)
289 291
 	m.Data = append(m.Data, b[:]...)
290
-	m.Len += 8
292
+	m.Header.Len += 8
291 293
 }
292 294
 
293
-func (m *message) PutType(t uint8) {
295
+func (m *Message) PutType(t uint8) {
294 296
 	m.Data = append(m.Data, t)
295
-	m.Len += 1
297
+	m.Header.Len += 1
296 298
 }
297 299
 
298
-func (m *message) PutTypeInt(v int) {
300
+func (m *Message) PutTypeInt(v int) {
299 301
 	m.PutType(mINT)
300 302
 	m.PutInt(v)
301 303
 }
302 304
 
303
-func (m *message) PutTypeUint32(v uint32) {
305
+func (m *Message) PutTypeUint32(v uint32) {
304 306
 	m.PutType(mUINT32)
305 307
 	m.PutUint32(v)
306 308
 }
307 309
 
308
-func (m *message) PutTypeString(s string) {
310
+func (m *Message) PutTypeString(s string) {
309 311
 	m.PutType(mSTRING)
310 312
 	m.PutString(s)
311 313
 }
312 314
 
313
-func (m *message) PutTypeID(id uint64) {
315
+func (m *Message) PutTypeID(id uint64) {
314 316
 	m.PutType(mID)
315 317
 	m.PutID(id)
316 318
 }

+ 16
- 12
smtpd.go View File

@@ -8,13 +8,13 @@ import (
8 8
 
9 9
 const (
10 10
 	// FilterVersion is the supported filter API version
11
-	FilterVersion = 51
11
+	FilterVersion = 52
12 12
 
13 13
 	// QueueVersion is the supported queue API version
14
-	QueueVersion = 1
14
+	QueueVersion = 2
15 15
 
16 16
 	// TableVersion is the supported table API version
17
-	TableVersion = 1
17
+	TableVersion = 2
18 18
 )
19 19
 
20 20
 var (
@@ -36,19 +36,23 @@ const (
36 36
 	ServiceMailaddr    = 0x040
37 37
 	ServiceAddrname    = 0x080
38 38
 	ServiceMailaddrMap = 0x100
39
+	ServiceRelayHost   = 0x200
40
+	ServiceString      = 0x400
39 41
 	ServiceAny         = 0xfff
40 42
 )
41 43
 
42 44
 var serviceTypeName = map[int]string{
43
-	ServiceAlias:       "ALIAS",
44
-	ServiceDomain:      "DOMAIN",
45
-	ServiceCredentials: "CREDENTIALS",
46
-	ServiceNetaddr:     "NETADDR",
47
-	ServiceUserinfo:    "USERINFO",
48
-	ServiceSource:      "SOURCE",
49
-	ServiceMailaddr:    "MAILADDR",
50
-	ServiceAddrname:    "ADDRNAME",
51
-	ServiceMailaddrMap: "MAILADDRMAP",
45
+	ServiceAlias:       "alias",
46
+	ServiceDomain:      "domain",
47
+	ServiceCredentials: "credentials",
48
+	ServiceNetaddr:     "netaddr",
49
+	ServiceUserinfo:    "userinfo",
50
+	ServiceSource:      "source",
51
+	ServiceMailaddr:    "mailaddr",
52
+	ServiceAddrname:    "addrname",
53
+	ServiceMailaddrMap: "maddrmap",
54
+	ServiceRelayHost:   "relayhost",
55
+	ServiceString:      "string",
52 56
 }
53 57
 
54 58
 func serviceName(service int) string {

+ 42
- 19
table.go View File

@@ -39,6 +39,21 @@ func procTableName(t uint32) string {
39 39
 	return fmt.Sprintf("UNKNOWN %d", t)
40 40
 }
41 41
 
42
+// Table services.
43
+const (
44
+	TableAlias       = 0x001 /* returns struct expand	*/
45
+	TableDomain      = 0x002 /* returns struct destination	*/
46
+	TableCredentials = 0x004 /* returns struct credentials	*/
47
+	TableNetAddr     = 0x008 /* returns struct netaddr	*/
48
+	TableUserInfo    = 0x010 /* returns struct userinfo	*/
49
+	TableSource      = 0x020 /* returns struct source	*/
50
+	TableMailAddr    = 0x040 /* returns struct mailaddr	*/
51
+	TableAddrName    = 0x080 /* returns struct addrname	*/
52
+	TableMailAddrMap = 0x100 /* returns struct maddrmap	*/
53
+	TableRelayHost   = 0x200 /* returns struct relayhost	*/
54
+	TableString      = 0x400
55
+)
56
+
42 57
 // Table implements the OpenSMTPD table API
43 58
 type Table struct {
44 59
 	// Update callback
@@ -57,7 +72,7 @@ type Table struct {
57 72
 	Close func() error
58 73
 
59 74
 	c      net.Conn
60
-	m      *message
75
+	m      *Message
61 76
 	closed bool
62 77
 }
63 78
 
@@ -68,17 +83,17 @@ func (t *Table) Serve() error {
68 83
 		return err
69 84
 	}
70 85
 
71
-	t.m = new(message)
86
+	t.m = new(Message)
72 87
 
73 88
 	for !t.closed {
74 89
 		if err = t.m.ReadFrom(t.c); err != nil {
75 90
 			if err.Error() != "resource temporarily unavailable" {
76
-				break
91
+				return fmt.Errorf("read error: %v", err)
77 92
 			}
78 93
 		}
79
-		debugf("table: %s", procTableName(t.m.Type))
94
+		debugf("table: %s", procTableName(t.m.Header.Type))
80 95
 		if err = t.dispatch(); err != nil {
81
-			break
96
+			return fmt.Errorf("dispatch error: %v", err)
82 97
 		}
83 98
 	}
84 99
 
@@ -91,7 +106,7 @@ type tableOpenParams struct {
91 106
 }
92 107
 
93 108
 func (t *Table) dispatch() (err error) {
94
-	switch t.m.Type {
109
+	switch t.m.Header.Type {
95 110
 	case procTableOpen:
96 111
 		/*
97 112
 			var op tableOpenParams
@@ -120,12 +135,12 @@ func (t *Table) dispatch() (err error) {
120 135
 			fatal("table: no name supplied by smtpd!?")
121 136
 		}
122 137
 
123
-		debugf("table: version=%d,name=%q\n", version, name)
138
+		debugf("table: version=%d name=%q\n", version, name)
124 139
 
125
-		m := new(message)
126
-		m.Type = procTableOK
127
-		m.Len = imsgHeaderSize
128
-		m.PID = uint32(os.Getpid())
140
+		m := new(Message)
141
+		m.Header.Type = procTableOK
142
+		m.Header.Len = imsgHeaderSize
143
+		m.Header.PID = uint32(os.Getpid())
129 144
 		if err = m.WriteTo(t.c); err != nil {
130 145
 			return
131 146
 		}
@@ -139,8 +154,8 @@ func (t *Table) dispatch() (err error) {
139 154
 			}
140 155
 		}
141 156
 
142
-		m := new(message)
143
-		m.Type = procTableOK
157
+		m := new(Message)
158
+		m.Header.Type = procTableOK
144 159
 		m.PutInt(r)
145 160
 		if err = m.WriteTo(t.c); err != nil {
146 161
 			return
@@ -184,6 +199,14 @@ func (t *Table) dispatch() (err error) {
184 199
 
185 200
 		log.Printf("table_check: result=%d\n", r)
186 201
 
202
+		m := new(Message)
203
+		m.Header.Type = procTableOK
204
+		m.Header.PID = uint32(os.Getpid())
205
+		m.PutInt(r)
206
+		if err = m.WriteTo(t.c); err != nil {
207
+			return
208
+		}
209
+
187 210
 	case procTableLookup:
188 211
 		var service int
189 212
 		if service, err = t.m.GetInt(); err != nil {
@@ -210,9 +233,9 @@ func (t *Table) dispatch() (err error) {
210 233
 			}
211 234
 		}
212 235
 
213
-		m := new(message)
214
-		m.Type = procTableOK
215
-		m.PID = uint32(os.Getpid())
236
+		m := new(Message)
237
+		m.Header.Type = procTableOK
238
+		m.Header.PID = uint32(os.Getpid())
216 239
 		if val == "" {
217 240
 			m.PutInt(-1)
218 241
 		} else {
@@ -244,9 +267,9 @@ func (t *Table) dispatch() (err error) {
244 267
 			}
245 268
 		}
246 269
 
247
-		m := new(message)
248
-		m.Type = procTableOK
249
-		m.PID = uint32(os.Getpid())
270
+		m := new(Message)
271
+		m.Header.Type = procTableOK
272
+		m.Header.PID = uint32(os.Getpid())
250 273
 		if val == "" {
251 274
 			m.PutInt(-1)
252 275
 		} else {

Loading…
Cancel
Save