Files
insightflow/backend/api_key_manager.py
AutoFix Bot e46c938b40 fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题 (E302, E305, E501)
- 修复行长度超过100字符的问题
- 修复F821未定义名称错误
2026-03-01 18:19:06 +08:00

540 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
InsightFlow API Key Manager - Phase 6
API Key 管理模块:生成、验证、撤销
"""
import hashlib
import json
import os
import secrets
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
class ApiKeyStatus(Enum):
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
@dataclass
class ApiKey:
id: str
key_hash: str # 存储哈希值,不存储原始 key
key_preview: str # 前8位预览如 "ak_live_abc..."
name: str # 密钥名称/描述
owner_id: str | None # 所有者ID预留多用户支持
permissions: list[str] # 权限列表,如 ["read", "write"]
rate_limit: int # 每分钟请求限制
status: str # active, revoked, expired
created_at: str
expires_at: str | None
last_used_at: str | None
revoked_at: str | None
revoked_reason: str | None
total_calls: int = 0
class ApiKeyManager:
"""API Key 管理器"""
# Key 前缀
KEY_PREFIX = "ak_live_"
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path
self._init_db()
def _init_db(self) -> None:
"""初始化数据库表"""
with sqlite3.connect(self.db_path) as conn:
conn.executescript("""
-- API Keys 表
CREATE TABLE IF NOT EXISTS api_keys (
id TEXT PRIMARY KEY,
key_hash TEXT UNIQUE NOT NULL,
key_preview TEXT NOT NULL,
name TEXT NOT NULL,
owner_id TEXT,
permissions TEXT NOT NULL DEFAULT '["read"]',
rate_limit INTEGER DEFAULT 60,
status TEXT DEFAULT 'active',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
last_used_at TIMESTAMP,
revoked_at TIMESTAMP,
revoked_reason TEXT,
total_calls INTEGER DEFAULT 0
);
-- API 调用日志表
CREATE TABLE IF NOT EXISTS api_call_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
api_key_id TEXT NOT NULL,
endpoint TEXT NOT NULL,
method TEXT NOT NULL,
status_code INTEGER,
response_time_ms INTEGER,
ip_address TEXT,
user_agent TEXT,
error_message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (api_key_id) REFERENCES api_keys(id)
);
-- API 调用统计表(按天汇总)
CREATE TABLE IF NOT EXISTS api_call_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
api_key_id TEXT NOT NULL,
date TEXT NOT NULL,
endpoint TEXT NOT NULL,
method TEXT NOT NULL,
total_calls INTEGER DEFAULT 0,
success_calls INTEGER DEFAULT 0,
error_calls INTEGER DEFAULT 0,
avg_response_time_ms INTEGER DEFAULT 0,
FOREIGN KEY (api_key_id) REFERENCES api_keys(id),
UNIQUE(api_key_id, date, endpoint, method)
);
-- 创建索引
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
CREATE INDEX IF NOT EXISTS idx_api_keys_owner ON api_keys(owner_id);
CREATE INDEX IF NOT EXISTS idx_api_logs_key_id ON api_call_logs(api_key_id);
CREATE INDEX IF NOT EXISTS idx_api_logs_created ON api_call_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_api_stats_key_date
ON api_call_stats(api_key_id, date);
""")
conn.commit()
def _generate_key(self) -> str:
"""生成新的 API Key"""
# 生成 40 字符的随机字符串
random_part = secrets.token_urlsafe(30)[:40]
return f"{self.KEY_PREFIX}{random_part}"
def _hash_key(self, key: str) -> str:
"""对 API Key 进行哈希"""
return hashlib.sha256(key.encode()).hexdigest()
def _get_preview(self, key: str) -> str:
"""获取 Key 的预览前16位"""
return f"{key[:16]}..."
def create_key(
self,
name: str,
owner_id: str | None = None,
permissions: list[str] = None,
rate_limit: int = 60,
expires_days: int | None = None,
) -> tuple[str, ApiKey]:
"""
创建新的 API Key
Returns:
tuple: (原始key仅返回一次, ApiKey对象)
"""
if permissions is None:
permissions = ["read"]
key_id = secrets.token_hex(16)
raw_key = self._generate_key()
key_hash = self._hash_key(raw_key)
key_preview = self._get_preview(raw_key)
expires_at = None
if expires_days:
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
api_key = ApiKey(
id=key_id,
key_hash=key_hash,
key_preview=key_preview,
name=name,
owner_id=owner_id,
permissions=permissions,
rate_limit=rate_limit,
status=ApiKeyStatus.ACTIVE.value,
created_at=datetime.now().isoformat(),
expires_at=expires_at,
last_used_at=None,
revoked_at=None,
revoked_reason=None,
total_calls=0,
)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO api_keys (
id, key_hash, key_preview, name, owner_id, permissions,
rate_limit, status, created_at, expires_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
api_key.id,
api_key.key_hash,
api_key.key_preview,
api_key.name,
api_key.owner_id,
json.dumps(api_key.permissions),
api_key.rate_limit,
api_key.status,
api_key.created_at,
api_key.expires_at,
),
)
conn.commit()
return raw_key, api_key
def validate_key(self, key: str) -> ApiKey | None:
"""
验证 API Key
Returns:
ApiKey if valid, None otherwise
"""
key_hash = self._hash_key(key)
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone()
if not row:
return None
api_key = self._row_to_api_key(row)
# 检查状态
if api_key.status != ApiKeyStatus.ACTIVE.value:
return None
# 检查是否过期
if api_key.expires_at:
expires = datetime.fromisoformat(api_key.expires_at)
if datetime.now() > expires:
# 更新状态为过期
conn.execute(
"UPDATE api_keys SET status = ? WHERE id = ?",
(ApiKeyStatus.EXPIRED.value, api_key.id),
)
conn.commit()
return None
return api_key
def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
"""撤销 API Key"""
with sqlite3.connect(self.db_path) as conn:
# 验证所有权(如果提供了 owner_id
if owner_id:
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
).fetchone()
if not row or row[0] != owner_id:
return False
cursor = conn.execute(
"""
UPDATE api_keys
SET status = ?, revoked_at = ?, revoked_reason = ?
WHERE id = ? AND status = ?
""",
(
ApiKeyStatus.REVOKED.value,
datetime.now().isoformat(),
reason,
key_id,
ApiKeyStatus.ACTIVE.value,
),
)
conn.commit()
return cursor.rowcount > 0
def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
"""通过 ID 获取 API Key不包含敏感信息"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
if owner_id:
row = conn.execute(
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
).fetchone()
else:
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
if row:
return self._row_to_api_key(row)
return None
def list_keys(
self,
owner_id: str | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ApiKey]:
"""列出 API Keys"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_keys WHERE 1=1"
params = []
if owner_id:
query += " AND owner_id = ?"
params.append(owner_id)
if status:
query += " AND status = ?"
params.append(status)
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
rows = conn.execute(query, params).fetchall()
return [self._row_to_api_key(row) for row in rows]
def update_key(
self,
key_id: str,
name: str | None = None,
permissions: list[str] | None = None,
rate_limit: int | None = None,
owner_id: str | None = None,
) -> bool:
"""更新 API Key 信息"""
updates = []
params = []
if name is not None:
updates.append("name = ?")
params.append(name)
if permissions is not None:
updates.append("permissions = ?")
params.append(json.dumps(permissions))
if rate_limit is not None:
updates.append("rate_limit = ?")
params.append(rate_limit)
if not updates:
return False
params.append(key_id)
with sqlite3.connect(self.db_path) as conn:
# 验证所有权
if owner_id:
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
).fetchone()
if not row or row[0] != owner_id:
return False
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
cursor = conn.execute(query, params)
conn.commit()
return cursor.rowcount > 0
def update_last_used(self, key_id: str) -> None:
"""更新最后使用时间"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
UPDATE api_keys
SET last_used_at = ?, total_calls = total_calls + 1
WHERE id = ?
""",
(datetime.now().isoformat(), key_id),
)
conn.commit()
def log_api_call(
self,
api_key_id: str,
endpoint: str,
method: str,
status_code: int = 200,
response_time_ms: int = 0,
ip_address: str = "",
user_agent: str = "",
error_message: str = "",
):
"""记录 API 调用日志"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO api_call_logs
(api_key_id, endpoint, method, status_code, response_time_ms,
ip_address, user_agent, error_message)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
api_key_id,
endpoint,
method,
status_code,
response_time_ms,
ip_address,
user_agent,
error_message,
),
)
conn.commit()
def get_call_logs(
self,
api_key_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[dict]:
"""获取 API 调用日志"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_call_logs WHERE 1=1"
params = []
if api_key_id:
query += " AND api_key_id = ?"
params.append(api_key_id)
if start_date:
query += " AND created_at >= ?"
params.append(start_date)
if end_date:
query += " AND created_at <= ?"
params.append(end_date)
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
rows = conn.execute(query, params).fetchall()
return [dict(row) for row in rows]
def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
"""获取 API 调用统计"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
# 总体统计
query = f"""
SELECT
COUNT(*) as total_calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_calls,
AVG(response_time_ms) as avg_response_time,
MAX(response_time_ms) as max_response_time,
MIN(response_time_ms) as min_response_time
FROM api_call_logs
WHERE created_at >= date('now', '-{days} days')
"""
params = []
if api_key_id:
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
params.insert(0, api_key_id)
row = conn.execute(query, params).fetchone()
# 按端点统计
endpoint_query = f"""
SELECT
endpoint,
method,
COUNT(*) as calls,
AVG(response_time_ms) as avg_time
FROM api_call_logs
WHERE created_at >= date('now', '-{days} days')
"""
endpoint_params = []
if api_key_id:
endpoint_query = endpoint_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
)
endpoint_params.insert(0, api_key_id)
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
# 按天统计
daily_query = f"""
SELECT
date(created_at) as date,
COUNT(*) as calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success
FROM api_call_logs
WHERE created_at >= date('now', '-{days} days')
"""
daily_params = []
if api_key_id:
daily_query = daily_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
)
daily_params.insert(0, api_key_id)
daily_query += " GROUP BY date(created_at) ORDER BY date"
daily_rows = conn.execute(daily_query, daily_params).fetchall()
return {
"summary": {
"total_calls": row["total_calls"] or 0,
"success_calls": row["success_calls"] or 0,
"error_calls": row["error_calls"] or 0,
"avg_response_time_ms": round(row["avg_response_time"] or 0, 2),
"max_response_time_ms": row["max_response_time"] or 0,
"min_response_time_ms": row["min_response_time"] or 0,
},
"endpoints": [dict(r) for r in endpoint_rows],
"daily": [dict(r) for r in daily_rows],
}
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
"""将数据库行转换为 ApiKey 对象"""
return ApiKey(
id=row["id"],
key_hash=row["key_hash"],
key_preview=row["key_preview"],
name=row["name"],
owner_id=row["owner_id"],
permissions=json.loads(row["permissions"]),
rate_limit=row["rate_limit"],
status=row["status"],
created_at=row["created_at"],
expires_at=row["expires_at"],
last_used_at=row["last_used_at"],
revoked_at=row["revoked_at"],
revoked_reason=row["revoked_reason"],
total_calls=row["total_calls"],
)
# 全局实例
_api_key_manager: ApiKeyManager | None = None
def get_api_key_manager() -> ApiKeyManager:
"""获取 API Key 管理器实例"""
global _api_key_manager
if _api_key_manager is None:
_api_key_manager = ApiKeyManager()
return _api_key_manager