#!/usr/bin/env python3 """ InsightFlow Rate Limiter - Phase 6 API 限流中间件 支持基于内存的滑动窗口限流 """ import asyncio import time from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass from functools import wraps @dataclass class RateLimitConfig: """限流配置""" requests_per_minute: int = 60 burst_size: int = 10 # 突发请求数 window_size: int = 60 # 窗口大小(秒) @dataclass class RateLimitInfo: """限流信息""" allowed: bool remaining: int reset_time: int # 重置时间戳 retry_after: int # 需要等待的秒数 class SlidingWindowCounter: """滑动窗口计数器""" def __init__(self, window_size: int = 60) -> None: self.window_size = window_size self.requests: dict[int, int] = defaultdict(int) # 秒级计数 self._lock = asyncio.Lock() self._cleanup_lock = asyncio.Lock() async def add_request(self) -> int: """添加请求,返回当前窗口内的请求数""" async with self._lock: now = int(time.time()) self.requests[now] += 1 self._cleanup_old(now) return sum(self.requests.values()) async def get_count(self) -> int: """获取当前窗口内的请求数""" async with self._lock: now = int(time.time()) self._cleanup_old(now) return sum(self.requests.values()) def _cleanup_old(self, now: int) -> None: """清理过期的请求记录 - 使用独立锁避免竞态条件""" cutoff = now - self.window_size old_keys = [k for k in list(self.requests.keys()) if k < cutoff] for k in old_keys: self.requests.pop(k, None) class RateLimiter: """API 限流器""" def __init__(self) -> None: # key -> SlidingWindowCounter self.counters: dict[str, SlidingWindowCounter] = {} # key -> RateLimitConfig self.configs: dict[str, RateLimitConfig] = {} self._lock = asyncio.Lock() self._cleanup_lock = asyncio.Lock() async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo: """ 检查是否允许请求 Args: key: 限流键(如 API Key ID) config: 限流配置,如果为 None 则使用默认配置 Returns: RateLimitInfo """ if config is None: config = RateLimitConfig() async with self._lock: if key not in self.counters: self.counters[key] = SlidingWindowCounter(config.window_size) self.configs[key] = config counter = self.counters[key] stored_config = self.configs.get(key, config) # 获取当前计数 current_count = await counter.get_count() # 计算剩余配额 remaining = max(0, stored_config.requests_per_minute - current_count) # 计算重置时间 now = int(time.time()) reset_time = now + stored_config.window_size # 检查是否超过限制 if current_count >= stored_config.requests_per_minute: return RateLimitInfo( allowed=False, remaining=0, reset_time=reset_time, retry_after=stored_config.window_size, ) # 允许请求,增加计数 await counter.add_request() return RateLimitInfo( allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0 ) async def get_limit_info(self, key: str) -> RateLimitInfo: """获取限流信息(不增加计数)""" if key not in self.counters: config = RateLimitConfig() return RateLimitInfo( allowed=True, remaining=config.requests_per_minute, reset_time=int(time.time()) + config.window_size, retry_after=0, ) counter = self.counters[key] config = self.configs.get(key, RateLimitConfig()) current_count = await counter.get_count() remaining = max(0, config.requests_per_minute - current_count) reset_time = int(time.time()) + config.window_size return RateLimitInfo( allowed=current_count < config.requests_per_minute, remaining=remaining, reset_time=reset_time, retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0, ) def reset(self, key: str | None = None) -> None: """重置限流计数器""" if key: self.counters.pop(key, None) self.configs.pop(key, None) else: self.counters.clear() self.configs.clear() # 全局限流器实例 _rate_limiter: RateLimiter | None = None def get_rate_limiter() -> RateLimiter: """获取限流器实例""" global _rate_limiter if _rate_limiter is None: _rate_limiter = RateLimiter() return _rate_limiter # 限流装饰器(用于函数级别限流) def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: """ 限流装饰器 Args: requests_per_minute: 每分钟请求数限制 key_func: 生成限流键的函数,默认为 None(使用函数名) """ def decorator(func) -> None: limiter = get_rate_limiter() config = RateLimitConfig(requests_per_minute=requests_per_minute) @wraps(func) async def async_wrapper(*args, **kwargs) -> None: key = key_func(*args, **kwargs) if key_func else func.__name__ info = await limiter.is_allowed(key, config) if not info.allowed: raise RateLimitExceeded( f"Rate limit exceeded. Try again in {info.retry_after} seconds." ) return await func(*args, **kwargs) @wraps(func) def sync_wrapper(*args, **kwargs) -> None: key = key_func(*args, **kwargs) if key_func else func.__name__ # 同步版本使用 asyncio.run info = asyncio.run(limiter.is_allowed(key, config)) if not info.allowed: raise RateLimitExceeded( f"Rate limit exceeded. Try again in {info.retry_after} seconds." ) return func(*args, **kwargs) return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper return decorator class RateLimitExceeded(Exception): """限流异常"""