223 lines
6.6 KiB
Python
223 lines
6.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
InsightFlow Rate Limiter - Phase 6
|
||
API 限流中间件
|
||
支持基于内存的滑动窗口限流
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
from collections import defaultdict
|
||
from collections.abc import Callable
|
||
from dataclasses import dataclass
|
||
from functools import wraps
|
||
|
||
|
||
@dataclass
|
||
class RateLimitConfig:
|
||
"""限流配置"""
|
||
|
||
requests_per_minute: int = 60
|
||
burst_size: int = 10 # 突发请求数
|
||
window_size: int = 60 # 窗口大小(秒)
|
||
|
||
|
||
@dataclass
|
||
class RateLimitInfo:
|
||
"""限流信息"""
|
||
|
||
allowed: bool
|
||
remaining: int
|
||
reset_time: int # 重置时间戳
|
||
retry_after: int # 需要等待的秒数
|
||
|
||
|
||
class SlidingWindowCounter:
|
||
"""滑动窗口计数器"""
|
||
|
||
def __init__(self, window_size: int = 60) -> None:
|
||
self.window_size = window_size
|
||
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
|
||
self._lock = asyncio.Lock()
|
||
self._cleanup_lock = asyncio.Lock()
|
||
|
||
async def add_request(self) -> int:
|
||
"""添加请求,返回当前窗口内的请求数"""
|
||
async with self._lock:
|
||
now = int(time.time())
|
||
self.requests[now] += 1
|
||
self._cleanup_old(now)
|
||
return sum(self.requests.values())
|
||
|
||
async def get_count(self) -> int:
|
||
"""获取当前窗口内的请求数"""
|
||
async with self._lock:
|
||
now = int(time.time())
|
||
self._cleanup_old(now)
|
||
return sum(self.requests.values())
|
||
|
||
def _cleanup_old(self, now: int) -> None:
|
||
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
||
cutoff = now - self.window_size
|
||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||
for k in old_keys:
|
||
self.requests.pop(k, None)
|
||
|
||
|
||
class RateLimiter:
|
||
"""API 限流器"""
|
||
|
||
def __init__(self) -> None:
|
||
# key -> SlidingWindowCounter
|
||
self.counters: dict[str, SlidingWindowCounter] = {}
|
||
# key -> RateLimitConfig
|
||
self.configs: dict[str, RateLimitConfig] = {}
|
||
self._lock = asyncio.Lock()
|
||
self._cleanup_lock = asyncio.Lock()
|
||
|
||
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
|
||
"""
|
||
检查是否允许请求
|
||
|
||
Args:
|
||
key: 限流键(如 API Key ID)
|
||
config: 限流配置,如果为 None 则使用默认配置
|
||
|
||
Returns:
|
||
RateLimitInfo
|
||
"""
|
||
if config is None:
|
||
config = RateLimitConfig()
|
||
|
||
async with self._lock:
|
||
if key not in self.counters:
|
||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||
self.configs[key] = config
|
||
|
||
counter = self.counters[key]
|
||
stored_config = self.configs.get(key, config)
|
||
|
||
# 获取当前计数
|
||
current_count = await counter.get_count()
|
||
|
||
# 计算剩余配额
|
||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||
|
||
# 计算重置时间
|
||
now = int(time.time())
|
||
reset_time = now + stored_config.window_size
|
||
|
||
# 检查是否超过限制
|
||
if current_count >= stored_config.requests_per_minute:
|
||
return RateLimitInfo(
|
||
allowed=False,
|
||
remaining=0,
|
||
reset_time=reset_time,
|
||
retry_after=stored_config.window_size,
|
||
)
|
||
|
||
# 允许请求,增加计数
|
||
await counter.add_request()
|
||
|
||
return RateLimitInfo(
|
||
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
|
||
)
|
||
|
||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||
"""获取限流信息(不增加计数)"""
|
||
if key not in self.counters:
|
||
config = RateLimitConfig()
|
||
return RateLimitInfo(
|
||
allowed=True,
|
||
remaining=config.requests_per_minute,
|
||
reset_time=int(time.time()) + config.window_size,
|
||
retry_after=0,
|
||
)
|
||
|
||
counter = self.counters[key]
|
||
config = self.configs.get(key, RateLimitConfig())
|
||
|
||
current_count = await counter.get_count()
|
||
remaining = max(0, config.requests_per_minute - current_count)
|
||
reset_time = int(time.time()) + config.window_size
|
||
|
||
return RateLimitInfo(
|
||
allowed=current_count < config.requests_per_minute,
|
||
remaining=remaining,
|
||
reset_time=reset_time,
|
||
retry_after=max(0, config.window_size)
|
||
if current_count >= config.requests_per_minute
|
||
else 0,
|
||
)
|
||
|
||
def reset(self, key: str | None = None) -> None:
|
||
"""重置限流计数器"""
|
||
if key:
|
||
self.counters.pop(key, None)
|
||
self.configs.pop(key, None)
|
||
else:
|
||
self.counters.clear()
|
||
self.configs.clear()
|
||
|
||
|
||
# 全局限流器实例
|
||
_rate_limiter: RateLimiter | None = None
|
||
|
||
|
||
def get_rate_limiter() -> RateLimiter:
|
||
"""获取限流器实例"""
|
||
global _rate_limiter
|
||
if _rate_limiter is None:
|
||
_rate_limiter = RateLimiter()
|
||
return _rate_limiter
|
||
|
||
|
||
# 限流装饰器(用于函数级别限流)
|
||
|
||
|
||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||
"""
|
||
限流装饰器
|
||
|
||
Args:
|
||
requests_per_minute: 每分钟请求数限制
|
||
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
||
"""
|
||
|
||
def decorator(func) -> None:
|
||
limiter = get_rate_limiter()
|
||
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
||
|
||
@wraps(func)
|
||
async def async_wrapper(*args, **kwargs) -> None:
|
||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||
info = await limiter.is_allowed(key, config)
|
||
|
||
if not info.allowed:
|
||
raise RateLimitExceeded(
|
||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
||
)
|
||
|
||
return await func(*args, **kwargs)
|
||
|
||
@wraps(func)
|
||
def sync_wrapper(*args, **kwargs) -> None:
|
||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||
# 同步版本使用 asyncio.run
|
||
info = asyncio.run(limiter.is_allowed(key, config))
|
||
|
||
if not info.allowed:
|
||
raise RateLimitExceeded(
|
||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
||
)
|
||
|
||
return func(*args, **kwargs)
|
||
|
||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
class RateLimitExceeded(Exception):
|
||
"""限流异常"""
|