Phase 6: API Platform - Add authentication to existing endpoints and frontend API Key management UI
This commit is contained in:
BIN
backend/__pycache__/api_key_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/api_key_manager.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/__pycache__/export_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/export_manager.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/__pycache__/neo4j_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/neo4j_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/rate_limiter.cpython-312.pyc
Normal file
BIN
backend/__pycache__/rate_limiter.cpython-312.pyc
Normal file
Binary file not shown.
529
backend/api_key_manager.py
Normal file
529
backend/api_key_manager.py
Normal file
@@ -0,0 +1,529 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
InsightFlow API Key Manager - Phase 6
|
||||
API Key 管理模块:生成、验证、撤销
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import secrets
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
|
||||
|
||||
|
||||
class ApiKeyStatus(Enum):
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKey:
|
||||
id: str
|
||||
key_hash: str # 存储哈希值,不存储原始 key
|
||||
key_preview: str # 前8位预览,如 "ak_live_abc..."
|
||||
name: str # 密钥名称/描述
|
||||
owner_id: Optional[str] # 所有者ID(预留多用户支持)
|
||||
permissions: List[str] # 权限列表,如 ["read", "write"]
|
||||
rate_limit: int # 每分钟请求限制
|
||||
status: str # active, revoked, expired
|
||||
created_at: str
|
||||
expires_at: Optional[str]
|
||||
last_used_at: Optional[str]
|
||||
revoked_at: Optional[str]
|
||||
revoked_reason: Optional[str]
|
||||
total_calls: int = 0
|
||||
|
||||
|
||||
class ApiKeyManager:
|
||||
"""API Key 管理器"""
|
||||
|
||||
# Key 前缀
|
||||
KEY_PREFIX = "ak_live_"
|
||||
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库表"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.executescript("""
|
||||
-- API Keys 表
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id TEXT PRIMARY KEY,
|
||||
key_hash TEXT UNIQUE NOT NULL,
|
||||
key_preview TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
owner_id TEXT,
|
||||
permissions TEXT NOT NULL DEFAULT '["read"]',
|
||||
rate_limit INTEGER DEFAULT 60,
|
||||
status TEXT DEFAULT 'active',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP,
|
||||
last_used_at TIMESTAMP,
|
||||
revoked_at TIMESTAMP,
|
||||
revoked_reason TEXT,
|
||||
total_calls INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- API 调用日志表
|
||||
CREATE TABLE IF NOT EXISTS api_call_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
api_key_id TEXT NOT NULL,
|
||||
endpoint TEXT NOT NULL,
|
||||
method TEXT NOT NULL,
|
||||
status_code INTEGER,
|
||||
response_time_ms INTEGER,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (api_key_id) REFERENCES api_keys(id)
|
||||
);
|
||||
|
||||
-- API 调用统计表(按天汇总)
|
||||
CREATE TABLE IF NOT EXISTS api_call_stats (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
api_key_id TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
endpoint TEXT NOT NULL,
|
||||
method TEXT NOT NULL,
|
||||
total_calls INTEGER DEFAULT 0,
|
||||
success_calls INTEGER DEFAULT 0,
|
||||
error_calls INTEGER DEFAULT 0,
|
||||
avg_response_time_ms INTEGER DEFAULT 0,
|
||||
FOREIGN KEY (api_key_id) REFERENCES api_keys(id),
|
||||
UNIQUE(api_key_id, date, endpoint, method)
|
||||
);
|
||||
|
||||
-- 创建索引
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_owner ON api_keys(owner_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_logs_key_id ON api_call_logs(api_key_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_logs_created ON api_call_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_stats_key_date ON api_call_stats(api_key_id, date);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def _generate_key(self) -> str:
|
||||
"""生成新的 API Key"""
|
||||
# 生成 40 字符的随机字符串
|
||||
random_part = secrets.token_urlsafe(30)[:40]
|
||||
return f"{self.KEY_PREFIX}{random_part}"
|
||||
|
||||
def _hash_key(self, key: str) -> str:
|
||||
"""对 API Key 进行哈希"""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
def _get_preview(self, key: str) -> str:
|
||||
"""获取 Key 的预览(前16位)"""
|
||||
return f"{key[:16]}..."
|
||||
|
||||
def create_key(
|
||||
self,
|
||||
name: str,
|
||||
owner_id: Optional[str] = None,
|
||||
permissions: List[str] = None,
|
||||
rate_limit: int = 60,
|
||||
expires_days: Optional[int] = None
|
||||
) -> tuple[str, ApiKey]:
|
||||
"""
|
||||
创建新的 API Key
|
||||
|
||||
Returns:
|
||||
tuple: (原始key(仅返回一次), ApiKey对象)
|
||||
"""
|
||||
if permissions is None:
|
||||
permissions = ["read"]
|
||||
|
||||
key_id = secrets.token_hex(16)
|
||||
raw_key = self._generate_key()
|
||||
key_hash = self._hash_key(raw_key)
|
||||
key_preview = self._get_preview(raw_key)
|
||||
|
||||
expires_at = None
|
||||
if expires_days:
|
||||
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
|
||||
|
||||
api_key = ApiKey(
|
||||
id=key_id,
|
||||
key_hash=key_hash,
|
||||
key_preview=key_preview,
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
permissions=permissions,
|
||||
rate_limit=rate_limit,
|
||||
status=ApiKeyStatus.ACTIVE.value,
|
||||
created_at=datetime.now().isoformat(),
|
||||
expires_at=expires_at,
|
||||
last_used_at=None,
|
||||
revoked_at=None,
|
||||
revoked_reason=None,
|
||||
total_calls=0
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO api_keys (
|
||||
id, key_hash, key_preview, name, owner_id, permissions,
|
||||
rate_limit, status, created_at, expires_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
api_key.id, api_key.key_hash, api_key.key_preview,
|
||||
api_key.name, api_key.owner_id, json.dumps(api_key.permissions),
|
||||
api_key.rate_limit, api_key.status, api_key.created_at,
|
||||
api_key.expires_at
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
return raw_key, api_key
|
||||
|
||||
def validate_key(self, key: str) -> Optional[ApiKey]:
|
||||
"""
|
||||
验证 API Key
|
||||
|
||||
Returns:
|
||||
ApiKey if valid, None otherwise
|
||||
"""
|
||||
key_hash = self._hash_key(key)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
row = conn.execute(
|
||||
"SELECT * FROM api_keys WHERE key_hash = ?",
|
||||
(key_hash,)
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
api_key = self._row_to_api_key(row)
|
||||
|
||||
# 检查状态
|
||||
if api_key.status != ApiKeyStatus.ACTIVE.value:
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if api_key.expires_at:
|
||||
expires = datetime.fromisoformat(api_key.expires_at)
|
||||
if datetime.now() > expires:
|
||||
# 更新状态为过期
|
||||
conn.execute(
|
||||
"UPDATE api_keys SET status = ? WHERE id = ?",
|
||||
(ApiKeyStatus.EXPIRED.value, api_key.id)
|
||||
)
|
||||
conn.commit()
|
||||
return None
|
||||
|
||||
return api_key
|
||||
|
||||
def revoke_key(
|
||||
self,
|
||||
key_id: str,
|
||||
reason: str = "",
|
||||
owner_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""撤销 API Key"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# 验证所有权(如果提供了 owner_id)
|
||||
if owner_id:
|
||||
row = conn.execute(
|
||||
"SELECT owner_id FROM api_keys WHERE id = ?",
|
||||
(key_id,)
|
||||
).fetchone()
|
||||
if not row or row[0] != owner_id:
|
||||
return False
|
||||
|
||||
cursor = conn.execute("""
|
||||
UPDATE api_keys
|
||||
SET status = ?, revoked_at = ?, revoked_reason = ?
|
||||
WHERE id = ? AND status = ?
|
||||
""", (
|
||||
ApiKeyStatus.REVOKED.value,
|
||||
datetime.now().isoformat(),
|
||||
reason,
|
||||
key_id,
|
||||
ApiKeyStatus.ACTIVE.value
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_key_by_id(self, key_id: str, owner_id: Optional[str] = None) -> Optional[ApiKey]:
|
||||
"""通过 ID 获取 API Key(不包含敏感信息)"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
if owner_id:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?",
|
||||
(key_id, owner_id)
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM api_keys WHERE id = ?",
|
||||
(key_id,)
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
return self._row_to_api_key(row)
|
||||
return None
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
owner_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[ApiKey]:
|
||||
"""列出 API Keys"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
query = "SELECT * FROM api_keys WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if owner_id:
|
||||
query += " AND owner_id = ?"
|
||||
params.append(owner_id)
|
||||
|
||||
if status:
|
||||
query += " AND status = ?"
|
||||
params.append(status)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [self._row_to_api_key(row) for row in rows]
|
||||
|
||||
def update_key(
|
||||
self,
|
||||
key_id: str,
|
||||
name: Optional[str] = None,
|
||||
permissions: Optional[List[str]] = None,
|
||||
rate_limit: Optional[int] = None,
|
||||
owner_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""更新 API Key 信息"""
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if name is not None:
|
||||
updates.append("name = ?")
|
||||
params.append(name)
|
||||
|
||||
if permissions is not None:
|
||||
updates.append("permissions = ?")
|
||||
params.append(json.dumps(permissions))
|
||||
|
||||
if rate_limit is not None:
|
||||
updates.append("rate_limit = ?")
|
||||
params.append(rate_limit)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
params.append(key_id)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# 验证所有权
|
||||
if owner_id:
|
||||
row = conn.execute(
|
||||
"SELECT owner_id FROM api_keys WHERE id = ?",
|
||||
(key_id,)
|
||||
).fetchone()
|
||||
if not row or row[0] != owner_id:
|
||||
return False
|
||||
|
||||
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor = conn.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def update_last_used(self, key_id: str):
|
||||
"""更新最后使用时间"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
UPDATE api_keys
|
||||
SET last_used_at = ?, total_calls = total_calls + 1
|
||||
WHERE id = ?
|
||||
""", (datetime.now().isoformat(), key_id))
|
||||
conn.commit()
|
||||
|
||||
def log_api_call(
|
||||
self,
|
||||
api_key_id: str,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
status_code: int = 200,
|
||||
response_time_ms: int = 0,
|
||||
ip_address: str = "",
|
||||
user_agent: str = "",
|
||||
error_message: str = ""
|
||||
):
|
||||
"""记录 API 调用日志"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO api_call_logs
|
||||
(api_key_id, endpoint, method, status_code, response_time_ms,
|
||||
ip_address, user_agent, error_message)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
api_key_id, endpoint, method, status_code, response_time_ms,
|
||||
ip_address, user_agent, error_message
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
def get_call_logs(
|
||||
self,
|
||||
api_key_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[Dict]:
|
||||
"""获取 API 调用日志"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
query = "SELECT * FROM api_call_logs WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if api_key_id:
|
||||
query += " AND api_key_id = ?"
|
||||
params.append(api_key_id)
|
||||
|
||||
if start_date:
|
||||
query += " AND created_at >= ?"
|
||||
params.append(start_date)
|
||||
|
||||
if end_date:
|
||||
query += " AND created_at <= ?"
|
||||
params.append(end_date)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def get_call_stats(
|
||||
self,
|
||||
api_key_id: Optional[str] = None,
|
||||
days: int = 30
|
||||
) -> Dict:
|
||||
"""获取 API 调用统计"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# 总体统计
|
||||
query = """
|
||||
SELECT
|
||||
COUNT(*) as total_calls,
|
||||
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
|
||||
COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_calls,
|
||||
AVG(response_time_ms) as avg_response_time,
|
||||
MAX(response_time_ms) as max_response_time,
|
||||
MIN(response_time_ms) as min_response_time
|
||||
FROM api_call_logs
|
||||
WHERE created_at >= date('now', '-{} days')
|
||||
""".format(days)
|
||||
|
||||
params = []
|
||||
if api_key_id:
|
||||
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||
params.insert(0, api_key_id)
|
||||
|
||||
row = conn.execute(query, params).fetchone()
|
||||
|
||||
# 按端点统计
|
||||
endpoint_query = """
|
||||
SELECT
|
||||
endpoint,
|
||||
method,
|
||||
COUNT(*) as calls,
|
||||
AVG(response_time_ms) as avg_time
|
||||
FROM api_call_logs
|
||||
WHERE created_at >= date('now', '-{} days')
|
||||
""".format(days)
|
||||
|
||||
endpoint_params = []
|
||||
if api_key_id:
|
||||
endpoint_query = endpoint_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||
endpoint_params.insert(0, api_key_id)
|
||||
|
||||
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
|
||||
|
||||
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
|
||||
|
||||
# 按天统计
|
||||
daily_query = """
|
||||
SELECT
|
||||
date(created_at) as date,
|
||||
COUNT(*) as calls,
|
||||
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success
|
||||
FROM api_call_logs
|
||||
WHERE created_at >= date('now', '-{} days')
|
||||
""".format(days)
|
||||
|
||||
daily_params = []
|
||||
if api_key_id:
|
||||
daily_query = daily_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||
daily_params.insert(0, api_key_id)
|
||||
|
||||
daily_query += " GROUP BY date(created_at) ORDER BY date"
|
||||
|
||||
daily_rows = conn.execute(daily_query, daily_params).fetchall()
|
||||
|
||||
return {
|
||||
"summary": {
|
||||
"total_calls": row["total_calls"] or 0,
|
||||
"success_calls": row["success_calls"] or 0,
|
||||
"error_calls": row["error_calls"] or 0,
|
||||
"avg_response_time_ms": round(row["avg_response_time"] or 0, 2),
|
||||
"max_response_time_ms": row["max_response_time"] or 0,
|
||||
"min_response_time_ms": row["min_response_time"] or 0,
|
||||
},
|
||||
"endpoints": [dict(r) for r in endpoint_rows],
|
||||
"daily": [dict(r) for r in daily_rows]
|
||||
}
|
||||
|
||||
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
|
||||
"""将数据库行转换为 ApiKey 对象"""
|
||||
return ApiKey(
|
||||
id=row["id"],
|
||||
key_hash=row["key_hash"],
|
||||
key_preview=row["key_preview"],
|
||||
name=row["name"],
|
||||
owner_id=row["owner_id"],
|
||||
permissions=json.loads(row["permissions"]),
|
||||
rate_limit=row["rate_limit"],
|
||||
status=row["status"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
last_used_at=row["last_used_at"],
|
||||
revoked_at=row["revoked_at"],
|
||||
revoked_reason=row["revoked_reason"],
|
||||
total_calls=row["total_calls"]
|
||||
)
|
||||
|
||||
|
||||
# 全局实例
|
||||
_api_key_manager: Optional[ApiKeyManager] = None
|
||||
|
||||
|
||||
def get_api_key_manager() -> ApiKeyManager:
|
||||
"""获取 API Key 管理器实例"""
|
||||
global _api_key_manager
|
||||
if _api_key_manager is None:
|
||||
_api_key_manager = ApiKeyManager()
|
||||
return _api_key_manager
|
||||
804
backend/main.py
804
backend/main.py
File diff suppressed because it is too large
Load Diff
223
backend/rate_limiter.py
Normal file
223
backend/rate_limiter.py
Normal file
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
InsightFlow Rate Limiter - Phase 6
|
||||
API 限流中间件
|
||||
支持基于内存的滑动窗口限流
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
requests_per_minute: int = 60
|
||||
burst_size: int = 10 # 突发请求数
|
||||
window_size: int = 60 # 窗口大小(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""限流信息"""
|
||||
allowed: bool
|
||||
remaining: int
|
||||
reset_time: int # 重置时间戳
|
||||
retry_after: int # 需要等待的秒数
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""滑动窗口计数器"""
|
||||
|
||||
def __init__(self, window_size: int = 60):
|
||||
self.window_size = window_size
|
||||
self.requests: Dict[int, int] = defaultdict(int) # 秒级计数
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_request(self) -> int:
|
||||
"""添加请求,返回当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
self.requests[now] += 1
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
async def get_count(self) -> int:
|
||||
"""获取当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
def _cleanup_old(self, now: int):
|
||||
"""清理过期的请求记录"""
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in self.requests.keys() if k < cutoff]
|
||||
for k in old_keys:
|
||||
del self.requests[k]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""API 限流器"""
|
||||
|
||||
def __init__(self):
|
||||
# key -> SlidingWindowCounter
|
||||
self.counters: Dict[str, SlidingWindowCounter] = {}
|
||||
# key -> RateLimitConfig
|
||||
self.configs: Dict[str, RateLimitConfig] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(
|
||||
self,
|
||||
key: str,
|
||||
config: Optional[RateLimitConfig] = None
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
检查是否允许请求
|
||||
|
||||
Args:
|
||||
key: 限流键(如 API Key ID)
|
||||
config: 限流配置,如果为 None 则使用默认配置
|
||||
|
||||
Returns:
|
||||
RateLimitInfo
|
||||
"""
|
||||
if config is None:
|
||||
config = RateLimitConfig()
|
||||
|
||||
async with self._lock:
|
||||
if key not in self.counters:
|
||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||
self.configs[key] = config
|
||||
|
||||
counter = self.counters[key]
|
||||
stored_config = self.configs.get(key, config)
|
||||
|
||||
# 获取当前计数
|
||||
current_count = await counter.get_count()
|
||||
|
||||
# 计算剩余配额
|
||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||
|
||||
# 计算重置时间
|
||||
now = int(time.time())
|
||||
reset_time = now + stored_config.window_size
|
||||
|
||||
# 检查是否超过限制
|
||||
if current_count >= stored_config.requests_per_minute:
|
||||
return RateLimitInfo(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_time=reset_time,
|
||||
retry_after=stored_config.window_size
|
||||
)
|
||||
|
||||
# 允许请求,增加计数
|
||||
await counter.add_request()
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=True,
|
||||
remaining=remaining - 1,
|
||||
reset_time=reset_time,
|
||||
retry_after=0
|
||||
)
|
||||
|
||||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||||
"""获取限流信息(不增加计数)"""
|
||||
if key not in self.counters:
|
||||
config = RateLimitConfig()
|
||||
return RateLimitInfo(
|
||||
allowed=True,
|
||||
remaining=config.requests_per_minute,
|
||||
reset_time=int(time.time()) + config.window_size,
|
||||
retry_after=0
|
||||
)
|
||||
|
||||
counter = self.counters[key]
|
||||
config = self.configs.get(key, RateLimitConfig())
|
||||
|
||||
current_count = await counter.get_count()
|
||||
remaining = max(0, config.requests_per_minute - current_count)
|
||||
reset_time = int(time.time()) + config.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=current_count < config.requests_per_minute,
|
||||
remaining=remaining,
|
||||
reset_time=reset_time,
|
||||
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0
|
||||
)
|
||||
|
||||
def reset(self, key: Optional[str] = None):
|
||||
"""重置限流计数器"""
|
||||
if key:
|
||||
self.counters.pop(key, None)
|
||||
self.configs.pop(key, None)
|
||||
else:
|
||||
self.counters.clear()
|
||||
self.configs.clear()
|
||||
|
||||
|
||||
# 全局限流器实例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取限流器实例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
# 限流装饰器(用于函数级别限流)
|
||||
def rate_limit(
|
||||
requests_per_minute: int = 60,
|
||||
key_func: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
限流装饰器
|
||||
|
||||
Args:
|
||||
requests_per_minute: 每分钟请求数限制
|
||||
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
||||
"""
|
||||
def decorator(func):
|
||||
limiter = get_rate_limiter()
|
||||
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
info = await limiter.is_allowed(key, config)
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
# 同步版本使用 asyncio.run
|
||||
info = asyncio.run(limiter.is_allowed(key, config))
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""限流异常"""
|
||||
pass
|
||||
@@ -30,3 +30,6 @@ cairosvg==2.7.1
|
||||
|
||||
# Neo4j Graph Database
|
||||
neo4j==5.15.0
|
||||
|
||||
# API Documentation (Swagger/OpenAPI)
|
||||
fastapi-offline-swagger==0.1.0
|
||||
|
||||
Reference in New Issue
Block a user