Source code for apifrom.middleware.rate_limit

"""
Rate limiting middleware for APIFromAnything.

This module provides middleware for rate limiting API requests to prevent abuse.
"""

import time
from collections import defaultdict, deque
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import json

from apifrom.core.request import Request
from apifrom.core.response import Response, JSONResponse
from apifrom.middleware.base import BaseMiddleware


[docs] class RateLimitBackend: """ Base class for rate limit backends. Rate limit backends are responsible for storing and retrieving rate limit data. """ def get(self, key: str) -> Any: """ Get rate limit data for a key. Args: key: The key to get data for Returns: The rate limit data """ raise NotImplementedError("Subclasses must implement get") def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: """ Set rate limit data for a key. Args: key: The key to set data for value: The data to set ttl: Time to live in seconds """ raise NotImplementedError("Subclasses must implement set") def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int: """ Increment rate limit counter for a key. Args: key: The key to increment amount: The amount to increment by ttl: Time to live in seconds Returns: The new counter value """ raise NotImplementedError("Subclasses must implement increment") def reset(self, key: str) -> None: """ Reset rate limit data for a key. Args: key: The key to reset """ raise NotImplementedError("Subclasses must implement reset")
class InMemoryRateLimitBackend(RateLimitBackend): """ In-memory implementation of RateLimitBackend. This backend stores rate limit data in memory. """ def __init__(self): """ Initialize a new InMemoryRateLimitBackend instance. """ self._data = {} self._expiry = {} def get(self, key: str) -> Any: """ Get rate limit data for a key. Args: key: The key to get data for Returns: The rate limit data, or None if not found """ self._cleanup() return self._data.get(key) def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: """ Set rate limit data for a key. Args: key: The key to set data for value: The data to set ttl: Time to live in seconds """ self._data[key] = value if ttl is not None: self._expiry[key] = time.time() + ttl def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int: """ Increment rate limit counter for a key. Args: key: The key to increment amount: The amount to increment by ttl: Time to live in seconds Returns: The new counter value """ self._cleanup() if key not in self._data: self._data[key] = 0 self._data[key] += amount if ttl is not None: self._expiry[key] = time.time() + ttl return self._data[key] def reset(self, key: str) -> None: """ Reset rate limit data for a key. Args: key: The key to reset """ if key in self._data: del self._data[key] if key in self._expiry: del self._expiry[key] def _cleanup(self) -> None: """ Clean up expired keys. """ now = time.time() expired_keys = [k for k, v in self._expiry.items() if v <= now] for key in expired_keys: self.reset(key)
[docs] class RateLimiter: """ Base rate limiter interface. """ def check_limit(self, key: str) -> Tuple[bool, Dict[str, Any]]: """ Check if a key has exceeded its rate limit. Args: key: The key to check Returns: A tuple containing (allowed, limit_info) """ raise NotImplementedError("Subclasses must implement check_limit") def update(self, key: str) -> None: """ Update the rate limit counter for a key. Args: key: The key to update """ raise NotImplementedError("Subclasses must implement update")
class FixedWindowRateLimiter(RateLimiter): """ Fixed window rate limiter implementation. This rate limiter uses a fixed time window to limit requests. """ def __init__(self, limit: int, window: int = 60): """ Initialize the fixed window rate limiter. Args: limit: Maximum number of requests allowed in the window window: Time window in seconds """ self.limit = limit self.window = window self.counters: Dict[str, Dict[int, int]] = defaultdict(lambda: defaultdict(int)) def _get_current_window(self) -> int: """ Get the current time window. Returns: The current time window as an integer """ return int(time.time() / self.window) def check_limit(self, key: str) -> Tuple[bool, Dict[str, Any]]: """ Check if a key has exceeded its rate limit. Args: key: The key to check Returns: A tuple containing (allowed, limit_info) """ current_window = self._get_current_window() current_count = self.counters[key][current_window] allowed = current_count < self.limit reset_time = (current_window + 1) * self.window limit_info = { "limit": self.limit, "remaining": max(0, self.limit - current_count), "reset": reset_time, "window": self.window, } return allowed, limit_info def update(self, key: str) -> None: """ Update the rate limit counter for a key. Args: key: The key to update """ current_window = self._get_current_window() self.counters[key][current_window] += 1 # Clean up old windows for k in list(self.counters.keys()): for w in list(self.counters[k].keys()): if w < current_window: del self.counters[k][w] if not self.counters[k]: del self.counters[k] class SlidingWindowRateLimiter(RateLimiter): """ Sliding window rate limiter implementation. This rate limiter uses a sliding time window to limit requests. """ def __init__(self, limit: int, window: int = 60): """ Initialize the sliding window rate limiter. Args: limit: Maximum number of requests allowed in the window window: Time window in seconds """ self.limit = limit self.window = window self.requests: Dict[str, deque] = defaultdict(deque) def check_limit(self, key: str) -> Tuple[bool, Dict[str, Any]]: """ Check if a key has exceeded its rate limit. Args: key: The key to check Returns: A tuple containing (allowed, limit_info) """ # Remove expired timestamps self._clean_old_requests(key) # Count current requests in window current_count = len(self.requests[key]) allowed = current_count < self.limit # Calculate reset time reset_time = time.time() + self.window if not self.requests[key] else self.requests[key][0] + self.window limit_info = { "limit": self.limit, "remaining": max(0, self.limit - current_count), "reset": reset_time, "window": self.window, } return allowed, limit_info def update(self, key: str) -> None: """ Update the rate limit counter for a key. Args: key: The key to update """ # Remove expired timestamps self._clean_old_requests(key) # Add current timestamp self.requests[key].append(time.time()) def _clean_old_requests(self, key: str) -> None: """ Remove expired timestamps for a key. Args: key: The key to clean """ if key not in self.requests: return current_time = time.time() cutoff_time = current_time - self.window while self.requests[key] and self.requests[key][0] <= cutoff_time: self.requests[key].popleft() if not self.requests[key]: del self.requests[key] class TokenBucketRateLimiter(RateLimiter): """ Token bucket rate limiter implementation. This rate limiter uses a token bucket algorithm to limit requests. """ def __init__(self, rate: float, capacity: int): """ Initialize the token bucket rate limiter. Args: rate: Token refill rate per second capacity: Maximum number of tokens in the bucket """ self.rate = rate self.capacity = capacity self.buckets: Dict[str, Dict[str, float]] = {} def check_limit(self, key: str) -> Tuple[bool, Dict[str, Any]]: """ Check if a key has exceeded its rate limit. Args: key: The key to check Returns: A tuple containing (allowed, limit_info) """ if key not in self.buckets: self.buckets[key] = { "tokens": self.capacity, "last_refill": time.time(), } # Refill tokens self._refill(key) # Check if we have enough tokens allowed = self.buckets[key]["tokens"] >= 1 # Calculate time until next token time_until_next_token = 0 if allowed else (1 - self.buckets[key]["tokens"]) / self.rate limit_info = { "limit": self.capacity, "remaining": int(self.buckets[key]["tokens"]), "reset": time.time() + time_until_next_token, "rate": self.rate, } return allowed, limit_info def update(self, key: str) -> None: """ Update the rate limit counter for a key. Args: key: The key to update """ if key not in self.buckets: self.buckets[key] = { "tokens": self.capacity - 1, # Consume one token "last_refill": time.time(), } else: # Refill tokens self._refill(key) # Consume one token self.buckets[key]["tokens"] -= 1 def _refill(self, key: str) -> None: """ Refill tokens for a key. Args: key: The key to refill """ now = time.time() time_passed = now - self.buckets[key]["last_refill"] new_tokens = time_passed * self.rate self.buckets[key]["tokens"] = min(self.capacity, self.buckets[key]["tokens"] + new_tokens) self.buckets[key]["last_refill"] = now class RateLimitMiddleware(BaseMiddleware): """ Middleware for rate limiting API requests. """ def __init__( self, limiter: RateLimiter, key_func: Optional[Callable[[Request], str]] = None, exclude_routes: Optional[List[str]] = None, headers_enabled: bool = True, ): """ Initialize the rate limit middleware. Args: limiter: The rate limiter to use key_func: Function to extract the rate limit key from a request exclude_routes: Routes to exclude from rate limiting headers_enabled: Whether to include rate limit headers in responses """ super().__init__( key_func=key_func, exclude_routes=exclude_routes, headers_enabled=headers_enabled ) self.limiter = limiter self.key_func = key_func or self._default_key_func self.exclude_routes = exclude_routes or [] self.headers_enabled = headers_enabled def _default_key_func(self, request: Request) -> str: """ Default function to extract the rate limit key from a request. Args: request: The request object Returns: The rate limit key """ # Use client IP as the default key return request.client_ip or "unknown" def _should_limit(self, request: Request) -> bool: """ Determine if a request should be rate limited. Args: request: The request object Returns: True if the request should be rate limited, False otherwise """ # Don't limit excluded routes for route in self.exclude_routes: if request.path.startswith(route): return False return True def _add_rate_limit_headers(self, response: Response, limit_info: Dict[str, Any]) -> None: """ Add rate limit headers to a response. Args: response: The response object limit_info: Rate limit information """ if not self.headers_enabled: return response.headers["X-RateLimit-Limit"] = str(limit_info["limit"]) response.headers["X-RateLimit-Remaining"] = str(limit_info["remaining"]) response.headers["X-RateLimit-Reset"] = str(int(limit_info["reset"])) if "window" in limit_info: response.headers["X-RateLimit-Window"] = str(limit_info["window"]) if "rate" in limit_info: response.headers["X-RateLimit-Rate"] = str(limit_info["rate"]) async def process_request(self, request: Request) -> Request: """ Process a request through the rate limit middleware. Args: request: The request object Returns: The request object """ # Check if we should rate limit this request if not self._should_limit(request): return request # Get the rate limit key key = self.key_func(request) # Check if the request is allowed allowed, limit_info = self.limiter.check_limit(key) if not allowed: # Request is rate limited response = JSONResponse( {"error": "Rate limit exceeded"}, status_code=429, ) self._add_rate_limit_headers(response, limit_info) # Store the rate limited response in the request state request.state.rate_limited_response = response return request # Update the rate limiter self.limiter.update(key) # Store the limit info in the request state for use in process_response request.state.rate_limit_info = limit_info return request async def process_response(self, response: Response) -> Response: """ Process a response through the rate limit middleware. Args: response: The response object Returns: The response object """ # If we have a rate limited response in the request state, return it if hasattr(response.request.state, 'rate_limited_response'): return response.request.state.rate_limited_response # Add rate limit headers to the response if we have limit info if hasattr(response.request.state, 'rate_limit_info'): self._add_rate_limit_headers(response, response.request.state.rate_limit_info) return response async def __call__(self, scope, receive, send): """ ASGI callable. Args: scope: The ASGI scope. receive: The ASGI receive function. send: The ASGI send function. """ # Only process HTTP requests if scope["type"] != "http": await self.app(scope, receive, send) return # Extract request information from scope path = scope["path"] # Check if we should rate limit this request should_rate_limit = True for route in self.exclude_routes: if path.startswith(route): should_rate_limit = False break if not should_rate_limit: await self.app(scope, receive, send) return # Generate rate limit key client_ip = scope.get("client", ("127.0.0.1", 0))[0] key = f"{client_ip}:{path}" # Check rate limit allowed, limit_info = self.limiter.check_limit(key) # Update the rate limit counter self.limiter.update(key) # If rate limit is exceeded, return 429 response if not allowed: headers = [ (b"content-type", b"application/json"), ] if self.headers_enabled: headers.extend([ (b"X-RateLimit-Limit", str(limit_info["limit"]).encode()), (b"X-RateLimit-Remaining", b"0"), (b"X-RateLimit-Reset", str(limit_info["reset"]).encode()), ]) await send({ "type": "http.response.start", "status": 429, "headers": headers, }) await send({ "type": "http.response.body", "body": json.dumps({"error": "Rate limit exceeded"}).encode(), }) return # Not rate limited, process request # Add rate limit headers to response original_send = send async def send_wrapper(message): if message["type"] == "http.response.start" and self.headers_enabled: # Add rate limit headers headers = list(message["headers"]) headers.extend([ (b"X-RateLimit-Limit", str(limit_info["limit"]).encode()), (b"X-RateLimit-Remaining", str(limit_info["remaining"]).encode()), (b"X-RateLimit-Reset", str(limit_info["reset"]).encode()), ]) message["headers"] = headers await original_send(message) await self.app(scope, receive, send_wrapper)
[docs] class RateLimit: """ Decorator for controlling rate limiting on specific endpoints. """ @staticmethod def limit(limit: int, window: int = 60, key_func: Optional[Callable] = None): """ Apply a rate limit to an endpoint. Args: limit: Maximum number of requests allowed in the window window: Time window in seconds key_func: Function to extract the rate limit key from a request Returns: A decorator function """ def decorator(func): func._rate_limit = { "limit": limit, "window": window, "key_func": key_func, } return func return decorator @staticmethod def exempt(func): """ Exempt an endpoint from rate limiting. Args: func: The function to decorate Returns: The decorated function """ func._rate_limit_exempt = True return func