Browse Source

Hide less interesting bits of the internal API

tags/v52
Wijnand 2 years ago
parent
commit
75ff870fa8
7 changed files with 235 additions and 224 deletions
  1. 2
    2
      conn.go
  2. 105
    102
      filter.go
  3. 4
    3
      filter_test.go
  4. 50
    50
      imsg.go
  5. 26
    26
      mproc.go
  6. 11
    0
      smtpd.go
  7. 37
    41
      table.go

+ 2
- 2
conn.go View File

@@ -5,8 +5,8 @@ import (
"os"
)

// NewConn wraps a file descriptor to a net.FileConn
func NewConn(fd int) (net.Conn, error) {
// newConn wraps a file descriptor to a net.FileConn
func newConn(fd int) (net.Conn, error) {
f := os.NewFile(uintptr(fd), "")
return net.FileConn(f)
}

+ 105
- 102
filter.go View File

@@ -11,23 +11,19 @@ import (
)

const (
FilterVersion = 51
)

const (
TypeFilterRegister uint32 = iota
TypeFilterEvent
TypeFilterQuery
TypeFilterPipe
TypeFilterResponse
typeFilterRegister uint32 = iota
typeFilterEvent
typeFilterquery
typeFilterPipe
typeFilterResponse
)

var filterTypeName = map[uint32]string{
TypeFilterRegister: "IMSG_FILTER_REGISTER",
TypeFilterEvent: "IMSG_FILTER_EVENT",
TypeFilterQuery: "IMSG_FILTER_QUERY",
TypeFilterPipe: "IMSG_FILTER_PIPE",
TypeFilterResponse: "IMSG_FILTER_RESPONSE",
typeFilterRegister: "IMSG_FILTER_REGISTER",
typeFilterEvent: "IMSG_FILTER_EVENT",
typeFilterquery: "IMSG_FILTER_QUERY",
typeFilterPipe: "IMSG_FILTER_PIPE",
typeFilterResponse: "IMSG_FILTER_RESPONSE",
}

func filterName(t uint32) string {
@@ -38,31 +34,31 @@ func filterName(t uint32) string {
}

const (
HookConnect = 1 << iota
HookHELO
HookMAIL
HookRCPT
HookDATA
HookEOM
HookReset
HookDisconnect
HookCommit
HookRollback
HookDataLine
hookConnect = 1 << iota
hookHELO
hookMAIL
hookRCPT
hookDATA
hookEOM
hookReset
hookDisconnect
hookCommit
hookRollback
hookDataLine
)

var hookTypeName = map[uint16]string{
HookConnect: "HOOK_CONNECT",
HookHELO: "HOOK_HELO",
HookMAIL: "HOOK_MAIL",
HookRCPT: "HOOK_RCPT",
HookDATA: "HOOK_DATA",
HookEOM: "HOOK_EOM",
HookReset: "HOOK_RESET",
HookDisconnect: "HOOK_DISCONNECT",
HookCommit: "HOOK_COMMIT",
HookRollback: "HOOK_ROLLBACK",
HookDataLine: "HOOK_DATALINE",
hookConnect: "HOOK_CONNECT",
hookHELO: "HOOK_HELO",
hookMAIL: "HOOK_MAIL",
hookRCPT: "HOOK_RCPT",
hookDATA: "HOOK_DATA",
hookEOM: "HOOK_EOM",
hookReset: "HOOK_RESET",
hookDisconnect: "HOOK_DISCONNECT",
hookCommit: "HOOK_COMMIT",
hookRollback: "HOOK_ROLLBACK",
hookDataLine: "HOOK_DATALINE",
}

func hookName(h uint16) string {
@@ -76,21 +72,21 @@ func hookName(h uint16) string {
}

const (
EventConnect = iota
EventReset
EventDisconnect
EventTXBegin
EventTXCommit
EventTXRollback
eventConnect = iota
eventReset
eventDisconnect
eventTXBegin
eventTXCommit
eventTXRollback
)

var eventTypeName = map[int]string{
EventConnect: "EVENT_CONNECT",
EventReset: "EVENT_RESET",
EventDisconnect: "EVENT_DISCONNECT",
EventTXBegin: "EVENT_TX_BEGIN",
EventTXCommit: "EVENT_TX_COMMIT",
EventTXRollback: "EVENT_TX_ROLLBACK",
eventConnect: "EVENT_CONNECT",
eventReset: "EVENT_RESET",
eventDisconnect: "EVENT_DISCONNECT",
eventTXBegin: "EVENT_TX_BEGIN",
eventTXCommit: "EVENT_TX_COMMIT",
eventTXRollback: "EVENT_TX_ROLLBACK",
}

func eventName(t int) string {
@@ -101,23 +97,23 @@ func eventName(t int) string {
}

const (
QueryConnect = iota
QueryHELO
QueryMAIL
QueryRCPT
QueryDATA
QueryEOM
QueryDataLine
queryConnect = iota
queryHELO
queryMAIL
queryRCPT
queryDATA
queryEOM
queryDataLine
)

var queryTypeName = map[int]string{
QueryConnect: "QUERY_CONNECT",
QueryHELO: "QUERY_HELO",
QueryMAIL: "QUERY_MAIL",
QueryRCPT: "QUERY_RCPT",
QueryDATA: "QUERY_DATA",
QueryEOM: "QUERY_EOM",
QueryDataLine: "QUERY_DATALINE",
queryConnect: "QUERY_CONNECT",
queryHELO: "QUERY_HELO",
queryMAIL: "QUERY_MAIL",
queryRCPT: "QUERY_RCPT",
queryDATA: "QUERY_DATA",
queryEOM: "QUERY_EOM",
queryDataLine: "QUERY_DATALINE",
}

func queryName(t int) string {
@@ -182,7 +178,7 @@ type Filter struct {
Version uint32

c net.Conn
m *Message
m *message

hooks int
flags int
@@ -192,42 +188,42 @@ type Filter struct {

func (f *Filter) OnConnect(fn func(*Session, *ConnectQuery) error) {
f.Connect = fn
f.hooks |= HookConnect
f.hooks |= hookConnect
}

func (f *Filter) OnHELO(fn func(*Session, string) error) {
f.HELO = fn
f.hooks |= HookHELO
f.hooks |= hookHELO
}

func (f *Filter) OnMAIL(fn func(*Session, string, string) error) {
f.MAIL = fn
f.hooks |= HookMAIL
f.hooks |= hookMAIL
}

func (f *Filter) OnRCPT(fn func(*Session, string, string) error) {
f.RCPT = fn
f.hooks |= HookRCPT
f.hooks |= hookRCPT
}

func (f *Filter) OnDATA(fn func(*Session) error) {
f.DATA = fn
f.hooks |= HookDATA
f.hooks |= hookDATA
}

func (f *Filter) OnDataLine(fn func(*Session, string) error) {
f.DataLine = fn
f.hooks |= HookDataLine
f.hooks |= hookDataLine
}

// Register our filter with OpenSMTPD
func (f *Filter) Register() error {
var err error
if f.m == nil {
f.m = new(Message)
f.m = new(message)
}
if f.c == nil {
if f.c, err = NewConn(0); err != nil {
if f.c, err = newConn(0); err != nil {
return err
}
}
@@ -237,31 +233,31 @@ func (f *Filter) Register() error {

// Fill hooks mask
if f.Connect != nil {
f.hooks |= HookConnect
f.hooks |= hookConnect
}
if f.HELO != nil {
f.hooks |= HookHELO
f.hooks |= hookHELO
}
if f.MAIL != nil {
f.hooks |= HookMAIL
f.hooks |= hookMAIL
}
if f.RCPT != nil {
f.hooks |= HookRCPT
f.hooks |= hookRCPT
}
if f.DATA != nil {
f.hooks |= HookDATA
f.hooks |= hookDATA
}
if f.DataLine != nil {
f.hooks |= HookDataLine
f.hooks |= hookDataLine
}
if f.EOM != nil {
f.hooks |= HookEOM
f.hooks |= hookEOM
}
if f.Disconnect != nil {
f.hooks |= HookDisconnect
f.hooks |= hookDisconnect
}
if f.Commit != nil {
f.hooks |= HookCommit
f.hooks |= hookCommit
}

if t, ok := filterTypeName[f.m.Type]; ok {
@@ -271,7 +267,7 @@ func (f *Filter) Register() error {
}

switch f.m.Type {
case TypeFilterRegister:
case typeFilterRegister:
var err error
if f.Version, err = f.m.GetTypeUint32(); err != nil {
return err
@@ -282,7 +278,7 @@ func (f *Filter) Register() error {
log.Printf("register version=%d,name=%q\n", f.Version, f.Name)

f.m.reset()
f.m.Type = TypeFilterRegister
f.m.Type = typeFilterRegister
f.m.PutTypeInt(f.hooks)
f.m.PutTypeInt(f.flags)
if err = f.m.WriteTo(f.c); err != nil {
@@ -292,6 +288,7 @@ func (f *Filter) Register() error {
return fmt.Errorf("filter: unexpected imsg type=%s\n", filterTypeName[f.m.Type])
}

f.ready = true
return nil
}

@@ -299,8 +296,15 @@ func (f *Filter) Register() error {
// parties closes stdin.
func (f *Filter) Serve() error {
var err error

if !f.ready {
if err = f.Register(); err != nil {
return err
}
}

if f.m == nil {
f.m = new(Message)
f.m = new(message)
}
if f.session == nil {
if f.session, err = lru.New(1024); err != nil {
@@ -308,13 +312,12 @@ func (f *Filter) Serve() error {
}
}
if f.c == nil {
if f.c, err = NewConn(0); err != nil {
if f.c, err = newConn(0); err != nil {
return err
}
}

for {
//log.Printf("fdcount: %d [pid=%d]\n", fdCount(), os.Getpid())
if err := f.m.ReadFrom(f.c); err != nil {
if err.Error() != "resource temporarily unavailable" {
return err
@@ -334,13 +337,13 @@ func (f *Filter) handle() (err error) {
}

switch f.m.Type {
case TypeFilterEvent:
case typeFilterEvent:
if err = f.handleEvent(); err != nil {
return
}

case TypeFilterQuery:
if err = f.handleQuery(); err != nil {
case typeFilterquery:
if err = f.handlequery(); err != nil {
return
}
}
@@ -381,16 +384,16 @@ func (f *Filter) handleEvent() (err error) {
log.Printf("fdcount: %d [pid=%d]\n", fdCount(), os.Getpid())

switch t {
case EventConnect:
case eventConnect:
f.session.Add(id, NewSession(f, id))
case EventDisconnect:
case eventDisconnect:
f.session.Remove(id)
}

return
}

func (f *Filter) handleQuery() (err error) {
func (f *Filter) handlequery() (err error) {
var (
id, qid uint64
t int
@@ -421,7 +424,7 @@ func (f *Filter) handleQuery() (err error) {
s.qid = qid

switch t {
case QueryConnect:
case queryConnect:
var query ConnectQuery
if query.Local, err = f.m.GetTypeSockaddr(); err != nil {
return
@@ -440,7 +443,7 @@ func (f *Filter) handleQuery() (err error) {

log.Printf("filter: WARNING: no connect callback\n")

case QueryHELO:
case queryHELO:
var line string
if line, err = f.m.GetTypeString(); err != nil {
return
@@ -454,7 +457,7 @@ func (f *Filter) handleQuery() (err error) {
log.Printf("filter: WARNING: no HELO callback\n")
return f.respond(s, FilterOK, 0, "")

case QueryMAIL:
case queryMAIL:
var user, domain string
if user, domain, err = f.m.GetTypeMailaddr(); err != nil {
return
@@ -468,7 +471,7 @@ func (f *Filter) handleQuery() (err error) {
log.Printf("filter: WARNING: no MAIL callback\n")
return f.respond(s, FilterOK, 0, "")

case QueryRCPT:
case queryRCPT:
var user, domain string
if user, domain, err = f.m.GetTypeMailaddr(); err != nil {
return
@@ -482,7 +485,7 @@ func (f *Filter) handleQuery() (err error) {
log.Printf("filter: WARNING: no RCPT callback\n")
return f.respond(s, FilterOK, 0, "")

case QueryDATA:
case queryDATA:
if f.DATA != nil {
return f.DATA(s)
}
@@ -490,7 +493,7 @@ func (f *Filter) handleQuery() (err error) {
log.Printf("filter: WARNING: no DATA callback\n")
return f.respond(s, FilterOK, 0, "")

case QueryEOM:
case queryEOM:
var dataLen uint32
if dataLen, err = f.m.GetTypeUint32(); err != nil {
return
@@ -508,18 +511,18 @@ func (f *Filter) handleQuery() (err error) {
}

func (f *Filter) respond(s *Session, status, code int, line string) error {
log.Printf("filter: %s %s [code=%d,line=%q]\n", filterName(TypeFilterResponse), responseName(status), code, line)
log.Printf("filter: %s %s [code=%d,line=%q]\n", filterName(typeFilterResponse), responseName(status), code, line)

if s.qtype == QueryEOM {
if s.qtype == queryEOM {
// Not implemented
return nil
}

m := new(Message)
m.Type = TypeFilterResponse
m := new(message)
m.Type = typeFilterResponse
m.PutTypeID(s.qid)
m.PutTypeInt(s.qtype)
if s.qtype == QueryEOM {
if s.qtype == queryEOM {
// Not imlemented
return nil
}

+ 4
- 3
filter_test.go View File

@@ -14,14 +14,15 @@ func ExampleFilter() {
}

// Add another hook
filter.OnMAIL(func(session *Session, user, domain string) error {
filter.MAIL = func(session *Session, user, domain string) error {
if strings.ToLower(domain) == "example.org" {
return session.Reject()
}
return session.Accept()
})
}

// Register our filter with smtpd
// Register our filter with smtpd. This step is optional and will
// be performed by Serve() if omitted.
if err := filter.Register(); err != nil {
panic(err)
}

+ 50
- 50
imsg.go View File

@@ -20,8 +20,8 @@ const (
maxDomainPartSize = (255 + 1)
)

// MessageHeader is the header of an imsg frame (struct imsg_hdr)
type MessageHeader struct {
// messageHeader is the header of an imsg frame (struct imsg_hdr)
type messageHeader struct {
Type uint32
Len uint16
Flags uint16
@@ -29,9 +29,9 @@ type MessageHeader struct {
PID uint32
}

// Message implements OpenBSD imsg
type Message struct {
MessageHeader
// message implements OpenBSD imsg
type message struct {
messageHeader
Data []byte

// rpos is the read position in the current Data
@@ -41,7 +41,7 @@ type Message struct {
buf []byte
}

func (m *Message) reset() {
func (m *message) reset() {
m.Type = 0
m.Len = 0
m.Flags = 0
@@ -54,7 +54,7 @@ func (m *Message) reset() {

// ReadFrom reads a message from the specified net.Conn, parses the header and
// reads the data payload.
func (m *Message) ReadFrom(c net.Conn) error {
func (m *message) ReadFrom(c net.Conn) error {
m.reset()

head := make([]byte, imsgHeaderSize)
@@ -63,12 +63,12 @@ func (m *Message) ReadFrom(c net.Conn) error {
}

r := bytes.NewBuffer(head)
if err := binary.Read(r, binary.LittleEndian, &m.MessageHeader); err != nil {
if err := binary.Read(r, binary.LittleEndian, &m.messageHeader); err != nil {
return err
}
debugf("imsg header: %+v\n", m.MessageHeader)
debugf("imsg header: %+v\n", m.messageHeader)

data := make([]byte, m.MessageHeader.Len-imsgHeaderSize)
data := make([]byte, m.messageHeader.Len-imsgHeaderSize)
if _, err := c.Read(data); err != nil {
return err
}
@@ -78,13 +78,13 @@ func (m *Message) ReadFrom(c net.Conn) error {
return nil
}

// WriteTo marshals the Message to wire format and sends it to the net.Conn
func (m *Message) WriteTo(c net.Conn) error {
// WriteTo marshals the message to wire format and sends it to the net.Conn
func (m *message) WriteTo(c net.Conn) error {
m.Len = uint16(len(m.Data)) + imsgHeaderSize

buf := new(bytes.Buffer)
debugf("imsg header: %+v\n", m.MessageHeader)
if err := binary.Write(buf, binary.LittleEndian, &m.MessageHeader); err != nil {
debugf("imsg header: %+v\n", m.messageHeader)
if err := binary.Write(buf, binary.LittleEndian, &m.messageHeader); err != nil {
return err
}
buf.Write(m.Data)
@@ -94,7 +94,7 @@ func (m *Message) WriteTo(c net.Conn) error {
return err
}

func (m *Message) GetInt() (int, error) {
func (m *message) GetInt() (int, error) {
if m.rpos+4 > len(m.Data) {
return 0, io.ErrShortBuffer
}
@@ -103,7 +103,7 @@ func (m *Message) GetInt() (int, error) {
return int(i), nil
}

func (m *Message) GetUint32() (uint32, error) {
func (m *message) GetUint32() (uint32, error) {
if m.rpos+4 > len(m.Data) {
return 0, io.ErrShortBuffer
}
@@ -112,7 +112,7 @@ func (m *Message) GetUint32() (uint32, error) {
return u, nil
}

func (m *Message) GetSize() (uint64, error) {
func (m *message) GetSize() (uint64, error) {
if m.rpos+8 > len(m.Data) {
return 0, io.ErrShortBuffer
}
@@ -121,7 +121,7 @@ func (m *Message) GetSize() (uint64, error) {
return u, nil
}

func (m *Message) GetString() (string, error) {
func (m *message) GetString() (string, error) {
o := bytes.IndexByte(m.Data[m.rpos:], 0)
if o < 0 {
return "", errors.New("imsg: string not NULL-terminated")
@@ -132,7 +132,7 @@ func (m *Message) GetString() (string, error) {
return s, nil
}

func (m *Message) GetID() (uint64, error) {
func (m *message) GetID() (uint64, error) {
if m.rpos+8 > len(m.Data) {
return 0, io.ErrShortBuffer
}
@@ -174,7 +174,7 @@ func (sa Sockaddr) String() string {
return fmt.Sprintf("%s:%d", sa.IP(), sa.Port())
}

func (m *Message) GetSockaddr() (net.Addr, error) {
func (m *message) GetSockaddr() (net.Addr, error) {
s, err := m.GetSize()
if err != nil {
return nil, err
@@ -190,7 +190,7 @@ func (m *Message) GetSockaddr() (net.Addr, error) {
return a, nil
}

func (m *Message) GetMailaddr() (user, domain string, err error) {
func (m *message) GetMailaddr() (user, domain string, err error) {
var buf [maxLocalPartSize + maxDomainPartSize]byte
if maxLocalPartSize+maxDomainPartSize > len(m.Data[m.rpos:]) {
return "", "", io.ErrShortBuffer
@@ -202,7 +202,7 @@ func (m *Message) GetMailaddr() (user, domain string, err error) {
return
}

func (m *Message) GetType(t uint8) error {
func (m *message) GetType(t uint8) error {
if m.rpos >= len(m.Data) {
return io.ErrShortBuffer
}
@@ -210,107 +210,107 @@ func (m *Message) GetType(t uint8) error {
b := m.Data[m.rpos]
m.rpos++
if b != t {
return MProcTypeErr{t, b}
return mprocTypeErr{t, b}
}
return nil
}

func (m *Message) GetTypeInt() (int, error) {
if err := m.GetType(M_INT); err != nil {
func (m *message) GetTypeInt() (int, error) {
if err := m.GetType(mINT); err != nil {
return 0, err
}
return m.GetInt()
}

func (m *Message) GetTypeUint32() (uint32, error) {
if err := m.GetType(M_UINT32); err != nil {
func (m *message) GetTypeUint32() (uint32, error) {
if err := m.GetType(mUINT32); err != nil {
return 0, err
}
return m.GetUint32()
}

func (m *Message) GetTypeSize() (uint64, error) {
if err := m.GetType(M_SIZET); err != nil {
func (m *message) GetTypeSize() (uint64, error) {
if err := m.GetType(mSIZET); err != nil {
return 0, err
}
return m.GetSize()
}

func (m *Message) GetTypeString() (string, error) {
if err := m.GetType(M_STRING); err != nil {
func (m *message) GetTypeString() (string, error) {
if err := m.GetType(mSTRING); err != nil {
return "", err
}
return m.GetString()
}

func (m *Message) GetTypeID() (uint64, error) {
if err := m.GetType(M_ID); err != nil {
func (m *message) GetTypeID() (uint64, error) {
if err := m.GetType(mID); err != nil {
return 0, err
}
return m.GetID()
}

func (m *Message) GetTypeSockaddr() (net.Addr, error) {
if err := m.GetType(M_SOCKADDR); err != nil {
func (m *message) GetTypeSockaddr() (net.Addr, error) {
if err := m.GetType(mSOCKADDR); err != nil {
return nil, err
}
return m.GetSockaddr()
}

func (m *Message) GetTypeMailaddr() (user, domain string, err error) {
if err = m.GetType(M_MAILADDR); err != nil {
func (m *message) GetTypeMailaddr() (user, domain string, err error) {
if err = m.GetType(mMAILADDR); err != nil {
return
}
return m.GetMailaddr()
}

func (m *Message) PutInt(v int) {
func (m *message) PutInt(v int) {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(v))
m.Data = append(m.Data, b[:]...)
m.Len += 4
}

func (m *Message) PutUint32(v uint32) {
func (m *message) PutUint32(v uint32) {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], v)
m.Data = append(m.Data, b[:]...)
m.Len += 4
}

func (m *Message) PutString(s string) {
func (m *message) PutString(s string) {
m.Data = append(m.Data, append([]byte(s), 0)...)
m.Len += uint16(len(s)) + 1
}

func (m *Message) PutID(id uint64) {
func (m *message) PutID(id uint64) {
var b [8]byte
binary.LittleEndian.PutUint64(b[:], id)
m.Data = append(m.Data, b[:]...)
m.Len += 8
}

func (m *Message) PutType(t uint8) {
func (m *message) PutType(t uint8) {
m.Data = append(m.Data, t)
m.Len += 1
}

func (m *Message) PutTypeInt(v int) {
m.PutType(M_INT)
func (m *message) PutTypeInt(v int) {
m.PutType(mINT)
m.PutInt(v)
}

func (m *Message) PutTypeUint32(v uint32) {
m.PutType(M_UINT32)
func (m *message) PutTypeUint32(v uint32) {
m.PutType(mUINT32)
m.PutUint32(v)
}

func (m *Message) PutTypeString(s string) {
m.PutType(M_STRING)
func (m *message) PutTypeString(s string) {
m.PutType(mSTRING)
m.PutString(s)
}

func (m *Message) PutTypeID(id uint64) {
m.PutType(M_ID)
func (m *message) PutTypeID(id uint64) {
m.PutType(mID)
m.PutID(id)
}

+ 26
- 26
mproc.go View File

@@ -5,33 +5,33 @@ import (
)

const (
M_INT = iota
M_UINT32
M_SIZET
M_TIME
M_STRING
M_DATA
M_ID
M_EVPID
M_MSGID
M_SOCKADDR
M_MAILADDR
M_ENVELOPE
mINT = iota
mUINT32
mSIZET
mTIME
mSTRING
mDATA
mID
mEVPID
mMSGID
mSOCKADDR
mMAILADDR
mENVELOPE
)

var mprocTypeName = map[uint8]string{
M_INT: "M_INT",
M_UINT32: "M_UINT32",
M_SIZET: "M_SIZET",
M_TIME: "M_TIME",
M_STRING: "M_STRING",
M_DATA: "M_DATA",
M_ID: "M_ID",
M_EVPID: "M_EVPID",
M_MSGID: "M_MSGID",
M_SOCKADDR: "M_SOCKADDR",
M_MAILADDR: "M_MAILADDR",
M_ENVELOPE: "M_ENVELOPE",
mINT: "M_INT",
mUINT32: "M_UINT32",
mSIZET: "M_SIZET",
mTIME: "M_TIME",
mSTRING: "M_STRING",
mDATA: "M_DATA",
mID: "M_ID",
mEVPID: "M_EVPID",
mMSGID: "M_MSGID",
mSOCKADDR: "M_SOCKADDR",
mMAILADDR: "M_MAILADDR",
mENVELOPE: "M_ENVELOPE",
}

func mprocType(t uint8) string {
@@ -41,11 +41,11 @@ func mprocType(t uint8) string {
return fmt.Sprintf("UNKNOWN %d", t)
}

type MProcTypeErr struct {
type mprocTypeErr struct {
want, got uint8
}

func (err MProcTypeErr) Error() string {
func (err mprocTypeErr) Error() string {
return fmt.Sprintf("mproc: expected type %s, got %s",
mprocType(err.want), mprocType(err.got))
}

+ 11
- 0
smtpd.go View File

@@ -6,6 +6,17 @@ import (
"strings"
)

const (
// FilterVersion is the supported filter API version
FilterVersion = 51

// QueueVersion is the supported queue API version
QueueVersion = 1

// TableVersion is the supported table API version
TableVersion = 1
)

var (
// Debug flag
Debug bool

+ 37
- 41
table.go View File

@@ -11,29 +11,25 @@ import (
)

const (
TableAPIVersion = 1
)

const (
ProcTableOK = iota
ProcTableFail
ProcTableOpen
ProcTableClose
ProcTableUpdate
ProcTableCheck
ProcTableLookup
ProcTableFetch
procTableOK = iota
procTableFail
procTableOpen
procTableClose
procTableUpdate
procTableCheck
procTableLookup
procTableFetch
)

var procTableTypeName = map[uint32]string{
ProcTableOK: "PROC_TABLE_OK",
ProcTableFail: "PROC_TABLE_FAIL",
ProcTableOpen: "PROC_TABLE_OPEN",
ProcTableClose: "PROC_TABLE_CLOSE",
ProcTableUpdate: "PROC_TABLE_UPDATE",
ProcTableCheck: "PROC_TABLE_CHECK",
ProcTableLookup: "PROC_TABLE_LOOKUP",
ProcTableFetch: "PROC_TABLE_FETCH",
procTableOK: "PROC_TABLE_OK",
procTableFail: "PROC_TABLE_FAIL",
procTableOpen: "PROC_TABLE_OPEN",
procTableClose: "PROC_TABLE_CLOSE",
procTableUpdate: "PROC_TABLE_UPDATE",
procTableCheck: "PROC_TABLE_CHECK",
procTableLookup: "PROC_TABLE_LOOKUP",
procTableFetch: "PROC_TABLE_FETCH",
}

func procTableName(t uint32) string {
@@ -61,18 +57,18 @@ type Table struct {
Close func() error

c net.Conn
m *Message
m *message
closed bool
}

func (t *Table) Serve() error {
var err error

if t.c, err = NewConn(0); err != nil {
if t.c, err = newConn(0); err != nil {
return err
}

t.m = new(Message)
t.m = new(message)

for !t.closed {
if err = t.m.ReadFrom(t.c); err != nil {
@@ -96,15 +92,15 @@ type tableOpenParams struct {

func (t *Table) dispatch() (err error) {
switch t.m.Type {
case ProcTableOpen:
case procTableOpen:
/*
var op tableOpenParams
if err = t.getMessage(&op, maxLineSize+4); err != nil {
return
}

if op.Version != TableAPIVersion {
fatalf("table: bad API version %d (we support %d)", op.Version, TableAPIVersion)
if op.Version != TableVersion {
fatalf("table: bad API version %d (we support %d)", op.Version, TableVersion)
}
if bytes.IndexByte(op.Name[:], 0) <= 0 {
fatal("table: no name supplied")
@@ -113,8 +109,8 @@ func (t *Table) dispatch() (err error) {
var version uint32
if version, err = t.m.GetUint32(); err != nil {
return
} else if version != TableAPIVersion {
fatalf("table: expected API version %d, got %d", TableAPIVersion, version)
} else if version != TableVersion {
fatalf("table: expected API version %d, got %d", TableVersion, version)
}

var name string
@@ -126,15 +122,15 @@ func (t *Table) dispatch() (err error) {

debugf("table: version=%d,name=%q\n", version, name)

m := new(Message)
m.Type = ProcTableOK
m := new(message)
m.Type = procTableOK
m.Len = imsgHeaderSize
m.PID = uint32(os.Getpid())
if err = m.WriteTo(t.c); err != nil {
return
}

case ProcTableUpdate:
case procTableUpdate:
var r = 1

if t.Update != nil {
@@ -143,14 +139,14 @@ func (t *Table) dispatch() (err error) {
}
}

m := new(Message)
m.Type = ProcTableOK
m := new(message)
m.Type = procTableOK
m.PutInt(r)
if err = m.WriteTo(t.c); err != nil {
return
}

case ProcTableClose:
case procTableClose:
if t.Close != nil {
if err = t.Close(); err != nil {
return
@@ -160,7 +156,7 @@ func (t *Table) dispatch() (err error) {
t.closed = true
return

case ProcTableCheck:
case procTableCheck:
var service int
if service, err = t.m.GetInt(); err != nil {
return
@@ -188,7 +184,7 @@ func (t *Table) dispatch() (err error) {

log.Printf("table_check: result=%d\n", r)

case ProcTableLookup:
case procTableLookup:
var service int
if service, err = t.m.GetInt(); err != nil {
return
@@ -214,8 +210,8 @@ func (t *Table) dispatch() (err error) {
}
}

m := new(Message)
m.Type = ProcTableOK
m := new(message)
m.Type = procTableOK
m.PID = uint32(os.Getpid())
if val == "" {
m.PutInt(-1)
@@ -227,7 +223,7 @@ func (t *Table) dispatch() (err error) {
return
}

case ProcTableFetch:
case procTableFetch:
var service int
if service, err = t.m.GetInt(); err != nil {
return
@@ -248,8 +244,8 @@ func (t *Table) dispatch() (err error) {
}
}

m := new(Message)
m.Type = ProcTableOK
m := new(message)
m.Type = procTableOK
m.PID = uint32(os.Getpid())
if val == "" {
m.PutInt(-1)

Loading…
Cancel
Save