Phase 6: API Platform - Add authentication to existing endpoints and frontend API Key management UI
This commit is contained in:
223
backend/rate_limiter.py
Normal file
223
backend/rate_limiter.py
Normal file
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
InsightFlow Rate Limiter - Phase 6
|
||||
API 限流中间件
|
||||
支持基于内存的滑动窗口限流
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
requests_per_minute: int = 60
|
||||
burst_size: int = 10 # 突发请求数
|
||||
window_size: int = 60 # 窗口大小(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""限流信息"""
|
||||
allowed: bool
|
||||
remaining: int
|
||||
reset_time: int # 重置时间戳
|
||||
retry_after: int # 需要等待的秒数
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""滑动窗口计数器"""
|
||||
|
||||
def __init__(self, window_size: int = 60):
|
||||
self.window_size = window_size
|
||||
self.requests: Dict[int, int] = defaultdict(int) # 秒级计数
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_request(self) -> int:
|
||||
"""添加请求,返回当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
self.requests[now] += 1
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
async def get_count(self) -> int:
|
||||
"""获取当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
def _cleanup_old(self, now: int):
|
||||
"""清理过期的请求记录"""
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in self.requests.keys() if k < cutoff]
|
||||
for k in old_keys:
|
||||
del self.requests[k]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""API 限流器"""
|
||||
|
||||
def __init__(self):
|
||||
# key -> SlidingWindowCounter
|
||||
self.counters: Dict[str, SlidingWindowCounter] = {}
|
||||
# key -> RateLimitConfig
|
||||
self.configs: Dict[str, RateLimitConfig] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(
|
||||
self,
|
||||
key: str,
|
||||
config: Optional[RateLimitConfig] = None
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
检查是否允许请求
|
||||
|
||||
Args:
|
||||
key: 限流键(如 API Key ID)
|
||||
config: 限流配置,如果为 None 则使用默认配置
|
||||
|
||||
Returns:
|
||||
RateLimitInfo
|
||||
"""
|
||||
if config is None:
|
||||
config = RateLimitConfig()
|
||||
|
||||
async with self._lock:
|
||||
if key not in self.counters:
|
||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||
self.configs[key] = config
|
||||
|
||||
counter = self.counters[key]
|
||||
stored_config = self.configs.get(key, config)
|
||||
|
||||
# 获取当前计数
|
||||
current_count = await counter.get_count()
|
||||
|
||||
# 计算剩余配额
|
||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||
|
||||
# 计算重置时间
|
||||
now = int(time.time())
|
||||
reset_time = now + stored_config.window_size
|
||||
|
||||
# 检查是否超过限制
|
||||
if current_count >= stored_config.requests_per_minute:
|
||||
return RateLimitInfo(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_time=reset_time,
|
||||
retry_after=stored_config.window_size
|
||||
)
|
||||
|
||||
# 允许请求,增加计数
|
||||
await counter.add_request()
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=True,
|
||||
remaining=remaining - 1,
|
||||
reset_time=reset_time,
|
||||
retry_after=0
|
||||
)
|
||||
|
||||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||||
"""获取限流信息(不增加计数)"""
|
||||
if key not in self.counters:
|
||||
config = RateLimitConfig()
|
||||
return RateLimitInfo(
|
||||
allowed=True,
|
||||
remaining=config.requests_per_minute,
|
||||
reset_time=int(time.time()) + config.window_size,
|
||||
retry_after=0
|
||||
)
|
||||
|
||||
counter = self.counters[key]
|
||||
config = self.configs.get(key, RateLimitConfig())
|
||||
|
||||
current_count = await counter.get_count()
|
||||
remaining = max(0, config.requests_per_minute - current_count)
|
||||
reset_time = int(time.time()) + config.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=current_count < config.requests_per_minute,
|
||||
remaining=remaining,
|
||||
reset_time=reset_time,
|
||||
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0
|
||||
)
|
||||
|
||||
def reset(self, key: Optional[str] = None):
|
||||
"""重置限流计数器"""
|
||||
if key:
|
||||
self.counters.pop(key, None)
|
||||
self.configs.pop(key, None)
|
||||
else:
|
||||
self.counters.clear()
|
||||
self.configs.clear()
|
||||
|
||||
|
||||
# 全局限流器实例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取限流器实例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
# 限流装饰器(用于函数级别限流)
|
||||
def rate_limit(
|
||||
requests_per_minute: int = 60,
|
||||
key_func: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
限流装饰器
|
||||
|
||||
Args:
|
||||
requests_per_minute: 每分钟请求数限制
|
||||
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
||||
"""
|
||||
def decorator(func):
|
||||
limiter = get_rate_limiter()
|
||||
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
info = await limiter.is_allowed(key, config)
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
# 同步版本使用 asyncio.run
|
||||
info = asyncio.run(limiter.is_allowed(key, config))
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""限流异常"""
|
||||
pass
|
||||
Reference in New Issue
Block a user