316 lines
11 KiB
Go
316 lines
11 KiB
Go
package middleware
|
|
|
|
import (
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
)
|
|
|
|
// RateLimiter implements a sliding window rate limiting algorithm to prevent abuse.
|
|
// Rate limiting is a critical security measure that protects your API from:
|
|
// 1. Brute force attacks (login attempts, password guessing)
|
|
// 2. Denial of Service (DoS) attacks (overwhelming the server)
|
|
// 3. API abuse (scraping, excessive requests)
|
|
// 4. Resource exhaustion (database connections, memory)
|
|
//
|
|
// How sliding window works:
|
|
// - Tracks request timestamps for each IP address
|
|
// - Keeps only requests within the time window
|
|
// - Blocks requests when limit is exceeded
|
|
// - Old requests automatically expire and don't count
|
|
//
|
|
// Example: limit=5, window=1 minute
|
|
// - 10:00:00: Request 1 ✅ (1/5)
|
|
// - 10:00:10: Request 2 ✅ (2/5)
|
|
// - 10:00:20: Request 3 ✅ (3/5)
|
|
// - 10:00:30: Request 4 ✅ (4/5)
|
|
// - 10:00:40: Request 5 ✅ (5/5)
|
|
// - 10:00:50: Request 6 ❌ (6/5 - BLOCKED!)
|
|
// - 10:01:05: Request 7 ✅ (2/5 - Request 1 expired)
|
|
//
|
|
// Why sliding window vs fixed window?
|
|
// - Fixed window: All counters reset at fixed intervals (e.g., every minute at :00)
|
|
// Problem: Can allow 2x limit (5 at 10:00:59, 5 at 10:01:00 = 10 in 1 second)
|
|
// - Sliding window: Counts requests in the last N seconds from now
|
|
// Benefit: Smoother rate limiting, no burst at window boundaries
|
|
//
|
|
// Memory consideration:
|
|
// - Stores timestamps for each IP address
|
|
// - Memory grows with number of unique IPs
|
|
// - Old timestamps are cleaned up automatically
|
|
// - For high-traffic applications, consider Redis-based rate limiting
|
|
//
|
|
// Thread safety:
|
|
// - Uses mutex (sync.Mutex) for concurrent access
|
|
// - Multiple requests can arrive simultaneously
|
|
// - Mutex ensures only one goroutine modifies the map at a time
|
|
|
|
type RateLimiter struct {
|
|
// requests maps IP addresses to their recent request timestamps
|
|
// Key: IP address (e.g., "192.168.1.1")
|
|
// Value: Slice of timestamps when requests were made
|
|
// Example: {"192.168.1.1": [10:00:00, 10:00:10, 10:00:20]}
|
|
requests map[string][]time.Time
|
|
|
|
// mu (mutex) ensures thread-safe access to the requests map
|
|
// Why needed: Multiple HTTP requests arrive concurrently (different goroutines)
|
|
// Without mutex: Race conditions (data corruption, incorrect counts)
|
|
// With mutex: Only one goroutine can read/write the map at a time
|
|
mu sync.Mutex
|
|
|
|
// limit is the maximum number of requests allowed within the time window
|
|
// Example: limit=5 means 5 requests per window
|
|
// Common values:
|
|
// - Login: 5-10 per minute (prevent brute force)
|
|
// - API: 100-1000 per minute (prevent abuse)
|
|
// - Registration: 3-5 per hour (prevent spam)
|
|
limit int
|
|
|
|
// window is the time duration for counting requests
|
|
// Example: window=1 minute means count requests in the last 60 seconds
|
|
// Common values:
|
|
// - 1 minute: Standard rate limiting
|
|
// - 1 hour: Aggressive rate limiting (password reset)
|
|
// - 1 second: Burst protection
|
|
// Format: time.Second, time.Minute, time.Hour
|
|
window time.Duration
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter with specified limit and time window.
|
|
// This constructor initializes the rate limiter with empty request tracking.
|
|
//
|
|
// Parameters:
|
|
//
|
|
// - limit: Maximum number of requests allowed in the time window
|
|
// Example: 5 means "allow 5 requests"
|
|
// Too low: Blocks legitimate users
|
|
// Too high: Doesn't prevent abuse
|
|
// Recommendation: Start conservative, increase if needed
|
|
//
|
|
// - window: Time duration for the sliding window
|
|
// Example: time.Minute means "5 requests per minute"
|
|
// Common patterns:
|
|
//
|
|
// - Login: NewRateLimiter(5, time.Minute) = 5 attempts per minute
|
|
//
|
|
// - API: NewRateLimiter(100, time.Minute) = 100 calls per minute
|
|
//
|
|
// - Registration: NewRateLimiter(3, time.Hour) = 3 signups per hour
|
|
//
|
|
// Returns:
|
|
// - Fully initialized RateLimiter ready to use as middleware
|
|
//
|
|
// Usage examples:
|
|
//
|
|
// // Protect login endpoint
|
|
// loginLimiter := NewRateLimiter(5, time.Minute)
|
|
// auth.POST("/login", handler, loginLimiter.Limit)
|
|
//
|
|
// // Protect API endpoints
|
|
// apiLimiter := NewRateLimiter(100, time.Minute)
|
|
// api.GET("/data", handler, apiLimiter.Limit)
|
|
//
|
|
// // Protect registration
|
|
// registerLimiter := NewRateLimiter(3, time.Hour)
|
|
// auth.POST("/register", handler, registerLimiter.Limit)
|
|
//
|
|
// Memory note:
|
|
// - Starts with empty map, grows as IPs make requests
|
|
// - Old timestamps cleaned automatically
|
|
// - No manual cleanup needed
|
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
|
return &RateLimiter{
|
|
requests: make(map[string][]time.Time),
|
|
limit: limit,
|
|
window: window,
|
|
}
|
|
}
|
|
|
|
// Limit is a middleware function that enforces rate limiting per IP address.
|
|
// This wraps your route handler with rate limiting logic.
|
|
//
|
|
// How it works:
|
|
// 1. Extract client's IP address
|
|
// 2. Lock mutex (prevent concurrent access)
|
|
// 3. Get current time and calculate window start
|
|
// 4. Filter out expired requests (older than window)
|
|
// 5. Check if limit exceeded
|
|
// 6. If exceeded: Return 429 Too Many Requests
|
|
// 7. If allowed: Add current request and proceed
|
|
// 8. Unlock mutex (allow next request to be processed)
|
|
//
|
|
// Sliding window algorithm explained:
|
|
// Window: [----------1 minute----------]
|
|
// Now: ^
|
|
// Window start: ^
|
|
// Only count requests between window start and now
|
|
//
|
|
// Example timeline (limit=3, window=1 minute):
|
|
// 10:00:00 - Request 1 ✅ Count: 1
|
|
// 10:00:20 - Request 2 ✅ Count: 2
|
|
// 10:00:40 - Request 3 ✅ Count: 3
|
|
// 10:00:50 - Request 4 ❌ Count: 4 (BLOCKED - returns 429)
|
|
// 10:01:05 - Request 5 ✅ Count: 3 (Request 1 expired)
|
|
// 10:01:25 - Request 6 ✅ Count: 3 (Request 2 expired)
|
|
//
|
|
// IP address tracking:
|
|
// - Uses c.RealIP() to get actual client IP
|
|
// - Handles proxies (X-Forwarded-For header)
|
|
// - Handles load balancers (X-Real-IP header)
|
|
//
|
|
// Why track by IP?
|
|
// - Simple and effective for most use cases
|
|
// - No user authentication required
|
|
// - Works for public endpoints
|
|
// - Alternative: Track by user ID (requires authentication)
|
|
//
|
|
// Thread safety:
|
|
// - Mutex locks ensure safe concurrent access
|
|
// - Multiple requests from different users are processed correctly
|
|
// - No race conditions or data corruption
|
|
//
|
|
// Response codes:
|
|
// - 200 OK: Request allowed (passes to next handler)
|
|
// - 429 Too Many Requests: Rate limit exceeded (request blocked)
|
|
//
|
|
// Important notes:
|
|
// - Mutex is locked during entire rate limit check
|
|
// - Keep processing fast (simple operations only)
|
|
// - Don't do expensive operations while locked
|
|
// - Unlock happens automatically when function returns
|
|
//
|
|
// Parameters:
|
|
// - next: The actual route handler to execute if rate limit allows
|
|
//
|
|
// Returns:
|
|
// - HandlerFunc that wraps the next handler with rate limiting
|
|
//
|
|
// Usage:
|
|
//
|
|
// rateLimiter := NewRateLimiter(5, time.Minute)
|
|
//
|
|
// // Single route
|
|
// e.POST("/login", loginHandler, rateLimiter.Limit)
|
|
//
|
|
// // Route group
|
|
// auth := e.Group("/auth")
|
|
// auth.Use(rateLimiter.Limit) // Apply to all routes in group
|
|
// auth.POST("/login", loginHandler)
|
|
// auth.POST("/register", registerHandler)
|
|
//
|
|
// Production considerations:
|
|
// - For high traffic, consider Redis-based rate limiting
|
|
// - Monitor 429 responses (legitimate users vs attackers)
|
|
// - Adjust limits based on actual usage patterns
|
|
// - Consider different limits for different endpoints
|
|
// - Add rate limit headers (X-RateLimit-Limit, X-RateLimit-Remaining)
|
|
func (rl *RateLimiter) Limit(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
// Step 1: Get client's IP address
|
|
// RealIP() handles:
|
|
// - X-Forwarded-For header (proxies)
|
|
// - X-Real-IP header (load balancers)
|
|
// - RemoteAddr (direct connections)
|
|
// Example: "192.168.1.100" or "2001:db8::1" (IPv6)
|
|
ip := c.RealIP()
|
|
|
|
// Step 2: Lock the mutex for thread-safe map access
|
|
// CRITICAL: This prevents race conditions when multiple requests arrive simultaneously
|
|
// What happens without lock:
|
|
// - Goroutine 1 reads count: 4
|
|
// - Goroutine 2 reads count: 4 (at the same time)
|
|
// - Both think they're under limit (5)
|
|
// - Both proceed (6 requests allowed instead of 5!)
|
|
// With lock:
|
|
// - Only one goroutine can read/write at a time
|
|
// - Accurate counting guaranteed
|
|
rl.mu.Lock()
|
|
|
|
// Note: Unlock will happen when function returns (defer not used here but safe because of return statements)
|
|
|
|
// Step 3: Get current time for window calculation
|
|
// Used to determine which requests are still within the time window
|
|
now := time.Now()
|
|
|
|
// Step 4: Calculate the start of the time window
|
|
// Example: If window=1 minute and now=10:05:30
|
|
// windowStart = 10:05:30 - 1 minute = 10:04:30
|
|
// We only count requests between 10:04:30 and 10:05:30
|
|
windowStart := now.Add(-rl.window)
|
|
|
|
// Step 5: Get existing requests for this IP
|
|
// If IP never made a request, this will be nil/empty slice
|
|
// Example: ["10:04:35", "10:05:10", "10:05:20"]
|
|
request := rl.requests[ip]
|
|
|
|
// Step 6: Filter requests to keep only those within the window
|
|
// Create a new slice to store valid (recent) requests
|
|
validRequests := []time.Time{}
|
|
|
|
// Iterate through all previous requests from this IP
|
|
for _, req := range request {
|
|
// Check if request timestamp is after the window start
|
|
// If yes: Request is recent (within window), keep it
|
|
// If no: Request is old (outside window), discard it
|
|
//
|
|
// Example: windowStart=10:04:30, now=10:05:30
|
|
// Request at 10:04:35 ✅ After window start (keep)
|
|
// Request at 10:04:20 ❌ Before window start (discard)
|
|
if req.After(windowStart) {
|
|
validRequests = append(validRequests, req)
|
|
}
|
|
// Old requests are automatically garbage collected
|
|
// This keeps memory usage bounded
|
|
}
|
|
|
|
// Step 7: Check if rate limit is exceeded
|
|
// Count how many valid requests exist
|
|
// If count >= limit, block the request
|
|
//
|
|
// Example: limit=5
|
|
// validRequests length=4 ✅ Allow (4 < 5)
|
|
// validRequests length=5 ❌ Block (5 >= 5)
|
|
// validRequests length=6 ❌ Block (6 >= 5)
|
|
if len(validRequests) >= rl.limit {
|
|
// IMPORTANT: Unlock mutex before returning error
|
|
// Without this, mutex stays locked forever (deadlock!)
|
|
rl.mu.Unlock()
|
|
|
|
// Return 429 Too Many Requests
|
|
// This is the standard HTTP status code for rate limiting
|
|
// Client should wait before retrying
|
|
//
|
|
// Best practice: Include Retry-After header (not implemented here)
|
|
// Example: Retry-After: 60 (wait 60 seconds)
|
|
return echo.NewHTTPError(429, "too many requests")
|
|
}
|
|
|
|
// Step 8: Request is allowed - add current request to tracking
|
|
// Append current timestamp to the valid requests
|
|
// This request will count against future rate limit checks
|
|
validRequests = append(validRequests, now)
|
|
|
|
// Step 9: Update the requests map with cleaned + new request
|
|
// Replace old request list (which had expired requests) with new list
|
|
// New list contains:
|
|
// - Recent requests (within window)
|
|
// - Current request (just made)
|
|
// Old requests outside window are now garbage collected
|
|
rl.requests[ip] = validRequests
|
|
|
|
// Step 10: Unlock the mutex
|
|
// CRITICAL: Must unlock before calling next handler
|
|
// Why: Next handler might take time (database query, etc.)
|
|
// If we don't unlock, other requests will wait unnecessarily
|
|
// Unlock here allows other IPs to be rate-limited concurrently
|
|
rl.mu.Unlock()
|
|
|
|
// Step 11: Proceed to the actual route handler
|
|
// Rate limit check passed, execute the requested operation
|
|
// This could be login, API call, registration, etc.
|
|
return next(c)
|
|
|
|
}
|
|
}
|