#!/usr/bin/env python3 """ InsightFlow API Key Manager - Phase 6 API Key 管理模块:生成、验证、撤销 """ import os import json import hashlib import secrets import sqlite3 from datetime import datetime, timedelta from typing import Optional, List, Dict from dataclasses import dataclass from enum import Enum DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") class ApiKeyStatus(Enum): ACTIVE = "active" REVOKED = "revoked" EXPIRED = "expired" @dataclass class ApiKey: id: str key_hash: str # 存储哈希值,不存储原始 key key_preview: str # 前8位预览,如 "ak_live_abc..." name: str # 密钥名称/描述 owner_id: Optional[str] # 所有者ID(预留多用户支持) permissions: List[str] # 权限列表,如 ["read", "write"] rate_limit: int # 每分钟请求限制 status: str # active, revoked, expired created_at: str expires_at: Optional[str] last_used_at: Optional[str] revoked_at: Optional[str] revoked_reason: Optional[str] total_calls: int = 0 class ApiKeyManager: """API Key 管理器""" # Key 前缀 KEY_PREFIX = "ak_live_" KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40) def __init__(self, db_path: str = DB_PATH): self.db_path = db_path self._init_db() def _init_db(self): """初始化数据库表""" with sqlite3.connect(self.db_path) as conn: conn.executescript(""" -- API Keys 表 CREATE TABLE IF NOT EXISTS api_keys ( id TEXT PRIMARY KEY, key_hash TEXT UNIQUE NOT NULL, key_preview TEXT NOT NULL, name TEXT NOT NULL, owner_id TEXT, permissions TEXT NOT NULL DEFAULT '["read"]', rate_limit INTEGER DEFAULT 60, status TEXT DEFAULT 'active', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, expires_at TIMESTAMP, last_used_at TIMESTAMP, revoked_at TIMESTAMP, revoked_reason TEXT, total_calls INTEGER DEFAULT 0 ); -- API 调用日志表 CREATE TABLE IF NOT EXISTS api_call_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, api_key_id TEXT NOT NULL, endpoint TEXT NOT NULL, method TEXT NOT NULL, status_code INTEGER, response_time_ms INTEGER, ip_address TEXT, user_agent TEXT, error_message TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ); -- API 调用统计表(按天汇总) CREATE TABLE IF NOT EXISTS api_call_stats ( id INTEGER PRIMARY KEY AUTOINCREMENT, api_key_id TEXT NOT NULL, date TEXT NOT NULL, endpoint TEXT NOT NULL, method TEXT NOT NULL, total_calls INTEGER DEFAULT 0, success_calls INTEGER DEFAULT 0, error_calls INTEGER DEFAULT 0, avg_response_time_ms INTEGER DEFAULT 0, FOREIGN KEY (api_key_id) REFERENCES api_keys(id), UNIQUE(api_key_id, date, endpoint, method) ); -- 创建索引 CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash); CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status); CREATE INDEX IF NOT EXISTS idx_api_keys_owner ON api_keys(owner_id); CREATE INDEX IF NOT EXISTS idx_api_logs_key_id ON api_call_logs(api_key_id); CREATE INDEX IF NOT EXISTS idx_api_logs_created ON api_call_logs(created_at); CREATE INDEX IF NOT EXISTS idx_api_stats_key_date ON api_call_stats(api_key_id, date); """) conn.commit() def _generate_key(self) -> str: """生成新的 API Key""" # 生成 40 字符的随机字符串 random_part = secrets.token_urlsafe(30)[:40] return f"{self.KEY_PREFIX}{random_part}" def _hash_key(self, key: str) -> str: """对 API Key 进行哈希""" return hashlib.sha256(key.encode()).hexdigest() def _get_preview(self, key: str) -> str: """获取 Key 的预览(前16位)""" return f"{key[:16]}..." def create_key( self, name: str, owner_id: Optional[str] = None, permissions: List[str] = None, rate_limit: int = 60, expires_days: Optional[int] = None ) -> tuple[str, ApiKey]: """ 创建新的 API Key Returns: tuple: (原始key(仅返回一次), ApiKey对象) """ if permissions is None: permissions = ["read"] key_id = secrets.token_hex(16) raw_key = self._generate_key() key_hash = self._hash_key(raw_key) key_preview = self._get_preview(raw_key) expires_at = None if expires_days: expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() api_key = ApiKey( id=key_id, key_hash=key_hash, key_preview=key_preview, name=name, owner_id=owner_id, permissions=permissions, rate_limit=rate_limit, status=ApiKeyStatus.ACTIVE.value, created_at=datetime.now().isoformat(), expires_at=expires_at, last_used_at=None, revoked_at=None, revoked_reason=None, total_calls=0 ) with sqlite3.connect(self.db_path) as conn: conn.execute(""" INSERT INTO api_keys ( id, key_hash, key_preview, name, owner_id, permissions, rate_limit, status, created_at, expires_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( api_key.id, api_key.key_hash, api_key.key_preview, api_key.name, api_key.owner_id, json.dumps(api_key.permissions), api_key.rate_limit, api_key.status, api_key.created_at, api_key.expires_at )) conn.commit() return raw_key, api_key def validate_key(self, key: str) -> Optional[ApiKey]: """ 验证 API Key Returns: ApiKey if valid, None otherwise """ key_hash = self._hash_key(key) with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row row = conn.execute( "SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,) ).fetchone() if not row: return None api_key = self._row_to_api_key(row) # 检查状态 if api_key.status != ApiKeyStatus.ACTIVE.value: return None # 检查是否过期 if api_key.expires_at: expires = datetime.fromisoformat(api_key.expires_at) if datetime.now() > expires: # 更新状态为过期 conn.execute( "UPDATE api_keys SET status = ? WHERE id = ?", (ApiKeyStatus.EXPIRED.value, api_key.id) ) conn.commit() return None return api_key def revoke_key( self, key_id: str, reason: str = "", owner_id: Optional[str] = None ) -> bool: """撤销 API Key""" with sqlite3.connect(self.db_path) as conn: # 验证所有权(如果提供了 owner_id) if owner_id: row = conn.execute( "SELECT owner_id FROM api_keys WHERE id = ?", (key_id,) ).fetchone() if not row or row[0] != owner_id: return False cursor = conn.execute(""" UPDATE api_keys SET status = ?, revoked_at = ?, revoked_reason = ? WHERE id = ? AND status = ? """, ( ApiKeyStatus.REVOKED.value, datetime.now().isoformat(), reason, key_id, ApiKeyStatus.ACTIVE.value )) conn.commit() return cursor.rowcount > 0 def get_key_by_id(self, key_id: str, owner_id: Optional[str] = None) -> Optional[ApiKey]: """通过 ID 获取 API Key(不包含敏感信息)""" with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row if owner_id: row = conn.execute( "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id) ).fetchone() else: row = conn.execute( "SELECT * FROM api_keys WHERE id = ?", (key_id,) ).fetchone() if row: return self._row_to_api_key(row) return None def list_keys( self, owner_id: Optional[str] = None, status: Optional[str] = None, limit: int = 100, offset: int = 0 ) -> List[ApiKey]: """列出 API Keys""" with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row query = "SELECT * FROM api_keys WHERE 1=1" params = [] if owner_id: query += " AND owner_id = ?" params.append(owner_id) if status: query += " AND status = ?" params.append(status) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) rows = conn.execute(query, params).fetchall() return [self._row_to_api_key(row) for row in rows] def update_key( self, key_id: str, name: Optional[str] = None, permissions: Optional[List[str]] = None, rate_limit: Optional[int] = None, owner_id: Optional[str] = None ) -> bool: """更新 API Key 信息""" updates = [] params = [] if name is not None: updates.append("name = ?") params.append(name) if permissions is not None: updates.append("permissions = ?") params.append(json.dumps(permissions)) if rate_limit is not None: updates.append("rate_limit = ?") params.append(rate_limit) if not updates: return False params.append(key_id) with sqlite3.connect(self.db_path) as conn: # 验证所有权 if owner_id: row = conn.execute( "SELECT owner_id FROM api_keys WHERE id = ?", (key_id,) ).fetchone() if not row or row[0] != owner_id: return False query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?" cursor = conn.execute(query, params) conn.commit() return cursor.rowcount > 0 def update_last_used(self, key_id: str): """更新最后使用时间""" with sqlite3.connect(self.db_path) as conn: conn.execute(""" UPDATE api_keys SET last_used_at = ?, total_calls = total_calls + 1 WHERE id = ? """, (datetime.now().isoformat(), key_id)) conn.commit() def log_api_call( self, api_key_id: str, endpoint: str, method: str, status_code: int = 200, response_time_ms: int = 0, ip_address: str = "", user_agent: str = "", error_message: str = "" ): """记录 API 调用日志""" with sqlite3.connect(self.db_path) as conn: conn.execute(""" INSERT INTO api_call_logs (api_key_id, endpoint, method, status_code, response_time_ms, ip_address, user_agent, error_message) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( api_key_id, endpoint, method, status_code, response_time_ms, ip_address, user_agent, error_message )) conn.commit() def get_call_logs( self, api_key_id: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 100, offset: int = 0 ) -> List[Dict]: """获取 API 调用日志""" with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row query = "SELECT * FROM api_call_logs WHERE 1=1" params = [] if api_key_id: query += " AND api_key_id = ?" params.append(api_key_id) if start_date: query += " AND created_at >= ?" params.append(start_date) if end_date: query += " AND created_at <= ?" params.append(end_date) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) rows = conn.execute(query, params).fetchall() return [dict(row) for row in rows] def get_call_stats( self, api_key_id: Optional[str] = None, days: int = 30 ) -> Dict: """获取 API 调用统计""" with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row # 总体统计 query = """ SELECT COUNT(*) as total_calls, COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls, COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_calls, AVG(response_time_ms) as avg_response_time, MAX(response_time_ms) as max_response_time, MIN(response_time_ms) as min_response_time FROM api_call_logs WHERE created_at >= date('now', '-{} days') """.format(days) params = [] if api_key_id: query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") params.insert(0, api_key_id) row = conn.execute(query, params).fetchone() # 按端点统计 endpoint_query = """ SELECT endpoint, method, COUNT(*) as calls, AVG(response_time_ms) as avg_time FROM api_call_logs WHERE created_at >= date('now', '-{} days') """.format(days) endpoint_params = [] if api_key_id: endpoint_query = endpoint_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") endpoint_params.insert(0, api_key_id) endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC" endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall() # 按天统计 daily_query = """ SELECT date(created_at) as date, COUNT(*) as calls, COUNT(CASE WHEN status_code < 400 THEN 1 END) as success FROM api_call_logs WHERE created_at >= date('now', '-{} days') """.format(days) daily_params = [] if api_key_id: daily_query = daily_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") daily_params.insert(0, api_key_id) daily_query += " GROUP BY date(created_at) ORDER BY date" daily_rows = conn.execute(daily_query, daily_params).fetchall() return { "summary": { "total_calls": row["total_calls"] or 0, "success_calls": row["success_calls"] or 0, "error_calls": row["error_calls"] or 0, "avg_response_time_ms": round(row["avg_response_time"] or 0, 2), "max_response_time_ms": row["max_response_time"] or 0, "min_response_time_ms": row["min_response_time"] or 0, }, "endpoints": [dict(r) for r in endpoint_rows], "daily": [dict(r) for r in daily_rows] } def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey: """将数据库行转换为 ApiKey 对象""" return ApiKey( id=row["id"], key_hash=row["key_hash"], key_preview=row["key_preview"], name=row["name"], owner_id=row["owner_id"], permissions=json.loads(row["permissions"]), rate_limit=row["rate_limit"], status=row["status"], created_at=row["created_at"], expires_at=row["expires_at"], last_used_at=row["last_used_at"], revoked_at=row["revoked_at"], revoked_reason=row["revoked_reason"], total_calls=row["total_calls"] ) # 全局实例 _api_key_manager: Optional[ApiKeyManager] = None def get_api_key_manager() -> ApiKeyManager: """获取 API Key 管理器实例""" global _api_key_manager if _api_key_manager is None: _api_key_manager = ApiKeyManager() return _api_key_manager