updates
This commit is contained in:
205
server/internal/middleware/auth.go
Normal file
205
server/internal/middleware/auth.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go-server/internal/config"
|
||||
"go-server/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AuthRequired is a middleware to validate JWT token
|
||||
func AuthRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
jwtHeader := c.GetHeader("x-jwt")
|
||||
if jwtHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the token
|
||||
claims, err := validateToken(jwtHeader)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Attach user ID to the context
|
||||
c.Set("userID", claims.UserID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// JWTClaims represents the JWT claims
|
||||
type JWTClaims struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
var (
|
||||
publicKeyMu sync.RWMutex
|
||||
publicKey *rsa.PublicKey
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize cached public key at startup
|
||||
if cfg, err := config.LoadConfig(); err == nil {
|
||||
if pk, err := parsePublicKeyFromConfig(cfg.JWT.PublicKey); err == nil {
|
||||
publicKeyMu.Lock()
|
||||
publicKey = pk
|
||||
publicKeyMu.Unlock()
|
||||
}
|
||||
}
|
||||
// Update cached key on config changes
|
||||
config.RegisterChangeListener(func(c *config.Config) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if pk, err := parsePublicKeyFromConfig(c.JWT.PublicKey); err == nil {
|
||||
publicKeyMu.Lock()
|
||||
publicKey = pk
|
||||
publicKeyMu.Unlock()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parsePublicKeyFromConfig(publicKeyB64 string) (*rsa.PublicKey, error) {
|
||||
if publicKeyB64 == "" {
|
||||
return nil, errors.New("missing JWT public key")
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode public key: %w", err)
|
||||
}
|
||||
pk, err := jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// validateToken validates and parses the JWT token
|
||||
func validateToken(tokenStr string) (*JWTClaims, error) {
|
||||
// 1) Try Redis cache first
|
||||
if claims := tryGetCachedClaims(tokenStr); claims != nil {
|
||||
// Double-check expiration in case clock skew and guard rails
|
||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// 2) Validate signature using cached public key
|
||||
publicKeyMu.RLock()
|
||||
pk := publicKey
|
||||
publicKeyMu.RUnlock()
|
||||
if pk == nil {
|
||||
return nil, errors.New("JWT public key not initialized")
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return pk, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
// Check token expiration
|
||||
if claims.ExpiresAt.Before(time.Now()) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
// 3) Store in Redis cache with TTL until expiry
|
||||
cacheClaims(tokenStr, claims)
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// --- Redis cache helpers ---
|
||||
|
||||
const jwtCachePrefix = "jwt:cache:"
|
||||
|
||||
type cachedClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func tryGetCachedClaims(tokenStr string) *JWTClaims {
|
||||
r := database.GetRedis()
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := jwtCachePrefix + sha256Hex(tokenStr)
|
||||
raw, err := r.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var cc cachedClaims
|
||||
if err := json.Unmarshal(raw, &cc); err != nil {
|
||||
return nil
|
||||
}
|
||||
if cc.Exp <= time.Now().Unix() {
|
||||
return nil
|
||||
}
|
||||
uid, err := uuid.Parse(cc.UserID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &JWTClaims{
|
||||
UserID: uid,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Unix(cc.Exp, 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func cacheClaims(tokenStr string, claims *JWTClaims) {
|
||||
if claims == nil || claims.ExpiresAt == nil {
|
||||
return
|
||||
}
|
||||
ttl := time.Until(claims.ExpiresAt.Time)
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
r := database.GetRedis()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := jwtCachePrefix + sha256Hex(tokenStr)
|
||||
payload, err := json.Marshal(cachedClaims{
|
||||
UserID: claims.UserID.String(),
|
||||
Exp: claims.ExpiresAt.Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = r.Set(ctx, key, payload, ttl).Err()
|
||||
}
|
||||
|
||||
func sha256Hex(s string) string {
|
||||
sum := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
20
server/internal/middleware/cors.go
Normal file
20
server/internal/middleware/cors.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// internal/middleware/cors.go
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, x-jwt")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
41
server/internal/middleware/logger.go
Normal file
41
server/internal/middleware/logger.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// internal/middleware/logger.go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go-server/pkg/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger provides lightweight structured logging and response time header
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
ip := c.ClientIP()
|
||||
ua := c.Request.UserAgent()
|
||||
|
||||
// Expose response time for clients/benchmarks
|
||||
c.Header("X-Response-Time", strconv.FormatInt(latency.Microseconds(), 10)+"us")
|
||||
|
||||
// Structured log (zap is very fast and minimally blocking)
|
||||
logger.Info("http_request",
|
||||
zap.Int("status", status),
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.Int64("latency_us", latency.Microseconds()),
|
||||
zap.String("ip", ip),
|
||||
zap.String("user_agent", ua),
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user