updates
This commit is contained in:
205
server/internal/middleware/auth.go
Normal file
205
server/internal/middleware/auth.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go-server/internal/config"
|
||||
"go-server/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AuthRequired is a middleware to validate JWT token
|
||||
func AuthRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
jwtHeader := c.GetHeader("x-jwt")
|
||||
if jwtHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the token
|
||||
claims, err := validateToken(jwtHeader)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Attach user ID to the context
|
||||
c.Set("userID", claims.UserID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// JWTClaims represents the JWT claims
|
||||
type JWTClaims struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
var (
|
||||
publicKeyMu sync.RWMutex
|
||||
publicKey *rsa.PublicKey
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize cached public key at startup
|
||||
if cfg, err := config.LoadConfig(); err == nil {
|
||||
if pk, err := parsePublicKeyFromConfig(cfg.JWT.PublicKey); err == nil {
|
||||
publicKeyMu.Lock()
|
||||
publicKey = pk
|
||||
publicKeyMu.Unlock()
|
||||
}
|
||||
}
|
||||
// Update cached key on config changes
|
||||
config.RegisterChangeListener(func(c *config.Config) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if pk, err := parsePublicKeyFromConfig(c.JWT.PublicKey); err == nil {
|
||||
publicKeyMu.Lock()
|
||||
publicKey = pk
|
||||
publicKeyMu.Unlock()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parsePublicKeyFromConfig(publicKeyB64 string) (*rsa.PublicKey, error) {
|
||||
if publicKeyB64 == "" {
|
||||
return nil, errors.New("missing JWT public key")
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode public key: %w", err)
|
||||
}
|
||||
pk, err := jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// validateToken validates and parses the JWT token
|
||||
func validateToken(tokenStr string) (*JWTClaims, error) {
|
||||
// 1) Try Redis cache first
|
||||
if claims := tryGetCachedClaims(tokenStr); claims != nil {
|
||||
// Double-check expiration in case clock skew and guard rails
|
||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// 2) Validate signature using cached public key
|
||||
publicKeyMu.RLock()
|
||||
pk := publicKey
|
||||
publicKeyMu.RUnlock()
|
||||
if pk == nil {
|
||||
return nil, errors.New("JWT public key not initialized")
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return pk, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
// Check token expiration
|
||||
if claims.ExpiresAt.Before(time.Now()) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
// 3) Store in Redis cache with TTL until expiry
|
||||
cacheClaims(tokenStr, claims)
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// --- Redis cache helpers ---
|
||||
|
||||
const jwtCachePrefix = "jwt:cache:"
|
||||
|
||||
type cachedClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func tryGetCachedClaims(tokenStr string) *JWTClaims {
|
||||
r := database.GetRedis()
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := jwtCachePrefix + sha256Hex(tokenStr)
|
||||
raw, err := r.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var cc cachedClaims
|
||||
if err := json.Unmarshal(raw, &cc); err != nil {
|
||||
return nil
|
||||
}
|
||||
if cc.Exp <= time.Now().Unix() {
|
||||
return nil
|
||||
}
|
||||
uid, err := uuid.Parse(cc.UserID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &JWTClaims{
|
||||
UserID: uid,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Unix(cc.Exp, 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func cacheClaims(tokenStr string, claims *JWTClaims) {
|
||||
if claims == nil || claims.ExpiresAt == nil {
|
||||
return
|
||||
}
|
||||
ttl := time.Until(claims.ExpiresAt.Time)
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
r := database.GetRedis()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := jwtCachePrefix + sha256Hex(tokenStr)
|
||||
payload, err := json.Marshal(cachedClaims{
|
||||
UserID: claims.UserID.String(),
|
||||
Exp: claims.ExpiresAt.Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = r.Set(ctx, key, payload, ttl).Err()
|
||||
}
|
||||
|
||||
func sha256Hex(s string) string {
|
||||
sum := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
Reference in New Issue
Block a user