From 6b523e16a226aad629c7a8372370742e8aab1f31 Mon Sep 17 00:00:00 2001 From: Wijnand Modderman-Lenstra Date: Sat, 14 Feb 2026 15:51:55 +0100 Subject: [PATCH] Initial import --- cmd/generate-token/main.go | 60 ++++++++ compat.go | 29 ++++ lib.go | 156 +++++++++++++++++++++ plugin.go | 279 +++++++++++++++++++++++++++++++++++++ topic.go | 46 ++++++ 5 files changed, 570 insertions(+) create mode 100644 cmd/generate-token/main.go create mode 100644 compat.go create mode 100644 lib.go create mode 100644 plugin.go create mode 100644 topic.go diff --git a/cmd/generate-token/main.go b/cmd/generate-token/main.go new file mode 100644 index 0000000..0f80c3a --- /dev/null +++ b/cmd/generate-token/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "crypto" + "crypto/ed25519" + "encoding/hex" + "fmt" + "flag" + "log" + "os" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func main() { + flagKey := flag.String("key", "jwt.key", "key file") + flag.Parse() + + pub, key, err := loadKey(*flagKey) + if err != nil { + log.Fatalln(err) + } + + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jwt.MapClaims{ + "publickey": hex.EncodeToString(pub), + "iat": time.Now().UTC().Unix(), + }) + s, err := token.SignedString(key) + if err != nil { + log.Fatalln(err) + } + fmt.Println(s) +} + +func loadKey(name string) (ed25519.PublicKey, crypto.Signer, error) { + b, err := os.ReadFile(name) + if err == nil { + key := ed25519.NewKeyFromSeed(b) + pub := key.Public().(ed25519.PublicKey) + return pub, key, nil + } + if err != nil && !os.IsNotExist(err) { + return nil, nil, err + } + + var ( + key ed25519.PrivateKey + pub ed25519.PublicKey + ) + if pub, key, err = ed25519.GenerateKey(nil); err != nil { + return nil, nil, err + } + + if err = os.WriteFile(name, key.Seed(), 0600); err != nil { + return nil, nil, err + } + + return pub, key, nil +} diff --git a/compat.go b/compat.go new file mode 100644 index 0000000..90f0861 --- /dev/null +++ b/compat.go @@ -0,0 +1,29 @@ +package main + +/* +#include +typedef const char const_char; +struct mosquitto_opt* accessArray( struct mosquitto_opt* arrptr, int i) +{ + return arrptr + i; +} +*/ +import "C" + +// Using //export in a file places a restriction on the preamble: it must not contain any definitions, only declarations. + +// extractOptions coverts a C mosquitto option array into a GO map +func extractOptions(arrptr *C.struct_mosquitto_opt, length C.int) map[string]string { + opts := make(map[string]string, length) + var i C.int + for i = 0; i < length; i++ { + c_opt := C.accessArray(arrptr, i) + opts[C.GoString(c_opt.key)] = C.GoString(c_opt.value) + } + return opts +} + +// goStringFromConstant converts a constant C string into a GO string +func goStringFromConstant(cstr *C.const_char) string { + return C.GoString((*C.char)(cstr)) +} diff --git a/lib.go b/lib.go new file mode 100644 index 0000000..15ff560 --- /dev/null +++ b/lib.go @@ -0,0 +1,156 @@ +package main + +/* +#include +#include +typedef struct mosquitto mosquitto; +typedef const struct mosquitto_acl_msg const_mosquitto_acl_msg; +typedef const char const_char; +*/ +import "C" +import ( + "fmt" + "log" + "net/http" + "os" + "unsafe" +) + +var ( + logger *log.Logger + file *os.File = nil +) + +//export mosquitto_auth_plugin_version +/* + * Returns the value of MOSQ_AUTH_PLUGIN_VERSION defined in the mosquitto header file that the plugin was compiled + * against. + */ +func mosquitto_auth_plugin_version() C.int { + return C.MOSQ_AUTH_PLUGIN_VERSION +} + +//export mosquitto_auth_plugin_init +/* + * Initialises the plugin. + */ +func mosquitto_auth_plugin_init(cUserData *unsafe.Pointer, cOpts *C.struct_mosquitto_opt, cOptCount C.int) C.int { + var err error + // copy opts from the C world into Go + optMap := extractOptions(cOpts, cOptCount) + + // initialise logger + if logger, file, err = initialiseLogger(optMap[optLogDest]); err != nil { + fmt.Printf("error initialising logger, %s", err) + return C.MOSQ_ERR_AUTH + } + logger.Println("plugin initializing") + + // initialise the user data that will be used in subsequent plugin calls + userData, err := initialiseUserData(optMap) + if err != nil { + logger.Println("initialiseUserData failed with err:", err) + return C.MOSQ_ERR_AUTH + } + *cUserData = unsafe.Pointer(&userData) + + return C.MOSQ_ERR_SUCCESS +} + +//export mosquitto_auth_plugin_cleanup +/* + * Cleans up the plugin before the server shuts down. + */ +func mosquitto_auth_plugin_cleanup(cUserData unsafe.Pointer, cOpts *C.struct_mosquitto_opt, cOptCount C.int) C.int { + //logger.Println("enter - plugin cleanup") + // close logfile + if file != nil { + file.Sync() + file.Close() + file = nil + } + // set the client cache to nil so it can be garage collected + clearUserData((*userData)(cUserData)) + + //logger.Println("leave - plugin cleanup") + logger = nil + return C.MOSQ_ERR_SUCCESS +} + +//export mosquitto_auth_acl_check +/* + * Checks whether a client is authorised to read from or write to a topic. + */ +func mosquitto_auth_acl_check(cUserData unsafe.Pointer, cAccess C.int, cClient *C.mosquitto, cMsg *C.const_mosquitto_acl_msg) C.int { + if cUserData == nil { + logger.Printf("auth_acl_check[%p]: missing user data", cClient) + return C.MOSQ_ERR_AUTH + } + + access := access(cAccess) + allow, err := authorise(http.DefaultClient, (*userData)(cUserData), access, unsafe.Pointer(cClient), + C.GoString(cMsg.topic)) + if err != nil { + logger.Printf("auth_acl_check[%p]: error: %v", cClient, err) + return C.MOSQ_ERR_AUTH + } + if !allow { + logger.Printf("auth_acl_check[%p]: acl %q denied", cClient, access) + return C.MOSQ_ERR_PLUGIN_DEFER + } + logger.Printf("auth_acl_check[%p]: acl %q granted", cClient, access) + return C.MOSQ_ERR_SUCCESS +} + +//export mosquitto_auth_unpwd_check +/* + * Authenticates the client by checking the supplied username and password. + */ +func mosquitto_auth_unpwd_check(cUserData unsafe.Pointer, cClient *C.mosquitto, cUsername, cPassword *C.const_char) C.int { + if cUsername == nil || cPassword == nil { + return C.MOSQ_ERR_AUTH + } + + username := goStringFromConstant(cUsername) + password := goStringFromConstant(cPassword) + //logger.Printf("u: %s, p: %s\n", username, password) + + authorised, err := authenticate((*userData)(cUserData), unsafe.Pointer(cClient), username, password) + if err != nil { + logger.Printf("auth_unpwd_check[%p]: user %q error: %v", cClient, username, err) + return C.MOSQ_ERR_AUTH + } + + if !authorised { + logger.Printf("auth_unpwd_check[%p]: user %q unauthorized", cClient, username) + return C.MOSQ_ERR_PLUGIN_DEFER + } + logger.Printf("auth_unpwd_check[%p]: user %q authorized", cClient, username) + return C.MOSQ_ERR_SUCCESS +} + +//export mosquitto_auth_security_init +/* + * No-op function. Included to satisfy the plugin contract to Mosquitto. + */ +func mosquitto_auth_security_init(cUserData unsafe.Pointer, cOpts *C.struct_mosquitto_opt, cOptCount C.int, cReload C.bool) C.int { + return C.MOSQ_ERR_SUCCESS +} + +//export mosquitto_auth_security_cleanup +/* + * No-op function. Included to satisfy the plugin contract to Mosquitto. + */ +func mosquitto_auth_security_cleanup(cUserData unsafe.Pointer, cOpts *C.struct_mosquitto_opt, cOptCount C.int, cReload C.bool) C.int { + return C.MOSQ_ERR_SUCCESS +} + +//export mosquitto_auth_psk_key_get +/* + * No-op function. Included to satisfy the plugin contract to Mosquitto. + */ +func mosquitto_auth_psk_key_get(cUserData unsafe.Pointer, cClient *C.mosquitto, cHint, cIdentity *C.const_char, cKey *C.char, cMaxKeyLen C.int) C.int { + return C.MOSQ_ERR_SUCCESS +} + +func main() {} diff --git a/plugin.go b/plugin.go new file mode 100644 index 0000000..4fa6cc5 --- /dev/null +++ b/plugin.go @@ -0,0 +1,279 @@ +package main + +import ( + "crypto/ed25519" + "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httputil" + "os" + "strings" + "time" + "unsafe" + "regexp" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + readRE = regexp.MustCompile(`mqtt:read:([\\p{L}/#\+]+)`) + writeRE = regexp.MustCompile(`mqtt:write:([\\p{L}/#\+]+)`) +) + +// access describes the type of access to a topic that the client is requesting +type access int + +func (a access) String() string { + switch a { + case read: + return "read" + case write: + return "write" + case subscribe: + return "subscribe" + default: + return "unknown" + } +} + +const ( + read access = 0x01 // read from a topic + write access = 0x02 // write to a topic + subscribe access = 0x04 // subscribe to a topic +) + +// clientAuthorisation contains the authorisation granted to the client +type clientAuthorisation struct { + write string + read string + expiration time.Time +} + +// userData contains the persistent data that is kept between plugin calls +type userData struct { + endpoint string + clientID string + clientSecret string + // clientCache to store client data between API calls. The client pointer value is used as the key. + clientCache map[unsafe.Pointer]clientAuthorisation +} + +// Introspect creates a request to introspect the given OAuth2 token +func (u userData) Introspect(token string) (*http.Request, error) { + req, err := http.NewRequest(http.MethodPost, + u.endpoint, + strings.NewReader(fmt.Sprintf("token=%s", token))) + + if err != nil { + return nil, err + } + + req.SetBasicAuth(u.clientID, u.clientSecret) + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + return req, nil +} + +const ( + optPrefix = "maj_" + + optEndpoint = optPrefix + "endpoint" + optClientID = optPrefix + "client_id" + optClientSecret = optPrefix + "client_secret" + // optional + optLogDest = optPrefix + "log_dest" +) + +var requiredOpts = [...]string{ + optEndpoint, + optClientID, + optClientSecret, +} + +// initialiseUserData initialises the data shared between plugin calls +func initialiseUserData(opts map[string]string) (userData, error) { + var data userData + // check all the required options have been supplied + for _, o := range requiredOpts { + if _, ok := opts[o]; !ok { + return data, fmt.Errorf("missing field %s", o) + } + } + + // copy over user data values + data.endpoint = opts[optEndpoint] + data.clientID = opts[optClientID] + data.clientSecret = opts[optClientSecret] + + // make client cache + data.clientCache = make(map[unsafe.Pointer]clientAuthorisation) + return data, nil +} + +const ( + // constants used by the config file to switch log destination + destNone = "none" + destFile = "file" + destStdout = "stdout" +) + +// initialiseLogger initialises the logger depending on the fields in the supplied configuration string +// Defaults to stdout if the input string is empty or unrecognised. +// Returns an error if logging to a file is requested but fails. +func initialiseLogger(s string) (l *log.Logger, f *os.File, err error) { + settings := strings.Fields(s) + var w = ioutil.Discard + if len(settings) > 0 { + switch settings[0] { + case destFile: + if len(settings) < 2 { + return l, f, fmt.Errorf("file path missing") + } + var err error + f, err = os.OpenFile(settings[1], os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return l, f, err + } + w = f + case destStdout: + w = os.Stdout + default: + fmt.Printf("WARNING: unknown debug setting, %s", settings) + } + } + return log.New(w, "[mosq-auth-jwt] ", log.LstdFlags|log.Lmsgprefix), f, nil +} + +// clearUserData clears the userData struct so that memory can be garbage collected +func clearUserData(user *userData) { + user.clientCache = nil +} + +// doer is an interface that represents a http client +type doer interface { + Do(req *http.Request) (*http.Response, error) +} + +// httpResponseError indicates that an unexpected response has been returned by the server +type httpResponseError struct { + response *http.Response +} + +func (e httpResponseError) Error() string { + statusCode := e.response.StatusCode + if b, err := httputil.DumpResponse(e.response, true); err == nil { + return string(b) + } + return fmt.Sprintf("received status code %d", statusCode) +} + +const ( + retryLimit = 4 +) + +// withBackOff retries the do function with back off until the max retry limit has been reached +func withBackOff(maxRetry int, do func() (bool, *http.Response, error)) (response *http.Response, err error) { + const backOff = 100 * time.Millisecond + retry := true + for i, b := 0, time.Duration(0); retry && i < maxRetry; i, b = i+1, b+backOff { + time.Sleep(b) // a zero duration will return immediately + retry, response, err = do() + } + return +} + +// checkResponseStatusCode checks the status code of the response and decides whether a retry is required +func checkResponseStatusCode(response *http.Response) (bool, error) { + switch response.StatusCode { + case http.StatusOK: + return false, nil + case http.StatusInternalServerError, http.StatusServiceUnavailable: + return true, httpResponseError{response} + default: + return false, httpResponseError{response} + } +} + +// Checks whether a client is authorised to write or read to a topic. +func authorise(httpDo doer, user *userData, access access, client unsafe.Pointer, topic string) (bool, error) { + // get cache data + authData, ok := user.clientCache[client] + if !ok { + // the user will not be in the cache if it was authenticated by mosquitto or another plugin + return false, nil + } + + // check whether the token has expired + //logger.Println(authData) + if time.Now().After(authData.expiration) { + logger.Printf("authorise[%p]: token has expired", client) + return false, nil + } + + allow := false + switch access { + case subscribe, read: + allow = matchTopic(authData.read, topic) + case write: + allow = matchTopic(authData.write, topic) + default: + return false, fmt.Errorf("Unexpected access request %d\n", access) + } + return allow, nil +} + +// parseFilter parses the MQTT topic filter from the scopes +// Assumes that the filter will be found in the the first capturing group of the regexp if the entire expression matches +func parseFilter(re *regexp.Regexp, scope string) string { + m := re.FindStringSubmatch(scope) + if len(m) < 2 { + return "" + } + return m[1] +} + +/* + * authenticate the client by checking the supplied username and password. + * an OAuth2 Access Token is passed in as the password. + */ +func authenticate(user *userData, client unsafe.Pointer, username, password string) (bool, error) { + _, err := jwt.Parse(password, func(token *jwt.Token) (any, error) { + if claims, ok := token.Claims.(jwt.MapClaims); ok { + keyHex, ok := claims["publickey"].(string) + if !ok { + return nil, errors.New("no publickey in claims") + } + + keyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return nil, err + } + if len(keyBytes) != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid public key size %d", len(keyBytes)) + } + return ed25519.PublicKey(keyBytes), nil + } + return nil, errors.New("no map claims") + }, + jwt.WithValidMethods([]string{jwt.SigningMethodEdDSA.Alg()}), + jwt.WithIssuedAt(), + jwt.WithLeeway(time.Hour * 12), + ) + if err != nil { + logger.Printf("authenticate[%s]: invalid token: %v", username, err) + return false, err + } + + logger.Printf("authenticate[%s]: token valid", username) + user.clientCache[client] = clientAuthorisation{ + read: "#", + write: "#", + expiration: time.Now().Add(24 * time.Hour), + } + + return true, nil +} diff --git a/topic.go b/topic.go new file mode 100644 index 0000000..244c266 --- /dev/null +++ b/topic.go @@ -0,0 +1,46 @@ +package main + +import ( + "strings" +) + +const ( + levelSep = "/" + multiLevel = "#" + singleLevel = "+" +) + +// returns true if the MQTT topic filter matches the topic name +func matchTopic(filter, name string) bool { + // split both strings into a root level and a remainder + f := strings.SplitN(filter, levelSep, 2) + n := strings.SplitN(name, levelSep, 2) + + // check whether the root levels match + switch f[0] { + case multiLevel: + return true + case n[0], singleLevel: + // roots match, continue + break + default: + // roots do not match + return false + } + + // check whether the filter or name consists of a single level + switch { + case len(f) == 1 && len(n) == 1: + return true + case len(f) == 1: + return false + case len(n) == 1: + if f[1] == multiLevel { + // special case is when filter has 1 extra level which is the '#' character + return true + } + return false + + } + return matchTopic(f[1], n[1]) +}