Files
futur-web-app/server/internal/middleware/auth.go
2025-11-03 12:24:01 +02:00

206 lines
4.6 KiB
Go

package middleware
import (
"context"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"time"
"go-server/internal/config"
"go-server/internal/database"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// AuthRequired is a middleware to validate JWT token
func AuthRequired() gin.HandlerFunc {
return func(c *gin.Context) {
jwtHeader := c.GetHeader("x-jwt")
if jwtHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"})
return
}
// Validate the token
claims, err := validateToken(jwtHeader)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
// Attach user ID to the context
c.Set("userID", claims.UserID)
c.Next()
}
}
// JWTClaims represents the JWT claims
type JWTClaims struct {
UserID uuid.UUID `json:"user_id"`
jwt.RegisteredClaims
}
var (
publicKeyMu sync.RWMutex
publicKey *rsa.PublicKey
)
func init() {
// Initialize cached public key at startup
if cfg, err := config.LoadConfig(); err == nil {
if pk, err := parsePublicKeyFromConfig(cfg.JWT.PublicKey); err == nil {
publicKeyMu.Lock()
publicKey = pk
publicKeyMu.Unlock()
}
}
// Update cached key on config changes
config.RegisterChangeListener(func(c *config.Config) {
if c == nil {
return
}
if pk, err := parsePublicKeyFromConfig(c.JWT.PublicKey); err == nil {
publicKeyMu.Lock()
publicKey = pk
publicKeyMu.Unlock()
}
})
}
func parsePublicKeyFromConfig(publicKeyB64 string) (*rsa.PublicKey, error) {
if publicKeyB64 == "" {
return nil, errors.New("missing JWT public key")
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to decode public key: %w", err)
}
pk, err := jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
return pk, nil
}
// validateToken validates and parses the JWT token
func validateToken(tokenStr string) (*JWTClaims, error) {
// 1) Try Redis cache first
if claims := tryGetCachedClaims(tokenStr); claims != nil {
// Double-check expiration in case clock skew and guard rails
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) {
return nil, errors.New("token expired")
}
return claims, nil
}
// 2) Validate signature using cached public key
publicKeyMu.RLock()
pk := publicKey
publicKeyMu.RUnlock()
if pk == nil {
return nil, errors.New("JWT public key not initialized")
}
token, err := jwt.ParseWithClaims(tokenStr, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return pk, nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, errors.New("invalid token")
}
// Check token expiration
if claims.ExpiresAt.Before(time.Now()) {
return nil, errors.New("token expired")
}
// 3) Store in Redis cache with TTL until expiry
cacheClaims(tokenStr, claims)
return claims, nil
}
// --- Redis cache helpers ---
const jwtCachePrefix = "jwt:cache:"
type cachedClaims struct {
UserID string `json:"user_id"`
Exp int64 `json:"exp"`
}
func tryGetCachedClaims(tokenStr string) *JWTClaims {
r := database.GetRedis()
if r == nil {
return nil
}
ctx := context.Background()
key := jwtCachePrefix + sha256Hex(tokenStr)
raw, err := r.Get(ctx, key).Bytes()
if err != nil {
return nil
}
var cc cachedClaims
if err := json.Unmarshal(raw, &cc); err != nil {
return nil
}
if cc.Exp <= time.Now().Unix() {
return nil
}
uid, err := uuid.Parse(cc.UserID)
if err != nil {
return nil
}
return &JWTClaims{
UserID: uid,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Unix(cc.Exp, 0)),
},
}
}
func cacheClaims(tokenStr string, claims *JWTClaims) {
if claims == nil || claims.ExpiresAt == nil {
return
}
ttl := time.Until(claims.ExpiresAt.Time)
if ttl <= 0 {
return
}
r := database.GetRedis()
if r == nil {
return
}
ctx := context.Background()
key := jwtCachePrefix + sha256Hex(tokenStr)
payload, err := json.Marshal(cachedClaims{
UserID: claims.UserID.String(),
Exp: claims.ExpiresAt.Unix(),
})
if err != nil {
return
}
_ = r.Set(ctx, key, payload, ttl).Err()
}
func sha256Hex(s string) string {
sum := sha256.Sum256([]byte(s))
return hex.EncodeToString(sum[:])
}