package middleware import ( "context" "crypto/rsa" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "net/http" "sync" "time" "go-server/internal/config" "go-server/internal/database" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) // AuthRequired is a middleware to validate JWT token func AuthRequired() gin.HandlerFunc { return func(c *gin.Context) { jwtHeader := c.GetHeader("x-jwt") if jwtHeader == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"}) return } // Validate the token claims, err := validateToken(jwtHeader) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) return } // Attach user ID to the context c.Set("userID", claims.UserID) c.Next() } } // JWTClaims represents the JWT claims type JWTClaims struct { UserID uuid.UUID `json:"user_id"` jwt.RegisteredClaims } var ( publicKeyMu sync.RWMutex publicKey *rsa.PublicKey ) func init() { // Initialize cached public key at startup if cfg, err := config.LoadConfig(); err == nil { if pk, err := parsePublicKeyFromConfig(cfg.JWT.PublicKey); err == nil { publicKeyMu.Lock() publicKey = pk publicKeyMu.Unlock() } } // Update cached key on config changes config.RegisterChangeListener(func(c *config.Config) { if c == nil { return } if pk, err := parsePublicKeyFromConfig(c.JWT.PublicKey); err == nil { publicKeyMu.Lock() publicKey = pk publicKeyMu.Unlock() } }) } func parsePublicKeyFromConfig(publicKeyB64 string) (*rsa.PublicKey, error) { if publicKeyB64 == "" { return nil, errors.New("missing JWT public key") } publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64) if err != nil { return nil, fmt.Errorf("failed to decode public key: %w", err) } pk, err := jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) if err != nil { return nil, fmt.Errorf("failed to parse public key: %w", err) } return pk, nil } // validateToken validates and parses the JWT token func validateToken(tokenStr string) (*JWTClaims, error) { // 1) Try Redis cache first if claims := tryGetCachedClaims(tokenStr); claims != nil { // Double-check expiration in case clock skew and guard rails if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) { return nil, errors.New("token expired") } return claims, nil } // 2) Validate signature using cached public key publicKeyMu.RLock() pk := publicKey publicKeyMu.RUnlock() if pk == nil { return nil, errors.New("JWT public key not initialized") } token, err := jwt.ParseWithClaims(tokenStr, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return pk, nil }) if err != nil { return nil, err } claims, ok := token.Claims.(*JWTClaims) if !ok || !token.Valid { return nil, errors.New("invalid token") } // Check token expiration if claims.ExpiresAt.Before(time.Now()) { return nil, errors.New("token expired") } // 3) Store in Redis cache with TTL until expiry cacheClaims(tokenStr, claims) return claims, nil } // --- Redis cache helpers --- const jwtCachePrefix = "jwt:cache:" type cachedClaims struct { UserID string `json:"user_id"` Exp int64 `json:"exp"` } func tryGetCachedClaims(tokenStr string) *JWTClaims { r := database.GetRedis() if r == nil { return nil } ctx := context.Background() key := jwtCachePrefix + sha256Hex(tokenStr) raw, err := r.Get(ctx, key).Bytes() if err != nil { return nil } var cc cachedClaims if err := json.Unmarshal(raw, &cc); err != nil { return nil } if cc.Exp <= time.Now().Unix() { return nil } uid, err := uuid.Parse(cc.UserID) if err != nil { return nil } return &JWTClaims{ UserID: uid, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Unix(cc.Exp, 0)), }, } } func cacheClaims(tokenStr string, claims *JWTClaims) { if claims == nil || claims.ExpiresAt == nil { return } ttl := time.Until(claims.ExpiresAt.Time) if ttl <= 0 { return } r := database.GetRedis() if r == nil { return } ctx := context.Background() key := jwtCachePrefix + sha256Hex(tokenStr) payload, err := json.Marshal(cachedClaims{ UserID: claims.UserID.String(), Exp: claims.ExpiresAt.Unix(), }) if err != nil { return } _ = r.Set(ctx, key, payload, ttl).Err() } func sha256Hex(s string) string { sum := sha256.Sum256([]byte(s)) return hex.EncodeToString(sum[:]) }