package middleware import ( "sync" "time" "github.com/labstack/echo/v4" ) // RateLimiter implements a sliding window rate limiting algorithm to prevent abuse. // Rate limiting is a critical security measure that protects your API from: // 1. Brute force attacks (login attempts, password guessing) // 2. Denial of Service (DoS) attacks (overwhelming the server) // 3. API abuse (scraping, excessive requests) // 4. Resource exhaustion (database connections, memory) // // How sliding window works: // - Tracks request timestamps for each IP address // - Keeps only requests within the time window // - Blocks requests when limit is exceeded // - Old requests automatically expire and don't count // // Example: limit=5, window=1 minute // - 10:00:00: Request 1 ✅ (1/5) // - 10:00:10: Request 2 ✅ (2/5) // - 10:00:20: Request 3 ✅ (3/5) // - 10:00:30: Request 4 ✅ (4/5) // - 10:00:40: Request 5 ✅ (5/5) // - 10:00:50: Request 6 ❌ (6/5 - BLOCKED!) // - 10:01:05: Request 7 ✅ (2/5 - Request 1 expired) // // Why sliding window vs fixed window? // - Fixed window: All counters reset at fixed intervals (e.g., every minute at :00) // Problem: Can allow 2x limit (5 at 10:00:59, 5 at 10:01:00 = 10 in 1 second) // - Sliding window: Counts requests in the last N seconds from now // Benefit: Smoother rate limiting, no burst at window boundaries // // Memory consideration: // - Stores timestamps for each IP address // - Memory grows with number of unique IPs // - Old timestamps are cleaned up automatically // - For high-traffic applications, consider Redis-based rate limiting // // Thread safety: // - Uses mutex (sync.Mutex) for concurrent access // - Multiple requests can arrive simultaneously // - Mutex ensures only one goroutine modifies the map at a time type RateLimiter struct { // requests maps IP addresses to their recent request timestamps // Key: IP address (e.g., "192.168.1.1") // Value: Slice of timestamps when requests were made // Example: {"192.168.1.1": [10:00:00, 10:00:10, 10:00:20]} requests map[string][]time.Time // mu (mutex) ensures thread-safe access to the requests map // Why needed: Multiple HTTP requests arrive concurrently (different goroutines) // Without mutex: Race conditions (data corruption, incorrect counts) // With mutex: Only one goroutine can read/write the map at a time mu sync.Mutex // limit is the maximum number of requests allowed within the time window // Example: limit=5 means 5 requests per window // Common values: // - Login: 5-10 per minute (prevent brute force) // - API: 100-1000 per minute (prevent abuse) // - Registration: 3-5 per hour (prevent spam) limit int // window is the time duration for counting requests // Example: window=1 minute means count requests in the last 60 seconds // Common values: // - 1 minute: Standard rate limiting // - 1 hour: Aggressive rate limiting (password reset) // - 1 second: Burst protection // Format: time.Second, time.Minute, time.Hour window time.Duration } // NewRateLimiter creates a new rate limiter with specified limit and time window. // This constructor initializes the rate limiter with empty request tracking. // // Parameters: // // - limit: Maximum number of requests allowed in the time window // Example: 5 means "allow 5 requests" // Too low: Blocks legitimate users // Too high: Doesn't prevent abuse // Recommendation: Start conservative, increase if needed // // - window: Time duration for the sliding window // Example: time.Minute means "5 requests per minute" // Common patterns: // // - Login: NewRateLimiter(5, time.Minute) = 5 attempts per minute // // - API: NewRateLimiter(100, time.Minute) = 100 calls per minute // // - Registration: NewRateLimiter(3, time.Hour) = 3 signups per hour // // Returns: // - Fully initialized RateLimiter ready to use as middleware // // Usage examples: // // // Protect login endpoint // loginLimiter := NewRateLimiter(5, time.Minute) // auth.POST("/login", handler, loginLimiter.Limit) // // // Protect API endpoints // apiLimiter := NewRateLimiter(100, time.Minute) // api.GET("/data", handler, apiLimiter.Limit) // // // Protect registration // registerLimiter := NewRateLimiter(3, time.Hour) // auth.POST("/register", handler, registerLimiter.Limit) // // Memory note: // - Starts with empty map, grows as IPs make requests // - Old timestamps cleaned automatically // - No manual cleanup needed func NewRateLimiter(limit int, window time.Duration) *RateLimiter { return &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, } } // Limit is a middleware function that enforces rate limiting per IP address. // This wraps your route handler with rate limiting logic. // // How it works: // 1. Extract client's IP address // 2. Lock mutex (prevent concurrent access) // 3. Get current time and calculate window start // 4. Filter out expired requests (older than window) // 5. Check if limit exceeded // 6. If exceeded: Return 429 Too Many Requests // 7. If allowed: Add current request and proceed // 8. Unlock mutex (allow next request to be processed) // // Sliding window algorithm explained: // Window: [----------1 minute----------] // Now: ^ // Window start: ^ // Only count requests between window start and now // // Example timeline (limit=3, window=1 minute): // 10:00:00 - Request 1 ✅ Count: 1 // 10:00:20 - Request 2 ✅ Count: 2 // 10:00:40 - Request 3 ✅ Count: 3 // 10:00:50 - Request 4 ❌ Count: 4 (BLOCKED - returns 429) // 10:01:05 - Request 5 ✅ Count: 3 (Request 1 expired) // 10:01:25 - Request 6 ✅ Count: 3 (Request 2 expired) // // IP address tracking: // - Uses c.RealIP() to get actual client IP // - Handles proxies (X-Forwarded-For header) // - Handles load balancers (X-Real-IP header) // // Why track by IP? // - Simple and effective for most use cases // - No user authentication required // - Works for public endpoints // - Alternative: Track by user ID (requires authentication) // // Thread safety: // - Mutex locks ensure safe concurrent access // - Multiple requests from different users are processed correctly // - No race conditions or data corruption // // Response codes: // - 200 OK: Request allowed (passes to next handler) // - 429 Too Many Requests: Rate limit exceeded (request blocked) // // Important notes: // - Mutex is locked during entire rate limit check // - Keep processing fast (simple operations only) // - Don't do expensive operations while locked // - Unlock happens automatically when function returns // // Parameters: // - next: The actual route handler to execute if rate limit allows // // Returns: // - HandlerFunc that wraps the next handler with rate limiting // // Usage: // // rateLimiter := NewRateLimiter(5, time.Minute) // // // Single route // e.POST("/login", loginHandler, rateLimiter.Limit) // // // Route group // auth := e.Group("/auth") // auth.Use(rateLimiter.Limit) // Apply to all routes in group // auth.POST("/login", loginHandler) // auth.POST("/register", registerHandler) // // Production considerations: // - For high traffic, consider Redis-based rate limiting // - Monitor 429 responses (legitimate users vs attackers) // - Adjust limits based on actual usage patterns // - Consider different limits for different endpoints // - Add rate limit headers (X-RateLimit-Limit, X-RateLimit-Remaining) func (rl *RateLimiter) Limit(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // Step 1: Get client's IP address // RealIP() handles: // - X-Forwarded-For header (proxies) // - X-Real-IP header (load balancers) // - RemoteAddr (direct connections) // Example: "192.168.1.100" or "2001:db8::1" (IPv6) ip := c.RealIP() // Step 2: Lock the mutex for thread-safe map access // CRITICAL: This prevents race conditions when multiple requests arrive simultaneously // What happens without lock: // - Goroutine 1 reads count: 4 // - Goroutine 2 reads count: 4 (at the same time) // - Both think they're under limit (5) // - Both proceed (6 requests allowed instead of 5!) // With lock: // - Only one goroutine can read/write at a time // - Accurate counting guaranteed rl.mu.Lock() // Note: Unlock will happen when function returns (defer not used here but safe because of return statements) // Step 3: Get current time for window calculation // Used to determine which requests are still within the time window now := time.Now() // Step 4: Calculate the start of the time window // Example: If window=1 minute and now=10:05:30 // windowStart = 10:05:30 - 1 minute = 10:04:30 // We only count requests between 10:04:30 and 10:05:30 windowStart := now.Add(-rl.window) // Step 5: Get existing requests for this IP // If IP never made a request, this will be nil/empty slice // Example: ["10:04:35", "10:05:10", "10:05:20"] request := rl.requests[ip] // Step 6: Filter requests to keep only those within the window // Create a new slice to store valid (recent) requests validRequests := []time.Time{} // Iterate through all previous requests from this IP for _, req := range request { // Check if request timestamp is after the window start // If yes: Request is recent (within window), keep it // If no: Request is old (outside window), discard it // // Example: windowStart=10:04:30, now=10:05:30 // Request at 10:04:35 ✅ After window start (keep) // Request at 10:04:20 ❌ Before window start (discard) if req.After(windowStart) { validRequests = append(validRequests, req) } // Old requests are automatically garbage collected // This keeps memory usage bounded } // Step 7: Check if rate limit is exceeded // Count how many valid requests exist // If count >= limit, block the request // // Example: limit=5 // validRequests length=4 ✅ Allow (4 < 5) // validRequests length=5 ❌ Block (5 >= 5) // validRequests length=6 ❌ Block (6 >= 5) if len(validRequests) >= rl.limit { // IMPORTANT: Unlock mutex before returning error // Without this, mutex stays locked forever (deadlock!) rl.mu.Unlock() // Return 429 Too Many Requests // This is the standard HTTP status code for rate limiting // Client should wait before retrying // // Best practice: Include Retry-After header (not implemented here) // Example: Retry-After: 60 (wait 60 seconds) return echo.NewHTTPError(429, "too many requests") } // Step 8: Request is allowed - add current request to tracking // Append current timestamp to the valid requests // This request will count against future rate limit checks validRequests = append(validRequests, now) // Step 9: Update the requests map with cleaned + new request // Replace old request list (which had expired requests) with new list // New list contains: // - Recent requests (within window) // - Current request (just made) // Old requests outside window are now garbage collected rl.requests[ip] = validRequests // Step 10: Unlock the mutex // CRITICAL: Must unlock before calling next handler // Why: Next handler might take time (database query, etc.) // If we don't unlock, other requests will wait unnecessarily // Unlock here allows other IPs to be rate-limited concurrently rl.mu.Unlock() // Step 11: Proceed to the actual route handler // Rate limit check passed, execute the requested operation // This could be login, API call, registration, etc. return next(c) } }