206 lines
4.6 KiB
Go
206 lines
4.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"go-server/internal/config"
|
|
"go-server/internal/database"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// AuthRequired is a middleware to validate JWT token
|
|
func AuthRequired() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
jwtHeader := c.GetHeader("x-jwt")
|
|
if jwtHeader == "" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"})
|
|
return
|
|
}
|
|
|
|
// Validate the token
|
|
claims, err := validateToken(jwtHeader)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Attach user ID to the context
|
|
c.Set("userID", claims.UserID)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// JWTClaims represents the JWT claims
|
|
type JWTClaims struct {
|
|
UserID uuid.UUID `json:"user_id"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
var (
|
|
publicKeyMu sync.RWMutex
|
|
publicKey *rsa.PublicKey
|
|
)
|
|
|
|
func init() {
|
|
// Initialize cached public key at startup
|
|
if cfg, err := config.LoadConfig(); err == nil {
|
|
if pk, err := parsePublicKeyFromConfig(cfg.JWT.PublicKey); err == nil {
|
|
publicKeyMu.Lock()
|
|
publicKey = pk
|
|
publicKeyMu.Unlock()
|
|
}
|
|
}
|
|
// Update cached key on config changes
|
|
config.RegisterChangeListener(func(c *config.Config) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
if pk, err := parsePublicKeyFromConfig(c.JWT.PublicKey); err == nil {
|
|
publicKeyMu.Lock()
|
|
publicKey = pk
|
|
publicKeyMu.Unlock()
|
|
}
|
|
})
|
|
}
|
|
|
|
func parsePublicKeyFromConfig(publicKeyB64 string) (*rsa.PublicKey, error) {
|
|
if publicKeyB64 == "" {
|
|
return nil, errors.New("missing JWT public key")
|
|
}
|
|
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode public key: %w", err)
|
|
}
|
|
pk, err := jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
|
}
|
|
return pk, nil
|
|
}
|
|
|
|
// validateToken validates and parses the JWT token
|
|
func validateToken(tokenStr string) (*JWTClaims, error) {
|
|
// 1) Try Redis cache first
|
|
if claims := tryGetCachedClaims(tokenStr); claims != nil {
|
|
// Double-check expiration in case clock skew and guard rails
|
|
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) {
|
|
return nil, errors.New("token expired")
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
// 2) Validate signature using cached public key
|
|
publicKeyMu.RLock()
|
|
pk := publicKey
|
|
publicKeyMu.RUnlock()
|
|
if pk == nil {
|
|
return nil, errors.New("JWT public key not initialized")
|
|
}
|
|
|
|
token, err := jwt.ParseWithClaims(tokenStr, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return pk, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claims, ok := token.Claims.(*JWTClaims)
|
|
if !ok || !token.Valid {
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
// Check token expiration
|
|
if claims.ExpiresAt.Before(time.Now()) {
|
|
return nil, errors.New("token expired")
|
|
}
|
|
|
|
// 3) Store in Redis cache with TTL until expiry
|
|
cacheClaims(tokenStr, claims)
|
|
return claims, nil
|
|
}
|
|
|
|
// --- Redis cache helpers ---
|
|
|
|
const jwtCachePrefix = "jwt:cache:"
|
|
|
|
type cachedClaims struct {
|
|
UserID string `json:"user_id"`
|
|
Exp int64 `json:"exp"`
|
|
}
|
|
|
|
func tryGetCachedClaims(tokenStr string) *JWTClaims {
|
|
r := database.GetRedis()
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
ctx := context.Background()
|
|
key := jwtCachePrefix + sha256Hex(tokenStr)
|
|
raw, err := r.Get(ctx, key).Bytes()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
var cc cachedClaims
|
|
if err := json.Unmarshal(raw, &cc); err != nil {
|
|
return nil
|
|
}
|
|
if cc.Exp <= time.Now().Unix() {
|
|
return nil
|
|
}
|
|
uid, err := uuid.Parse(cc.UserID)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return &JWTClaims{
|
|
UserID: uid,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(cc.Exp, 0)),
|
|
},
|
|
}
|
|
}
|
|
|
|
func cacheClaims(tokenStr string, claims *JWTClaims) {
|
|
if claims == nil || claims.ExpiresAt == nil {
|
|
return
|
|
}
|
|
ttl := time.Until(claims.ExpiresAt.Time)
|
|
if ttl <= 0 {
|
|
return
|
|
}
|
|
r := database.GetRedis()
|
|
if r == nil {
|
|
return
|
|
}
|
|
ctx := context.Background()
|
|
key := jwtCachePrefix + sha256Hex(tokenStr)
|
|
payload, err := json.Marshal(cachedClaims{
|
|
UserID: claims.UserID.String(),
|
|
Exp: claims.ExpiresAt.Unix(),
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
_ = r.Set(ctx, key, payload, ttl).Err()
|
|
}
|
|
|
|
func sha256Hex(s string) string {
|
|
sum := sha256.Sum256([]byte(s))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|