modules-durpify/cmd/middleware/auth.go

156 lines
3.1 KiB
Go
Raw Normal View History

2024-09-02 13:38:46 -05:00
package middleware
import (
"context"
"errors"
"gitlab.com/developerdurp/stdmodels"
"net/http"
"strings"
"time"
"github.com/MicahParks/keyfunc"
"github.com/golang-jwt/jwt/v4"
"gitlab.com/developerdurp/logger"
)
func InitAuthMiddleware(allowedGroups []string, jwks string) *AuthConfig {
return &AuthConfig{
allowedGroups: allowedGroups,
jwks: jwks,
}
}
type AuthConfig struct {
allowedGroups []string
jwks string
}
type StandardMessage struct {
Message string `json:"message" example:"message"`
}
func (cfg *AuthConfig) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var groups []string
tokenString, err := getToken(w)
if err != nil {
resp := stdmodels.NewFailureResponse(
err.Error(),
http.StatusUnauthorized,
[]string{},
)
resp.SendReponse(w)
}
token, err := cfg.validateToken(tokenString)
if err != nil {
resp := stdmodels.NewFailureResponse(
"Failed to Validate Token",
http.StatusUnauthorized,
[]string{err.Error()},
)
resp.SendReponse(w)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
resp := stdmodels.NewFailureResponse(
"Invalid Authorization token claim",
http.StatusUnauthorized,
[]string{},
)
resp.SendReponse(w)
return
}
groupsClaim, ok := claims["groups"].([]interface{})
if !ok {
resp := stdmodels.NewFailureResponse(
"Missing or invalid groups in the token",
http.StatusUnauthorized,
[]string{},
)
resp.SendReponse(w)
return
}
for _, group := range groupsClaim {
if groupName, ok := group.(string); ok {
groups = append(groups, groupName)
}
}
isAllowed := false
for _, allowedGroup := range cfg.allowedGroups {
for _, group := range groups {
if group == allowedGroup {
isAllowed = true
break
}
}
if isAllowed {
break
}
}
if !isAllowed {
resp := stdmodels.NewFailureResponse(
"Unauthorized to use this endpoint",
http.StatusUnauthorized,
[]string{},
)
resp.SendReponse(w)
return
}
next.ServeHTTP(w, r)
})
}
func (cfg *AuthConfig) validateToken(tokenString string) (*jwt.Token, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
options := keyfunc.Options{
Ctx: ctx,
RefreshErrorHandler: func(err error) {
logger.LogError("There was an error with the jwt.Keyfunc" + err.Error())
},
RefreshInterval: time.Hour,
RefreshRateLimit: time.Minute * 5,
RefreshTimeout: time.Second * 10,
RefreshUnknownKID: true,
}
jwks, err := keyfunc.Get(cfg.jwks, options)
defer jwks.EndBackground()
if err != nil {
return nil, errors.New("Failed to get JWKS")
}
token, err := jwt.Parse(tokenString, jwks.Keyfunc)
if err != nil {
return nil, errors.New("Failed to Parse JWT")
}
if !token.Valid {
return nil, errors.New("Invalid Token")
}
return token, nil
}
func getToken(w http.ResponseWriter) (string, error) {
tokenString := w.Header().Get("Authorization")
if tokenString == "" {
return "", errors.New("No Token Detected")
}
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
return tokenString, nil
}