Files
insightflow/backend/rate_limiter.py
OpenClaw Bot 17bda3dbce fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
2026-02-27 18:09:24 +08:00

210 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
InsightFlow Rate Limiter - Phase 6
API 限流中间件
支持基于内存的滑动窗口限流
"""
import asyncio
import time
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
@dataclass
class RateLimitConfig:
"""限流配置"""
requests_per_minute: int = 60
burst_size: int = 10 # 突发请求数
window_size: int = 60 # 窗口大小(秒)
@dataclass
class RateLimitInfo:
"""限流信息"""
allowed: bool
remaining: int
reset_time: int # 重置时间戳
retry_after: int # 需要等待的秒数
class SlidingWindowCounter:
"""滑动窗口计数器"""
def __init__(self, window_size: int = 60):
self.window_size = window_size
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
async def add_request(self) -> int:
"""添加请求,返回当前窗口内的请求数"""
async with self._lock:
now = int(time.time())
self.requests[now] += 1
self._cleanup_old(now)
return sum(self.requests.values())
async def get_count(self) -> int:
"""获取当前窗口内的请求数"""
async with self._lock:
now = int(time.time())
self._cleanup_old(now)
return sum(self.requests.values())
def _cleanup_old(self, now: int):
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
cutoff = now - self.window_size
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
for k in old_keys:
self.requests.pop(k, None)
class RateLimiter:
"""API 限流器"""
def __init__(self):
# key -> SlidingWindowCounter
self.counters: dict[str, SlidingWindowCounter] = {}
# key -> RateLimitConfig
self.configs: dict[str, RateLimitConfig] = {}
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
"""
检查是否允许请求
Args:
key: 限流键(如 API Key ID
config: 限流配置,如果为 None 则使用默认配置
Returns:
RateLimitInfo
"""
if config is None:
config = RateLimitConfig()
async with self._lock:
if key not in self.counters:
self.counters[key] = SlidingWindowCounter(config.window_size)
self.configs[key] = config
counter = self.counters[key]
stored_config = self.configs.get(key, config)
# 获取当前计数
current_count = await counter.get_count()
# 计算剩余配额
remaining = max(0, stored_config.requests_per_minute - current_count)
# 计算重置时间
now = int(time.time())
reset_time = now + stored_config.window_size
# 检查是否超过限制
if current_count >= stored_config.requests_per_minute:
return RateLimitInfo(
allowed=False, remaining=0, reset_time=reset_time, retry_after=stored_config.window_size
)
# 允许请求,增加计数
await counter.add_request()
return RateLimitInfo(allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0)
async def get_limit_info(self, key: str) -> RateLimitInfo:
"""获取限流信息(不增加计数)"""
if key not in self.counters:
config = RateLimitConfig()
return RateLimitInfo(
allowed=True,
remaining=config.requests_per_minute,
reset_time=int(time.time()) + config.window_size,
retry_after=0,
)
counter = self.counters[key]
config = self.configs.get(key, RateLimitConfig())
current_count = await counter.get_count()
remaining = max(0, config.requests_per_minute - current_count)
reset_time = int(time.time()) + config.window_size
return RateLimitInfo(
allowed=current_count < config.requests_per_minute,
remaining=remaining,
reset_time=reset_time,
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
)
def reset(self, key: str | None = None):
"""重置限流计数器"""
if key:
self.counters.pop(key, None)
self.configs.pop(key, None)
else:
self.counters.clear()
self.configs.clear()
# 全局限流器实例
_rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter:
"""获取限流器实例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
# 限流装饰器(用于函数级别限流)
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None):
"""
限流装饰器
Args:
requests_per_minute: 每分钟请求数限制
key_func: 生成限流键的函数,默认为 None使用函数名
"""
def decorator(func):
limiter = get_rate_limiter()
config = RateLimitConfig(requests_per_minute=requests_per_minute)
@wraps(func)
async def async_wrapper(*args, **kwargs):
key = key_func(*args, **kwargs) if key_func else func.__name__
info = await limiter.is_allowed(key, config)
if not info.allowed:
raise RateLimitExceeded(f"Rate limit exceeded. Try again in {info.retry_after} seconds.")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
key = key_func(*args, **kwargs) if key_func else func.__name__
# 同步版本使用 asyncio.run
info = asyncio.run(limiter.is_allowed(key, config))
if not info.allowed:
raise RateLimitExceeded(f"Rate limit exceeded. Try again in {info.retry_after} seconds.")
return func(*args, **kwargs)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator
class RateLimitExceeded(Exception):
"""限流异常"""