aurganize-backend/backend/internal/services/auth_services.go

1214 lines
42 KiB
Go

package services
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"strings"
"time"
"github.com/creativenoz/aurganize-v62/backend/internal/config"
"github.com/creativenoz/aurganize-v62/backend/internal/models"
"github.com/creativenoz/aurganize-v62/backend/internal/repositories"
"github.com/creativenoz/aurganize-v62/backend/pkg/auth"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
// Predefined errors for authentication operations.
// These are defined as package-level variables so they can be:
// 1. Compared using errors.Is() for error handling
// 2. Wrapped with additional context using errors.Wrap()
// 3. Tested reliably (same error instance)
// 4. Documented centrally
//
// Why use errors.New() vs custom error types?
// - Simple errors don't need additional data
// - errors.Is() works for comparison
// - Can still wrap with context: fmt.Errorf("context: %w", ErrInvalidToken)
// - Good balance between simplicity and functionality
var (
// ErrInvalidToken indicates the provided token is malformed or has invalid signature
// Common causes:
// - Token tampered with
// - Wrong signing secret used
// - Token format doesn't match JWT standard
// - Token claims are malformed
ErrInvalidToken = errors.New("invalid token")
// ErrExpiredToken indicates the token's expiration time has passed
// This is normal - tokens should expire for security
// Client should use refresh token to get new access token
ErrExpiredToken = errors.New("token has been expired")
// ErrRevokedToken indicates the token has been explicitly invalidated
// Happens when:
// - User logs out
// - Admin revokes session
// - Password changed (all sessions revoked)
// - Security breach detected
ErrRevokedToken = errors.New("token has been revoked")
// ErrInvalidTokenType indicates token type doesn't match expected
// We have two token types: "access" and "refresh"
// This prevents using a refresh token as an access token (security issue)
ErrInvalidTokenType = errors.New("invalid token type")
)
// AuthService handles all authentication and authorization logic.
// This service is the business logic layer that sits between handlers and repositories.
//
// Responsibilities:
// 1. Token generation (access and refresh tokens)
// 2. Token validation (signature, expiration, revocation)
// 3. Session management (create, validate, revoke)
// 4. Token revocation (logout, logout all devices)
//
// Architecture: Service Layer Pattern
// - Handlers call services (not repositories directly)
// - Services contain business logic
// - Services call repositories for data access
// - Services can call multiple repositories (transaction coordination)
//
// Why separate service from handler?
// - Reusability: Multiple handlers can use same service
// - Testability: Can test business logic without HTTP
// - Separation of concerns: HTTP logic vs business logic
// - Transaction management: Service coordinates multiple repo calls
type AuthService struct {
config *config.Config // JWT secrets, expiration times, issuer info
sessionRepo *repositories.SessionRepository // Database operations for sessions
userRepo *repositories.UserRepository // Database operations for users (not used much here)
}
// NewAuthService creates a new AuthService with injected dependencies.
// This constructor follows dependency injection pattern for:
// - Testability (can inject mocks)
// - Flexibility (can change implementations)
// - Clear dependencies (explicit in signature)
//
// Parameters:
// - config: Application configuration (JWT settings)
// - sessionRepo: For session database operations
// - userRepo: For user database operations
//
// Returns:
// - Fully initialized AuthService
func NewAuthService(config *config.Config, sessionRepo *repositories.SessionRepository, userRepo *repositories.UserRepository) *AuthService {
log.Info().
Str("service", "auth").
Str("component", "service_init").
Dur("access_expiry", config.JWT.AccessExpiry).
Dur("refresh_expiry", config.JWT.RefreshExpiry).
Bool("has_session_repo", sessionRepo != nil).
Bool("has_user_repo", userRepo != nil).
Msg("auth service initialized with JWT configuration")
return &AuthService{
config: config,
sessionRepo: sessionRepo,
userRepo: userRepo,
}
}
// GenerateAccessToken creates a new JWT access token for a user.
// Access tokens are short-lived (typically 15 minutes) and used for API authentication.
//
// What's an access token?
// - Short-lived JWT (15 minutes typical)
// - Contains user identity and permissions
// - Sent with every API request
// - Stateless (doesn't require database lookup)
// - Used for request authentication/authorization
//
// Token structure:
// - Header: Algorithm (HS256), type (JWT)
// - Payload: Claims (user info, expiration, issuer)
// - Signature: HMAC SHA-256 of header+payload with secret
//
// Claims included:
// - UserID: Identifies the user
// - TenantID: For multi-tenancy (which organization)
// - Email: User's email address
// - Role: User's role (for authorization checks)
// - TokenType: "access" (prevents token confusion)
// - Standard claims: exp, iat, nbf, iss, sub
//
// Standard JWT claims explained:
// - exp (expires at): When token becomes invalid
// - iat (issued at): When token was created
// - nbf (not before): Token not valid before this time (usually same as iat)
// - iss (issuer): Who issued the token (our application)
// - sub (subject): User ID (standard way to identify token subject)
//
// Why short-lived?
// - If token is stolen, it's only valid for 15 minutes
// - Limits damage from token theft
// - Forces regular token refresh (can check if user still has access)
// - Balance between security and UX
//
// Security considerations:
// - Signed with secret key (only server can create valid tokens)
// - Include token type to prevent refresh token being used as access token
// - Don't include sensitive data (token visible in requests)
// - Set reasonable expiration (not too long)
//
// Return values:
// - (string, nil): Successfully generated token
// - ("", error): Token generation failed (configuration error)
func (a *AuthService) GenerateAccessToken(user *models.User) (string, error) {
log.Debug().
Str("service", "auth").
Str("action", "generate_access_token_started").
Str("user_id", user.ID.String()).
Str("tenant_id", user.TenantID.String()).
Str("role", user.Role).
Str("token_type", "access").
Dur("expiry_duration", a.config.JWT.AccessExpiry).
Msg("generating access token")
// Get current time for timestamps
now := time.Now()
// Calculate expiration time (now + configured expiry duration)
expiresAt := now.Add(a.config.JWT.AccessExpiry)
// Create claims structure with user information
claims := auth.AccessTokenClaims{
// Custom claims (our application-specific data)
UserID: user.ID, // User's unique identifier
TenantID: user.TenantID, // Organization/tenant identifier (multi-tenancy)
Email: user.Email, // User's email address
Role: user.Role, // User's role (admin, user, etc.)
TokenType: "access", // Identifies this as an access token
// Standard JWT registered claims
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt), // When token expires
IssuedAt: jwt.NewNumericDate(now), // When token was created
NotBefore: jwt.NewNumericDate(now), // Token valid from this time
Issuer: "aurganize-v62-api", // Who issued this token
Subject: user.ID.String(), // Subject of token (user ID)
},
}
// Create JWT token with claims
// NewWithClaims:
// - First param: Signing method (HS256 = HMAC SHA-256)
// - Second param: Claims to include in token
// Returns unsigned token object
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(a.config.JWT.AccessSecret))
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "generate_access_token_failed").
Str("user_id", user.ID.String()).
Str("tenant_id", user.TenantID.String()).
Err(err).
Msg("CRITICAL: failed to sign access token - authentication broken")
return "", nil
}
log.Debug().
Str("service", "auth").
Str("action", "generate_access_token_success").
Str("user_id", user.ID.String()).
Str("tenant_id", user.TenantID.String()).
Str("role", user.Role).
Time("expires_at", expiresAt).
Int("token_length", len(signedToken)).
Msg("access token generated successfully")
// Sign token with secret key to create final JWT string
// SignedString:
// - Takes secret key as []byte
// - Creates signature using HMAC SHA-256
// - Returns complete JWT string: "header.payload.signature"
// - Only holder of secret can create valid signatures
return signedToken, nil
}
// GenerateRefreshToken creates a new refresh token and session record.
// Refresh tokens are long-lived (typically 7 days) and used to obtain new access tokens.
//
// What's a refresh token?
// - Long-lived (7 days typical)
// - Used only to get new access tokens
// - Stored in database (can be revoked)
// - Contains session ID for tracking
// - More secure than never-expiring access tokens
//
// Refresh token vs Access token:
// - Access: Short-lived (15min), stateless, for API requests
// - Refresh: Long-lived (7 days), stateful (in database), for token renewal only
//
// Two-part token system:
// 1. Random token ID (stored in database)
// 2. JWT containing user ID, session ID, and token ID
//
// Why this two-part system?
// - JWT alone: Can't revoke (stateless)
// - Database alone: Requires lookup on every request (slow)
// - Hybrid: JWT for claims, database for revocation checking
//
// Token generation process:
// 1. Generate random 32-byte token ID (cryptographically secure)
// 2. Create session record in database with hashed token ID
// 3. Create JWT containing user ID, session ID, and token ID
// 4. Sign JWT with refresh secret
//
// Session tracking:
// - Records device information (user agent, IP, device type)
// - Allows "view active sessions" feature
// - Enables "logout from device X" functionality
// - Audit trail for security
//
// Security features:
// - Random token ID (not predictable)
// - Stored in database (can revoke)
// - Hashed in database (can't use even if database breached)
// - Device tracking (detect unusual activity)
// - Expiration date (eventually expires)
//
// Parameters:
// - ctx: Context for database operations
// - user: User to create token for
// - userAgent: Browser/app information (optional)
// - ipAddress: IP address of request (optional)
//
// Returns:
// - (signedToken, session, nil): Success
// - ("", nil, error): Failed to generate random token
// - ("", nil, error): Failed to create session in database
// - ("", nil, error): Failed to sign JWT
func (a *AuthService) GenerateRefreshToken(ctx context.Context, user *models.User, userAgent *string, ipAddress *string) (string, *models.Session, error) {
var ipStr, uaStr, deviceType string
if ipAddress != nil {
ipStr = *ipAddress
}
if userAgent != nil {
uaStr = *userAgent
}
deviceType = detectDeviceType(userAgent)
log.Info().
Str("service", "auth").
Str("action", "generate_refresh_token_started").
Str("user_id", user.ID.String()).
Str("tenant_id", user.TenantID.String()).
Str("email", user.Email).
Str("device_type", deviceType).
Str("ip", ipStr).
Str("user_agent", uaStr).
Dur("expiry_duration", a.config.JWT.RefreshExpiry).
Msg("generating refresh token and creating session")
log.Debug().
Str("service", "auth").
Str("action", "generating_random_token_id").
Int("token_bytes", 32).
Msg("generating cryptographically secure random token id")
// Step 1: Generate cryptographically secure random token ID
// Create 32-byte buffer for random data
tokenBytes := make([]byte, 32)
// Fill buffer with cryptographically secure random bytes
// crypto/rand.Read uses OS-provided randomness (very secure)
// This is NOT like math/rand (which is predictable)
if _, err := rand.Read(tokenBytes); err != nil {
log.Error().
Str("service", "auth").
Str("action", "random_token_generation_failed").
Str("user_id", user.ID.String()).
Err(err).
Msg("CRITICAL: failed to generate random token bytes - crypto RNG issue")
return "", nil, err // Failed to generate random data (very rare)
}
// Encode random bytes to base64 string
// Base64 makes binary data safe for text storage
// URLEncoding variant avoids special characters (+, /, =)
// Result: ~44-character string
refreshToken := base64.URLEncoding.EncodeToString(tokenBytes)
// Step 2: Calculate token expiration time
now := time.Now()
expiresAt := now.Add(a.config.JWT.RefreshExpiry) // Usually 7 days
log.Debug().
Str("service", "auth").
Str("action", "creating_session_record").
Str("user_id", user.ID.String()).
Str("device_type", deviceType).
Msg("creating session record in database")
// Step 3: Create session record in database
// This stores:
// - Hashed token (not plaintext for security)
// - User ID (who owns this session)
// - Device information (user agent, IP, device type)
// - Expiration time
session, err := a.sessionRepo.Create(ctx, &models.CreateSessionInput{
UserID: user.ID,
RefreshToken: refreshToken, // Will be hashed by repository
UserAgent: userAgent, // Browser/app info
IPAddress: ipAddress, // Where login came from
DeviceType: detectDeviceType(userAgent), // mobile, desktop, or web
ExpiresAt: expiresAt,
})
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "session_creation_failed").
Str("user_id", user.ID.String()).
Str("device_type", deviceType).
Err(err).
Msg("failed to create session in database - login will fail")
return "", nil, err // Database error
}
// Step 4: Create JWT claims with session information
claims := auth.RefreshTokenClaims{
// Custom claims
UserID: user.ID, // Which user owns this token
SessionID: session.ID, // Which session this token belongs to
TokenID: refreshToken, // The random token ID (for database lookup)
TokenType: "refresh", // Identifies this as refresh token
// Standard JWT claims
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt), // When token expires
IssuedAt: jwt.NewNumericDate(now), // When created
NotBefore: jwt.NewNumericDate(now), // Valid from now
Issuer: "aurganize-v62-api", // Who issued it
Subject: user.ID.String(), // Subject (user ID)
},
}
// Step 5: Create and sign JWT
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(a.config.JWT.RefreshSecret))
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "refresh_token_signing_failed").
Str("user_id", user.ID.String()).
Str("session_id", session.ID.String()).
Err(err).
Msg("CRITICAL: failed to sign refresh token JWT - session created but unusable")
return "", nil, err // Failed to sign token
}
log.Info().
Str("service", "auth").
Str("action", "generate_refresh_token_success").
Str("user_id", user.ID.String()).
Str("tenant_id", user.TenantID.String()).
Str("email", user.Email).
Str("session_id", session.ID.String()).
Str("device_type", deviceType).
Str("ip", ipStr).
Time("expires_at", expiresAt).
Int("token_length", len(signedToken)).
Msg("refresh token generated successfully - user logged in")
// Return signed JWT and session object
return signedToken, session, err
}
// ValidateAccessToken verifies an access token's signature and claims.
// This is called on every authenticated API request.
//
// What gets validated:
// 1. JWT signature (proves token wasn't tampered with)
// 2. Token not expired (exp claim)
// 3. Token valid now (nbf claim)
// 4. Token is "access" type (not refresh)
// 5. Token issued by us (iss claim)
//
// Validation process:
// 1. Parse JWT structure (header.payload.signature)
// 2. Verify signature using access secret
// 3. Check signing algorithm is HMAC (not "none" or RSA)
// 4. Validate expiration time
// 5. Validate issued-at time
// 6. Validate token type
//
// Why validate on every request?
// - Stateless authentication (no session lookups)
// - Fast (just cryptographic validation)
// - Secure (can't forge without secret)
// - Scalable (no database query needed)
//
// Security checks:
// - Algorithm verification (prevents "none" algorithm attack)
// - Signature verification (prevents tampering)
// - Expiration check (prevents replay of old tokens)
// - Token type check (prevents using refresh as access)
//
// Why NOT check database?
// - Would be slow (database query on every request)
// - Would not scale well
// - Access tokens are short-lived anyway (15 min)
// - Revocation handled at refresh token level
//
// When validation fails:
// - 401 Unauthorized response to client
// - Client should try refresh token
// - If refresh fails, redirect to login
//
// Parameters:
// - tokenString: JWT string from Authorization header or cookie
//
// Return values:
// - (*claims, nil): Token valid, returns claims for authorization
// - (nil, ErrExpiredToken): Token expired (client should refresh)
// - (nil, ErrInvalidToken): Token invalid (malformed, wrong signature, wrong type)
func (a *AuthService) ValidateAccessToken(tokenString string) (*auth.AccessTokenClaims, error) {
log.Debug().
Str("service", "auth").
Str("action", "validate_access_token_started").
Int("token_length", len(tokenString)).
Msg("validating access token")
// Parse and validate JWT token
// ParseWithClaims:
// - Parses JWT string
// - Validates signature using provided key function
// - Checks expiration and issued-at times
// - Populates claims struct
token, err := jwt.ParseWithClaims(
tokenString, // JWT string to parse
&auth.AccessTokenClaims{}, // Struct to populate with claims
// Key function: Called to get signing key for validation
func(token *jwt.Token) (interface{}, error) {
// Security check: Verify algorithm is HMAC
// Prevents "none" algorithm attack where attacker removes signature
// Prevents algorithm confusion attacks (using public key as symmetric key)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
log.Warn().
Str("service", "auth").
Str("action", "invalid_signing_algorithm").
Str("algorithm", token.Method.Alg()).
Msg("token validation failed - invalid signing algorithm (possible attack)")
return nil, ErrInvalidToken
}
// Return secret key for signature verification
return []byte(a.config.JWT.AccessSecret), nil
},
// Parser options for additional validation
jwt.WithExpirationRequired(), // Ensure exp claim is present and valid
jwt.WithIssuedAt(), // Validate iat claim
jwt.WithTimeFunc(time.Now), // Use current time for validation
)
if err != nil {
// Check if error is specifically about expiration
if errors.Is(err, jwt.ErrTokenExpired) {
log.Warn().
Str("service", "auth").
Str("action", "access_token_expired").
Msg("access token validation failed - token expired")
return nil, ErrExpiredToken // Return specific expiration error
}
log.Warn().
Str("service", "auth").
Str("action", "access_token_validation_failed").
Err(err).
Msg("access token validation failed - invalid token")
// Other errors: invalid signature, malformed JWT, etc.
return nil, ErrInvalidToken
}
// Extract and validate claims
// Type assertion: Convert interface{} to *AccessTokenClaims
claims, ok := token.Claims.(*auth.AccessTokenClaims)
if !ok || !token.Valid {
log.Warn().
Str("service", "auth").
Str("action", "access_token_claims_invalid").
Msg("access token validation failed - claims invalid or token not valid")
// Claims wrong type or token invalid
return nil, ErrInvalidToken
}
// Verify token type (prevent refresh token being used as access token)
if claims.TokenType != "access" {
log.Warn().
Str("service", "auth").
Str("action", "wrong_token_type_for_access").
Str("provided_type", claims.TokenType).
Str("expected_type", "access").
Msg("token validation failed - refresh token used as access token")
return nil, ErrInvalidTokenType
}
tokenAge := time.Since(claims.IssuedAt.Time)
log.Debug().
Str("service", "auth").
Str("action", "validate_access_token_success").
Str("user_id", claims.UserID.String()).
Str("tenant_id", claims.TenantID.String()).
Str("role", claims.Role).
Dur("token_age", tokenAge).
Msg("access token validated successfully")
// Token is valid, return claims for use in authorization
return claims, nil
}
// ValidateRefreshToken verifies a refresh token and returns associated session.
// This is called when client wants to get a new access token.
//
// Unlike access tokens, refresh tokens are validated against database:
// 1. Verify JWT signature
// 2. Check expiration
// 3. Look up session in database
// 4. Verify session not revoked
// 5. Verify session not expired
// 6. Update session last-used timestamp
//
// Why check database for refresh tokens?
// - Enables revocation (logout, password change)
// - Tracks device/location information
// - Allows "logout all devices" functionality
// - More secure than purely stateless
// - Acceptable performance (refresh happens every 15 min, not every request)
//
// Validation flow:
// 1. Parse JWT and verify signature
// 2. Extract session ID and token ID from claims
// 3. Query database for matching session
// 4. Verify session exists and is valid
// 5. Update last-used timestamp
// 6. Return claims and session
//
// Security checks performed:
// 1. JWT signature verification
// 2. Token expiration check
// 3. Token type verification ("refresh")
// 4. Session existence check
// 5. Session revocation check
// 6. Session expiration check
//
// Why so many checks?
// - Defense in depth (multiple security layers)
// - Catches different types of attacks
// - Provides clear error messages
// - Enables fine-grained control
//
// Parameters:
// - ctx: Context for database operations
// - tokenString: JWT string from cookie
//
// Return values:
// - (claims, session, nil): Valid token, returns data for use
// - (nil, nil, ErrExpiredToken): Token expired
// - (nil, nil, ErrInvalidToken): Token invalid or session not found
// - (nil, nil, ErrRevokedToken): Session has been revoked
func (a *AuthService) ValidateRefreshToken(ctx context.Context, tokenString string) (*auth.RefreshTokenClaims, *models.Session, error) {
log.Info().
Str("service", "auth").
Str("action", "validate_refresh_token_started").
Int("token_length", len(tokenString)).
Msg("validating refresh token")
// Step 1: Parse and validate JWT
token, err := jwt.ParseWithClaims(
tokenString,
&auth.RefreshTokenClaims{}, // Refresh token claims struct
func(token *jwt.Token) (interface{}, error) {
// Verify algorithm is HMAC (security check)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
log.Warn().
Str("service", "auth").
Str("action", "invalid_signing_algorithm_refresh").
Str("algorithm", token.Method.Alg()).
Msg("refresh token validation failed - invalid signing algorithm")
return nil, ErrInvalidToken
}
// Return REFRESH secret (different from access secret!)
return []byte(a.config.JWT.RefreshSecret), nil
},
jwt.WithExpirationRequired(), // Check expiration
jwt.WithIssuedAt(), // Check issued-at
jwt.WithTimeFunc(time.Now), // Use current time
)
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
log.Warn().
Str("service", "auth").
Str("action", "refresh_token_expired").
Msg("refresh token validation failed - token expired, user must re-login")
return nil, nil, ErrExpiredToken
}
log.Warn().
Str("service", "auth").
Str("action", "refresh_token_validation_failed").
Err(err).
Msg("refresh token validation failed - invalid token")
return nil, nil, ErrInvalidToken
}
// Step 2: Extract and validate claims
claims, ok := token.Claims.(*auth.RefreshTokenClaims)
if !ok || !token.Valid {
log.Warn().
Str("service", "auth").
Str("action", "refresh_token_claims_invalid").
Msg("refresh token validation failed - claims invalid")
return nil, nil, ErrInvalidToken
}
// Step 3: Verify token type
if claims.TokenType != "refresh" {
log.Warn().
Str("service", "auth").
Str("action", "wrong_token_type_for_refresh").
Str("provided_type", claims.TokenType).
Msg("token validation failed - access token used as refresh token")
return nil, nil, ErrInvalidToken
}
log.Debug().
Str("service", "auth").
Str("action", "looking_up_session").
Str("session_id", claims.SessionID.String()).
Str("user_id", claims.UserID.String()).
Msg("looking up session in database for refresh token")
// Step 4: Look up session in database
// This checks:
// - Session exists
// - Token hash matches
// - Session not revoked
// - Session not expired
session, err := a.sessionRepo.FindBySessionIDAndToken(ctx, claims.SessionID, claims.TokenID)
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "session_lookup_error").
Str("session_id", claims.SessionID.String()).
Err(err).
Msg("database error during session lookup")
return nil, nil, err // Database error
}
// Step 5: Verify session was found
if session == nil {
log.Warn().
Str("service", "auth").
Str("action", "session_not_found").
Str("session_id", claims.SessionID.String()).
Str("user_id", claims.UserID.String()).
Msg("session not found - may be revoked, expired, or invalid")
// Session doesn't exist or is invalid
// Could mean: wrong token, session revoked, session expired
return nil, nil, ErrInvalidToken
}
// Step 6: Verify session not revoked (redundant but explicit)
if session.IsRevoked {
log.Warn().
Str("service", "auth").
Str("action", "session_revoked").
Str("session_id", session.ID.String()).
Str("user_id", session.UserID.String()).
Str("revoked_reason", func() string {
if session.RevokedReason != nil {
return *session.RevokedReason
}
return "unknown"
}()).
Msg("session is revoked - refresh token invalid")
// Session was explicitly revoked (logout, password change, etc.)
return nil, nil, ErrRevokedToken
}
// Step 7: Verify session not expired (redundant but explicit)
if session.ExpiresAt.Before(time.Now()) {
log.Warn().
Str("service", "auth").
Str("action", "session_expired").
Str("session_id", session.ID.String()).
Time("expired_at", session.ExpiresAt).
Msg("session has expired")
// Session expired (different from token expiration)
return nil, nil, ErrRevokedToken
}
// Step 8: Update session last-used timestamp
// Tracks when session was last active
// Useful for security monitoring and cleanup
// We ignore error (not critical for validation)
_ = a.sessionRepo.UpdateLastUsed(ctx, session.ID)
log.Info().
Str("service", "auth").
Str("action", "validate_refresh_token_success").
Str("user_id", claims.UserID.String()).
Str("session_id", claims.SessionID.String()).
Time("session_last_used", session.LastUsedAt).
Msg("refresh token validated successfully")
// All checks passed, return claims and session
return claims, session, nil
}
// RotateRefreshToken validates an old refresh token and issues a new one.
// This implements refresh token rotation for enhanced security:
// 1. Validates the old refresh token (JWT + session)
// 2. Generates a new random token and creates new session
// 3. Revokes the old session
// 4. Returns new access token + new refresh token
//
// Security benefits:
// - Limits window of exposure if token is stolen
// - Enables detection of token theft (if old token is used after rotation)
// - Reduces attack surface by regularly cycling credentials
//
// Token theft detection:
// If an attacker uses a stolen old token after the legitimate user has already
// rotated it, the system can detect this suspicious activity and revoke all
// sessions for that user as a security precaution.
func (a *AuthService) RotateRefreshToken(ctx context.Context, oldTokenString string, userAgent *string, ipAddress *string) (string, string, *models.Session, error) {
log.Info().
Str("service", "auth").
Str("action", "rotate_refresh_token_started").
Msg("starting refresh token rotation")
// Step 1: Validate the old refresh token
claims, _, err := a.ValidateRefreshToken(ctx, oldTokenString)
if err != nil {
log.Warn().
Str("service", "auth").
Str("action", "rotation_validation_failed").
Err(err).
Msg("token rotation failed - old token invalid")
return "", "", nil, err
}
// Step 2: Get user details
user, err := a.userRepo.FindByID(ctx, claims.UserID)
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "rotation_user_not_found").
Str("user_id", claims.UserID.String()).
Err(err).
Msg("token rotation failed - database error")
return "", "", nil, err
}
if user == nil {
log.Error().
Str("service", "auth").
Str("action", "rotation_user_not_found").
Str("user_id", claims.UserID.String()).
Err(err).
Msg("token rotation failed - user not found")
return "", "", nil, ErrInvalidToken
}
// Step 3: Generate new access token
newAccessToken, err := a.GenerateAccessToken(user)
if err != nil {
return "", "", nil, err
}
// Step 4: Generate new refresh token (creates new session)
newRefreshToken, newSession, err := a.GenerateRefreshToken(ctx, user, userAgent, ipAddress)
if err != nil {
return "", "", nil, err
}
log.Info().
Str("service", "auth").
Str("action", "revoking_old_session_after_rotation").
Str("old_session_id", claims.SessionID.String()).
Str("new_session_id", newSession.ID.String()).
Str("user_id", user.ID.String()).
Msg("revoking old session after successful token rotation")
// Step 5: Revoke the old session (invalidates old refresh token)
// Use background context to ensure revocation completes even if request is cancelled
go func() {
_ = a.sessionRepo.Revoke(context.Background(), claims.TokenID, "token_rotation")
}()
log.Info().
Str("service", "auth").
Str("action", "rotate_refresh_token_success").
Str("user_id", user.ID.String()).
Str("old_session_id", claims.SessionID.String()).
Str("new_session_id", newSession.ID.String()).
Msg("refresh token rotated successfully")
return newAccessToken, newRefreshToken, newSession, nil
}
// ValidateRefreshTokenWithRotationCheck validates a refresh token and detects potential theft.
// If a revoked token is used (possible replay attack after rotation), it revokes all user sessions.
//
// Attack scenario:
// 1. Legitimate user rotates token (old token revoked, new token issued)
// 2. Attacker tries to use the old stolen token
// 3. System detects revoked token usage → revokes ALL user sessions
// 4. Both attacker and legitimate user must re-authenticate
// 5. User is alerted to suspicious activity
//
// This is an optional enhancement for high-security requirements.
func (a *AuthService) ValidateRefreshTokenWithRotationCheck(ctx context.Context, tokenString string) (*auth.RefreshTokenClaims, *models.Session, error) {
log.Info().
Str("service", "auth").
Str("action", "validate_with_theft_detection_started").
Msg("validating refresh token with rotation theft detection")
// Parse JWT to get claims
token, err := jwt.ParseWithClaims(
tokenString,
&auth.RefreshTokenClaims{},
func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, ErrInvalidToken
}
return []byte(a.config.JWT.RefreshSecret), nil
},
jwt.WithExpirationRequired(),
jwt.WithIssuedAt(),
jwt.WithTimeFunc(time.Now),
)
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, nil, ErrExpiredToken
}
return nil, nil, ErrInvalidToken
}
claims, ok := token.Claims.(*auth.RefreshTokenClaims)
if !ok {
return nil, nil, ErrInvalidToken
}
if claims.TokenType != "refresh" {
return nil, nil, ErrInvalidToken
}
// Look up session
session, err := a.sessionRepo.FindBySessionIDAndToken(ctx, claims.SessionID, claims.TokenID)
if err != nil {
return nil, nil, err
}
if session == nil {
return nil, nil, ErrInvalidToken
}
// THEFT DETECTION: If session is revoked but token is still valid (not expired),
// this indicates someone is trying to reuse a rotated token.
// This could be a legitimate user with an old token, or an attacker with a stolen token.
if session.IsRevoked {
// Check if token was revoked due to rotation
if session.RevokedReason != nil && *session.RevokedReason == "token_rotation" {
log.Error().
Str("service", "auth").
Str("action", "token_theft_detected").
Str("user_id", session.UserID.String()).
Str("session_id", session.ID.String()).
Str("ip", func() string {
if session.IPAddress != nil {
return *session.IPAddress
}
return "unknown"
}()).
Msg("SECURITY ALERT: Revoked token reused after rotation - possible token theft! Revoking all user sessions")
// SECURITY EVENT: Possible token theft detected
// Revoke ALL sessions for this user as a precaution
go func() {
_ = a.sessionRepo.RevokeByUserId(context.Background(), session.UserID, "potential_token_theft")
// TODO: Send security alert email/notification to user
// TODO: Log security event for monitoring
}()
}
return nil, nil, ErrExpiredToken
}
if session.ExpiresAt.Before(time.Now()) {
return nil, nil, ErrExpiredToken
}
_ = a.sessionRepo.UpdateLastUsed(ctx, session.ID)
return claims, session, nil
}
// RevokeRefreshToken marks a refresh token as revoked (logout).
// This is called when user logs out from current device.
//
// What happens:
// 1. Parse JWT to extract token ID
// 2. Find session by token hash
// 3. Mark session as revoked in database
// 4. Record revocation reason and timestamp
//
// Why parse JWT if we're revoking?
// - Need to extract token ID from claims
// - Token might be expired (that's okay for revocation)
// - We still verify signature (ensure it's our token)
//
// After revocation:
// - Token can't be used to get new access tokens
// - Current access tokens still work (until they expire in ~15 min)
// - User effectively logged out from this device
//
// Why current access tokens still work:
// - Access tokens are stateless (not checked against database)
// - They expire quickly anyway (15 minutes)
// - Checking database on every request would be too slow
// - This is an acceptable security tradeoff
//
// Revocation reason:
// - "user_logout": User clicked logout button
// - Stored for audit trail
// - Can be used for analytics
// - Helps in security investigations
//
// Error handling:
// - Returns error if parsing fails
// - Returns error if database update fails
// - Idempotent: OK to revoke already-revoked token
//
// Parameters:
// - ctx: Context for database operations
// - tokenJWT: JWT string to revoke
//
// Return values:
// - nil: Successfully revoked (or already revoked)
// - error: Failed to parse token or update database
func (a *AuthService) RevokeRefreshToken(ctx context.Context, tokenJWT string) error {
log.Info().
Str("service", "auth").
Str("action", "revoke_refresh_token_started").
Msg("revoking refresh token (logout)")
// Step 1: Parse JWT to extract claims
// We need the token ID to find the session
// We still use ParseWithClaims even though we're revoking
// because we need to verify it's actually our token
token, err := jwt.ParseWithClaims(
tokenJWT,
&auth.RefreshTokenClaims{},
func(token *jwt.Token) (interface{}, error) {
// Verify algorithm
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, ErrInvalidToken
}
return []byte(a.config.JWT.RefreshSecret), nil
},
// Note: We still require expiration and issued-at validation
// even for revocation, to ensure token structure is valid
jwt.WithExpirationRequired(),
jwt.WithIssuedAt(),
jwt.WithTimeFunc(time.Now),
)
if err != nil {
log.Warn().
Str("service", "auth").
Str("action", "revoke_token_parse_failed").
Err(err).
Msg("failed to parse token for revocation")
// Could be expired token (that's okay for revocation)
// But we still return error to indicate parsing failure
return err
}
// Step 2: Extract claims
claims, ok := token.Claims.(*auth.RefreshTokenClaims)
if !ok {
return ErrInvalidToken
}
err = a.sessionRepo.Revoke(ctx, claims.TokenID, "user_logout")
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "revoke_refresh_token_failed").
Str("session_id", claims.SessionID.String()).
Err(err).
Msg("failed to revoke refresh token")
return err
}
log.Info().
Str("service", "auth").
Str("action", "revoke_refresh_token_success").
Str("user_id", claims.UserID.String()).
Str("session_id", claims.SessionID.String()).
Msg("refresh token revoked successfully - user logged out")
// Step 3: Revoke the session in database
// Uses token ID to find session
// Marks as revoked with reason "user_logout"
// Operation is idempotent (safe to call multiple times)
return nil
}
// RevokeAllUserToken revokes all refresh tokens for a user (logout all devices).
// This is a security feature for:
// - Password change (force re-login everywhere)
// - Account compromise (revoke all access)
// - User request (logout from all devices)
// - Administrative action (force logout)
//
// What it does:
// - Finds all non-revoked sessions for user
// - Marks them all as revoked
// - Records revocation reason "revoke_all"
// - Updates revocation timestamp
//
// After calling this:
// - All refresh tokens for user become invalid
// - User must log in again on all devices
// - Current access tokens still work (until they expire in ~15 min)
//
// Use cases:
// 1. Password change: User changes password, log out all devices
// 2. Security breach: User reports compromise, revoke all access
// 3. Lost device: User lost phone, remotely log out all
// 4. Suspicious activity: Admin detects breach, force logout
// 5. Account termination: Ensure all access revoked
//
// Why this is important:
// - User control: Can remotely log out stolen device
// - Security: Limits damage from compromise
// - Password change: Ensures old sessions can't continue
// - Compliance: May be required for certain operations
//
// What happens to user:
// - All devices logged out
// - Must log in again with new credentials
// - Sees all sessions revoked in session list
// - Receives email notification (recommended)
//
// Parameters:
// - ctx: Context for database operations
// - userId: User whose tokens to revoke
//
// Return values:
// - nil: All tokens successfully revoked
// - error: Database error occurred
func (a *AuthService) RevokeAllUserToken(ctx context.Context, userId uuid.UUID) error {
log.Info().
Str("service", "auth").
Str("action", "revoke_all_user_tokens_started").
Str("user_id", userId.String()).
Msg("revoking all refresh tokens for user (logout all devices)")
err := a.sessionRepo.RevokeByUserId(ctx, userId, "revoke_all")
if err != nil {
log.Error().
Str("service", "auth").
Str("action", "revoke_all_tokens_failed").
Str("user_id", userId.String()).
Err(err).
Msg("CRITICAL: failed to revoke all user tokens")
return err
}
log.Info().
Str("service", "auth").
Str("action", "revoke_all_user_tokens_success").
Str("user_id", userId.String()).
Msg("all user tokens revoked successfully - logged out from all devices")
// Revoke all sessions for user
// Reason "revoke_all" indicates this was bulk revocation
// Repository handles finding and updating all sessions
return nil
}
// detectDeviceType attempts to determine device type from user agent string.
// This is used for:
// - Session display (show user what device they're logged in on)
// - Security monitoring (detect unusual devices)
// - Analytics (understand user devices)
// - Targeted features (mobile vs desktop experience)
//
// Detection logic:
// 1. Check for mobile indicators: "Mobile", "Android", "iPhone"
// 2. Check for desktop app indicator: "Electron"
// 3. Default to "web" for browsers
// 4. Return "unknown" if no user agent
//
// Device types:
// - "mobile": Smartphones and tablets (Android, iOS)
// - "desktop": Desktop applications (Electron apps)
// - "web": Web browsers on desktop/laptop
// - "unknown": No user agent or unrecognized
//
// Limitations:
// - User agent can be spoofed (not 100% reliable)
// - Simple detection (not comprehensive device detection)
// - Can't distinguish tablet from phone
// - Can't detect specific browser or OS version
//
// For better detection, consider:
// - User agent parsing library (more comprehensive)
// - Client-side detection (more accurate)
// - Device fingerprinting (more reliable but privacy concerns)
//
// Why this is good enough:
// - Just for display/convenience
// - Not used for security decisions
// - Simple and fast
// - No external dependencies
//
// Parameters:
// - userAgent: User-Agent header from HTTP request
//
// Returns:
// - "mobile": Mobile device detected
// - "desktop": Desktop application detected
// - "web": Web browser (default for desktop browsers)
// - "unknown": No user agent or unrecognized
func detectDeviceType(userAgent *string) string {
// Handle nil user agent (no header provided)
if userAgent == nil {
return "unknown"
}
// Get user agent string
ua := *userAgent
// Check for mobile indicators
// Contains checks for substring (case-insensitive via contains helper)
// Mobile keywords: "Mobile", "Android", "Iphone" (covers iOS and Android)
if contains(ua, "Mobile") || contains(ua, "Android") || contains(ua, "Iphone") {
return "mobile"
}
// Check for desktop application
// "Electron" indicates Electron-based desktop app
if contains(ua, "Electron") {
return "desktop"
}
// Default to web browser
// Catches Chrome, Firefox, Safari, Edge, etc. on desktop
return "web"
}
// contains checks if a string contains a substring (case-insensitive).
// This is a helper function for user agent parsing.
//
// Why case-insensitive?
// - User agents can vary in casing
// - "iPhone" vs "iphone" vs "IPHONE"
// - "Android" vs "android"
// - More robust matching
//
// Implementation:
// - Convert both strings to lowercase
// - Use strings.Contains for substring check
//
// Parameters:
// - s: String to search in
// - substring: Substring to search for
//
// Returns:
// - true: Substring found (case-insensitive)
// - false: Substring not found
func contains(s string, substring string) bool {
// Convert both to lowercase and check if substring exists
return strings.Contains(strings.ToLower(s), strings.ToLower(substring))
}