Initial import
This commit is contained in:
279
plugin.go
Normal file
279
plugin.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
"regexp"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
readRE = regexp.MustCompile(`mqtt:read:([\\p{L}/#\+]+)`)
|
||||
writeRE = regexp.MustCompile(`mqtt:write:([\\p{L}/#\+]+)`)
|
||||
)
|
||||
|
||||
// access describes the type of access to a topic that the client is requesting
|
||||
type access int
|
||||
|
||||
func (a access) String() string {
|
||||
switch a {
|
||||
case read:
|
||||
return "read"
|
||||
case write:
|
||||
return "write"
|
||||
case subscribe:
|
||||
return "subscribe"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
read access = 0x01 // read from a topic
|
||||
write access = 0x02 // write to a topic
|
||||
subscribe access = 0x04 // subscribe to a topic
|
||||
)
|
||||
|
||||
// clientAuthorisation contains the authorisation granted to the client
|
||||
type clientAuthorisation struct {
|
||||
write string
|
||||
read string
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
// userData contains the persistent data that is kept between plugin calls
|
||||
type userData struct {
|
||||
endpoint string
|
||||
clientID string
|
||||
clientSecret string
|
||||
// clientCache to store client data between API calls. The client pointer value is used as the key.
|
||||
clientCache map[unsafe.Pointer]clientAuthorisation
|
||||
}
|
||||
|
||||
// Introspect creates a request to introspect the given OAuth2 token
|
||||
func (u userData) Introspect(token string) (*http.Request, error) {
|
||||
req, err := http.NewRequest(http.MethodPost,
|
||||
u.endpoint,
|
||||
strings.NewReader(fmt.Sprintf("token=%s", token)))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.SetBasicAuth(u.clientID, u.clientSecret)
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
const (
|
||||
optPrefix = "maj_"
|
||||
|
||||
optEndpoint = optPrefix + "endpoint"
|
||||
optClientID = optPrefix + "client_id"
|
||||
optClientSecret = optPrefix + "client_secret"
|
||||
// optional
|
||||
optLogDest = optPrefix + "log_dest"
|
||||
)
|
||||
|
||||
var requiredOpts = [...]string{
|
||||
optEndpoint,
|
||||
optClientID,
|
||||
optClientSecret,
|
||||
}
|
||||
|
||||
// initialiseUserData initialises the data shared between plugin calls
|
||||
func initialiseUserData(opts map[string]string) (userData, error) {
|
||||
var data userData
|
||||
// check all the required options have been supplied
|
||||
for _, o := range requiredOpts {
|
||||
if _, ok := opts[o]; !ok {
|
||||
return data, fmt.Errorf("missing field %s", o)
|
||||
}
|
||||
}
|
||||
|
||||
// copy over user data values
|
||||
data.endpoint = opts[optEndpoint]
|
||||
data.clientID = opts[optClientID]
|
||||
data.clientSecret = opts[optClientSecret]
|
||||
|
||||
// make client cache
|
||||
data.clientCache = make(map[unsafe.Pointer]clientAuthorisation)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// constants used by the config file to switch log destination
|
||||
destNone = "none"
|
||||
destFile = "file"
|
||||
destStdout = "stdout"
|
||||
)
|
||||
|
||||
// initialiseLogger initialises the logger depending on the fields in the supplied configuration string
|
||||
// Defaults to stdout if the input string is empty or unrecognised.
|
||||
// Returns an error if logging to a file is requested but fails.
|
||||
func initialiseLogger(s string) (l *log.Logger, f *os.File, err error) {
|
||||
settings := strings.Fields(s)
|
||||
var w = ioutil.Discard
|
||||
if len(settings) > 0 {
|
||||
switch settings[0] {
|
||||
case destFile:
|
||||
if len(settings) < 2 {
|
||||
return l, f, fmt.Errorf("file path missing")
|
||||
}
|
||||
var err error
|
||||
f, err = os.OpenFile(settings[1], os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return l, f, err
|
||||
}
|
||||
w = f
|
||||
case destStdout:
|
||||
w = os.Stdout
|
||||
default:
|
||||
fmt.Printf("WARNING: unknown debug setting, %s", settings)
|
||||
}
|
||||
}
|
||||
return log.New(w, "[mosq-auth-jwt] ", log.LstdFlags|log.Lmsgprefix), f, nil
|
||||
}
|
||||
|
||||
// clearUserData clears the userData struct so that memory can be garbage collected
|
||||
func clearUserData(user *userData) {
|
||||
user.clientCache = nil
|
||||
}
|
||||
|
||||
// doer is an interface that represents a http client
|
||||
type doer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// httpResponseError indicates that an unexpected response has been returned by the server
|
||||
type httpResponseError struct {
|
||||
response *http.Response
|
||||
}
|
||||
|
||||
func (e httpResponseError) Error() string {
|
||||
statusCode := e.response.StatusCode
|
||||
if b, err := httputil.DumpResponse(e.response, true); err == nil {
|
||||
return string(b)
|
||||
}
|
||||
return fmt.Sprintf("received status code %d", statusCode)
|
||||
}
|
||||
|
||||
const (
|
||||
retryLimit = 4
|
||||
)
|
||||
|
||||
// withBackOff retries the do function with back off until the max retry limit has been reached
|
||||
func withBackOff(maxRetry int, do func() (bool, *http.Response, error)) (response *http.Response, err error) {
|
||||
const backOff = 100 * time.Millisecond
|
||||
retry := true
|
||||
for i, b := 0, time.Duration(0); retry && i < maxRetry; i, b = i+1, b+backOff {
|
||||
time.Sleep(b) // a zero duration will return immediately
|
||||
retry, response, err = do()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// checkResponseStatusCode checks the status code of the response and decides whether a retry is required
|
||||
func checkResponseStatusCode(response *http.Response) (bool, error) {
|
||||
switch response.StatusCode {
|
||||
case http.StatusOK:
|
||||
return false, nil
|
||||
case http.StatusInternalServerError, http.StatusServiceUnavailable:
|
||||
return true, httpResponseError{response}
|
||||
default:
|
||||
return false, httpResponseError{response}
|
||||
}
|
||||
}
|
||||
|
||||
// Checks whether a client is authorised to write or read to a topic.
|
||||
func authorise(httpDo doer, user *userData, access access, client unsafe.Pointer, topic string) (bool, error) {
|
||||
// get cache data
|
||||
authData, ok := user.clientCache[client]
|
||||
if !ok {
|
||||
// the user will not be in the cache if it was authenticated by mosquitto or another plugin
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check whether the token has expired
|
||||
//logger.Println(authData)
|
||||
if time.Now().After(authData.expiration) {
|
||||
logger.Printf("authorise[%p]: token has expired", client)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
allow := false
|
||||
switch access {
|
||||
case subscribe, read:
|
||||
allow = matchTopic(authData.read, topic)
|
||||
case write:
|
||||
allow = matchTopic(authData.write, topic)
|
||||
default:
|
||||
return false, fmt.Errorf("Unexpected access request %d\n", access)
|
||||
}
|
||||
return allow, nil
|
||||
}
|
||||
|
||||
// parseFilter parses the MQTT topic filter from the scopes
|
||||
// Assumes that the filter will be found in the the first capturing group of the regexp if the entire expression matches
|
||||
func parseFilter(re *regexp.Regexp, scope string) string {
|
||||
m := re.FindStringSubmatch(scope)
|
||||
if len(m) < 2 {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
/*
|
||||
* authenticate the client by checking the supplied username and password.
|
||||
* an OAuth2 Access Token is passed in as the password.
|
||||
*/
|
||||
func authenticate(user *userData, client unsafe.Pointer, username, password string) (bool, error) {
|
||||
_, err := jwt.Parse(password, func(token *jwt.Token) (any, error) {
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
keyHex, ok := claims["publickey"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("no publickey in claims")
|
||||
}
|
||||
|
||||
keyBytes, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(keyBytes) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid public key size %d", len(keyBytes))
|
||||
}
|
||||
return ed25519.PublicKey(keyBytes), nil
|
||||
}
|
||||
return nil, errors.New("no map claims")
|
||||
},
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodEdDSA.Alg()}),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithLeeway(time.Hour * 12),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Printf("authenticate[%s]: invalid token: %v", username, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
logger.Printf("authenticate[%s]: token valid", username)
|
||||
user.clientCache[client] = clientAuthorisation{
|
||||
read: "#",
|
||||
write: "#",
|
||||
expiration: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
Reference in New Issue
Block a user