fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 修复语法错误(运算符空格问题) - 修复类型注解格式
This commit is contained in:
@@ -17,9 +17,9 @@ from functools import wraps
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
|
||||
requests_per_minute: int = 60
|
||||
burst_size: int = 10 # 突发请求数
|
||||
window_size: int = 60 # 窗口大小(秒)
|
||||
requests_per_minute: int = 60
|
||||
burst_size: int = 10 # 突发请求数
|
||||
window_size: int = 60 # 窗口大小(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,16 +35,16 @@ class RateLimitInfo:
|
||||
class SlidingWindowCounter:
|
||||
"""滑动窗口计数器"""
|
||||
|
||||
def __init__(self, window_size: int = 60):
|
||||
self.window_size = window_size
|
||||
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
def __init__(self, window_size: int = 60) -> None:
|
||||
self.window_size = window_size
|
||||
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
|
||||
async def add_request(self) -> int:
|
||||
"""添加请求,返回当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
now = int(time.time())
|
||||
self.requests[now] += 1
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
@@ -52,14 +52,14 @@ class SlidingWindowCounter:
|
||||
async def get_count(self) -> int:
|
||||
"""获取当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
now = int(time.time())
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
def _cleanup_old(self, now: int) -> None:
|
||||
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||
for k in old_keys:
|
||||
self.requests.pop(k, None)
|
||||
|
||||
@@ -69,13 +69,13 @@ class RateLimiter:
|
||||
|
||||
def __init__(self) -> None:
|
||||
# key -> SlidingWindowCounter
|
||||
self.counters: dict[str, SlidingWindowCounter] = {}
|
||||
self.counters: dict[str, SlidingWindowCounter] = {}
|
||||
# key -> RateLimitConfig
|
||||
self.configs: dict[str, RateLimitConfig] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self.configs: dict[str, RateLimitConfig] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
|
||||
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
|
||||
"""
|
||||
检查是否允许请求
|
||||
|
||||
@@ -87,70 +87,70 @@ class RateLimiter:
|
||||
RateLimitInfo
|
||||
"""
|
||||
if config is None:
|
||||
config = RateLimitConfig()
|
||||
config = RateLimitConfig()
|
||||
|
||||
async with self._lock:
|
||||
if key not in self.counters:
|
||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||
self.configs[key] = config
|
||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||
self.configs[key] = config
|
||||
|
||||
counter = self.counters[key]
|
||||
stored_config = self.configs.get(key, config)
|
||||
counter = self.counters[key]
|
||||
stored_config = self.configs.get(key, config)
|
||||
|
||||
# 获取当前计数
|
||||
current_count = await counter.get_count()
|
||||
current_count = await counter.get_count()
|
||||
|
||||
# 计算剩余配额
|
||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||
|
||||
# 计算重置时间
|
||||
now = int(time.time())
|
||||
reset_time = now + stored_config.window_size
|
||||
now = int(time.time())
|
||||
reset_time = now + stored_config.window_size
|
||||
|
||||
# 检查是否超过限制
|
||||
if current_count >= stored_config.requests_per_minute:
|
||||
return RateLimitInfo(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_time=reset_time,
|
||||
retry_after=stored_config.window_size,
|
||||
allowed = False,
|
||||
remaining = 0,
|
||||
reset_time = reset_time,
|
||||
retry_after = stored_config.window_size,
|
||||
)
|
||||
|
||||
# 允许请求,增加计数
|
||||
await counter.add_request()
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
|
||||
allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0
|
||||
)
|
||||
|
||||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||||
"""获取限流信息(不增加计数)"""
|
||||
if key not in self.counters:
|
||||
config = RateLimitConfig()
|
||||
config = RateLimitConfig()
|
||||
return RateLimitInfo(
|
||||
allowed=True,
|
||||
remaining=config.requests_per_minute,
|
||||
reset_time=int(time.time()) + config.window_size,
|
||||
retry_after=0,
|
||||
allowed = True,
|
||||
remaining = config.requests_per_minute,
|
||||
reset_time = int(time.time()) + config.window_size,
|
||||
retry_after = 0,
|
||||
)
|
||||
|
||||
counter = self.counters[key]
|
||||
config = self.configs.get(key, RateLimitConfig())
|
||||
counter = self.counters[key]
|
||||
config = self.configs.get(key, RateLimitConfig())
|
||||
|
||||
current_count = await counter.get_count()
|
||||
remaining = max(0, config.requests_per_minute - current_count)
|
||||
reset_time = int(time.time()) + config.window_size
|
||||
current_count = await counter.get_count()
|
||||
remaining = max(0, config.requests_per_minute - current_count)
|
||||
reset_time = int(time.time()) + config.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=current_count < config.requests_per_minute,
|
||||
remaining=remaining,
|
||||
reset_time=reset_time,
|
||||
retry_after=max(0, config.window_size)
|
||||
allowed = current_count < config.requests_per_minute,
|
||||
remaining = remaining,
|
||||
reset_time = reset_time,
|
||||
retry_after = max(0, config.window_size)
|
||||
if current_count >= config.requests_per_minute
|
||||
else 0,
|
||||
)
|
||||
|
||||
def reset(self, key: str | None = None) -> None:
|
||||
def reset(self, key: str | None = None) -> None:
|
||||
"""重置限流计数器"""
|
||||
if key:
|
||||
self.counters.pop(key, None)
|
||||
@@ -161,21 +161,21 @@ class RateLimiter:
|
||||
|
||||
|
||||
# 全局限流器实例
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取限流器实例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
# 限流装饰器(用于函数级别限流)
|
||||
|
||||
|
||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||||
"""
|
||||
限流装饰器
|
||||
|
||||
@@ -184,14 +184,14 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
|
||||
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
limiter = get_rate_limiter()
|
||||
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
||||
def decorator(func) -> None:
|
||||
limiter = get_rate_limiter()
|
||||
config = RateLimitConfig(requests_per_minute = requests_per_minute)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
info = await limiter.is_allowed(key, config)
|
||||
async def async_wrapper(*args, **kwargs) -> None:
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
info = await limiter.is_allowed(key, config)
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
@@ -201,10 +201,10 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
def sync_wrapper(*args, **kwargs) -> None:
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
# 同步版本使用 asyncio.run
|
||||
info = asyncio.run(limiter.is_allowed(key, config))
|
||||
info = asyncio.run(limiter.is_allowed(key, config))
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
|
||||
Reference in New Issue
Block a user