aurganize-backend/backend/internal/middleware/rate_limiter.go

415 lines
15 KiB
Go

package middleware
import (
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
)
// RateLimiter implements a sliding window rate limiting algorithm to prevent abuse.
// Rate limiting is a critical security measure that protects your API from:
// 1. Brute force attacks (login attempts, password guessing)
// 2. Denial of Service (DoS) attacks (overwhelming the server)
// 3. API abuse (scraping, excessive requests)
// 4. Resource exhaustion (database connections, memory)
//
// How sliding window works:
// - Tracks request timestamps for each IP address
// - Keeps only requests within the time window
// - Blocks requests when limit is exceeded
// - Old requests automatically expire and don't count
//
// Example: limit=5, window=1 minute
// - 10:00:00: Request 1 ✅ (1/5)
// - 10:00:10: Request 2 ✅ (2/5)
// - 10:00:20: Request 3 ✅ (3/5)
// - 10:00:30: Request 4 ✅ (4/5)
// - 10:00:40: Request 5 ✅ (5/5)
// - 10:00:50: Request 6 ❌ (6/5 - BLOCKED!)
// - 10:01:05: Request 7 ✅ (2/5 - Request 1 expired)
//
// Why sliding window vs fixed window?
// - Fixed window: All counters reset at fixed intervals (e.g., every minute at :00)
// Problem: Can allow 2x limit (5 at 10:00:59, 5 at 10:01:00 = 10 in 1 second)
// - Sliding window: Counts requests in the last N seconds from now
// Benefit: Smoother rate limiting, no burst at window boundaries
//
// Memory consideration:
// - Stores timestamps for each IP address
// - Memory grows with number of unique IPs
// - Old timestamps are cleaned up automatically
// - For high-traffic applications, consider Redis-based rate limiting
//
// Thread safety:
// - Uses mutex (sync.Mutex) for concurrent access
// - Multiple requests can arrive simultaneously
// - Mutex ensures only one goroutine modifies the map at a time
type RateLimiter struct {
// requests maps IP addresses to their recent request timestamps
// Key: IP address (e.g., "192.168.1.1")
// Value: Slice of timestamps when requests were made
// Example: {"192.168.1.1": [10:00:00, 10:00:10, 10:00:20]}
requests map[string][]time.Time
// mu (mutex) ensures thread-safe access to the requests map
// Why needed: Multiple HTTP requests arrive concurrently (different goroutines)
// Without mutex: Race conditions (data corruption, incorrect counts)
// With mutex: Only one goroutine can read/write the map at a time
mu sync.Mutex
// limit is the maximum number of requests allowed within the time window
// Example: limit=5 means 5 requests per window
// Common values:
// - Login: 5-10 per minute (prevent brute force)
// - API: 100-1000 per minute (prevent abuse)
// - Registration: 3-5 per hour (prevent spam)
limit int
// window is the time duration for counting requests
// Example: window=1 minute means count requests in the last 60 seconds
// Common values:
// - 1 minute: Standard rate limiting
// - 1 hour: Aggressive rate limiting (password reset)
// - 1 second: Burst protection
// Format: time.Second, time.Minute, time.Hour
window time.Duration
}
// NewRateLimiter creates a new rate limiter with specified limit and time window.
// This constructor initializes the rate limiter with empty request tracking.
//
// Parameters:
//
// - limit: Maximum number of requests allowed in the time window
// Example: 5 means "allow 5 requests"
// Too low: Blocks legitimate users
// Too high: Doesn't prevent abuse
// Recommendation: Start conservative, increase if needed
//
// - window: Time duration for the sliding window
// Example: time.Minute means "5 requests per minute"
// Common patterns:
//
// - Login: NewRateLimiter(5, time.Minute) = 5 attempts per minute
//
// - API: NewRateLimiter(100, time.Minute) = 100 calls per minute
//
// - Registration: NewRateLimiter(3, time.Hour) = 3 signups per hour
//
// Returns:
// - Fully initialized RateLimiter ready to use as middleware
//
// Usage examples:
//
// // Protect login endpoint
// loginLimiter := NewRateLimiter(5, time.Minute)
// auth.POST("/login", handler, loginLimiter.Limit)
//
// // Protect API endpoints
// apiLimiter := NewRateLimiter(100, time.Minute)
// api.GET("/data", handler, apiLimiter.Limit)
//
// // Protect registration
// registerLimiter := NewRateLimiter(3, time.Hour)
// auth.POST("/register", handler, registerLimiter.Limit)
//
// Memory note:
// - Starts with empty map, grows as IPs make requests
// - Old timestamps cleaned automatically
// - No manual cleanup needed
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
log.Info().
Str("middleware", "rate_limiter").
Str("component", "middleware_init").
Int("limit", limit).
Dur("window", window).
Float64("requests_per_second", float64(limit)/window.Seconds()).
Msg("rate limiter initialized with security limits")
if limit <= 0 {
log.Error().
Str("middleware", "rate_limiter").
Str("action", "invalid_limit_config").
Int("limit", limit).
Msg("CRITICAL: rate limiter configured with zero or negative limit - all requests will be blocked!")
} else if limit > 1000 {
log.Warn().
Str("middleware", "rate_limiter").
Str("action", "very_high_limit").
Int("limit", limit).
Dur("window", window).
Msg("rate limiter configured with very high limit - may not prevent abuse effectively")
} else if limit < 3 && window < time.Minute {
log.Warn().
Str("middleware", "rate_limiter").
Str("action", "very_strict_limit").
Int("limit", limit).
Dur("window", window).
Msg("rate limiter configured with very strict limit - may impact legitimate users")
}
return &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
// Limit is a middleware function that enforces rate limiting per IP address.
// This wraps your route handler with rate limiting logic.
//
// How it works:
// 1. Extract client's IP address
// 2. Lock mutex (prevent concurrent access)
// 3. Get current time and calculate window start
// 4. Filter out expired requests (older than window)
// 5. Check if limit exceeded
// 6. If exceeded: Return 429 Too Many Requests
// 7. If allowed: Add current request and proceed
// 8. Unlock mutex (allow next request to be processed)
// Example timeline (limit=3, window=1 minute):
// 10:00:00 - Request 1 ✅ Count: 1
// 10:00:20 - Request 2 ✅ Count: 2
// 10:00:40 - Request 3 ✅ Count: 3
// 10:00:50 - Request 4 ❌ Count: 4 (BLOCKED - returns 429)
// 10:01:05 - Request 5 ✅ Count: 3 (Request 1 expired)
// 10:01:25 - Request 6 ✅ Count: 3 (Request 2 expired)
//
// IP address tracking:
// - Uses c.RealIP() to get actual client IP
// - Handles proxies (X-Forwarded-For header)
// - Handles load balancers (X-Real-IP header)
//
// Why track by IP?
// - Simple and effective for most use cases
// - No user authentication required
// - Works for public endpoints
// - Alternative: Track by user ID (requires authentication)
//
// Thread safety:
// - Mutex locks ensure safe concurrent access
// - Multiple requests from different users are processed correctly
// - No race conditions or data corruption
//
// Response codes:
// - 200 OK: Request allowed (passes to next handler)
// - 429 Too Many Requests: Rate limit exceeded (request blocked)
//
// Important notes:
// - Mutex is locked during entire rate limit check
// - Keep processing fast (simple operations only)
// - Don't do expensive operations while locked
// - Unlock happens automatically when function returns
//
// Parameters:
// - next: The actual route handler to execute if rate limit allows
//
// Returns:
// - HandlerFunc that wraps the next handler with rate limiting
//
// Usage:
//
// rateLimiter := NewRateLimiter(5, time.Minute)
//
// // Single route
// e.POST("/login", loginHandler, rateLimiter.Limit)
//
// // Route group
// auth := e.Group("/auth")
// auth.Use(rateLimiter.Limit) // Apply to all routes in group
// auth.POST("/login", loginHandler)
// auth.POST("/register", registerHandler)
//
// Production considerations:
// - For high traffic, consider Redis-based rate limiting
// - Monitor 429 responses (legitimate users vs attackers)
// - Adjust limits based on actual usage patterns
// - Consider different limits for different endpoints
// - Add rate limit headers (X-RateLimit-Limit, X-RateLimit-Remaining)
func (rl *RateLimiter) Limit(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Step 1: Get client's IP address
// RealIP() handles:
// - X-Forwarded-For header (proxies)
// - X-Real-IP header (load balancers)
// - RemoteAddr (direct connections)
// Example: "192.168.1.100" or "2001:db8::1" (IPv6)
ip := c.RealIP()
log.Debug().
Str("middleware", "rate_limiter").
Str("action", "rate_check_started").
Str("ip", ip).
Str("path", c.Request().URL.Path).
Str("method", c.Request().Method).
Msg("checking rate limit for request")
// Step 2: Lock the mutex for thread-safe map access
// CRITICAL: This prevents race conditions when multiple requests arrive simultaneously
// What happens without lock:
// - Goroutine 1 reads count: 4
// - Goroutine 2 reads count: 4 (at the same time)
// - Both think they're under limit (5)
// - Both proceed (6 requests allowed instead of 5!)
// With lock:
// - Only one goroutine can read/write at a time
// - Accurate counting guaranteed
rl.mu.Lock()
// Note: Unlock will happen when function returns (defer not used here but safe because of return statements)
// Step 3: Get current time for window calculation
// Used to determine which requests are still within the time window
now := time.Now()
// Step 4: Calculate the start of the time window
// Example: If window=1 minute and now=10:05:30
// windowStart = 10:05:30 - 1 minute = 10:04:30
// We only count requests between 10:04:30 and 10:05:30
windowStart := now.Add(-rl.window)
// Step 5: Get existing requests for this IP
// If IP never made a request, this will be nil/empty slice
// Example: ["10:04:35", "10:05:10", "10:05:20"]
request := rl.requests[ip]
// Step 6: Filter requests to keep only those within the window
// Create a new slice to store valid (recent) requests
validRequests := []time.Time{}
expiredCount := 0
// Iterate through all previous requests from this IP
for _, req := range request {
// Check if request timestamp is after the window start
// If yes: Request is recent (within window), keep it
// If no: Request is old (outside window), discard it
//
// Example: windowStart=10:04:30, now=10:05:30
// Request at 10:04:35 ✅ After window start (keep)
// Request at 10:04:20 ❌ Before window start (discard)
if req.After(windowStart) {
validRequests = append(validRequests, req)
} else {
expiredCount++
}
// Old requests are automatically garbage collected
// This keeps memory usage bounded
}
if expiredCount > 0 {
log.Debug().
Str("middleware", "rate_limiter").
Str("action", "expired_requests_cleaned").
Str("ip", ip).
Int("expired_count", expiredCount).
Int("remaining_count", len(validRequests)).
Msg("cleaned up expired requests from sliding window")
}
log.Debug().
Str("middleware", "rate_limiter").
Str("action", "rate_limit_decision").
Str("ip", ip).
Int("current_count", len(validRequests)).
Int("limit", rl.limit).
Dur("window", rl.window).
Bool("will_allow", len(validRequests) < rl.limit).
Msg("evaluating rate limit threshold")
// Step 7: Check if rate limit is exceeded
// Count how many valid requests exist
// If count >= limit, block the request
//
// Example: limit=5
// validRequests length=4 ✅ Allow (4 < 5)
// validRequests length=5 ❌ Block (5 >= 5)
// validRequests length=6 ❌ Block (6 >= 5)
if len(validRequests) >= rl.limit {
log.Warn().
Str("middleware", "rate_limiter").
Str("action", "rate_limit_exceeded").
Str("ip", ip).
Str("path", c.Request().URL.Path).
Str("method", c.Request().Method).
Int("current_count", len(validRequests)).
Int("limit", rl.limit).
Dur("window", rl.window).
Str("user_agent", c.Request().UserAgent()).
Msg("rate limit exceeded - request blocked")
path := c.Request().URL.Path
if path == "/auth/login" || path == "/api/auth/login" {
log.Warn().
Str("middleware", "rate_limiter").
Str("action", "login_rate_limit_hit").
Str("ip", ip).
Int("attempt_count", len(validRequests)).
Msg("SECURITY: multiple failed login attempts - possible brute force attack")
} else if path == "/auth/register" || path == "/api/auth/register" {
log.Warn().
Str("middleware", "rate_limiter").
Str("action", "register_rate_limit_hit").
Str("ip", ip).
Msg("SECURITY: multiple registration attempts - possible spam or abuse")
}
// IMPORTANT: Unlock mutex before returning error
// Without this, mutex stays locked forever (deadlock!)
rl.mu.Unlock()
// Return 429 Too Many Requests
// This is the standard HTTP status code for rate limiting
// Client should wait before retrying
//
// Best practice: Include Retry-After header (not implemented here)
// Example: Retry-After: 60 (wait 60 seconds)
return echo.NewHTTPError(429, "too many requests")
}
// Step 8: Request is allowed - add current request to tracking
// Append current timestamp to the valid requests
// This request will count against future rate limit checks
validRequests = append(validRequests, now)
// Step 9: Update the requests map with cleaned + new request
// Replace old request list (which had expired requests) with new list
// New list contains:
// - Recent requests (within window)
// - Current request (just made)
// Old requests outside window are now garbage collected
rl.requests[ip] = validRequests
uniqueIpCount := len(rl.requests)
if uniqueIpCount%100 == 0 {
log.Info().
Str("middleware", "rate_limiter").
Str("action", "unique_ip_milestone").
Int("unique_ip_count", uniqueIpCount).
Int("limit", rl.limit).
Dur("window", rl.window).
Msg("rate limiter tracking milestone reached")
}
log.Debug().
Str("middleware", "rate_limiter").
Str("action", "rate_limit_allowed").
Str("ip", ip).
Str("path", c.Request().URL.Path).
Int("current_count", len(validRequests)).
Int("remaining", rl.limit-len(validRequests)).
Msg("request allowed through rate limiter")
// Step 10: Unlock the mutex
// CRITICAL: Must unlock before calling next handler
// Why: Next handler might take time (database query, etc.)
// If we don't unlock, other requests will wait unnecessarily
// Unlock here allows other IPs to be rate-limited concurrently
rl.mu.Unlock()
// Step 11: Proceed to the actual route handler
// Rate limit check passed, execute the requested operation
// This could be login, API call, registration, etc.
return next(c)
}
}