Source code for apifrom.utils.rate_limit

from typing import Optional, Callable, Dict, List, Any, Union, TypeVar, Generic, cast
import time
import asyncio
from datetime import datetime, timedelta

[docs] class RateLimiter: """ A rate limiter implementation that limits the number of requests in a specified time period. """ def __init__( self, rate: int, period: int, burst: Optional[int] = None, key_func: Optional[Callable[[Any], str]] = None, error_message: Optional[str] = None ): """ Initialize a rate limiter. Args: rate: Maximum number of requests allowed in the time period period: Time period in seconds burst: Optional burst limit (maximum number of consecutive requests) key_func: Optional function to extract a key from the request error_message: Optional custom error message when rate limit is exceeded """ self.rate = rate self.period = period self.burst = burst or rate self.key_func = key_func or (lambda r: str(r)) self.error_message = error_message or f"Rate limit exceeded: {rate} requests per {period} seconds allowed" self._tokens: Dict[str, List[float]] = {} async def is_rate_limited(self, request: Any) -> bool: """ Check if a request exceeds the rate limit. Args: request: The request to check Returns: bool: True if the request is rate limited, False otherwise """ key = self.key_func(request) now = time.time() # Initialize token bucket if it doesn't exist if key not in self._tokens: self._tokens[key] = [] # Remove expired tokens self._tokens[key] = [t for t in self._tokens[key] if now - t <= self.period] # Check if rate limit is exceeded if len(self._tokens[key]) >= self.rate: return True # Add token for this request self._tokens[key].append(now) return False async def acquire(self, request: Any) -> bool: """ Try to acquire a token for the request. Args: request: The request to acquire a token for Returns: bool: True if a token was acquired, False if rate limited """ is_limited = await self.is_rate_limited(request) return not is_limited def get_remaining(self, request: Any) -> int: """ Get the number of remaining requests allowed. Args: request: The request to check Returns: int: Number of remaining requests allowed """ key = self.key_func(request) now = time.time() # Initialize token bucket if it doesn't exist if key not in self._tokens: return self.rate # Count valid tokens valid_tokens = [t for t in self._tokens[key] if now - t <= self.period] return max(0, self.rate - len(valid_tokens)) def get_reset_time(self, request: Any) -> float: """ Get the time until the rate limit resets. Args: request: The request to check Returns: float: Time in seconds until the rate limit resets """ key = self.key_func(request) now = time.time() if key not in self._tokens or not self._tokens[key]: return 0.0 oldest_token = min(self._tokens[key]) reset_time = max(0.0, self.period - (now - oldest_token)) return reset_time
class TokenBucketRateLimiter(RateLimiter): """ A token bucket implementation of rate limiting. """ def __init__( self, rate: int, period: int, burst: Optional[int] = None, key_func: Optional[Callable[[Any], str]] = None, error_message: Optional[str] = None ): """ Initialize a token bucket rate limiter. Args: rate: Rate at which tokens are added to the bucket (tokens/second) period: How often to add tokens (in seconds) burst: Maximum bucket size (defaults to rate) key_func: Function to extract a key from the request error_message: Custom error message when rate limit is exceeded """ super().__init__(rate, period, burst, key_func, error_message) self._last_refill: Dict[str, float] = {} self._tokens = {} # Override the type from parent to store token counts instead of timestamps async def is_rate_limited(self, request: Any) -> bool: """ Check if a request exceeds the rate limit using the token bucket algorithm. Args: request: The request to check Returns: bool: True if the request is rate limited, False otherwise """ key = self.key_func(request) now = time.time() # Initialize bucket if it doesn't exist if key not in self._tokens: self._tokens[key] = self.burst self._last_refill[key] = now # Refill tokens based on elapsed time elapsed = now - self._last_refill.get(key, now) new_tokens = min(self.burst, self._tokens.get(key, 0) + elapsed * (self.rate / self.period)) # If no tokens available, rate limited if new_tokens < 1: self._tokens[key] = new_tokens self._last_refill[key] = now return True # Consume a token self._tokens[key] = new_tokens - 1 self._last_refill[key] = now return False class FixedWindowRateLimiter(RateLimiter): """ A fixed window implementation of rate limiting. """ def __init__( self, rate: int, period: int, key_func: Optional[Callable[[Any], str]] = None, error_message: Optional[str] = None ): """ Initialize a fixed window rate limiter. Args: rate: Maximum number of requests allowed in the window period: Window size in seconds key_func: Function to extract a key from the request error_message: Custom error message when rate limit is exceeded """ super().__init__(rate, period, None, key_func, error_message) self._windows: Dict[str, Dict[int, int]] = {} async def is_rate_limited(self, request: Any) -> bool: """ Check if a request exceeds the rate limit using a fixed window algorithm. Args: request: The request to check Returns: bool: True if the request is rate limited, False otherwise """ key = self.key_func(request) now = time.time() current_window = int(now / self.period) # Initialize window if it doesn't exist if key not in self._windows: self._windows[key] = {} # Clear old windows self._windows[key] = {w: count for w, count in self._windows[key].items() if w >= current_window - 1} # Check if current window exceeds limit current_count = self._windows[key].get(current_window, 0) if current_count >= self.rate: return True # Increment the count for this window self._windows[key][current_window] = current_count + 1 return False