package middleware import ( "sync" "time" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" ) // RateLimiter implements a sliding window rate limiting algorithm to prevent abuse. // Rate limiting is a critical security measure that protects your API from: // 1. Brute force attacks (login attempts, password guessing) // 2. Denial of Service (DoS) attacks (overwhelming the server) // 3. API abuse (scraping, excessive requests) // 4. Resource exhaustion (database connections, memory) // // How sliding window works: // - Tracks request timestamps for each IP address // - Keeps only requests within the time window // - Blocks requests when limit is exceeded // - Old requests automatically expire and don't count // // Example: limit=5, window=1 minute // - 10:00:00: Request 1 ✅ (1/5) // - 10:00:10: Request 2 ✅ (2/5) // - 10:00:20: Request 3 ✅ (3/5) // - 10:00:30: Request 4 ✅ (4/5) // - 10:00:40: Request 5 ✅ (5/5) // - 10:00:50: Request 6 ❌ (6/5 - BLOCKED!) // - 10:01:05: Request 7 ✅ (2/5 - Request 1 expired) // // Why sliding window vs fixed window? // - Fixed window: All counters reset at fixed intervals (e.g., every minute at :00) // Problem: Can allow 2x limit (5 at 10:00:59, 5 at 10:01:00 = 10 in 1 second) // - Sliding window: Counts requests in the last N seconds from now // Benefit: Smoother rate limiting, no burst at window boundaries // // Memory consideration: // - Stores timestamps for each IP address // - Memory grows with number of unique IPs // - Old timestamps are cleaned up automatically // - For high-traffic applications, consider Redis-based rate limiting // // Thread safety: // - Uses mutex (sync.Mutex) for concurrent access // - Multiple requests can arrive simultaneously // - Mutex ensures only one goroutine modifies the map at a time type RateLimiter struct { // requests maps IP addresses to their recent request timestamps // Key: IP address (e.g., "192.168.1.1") // Value: Slice of timestamps when requests were made // Example: {"192.168.1.1": [10:00:00, 10:00:10, 10:00:20]} requests map[string][]time.Time // mu (mutex) ensures thread-safe access to the requests map // Why needed: Multiple HTTP requests arrive concurrently (different goroutines) // Without mutex: Race conditions (data corruption, incorrect counts) // With mutex: Only one goroutine can read/write the map at a time mu sync.Mutex // limit is the maximum number of requests allowed within the time window // Example: limit=5 means 5 requests per window // Common values: // - Login: 5-10 per minute (prevent brute force) // - API: 100-1000 per minute (prevent abuse) // - Registration: 3-5 per hour (prevent spam) limit int // window is the time duration for counting requests // Example: window=1 minute means count requests in the last 60 seconds // Common values: // - 1 minute: Standard rate limiting // - 1 hour: Aggressive rate limiting (password reset) // - 1 second: Burst protection // Format: time.Second, time.Minute, time.Hour window time.Duration } // NewRateLimiter creates a new rate limiter with specified limit and time window. // This constructor initializes the rate limiter with empty request tracking. // // Parameters: // // - limit: Maximum number of requests allowed in the time window // Example: 5 means "allow 5 requests" // Too low: Blocks legitimate users // Too high: Doesn't prevent abuse // Recommendation: Start conservative, increase if needed // // - window: Time duration for the sliding window // Example: time.Minute means "5 requests per minute" // Common patterns: // // - Login: NewRateLimiter(5, time.Minute) = 5 attempts per minute // // - API: NewRateLimiter(100, time.Minute) = 100 calls per minute // // - Registration: NewRateLimiter(3, time.Hour) = 3 signups per hour // // Returns: // - Fully initialized RateLimiter ready to use as middleware // // Usage examples: // // // Protect login endpoint // loginLimiter := NewRateLimiter(5, time.Minute) // auth.POST("/login", handler, loginLimiter.Limit) // // // Protect API endpoints // apiLimiter := NewRateLimiter(100, time.Minute) // api.GET("/data", handler, apiLimiter.Limit) // // // Protect registration // registerLimiter := NewRateLimiter(3, time.Hour) // auth.POST("/register", handler, registerLimiter.Limit) // // Memory note: // - Starts with empty map, grows as IPs make requests // - Old timestamps cleaned automatically // - No manual cleanup needed func NewRateLimiter(limit int, window time.Duration) *RateLimiter { log.Info(). Str("middleware", "rate_limiter"). Str("component", "middleware_init"). Int("limit", limit). Dur("window", window). Float64("requests_per_second", float64(limit)/window.Seconds()). Msg("rate limiter initialized with security limits") if limit <= 0 { log.Error(). Str("middleware", "rate_limiter"). Str("action", "invalid_limit_config"). Int("limit", limit). Msg("CRITICAL: rate limiter configured with zero or negative limit - all requests will be blocked!") } else if limit > 1000 { log.Warn(). Str("middleware", "rate_limiter"). Str("action", "very_high_limit"). Int("limit", limit). Dur("window", window). Msg("rate limiter configured with very high limit - may not prevent abuse effectively") } else if limit < 3 && window < time.Minute { log.Warn(). Str("middleware", "rate_limiter"). Str("action", "very_strict_limit"). Int("limit", limit). Dur("window", window). Msg("rate limiter configured with very strict limit - may impact legitimate users") } return &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, } } // Limit is a middleware function that enforces rate limiting per IP address. // This wraps your route handler with rate limiting logic. // // How it works: // 1. Extract client's IP address // 2. Lock mutex (prevent concurrent access) // 3. Get current time and calculate window start // 4. Filter out expired requests (older than window) // 5. Check if limit exceeded // 6. If exceeded: Return 429 Too Many Requests // 7. If allowed: Add current request and proceed // 8. Unlock mutex (allow next request to be processed) // Example timeline (limit=3, window=1 minute): // 10:00:00 - Request 1 ✅ Count: 1 // 10:00:20 - Request 2 ✅ Count: 2 // 10:00:40 - Request 3 ✅ Count: 3 // 10:00:50 - Request 4 ❌ Count: 4 (BLOCKED - returns 429) // 10:01:05 - Request 5 ✅ Count: 3 (Request 1 expired) // 10:01:25 - Request 6 ✅ Count: 3 (Request 2 expired) // // IP address tracking: // - Uses c.RealIP() to get actual client IP // - Handles proxies (X-Forwarded-For header) // - Handles load balancers (X-Real-IP header) // // Why track by IP? // - Simple and effective for most use cases // - No user authentication required // - Works for public endpoints // - Alternative: Track by user ID (requires authentication) // // Thread safety: // - Mutex locks ensure safe concurrent access // - Multiple requests from different users are processed correctly // - No race conditions or data corruption // // Response codes: // - 200 OK: Request allowed (passes to next handler) // - 429 Too Many Requests: Rate limit exceeded (request blocked) // // Important notes: // - Mutex is locked during entire rate limit check // - Keep processing fast (simple operations only) // - Don't do expensive operations while locked // - Unlock happens automatically when function returns // // Parameters: // - next: The actual route handler to execute if rate limit allows // // Returns: // - HandlerFunc that wraps the next handler with rate limiting // // Usage: // // rateLimiter := NewRateLimiter(5, time.Minute) // // // Single route // e.POST("/login", loginHandler, rateLimiter.Limit) // // // Route group // auth := e.Group("/auth") // auth.Use(rateLimiter.Limit) // Apply to all routes in group // auth.POST("/login", loginHandler) // auth.POST("/register", registerHandler) // // Production considerations: // - For high traffic, consider Redis-based rate limiting // - Monitor 429 responses (legitimate users vs attackers) // - Adjust limits based on actual usage patterns // - Consider different limits for different endpoints // - Add rate limit headers (X-RateLimit-Limit, X-RateLimit-Remaining) func (rl *RateLimiter) Limit(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // Step 1: Get client's IP address // RealIP() handles: // - X-Forwarded-For header (proxies) // - X-Real-IP header (load balancers) // - RemoteAddr (direct connections) // Example: "192.168.1.100" or "2001:db8::1" (IPv6) ip := c.RealIP() log.Debug(). Str("middleware", "rate_limiter"). Str("action", "rate_check_started"). Str("ip", ip). Str("path", c.Request().URL.Path). Str("method", c.Request().Method). Msg("checking rate limit for request") // Step 2: Lock the mutex for thread-safe map access // CRITICAL: This prevents race conditions when multiple requests arrive simultaneously // What happens without lock: // - Goroutine 1 reads count: 4 // - Goroutine 2 reads count: 4 (at the same time) // - Both think they're under limit (5) // - Both proceed (6 requests allowed instead of 5!) // With lock: // - Only one goroutine can read/write at a time // - Accurate counting guaranteed rl.mu.Lock() // Note: Unlock will happen when function returns (defer not used here but safe because of return statements) // Step 3: Get current time for window calculation // Used to determine which requests are still within the time window now := time.Now() // Step 4: Calculate the start of the time window // Example: If window=1 minute and now=10:05:30 // windowStart = 10:05:30 - 1 minute = 10:04:30 // We only count requests between 10:04:30 and 10:05:30 windowStart := now.Add(-rl.window) // Step 5: Get existing requests for this IP // If IP never made a request, this will be nil/empty slice // Example: ["10:04:35", "10:05:10", "10:05:20"] request := rl.requests[ip] // Step 6: Filter requests to keep only those within the window // Create a new slice to store valid (recent) requests validRequests := []time.Time{} expiredCount := 0 // Iterate through all previous requests from this IP for _, req := range request { // Check if request timestamp is after the window start // If yes: Request is recent (within window), keep it // If no: Request is old (outside window), discard it // // Example: windowStart=10:04:30, now=10:05:30 // Request at 10:04:35 ✅ After window start (keep) // Request at 10:04:20 ❌ Before window start (discard) if req.After(windowStart) { validRequests = append(validRequests, req) } else { expiredCount++ } // Old requests are automatically garbage collected // This keeps memory usage bounded } if expiredCount > 0 { log.Debug(). Str("middleware", "rate_limiter"). Str("action", "expired_requests_cleaned"). Str("ip", ip). Int("expired_count", expiredCount). Int("remaining_count", len(validRequests)). Msg("cleaned up expired requests from sliding window") } log.Debug(). Str("middleware", "rate_limiter"). Str("action", "rate_limit_decision"). Str("ip", ip). Int("current_count", len(validRequests)). Int("limit", rl.limit). Dur("window", rl.window). Bool("will_allow", len(validRequests) < rl.limit). Msg("evaluating rate limit threshold") // Step 7: Check if rate limit is exceeded // Count how many valid requests exist // If count >= limit, block the request // // Example: limit=5 // validRequests length=4 ✅ Allow (4 < 5) // validRequests length=5 ❌ Block (5 >= 5) // validRequests length=6 ❌ Block (6 >= 5) if len(validRequests) >= rl.limit { log.Warn(). Str("middleware", "rate_limiter"). Str("action", "rate_limit_exceeded"). Str("ip", ip). Str("path", c.Request().URL.Path). Str("method", c.Request().Method). Int("current_count", len(validRequests)). Int("limit", rl.limit). Dur("window", rl.window). Str("user_agent", c.Request().UserAgent()). Msg("rate limit exceeded - request blocked") path := c.Request().URL.Path if path == "/auth/login" || path == "/api/auth/login" { log.Warn(). Str("middleware", "rate_limiter"). Str("action", "login_rate_limit_hit"). Str("ip", ip). Int("attempt_count", len(validRequests)). Msg("SECURITY: multiple failed login attempts - possible brute force attack") } else if path == "/auth/register" || path == "/api/auth/register" { log.Warn(). Str("middleware", "rate_limiter"). Str("action", "register_rate_limit_hit"). Str("ip", ip). Msg("SECURITY: multiple registration attempts - possible spam or abuse") } // IMPORTANT: Unlock mutex before returning error // Without this, mutex stays locked forever (deadlock!) rl.mu.Unlock() // Return 429 Too Many Requests // This is the standard HTTP status code for rate limiting // Client should wait before retrying // // Best practice: Include Retry-After header (not implemented here) // Example: Retry-After: 60 (wait 60 seconds) return echo.NewHTTPError(429, "too many requests") } // Step 8: Request is allowed - add current request to tracking // Append current timestamp to the valid requests // This request will count against future rate limit checks validRequests = append(validRequests, now) // Step 9: Update the requests map with cleaned + new request // Replace old request list (which had expired requests) with new list // New list contains: // - Recent requests (within window) // - Current request (just made) // Old requests outside window are now garbage collected rl.requests[ip] = validRequests uniqueIpCount := len(rl.requests) if uniqueIpCount%100 == 0 { log.Info(). Str("middleware", "rate_limiter"). Str("action", "unique_ip_milestone"). Int("unique_ip_count", uniqueIpCount). Int("limit", rl.limit). Dur("window", rl.window). Msg("rate limiter tracking milestone reached") } log.Debug(). Str("middleware", "rate_limiter"). Str("action", "rate_limit_allowed"). Str("ip", ip). Str("path", c.Request().URL.Path). Int("current_count", len(validRequests)). Int("remaining", rl.limit-len(validRequests)). Msg("request allowed through rate limiter") // Step 10: Unlock the mutex // CRITICAL: Must unlock before calling next handler // Why: Next handler might take time (database query, etc.) // If we don't unlock, other requests will wait unnecessarily // Unlock here allows other IPs to be rate-limited concurrently rl.mu.Unlock() // Step 11: Proceed to the actual route handler // Rate limit check passed, execute the requested operation // This could be login, API call, registration, etc. return next(c) } }