From d767f0dddca67182532a23d907aa90e285152db0 Mon Sep 17 00:00:00 2001 From: OpenClaw Bot Date: Fri, 27 Feb 2026 21:12:04 +0800 Subject: [PATCH] fix: auto-fix code issues (cron) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解 --- backend/ai_manager.py | 5 +- backend/api_key_manager.py | 4 +- backend/collaboration_manager.py | 14 +- backend/db_manager.py | 4 +- backend/developer_ecosystem_manager.py | 18 +- backend/enterprise_manager.py | 4 +- backend/export_manager.py | 22 +- backend/growth_manager.py | 2 +- backend/image_processor.py | 8 +- backend/knowledge_reasoner.py | 10 - backend/llm_client.py | 4 +- backend/localization_manager.py | 4 +- backend/main.py | 3741 ++++++++++-------------- backend/multimodal_entity_linker.py | 2 +- backend/multimodal_processor.py | 4 +- backend/neo4j_manager.py | 24 +- backend/ops_manager.py | 16 +- backend/oss_uploader.py | 2 +- backend/performance_manager.py | 44 +- backend/plugin_manager.py | 12 +- backend/rate_limiter.py | 8 +- backend/search_manager.py | 899 +++--- backend/security_manager.py | 458 ++- backend/subscription_manager.py | 1146 +++++--- backend/tenant_manager.py | 763 +++-- backend/tingwu_client.py | 52 +- backend/workflow_manager.py | 524 ++-- 27 files changed, 3636 insertions(+), 4158 deletions(-) diff --git a/backend/ai_manager.py b/backend/ai_manager.py index 8aab05e..9468613 100644 --- a/backend/ai_manager.py +++ b/backend/ai_manager.py @@ -209,7 +209,7 @@ class AIManager: self.openai_api_key = os.getenv("OPENAI_API_KEY", "") self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") - def _get_db(self): + def _get_db(self) -> None: """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row @@ -921,7 +921,6 @@ class AIManager: # 解析关键要点 key_points = [] - import re # 尝试从 JSON 中提取 json_match = re.search(r"\{.*?\}", content, re.DOTALL) @@ -1312,7 +1311,7 @@ class AIManager: return [self._row_to_prediction_result(row) for row in rows] - def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool): + def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None: """更新预测反馈(用于模型改进)""" with self._get_db() as conn: conn.execute( diff --git a/backend/api_key_manager.py b/backend/api_key_manager.py index 65cae67..fd7846e 100644 --- a/backend/api_key_manager.py +++ b/backend/api_key_manager.py @@ -51,7 +51,7 @@ class ApiKeyManager: self.db_path = db_path self._init_db() - def _init_db(self): + def _init_db(self) -> None: """初始化数据库表""" with sqlite3.connect(self.db_path) as conn: conn.executescript(""" @@ -331,7 +331,7 @@ class ApiKeyManager: conn.commit() return cursor.rowcount > 0 - def update_last_used(self, key_id: str): + def update_last_used(self, key_id: str) -> None: """更新最后使用时间""" with sqlite3.connect(self.db_path) as conn: conn.execute( diff --git a/backend/collaboration_manager.py b/backend/collaboration_manager.py index d7e0f8a..ba8cdcb 100644 --- a/backend/collaboration_manager.py +++ b/backend/collaboration_manager.py @@ -195,7 +195,7 @@ class CollaborationManager: data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}" return hashlib.sha256(data.encode()).hexdigest()[:32] - def _save_share_to_db(self, share: ProjectShare): + def _save_share_to_db(self, share: ProjectShare) -> None: """保存分享记录到数据库""" cursor = self.db.conn.cursor() cursor.execute( @@ -286,7 +286,7 @@ class CollaborationManager: allow_export=bool(row[12]), ) - def increment_share_usage(self, token: str): + def increment_share_usage(self, token: str) -> None: """增加分享链接使用次数""" share = self._shares_cache.get(token) if share: @@ -403,7 +403,7 @@ class CollaborationManager: return comment - def _save_comment_to_db(self, comment: Comment): + def _save_comment_to_db(self, comment: Comment) -> None: """保存评论到数据库""" cursor = self.db.conn.cursor() cursor.execute( @@ -616,7 +616,7 @@ class CollaborationManager: return record - def _save_change_to_db(self, record: ChangeRecord): + def _save_change_to_db(self, record: ChangeRecord) -> None: """保存变更记录到数据库""" cursor = self.db.conn.cursor() cursor.execute( @@ -862,7 +862,7 @@ class CollaborationManager: } return permissions_map.get(role, ["read"]) - def _save_member_to_db(self, member: TeamMember): + def _save_member_to_db(self, member: TeamMember) -> None: """保存成员到数据库""" cursor = self.db.conn.cursor() cursor.execute( @@ -970,7 +970,7 @@ class CollaborationManager: permissions = json.loads(row[0]) if row[0] else [] return permission in permissions or "admin" in permissions - def update_last_active(self, project_id: str, user_id: str): + def update_last_active(self, project_id: str, user_id: str) -> None: """更新用户最后活跃时间""" if not self.db: return @@ -992,7 +992,7 @@ class CollaborationManager: _collaboration_manager = None -def get_collaboration_manager(db_manager=None): +def get_collaboration_manager(db_manager=None) -> None: """获取协作管理器单例""" global _collaboration_manager if _collaboration_manager is None: diff --git a/backend/db_manager.py b/backend/db_manager.py index 667f5cc..03afbaf 100644 --- a/backend/db_manager.py +++ b/backend/db_manager.py @@ -123,7 +123,7 @@ class DatabaseManager: conn.row_factory = sqlite3.Row return conn - def init_db(self): + def init_db(self) -> None: """初始化数据库表""" with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f: schema = f.read() @@ -299,7 +299,7 @@ class DatabaseManager: conn.close() return self.get_entity(entity_id) - def delete_entity(self, entity_id: str): + def delete_entity(self, entity_id: str) -> None: """删除实体及其关联数据""" conn = self.get_conn() conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,)) diff --git a/backend/developer_ecosystem_manager.py b/backend/developer_ecosystem_manager.py index c0bbc89..a1b393e 100644 --- a/backend/developer_ecosystem_manager.py +++ b/backend/developer_ecosystem_manager.py @@ -352,7 +352,7 @@ class DeveloperEcosystemManager: self.db_path = db_path self.platform_fee_rate = 0.30 # 平台抽成比例 30% - def _get_db(self): + def _get_db(self) -> None: """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row @@ -515,7 +515,7 @@ class DeveloperEcosystemManager: return self.get_sdk_release(sdk_id) - def increment_sdk_download(self, sdk_id: str): + def increment_sdk_download(self, sdk_id: str) -> None: """增加 SDK 下载计数""" with self._get_db() as conn: conn.execute( @@ -785,7 +785,7 @@ class DeveloperEcosystemManager: return self.get_template(template_id) - def increment_template_install(self, template_id: str): + def increment_template_install(self, template_id: str) -> None: """增加模板安装计数""" with self._get_db() as conn: conn.execute( @@ -852,7 +852,7 @@ class DeveloperEcosystemManager: return review - def _update_template_rating(self, conn, template_id: str): + def _update_template_rating(self, conn, template_id: str) -> None: """更新模板评分""" row = conn.execute( """ @@ -1084,7 +1084,7 @@ class DeveloperEcosystemManager: return self.get_plugin(plugin_id) - def increment_plugin_install(self, plugin_id: str, active: bool = True): + def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None: """增加插件安装计数""" with self._get_db() as conn: conn.execute( @@ -1160,7 +1160,7 @@ class DeveloperEcosystemManager: return review - def _update_plugin_rating(self, conn, plugin_id: str): + def _update_plugin_rating(self, conn, plugin_id: str) -> None: """更新插件评分""" row = conn.execute( """ @@ -1421,7 +1421,7 @@ class DeveloperEcosystemManager: return self.get_developer_profile(developer_id) - def update_developer_stats(self, developer_id: str): + def update_developer_stats(self, developer_id: str) -> None: """更新开发者统计信息""" with self._get_db() as conn: # 统计插件数量 @@ -1574,7 +1574,7 @@ class DeveloperEcosystemManager: rows = conn.execute(query, params).fetchall() return [self._row_to_code_example(row) for row in rows] - def increment_example_view(self, example_id: str): + def increment_example_view(self, example_id: str) -> None: """增加代码示例查看计数""" with self._get_db() as conn: conn.execute( @@ -1587,7 +1587,7 @@ class DeveloperEcosystemManager: ) conn.commit() - def increment_example_copy(self, example_id: str): + def increment_example_copy(self, example_id: str) -> None: """增加代码示例复制计数""" with self._get_db() as conn: conn.execute( diff --git a/backend/enterprise_manager.py b/backend/enterprise_manager.py index 6d86c1b..aab1d23 100644 --- a/backend/enterprise_manager.py +++ b/backend/enterprise_manager.py @@ -329,7 +329,7 @@ class EnterpriseManager: conn.row_factory = sqlite3.Row return conn - def _init_db(self): + def _init_db(self) -> None: """初始化数据库表""" conn = self._get_connection() try: @@ -1190,7 +1190,7 @@ class EnterpriseManager: # GET {scim_base_url}/Users return [] - def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]): + def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]) -> None: """插入或更新 SCIM 用户""" cursor = conn.cursor() diff --git a/backend/export_manager.py b/backend/export_manager.py index da174cf..f908927 100644 --- a/backend/export_manager.py +++ b/backend/export_manager.py @@ -22,14 +22,7 @@ try: from reportlab.lib.pagesizes import A4 from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet from reportlab.lib.units import inch - from reportlab.platypus import ( - PageBreak, - Paragraph, - SimpleDocTemplate, - Spacer, - Table, - TableStyle, - ) + from reportlab.platypus import PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle REPORTLAB_AVAILABLE = True except ImportError: @@ -194,20 +187,16 @@ class ExportManager: f'fill="white" stroke="#bdc3c7" rx="5"/>' ) svg_parts.append( - f'实体类型' + f'实体类型' ) for i, (etype, color) in enumerate(type_colors.items()): if etype != "default": y_pos = legend_y + 25 + i * 20 - svg_parts.append( - f'' - ) + svg_parts.append(f'') text_y = y_pos + 4 svg_parts.append( - f'{etype}' + f'{etype}' ) svg_parts.append("") @@ -320,7 +309,6 @@ class ExportManager: Returns: CSV 字符串 """ - import csv output = io.StringIO() writer = csv.writer(output) @@ -586,7 +574,7 @@ class ExportManager: _export_manager = None -def get_export_manager(db_manager=None): +def get_export_manager(db_manager=None) -> None: """获取导出管理器实例""" global _export_manager if _export_manager is None: diff --git a/backend/growth_manager.py b/backend/growth_manager.py index a630e34..a988c88 100644 --- a/backend/growth_manager.py +++ b/backend/growth_manager.py @@ -369,7 +369,7 @@ class GrowthManager: self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "") self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "") - def _get_db(self): + def _get_db(self) -> None: """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row diff --git a/backend/image_processor.py b/backend/image_processor.py index 58e514f..513410f 100644 --- a/backend/image_processor.py +++ b/backend/image_processor.py @@ -93,7 +93,7 @@ class ImageProcessor: "other": "其他", } - def __init__(self, temp_dir: str = None): + def __init__(self, temp_dir: str = None) -> None: """ 初始化图片处理器 @@ -103,7 +103,7 @@ class ImageProcessor: self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") os.makedirs(self.temp_dir, exist_ok=True) - def preprocess_image(self, image, image_type: str = None): + def preprocess_image(self, image, image_type: str = None) -> None: """ 预处理图片以提高OCR质量 @@ -145,7 +145,7 @@ class ImageProcessor: print(f"Image preprocessing error: {e}") return image - def _enhance_whiteboard(self, image): + def _enhance_whiteboard(self, image) -> None: """增强白板图片""" # 转换为灰度 gray = image.convert("L") @@ -160,7 +160,7 @@ class ImageProcessor: return binary.convert("L") - def _enhance_handwritten(self, image): + def _enhance_handwritten(self, image) -> None: """增强手写笔记图片""" # 转换为灰度 gray = image.convert("L") diff --git a/backend/knowledge_reasoner.py b/backend/knowledge_reasoner.py index 0cac3b6..1810ef6 100644 --- a/backend/knowledge_reasoner.py +++ b/backend/knowledge_reasoner.py @@ -163,8 +163,6 @@ class KnowledgeReasoner: content = await self._call_llm(prompt, temperature=0.3) - import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: @@ -217,8 +215,6 @@ class KnowledgeReasoner: content = await self._call_llm(prompt, temperature=0.3) - import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: @@ -271,8 +267,6 @@ class KnowledgeReasoner: content = await self._call_llm(prompt, temperature=0.3) - import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: @@ -325,8 +319,6 @@ class KnowledgeReasoner: content = await self._call_llm(prompt, temperature=0.4) - import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: @@ -474,8 +466,6 @@ class KnowledgeReasoner: content = await self._call_llm(prompt, temperature=0.3) - import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: diff --git a/backend/llm_client.py b/backend/llm_client.py index 8e5da38..2068177 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -204,15 +204,13 @@ class LLMClient: messages = [ChatMessage(role="user", content=prompt)] content = await self.chat(messages, temperature=0.1) - import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if not json_match: return {"intent": "unknown", "explanation": "无法解析指令"} try: return json.loads(json_match.group()) - except BaseException: + except (json.JSONDecodeError, KeyError, TypeError): return {"intent": "unknown", "explanation": "解析失败"} async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str: diff --git a/backend/localization_manager.py b/backend/localization_manager.py index 96c05c1..b4fe000 100644 --- a/backend/localization_manager.py +++ b/backend/localization_manager.py @@ -938,9 +938,7 @@ class LocalizationManager: finally: self._close_if_file_db(conn) - def get_translation( - self, key: str, language: str, namespace: str = "common", fallback: bool = True - ) -> str | None: + def get_translation(self, key: str, language: str, namespace: str = "common", fallback: bool = True) -> str | None: conn = self._get_connection() try: cursor = conn.cursor() diff --git a/backend/main.py b/backend/main.py index af752dc..33a5944 100644 --- a/backend/main.py +++ b/backend/main.py @@ -17,18 +17,7 @@ from datetime import datetime, timedelta from typing import Any, Optional import httpx -from fastapi import ( - Body, - Depends, - FastAPI, - File, - Form, - Header, - HTTPException, - Query, - Request, - UploadFile, -) +from fastapi import Body, Depends, FastAPI, File, Form, Header, HTTPException, Query, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles @@ -45,18 +34,21 @@ if backend_dir not in sys.path: # Import clients try: from oss_uploader import get_oss_uploader + OSS_AVAILABLE = True except ImportError: OSS_AVAILABLE = False try: from tingwu_client import TingwuClient + TINGWU_AVAILABLE = True except ImportError: TINGWU_AVAILABLE = False try: from db_manager import Entity, EntityMention, get_db_manager + DB_AVAILABLE = True except ImportError as e: print(f"DB import error: {e}") @@ -64,30 +56,35 @@ except ImportError as e: try: from document_processor import DocumentProcessor + DOC_PROCESSOR_AVAILABLE = True except ImportError: DOC_PROCESSOR_AVAILABLE = False try: from entity_aligner import EntityAligner + ALIGNER_AVAILABLE = True except ImportError: ALIGNER_AVAILABLE = False try: from llm_client import ChatMessage, get_llm_client + LLM_CLIENT_AVAILABLE = True except ImportError: LLM_CLIENT_AVAILABLE = False try: from knowledge_reasoner import get_knowledge_reasoner + REASONER_AVAILABLE = True except ImportError: REASONER_AVAILABLE = False try: from export_manager import ExportEntity, ExportRelation, ExportTranscript, get_export_manager + EXPORT_AVAILABLE = True except ImportError: EXPORT_AVAILABLE = False @@ -100,6 +97,7 @@ except ImportError: # Phase 6: API Key Manager try: from api_key_manager import get_api_key_manager + API_KEY_AVAILABLE = True except ImportError as e: print(f"API Key Manager import error: {e}") @@ -108,6 +106,7 @@ except ImportError as e: # Phase 6: Rate Limiter try: from rate_limiter import RateLimitConfig, get_rate_limiter + RATE_LIMITER_AVAILABLE = True except ImportError as e: print(f"Rate Limiter import error: {e}") @@ -116,6 +115,7 @@ except ImportError as e: # Phase 7: Workflow Manager try: from workflow_manager import WebhookConfig, Workflow + WORKFLOW_AVAILABLE = True except ImportError as e: print(f"Workflow Manager import error: {e}") @@ -124,6 +124,7 @@ except ImportError as e: # Phase 7: Multimodal Support try: from multimodal_processor import get_multimodal_processor + MULTIMODAL_AVAILABLE = True except ImportError as e: print(f"Multimodal Processor import error: {e}") @@ -131,6 +132,7 @@ except ImportError as e: try: from image_processor import get_image_processor + IMAGE_PROCESSOR_AVAILABLE = True except ImportError as e: print(f"Image Processor import error: {e}") @@ -138,6 +140,7 @@ except ImportError as e: try: from multimodal_entity_linker import EntityLink, get_multimodal_entity_linker + MULTIMODAL_LINKER_AVAILABLE = True except ImportError as e: print(f"Multimodal Entity Linker import error: {e}") @@ -145,14 +148,8 @@ except ImportError as e: # Phase 7 Task 7: Plugin Manager try: - from plugin_manager import ( - BotHandler, - Plugin, - PluginStatus, - PluginType, - WebhookIntegration, - get_plugin_manager, - ) + from plugin_manager import BotHandler, Plugin, PluginStatus, PluginType, WebhookIntegration, get_plugin_manager + PLUGIN_MANAGER_AVAILABLE = True except ImportError as e: print(f"Plugin Manager import error: {e}") @@ -161,6 +158,7 @@ except ImportError as e: # Phase 7 Task 3: Security Manager try: from security_manager import MaskingRuleType, get_security_manager + SECURITY_MANAGER_AVAILABLE = True except ImportError as e: print(f"Security Manager import error: {e}") @@ -169,6 +167,7 @@ except ImportError as e: # Phase 7 Task 4: Collaboration Manager try: from collaboration_manager import get_collaboration_manager + COLLABORATION_AVAILABLE = True except ImportError as e: print(f"Collaboration Manager import error: {e}") @@ -184,6 +183,7 @@ except ImportError as e: # Phase 7 Task 6: Search Manager try: from search_manager import SearchOperator, get_search_manager + SEARCH_MANAGER_AVAILABLE = True except ImportError as e: print(f"Search Manager import error: {e}") @@ -192,6 +192,7 @@ except ImportError as e: # Phase 7 Task 8: Performance Manager try: from performance_manager import get_performance_manager + PERFORMANCE_MANAGER_AVAILABLE = True except ImportError as e: print(f"Performance Manager import error: {e}") @@ -200,6 +201,7 @@ except ImportError as e: # Phase 8: Tenant Manager (Multi-Tenant SaaS) try: from tenant_manager import TenantRole, TenantStatus, TenantTier, get_tenant_manager + TENANT_MANAGER_AVAILABLE = True except ImportError as e: print(f"Tenant Manager import error: {e}") @@ -208,6 +210,7 @@ except ImportError as e: # Phase 8: Subscription Manager try: from subscription_manager import get_subscription_manager + SUBSCRIPTION_MANAGER_AVAILABLE = True except ImportError as e: print(f"Subscription Manager import error: {e}") @@ -216,6 +219,7 @@ except ImportError as e: # Phase 8: Enterprise Manager try: from enterprise_manager import get_enterprise_manager + ENTERPRISE_MANAGER_AVAILABLE = True except ImportError as e: print(f"Enterprise Manager import error: {e}") @@ -224,6 +228,7 @@ except ImportError as e: # Phase 8: Localization Manager try: from localization_manager import get_localization_manager + LOCALIZATION_MANAGER_AVAILABLE = True except ImportError as e: print(f"Localization Manager import error: {e}") @@ -231,13 +236,8 @@ except ImportError as e: # Phase 8 Task 4: AI Manager try: - from ai_manager import ( - ModelStatus, - ModelType, - MultimodalProvider, - PredictionType, - get_ai_manager, - ) + from ai_manager import ModelStatus, ModelType, MultimodalProvider, PredictionType, get_ai_manager + AI_MANAGER_AVAILABLE = True except ImportError as e: print(f"AI Manager import error: {e}") @@ -253,6 +253,7 @@ try: TrafficAllocationType, WorkflowTriggerType, ) + GROWTH_MANAGER_AVAILABLE = True except ImportError as e: print(f"Growth Manager import error: {e}") @@ -260,14 +261,8 @@ except ImportError as e: # Phase 8 Task 8: Operations & Monitoring Manager try: - from ops_manager import ( - AlertChannelType, - AlertRuleType, - AlertSeverity, - AlertStatus, - ResourceType, - get_ops_manager, - ) + from ops_manager import AlertChannelType, AlertRuleType, AlertSeverity, AlertStatus, ResourceType, get_ops_manager + OPS_MANAGER_AVAILABLE = True except ImportError as e: print(f"Ops Manager import error: {e}") @@ -330,9 +325,12 @@ app = FastAPI( {"name": "Localization", "description": "全球化与本地化(多语言、数据中心、支付方式、时区日历)"}, {"name": "AI Enhancement", "description": "AI 能力增强(自定义模型、多模态分析、智能摘要、预测分析)"}, {"name": "Growth & Analytics", "description": "运营与增长工具(用户行为分析、A/B 测试、邮件营销、推荐系统)"}, - {"name": "Operations & Monitoring", "description": "运维与监控(实时告警、容量规划、自动扩缩容、灾备故障转移、成本优化)"}, + { + "name": "Operations & Monitoring", + "description": "运维与监控(实时告警、容量规划、自动扩缩容、灾备故障转移、成本优化)", + }, {"name": "System", "description": "系统信息"}, - ] + ], ) app.add_middleware( @@ -347,8 +345,12 @@ app.add_middleware( # 公开访问的路径(不需要 API Key) PUBLIC_PATHS = { - "/", "/docs", "/openapi.json", "/redoc", - "/api/v1/health", "/api/v1/status", + "/", + "/docs", + "/openapi.json", + "/redoc", + "/api/v1/health", + "/api/v1/status", "/api/v1/api-keys", # POST 创建 API Key 不需要认证 } @@ -384,8 +386,7 @@ async def verify_api_key(request: Request, x_api_key: str | None = Header(None, if any(path.startswith(p) for p in ADMIN_PATHS): if not x_api_key or x_api_key != MASTER_KEY: raise HTTPException( - status_code=403, - detail="Admin access required. Provide valid master key in X-API-Key header." + status_code=403, detail="Admin access required. Provide valid master key in X-API-Key header." ) return {"type": "admin", "key": x_api_key} @@ -398,7 +399,7 @@ async def verify_api_key(request: Request, x_api_key: str | None = Header(None, raise HTTPException( status_code=401, detail="API Key required. Provide your key in X-API-Key header.", - headers={"WWW-Authenticate": "ApiKey"} + headers={"WWW-Authenticate": "ApiKey"}, ) # 验证 API Key @@ -406,10 +407,7 @@ async def verify_api_key(request: Request, x_api_key: str | None = Header(None, api_key = key_manager.validate_key(x_api_key) if not api_key: - raise HTTPException( - status_code=401, - detail="Invalid or expired API Key" - ) + raise HTTPException(status_code=401, detail="Invalid or expired API Key") # 更新最后使用时间 key_manager.update_last_used(api_key.id) @@ -445,7 +443,7 @@ async def rate_limit_middleware(request: Request, call_next): # Master key 有更高的限流 config = RateLimitConfig(requests_per_minute=1000) limit_key = f"master:{x_api_key[:16]}" - elif hasattr(request.state, 'api_key') and request.state.api_key: + elif hasattr(request.state, "api_key") and request.state.api_key: # 使用 API Key 的限流配置 api_key = request.state.api_key config = RateLimitConfig(requests_per_minute=api_key.rate_limit) @@ -466,14 +464,14 @@ async def rate_limit_middleware(request: Request, call_next): "error": "Rate limit exceeded", "retry_after": info.retry_after, "limit": config.requests_per_minute, - "window": "minute" + "window": "minute", }, headers={ "X-RateLimit-Limit": str(config.requests_per_minute), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(info.reset_time), - "Retry-After": str(info.retry_after) - } + "Retry-After": str(info.retry_after), + }, ) # 继续处理请求 @@ -487,7 +485,7 @@ async def rate_limit_middleware(request: Request, call_next): # 记录 API 调用日志 try: - if hasattr(request.state, 'api_key') and request.state.api_key: + if hasattr(request.state, "api_key") and request.state.api_key: api_key = request.state.api_key response_time = int((time.time() - start_time) * 1000) key_manager = get_api_key_manager() @@ -498,7 +496,7 @@ async def rate_limit_middleware(request: Request, call_next): status_code=response.status_code, response_time_ms=response_time, ip_address=request.client.host if request.client else "", - user_agent=request.headers.get("User-Agent", "") + user_agent=request.headers.get("User-Agent", ""), ) except Exception as e: # 日志记录失败不应影响主流程 @@ -659,11 +657,13 @@ class GlossaryTermCreate(BaseModel): # ==================== Phase 7: Workflow Pydantic Models ==================== + class WorkflowCreate(BaseModel): name: str = Field(..., description="工作流名称") description: str = Field(default="", description="工作流描述") - workflow_type: str = Field(..., - description="工作流类型: auto_analyze, auto_align, auto_relation, scheduled_report, custom") + workflow_type: str = Field( + ..., description="工作流类型: auto_analyze, auto_align, auto_relation, scheduled_report, custom" + ) project_id: str = Field(..., description="所属项目ID") schedule: str | None = Field(default=None, description="调度表达式(cron或分钟数)") schedule_type: str = Field(default="manual", description="调度类型: manual, cron, interval") @@ -861,6 +861,7 @@ def get_collab_manager(): _collaboration_manager = get_collaboration_manager(db) return _collaboration_manager + # Phase 2: Entity Edit API @@ -884,7 +885,7 @@ async def update_entity(entity_id: str, update: EntityUpdate, _=Depends(verify_a "name": updated.name, "type": updated.type, "definition": updated.definition, - "aliases": updated.aliases + "aliases": updated.aliases, } @@ -926,10 +927,11 @@ async def merge_entities_endpoint(entity_id: str, merge_req: EntityMergeRequest, "name": result.name, "type": result.type, "definition": result.definition, - "aliases": result.aliases - } + "aliases": result.aliases, + }, } + # Phase 2: Relation Edit API @@ -953,7 +955,7 @@ async def create_relation_endpoint(project_id: str, relation: RelationCreate, _= source_entity_id=relation.source_entity_id, target_entity_id=relation.target_entity_id, relation_type=relation.relation_type, - evidence=relation.evidence + evidence=relation.evidence, ) return { @@ -961,7 +963,7 @@ async def create_relation_endpoint(project_id: str, relation: RelationCreate, _= "source_id": relation.source_entity_id, "target_id": relation.target_entity_id, "type": relation.relation_type, - "success": True + "success": True, } @@ -984,17 +986,11 @@ async def update_relation(relation_id: str, relation: RelationCreate, _=Depends( db = get_db_manager() updated = db.update_relation( - relation_id=relation_id, - relation_type=relation.relation_type, - evidence=relation.evidence + relation_id=relation_id, relation_type=relation.relation_type, evidence=relation.evidence ) - return { - "id": relation_id, - "type": updated["relation_type"], - "evidence": updated["evidence"], - "success": True - } + return {"id": relation_id, "type": updated["relation_type"], "evidence": updated["evidence"], "success": True} + # Phase 2: Transcript Edit API @@ -1031,9 +1027,10 @@ async def update_transcript(transcript_id: str, update: TranscriptUpdate, _=Depe "id": transcript_id, "full_text": updated["full_text"], "updated_at": updated["updated_at"], - "success": True + "success": True, } + # Phase 2: Manual Entity Creation @@ -1057,21 +1054,12 @@ async def create_manual_entity(project_id: str, entity: ManualEntityCreate, _=De # 检查是否已存在 existing = db.get_entity_by_name(project_id, entity.name) if existing: - return { - "id": existing.id, - "name": existing.name, - "type": existing.type, - "existed": True - } + return {"id": existing.id, "name": existing.name, "type": existing.type, "existed": True} entity_id = str(uuid.uuid4())[:8] - new_entity = db.create_entity(Entity( - id=entity_id, - project_id=project_id, - name=entity.name, - type=entity.type, - definition=entity.definition - )) + new_entity = db.create_entity( + Entity(id=entity_id, project_id=project_id, name=entity.name, type=entity.type, definition=entity.definition) + ) # 如果有提及位置信息,保存提及 if entity.transcript_id and entity.start_pos is not None and entity.end_pos is not None: @@ -1084,8 +1072,8 @@ async def create_manual_entity(project_id: str, entity: ManualEntityCreate, _=De transcript_id=entity.transcript_id, start_pos=entity.start_pos, end_pos=entity.end_pos, - text_snippet=text[max(0, entity.start_pos - 20):min(len(text), entity.end_pos + 20)], - confidence=1.0 + text_snippet=text[max(0, entity.start_pos - 20) : min(len(text), entity.end_pos + 20)], + confidence=1.0, ) db.add_mention(mention) @@ -1094,7 +1082,7 @@ async def create_manual_entity(project_id: str, entity: ManualEntityCreate, _=De "name": new_entity.name, "type": new_entity.type, "definition": new_entity.definition, - "success": True + "success": True, } @@ -1134,8 +1122,13 @@ def mock_transcribe() -> dict: return { "full_text": "我们今天讨论 Project Alpha 的进度,K8s 集群已经部署完成。", "segments": [ - {"start": 0.0, "end": 5.0, "text": "我们今天讨论 Project Alpha 的进度,K8s 集群已经部署完成。", "speaker": "Speaker A"} - ] + { + "start": 0.0, + "end": 5.0, + "text": "我们今天讨论 Project Alpha 的进度,K8s 集群已经部署完成。", + "speaker": "Speaker A", + } + ], } @@ -1174,14 +1167,15 @@ def extract_entities_with_llm(text: str) -> tuple[list[dict], list[dict]]: f"{KIMI_BASE_URL}/v1/chat/completions", headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"}, json={"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1}, - timeout=60.0 + timeout=60.0, ) response.raise_for_status() result = response.json() content = result["choices"][0]["message"]["content"] import re - json_match = re.search(r'\{{.*?\}}', content, re.DOTALL) + + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: data = json.loads(json_match.group()) return data.get("entities", []), data.get("relations", []) @@ -1191,7 +1185,7 @@ def extract_entities_with_llm(text: str) -> tuple[list[dict], list[dict]]: return [], [] -def align_entity(project_id: str, name: str, db, definition: str = "") -> Optional['Entity']: +def align_entity(project_id: str, name: str, db, definition: str = "") -> Optional["Entity"]: """实体对齐 - Phase 3: 使用 embedding 对齐""" # 1. 首先尝试精确匹配 existing = db.get_entity_by_name(project_id, name) @@ -1212,6 +1206,7 @@ def align_entity(project_id: str, name: str, db, definition: str = "") -> Option return None + # API Endpoints @@ -1262,10 +1257,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( # 保存转录记录 transcript_id = str(uuid.uuid4())[:8] db.save_transcript( - transcript_id=transcript_id, - project_id=project_id, - filename=file.filename, - full_text=tw_result["full_text"] + transcript_id=transcript_id, project_id=project_id, filename=file.filename, full_text=tw_result["full_text"] ) # 实体对齐并保存 - Phase 3: 使用增强对齐 @@ -1281,23 +1273,20 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( name=existing.name, type=existing.type, definition=existing.definition, - aliases=existing.aliases + aliases=existing.aliases, ) entity_name_to_id[raw_ent["name"]] = existing.id else: - new_ent = db.create_entity(Entity( - id=str(uuid.uuid4())[:8], - project_id=project_id, - name=raw_ent["name"], - type=raw_ent.get("type", "OTHER"), - definition=raw_ent.get("definition", "") - )) - ent_model = EntityModel( - id=new_ent.id, - name=new_ent.name, - type=new_ent.type, - definition=new_ent.definition + new_ent = db.create_entity( + Entity( + id=str(uuid.uuid4())[:8], + project_id=project_id, + name=raw_ent["name"], + type=raw_ent.get("type", "OTHER"), + definition=raw_ent.get("definition", ""), + ) ) + ent_model = EntityModel(id=new_ent.id, name=new_ent.name, type=new_ent.type, definition=new_ent.definition) entity_name_to_id[raw_ent["name"]] = new_ent.id aligned_entities.append(ent_model) @@ -1316,8 +1305,8 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( transcript_id=transcript_id, start_pos=pos, end_pos=pos + len(name), - text_snippet=full_text[max(0, pos - 20):min(len(full_text), pos + len(name) + 20)], - confidence=1.0 + text_snippet=full_text[max(0, pos - 20) : min(len(full_text), pos + len(name) + 20)], + confidence=1.0, ) db.add_mention(mention) start_pos = pos + 1 @@ -1333,7 +1322,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( target_entity_id=target_id, relation_type=rel.get("type", "related"), evidence=tw_result["full_text"][:200], - transcript_id=transcript_id + transcript_id=transcript_id, ) # 构建片段 @@ -1345,9 +1334,10 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( segments=segments, entities=aligned_entities, full_text=tw_result["full_text"], - created_at=datetime.now().isoformat() + created_at=datetime.now().isoformat(), ) + # Phase 3: Document Upload API @@ -1381,7 +1371,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen project_id=project_id, filename=file.filename, full_text=result["text"], - transcript_type="document" + transcript_type="document", ) # 提取实体和关系 @@ -1396,28 +1386,29 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen if existing: entity_name_to_id[raw_ent["name"]] = existing.id - aligned_entities.append(EntityModel( - id=existing.id, - name=existing.name, - type=existing.type, - definition=existing.definition, - aliases=existing.aliases - )) + aligned_entities.append( + EntityModel( + id=existing.id, + name=existing.name, + type=existing.type, + definition=existing.definition, + aliases=existing.aliases, + ) + ) else: - new_ent = db.create_entity(Entity( - id=str(uuid.uuid4())[:8], - project_id=project_id, - name=raw_ent["name"], - type=raw_ent.get("type", "OTHER"), - definition=raw_ent.get("definition", "") - )) + new_ent = db.create_entity( + Entity( + id=str(uuid.uuid4())[:8], + project_id=project_id, + name=raw_ent["name"], + type=raw_ent.get("type", "OTHER"), + definition=raw_ent.get("definition", ""), + ) + ) entity_name_to_id[raw_ent["name"]] = new_ent.id - aligned_entities.append(EntityModel( - id=new_ent.id, - name=new_ent.name, - type=new_ent.type, - definition=new_ent.definition - )) + aligned_entities.append( + EntityModel(id=new_ent.id, name=new_ent.name, type=new_ent.type, definition=new_ent.definition) + ) # 保存实体提及位置 full_text = result["text"] @@ -1433,8 +1424,8 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen transcript_id=transcript_id, start_pos=pos, end_pos=pos + len(name), - text_snippet=full_text[max(0, pos - 20):min(len(full_text), pos + len(name) + 20)], - confidence=1.0 + text_snippet=full_text[max(0, pos - 20) : min(len(full_text), pos + len(name) + 20)], + confidence=1.0, ) db.add_mention(mention) start_pos = pos + 1 @@ -1450,7 +1441,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen target_entity_id=target_id, relation_type=rel.get("type", "related"), evidence=result["text"][:200], - transcript_id=transcript_id + transcript_id=transcript_id, ) return { @@ -1459,9 +1450,10 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen "filename": file.filename, "text_length": len(result["text"]), "entities": [e.dict() for e in aligned_entities], - "created_at": datetime.now().isoformat() + "created_at": datetime.now().isoformat(), } + # Phase 3: Knowledge Base API @@ -1495,7 +1487,7 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)): mentions = db.get_entity_mentions(ent.id) entity_stats[ent.id] = { "mention_count": len(mentions), - "transcript_ids": list(set([m.transcript_id for m in mentions])) + "transcript_ids": list(set([m.transcript_id for m in mentions])), } # Phase 5: 获取实体属性 attrs = db.get_entity_attributes(ent.id) @@ -1505,16 +1497,12 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)): entity_map = {e.id: e.name for e in entities} return { - "project": { - "id": project.id, - "name": project.name, - "description": project.description - }, + "project": {"id": project.id, "name": project.name, "description": project.description}, "stats": { "entity_count": len(entities), "relation_count": len(relations), "transcript_count": len(transcripts), - "glossary_count": len(glossary) + "glossary_count": len(glossary), }, "entities": [ { @@ -1525,7 +1513,7 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)): "aliases": e.aliases, "mention_count": entity_stats.get(e.id, {}).get("mention_count", 0), "appears_in": entity_stats.get(e.id, {}).get("transcript_ids", []), - "attributes": entity_attributes.get(e.id, []) # Phase 5: 包含属性 + "attributes": entity_attributes.get(e.id, []), # Phase 5: 包含属性 } for e in entities ], @@ -1537,30 +1525,21 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)): "target_id": r["target_entity_id"], "target_name": entity_map.get(r["target_entity_id"], "Unknown"), "type": r["relation_type"], - "evidence": r["evidence"] + "evidence": r["evidence"], } for r in relations ], "glossary": [ - { - "id": g["id"], - "term": g["term"], - "pronunciation": g["pronunciation"], - "frequency": g["frequency"] - } + {"id": g["id"], "term": g["term"], "pronunciation": g["pronunciation"], "frequency": g["frequency"]} for g in glossary ], "transcripts": [ - { - "id": t["id"], - "filename": t["filename"], - "type": t.get("type", "audio"), - "created_at": t["created_at"] - } + {"id": t["id"], "filename": t["filename"], "type": t.get("type", "audio"), "created_at": t["created_at"]} for t in transcripts - ] + ], } + # Phase 3: Glossary API @@ -1575,18 +1554,9 @@ async def add_glossary_term(project_id: str, term: GlossaryTermCreate, _=Depends if not project: raise HTTPException(status_code=404, detail="Project not found") - term_id = db.add_glossary_term( - project_id=project_id, - term=term.term, - pronunciation=term.pronunciation - ) + term_id = db.add_glossary_term(project_id=project_id, term=term.term, pronunciation=term.pronunciation) - return { - "id": term_id, - "term": term.term, - "pronunciation": term.pronunciation, - "success": True - } + return {"id": term_id, "term": term.term, "pronunciation": term.pronunciation, "success": True} @app.get("/api/v1/projects/{project_id}/glossary") @@ -1610,6 +1580,7 @@ async def delete_glossary_term(term_id: str, _=Depends(verify_api_key)): db.delete_glossary_term(term_id) return {"success": True} + # Phase 3: Entity Alignment API @@ -1637,27 +1608,16 @@ async def align_project_entities(project_id: str, threshold: float = 0.85, _=Dep continue similar = aligner.find_similar_entity( - project_id, - entity.name, - entity.definition, - exclude_id=entity.id, - threshold=threshold + project_id, entity.name, entity.definition, exclude_id=entity.id, threshold=threshold ) if similar: # 合并实体 db.merge_entities(similar.id, entity.id) merged_count += 1 - merged_pairs.append({ - "source": entity.name, - "target": similar.name - }) + merged_pairs.append({"source": entity.name, "target": similar.name}) - return { - "success": True, - "merged_count": merged_count, - "merged_pairs": merged_pairs - } + return {"success": True, "merged_count": merged_count, "merged_pairs": merged_pairs} @app.get("/api/v1/projects/{project_id}/entities") @@ -1668,8 +1628,9 @@ async def get_project_entities(project_id: str, _=Depends(verify_api_key)): db = get_db_manager() entities = db.list_project_entities(project_id) - return [{"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": e.aliases} - for e in entities] + return [ + {"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": e.aliases} for e in entities + ] @app.get("/api/v1/projects/{project_id}/relations") @@ -1685,15 +1646,18 @@ async def get_project_relations(project_id: str, _=Depends(verify_api_key)): entities = db.list_project_entities(project_id) entity_map = {e.id: e.name for e in entities} - return [{ - "id": r["id"], - "source_id": r["source_entity_id"], - "source_name": entity_map.get(r["source_entity_id"], "Unknown"), - "target_id": r["target_entity_id"], - "target_name": entity_map.get(r["target_entity_id"], "Unknown"), - "type": r["relation_type"], - "evidence": r["evidence"] - } for r in relations] + return [ + { + "id": r["id"], + "source_id": r["source_entity_id"], + "source_name": entity_map.get(r["source_entity_id"], "Unknown"), + "target_id": r["target_entity_id"], + "target_name": entity_map.get(r["target_entity_id"], "Unknown"), + "type": r["relation_type"], + "evidence": r["evidence"], + } + for r in relations + ] @app.get("/api/v1/projects/{project_id}/transcripts") @@ -1704,13 +1668,16 @@ async def get_project_transcripts(project_id: str, _=Depends(verify_api_key)): db = get_db_manager() transcripts = db.list_project_transcripts(project_id) - return [{ - "id": t["id"], - "filename": t["filename"], - "type": t.get("type", "audio"), - "created_at": t["created_at"], - "preview": t["full_text"][:100] + "..." if len(t["full_text"]) > 100 else t["full_text"] - } for t in transcripts] + return [ + { + "id": t["id"], + "filename": t["filename"], + "type": t.get("type", "audio"), + "created_at": t["created_at"], + "preview": t["full_text"][:100] + "..." if len(t["full_text"]) > 100 else t["full_text"], + } + for t in transcripts + ] @app.get("/api/v1/entities/{entity_id}/mentions") @@ -1721,14 +1688,18 @@ async def get_entity_mentions(entity_id: str, _=Depends(verify_api_key)): db = get_db_manager() mentions = db.get_entity_mentions(entity_id) - return [{ - "id": m.id, - "transcript_id": m.transcript_id, - "start_pos": m.start_pos, - "end_pos": m.end_pos, - "text_snippet": m.text_snippet, - "confidence": m.confidence - } for m in mentions] + return [ + { + "id": m.id, + "transcript_id": m.transcript_id, + "start_pos": m.start_pos, + "end_pos": m.end_pos, + "text_snippet": m.text_snippet, + "confidence": m.confidence, + } + for m in mentions + ] + # Health check - Legacy endpoint (deprecated, use /api/v1/health) @@ -1749,12 +1720,13 @@ async def legacy_health_check(): "multimodal_available": MULTIMODAL_AVAILABLE, "image_processor_available": IMAGE_PROCESSOR_AVAILABLE, "multimodal_linker_available": MULTIMODAL_LINKER_AVAILABLE, - "plugin_manager_available": PLUGIN_MANAGER_AVAILABLE + "plugin_manager_available": PLUGIN_MANAGER_AVAILABLE, } # ==================== Phase 4: Agent 助手 API ==================== + @app.post("/api/v1/projects/{project_id}/agent/query") async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_key)): """Agent RAG 问答""" @@ -1773,20 +1745,21 @@ async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_k # 构建上下文 context_parts = [] - for t in project_context.get('recent_transcripts', []): + for t in project_context.get("recent_transcripts", []): context_parts.append(f"【{t['filename']}】\n{t['full_text'][:1000]}") context = "\n\n".join(context_parts) if query.stream: - import json from fastapi.responses import StreamingResponse async def stream_response(): messages = [ ChatMessage(role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"), - ChatMessage(role="user", content=f"""基于以下项目信息回答问题: + ChatMessage( + role="user", + content=f"""基于以下项目信息回答问题: ## 项目信息 {json.dumps(project_context, ensure_ascii=False, indent=2)} @@ -1797,7 +1770,8 @@ async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_k ## 用户问题 {query.query} -请用中文回答,保持简洁专业。如果信息不足,请明确说明。""") +请用中文回答,保持简洁专业。如果信息不足,请明确说明。""", + ), ] async for chunk in llm.chat_stream(messages): @@ -1896,7 +1870,9 @@ async def agent_command(project_id: str, command: AgentCommand, _=Depends(verify else: result["action"] = "none" - result["message"] = "无法理解的指令,请尝试:\n- 合并实体:把所有'客户端'合并到'App'\n- 提问:张总对项目的态度如何?\n- 编辑:修改'K8s'的定义为..." + result["message"] = ( + "无法理解的指令,请尝试:\n- 合并实体:把所有'客户端'合并到'App'\n- 提问:张总对项目的态度如何?\n- 编辑:修改'K8s'的定义为..." + ) return result @@ -1927,8 +1903,7 @@ async def agent_suggest(project_id: str, _=Depends(verify_api_key)): messages = [ChatMessage(role="user", content=prompt)] content = await llm.chat(messages, temperature=0.3) - import re - json_match = re.search(r'\{{.*?\}}', content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: data = json.loads(json_match.group()) @@ -1941,6 +1916,7 @@ async def agent_suggest(project_id: str, _=Depends(verify_api_key)): # ==================== Phase 4: 知识溯源 API ==================== + @app.get("/api/v1/relations/{relation_id}/provenance") async def get_relation_provenance(relation_id: str, _=Depends(verify_api_key)): """获取关系的知识溯源信息""" @@ -1959,10 +1935,14 @@ async def get_relation_provenance(relation_id: str, _=Depends(verify_api_key)): "target": relation.get("target_name"), "type": relation.get("relation_type"), "evidence": relation.get("evidence"), - "transcript": { - "id": relation.get("transcript_id"), - "filename": relation.get("transcript_filename"), - } if relation.get("transcript_id") else None + "transcript": ( + { + "id": relation.get("transcript_id"), + "filename": relation.get("transcript_filename"), + } + if relation.get("transcript_id") + else None + ), } @@ -2007,15 +1987,16 @@ async def get_entity_evolution(entity_id: str, _=Depends(verify_api_key)): "date": m.get("transcript_date"), "snippet": m.get("text_snippet"), "transcript_id": m.get("transcript_id"), - "filename": m.get("filename") + "filename": m.get("filename"), } for m in entity.get("mentions", []) - ] + ], } # ==================== Phase 4: 实体管理增强 API ==================== + @app.get("/api/v1/projects/{project_id}/entities/search") async def search_entities(project_id: str, q: str, _=Depends(verify_api_key)): """搜索实体""" @@ -2029,13 +2010,10 @@ async def search_entities(project_id: str, q: str, _=Depends(verify_api_key)): # ==================== Phase 5: 时间线视图 API ==================== + @app.get("/api/v1/projects/{project_id}/timeline") async def get_project_timeline( - project_id: str, - entity_id: str = None, - start_date: str = None, - end_date: str = None, - _=Depends(verify_api_key) + project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None, _=Depends(verify_api_key) ): """获取项目时间线 - 按时间顺序的实体提及和关系事件""" if not DB_AVAILABLE: @@ -2048,11 +2026,7 @@ async def get_project_timeline( timeline = db.get_project_timeline(project_id, entity_id, start_date, end_date) - return { - "project_id": project_id, - "events": timeline, - "total_count": len(timeline) - } + return {"project_id": project_id, "events": timeline, "total_count": len(timeline)} @app.get("/api/v1/projects/{project_id}/timeline/summary") @@ -2068,11 +2042,7 @@ async def get_timeline_summary(project_id: str, _=Depends(verify_api_key)): summary = db.get_entity_timeline_summary(project_id) - return { - "project_id": project_id, - "project_name": project.name, - **summary - } + return {"project_id": project_id, "project_name": project.name, **summary} @app.get("/api/v1/entities/{entity_id}/timeline") @@ -2093,12 +2063,13 @@ async def get_entity_timeline(entity_id: str, _=Depends(verify_api_key)): "entity_name": entity.name, "entity_type": entity.type, "events": timeline, - "total_count": len(timeline) + "total_count": len(timeline), } # ==================== Phase 5: 知识推理与问答增强 API ==================== + class ReasoningQuery(BaseModel): query: str reasoning_depth: str = "medium" # shallow/medium/deep @@ -2135,15 +2106,12 @@ async def reasoning_query(project_id: str, query: ReasoningQuery, _=Depends(veri graph_data = { "entities": [{"id": e.id, "name": e.name, "type": e.type, "definition": e.definition} for e in entities], - "relations": relations + "relations": relations, } # 执行增强问答 result = await reasoner.enhanced_qa( - query=query.query, - project_context=project_context, - graph_data=graph_data, - reasoning_depth=query.reasoning_depth + query=query.query, project_context=project_context, graph_data=graph_data, reasoning_depth=query.reasoning_depth ) return { @@ -2152,17 +2120,12 @@ async def reasoning_query(project_id: str, query: ReasoningQuery, _=Depends(veri "confidence": result.confidence, "evidence": result.evidence, "knowledge_gaps": result.gaps, - "project_id": project_id + "project_id": project_id, } @app.post("/api/v1/projects/{project_id}/reasoning/inference-path") -async def find_inference_path( - project_id: str, - start_entity: str, - end_entity: str, - _=Depends(verify_api_key) -): +async def find_inference_path(project_id: str, start_entity: str, end_entity: str, _=Depends(verify_api_key)): """ 发现两个实体之间的推理路径 @@ -2182,10 +2145,7 @@ async def find_inference_path( entities = db.list_project_entities(project_id) relations = db.list_project_relations(project_id) - graph_data = { - "entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], - "relations": relations - } + graph_data = {"entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], "relations": relations} # 查找推理路径 paths = reasoner.find_inference_paths(start_entity, end_entity, graph_data) @@ -2197,11 +2157,11 @@ async def find_inference_path( { "path": path.path, "strength": path.strength, - "path_description": " -> ".join([p["entity"] for p in path.path]) + "path_description": " -> ".join([p["entity"] for p in path.path]), } for path in paths[:5] # 最多返回5条路径 ], - "total_paths": len(paths) + "total_paths": len(paths), } @@ -2237,28 +2197,19 @@ async def project_summary(project_id: str, req: SummaryRequest, _=Depends(verify entities = db.list_project_entities(project_id) relations = db.list_project_relations(project_id) - graph_data = { - "entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], - "relations": relations - } + graph_data = {"entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], "relations": relations} # 生成总结 summary = await reasoner.summarize_project( - project_context=project_context, - graph_data=graph_data, - summary_type=req.summary_type + project_context=project_context, graph_data=graph_data, summary_type=req.summary_type ) - return { - "project_id": project_id, - "summary_type": req.summary_type, - **summary - ** summary - } + return {"project_id": project_id, "summary_type": req.summary_type, **summary**summary} # ==================== Phase 5: 实体属性扩展 API ==================== + class AttributeTemplateCreate(BaseModel): name: str type: str # text, number, date, select, multiselect, boolean @@ -2296,9 +2247,8 @@ class EntityAttributeBatchSet(BaseModel): # 属性模板管理 API @app.post("/api/v1/projects/{project_id}/attribute-templates") async def create_attribute_template_endpoint( - project_id: str, - template: AttributeTemplateCreate, - _=Depends(verify_api_key)): + project_id: str, template: AttributeTemplateCreate, _=Depends(verify_api_key) +): """创建属性模板""" if not DB_AVAILABLE: raise HTTPException(status_code=500, detail="Database not available") @@ -2319,17 +2269,12 @@ async def create_attribute_template_endpoint( default_value=template.default_value or "", description=template.description or "", is_required=template.is_required, - sort_order=template.sort_order + sort_order=template.sort_order, ) db.create_attribute_template(new_template) - return { - "id": new_template.id, - "name": new_template.name, - "type": new_template.type, - "success": True - } + return {"id": new_template.id, "name": new_template.name, "type": new_template.type, "success": True} @app.get("/api/v1/projects/{project_id}/attribute-templates") @@ -2350,7 +2295,7 @@ async def list_attribute_templates_endpoint(project_id: str, _=Depends(verify_ap "default_value": t.default_value, "description": t.description, "is_required": t.is_required, - "sort_order": t.sort_order + "sort_order": t.sort_order, } for t in templates ] @@ -2376,15 +2321,14 @@ async def get_attribute_template_endpoint(template_id: str, _=Depends(verify_api "default_value": template.default_value, "description": template.description, "is_required": template.is_required, - "sort_order": template.sort_order + "sort_order": template.sort_order, } @app.put("/api/v1/attribute-templates/{template_id}") async def update_attribute_template_endpoint( - template_id: str, - update: AttributeTemplateUpdate, - _=Depends(verify_api_key)): + template_id: str, update: AttributeTemplateUpdate, _=Depends(verify_api_key) +): """更新属性模板""" if not DB_AVAILABLE: raise HTTPException(status_code=500, detail="Database not available") @@ -2397,12 +2341,7 @@ async def update_attribute_template_endpoint( update_data = {k: v for k, v in update.dict().items() if v is not None} updated = db.update_attribute_template(template_id, **update_data) - return { - "id": updated.id, - "name": updated.name, - "type": updated.type, - "success": True - } + return {"id": updated.id, "name": updated.name, "type": updated.type, "success": True} @app.delete("/api/v1/attribute-templates/{template_id}") @@ -2430,13 +2369,13 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet raise HTTPException(status_code=404, detail="Entity not found") # 验证类型 - valid_types = ['text', 'number', 'date', 'select', 'multiselect'] + valid_types = ["text", "number", "date", "select", "multiselect"] if attr.type not in valid_types: raise HTTPException(status_code=400, detail=f"Invalid type. Must be one of: {valid_types}") # 处理 value value = attr.value - if attr.type == 'multiselect' and isinstance(value, list): + if attr.type == "multiselect" and isinstance(value, list): value = json.dumps(value) elif value is not None: value = str(value) @@ -2449,8 +2388,7 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet # 检查是否已存在 conn = db.get_conn() existing = conn.execute( - "SELECT * FROM entity_attributes WHERE entity_id = ? AND name = ?", - (entity_id, attr.name) + "SELECT * FROM entity_attributes WHERE entity_id = ? AND name = ?", (entity_id, attr.name) ).fetchone() now = datetime.now().isoformat() @@ -2461,8 +2399,16 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet """INSERT INTO attribute_history (id, entity_id, attribute_name, old_value, new_value, changed_by, changed_at, change_reason) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - (str(uuid.uuid4())[:8], entity_id, attr.name, existing['value'], value, - "user", now, attr.change_reason or "") + ( + str(uuid.uuid4())[:8], + entity_id, + attr.name, + existing["value"], + value, + "user", + now, + attr.change_reason or "", + ), ) # 更新 @@ -2470,9 +2416,9 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet """UPDATE entity_attributes SET value = ?, type = ?, options = ?, updated_at = ? WHERE id = ?""", - (value, attr.type, options, now, existing['id']) + (value, attr.type, options, now, existing["id"]), ) - attr_id = existing['id'] + attr_id = existing["id"] else: # 创建 attr_id = str(uuid.uuid4())[:8] @@ -2480,7 +2426,7 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet """INSERT INTO entity_attributes (id, entity_id, template_id, name, type, value, options, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (attr_id, entity_id, attr.template_id, attr.name, attr.type, value, options, now, now) + (attr_id, entity_id, attr.template_id, attr.name, attr.type, value, options, now, now), ) # 记录历史 @@ -2488,8 +2434,7 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet """INSERT INTO attribute_history (id, entity_id, attribute_name, old_value, new_value, changed_by, changed_at, change_reason) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - (str(uuid.uuid4())[:8], entity_id, attr.name, None, value, - "user", now, attr.change_reason or "创建属性") + (str(uuid.uuid4())[:8], entity_id, attr.name, None, value, "user", now, attr.change_reason or "创建属性"), ) conn.commit() @@ -2501,15 +2446,14 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet "name": attr.name, "type": attr.type, "value": attr.value, - "success": True + "success": True, } @app.post("/api/v1/entities/{entity_id}/attributes/batch") async def batch_set_entity_attributes_endpoint( - entity_id: str, - batch: EntityAttributeBatchSet, - _=Depends(verify_api_key)): + entity_id: str, batch: EntityAttributeBatchSet, _=Depends(verify_api_key) +): """批量设置实体属性值""" if not DB_AVAILABLE: raise HTTPException(status_code=500, detail="Database not available") @@ -2526,25 +2470,14 @@ async def batch_set_entity_attributes_endpoint( template = db.get_attribute_template(attr_data.template_id) if template: new_attr = EntityAttribute( - id=str(uuid.uuid4())[:8], - entity_id=entity_id, - template_id=attr_data.template_id, - value=attr_data.value + id=str(uuid.uuid4())[:8], entity_id=entity_id, template_id=attr_data.template_id, value=attr_data.value + ) + db.set_entity_attribute(new_attr, changed_by="user", change_reason=batch.change_reason or "批量更新") + results.append( + {"template_id": attr_data.template_id, "template_name": template.name, "value": attr_data.value} ) - db.set_entity_attribute(new_attr, changed_by="user", - change_reason=batch.change_reason or "批量更新") - results.append({ - "template_id": attr_data.template_id, - "template_name": template.name, - "value": attr_data.value - }) - return { - "entity_id": entity_id, - "updated_count": len(results), - "attributes": results, - "success": True - } + return {"entity_id": entity_id, "updated_count": len(results), "attributes": results, "success": True} @app.get("/api/v1/entities/{entity_id}/attributes") @@ -2566,22 +2499,22 @@ async def get_entity_attributes_endpoint(entity_id: str, _=Depends(verify_api_ke "template_id": a.template_id, "template_name": a.template_name, "template_type": a.template_type, - "value": a.value + "value": a.value, } for a in attrs ] @app.delete("/api/v1/entities/{entity_id}/attributes/{template_id}") -async def delete_entity_attribute_endpoint(entity_id: str, template_id: str, - reason: str | None = "", _=Depends(verify_api_key)): +async def delete_entity_attribute_endpoint( + entity_id: str, template_id: str, reason: str | None = "", _=Depends(verify_api_key) +): """删除实体属性值""" if not DB_AVAILABLE: raise HTTPException(status_code=500, detail="Database not available") db = get_db_manager() - db.delete_entity_attribute(entity_id, template_id, - changed_by="user", change_reason=reason) + db.delete_entity_attribute(entity_id, template_id, changed_by="user", change_reason=reason) return {"success": True, "message": "Attribute deleted"} @@ -2604,7 +2537,7 @@ async def get_entity_attribute_history_endpoint(entity_id: str, limit: int = 50, "new_value": h.new_value, "changed_by": h.changed_by, "changed_at": h.changed_at, - "change_reason": h.change_reason + "change_reason": h.change_reason, } for h in history ] @@ -2628,7 +2561,7 @@ async def get_template_history_endpoint(template_id: str, limit: int = 50, _=Dep "new_value": h.new_value, "changed_by": h.changed_by, "changed_at": h.changed_at, - "change_reason": h.change_reason + "change_reason": h.change_reason, } for h in history ] @@ -2639,7 +2572,7 @@ async def get_template_history_endpoint(template_id: str, limit: int = 50, _=Dep async def search_entities_by_attributes_endpoint( project_id: str, attribute_filter: str | None = None, # JSON 格式: {"职位": "经理", "部门": "技术部"} - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """根据属性筛选搜索实体""" if not DB_AVAILABLE: @@ -2660,13 +2593,7 @@ async def search_entities_by_attributes_endpoint( entities = db.search_entities_by_attributes(project_id, filters) return [ - { - "id": e.id, - "name": e.name, - "type": e.type, - "definition": e.definition, - "attributes": e.attributes - } + {"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "attributes": e.attributes} for e in entities ] @@ -2693,34 +2620,38 @@ async def export_graph_svg_endpoint(project_id: str, _=Depends(verify_api_key)): entities = [] for e in entities_data: attrs = db.get_entity_attributes(e.id) - entities.append(ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={a.template_name: a.value for a in attrs} - )) + entities.append( + ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={a.template_name: a.value for a in attrs}, + ) + ) relations = [] for r in relations_data: - relations.append(ExportRelation( - id=r.id, - source=r.source_name, - target=r.target_name, - relation_type=r.relation_type, - confidence=r.confidence, - evidence=r.evidence or "" - )) + relations.append( + ExportRelation( + id=r.id, + source=r.source_name, + target=r.target_name, + relation_type=r.relation_type, + confidence=r.confidence, + evidence=r.evidence or "", + ) + ) export_mgr = get_export_manager() svg_content = export_mgr.export_knowledge_graph_svg(project_id, entities, relations) return StreamingResponse( - io.BytesIO(svg_content.encode('utf-8')), + io.BytesIO(svg_content.encode("utf-8")), media_type="image/svg+xml", - headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.svg"} + headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.svg"}, ) @@ -2743,26 +2674,30 @@ async def export_graph_png_endpoint(project_id: str, _=Depends(verify_api_key)): entities = [] for e in entities_data: attrs = db.get_entity_attributes(e.id) - entities.append(ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={a.template_name: a.value for a in attrs} - )) + entities.append( + ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={a.template_name: a.value for a in attrs}, + ) + ) relations = [] for r in relations_data: - relations.append(ExportRelation( - id=r.id, - source=r.source_name, - target=r.target_name, - relation_type=r.relation_type, - confidence=r.confidence, - evidence=r.evidence or "" - )) + relations.append( + ExportRelation( + id=r.id, + source=r.source_name, + target=r.target_name, + relation_type=r.relation_type, + confidence=r.confidence, + evidence=r.evidence or "", + ) + ) export_mgr = get_export_manager() png_bytes = export_mgr.export_knowledge_graph_png(project_id, entities, relations) @@ -2770,7 +2705,7 @@ async def export_graph_png_endpoint(project_id: str, _=Depends(verify_api_key)): return StreamingResponse( io.BytesIO(png_bytes), media_type="image/png", - headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.png"} + headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.png"}, ) @@ -2791,15 +2726,17 @@ async def export_entities_excel_endpoint(project_id: str, _=Depends(verify_api_k entities = [] for e in entities_data: attrs = db.get_entity_attributes(e.id) - entities.append(ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={a.template_name: a.value for a in attrs} - )) + entities.append( + ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={a.template_name: a.value for a in attrs}, + ) + ) export_mgr = get_export_manager() excel_bytes = export_mgr.export_entities_excel(entities) @@ -2807,7 +2744,7 @@ async def export_entities_excel_endpoint(project_id: str, _=Depends(verify_api_k return StreamingResponse( io.BytesIO(excel_bytes), media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.xlsx"} + headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.xlsx"}, ) @@ -2828,23 +2765,25 @@ async def export_entities_csv_endpoint(project_id: str, _=Depends(verify_api_key entities = [] for e in entities_data: attrs = db.get_entity_attributes(e.id) - entities.append(ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={a.template_name: a.value for a in attrs} - )) + entities.append( + ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={a.template_name: a.value for a in attrs}, + ) + ) export_mgr = get_export_manager() csv_content = export_mgr.export_entities_csv(entities) return StreamingResponse( - io.BytesIO(csv_content.encode('utf-8')), + io.BytesIO(csv_content.encode("utf-8")), media_type="text/csv", - headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.csv"} + headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.csv"}, ) @@ -2864,22 +2803,24 @@ async def export_relations_csv_endpoint(project_id: str, _=Depends(verify_api_ke relations = [] for r in relations_data: - relations.append(ExportRelation( - id=r.id, - source=r.source_name, - target=r.target_name, - relation_type=r.relation_type, - confidence=r.confidence, - evidence=r.evidence or "" - )) + relations.append( + ExportRelation( + id=r.id, + source=r.source_name, + target=r.target_name, + relation_type=r.relation_type, + confidence=r.confidence, + evidence=r.evidence or "", + ) + ) export_mgr = get_export_manager() csv_content = export_mgr.export_relations_csv(relations) return StreamingResponse( - io.BytesIO(csv_content.encode('utf-8')), + io.BytesIO(csv_content.encode("utf-8")), media_type="text/csv", - headers={"Content-Disposition": f"attachment; filename=insightflow-relations-{project_id}.csv"} + headers={"Content-Disposition": f"attachment; filename=insightflow-relations-{project_id}.csv"}, ) @@ -2903,38 +2844,39 @@ async def export_report_pdf_endpoint(project_id: str, _=Depends(verify_api_key)) entities = [] for e in entities_data: attrs = db.get_entity_attributes(e.id) - entities.append(ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={a.template_name: a.value for a in attrs} - )) + entities.append( + ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={a.template_name: a.value for a in attrs}, + ) + ) relations = [] for r in relations_data: - relations.append(ExportRelation( - id=r.id, - source=r.source_name, - target=r.target_name, - relation_type=r.relation_type, - confidence=r.confidence, - evidence=r.evidence or "" - )) + relations.append( + ExportRelation( + id=r.id, + source=r.source_name, + target=r.target_name, + relation_type=r.relation_type, + confidence=r.confidence, + evidence=r.evidence or "", + ) + ) transcripts = [] for t in transcripts_data: segments = json.loads(t.segments) if t.segments else [] - transcripts.append(ExportTranscript( - id=t.id, - name=t.name, - type=t.type, - content=t.full_text or "", - segments=segments, - entity_mentions=[] - )) + transcripts.append( + ExportTranscript( + id=t.id, name=t.name, type=t.type, content=t.full_text or "", segments=segments, entity_mentions=[] + ) + ) # 获取项目总结 summary = "" @@ -2954,7 +2896,7 @@ async def export_report_pdf_endpoint(project_id: str, _=Depends(verify_api_key)) return StreamingResponse( io.BytesIO(pdf_bytes), media_type="application/pdf", - headers={"Content-Disposition": f"attachment; filename=insightflow-report-{project_id}.pdf"} + headers={"Content-Disposition": f"attachment; filename=insightflow-report-{project_id}.pdf"}, ) @@ -2978,48 +2920,47 @@ async def export_project_json_endpoint(project_id: str, _=Depends(verify_api_key entities = [] for e in entities_data: attrs = db.get_entity_attributes(e.id) - entities.append(ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={a.template_name: a.value for a in attrs} - )) + entities.append( + ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={a.template_name: a.value for a in attrs}, + ) + ) relations = [] for r in relations_data: - relations.append(ExportRelation( - id=r.id, - source=r.source_name, - target=r.target_name, - relation_type=r.relation_type, - confidence=r.confidence, - evidence=r.evidence or "" - )) + relations.append( + ExportRelation( + id=r.id, + source=r.source_name, + target=r.target_name, + relation_type=r.relation_type, + confidence=r.confidence, + evidence=r.evidence or "", + ) + ) transcripts = [] for t in transcripts_data: segments = json.loads(t.segments) if t.segments else [] - transcripts.append(ExportTranscript( - id=t.id, - name=t.name, - type=t.type, - content=t.full_text or "", - segments=segments, - entity_mentions=[] - )) + transcripts.append( + ExportTranscript( + id=t.id, name=t.name, type=t.type, content=t.full_text or "", segments=segments, entity_mentions=[] + ) + ) export_mgr = get_export_manager() - json_content = export_mgr.export_project_json( - project_id, project.name, entities, relations, transcripts - ) + json_content = export_mgr.export_project_json(project_id, project.name, entities, relations, transcripts) return StreamingResponse( - io.BytesIO(json_content.encode('utf-8')), + io.BytesIO(json_content.encode("utf-8")), media_type="application/json", - headers={"Content-Disposition": f"attachment; filename=insightflow-project-{project_id}.json"} + headers={"Content-Disposition": f"attachment; filename=insightflow-project-{project_id}.json"}, ) @@ -3039,15 +2980,18 @@ async def export_transcript_markdown_endpoint(transcript_id: str, _=Depends(veri # 获取项目实体用于映射 entities_data = db.get_project_entities(transcript.project_id) - entities_map = {e.id: ExportEntity( - id=e.id, - name=e.name, - type=e.type, - definition=e.definition or "", - aliases=json.loads(e.aliases) if e.aliases else [], - mention_count=e.mention_count, - attributes={} - ) for e in entities_data} + entities_map = { + e.id: ExportEntity( + id=e.id, + name=e.name, + type=e.type, + definition=e.definition or "", + aliases=json.loads(e.aliases) if e.aliases else [], + mention_count=e.mention_count, + attributes={}, + ) + for e in entities_data + } segments = json.loads(transcript.segments) if transcript.segments else [] @@ -3057,26 +3001,25 @@ async def export_transcript_markdown_endpoint(transcript_id: str, _=Depends(veri type=transcript.type, content=transcript.full_text or "", segments=segments, - entity_mentions=[{ - "entity_id": m.entity_id, - "entity_name": m.entity_name, - "position": m.position, - "context": m.context - } for m in mentions] + entity_mentions=[ + {"entity_id": m.entity_id, "entity_name": m.entity_name, "position": m.position, "context": m.context} + for m in mentions + ], ) export_mgr = get_export_manager() markdown_content = export_mgr.export_transcript_markdown(export_transcript, entities_map) return StreamingResponse( - io.BytesIO(markdown_content.encode('utf-8')), + io.BytesIO(markdown_content.encode("utf-8")), media_type="text/markdown", - headers={"Content-Disposition": f"attachment; filename=insightflow-transcript-{transcript_id}.md"} + headers={"Content-Disposition": f"attachment; filename=insightflow-transcript-{transcript_id}.md"}, ) # ==================== Neo4j Graph Database API ==================== + class Neo4jSyncRequest(BaseModel): project_id: str @@ -3096,11 +3039,7 @@ class GraphQueryRequest(BaseModel): async def neo4j_status(_=Depends(verify_api_key)): """获取 Neo4j 连接状态""" if not NEO4J_AVAILABLE: - return { - "available": False, - "connected": False, - "message": "Neo4j driver not installed" - } + return {"available": False, "connected": False, "message": "Neo4j driver not installed"} try: manager = get_neo4j_manager() @@ -3109,14 +3048,10 @@ async def neo4j_status(_=Depends(verify_api_key)): "available": True, "connected": connected, "uri": manager.uri if connected else None, - "message": "Connected" if connected else "Not connected" + "message": "Connected" if connected else "Not connected", } except Exception as e: - return { - "available": True, - "connected": False, - "message": str(e) - } + return {"available": True, "connected": False, "message": str(e)} @app.post("/api/v1/neo4j/sync") @@ -3141,34 +3076,35 @@ async def neo4j_sync_project(request: Neo4jSyncRequest, _=Depends(verify_api_key entities = db.get_project_entities(request.project_id) entities_data = [] for e in entities: - entities_data.append({ - "id": e.id, - "name": e.name, - "type": e.type, - "definition": e.definition, - "aliases": json.loads(e.aliases) if e.aliases else [], - "properties": e.attributes if hasattr(e, 'attributes') else {} - }) + entities_data.append( + { + "id": e.id, + "name": e.name, + "type": e.type, + "definition": e.definition, + "aliases": json.loads(e.aliases) if e.aliases else [], + "properties": e.attributes if hasattr(e, "attributes") else {}, + } + ) # 获取项目所有关系 relations = db.get_project_relations(request.project_id) relations_data = [] for r in relations: - relations_data.append({ - "id": r.id, - "source_entity_id": r.source_entity_id, - "target_entity_id": r.target_entity_id, - "relation_type": r.relation_type, - "evidence": r.evidence, - "properties": {} - }) + relations_data.append( + { + "id": r.id, + "source_entity_id": r.source_entity_id, + "target_entity_id": r.target_entity_id, + "relation_type": r.relation_type, + "evidence": r.evidence, + "properties": {}, + } + ) # 同步到 Neo4j sync_project_to_neo4j( - project_id=request.project_id, - project_name=project.name, - entities=entities_data, - relations=relations_data + project_id=request.project_id, project_name=project.name, entities=entities_data, relations=relations_data ) return { @@ -3176,7 +3112,7 @@ async def neo4j_sync_project(request: Neo4jSyncRequest, _=Depends(verify_api_key "project_id": request.project_id, "entities_synced": len(entities_data), "relations_synced": len(relations_data), - "message": f"Synced {len(entities_data)} entities and {len(relations_data)} relations to Neo4j" + "message": f"Synced {len(entities_data)} entities and {len(relations_data)} relations to Neo4j", } @@ -3204,26 +3140,12 @@ async def find_shortest_path(request: PathQueryRequest, _=Depends(verify_api_key if not manager.is_connected(): raise HTTPException(status_code=503, detail="Neo4j not connected") - path = manager.find_shortest_path( - request.source_entity_id, - request.target_entity_id, - request.max_depth - ) + path = manager.find_shortest_path(request.source_entity_id, request.target_entity_id, request.max_depth) if not path: - return { - "found": False, - "message": "No path found between entities" - } + return {"found": False, "message": "No path found between entities"} - return { - "found": True, - "path": { - "nodes": path.nodes, - "relationships": path.relationships, - "length": path.length - } - } + return {"found": True, "path": {"nodes": path.nodes, "relationships": path.relationships, "length": path.length}} @app.post("/api/v1/graph/paths") @@ -3236,32 +3158,16 @@ async def find_all_paths(request: PathQueryRequest, _=Depends(verify_api_key)): if not manager.is_connected(): raise HTTPException(status_code=503, detail="Neo4j not connected") - paths = manager.find_all_paths( - request.source_entity_id, - request.target_entity_id, - request.max_depth - ) + paths = manager.find_all_paths(request.source_entity_id, request.target_entity_id, request.max_depth) return { "count": len(paths), - "paths": [ - { - "nodes": p.nodes, - "relationships": p.relationships, - "length": p.length - } - for p in paths - ] + "paths": [{"nodes": p.nodes, "relationships": p.relationships, "length": p.length} for p in paths], } @app.get("/api/v1/entities/{entity_id}/neighbors") -async def get_entity_neighbors( - entity_id: str, - relation_type: str = None, - limit: int = 50, - _=Depends(verify_api_key) -): +async def get_entity_neighbors(entity_id: str, relation_type: str = None, limit: int = 50, _=Depends(verify_api_key)): """获取实体的邻居节点""" if not NEO4J_AVAILABLE: raise HTTPException(status_code=503, detail="Neo4j not available") @@ -3271,11 +3177,7 @@ async def get_entity_neighbors( raise HTTPException(status_code=503, detail="Neo4j not connected") neighbors = manager.find_neighbors(entity_id, relation_type, limit) - return { - "entity_id": entity_id, - "count": len(neighbors), - "neighbors": neighbors - } + return {"entity_id": entity_id, "count": len(neighbors), "neighbors": neighbors} @app.get("/api/v1/entities/{entity_id1}/common-neighbors/{entity_id2}") @@ -3289,20 +3191,11 @@ async def get_common_neighbors(entity_id1: str, entity_id2: str, _=Depends(verif raise HTTPException(status_code=503, detail="Neo4j not connected") common = manager.find_common_neighbors(entity_id1, entity_id2) - return { - "entity_id1": entity_id1, - "entity_id2": entity_id2, - "count": len(common), - "common_neighbors": common - } + return {"entity_id1": entity_id1, "entity_id2": entity_id2, "count": len(common), "common_neighbors": common} @app.get("/api/v1/projects/{project_id}/graph/centrality") -async def get_centrality_analysis( - project_id: str, - metric: str = "degree", - _=Depends(verify_api_key) -): +async def get_centrality_analysis(project_id: str, metric: str = "degree", _=Depends(verify_api_key)): """获取中心性分析结果""" if not NEO4J_AVAILABLE: raise HTTPException(status_code=503, detail="Neo4j not available") @@ -3316,14 +3209,8 @@ async def get_centrality_analysis( "metric": metric, "count": len(rankings), "rankings": [ - { - "entity_id": r.entity_id, - "entity_name": r.entity_name, - "score": r.score, - "rank": r.rank - } - for r in rankings - ] + {"entity_id": r.entity_id, "entity_name": r.entity_name, "score": r.score, "rank": r.rank} for r in rankings + ], } @@ -3341,14 +3228,9 @@ async def get_communities(project_id: str, _=Depends(verify_api_key)): return { "count": len(communities), "communities": [ - { - "community_id": c.community_id, - "size": c.size, - "density": c.density, - "nodes": c.nodes - } + {"community_id": c.community_id, "size": c.size, "density": c.density, "nodes": c.nodes} for c in communities - ] + ], } @@ -3368,6 +3250,7 @@ async def get_subgraph(request: GraphQueryRequest, _=Depends(verify_api_key)): # ==================== Phase 6: API Key Management Endpoints ==================== + @app.post("/api/v1/api-keys", response_model=ApiKeyCreateResponse, tags=["API Keys"]) async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)): """ @@ -3386,7 +3269,7 @@ async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)): name=request.name, permissions=request.permissions, rate_limit=request.rate_limit, - expires_days=request.expires_days + expires_days=request.expires_days, ) return ApiKeyCreateResponse( @@ -3401,18 +3284,13 @@ async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)): created_at=api_key.created_at, expires_at=api_key.expires_at, last_used_at=api_key.last_used_at, - total_calls=api_key.total_calls - ) + total_calls=api_key.total_calls, + ), ) @app.get("/api/v1/api-keys", response_model=ApiKeyListResponse, tags=["API Keys"]) -async def list_api_keys( - status: str | None = None, - limit: int = 100, - offset: int = 0, - _=Depends(verify_api_key) -): +async def list_api_keys(status: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key)): """ 列出所有 API Keys @@ -3438,11 +3316,11 @@ async def list_api_keys( created_at=k.created_at, expires_at=k.expires_at, last_used_at=k.last_used_at, - total_calls=k.total_calls + total_calls=k.total_calls, ) for k in keys ], - total=len(keys) + total=len(keys), ) @@ -3468,7 +3346,7 @@ async def get_api_key(key_id: str, _=Depends(verify_api_key)): created_at=key.created_at, expires_at=key.expires_at, last_used_at=key.last_used_at, - total_calls=key.total_calls + total_calls=key.total_calls, ) @@ -3513,7 +3391,7 @@ async def update_api_key(key_id: str, request: ApiKeyUpdate, _=Depends(verify_ap created_at=key.created_at, expires_at=key.expires_at, last_used_at=key.last_used_at, - total_calls=key.total_calls + total_calls=key.total_calls, ) @@ -3556,19 +3434,12 @@ async def get_api_key_stats(key_id: str, days: int = 30, _=Depends(verify_api_ke stats = key_manager.get_call_stats(key_id, days=days) return ApiStatsResponse( - summary=ApiCallStats(**stats["summary"]), - endpoints=stats["endpoints"], - daily=stats["daily"] + summary=ApiCallStats(**stats["summary"]), endpoints=stats["endpoints"], daily=stats["daily"] ) @app.get("/api/v1/api-keys/{key_id}/logs", response_model=ApiLogsResponse, tags=["API Keys"]) -async def get_api_key_logs( - key_id: str, - limit: int = 100, - offset: int = 0, - _=Depends(verify_api_key) -): +async def get_api_key_logs(key_id: str, limit: int = 100, offset: int = 0, _=Depends(verify_api_key)): """ 获取 API Key 的调用日志 @@ -3598,11 +3469,11 @@ async def get_api_key_logs( ip_address=log["ip_address"], user_agent=log["user_agent"], error_message=log["error_message"], - created_at=log["created_at"] + created_at=log["created_at"], ) for log in logs ], - total=len(logs) + total=len(logs), ) @@ -3610,17 +3481,12 @@ async def get_api_key_logs( async def get_rate_limit_status(request: Request, _=Depends(verify_api_key)): """获取当前请求的限流状态""" if not RATE_LIMITER_AVAILABLE: - return RateLimitStatus( - limit=60, - remaining=60, - reset_time=int(time.time()) + 60, - window="minute" - ) + return RateLimitStatus(limit=60, remaining=60, reset_time=int(time.time()) + 60, window="minute") limiter = get_rate_limiter() # 获取限流键 - if hasattr(request.state, 'api_key') and request.state.api_key: + if hasattr(request.state, "api_key") and request.state.api_key: api_key = request.state.api_key limit_key = f"api_key:{api_key.id}" limit = api_key.rate_limit @@ -3631,24 +3497,16 @@ async def get_rate_limit_status(request: Request, _=Depends(verify_api_key)): info = await limiter.get_limit_info(limit_key) - return RateLimitStatus( - limit=limit, - remaining=info.remaining, - reset_time=info.reset_time, - window="minute" - ) + return RateLimitStatus(limit=limit, remaining=info.remaining, reset_time=info.reset_time, window="minute") # ==================== Phase 6: System Endpoints ==================== + @app.get("/api/v1/health", tags=["System"]) async def api_health_check(): """健康检查端点""" - return { - "status": "healthy", - "version": "0.7.0", - "timestamp": datetime.now().isoformat() - } + return {"status": "healthy", "version": "0.7.0", "timestamp": datetime.now().isoformat()} @app.get("/api/v1/status", tags=["System"]) @@ -3675,7 +3533,7 @@ async def system_status(): "documentation": "/docs", "openapi": "/openapi.json", }, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } return status @@ -3691,6 +3549,7 @@ def get_workflow_manager_instance(): global _workflow_manager if _workflow_manager is None and WORKFLOW_AVAILABLE and DB_AVAILABLE: from workflow_manager import WorkflowManager + db = get_db_manager() _workflow_manager = WorkflowManager(db) _workflow_manager.start() @@ -3733,7 +3592,7 @@ async def create_workflow_endpoint(request: WorkflowCreate, _=Depends(verify_api schedule=request.schedule, schedule_type=request.schedule_type, config=request.config, - webhook_ids=request.webhook_ids + webhook_ids=request.webhook_ids, ) created = manager.create_workflow(workflow) @@ -3756,7 +3615,7 @@ async def create_workflow_endpoint(request: WorkflowCreate, _=Depends(verify_api next_run_at=created.next_run_at, run_count=created.run_count, success_count=created.success_count, - fail_count=created.fail_count + fail_count=created.fail_count, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -3767,7 +3626,7 @@ async def list_workflows_endpoint( project_id: str | None = None, status: str | None = None, workflow_type: str | None = None, - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取工作流列表""" if not WORKFLOW_AVAILABLE: @@ -3796,11 +3655,11 @@ async def list_workflows_endpoint( next_run_at=w.next_run_at, run_count=w.run_count, success_count=w.success_count, - fail_count=w.fail_count + fail_count=w.fail_count, ) for w in workflows ], - total=len(workflows) + total=len(workflows), ) @@ -3834,7 +3693,7 @@ async def get_workflow_endpoint(workflow_id: str, _=Depends(verify_api_key)): next_run_at=workflow.next_run_at, run_count=workflow.run_count, success_count=workflow.success_count, - fail_count=workflow.fail_count + fail_count=workflow.fail_count, ) @@ -3870,7 +3729,7 @@ async def update_workflow_endpoint(workflow_id: str, request: WorkflowUpdate, _= next_run_at=updated.next_run_at, run_count=updated.run_count, success_count=updated.success_count, - fail_count=updated.fail_count + fail_count=updated.fail_count, ) @@ -3891,9 +3750,8 @@ async def delete_workflow_endpoint(workflow_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/workflows/{workflow_id}/trigger", response_model=WorkflowTriggerResponse, tags=["Workflows"]) async def trigger_workflow_endpoint( - workflow_id: str, - request: WorkflowTriggerRequest = None, - _=Depends(verify_api_key)): + workflow_id: str, request: WorkflowTriggerRequest = None, _=Depends(verify_api_key) +): """手动触发工作流""" if not WORKFLOW_AVAILABLE: raise HTTPException(status_code=503, detail="Workflow automation not available") @@ -3901,17 +3759,14 @@ async def trigger_workflow_endpoint( manager = get_workflow_manager_instance() try: - result = await manager.execute_workflow( - workflow_id, - input_data=request.input_data if request else {} - ) + result = await manager.execute_workflow(workflow_id, input_data=request.input_data if request else {}) return WorkflowTriggerResponse( success=result["success"], workflow_id=result["workflow_id"], log_id=result["log_id"], results=result["results"], - duration_ms=result["duration_ms"] + duration_ms=result["duration_ms"], ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -3921,11 +3776,7 @@ async def trigger_workflow_endpoint( @app.get("/api/v1/workflows/{workflow_id}/logs", response_model=WorkflowLogListResponse, tags=["Workflows"]) async def get_workflow_logs_endpoint( - workflow_id: str, - status: str | None = None, - limit: int = 100, - offset: int = 0, - _=Depends(verify_api_key) + workflow_id: str, status: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key) ): """获取工作流执行日志""" if not WORKFLOW_AVAILABLE: @@ -3947,11 +3798,11 @@ async def get_workflow_logs_endpoint( input_data=log.input_data, output_data=log.output_data, error_message=log.error_message, - created_at=log.created_at + created_at=log.created_at, ) for log in logs ], - total=len(logs) + total=len(logs), ) @@ -3969,6 +3820,7 @@ async def get_workflow_stats_endpoint(workflow_id: str, days: int = 30, _=Depend # ==================== Phase 7: Webhook Endpoints ==================== + @app.post("/api/v1/webhooks", response_model=WebhookResponse, tags=["Webhooks"]) async def create_webhook_endpoint(request: WebhookCreate, _=Depends(verify_api_key)): """ @@ -3993,7 +3845,7 @@ async def create_webhook_endpoint(request: WebhookCreate, _=Depends(verify_api_k url=request.url, secret=request.secret, headers=request.headers, - template=request.template + template=request.template, ) created = manager.create_webhook(webhook) @@ -4010,7 +3862,7 @@ async def create_webhook_endpoint(request: WebhookCreate, _=Depends(verify_api_k updated_at=created.updated_at, last_used_at=created.last_used_at, success_count=created.success_count, - fail_count=created.fail_count + fail_count=created.fail_count, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -4039,11 +3891,11 @@ async def list_webhooks_endpoint(_=Depends(verify_api_key)): updated_at=w.updated_at, last_used_at=w.last_used_at, success_count=w.success_count, - fail_count=w.fail_count + fail_count=w.fail_count, ) for w in webhooks ], - total=len(webhooks) + total=len(webhooks), ) @@ -4071,7 +3923,7 @@ async def get_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)): updated_at=webhook.updated_at, last_used_at=webhook.last_used_at, success_count=webhook.success_count, - fail_count=webhook.fail_count + fail_count=webhook.fail_count, ) @@ -4101,7 +3953,7 @@ async def update_webhook_endpoint(webhook_id: str, request: WebhookUpdate, _=Dep updated_at=updated.updated_at, last_used_at=updated.last_used_at, success_count=updated.success_count, - fail_count=updated.fail_count + fail_count=updated.fail_count, ) @@ -4138,7 +3990,9 @@ async def test_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)): } if webhook.webhook_type == "slack": - test_message = {"text": "🔔 这是来自 InsightFlow 的 Webhook 测试消息\n\n如果您收到这条消息,说明 Webhook 配置正确!"} + test_message = { + "text": "🔔 这是来自 InsightFlow 的 Webhook 测试消息\n\n如果您收到这条消息,说明 Webhook 配置正确!" + } success = await manager.notifier.send(webhook, test_message) manager.update_webhook_stats(webhook_id, success) @@ -4151,6 +4005,7 @@ async def test_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)): # ==================== Phase 7: Multimodal Support Endpoints ==================== + # Pydantic Models for Multimodal API class VideoUploadResponse(BaseModel): video_id: str @@ -4208,10 +4063,7 @@ class MultimodalStatsResponse(BaseModel): @app.post("/api/v1/projects/{project_id}/upload-video", response_model=VideoUploadResponse, tags=["Multimodal"]) async def upload_video_endpoint( - project_id: str, - file: UploadFile = File(...), - extract_interval: int = Form(5), - _=Depends(verify_api_key) + project_id: str, file: UploadFile = File(...), extract_interval: int = Form(5), _=Depends(verify_api_key) ): """ 上传视频文件进行处理 @@ -4261,10 +4113,21 @@ async def upload_video_endpoint( audio_transcript_id, full_ocr_text, extracted_entities, extracted_relations, status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (video_id, project_id, file.filename, video_info.get('duration', 0), - video_info.get('fps', 0), - json.dumps({'width': video_info.get('width', 0), 'height': video_info.get('height', 0)}), - None, result.full_text, '[]', '[]', 'completed', now, now) + ( + video_id, + project_id, + file.filename, + video_info.get("duration", 0), + video_info.get("fps", 0), + json.dumps({"width": video_info.get("width", 0), "height": video_info.get("height", 0)}), + None, + result.full_text, + "[]", + "[]", + "completed", + now, + now, + ), ) # 保存关键帧信息 @@ -4273,8 +4136,16 @@ async def upload_video_endpoint( """INSERT INTO video_frames (id, video_id, frame_number, timestamp, image_url, ocr_text, extracted_entities, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - (frame.id, frame.video_id, frame.frame_number, frame.timestamp, - frame.frame_path, frame.ocr_text, json.dumps(frame.entities_detected), now) + ( + frame.id, + frame.video_id, + frame.frame_number, + frame.timestamp, + frame.frame_path, + frame.ocr_text, + json.dumps(frame.entities_detected), + now, + ), ) conn.commit() @@ -4292,13 +4163,15 @@ async def upload_video_endpoint( if existing: entity_name_to_id[raw_ent["name"]] = existing.id else: - new_ent = db.create_entity(Entity( - id=str(uuid.uuid4())[:8], - project_id=project_id, - name=raw_ent["name"], - type=raw_ent.get("type", "OTHER"), - definition=raw_ent.get("definition", "") - )) + new_ent = db.create_entity( + Entity( + id=str(uuid.uuid4())[:8], + project_id=project_id, + name=raw_ent["name"], + type=raw_ent.get("type", "OTHER"), + definition=raw_ent.get("definition", ""), + ) + ) entity_name_to_id[raw_ent["name"]] = new_ent.id # 保存多模态实体提及 @@ -4307,8 +4180,17 @@ async def upload_video_endpoint( """INSERT OR REPLACE INTO multimodal_mentions (id, project_id, entity_id, modality, source_id, source_type, text_snippet, confidence, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (str(uuid.uuid4())[:8], project_id, entity_name_to_id[raw_ent["name"]], - 'video', video_id, 'video_frame', raw_ent.get("name", ""), 1.0, now) + ( + str(uuid.uuid4())[:8], + project_id, + entity_name_to_id[raw_ent["name"]], + "video", + video_id, + "video_frame", + raw_ent.get("name", ""), + 1.0, + now, + ), ) conn.commit() conn.close() @@ -4323,14 +4205,14 @@ async def upload_video_endpoint( source_entity_id=source_id, target_entity_id=target_id, relation_type=rel.get("type", "related"), - evidence=result.full_text[:200] + evidence=result.full_text[:200], ) # 更新视频的实体和关系信息 conn = db.get_conn() conn.execute( "UPDATE videos SET extracted_entities = ?, extracted_relations = ? WHERE id = ?", - (json.dumps(raw_entities), json.dumps(raw_relations), video_id) + (json.dumps(raw_entities), json.dumps(raw_relations), video_id), ) conn.commit() conn.close() @@ -4343,16 +4225,13 @@ async def upload_video_endpoint( audio_extracted=bool(result.audio_path), frame_count=len(result.frames), ocr_text_preview=result.full_text[:200] + "..." if len(result.full_text) > 200 else result.full_text, - message="Video processed successfully" + message="Video processed successfully", ) @app.post("/api/v1/projects/{project_id}/upload-image", response_model=ImageUploadResponse, tags=["Multimodal"]) async def upload_image_endpoint( - project_id: str, - file: UploadFile = File(...), - detect_type: bool = Form(True), - _=Depends(verify_api_key) + project_id: str, file: UploadFile = File(...), detect_type: bool = Form(True), _=Depends(verify_api_key) ): """ 上传图片文件进行处理 @@ -4397,10 +4276,18 @@ async def upload_image_endpoint( (id, project_id, filename, ocr_text, description, extracted_entities, extracted_relations, status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (image_id, project_id, file.filename, result.ocr_text, result.description, - json.dumps([{"name": e.name, "type": e.type, "confidence": e.confidence} for e in result.entities]), - json.dumps([{"source": r.source, "target": r.target, "type": r.relation_type} for r in result.relations]), - 'completed', now, now) + ( + image_id, + project_id, + file.filename, + result.ocr_text, + result.description, + json.dumps([{"name": e.name, "type": e.type, "confidence": e.confidence} for e in result.entities]), + json.dumps([{"source": r.source, "target": r.target, "type": r.relation_type} for r in result.relations]), + "completed", + now, + now, + ), ) conn.commit() conn.close() @@ -4410,13 +4297,11 @@ async def upload_image_endpoint( existing = align_entity(project_id, entity.name, db, "") if not existing: - new_ent = db.create_entity(Entity( - id=str(uuid.uuid4())[:8], - project_id=project_id, - name=entity.name, - type=entity.type, - definition="" - )) + new_ent = db.create_entity( + Entity( + id=str(uuid.uuid4())[:8], project_id=project_id, name=entity.name, type=entity.type, definition="" + ) + ) entity_id = new_ent.id else: entity_id = existing.id @@ -4427,8 +4312,17 @@ async def upload_image_endpoint( """INSERT OR REPLACE INTO multimodal_mentions (id, project_id, entity_id, modality, source_id, source_type, text_snippet, confidence, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (str(uuid.uuid4())[:8], project_id, entity_id, - 'image', image_id, result.image_type, entity.name, entity.confidence, now) + ( + str(uuid.uuid4())[:8], + project_id, + entity_id, + "image", + image_id, + result.image_type, + entity.name, + entity.confidence, + now, + ), ) conn.commit() conn.close() @@ -4444,7 +4338,7 @@ async def upload_image_endpoint( source_entity_id=source_entity.id, target_entity_id=target_entity.id, relation_type=relation.relation_type, - evidence=result.ocr_text[:200] + evidence=result.ocr_text[:200], ) return ImageUploadResponse( @@ -4455,16 +4349,12 @@ async def upload_image_endpoint( ocr_text_preview=result.ocr_text[:200] + "..." if len(result.ocr_text) > 200 else result.ocr_text, description=result.description, entity_count=len(result.entities), - status="completed" + status="completed", ) @app.post("/api/v1/projects/{project_id}/upload-images-batch", tags=["Multimodal"]) -async def upload_images_batch_endpoint( - project_id: str, - files: list[UploadFile] = File(...), - _=Depends(verify_api_key) -): +async def upload_images_batch_endpoint(project_id: str, files: list[UploadFile] = File(...), _=Depends(verify_api_key)): """ 批量上传图片文件进行处理 @@ -4506,43 +4396,46 @@ async def upload_images_batch_endpoint( (id, project_id, filename, ocr_text, description, extracted_entities, extracted_relations, status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (image_id, project_id, "batch_image", result.ocr_text, result.description, - json.dumps([{"name": e.name, "type": e.type} for e in result.entities]), - json.dumps([{"source": r.source, "target": r.target} for r in result.relations]), - 'completed', now, now) + ( + image_id, + project_id, + "batch_image", + result.ocr_text, + result.description, + json.dumps([{"name": e.name, "type": e.type} for e in result.entities]), + json.dumps([{"source": r.source, "target": r.target} for r in result.relations]), + "completed", + now, + now, + ), ) conn.commit() conn.close() - results.append({ - "image_id": image_id, - "status": "success", - "image_type": result.image_type, - "entity_count": len(result.entities) - }) + results.append( + { + "image_id": image_id, + "status": "success", + "image_type": result.image_type, + "entity_count": len(result.entities), + } + ) else: - results.append({ - "image_id": result.image_id, - "status": "failed", - "error": result.error_message - }) + results.append({"image_id": result.image_id, "status": "failed", "error": result.error_message}) return { "project_id": project_id, "total_count": batch_result.total_count, "success_count": batch_result.success_count, "failed_count": batch_result.failed_count, - "results": results + "results": results, } -@app.post("/api/v1/projects/{project_id}/multimodal/align", - response_model=MultimodalAlignmentResponse, tags=["Multimodal"]) -async def align_multimodal_entities_endpoint( - project_id: str, - threshold: float = 0.85, - _=Depends(verify_api_key) -): +@app.post( + "/api/v1/projects/{project_id}/multimodal/align", response_model=MultimodalAlignmentResponse, tags=["Multimodal"] +) +async def align_multimodal_entities_endpoint(project_id: str, threshold: float = 0.85, _=Depends(verify_api_key)): """ 跨模态实体对齐 @@ -4567,35 +4460,34 @@ async def align_multimodal_entities_endpoint( # 获取多模态提及 conn = db.get_conn() - mentions = conn.execute( - """SELECT * FROM multimodal_mentions WHERE project_id = ?""", - (project_id,) - ).fetchall() + mentions = conn.execute("""SELECT * FROM multimodal_mentions WHERE project_id = ?""", (project_id,)).fetchall() conn.close() # 按模态分组实体 modality_entities = {"audio": [], "video": [], "image": [], "document": []} for mention in mentions: - modality = mention['modality'] - entity = db.get_entity(mention['entity_id']) - if entity and entity.id not in [e.get('id') for e in modality_entities[modality]]: - modality_entities[modality].append({ - 'id': entity.id, - 'name': entity.name, - 'type': entity.type, - 'definition': entity.definition, - 'aliases': entity.aliases - }) + modality = mention["modality"] + entity = db.get_entity(mention["entity_id"]) + if entity and entity.id not in [e.get("id") for e in modality_entities[modality]]: + modality_entities[modality].append( + { + "id": entity.id, + "name": entity.name, + "type": entity.type, + "definition": entity.definition, + "aliases": entity.aliases, + } + ) # 跨模态对齐 linker = get_multimodal_entity_linker(similarity_threshold=threshold) links = linker.align_cross_modal_entities( project_id=project_id, - audio_entities=modality_entities['audio'], - video_entities=modality_entities['video'], - image_entities=modality_entities['image'], - document_entities=modality_entities['document'] + audio_entities=modality_entities["audio"], + video_entities=modality_entities["video"], + image_entities=modality_entities["image"], + document_entities=modality_entities["document"], ) # 保存关联到数据库 @@ -4608,20 +4500,29 @@ async def align_multimodal_entities_endpoint( """INSERT OR REPLACE INTO multimodal_entity_links (id, entity_id, linked_entity_id, link_type, confidence, evidence, modalities, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - (link.id, link.source_entity_id, link.target_entity_id, link.link_type, - link.confidence, link.evidence, - json.dumps([link.source_modality, link.target_modality]), now) + ( + link.id, + link.source_entity_id, + link.target_entity_id, + link.link_type, + link.confidence, + link.evidence, + json.dumps([link.source_modality, link.target_modality]), + now, + ), + ) + saved_links.append( + MultimodalEntityLinkResponse( + link_id=link.id, + source_entity_id=link.source_entity_id, + target_entity_id=link.target_entity_id, + source_modality=link.source_modality, + target_modality=link.target_modality, + link_type=link.link_type, + confidence=link.confidence, + evidence=link.evidence, + ) ) - saved_links.append(MultimodalEntityLinkResponse( - link_id=link.id, - source_entity_id=link.source_entity_id, - target_entity_id=link.target_entity_id, - source_modality=link.source_modality, - target_modality=link.target_modality, - link_type=link.link_type, - confidence=link.confidence, - evidence=link.evidence - )) conn.commit() conn.close() @@ -4630,7 +4531,7 @@ async def align_multimodal_entities_endpoint( project_id=project_id, aligned_count=len(saved_links), links=saved_links, - message=f"Successfully aligned {len(saved_links)} cross-modal entity pairs" + message=f"Successfully aligned {len(saved_links)} cross-modal entity pairs", ) @@ -4652,36 +4553,33 @@ async def get_multimodal_stats_endpoint(project_id: str, _=Depends(verify_api_ke conn = db.get_conn() # 统计视频数量 - video_count = conn.execute( - "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", - (project_id,) - ).fetchone()['count'] + video_count = conn.execute("SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,)).fetchone()[ + "count" + ] # 统计图片数量 - image_count = conn.execute( - "SELECT COUNT(*) as count FROM images WHERE project_id = ?", - (project_id,) - ).fetchone()['count'] + image_count = conn.execute("SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,)).fetchone()[ + "count" + ] # 统计多模态实体提及 multimodal_count = conn.execute( - "SELECT COUNT(DISTINCT entity_id) as count FROM multimodal_mentions WHERE project_id = ?", - (project_id,) - ).fetchone()['count'] + "SELECT COUNT(DISTINCT entity_id) as count FROM multimodal_mentions WHERE project_id = ?", (project_id,) + ).fetchone()["count"] # 统计跨模态关联 cross_modal_count = conn.execute( "SELECT COUNT(*) as count FROM multimodal_entity_links WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)", - (project_id,) - ).fetchone()['count'] + (project_id,), + ).fetchone()["count"] # 模态分布 modality_dist = {} - for modality in ['audio', 'video', 'image', 'document']: + for modality in ["audio", "video", "image", "document"]: count = conn.execute( "SELECT COUNT(*) as count FROM multimodal_mentions WHERE project_id = ? AND modality = ?", - (project_id, modality) - ).fetchone()['count'] + (project_id, modality), + ).fetchone()["count"] modality_dist[modality] = count conn.close() @@ -4692,7 +4590,7 @@ async def get_multimodal_stats_endpoint(project_id: str, _=Depends(verify_api_ke image_count=image_count, multimodal_entity_count=multimodal_count, cross_modal_links=cross_modal_count, - modality_distribution=modality_dist + modality_distribution=modality_dist, ) @@ -4709,21 +4607,28 @@ async def list_project_videos_endpoint(project_id: str, _=Depends(verify_api_key """SELECT id, filename, duration, fps, resolution, full_ocr_text, status, created_at FROM videos WHERE project_id = ? ORDER BY created_at DESC""", - (project_id,) + (project_id,), ).fetchall() conn.close() - return [{ - "id": v['id'], - "filename": v['filename'], - "duration": v['duration'], - "fps": v['fps'], - "resolution": json.loads(v['resolution']) if v['resolution'] else None, - "ocr_preview": v['full_ocr_text'][:200] + "..." if v['full_ocr_text'] and len(v['full_ocr_text']) > 200 else v['full_ocr_text'], - "status": v['status'], - "created_at": v['created_at'] - } for v in videos] + return [ + { + "id": v["id"], + "filename": v["filename"], + "duration": v["duration"], + "fps": v["fps"], + "resolution": json.loads(v["resolution"]) if v["resolution"] else None, + "ocr_preview": ( + v["full_ocr_text"][:200] + "..." + if v["full_ocr_text"] and len(v["full_ocr_text"]) > 200 + else v["full_ocr_text"] + ), + "status": v["status"], + "created_at": v["created_at"], + } + for v in videos + ] @app.get("/api/v1/projects/{project_id}/images", tags=["Multimodal"]) @@ -4739,20 +4644,25 @@ async def list_project_images_endpoint(project_id: str, _=Depends(verify_api_key """SELECT id, filename, ocr_text, description, extracted_entities, status, created_at FROM images WHERE project_id = ? ORDER BY created_at DESC""", - (project_id,) + (project_id,), ).fetchall() conn.close() - return [{ - "id": img['id'], - "filename": img['filename'], - "ocr_preview": img['ocr_text'][:200] + "..." if img['ocr_text'] and len(img['ocr_text']) > 200 else img['ocr_text'], - "description": img['description'], - "entity_count": len(json.loads(img['extracted_entities'])) if img['extracted_entities'] else 0, - "status": img['status'], - "created_at": img['created_at'] - } for img in images] + return [ + { + "id": img["id"], + "filename": img["filename"], + "ocr_preview": ( + img["ocr_text"][:200] + "..." if img["ocr_text"] and len(img["ocr_text"]) > 200 else img["ocr_text"] + ), + "description": img["description"], + "entity_count": len(json.loads(img["extracted_entities"])) if img["extracted_entities"] else 0, + "status": img["status"], + "created_at": img["created_at"], + } + for img in images + ] @app.get("/api/v1/videos/{video_id}/frames", tags=["Multimodal"]) @@ -4767,19 +4677,22 @@ async def get_video_frames_endpoint(video_id: str, _=Depends(verify_api_key)): frames = conn.execute( """SELECT id, frame_number, timestamp, image_url, ocr_text, extracted_entities FROM video_frames WHERE video_id = ? ORDER BY timestamp""", - (video_id,) + (video_id,), ).fetchall() conn.close() - return [{ - "id": f['id'], - "frame_number": f['frame_number'], - "timestamp": f['timestamp'], - "image_url": f['image_url'], - "ocr_text": f['ocr_text'], - "entities": json.loads(f['extracted_entities']) if f['extracted_entities'] else [] - } for f in frames] + return [ + { + "id": f["id"], + "frame_number": f["frame_number"], + "timestamp": f["timestamp"], + "image_url": f["image_url"], + "ocr_text": f["ocr_text"], + "entities": json.loads(f["extracted_entities"]) if f["extracted_entities"] else [], + } + for f in frames + ] @app.get("/api/v1/entities/{entity_id}/multimodal-mentions", tags=["Multimodal"]) @@ -4796,22 +4709,25 @@ async def get_entity_multimodal_mentions_endpoint(entity_id: str, _=Depends(veri FROM multimodal_mentions m JOIN entities e ON m.entity_id = e.id WHERE m.entity_id = ? ORDER BY m.created_at DESC""", - (entity_id,) + (entity_id,), ).fetchall() conn.close() - return [{ - "id": m['id'], - "entity_id": m['entity_id'], - "entity_name": m['entity_name'], - "modality": m['modality'], - "source_id": m['source_id'], - "source_type": m['source_type'], - "text_snippet": m['text_snippet'], - "confidence": m['confidence'], - "created_at": m['created_at'] - } for m in mentions] + return [ + { + "id": m["id"], + "entity_id": m["entity_id"], + "entity_name": m["entity_name"], + "modality": m["modality"], + "source_id": m["source_id"], + "source_type": m["source_type"], + "text_snippet": m["text_snippet"], + "confidence": m["confidence"], + "created_at": m["created_at"], + } + for m in mentions + ] @app.get("/api/v1/projects/{project_id}/multimodal/suggest-merges", tags=["Multimodal"]) @@ -4834,36 +4750,34 @@ async def suggest_multimodal_merges_endpoint(project_id: str, _=Depends(verify_a # 获取所有实体 entities = db.list_project_entities(project_id) - entity_dicts = [{ - 'id': e.id, - 'name': e.name, - 'type': e.type, - 'definition': e.definition, - 'aliases': e.aliases - } for e in entities] + entity_dicts = [ + {"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": e.aliases} for e in entities + ] # 获取现有链接 conn = db.get_conn() existing_links = conn.execute( """SELECT * FROM multimodal_entity_links WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""", - (project_id,) + (project_id,), ).fetchall() conn.close() existing_link_objects = [] for row in existing_links: - existing_link_objects.append(EntityLink( - id=row['id'], - project_id=project_id, - source_entity_id=row['entity_id'], - target_entity_id=row['linked_entity_id'], - link_type=row['link_type'], - source_modality='unknown', - target_modality='unknown', - confidence=row['confidence'], - evidence=row['evidence'] or "" - )) + existing_link_objects.append( + EntityLink( + id=row["id"], + project_id=project_id, + source_entity_id=row["entity_id"], + target_entity_id=row["linked_entity_id"], + link_type=row["link_type"], + source_modality="unknown", + target_modality="unknown", + confidence=row["confidence"], + evidence=row["evidence"] or "", + ) + ) # 获取建议 linker = get_multimodal_entity_linker() @@ -4875,26 +4789,27 @@ async def suggest_multimodal_merges_endpoint(project_id: str, _=Depends(verify_a "suggestions": [ { "entity1": { - "id": s['entity1'].get('id'), - "name": s['entity1'].get('name'), - "type": s['entity1'].get('type') + "id": s["entity1"].get("id"), + "name": s["entity1"].get("name"), + "type": s["entity1"].get("type"), }, "entity2": { - "id": s['entity2'].get('id'), - "name": s['entity2'].get('name'), - "type": s['entity2'].get('type') + "id": s["entity2"].get("id"), + "name": s["entity2"].get("name"), + "type": s["entity2"].get("type"), }, - "similarity": s['similarity'], - "match_type": s['match_type'], - "suggested_action": s['suggested_action'] + "similarity": s["similarity"], + "match_type": s["match_type"], + "suggested_action": s["suggested_action"], } for s in suggestions[:20] # 最多返回20个建议 - ] + ], } # ==================== Phase 7: Multimodal Support API ==================== + class VideoUploadResponse(BaseModel): video_id: str filename: str @@ -4934,10 +4849,12 @@ class MultimodalProfileResponse(BaseModel): # ==================== Phase 7 Task 7: Plugin Management Pydantic Models ==================== + class PluginCreate(BaseModel): name: str = Field(..., description="插件名称") - plugin_type: str = Field(..., - description="插件类型: chrome_extension, feishu_bot, dingtalk_bot, zapier, make, webdav, custom") + plugin_type: str = Field( + ..., description="插件类型: chrome_extension, feishu_bot, dingtalk_bot, zapier, make, webdav, custom" + ) project_id: str = Field(..., description="关联项目ID") config: dict = Field(default_factory=dict, description="插件配置") @@ -5109,6 +5026,7 @@ def get_plugin_manager_instance(): # ==================== Phase 7 Task 7: Plugin Management Endpoints ==================== + @app.post("/api/v1/plugins", response_model=PluginResponse, tags=["Plugins"]) async def create_plugin_endpoint(request: PluginCreate, _=Depends(verify_api_key)): """ @@ -5133,7 +5051,7 @@ async def create_plugin_endpoint(request: PluginCreate, _=Depends(verify_api_key name=request.name, plugin_type=request.plugin_type, project_id=request.project_id, - config=request.config + config=request.config, ) created = manager.create_plugin(plugin) @@ -5148,16 +5066,13 @@ async def create_plugin_endpoint(request: PluginCreate, _=Depends(verify_api_key created_at=created.created_at, updated_at=created.updated_at, last_used_at=created.last_used_at, - use_count=created.use_count + use_count=created.use_count, ) @app.get("/api/v1/plugins", response_model=PluginListResponse, tags=["Plugins"]) async def list_plugins_endpoint( - project_id: str | None = None, - plugin_type: str | None = None, - status: str | None = None, - _=Depends(verify_api_key) + project_id: str | None = None, plugin_type: str | None = None, status: str | None = None, _=Depends(verify_api_key) ): """获取插件列表""" if not PLUGIN_MANAGER_AVAILABLE: @@ -5178,11 +5093,11 @@ async def list_plugins_endpoint( created_at=p.created_at, updated_at=p.updated_at, last_used_at=p.last_used_at, - use_count=p.use_count + use_count=p.use_count, ) for p in plugins ], - total=len(plugins) + total=len(plugins), ) @@ -5208,7 +5123,7 @@ async def get_plugin_endpoint(plugin_id: str, _=Depends(verify_api_key)): created_at=plugin.created_at, updated_at=plugin.updated_at, last_used_at=plugin.last_used_at, - use_count=plugin.use_count + use_count=plugin.use_count, ) @@ -5236,7 +5151,7 @@ async def update_plugin_endpoint(plugin_id: str, request: PluginUpdate, _=Depend created_at=updated.created_at, updated_at=updated.updated_at, last_used_at=updated.last_used_at, - use_count=updated.use_count + use_count=updated.use_count, ) @@ -5257,6 +5172,7 @@ async def delete_plugin_endpoint(plugin_id: str, _=Depends(verify_api_key)): # ==================== Phase 7 Task 7: Chrome Extension Endpoints ==================== + @app.post("/api/v1/plugins/chrome/tokens", response_model=ChromeExtensionTokenResponse, tags=["Chrome Extension"]) async def create_chrome_token_endpoint(request: ChromeExtensionTokenCreate, _=Depends(verify_api_key)): """ @@ -5277,7 +5193,7 @@ async def create_chrome_token_endpoint(request: ChromeExtensionTokenCreate, _=De name=request.name, project_id=request.project_id, permissions=request.permissions, - expires_days=request.expires_days + expires_days=request.expires_days, ) return ChromeExtensionTokenResponse( @@ -5287,15 +5203,12 @@ async def create_chrome_token_endpoint(request: ChromeExtensionTokenCreate, _=De project_id=token.project_id, permissions=token.permissions, expires_at=token.expires_at, - created_at=token.created_at + created_at=token.created_at, ) @app.get("/api/v1/plugins/chrome/tokens", tags=["Chrome Extension"]) -async def list_chrome_tokens_endpoint( - project_id: str | None = None, - _=Depends(verify_api_key) -): +async def list_chrome_tokens_endpoint(project_id: str | None = None, _=Depends(verify_api_key)): """列出 Chrome 扩展令牌""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -5319,11 +5232,11 @@ async def list_chrome_tokens_endpoint( "created_at": t.created_at, "last_used_at": t.last_used_at, "use_count": t.use_count, - "is_revoked": t.is_revoked + "is_revoked": t.is_revoked, } for t in tokens ], - "total": len(tokens) + "total": len(tokens), } @@ -5370,11 +5283,7 @@ async def chrome_import_webpage_endpoint(request: ChromeExtensionImportRequest): # 导入网页 result = await handler.import_webpage( - token=token, - url=request.url, - title=request.title, - content=request.content, - html_content=request.html_content + token=token, url=request.url, title=request.title, content=request.content, html_content=request.html_content ) if not result["success"]: @@ -5385,6 +5294,7 @@ async def chrome_import_webpage_endpoint(request: ChromeExtensionImportRequest): # ==================== Phase 7 Task 7: Bot Endpoints ==================== + @app.post("/api/v1/plugins/bot/feishu/sessions", response_model=BotSessionResponse, tags=["Bot"]) async def create_feishu_session_endpoint(request: BotSessionCreate, _=Depends(verify_api_key)): """创建飞书机器人会话""" @@ -5402,7 +5312,7 @@ async def create_feishu_session_endpoint(request: BotSessionCreate, _=Depends(ve session_name=request.session_name, project_id=request.project_id, webhook_url=request.webhook_url, - secret=request.secret + secret=request.secret, ) return BotSessionResponse( @@ -5415,7 +5325,7 @@ async def create_feishu_session_endpoint(request: BotSessionCreate, _=Depends(ve is_active=session.is_active, created_at=session.created_at, last_message_at=session.last_message_at, - message_count=session.message_count + message_count=session.message_count, ) @@ -5436,7 +5346,7 @@ async def create_dingtalk_session_endpoint(request: BotSessionCreate, _=Depends( session_name=request.session_name, project_id=request.project_id, webhook_url=request.webhook_url, - secret=request.secret + secret=request.secret, ) return BotSessionResponse( @@ -5449,16 +5359,12 @@ async def create_dingtalk_session_endpoint(request: BotSessionCreate, _=Depends( is_active=session.is_active, created_at=session.created_at, last_message_at=session.last_message_at, - message_count=session.message_count + message_count=session.message_count, ) @app.get("/api/v1/plugins/bot/{bot_type}/sessions", tags=["Bot"]) -async def list_bot_sessions_endpoint( - bot_type: str, - project_id: str | None = None, - _=Depends(verify_api_key) -): +async def list_bot_sessions_endpoint(bot_type: str, project_id: str | None = None, _=Depends(verify_api_key)): """列出机器人会话""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -5488,11 +5394,11 @@ async def list_bot_sessions_endpoint( "is_active": s.is_active, "created_at": s.created_at, "last_message_at": s.last_message_at, - "message_count": s.message_count + "message_count": s.message_count, } for s in sessions ], - "total": len(sessions) + "total": len(sessions), } @@ -5523,9 +5429,9 @@ async def bot_webhook_endpoint(bot_type: str, request: Request): # 获取会话ID(飞书和钉钉的格式不同) if bot_type == "feishu": - session_id = message.get('chat_id') or message.get('open_chat_id') + session_id = message.get("chat_id") or message.get("open_chat_id") else: # dingtalk - session_id = message.get('conversationId') or message.get('senderStaffId') + session_id = message.get("conversationId") or message.get("senderStaffId") if not session_id: raise HTTPException(status_code=400, detail="Cannot identify session") @@ -5534,11 +5440,7 @@ async def bot_webhook_endpoint(bot_type: str, request: Request): session = handler.get_session(session_id) if not session: # 自动创建会话 - session = handler.create_session( - session_id=session_id, - session_name=f"Auto-{session_id[:8]}", - webhook_url="" - ) + session = handler.create_session(session_id=session_id, session_name=f"Auto-{session_id[:8]}", webhook_url="") # 处理消息 result = await handler.handle_message(session, message) @@ -5551,12 +5453,7 @@ async def bot_webhook_endpoint(bot_type: str, request: Request): @app.post("/api/v1/plugins/bot/{bot_type}/sessions/{session_id}/send", tags=["Bot"]) -async def send_bot_message_endpoint( - bot_type: str, - session_id: str, - message: str, - _=Depends(verify_api_key) -): +async def send_bot_message_endpoint(bot_type: str, session_id: str, message: str, _=Depends(verify_api_key)): """发送消息到机器人会话""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -5584,6 +5481,7 @@ async def send_bot_message_endpoint( # ==================== Phase 7 Task 7: Integration Endpoints ==================== + @app.post("/api/v1/plugins/integrations/zapier", response_model=WebhookEndpointResponse, tags=["Integrations"]) async def create_zapier_endpoint(request: WebhookEndpointCreate, _=Depends(verify_api_key)): """创建 Zapier Webhook 端点""" @@ -5602,7 +5500,7 @@ async def create_zapier_endpoint(request: WebhookEndpointCreate, _=Depends(verif project_id=request.project_id, auth_type=request.auth_type, auth_config=request.auth_config, - trigger_events=request.trigger_events + trigger_events=request.trigger_events, ) return WebhookEndpointResponse( @@ -5616,7 +5514,7 @@ async def create_zapier_endpoint(request: WebhookEndpointCreate, _=Depends(verif is_active=endpoint.is_active, created_at=endpoint.created_at, last_triggered_at=endpoint.last_triggered_at, - trigger_count=endpoint.trigger_count + trigger_count=endpoint.trigger_count, ) @@ -5638,7 +5536,7 @@ async def create_make_endpoint(request: WebhookEndpointCreate, _=Depends(verify_ project_id=request.project_id, auth_type=request.auth_type, auth_config=request.auth_config, - trigger_events=request.trigger_events + trigger_events=request.trigger_events, ) return WebhookEndpointResponse( @@ -5652,15 +5550,13 @@ async def create_make_endpoint(request: WebhookEndpointCreate, _=Depends(verify_ is_active=endpoint.is_active, created_at=endpoint.created_at, last_triggered_at=endpoint.last_triggered_at, - trigger_count=endpoint.trigger_count + trigger_count=endpoint.trigger_count, ) @app.get("/api/v1/plugins/integrations/{endpoint_type}", tags=["Integrations"]) async def list_integration_endpoints_endpoint( - endpoint_type: str, - project_id: str | None = None, - _=Depends(verify_api_key) + endpoint_type: str, project_id: str | None = None, _=Depends(verify_api_key) ): """列出集成端点""" if not PLUGIN_MANAGER_AVAILABLE: @@ -5693,11 +5589,11 @@ async def list_integration_endpoints_endpoint( "is_active": e.is_active, "created_at": e.created_at, "last_triggered_at": e.last_triggered_at, - "trigger_count": e.trigger_count + "trigger_count": e.trigger_count, } for e in endpoints ], - "total": len(endpoints) + "total": len(endpoints), } @@ -5722,20 +5618,11 @@ async def test_integration_endpoint(endpoint_id: str, _=Depends(verify_api_key)) result = await handler.test_endpoint(endpoint) - return WebhookTestResponse( - success=result["success"], - endpoint_id=endpoint_id, - message=result["message"] - ) + return WebhookTestResponse(success=result["success"], endpoint_id=endpoint_id, message=result["message"]) @app.post("/api/v1/plugins/integrations/{endpoint_id}/trigger", tags=["Integrations"]) -async def trigger_integration_endpoint( - endpoint_id: str, - event_type: str, - data: dict, - _=Depends(verify_api_key) -): +async def trigger_integration_endpoint(endpoint_id: str, event_type: str, data: dict, _=Depends(verify_api_key)): """手动触发集成端点""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -5760,6 +5647,7 @@ async def trigger_integration_endpoint( # ==================== Phase 7 Task 7: WebDAV Endpoints ==================== + @app.post("/api/v1/plugins/webdav", response_model=WebDAVSyncResponse, tags=["WebDAV"]) async def create_webdav_sync_endpoint(request: WebDAVSyncCreate, _=Depends(verify_api_key)): """ @@ -5784,7 +5672,7 @@ async def create_webdav_sync_endpoint(request: WebDAVSyncCreate, _=Depends(verif password=request.password, remote_path=request.remote_path, sync_mode=request.sync_mode, - sync_interval=request.sync_interval + sync_interval=request.sync_interval, ) return WebDAVSyncResponse( @@ -5800,15 +5688,12 @@ async def create_webdav_sync_endpoint(request: WebDAVSyncCreate, _=Depends(verif last_sync_status=sync.last_sync_status, is_active=sync.is_active, created_at=sync.created_at, - sync_count=sync.sync_count + sync_count=sync.sync_count, ) @app.get("/api/v1/plugins/webdav", tags=["WebDAV"]) -async def list_webdav_syncs_endpoint( - project_id: str | None = None, - _=Depends(verify_api_key) -): +async def list_webdav_syncs_endpoint(project_id: str | None = None, _=Depends(verify_api_key)): """列出 WebDAV 同步配置""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -5836,11 +5721,11 @@ async def list_webdav_syncs_endpoint( "last_sync_status": s.last_sync_status, "is_active": s.is_active, "created_at": s.created_at, - "sync_count": s.sync_count + "sync_count": s.sync_count, } for s in syncs ], - "total": len(syncs) + "total": len(syncs), } @@ -5863,8 +5748,7 @@ async def test_webdav_connection_endpoint(sync_id: str, _=Depends(verify_api_key result = await handler.test_connection(sync) return WebDAVTestResponse( - success=result["success"], - message=result.get("message") or result.get("error", "Unknown result") + success=result["success"], message=result.get("message") or result.get("error", "Unknown result") ) @@ -5892,7 +5776,7 @@ async def sync_webdav_endpoint(sync_id: str, _=Depends(verify_api_key)): entities_count=result.get("entities_count"), relations_count=result.get("relations_count"), remote_path=result.get("remote_path"), - error=result.get("error") + error=result.get("error"), ) @@ -5920,12 +5804,9 @@ async def delete_webdav_sync_endpoint(sync_id: str, _=Depends(verify_api_key)): async def get_openapi(): """获取 OpenAPI 规范""" from fastapi.openapi.utils import get_openapi + return get_openapi( - title=app.title, - version=app.version, - description=app.description, - routes=app.routes, - tags=app.openapi_tags + title=app.title, version=app.version, description=app.description, routes=app.routes, tags=app.openapi_tags ) @@ -5934,6 +5815,7 @@ app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend") if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) @@ -6036,20 +5918,14 @@ class WebhookPayload(BaseModel): @app.post("/api/v1/plugins", response_model=PluginResponse, tags=["Plugins"]) -async def create_plugin( - request: PluginCreateRequest, - api_key: str = Depends(verify_api_key) -): +async def create_plugin(request: PluginCreateRequest, api_key: str = Depends(verify_api_key)): """创建插件""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") manager = get_plugin_manager() plugin = manager.create_plugin( - name=request.name, - plugin_type=request.plugin_type, - project_id=request.project_id, - config=request.config + name=request.name, plugin_type=request.plugin_type, project_id=request.project_id, config=request.config ) return PluginResponse( @@ -6059,15 +5935,13 @@ async def create_plugin( project_id=plugin.project_id, status=plugin.status, api_key=plugin.api_key, - created_at=plugin.created_at + created_at=plugin.created_at, ) @app.get("/api/v1/plugins", tags=["Plugins"]) async def list_plugins( - project_id: str | None = None, - plugin_type: str | None = None, - api_key: str = Depends(verify_api_key) + project_id: str | None = None, plugin_type: str | None = None, api_key: str = Depends(verify_api_key) ): """列出插件""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6085,7 +5959,7 @@ async def list_plugins( "project_id": p.project_id, "status": p.status, "use_count": p.use_count, - "created_at": p.created_at + "created_at": p.created_at, } for p in plugins ] @@ -6093,10 +5967,7 @@ async def list_plugins( @app.get("/api/v1/plugins/{plugin_id}", response_model=PluginResponse, tags=["Plugins"]) -async def get_plugin( - plugin_id: str, - api_key: str = Depends(verify_api_key) -): +async def get_plugin(plugin_id: str, api_key: str = Depends(verify_api_key)): """获取插件详情""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6114,15 +5985,12 @@ async def get_plugin( project_id=plugin.project_id, status=plugin.status, api_key=plugin.api_key, - created_at=plugin.created_at + created_at=plugin.created_at, ) @app.delete("/api/v1/plugins/{plugin_id}", tags=["Plugins"]) -async def delete_plugin( - plugin_id: str, - api_key: str = Depends(verify_api_key) -): +async def delete_plugin(plugin_id: str, api_key: str = Depends(verify_api_key)): """删除插件""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6134,10 +6002,7 @@ async def delete_plugin( @app.post("/api/v1/plugins/{plugin_id}/regenerate-key", tags=["Plugins"]) -async def regenerate_plugin_key( - plugin_id: str, - api_key: str = Depends(verify_api_key) -): +async def regenerate_plugin_key(plugin_id: str, api_key: str = Depends(verify_api_key)): """重新生成插件 API Key""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6150,11 +6015,9 @@ async def regenerate_plugin_key( # ==================== Chrome Extension API ==================== + @app.post("/api/v1/plugins/chrome/clip", response_model=ChromeClipResponse, tags=["Chrome Extension"]) -async def chrome_clip( - request: ChromeClipRequest, - x_api_key: str | None = Header(None, alias="X-API-Key") -): +async def chrome_clip(request: ChromeClipRequest, x_api_key: str | None = Header(None, alias="X-API-Key")): """Chrome 插件保存网页内容""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6195,7 +6058,7 @@ URL: {request.url} project_id=project_id, filename=f"clip_{request.title[:50]}.md", full_text=doc_content, - transcript_type="document" + transcript_type="document", ) # 记录活动 @@ -6203,12 +6066,7 @@ URL: {request.url} plugin_id=plugin.id, activity_type="clip", source="chrome_extension", - details={ - "url": request.url, - "title": request.title, - "project_id": project_id, - "transcript_id": transcript_id - } + details={"url": request.url, "title": request.title, "project_id": project_id, "transcript_id": transcript_id}, ) return ChromeClipResponse( @@ -6217,18 +6075,15 @@ URL: {request.url} url=request.url, title=request.title, status="success", - message="Content saved successfully" + message="Content saved successfully", ) # ==================== Bot API ==================== + @app.post("/api/v1/bots/webhook/{platform}", response_model=BotMessageResponse, tags=["Bot"]) -async def bot_webhook( - platform: str, - request: Request, - x_signature: str | None = Header(None, alias="X-Signature") -): +async def bot_webhook(platform: str, request: Request, x_signature: str | None = Header(None, alias="X-Signature")): """接收机器人 Webhook 消息(飞书/钉钉/Slack)""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6257,15 +6112,13 @@ async def bot_webhook( success=True, reply="收到消息!请使用 InsightFlow 控制台查看更多功能。", session_id=message.get("session_id", ""), - action="reply" + action="reply", ) @app.get("/api/v1/bots/sessions", response_model=list[BotSessionResponse], tags=["Bot"]) async def list_bot_sessions( - plugin_id: str | None = None, - project_id: str | None = None, - api_key: str = Depends(verify_api_key) + plugin_id: str | None = None, project_id: str | None = None, api_key: str = Depends(verify_api_key) ): """列出机器人会话""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6285,7 +6138,7 @@ async def list_bot_sessions( project_id=s.project_id, message_count=s.message_count, created_at=s.created_at, - last_message_at=s.last_message_at + last_message_at=s.last_message_at, ) for s in sessions ] @@ -6293,6 +6146,7 @@ async def list_bot_sessions( # ==================== Webhook Integration API ==================== + @app.post("/api/v1/webhook-endpoints", response_model=WebhookEndpointResponse, tags=["Integrations"]) async def create_integration_webhook_endpoint( plugin_id: str, @@ -6300,7 +6154,7 @@ async def create_integration_webhook_endpoint( endpoint_type: str, target_project_id: str | None = None, allowed_events: list[str] | None = None, - api_key: str = Depends(verify_api_key) + api_key: str = Depends(verify_api_key), ): """创建 Webhook 端点(用于 Zapier/Make 集成)""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6312,7 +6166,7 @@ async def create_integration_webhook_endpoint( name=name, endpoint_type=endpoint_type, target_project_id=target_project_id, - allowed_events=allowed_events + allowed_events=allowed_events, ) return WebhookEndpointResponse( @@ -6324,15 +6178,12 @@ async def create_integration_webhook_endpoint( target_project_id=endpoint.target_project_id, is_active=endpoint.is_active, trigger_count=endpoint.trigger_count, - created_at=endpoint.created_at + created_at=endpoint.created_at, ) @app.get("/api/v1/webhook-endpoints", response_model=list[WebhookEndpointResponse], tags=["Integrations"]) -async def list_webhook_endpoints( - plugin_id: str | None = None, - api_key: str = Depends(verify_api_key) -): +async def list_webhook_endpoints(plugin_id: str | None = None, api_key: str = Depends(verify_api_key)): """列出 Webhook 端点""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6350,7 +6201,7 @@ async def list_webhook_endpoints( target_project_id=e.target_project_id, is_active=e.is_active, trigger_count=e.trigger_count, - created_at=e.created_at + created_at=e.created_at, ) for e in endpoints ] @@ -6358,10 +6209,7 @@ async def list_webhook_endpoints( @app.post("/webhook/{endpoint_type}/{token}", tags=["Integrations"]) async def receive_webhook( - endpoint_type: str, - token: str, - request: Request, - x_signature: str | None = Header(None, alias="X-Signature") + endpoint_type: str, token: str, request: Request, x_signature: str | None = Header(None, alias="X-Signature") ): """接收外部 Webhook 调用(Zapier/Make/Custom)""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6397,22 +6245,19 @@ async def receive_webhook( details={ "endpoint_id": endpoint.id, "event": body.get("event"), - "data_keys": list(body.get("data", {}).keys()) - } + "data_keys": list(body.get("data", {}).keys()), + }, ) # 处理数据(简化版本) # 实际应该根据 endpoint.target_project_id 和 body 内容创建文档/实体等 - return { - "success": True, - "endpoint_id": endpoint.id, - "received_at": datetime.now().isoformat() - } + return {"success": True, "endpoint_id": endpoint.id, "received_at": datetime.now().isoformat()} # ==================== WebDAV API ==================== + @app.post("/api/v1/webdav-syncs", response_model=WebDAVSyncResponse, tags=["WebDAV"]) async def create_webdav_sync( plugin_id: str, @@ -6425,7 +6270,7 @@ async def create_webdav_sync( sync_direction: str = "bidirectional", sync_mode: str = "manual", auto_analyze: bool = True, - api_key: str = Depends(verify_api_key) + api_key: str = Depends(verify_api_key), ): """创建 WebDAV 同步配置""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6442,7 +6287,7 @@ async def create_webdav_sync( local_path=local_path, sync_direction=sync_direction, sync_mode=sync_mode, - auto_analyze=auto_analyze + auto_analyze=auto_analyze, ) return WebDAVSyncResponse( @@ -6458,15 +6303,12 @@ async def create_webdav_sync( auto_analyze=sync.auto_analyze, is_active=sync.is_active, last_sync_at=sync.last_sync_at, - created_at=sync.created_at + created_at=sync.created_at, ) @app.get("/api/v1/webdav-syncs", response_model=list[WebDAVSyncResponse], tags=["WebDAV"]) -async def list_webdav_syncs( - plugin_id: str | None = None, - api_key: str = Depends(verify_api_key) -): +async def list_webdav_syncs(plugin_id: str | None = None, api_key: str = Depends(verify_api_key)): """列出 WebDAV 同步配置""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6488,17 +6330,14 @@ async def list_webdav_syncs( auto_analyze=s.auto_analyze, is_active=s.is_active, last_sync_at=s.last_sync_at, - created_at=s.created_at + created_at=s.created_at, ) for s in syncs ] @app.post("/api/v1/webdav-syncs/{sync_id}/test", tags=["WebDAV"]) -async def test_webdav_connection( - sync_id: str, - api_key: str = Depends(verify_api_key) -): +async def test_webdav_connection(sync_id: str, api_key: str = Depends(verify_api_key)): """测试 WebDAV 连接""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6510,22 +6349,16 @@ async def test_webdav_connection( raise HTTPException(status_code=404, detail="WebDAV sync not found") from plugin_manager import WebDAVSync as WebDAVSyncHandler + handler = WebDAVSyncHandler(manager) - success, message = await handler.test_connection( - sync.server_url, - sync.username, - sync.password - ) + success, message = await handler.test_connection(sync.server_url, sync.username, sync.password) return {"success": success, "message": message} @app.post("/api/v1/webdav-syncs/{sync_id}/sync", tags=["WebDAV"]) -async def trigger_webdav_sync( - sync_id: str, - api_key: str = Depends(verify_api_key) -): +async def trigger_webdav_sync(sync_id: str, api_key: str = Depends(verify_api_key)): """手动触发 WebDAV 同步""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") @@ -6539,39 +6372,24 @@ async def trigger_webdav_sync( # 这里应该启动异步同步任务 # 简化版本,仅返回成功 - manager.update_webdav_sync( - sync_id, - last_sync_at=datetime.now().isoformat(), - last_sync_status="running" - ) + manager.update_webdav_sync(sync_id, last_sync_at=datetime.now().isoformat(), last_sync_status="running") - return { - "success": True, - "sync_id": sync_id, - "status": "running", - "message": "Sync started" - } + return {"success": True, "sync_id": sync_id, "status": "running", "message": "Sync started"} # ==================== Plugin Activity Logs ==================== + @app.get("/api/v1/plugins/{plugin_id}/logs", tags=["Plugins"]) async def get_plugin_logs( - plugin_id: str, - activity_type: str | None = None, - limit: int = 100, - api_key: str = Depends(verify_api_key) + plugin_id: str, activity_type: str | None = None, limit: int = 100, api_key: str = Depends(verify_api_key) ): """获取插件活动日志""" if not PLUGIN_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Plugin manager not available") manager = get_plugin_manager() - logs = manager.get_activity_logs( - plugin_id=plugin_id, - activity_type=activity_type, - limit=limit - ) + logs = manager.get_activity_logs(plugin_id=plugin_id, activity_type=activity_type, limit=limit) return { "logs": [ @@ -6580,7 +6398,7 @@ async def get_plugin_logs( "activity_type": log.activity_type, "source": log.source, "details": log.details, - "created_at": log.created_at + "created_at": log.created_at, } for log in logs ] @@ -6589,6 +6407,7 @@ async def get_plugin_logs( # ==================== Phase 7 Task 3: Security & Compliance API ==================== + # Pydantic models for security API class AuditLogResponse(BaseModel): id: str @@ -6704,6 +6523,7 @@ class AccessRequestResponse(BaseModel): # ==================== Audit Logs API ==================== + @app.get("/api/v1/audit-logs", response_model=list[AuditLogResponse], tags=["Security"]) async def get_audit_logs( user_id: str | None = None, @@ -6715,7 +6535,7 @@ async def get_audit_logs( success: bool | None = None, limit: int = 100, offset: int = 0, - api_key: str = Depends(verify_api_key) + api_key: str = Depends(verify_api_key), ): """查询审计日志""" if not SECURITY_MANAGER_AVAILABLE: @@ -6731,7 +6551,7 @@ async def get_audit_logs( end_time=end_time, success=success, limit=limit, - offset=offset + offset=offset, ) return [ @@ -6745,7 +6565,7 @@ async def get_audit_logs( action_details=log.action_details, success=log.success, error_message=log.error_message, - created_at=log.created_at + created_at=log.created_at, ) for log in logs ] @@ -6753,9 +6573,7 @@ async def get_audit_logs( @app.get("/api/v1/audit-logs/stats", response_model=AuditStatsResponse, tags=["Security"]) async def get_audit_stats( - start_time: str | None = None, - end_time: str | None = None, - api_key: str = Depends(verify_api_key) + start_time: str | None = None, end_time: str | None = None, api_key: str = Depends(verify_api_key) ): """获取审计统计""" if not SECURITY_MANAGER_AVAILABLE: @@ -6769,11 +6587,10 @@ async def get_audit_stats( # ==================== Encryption API ==================== + @app.post("/api/v1/projects/{project_id}/encryption/enable", response_model=EncryptionConfigResponse, tags=["Security"]) async def enable_project_encryption( - project_id: str, - request: EncryptionEnableRequest, - api_key: str = Depends(verify_api_key) + project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key) ): """启用项目端到端加密""" if not SECURITY_MANAGER_AVAILABLE: @@ -6789,7 +6606,7 @@ async def enable_project_encryption( is_enabled=config.is_enabled, encryption_type=config.encryption_type, created_at=config.created_at, - updated_at=config.updated_at + updated_at=config.updated_at, ) except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -6797,9 +6614,7 @@ async def enable_project_encryption( @app.post("/api/v1/projects/{project_id}/encryption/disable", tags=["Security"]) async def disable_project_encryption( - project_id: str, - request: EncryptionEnableRequest, - api_key: str = Depends(verify_api_key) + project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key) ): """禁用项目加密""" if not SECURITY_MANAGER_AVAILABLE: @@ -6816,9 +6631,7 @@ async def disable_project_encryption( @app.post("/api/v1/projects/{project_id}/encryption/verify", tags=["Security"]) async def verify_encryption_password( - project_id: str, - request: EncryptionEnableRequest, - api_key: str = Depends(verify_api_key) + project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key) ): """验证加密密码""" if not SECURITY_MANAGER_AVAILABLE: @@ -6830,12 +6643,10 @@ async def verify_encryption_password( return {"valid": is_valid} -@app.get("/api/v1/projects/{project_id}/encryption", - response_model=Optional[EncryptionConfigResponse], tags=["Security"]) -async def get_encryption_config( - project_id: str, - api_key: str = Depends(verify_api_key) -): +@app.get( + "/api/v1/projects/{project_id}/encryption", response_model=Optional[EncryptionConfigResponse], tags=["Security"] +) +async def get_encryption_config(project_id: str, api_key: str = Depends(verify_api_key)): """获取项目加密配置""" if not SECURITY_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Security manager not available") @@ -6852,17 +6663,16 @@ async def get_encryption_config( is_enabled=config.is_enabled, encryption_type=config.encryption_type, created_at=config.created_at, - updated_at=config.updated_at + updated_at=config.updated_at, ) # ==================== Data Masking API ==================== + @app.post("/api/v1/projects/{project_id}/masking-rules", response_model=MaskingRuleResponse, tags=["Security"]) async def create_masking_rule( - project_id: str, - request: MaskingRuleCreateRequest, - api_key: str = Depends(verify_api_key) + project_id: str, request: MaskingRuleCreateRequest, api_key: str = Depends(verify_api_key) ): """创建数据脱敏规则""" if not SECURITY_MANAGER_AVAILABLE: @@ -6882,7 +6692,7 @@ async def create_masking_rule( pattern=request.pattern, replacement=request.replacement, description=request.description, - priority=request.priority + priority=request.priority, ) return MaskingRuleResponse( @@ -6896,16 +6706,12 @@ async def create_masking_rule( priority=rule.priority, description=rule.description, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) @app.get("/api/v1/projects/{project_id}/masking-rules", response_model=list[MaskingRuleResponse], tags=["Security"]) -async def get_masking_rules( - project_id: str, - active_only: bool = True, - api_key: str = Depends(verify_api_key) -): +async def get_masking_rules(project_id: str, active_only: bool = True, api_key: str = Depends(verify_api_key)): """获取项目脱敏规则""" if not SECURITY_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Security manager not available") @@ -6925,7 +6731,7 @@ async def get_masking_rules( priority=rule.priority, description=rule.description, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) for rule in rules ] @@ -6940,7 +6746,7 @@ async def update_masking_rule( is_active: bool | None = None, priority: int | None = None, description: str | None = None, - api_key: str = Depends(verify_api_key) + api_key: str = Depends(verify_api_key), ): """更新脱敏规则""" if not SECURITY_MANAGER_AVAILABLE: @@ -6978,15 +6784,12 @@ async def update_masking_rule( priority=rule.priority, description=rule.description, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) @app.delete("/api/v1/masking-rules/{rule_id}", tags=["Security"]) -async def delete_masking_rule( - rule_id: str, - api_key: str = Depends(verify_api_key) -): +async def delete_masking_rule(rule_id: str, api_key: str = Depends(verify_api_key)): """删除脱敏规则""" if not SECURITY_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Security manager not available") @@ -7001,11 +6804,7 @@ async def delete_masking_rule( @app.post("/api/v1/projects/{project_id}/masking/apply", response_model=MaskingApplyResponse, tags=["Security"]) -async def apply_masking( - project_id: str, - request: MaskingApplyRequest, - api_key: str = Depends(verify_api_key) -): +async def apply_masking(project_id: str, request: MaskingApplyRequest, api_key: str = Depends(verify_api_key)): """应用脱敏规则到文本""" if not SECURITY_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Security manager not available") @@ -7023,20 +6822,15 @@ async def apply_masking( rules = manager.get_masking_rules(project_id) applied_rules = [r.name for r in rules if r.is_active] - return MaskingApplyResponse( - original_text=request.text, - masked_text=masked_text, - applied_rules=applied_rules - ) + return MaskingApplyResponse(original_text=request.text, masked_text=masked_text, applied_rules=applied_rules) # ==================== Data Access Policy API ==================== + @app.post("/api/v1/projects/{project_id}/access-policies", response_model=AccessPolicyResponse, tags=["Security"]) async def create_access_policy( - project_id: str, - request: AccessPolicyCreateRequest, - api_key: str = Depends(verify_api_key) + project_id: str, request: AccessPolicyCreateRequest, api_key: str = Depends(verify_api_key) ): """创建数据访问策略""" if not SECURITY_MANAGER_AVAILABLE: @@ -7053,7 +6847,7 @@ async def create_access_policy( allowed_ips=request.allowed_ips, time_restrictions=request.time_restrictions, max_access_count=request.max_access_count, - require_approval=request.require_approval + require_approval=request.require_approval, ) return AccessPolicyResponse( @@ -7069,16 +6863,12 @@ async def create_access_policy( require_approval=policy.require_approval, is_active=policy.is_active, created_at=policy.created_at, - updated_at=policy.updated_at + updated_at=policy.updated_at, ) @app.get("/api/v1/projects/{project_id}/access-policies", response_model=list[AccessPolicyResponse], tags=["Security"]) -async def get_access_policies( - project_id: str, - active_only: bool = True, - api_key: str = Depends(verify_api_key) -): +async def get_access_policies(project_id: str, active_only: bool = True, api_key: str = Depends(verify_api_key)): """获取项目访问策略""" if not SECURITY_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Security manager not available") @@ -7100,7 +6890,7 @@ async def get_access_policies( require_approval=policy.require_approval, is_active=policy.is_active, created_at=policy.created_at, - updated_at=policy.updated_at + updated_at=policy.updated_at, ) for policy in policies ] @@ -7108,10 +6898,7 @@ async def get_access_policies( @app.post("/api/v1/access-policies/{policy_id}/check", tags=["Security"]) async def check_access_permission( - policy_id: str, - user_id: str, - user_ip: str | None = None, - api_key: str = Depends(verify_api_key) + policy_id: str, user_id: str, user_ip: str | None = None, api_key: str = Depends(verify_api_key) ): """检查访问权限""" if not SECURITY_MANAGER_AVAILABLE: @@ -7120,19 +6907,17 @@ async def check_access_permission( manager = get_security_manager() allowed, reason = manager.check_access_permission(policy_id, user_id, user_ip) - return { - "allowed": allowed, - "reason": reason if not allowed else None - } + return {"allowed": allowed, "reason": reason if not allowed else None} # ==================== Access Request API ==================== + @app.post("/api/v1/access-requests", response_model=AccessRequestResponse, tags=["Security"]) async def create_access_request( request: AccessRequestCreateRequest, user_id: str, # 实际应该从认证信息中获取 - api_key: str = Depends(verify_api_key) + api_key: str = Depends(verify_api_key), ): """创建访问请求""" if not SECURITY_MANAGER_AVAILABLE: @@ -7144,7 +6929,7 @@ async def create_access_request( policy_id=request.policy_id, user_id=user_id, request_reason=request.request_reason, - expires_hours=request.expires_hours + expires_hours=request.expires_hours, ) return AccessRequestResponse( @@ -7156,16 +6941,13 @@ async def create_access_request( approved_by=access_request.approved_by, approved_at=access_request.approved_at, expires_at=access_request.expires_at, - created_at=access_request.created_at + created_at=access_request.created_at, ) @app.post("/api/v1/access-requests/{request_id}/approve", response_model=AccessRequestResponse, tags=["Security"]) async def approve_access_request( - request_id: str, - approved_by: str, - expires_hours: int = 24, - api_key: str = Depends(verify_api_key) + request_id: str, approved_by: str, expires_hours: int = 24, api_key: str = Depends(verify_api_key) ): """批准访问请求""" if not SECURITY_MANAGER_AVAILABLE: @@ -7186,16 +6968,12 @@ async def approve_access_request( approved_by=access_request.approved_by, approved_at=access_request.approved_at, expires_at=access_request.expires_at, - created_at=access_request.created_at + created_at=access_request.created_at, ) @app.post("/api/v1/access-requests/{request_id}/reject", response_model=AccessRequestResponse, tags=["Security"]) -async def reject_access_request( - request_id: str, - rejected_by: str, - api_key: str = Depends(verify_api_key) -): +async def reject_access_request(request_id: str, rejected_by: str, api_key: str = Depends(verify_api_key)): """拒绝访问请求""" if not SECURITY_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Security manager not available") @@ -7215,7 +6993,7 @@ async def reject_access_request( approved_by=access_request.approved_by, approved_at=access_request.approved_at, expires_at=access_request.expires_at, - created_at=access_request.created_at + created_at=access_request.created_at, ) @@ -7225,6 +7003,7 @@ async def reject_access_request( # ----- 请求模型 ----- + class ShareLinkCreate(BaseModel): permission: str = "read_only" # read_only, comment, edit, admin expires_in_days: int | None = None @@ -7268,6 +7047,7 @@ class TeamMemberRoleUpdate(BaseModel): # ----- 项目分享 ----- + @app.post("/api/v1/projects/{project_id}/shares") async def create_share_link(project_id: str, request: ShareLinkCreate, created_by: str = "current_user"): """创建项目分享链接""" @@ -7283,7 +7063,7 @@ async def create_share_link(project_id: str, request: ShareLinkCreate, created_b max_uses=request.max_uses, password=request.password, allow_download=request.allow_download, - allow_export=request.allow_export + allow_export=request.allow_export, ) return { @@ -7293,7 +7073,7 @@ async def create_share_link(project_id: str, request: ShareLinkCreate, created_b "created_at": share.created_at, "expires_at": share.expires_at, "max_uses": share.max_uses, - "share_url": f"/share/{share.token}" + "share_url": f"/share/{share.token}", } @@ -7319,7 +7099,7 @@ async def list_project_shares(project_id: str): "is_active": s.is_active, "has_password": s.password_hash is not None, "allow_download": s.allow_download, - "allow_export": s.allow_export + "allow_export": s.allow_export, } for s in shares ] @@ -7346,7 +7126,7 @@ async def verify_share_link(request: ShareLinkVerify): "project_id": share.project_id, "permission": share.permission, "allow_download": share.allow_download, - "allow_export": share.allow_export + "allow_export": share.allow_export, } @@ -7380,11 +7160,11 @@ async def access_shared_project(token: str, password: str | None = None): "id": project.id, "name": project.name, "description": project.description, - "created_at": project.created_at + "created_at": project.created_at, }, "permission": share.permission, "allow_download": share.allow_download, - "allow_export": share.allow_export + "allow_export": share.allow_export, } @@ -7402,6 +7182,7 @@ async def revoke_share_link(share_id: str, revoked_by: str = "current_user"): return {"success": True, "message": "Share link revoked"} + # ----- 评论和批注 ----- @@ -7420,7 +7201,7 @@ async def add_comment(project_id: str, request: CommentCreate, author: str = "cu author_name=author_name, content=request.content, parent_id=request.parent_id, - mentions=request.mentions + mentions=request.mentions, ) return { @@ -7432,7 +7213,7 @@ async def add_comment(project_id: str, request: CommentCreate, author: str = "cu "author_name": comment.author_name, "content": comment.content, "created_at": comment.created_at, - "resolved": comment.resolved + "resolved": comment.resolved, } @@ -7458,10 +7239,10 @@ async def get_comments(target_type: str, target_id: str, include_resolved: bool "updated_at": c.updated_at, "resolved": c.resolved, "resolved_by": c.resolved_by, - "resolved_at": c.resolved_at + "resolved_at": c.resolved_at, } for c in comments - ] + ], } @@ -7486,10 +7267,10 @@ async def get_project_comments(project_id: str, limit: int = 50, offset: int = 0 "author_name": c.author_name, "content": c.content, "created_at": c.created_at, - "resolved": c.resolved + "resolved": c.resolved, } for c in comments - ] + ], } @@ -7505,11 +7286,7 @@ async def update_comment(comment_id: str, request: CommentUpdate, updated_by: st if not comment: raise HTTPException(status_code=404, detail="Comment not found or not authorized") - return { - "id": comment.id, - "content": comment.content, - "updated_at": comment.updated_at - } + return {"id": comment.id, "content": comment.content, "updated_at": comment.updated_at} @app.post("/api/v1/comments/{comment_id}/resolve") @@ -7541,16 +7318,13 @@ async def delete_comment(comment_id: str, deleted_by: str = "current_user"): return {"success": True, "message": "Comment deleted"} + # ----- 变更历史 ----- @app.get("/api/v1/projects/{project_id}/history") async def get_change_history( - project_id: str, - entity_type: str | None = None, - entity_id: str | None = None, - limit: int = 50, - offset: int = 0 + project_id: str, entity_type: str | None = None, entity_id: str | None = None, limit: int = 50, offset: int = 0 ): """获取变更历史""" if not COLLABORATION_AVAILABLE: @@ -7574,10 +7348,10 @@ async def get_change_history( "old_value": r.old_value, "new_value": r.new_value, "description": r.description, - "reverted": r.reverted + "reverted": r.reverted, } for r in records - ] + ], } @@ -7613,10 +7387,10 @@ async def get_entity_versions(entity_type: str, entity_id: str): "changed_at": r.changed_at, "old_value": r.old_value, "new_value": r.new_value, - "description": r.description + "description": r.description, } for r in records - ] + ], } @@ -7634,6 +7408,7 @@ async def revert_change(record_id: str, reverted_by: str = "current_user"): return {"success": True, "message": "Change reverted"} + # ----- 团队成员 ----- @@ -7650,7 +7425,7 @@ async def invite_team_member(project_id: str, request: TeamMemberInvite, invited user_name=request.user_name, user_email=request.user_email, role=request.role, - invited_by=invited_by + invited_by=invited_by, ) return { @@ -7660,7 +7435,7 @@ async def invite_team_member(project_id: str, request: TeamMemberInvite, invited "user_email": member.user_email, "role": member.role, "joined_at": member.joined_at, - "permissions": member.permissions + "permissions": member.permissions, } @@ -7684,10 +7459,10 @@ async def list_team_members(project_id: str): "role": m.role, "joined_at": m.joined_at, "last_active_at": m.last_active_at, - "permissions": m.permissions + "permissions": m.permissions, } for m in members - ] + ], } @@ -7737,23 +7512,17 @@ async def check_project_permissions(project_id: str, user_id: str = "current_use break if not user_member: - return { - "has_access": False, - "role": None, - "permissions": [] - } + return {"has_access": False, "role": None, "permissions": []} - return { - "has_access": True, - "role": user_member.role, - "permissions": user_member.permissions - } + return {"has_access": True, "role": user_member.role, "permissions": user_member.permissions} # ==================== Phase 7 Task 6: Advanced Search & Discovery ==================== + class FullTextSearchRequest(BaseModel): """全文搜索请求""" + query: str content_types: list[str] | None = None operator: str = "AND" # AND, OR, NOT @@ -7762,6 +7531,7 @@ class FullTextSearchRequest(BaseModel): class SemanticSearchRequest(BaseModel): """语义搜索请求""" + query: str content_types: list[str] | None = None threshold: float = 0.7 @@ -7769,11 +7539,7 @@ class SemanticSearchRequest(BaseModel): @app.post("/api/v1/search/fulltext", tags=["Search"]) -async def fulltext_search( - project_id: str, - request: FullTextSearchRequest, - _=Depends(verify_api_key) -): +async def fulltext_search(project_id: str, request: FullTextSearchRequest, _=Depends(verify_api_key)): """全文搜索""" if not SEARCH_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Search manager not available") @@ -7790,30 +7556,29 @@ async def fulltext_search( project_id=project_id, content_types=request.content_types, operator=operator, - limit=request.limit + limit=request.limit, ) return { "query": request.query, "operator": request.operator, "total": len(results), - "results": [{ - "id": r.id, - "type": r.type, - "title": r.title, - "content": r.content, - "highlights": r.highlights, - "score": r.score - } for r in results] + "results": [ + { + "id": r.id, + "type": r.type, + "title": r.title, + "content": r.content, + "highlights": r.highlights, + "score": r.score, + } + for r in results + ], } @app.post("/api/v1/search/semantic", tags=["Search"]) -async def semantic_search( - project_id: str, - request: SemanticSearchRequest, - _=Depends(verify_api_key) -): +async def semantic_search(project_id: str, request: SemanticSearchRequest, _=Depends(verify_api_key)): """语义搜索""" if not SEARCH_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Search manager not available") @@ -7825,29 +7590,20 @@ async def semantic_search( project_id=project_id, content_types=request.content_types, threshold=request.threshold, - limit=request.limit + limit=request.limit, ) return { "query": request.query, "threshold": request.threshold, "total": len(results), - "results": [{ - "id": r.id, - "type": r.type, - "text": r.text, - "similarity": r.similarity - } for r in results] + "results": [{"id": r.id, "type": r.type, "text": r.text, "similarity": r.similarity} for r in results], } @app.get("/api/v1/entities/{entity_id}/paths/{target_entity_id}", tags=["Search"]) async def find_entity_paths( - entity_id: str, - target_entity_id: str, - max_depth: int = 5, - find_all: bool = False, - _=Depends(verify_api_key) + entity_id: str, target_entity_id: str, max_depth: int = 5, find_all: bool = False, _=Depends(verify_api_key) ): """查找实体关系路径""" if not SEARCH_MANAGER_AVAILABLE: @@ -7857,15 +7613,11 @@ async def find_entity_paths( if find_all: paths = search_manager.path_discovery.find_all_paths( - source_entity_id=entity_id, - target_entity_id=target_entity_id, - max_depth=max_depth + source_entity_id=entity_id, target_entity_id=target_entity_id, max_depth=max_depth ) else: path = search_manager.path_discovery.find_shortest_path( - source_entity_id=entity_id, - target_entity_id=target_entity_id, - max_depth=max_depth + source_entity_id=entity_id, target_entity_id=target_entity_id, max_depth=max_depth ) paths = [path] if path else [] @@ -7873,22 +7625,21 @@ async def find_entity_paths( "source_entity_id": entity_id, "target_entity_id": target_entity_id, "path_count": len(paths), - "paths": [{ - "path_id": p.path_id, - "path_length": p.path_length, - "nodes": p.nodes, - "edges": p.edges, - "confidence": p.confidence - } for p in paths] + "paths": [ + { + "path_id": p.path_id, + "path_length": p.path_length, + "nodes": p.nodes, + "edges": p.edges, + "confidence": p.confidence, + } + for p in paths + ], } @app.get("/api/v1/entities/{entity_id}/network", tags=["Search"]) -async def get_entity_network( - entity_id: str, - depth: int = 2, - _=Depends(verify_api_key) -): +async def get_entity_network(entity_id: str, depth: int = 2, _=Depends(verify_api_key)): """获取实体关系网络""" if not SEARCH_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Search manager not available") @@ -7900,10 +7651,7 @@ async def get_entity_network( @app.get("/api/v1/projects/{project_id}/knowledge-gaps", tags=["Search"]) -async def detect_knowledge_gaps( - project_id: str, - _=Depends(verify_api_key) -): +async def detect_knowledge_gaps(project_id: str, _=Depends(verify_api_key)): """检测知识缺口""" if not SEARCH_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Search manager not available") @@ -7917,23 +7665,23 @@ async def detect_knowledge_gaps( "project_id": project_id, "completeness": completeness, "gap_count": len(gaps), - "gaps": [{ - "gap_id": g.gap_id, - "gap_type": g.gap_type, - "entity_id": g.entity_id, - "entity_name": g.entity_name, - "description": g.description, - "severity": g.severity, - "suggestion": g.suggestion - } for g in gaps] + "gaps": [ + { + "gap_id": g.gap_id, + "gap_type": g.gap_type, + "entity_id": g.entity_id, + "entity_name": g.entity_name, + "description": g.description, + "severity": g.severity, + "suggestion": g.suggestion, + } + for g in gaps + ], } @app.post("/api/v1/projects/{project_id}/search/index", tags=["Search"]) -async def index_project_for_search( - project_id: str, - _=Depends(verify_api_key) -): +async def index_project_for_search(project_id: str, _=Depends(verify_api_key)): """为项目创建搜索索引""" if not SEARCH_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Search manager not available") @@ -7949,10 +7697,9 @@ async def index_project_for_search( # ==================== Phase 7 Task 8: Performance & Scaling ==================== + @app.get("/api/v1/cache/stats", tags=["Performance"]) -async def get_cache_stats( - _=Depends(verify_api_key) -): +async def get_cache_stats(_=Depends(verify_api_key)): """获取缓存统计""" if not PERFORMANCE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Performance manager not available") @@ -7967,15 +7714,12 @@ async def get_cache_stats( "miss_count": stats.miss_count, "hit_rate": stats.hit_rate, "evicted_count": stats.evicted_count, - "expired_count": stats.expired_count + "expired_count": stats.expired_count, } @app.post("/api/v1/cache/clear", tags=["Performance"]) -async def clear_cache( - pattern: str | None = None, - _=Depends(verify_api_key) -): +async def clear_cache(pattern: str | None = None, _=Depends(verify_api_key)): """清除缓存""" if not PERFORMANCE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Performance manager not available") @@ -7995,7 +7739,7 @@ async def get_performance_metrics( endpoint: str | None = None, hours: int = 24, limit: int = 1000, - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取性能指标""" if not PERFORMANCE_MANAGER_AVAILABLE: @@ -8006,31 +7750,28 @@ async def get_performance_metrics( start_time = (datetime.now() - timedelta(hours=hours)).isoformat() metrics = perf_manager.monitor.get_metrics( - metric_type=metric_type, - endpoint=endpoint, - start_time=start_time, - limit=limit + metric_type=metric_type, endpoint=endpoint, start_time=start_time, limit=limit ) return { "period_hours": hours, "total": len(metrics), - "metrics": [{ - "id": m.id, - "metric_type": m.metric_type, - "endpoint": m.endpoint, - "duration_ms": m.duration_ms, - "status_code": m.status_code, - "timestamp": m.timestamp - } for m in metrics] + "metrics": [ + { + "id": m.id, + "metric_type": m.metric_type, + "endpoint": m.endpoint, + "duration_ms": m.duration_ms, + "status_code": m.status_code, + "timestamp": m.timestamp, + } + for m in metrics + ], } @app.get("/api/v1/performance/summary", tags=["Performance"]) -async def get_performance_summary( - hours: int = 24, - _=Depends(verify_api_key) -): +async def get_performance_summary(hours: int = 24, _=Depends(verify_api_key)): """获取性能汇总统计""" if not PERFORMANCE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Performance manager not available") @@ -8042,10 +7783,7 @@ async def get_performance_summary( @app.get("/api/v1/tasks/{task_id}/status", tags=["Performance"]) -async def get_task_status( - task_id: str, - _=Depends(verify_api_key) -): +async def get_task_status(task_id: str, _=Depends(verify_api_key)): """获取任务状态""" if not PERFORMANCE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Performance manager not available") @@ -8068,16 +7806,13 @@ async def get_task_status( "started_at": task.started_at, "completed_at": task.completed_at, "retry_count": task.retry_count, - "priority": task.priority + "priority": task.priority, } @app.get("/api/v1/tasks", tags=["Performance"]) async def list_tasks( - project_id: str | None = None, - status: str | None = None, - limit: int = 50, - _=Depends(verify_api_key) + project_id: str | None = None, status: str | None = None, limit: int = 50, _=Depends(verify_api_key) ): """列出任务""" if not PERFORMANCE_MANAGER_AVAILABLE: @@ -8088,23 +7823,23 @@ async def list_tasks( return { "total": len(tasks), - "tasks": [{ - "task_id": t.task_id, - "task_type": t.task_type, - "status": t.status, - "project_id": t.project_id, - "created_at": t.created_at, - "retry_count": t.retry_count, - "priority": t.priority - } for t in tasks] + "tasks": [ + { + "task_id": t.task_id, + "task_type": t.task_type, + "status": t.status, + "project_id": t.project_id, + "created_at": t.created_at, + "retry_count": t.retry_count, + "priority": t.priority, + } + for t in tasks + ], } @app.post("/api/v1/tasks/{task_id}/cancel", tags=["Performance"]) -async def cancel_task( - task_id: str, - _=Depends(verify_api_key) -): +async def cancel_task(task_id: str, _=Depends(verify_api_key)): """取消任务""" if not PERFORMANCE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Performance manager not available") @@ -8119,9 +7854,7 @@ async def cancel_task( @app.get("/api/v1/shards", tags=["Performance"]) -async def list_shards( - _=Depends(verify_api_key) -): +async def list_shards(_=Depends(verify_api_key)): """列出数据库分片""" if not PERFORMANCE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Performance manager not available") @@ -8131,12 +7864,10 @@ async def list_shards( return { "shard_count": len(shards), - "shards": [{ - "shard_id": s.shard_id, - "entity_count": s.entity_count, - "db_path": s.db_path, - "created_at": s.created_at - } for s in shards] + "shards": [ + {"shard_id": s.shard_id, "entity_count": s.entity_count, "db_path": s.db_path, "created_at": s.created_at} + for s in shards + ], } @@ -8144,6 +7875,7 @@ async def list_shards( # Phase 8: Multi-Tenant SaaS APIs # ============================================ + class CreateTenantRequest(BaseModel): name: str description: str | None = None @@ -8184,9 +7916,7 @@ class UpdateMemberRequest(BaseModel): # Tenant Management APIs @app.post("/api/v1/tenants", tags=["Tenants"]) async def create_tenant( - request: CreateTenantRequest, - user_id: str = Header(..., description="当前用户ID"), - _=Depends(verify_api_key) + request: CreateTenantRequest, user_id: str = Header(..., description="当前用户ID"), _=Depends(verify_api_key) ): """创建新租户""" if not TENANT_MANAGER_AVAILABLE: @@ -8195,10 +7925,7 @@ async def create_tenant( manager = get_tenant_manager() try: tenant = manager.create_tenant( - name=request.name, - owner_id=user_id, - tier=request.tier, - description=request.description + name=request.name, owner_id=user_id, tier=request.tier, description=request.description ) return { "id": tenant.id, @@ -8206,17 +7933,14 @@ async def create_tenant( "slug": tenant.slug, "tier": tenant.tier, "status": tenant.status, - "created_at": tenant.created_at.isoformat() + "created_at": tenant.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/tenants", tags=["Tenants"]) -async def list_my_tenants( - user_id: str = Header(..., description="当前用户ID"), - _=Depends(verify_api_key) -): +async def list_my_tenants(user_id: str = Header(..., description="当前用户ID"), _=Depends(verify_api_key)): """获取当前用户的所有租户""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8227,10 +7951,7 @@ async def list_my_tenants( @app.get("/api/v1/tenants/{tenant_id}", tags=["Tenants"]) -async def get_tenant( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_tenant(tenant_id: str, _=Depends(verify_api_key)): """获取租户详情""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8251,16 +7972,12 @@ async def get_tenant( "owner_id": tenant.owner_id, "created_at": tenant.created_at.isoformat(), "settings": tenant.settings, - "resource_limits": tenant.resource_limits + "resource_limits": tenant.resource_limits, } @app.put("/api/v1/tenants/{tenant_id}", tags=["Tenants"]) -async def update_tenant( - tenant_id: str, - request: UpdateTenantRequest, - _=Depends(verify_api_key) -): +async def update_tenant(tenant_id: str, request: UpdateTenantRequest, _=Depends(verify_api_key)): """更新租户信息""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8271,7 +7988,7 @@ async def update_tenant( name=request.name, description=request.description, tier=request.tier, - status=request.status + status=request.status, ) if not tenant: @@ -8283,15 +8000,12 @@ async def update_tenant( "slug": tenant.slug, "tier": tenant.tier, "status": tenant.status, - "updated_at": tenant.updated_at.isoformat() + "updated_at": tenant.updated_at.isoformat(), } @app.delete("/api/v1/tenants/{tenant_id}", tags=["Tenants"]) -async def delete_tenant( - tenant_id: str, - _=Depends(verify_api_key) -): +async def delete_tenant(tenant_id: str, _=Depends(verify_api_key)): """删除租户""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8307,22 +8021,14 @@ async def delete_tenant( # Domain Management APIs @app.post("/api/v1/tenants/{tenant_id}/domains", tags=["Tenants"]) -async def add_domain( - tenant_id: str, - request: AddDomainRequest, - _=Depends(verify_api_key) -): +async def add_domain(tenant_id: str, request: AddDomainRequest, _=Depends(verify_api_key)): """为租户添加自定义域名""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") manager = get_tenant_manager() try: - domain = manager.add_domain( - tenant_id=tenant_id, - domain=request.domain, - is_primary=request.is_primary - ) + domain = manager.add_domain(tenant_id=tenant_id, domain=request.domain, is_primary=request.is_primary) # 获取验证指导 instructions = manager.get_domain_verification_instructions(domain.id) @@ -8334,17 +8040,14 @@ async def add_domain( "is_primary": domain.is_primary, "verification_token": domain.verification_token, "verification_instructions": instructions, - "created_at": domain.created_at.isoformat() + "created_at": domain.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/tenants/{tenant_id}/domains", tags=["Tenants"]) -async def list_domains( - tenant_id: str, - _=Depends(verify_api_key) -): +async def list_domains(tenant_id: str, _=Depends(verify_api_key)): """列出租户的所有域名""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8353,24 +8056,23 @@ async def list_domains( domains = manager.list_domains(tenant_id) return { - "domains": [{ - "id": d.id, - "domain": d.domain, - "status": d.status, - "is_primary": d.is_primary, - "ssl_enabled": d.ssl_enabled, - "verified_at": d.verified_at.isoformat() if d.verified_at else None, - "created_at": d.created_at.isoformat() - } for d in domains] + "domains": [ + { + "id": d.id, + "domain": d.domain, + "status": d.status, + "is_primary": d.is_primary, + "ssl_enabled": d.ssl_enabled, + "verified_at": d.verified_at.isoformat() if d.verified_at else None, + "created_at": d.created_at.isoformat(), + } + for d in domains + ] } @app.post("/api/v1/tenants/{tenant_id}/domains/{domain_id}/verify", tags=["Tenants"]) -async def verify_domain( - tenant_id: str, - domain_id: str, - _=Depends(verify_api_key) -): +async def verify_domain(tenant_id: str, domain_id: str, _=Depends(verify_api_key)): """验证域名所有权""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8378,18 +8080,11 @@ async def verify_domain( manager = get_tenant_manager() success = manager.verify_domain(tenant_id, domain_id) - return { - "success": success, - "message": "Domain verified successfully" if success else "Domain verification failed" - } + return {"success": success, "message": "Domain verified successfully" if success else "Domain verification failed"} @app.delete("/api/v1/tenants/{tenant_id}/domains/{domain_id}", tags=["Tenants"]) -async def remove_domain( - tenant_id: str, - domain_id: str, - _=Depends(verify_api_key) -): +async def remove_domain(tenant_id: str, domain_id: str, _=Depends(verify_api_key)): """移除域名绑定""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8405,10 +8100,7 @@ async def remove_domain( # Branding APIs @app.get("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"]) -async def get_branding( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_branding(tenant_id: str, _=Depends(verify_api_key)): """获取租户品牌配置""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8423,7 +8115,7 @@ async def get_branding( "favicon_url": None, "primary_color": None, "secondary_color": None, - "custom_css": None + "custom_css": None, } return { @@ -8434,16 +8126,12 @@ async def get_branding( "secondary_color": branding.secondary_color, "custom_css": branding.custom_css, "custom_js": branding.custom_js, - "login_page_bg": branding.login_page_bg + "login_page_bg": branding.login_page_bg, } @app.put("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"]) -async def update_branding( - tenant_id: str, - request: UpdateBrandingRequest, - _=Depends(verify_api_key) -): +async def update_branding(tenant_id: str, request: UpdateBrandingRequest, _=Depends(verify_api_key)): """更新租户品牌配置""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8457,7 +8145,7 @@ async def update_branding( secondary_color=request.secondary_color, custom_css=request.custom_css, custom_js=request.custom_js, - login_page_bg=request.login_page_bg + login_page_bg=request.login_page_bg, ) return { @@ -8466,7 +8154,7 @@ async def update_branding( "favicon_url": branding.favicon_url, "primary_color": branding.primary_color, "secondary_color": branding.secondary_color, - "updated_at": branding.updated_at.isoformat() + "updated_at": branding.updated_at.isoformat(), } @@ -8480,6 +8168,7 @@ async def get_branding_css(tenant_id: str): css = manager.get_branding_css(tenant_id) from fastapi.responses import PlainTextResponse + return PlainTextResponse(content=css, media_type="text/css") @@ -8489,7 +8178,7 @@ async def invite_member( tenant_id: str, request: InviteMemberRequest, user_id: str = Header(..., description="邀请者用户ID"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """邀请成员加入租户""" if not TENANT_MANAGER_AVAILABLE: @@ -8497,30 +8186,21 @@ async def invite_member( manager = get_tenant_manager() try: - member = manager.invite_member( - tenant_id=tenant_id, - email=request.email, - role=request.role, - invited_by=user_id - ) + member = manager.invite_member(tenant_id=tenant_id, email=request.email, role=request.role, invited_by=user_id) return { "id": member.id, "email": member.email, "role": member.role, "status": member.status, - "invited_at": member.invited_at.isoformat() + "invited_at": member.invited_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/tenants/{tenant_id}/members", tags=["Tenants"]) -async def list_members( - tenant_id: str, - status: str | None = None, - _=Depends(verify_api_key) -): +async def list_members(tenant_id: str, status: str | None = None, _=Depends(verify_api_key)): """列出租户成员""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8529,27 +8209,25 @@ async def list_members( members = manager.list_members(tenant_id, status) return { - "members": [{ - "id": m.id, - "user_id": m.user_id, - "email": m.email, - "role": m.role, - "status": m.status, - "permissions": m.permissions, - "invited_at": m.invited_at.isoformat(), - "joined_at": m.joined_at.isoformat() if m.joined_at else None, - "last_active_at": m.last_active_at.isoformat() if m.last_active_at else None - } for m in members] + "members": [ + { + "id": m.id, + "user_id": m.user_id, + "email": m.email, + "role": m.role, + "status": m.status, + "permissions": m.permissions, + "invited_at": m.invited_at.isoformat(), + "joined_at": m.joined_at.isoformat() if m.joined_at else None, + "last_active_at": m.last_active_at.isoformat() if m.last_active_at else None, + } + for m in members + ] } @app.put("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"]) -async def update_member( - tenant_id: str, - member_id: str, - request: UpdateMemberRequest, - _=Depends(verify_api_key) -): +async def update_member(tenant_id: str, member_id: str, request: UpdateMemberRequest, _=Depends(verify_api_key)): """更新成员角色""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8564,11 +8242,7 @@ async def update_member( @app.delete("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"]) -async def remove_member( - tenant_id: str, - member_id: str, - _=Depends(verify_api_key) -): +async def remove_member(tenant_id: str, member_id: str, _=Depends(verify_api_key)): """移除成员""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8584,10 +8258,7 @@ async def remove_member( # Usage & Limits APIs @app.get("/api/v1/tenants/{tenant_id}/usage", tags=["Tenants"]) -async def get_tenant_usage( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_tenant_usage(tenant_id: str, _=Depends(verify_api_key)): """获取租户资源使用统计""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8599,11 +8270,7 @@ async def get_tenant_usage( @app.get("/api/v1/tenants/{tenant_id}/limits/{resource_type}", tags=["Tenants"]) -async def check_resource_limit( - tenant_id: str, - resource_type: str, - _=Depends(verify_api_key) -): +async def check_resource_limit(tenant_id: str, resource_type: str, _=Depends(verify_api_key)): """检查特定资源是否超限""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -8616,7 +8283,7 @@ async def check_resource_limit( "allowed": allowed, "current": current, "limit": limit, - "usage_percentage": round(current / limit * 100, 2) if limit > 0 else 0 + "usage_percentage": round(current / limit * 100, 2) if limit > 0 else 0, } @@ -8643,19 +8310,15 @@ async def resolve_tenant_by_domain(domain: str): "branding": { "logo_url": branding.logo_url if branding else None, "primary_color": branding.primary_color if branding else None, - "favicon_url": branding.favicon_url if branding else None - } + "favicon_url": branding.favicon_url if branding else None, + }, } @app.get("/api/v1/health", tags=["System"]) async def detailed_health_check(): """健康检查""" - health = { - "status": "healthy", - "timestamp": datetime.now().isoformat(), - "components": {} - } + health = {"status": "healthy", "timestamp": datetime.now().isoformat(), "components": {}} # 数据库检查 if DB_AVAILABLE: @@ -8700,6 +8363,7 @@ async def detailed_health_check(): # ==================== Phase 8: Multi-Tenant SaaS API ==================== + # Pydantic Models for Tenant API class TenantCreate(BaseModel): name: str = Field(..., description="租户名称") @@ -8829,7 +8493,7 @@ async def create_tenant_endpoint(tenant: TenantCreate, request: Request, _=Depen # 获取当前用户ID(从请求状态或API Key) user_id = "" - if hasattr(request.state, 'api_key') and request.state.api_key: + if hasattr(request.state, "api_key") and request.state.api_key: user_id = request.state.api_key.created_by or "" try: @@ -8839,7 +8503,7 @@ async def create_tenant_endpoint(tenant: TenantCreate, request: Request, _=Depen created_by=user_id, description=tenant.description, plan=TenantTier(tenant.plan), - billing_email=tenant.billing_email + billing_email=tenant.billing_email, ) return new_tenant.to_dict() except ValueError as e: @@ -8848,11 +8512,7 @@ async def create_tenant_endpoint(tenant: TenantCreate, request: Request, _=Depen @app.get("/api/v1/tenants", response_model=list[TenantResponse], tags=["Tenants"]) async def list_tenants_endpoint( - status: str | None = None, - plan: str | None = None, - limit: int = 100, - offset: int = 0, - _=Depends(verify_api_key) + status: str | None = None, plan: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key) ): """列出租户""" if not TENANT_MANAGER_AVAILABLE: @@ -8863,12 +8523,7 @@ async def list_tenants_endpoint( status_enum = TenantStatus(status) if status else None plan_enum = TenantTier(plan) if plan else None - tenants = tenant_manager.list_tenants( - status=status_enum, - plan=plan_enum, - limit=limit, - offset=offset - ) + tenants = tenant_manager.list_tenants(status=status_enum, plan=plan_enum, limit=limit, offset=offset) return [t.to_dict() for t in tenants] @@ -9031,11 +8686,7 @@ async def get_tenant_branding_endpoint(tenant_id: str, _=Depends(verify_api_key) @app.put("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"]) -async def update_tenant_branding_endpoint( - tenant_id: str, - branding: TenantBrandingUpdate, - _=Depends(verify_api_key) -): +async def update_tenant_branding_endpoint(tenant_id: str, branding: TenantBrandingUpdate, _=Depends(verify_api_key)): """更新租户品牌配置""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -9064,17 +8715,13 @@ async def get_tenant_theme_css_endpoint(tenant_id: str): if not branding: raise HTTPException(status_code=404, detail="Branding not found") - from fastapi.responses import PlainTextResponse return PlainTextResponse(content=branding.get_theme_css(), media_type="text/css") # Tenant Member API @app.post("/api/v1/tenants/{tenant_id}/members/invite", response_model=TenantMemberResponse, tags=["Tenants"]) async def invite_tenant_member_endpoint( - tenant_id: str, - invite: TenantMemberInvite, - request: Request, - _=Depends(verify_api_key) + tenant_id: str, invite: TenantMemberInvite, request: Request, _=Depends(verify_api_key) ): """邀请成员加入租户""" if not TENANT_MANAGER_AVAILABLE: @@ -9084,7 +8731,7 @@ async def invite_tenant_member_endpoint( # 获取当前用户ID invited_by = "" - if hasattr(request.state, 'api_key') and request.state.api_key: + if hasattr(request.state, "api_key") and request.state.api_key: invited_by = request.state.api_key.created_by or "" try: @@ -9093,7 +8740,7 @@ async def invite_tenant_member_endpoint( email=invite.email, role=TenantRole(invite.role), invited_by=invited_by, - name=invite.name + name=invite.name, ) return member.to_dict() except ValueError as e: @@ -9117,10 +8764,7 @@ async def accept_invitation_endpoint(token: str, user_id: str): @app.get("/api/v1/tenants/{tenant_id}/members", response_model=list[TenantMemberResponse], tags=["Tenants"]) async def list_tenant_members_endpoint( - tenant_id: str, - status: str | None = None, - role: str | None = None, - _=Depends(verify_api_key) + tenant_id: str, status: str | None = None, role: str | None = None, _=Depends(verify_api_key) ): """列出租户成员""" if not TENANT_MANAGER_AVAILABLE: @@ -9137,11 +8781,7 @@ async def list_tenant_members_endpoint( @app.put("/api/v1/tenants/{tenant_id}/members/{member_id}/role", tags=["Tenants"]) async def update_member_role_endpoint( - tenant_id: str, - member_id: str, - role: str, - request: Request, - _=Depends(verify_api_key) + tenant_id: str, member_id: str, role: str, request: Request, _=Depends(verify_api_key) ): """更新成员角色""" if not TENANT_MANAGER_AVAILABLE: @@ -9151,15 +8791,12 @@ async def update_member_role_endpoint( # 获取当前用户ID updated_by = "" - if hasattr(request.state, 'api_key') and request.state.api_key: + if hasattr(request.state, "api_key") and request.state.api_key: updated_by = request.state.api_key.created_by or "" try: updated = tenant_manager.update_member_role( - tenant_id=tenant_id, - member_id=member_id, - new_role=TenantRole(role), - updated_by=updated_by + tenant_id=tenant_id, member_id=member_id, new_role=TenantRole(role), updated_by=updated_by ) if not updated: raise HTTPException(status_code=404, detail="Member not found") @@ -9169,12 +8806,7 @@ async def update_member_role_endpoint( @app.delete("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"]) -async def remove_tenant_member_endpoint( - tenant_id: str, - member_id: str, - request: Request, - _=Depends(verify_api_key) -): +async def remove_tenant_member_endpoint(tenant_id: str, member_id: str, request: Request, _=Depends(verify_api_key)): """移除租户成员""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -9183,7 +8815,7 @@ async def remove_tenant_member_endpoint( # 获取当前用户ID removed_by = "" - if hasattr(request.state, 'api_key') and request.state.api_key: + if hasattr(request.state, "api_key") and request.state.api_key: removed_by = request.state.api_key.created_by or "" try: @@ -9208,11 +8840,7 @@ async def list_tenant_roles_endpoint(tenant_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/tenants/{tenant_id}/roles", response_model=TenantRoleResponse, tags=["Tenants"]) -async def create_tenant_role_endpoint( - tenant_id: str, - role: TenantRoleCreate, - _=Depends(verify_api_key) -): +async def create_tenant_role_endpoint(tenant_id: str, role: TenantRoleCreate, _=Depends(verify_api_key)): """创建自定义角色""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") @@ -9221,10 +8849,7 @@ async def create_tenant_role_endpoint( try: new_role = tenant_manager.create_custom_role( - tenant_id=tenant_id, - name=role.name, - description=role.description, - permissions=role.permissions + tenant_id=tenant_id, name=role.name, description=role.description, permissions=role.permissions ) return new_role.to_dict() except ValueError as e: @@ -9233,10 +8858,7 @@ async def create_tenant_role_endpoint( @app.put("/api/v1/tenants/{tenant_id}/roles/{role_id}/permissions", tags=["Tenants"]) async def update_role_permissions_endpoint( - tenant_id: str, - role_id: str, - permissions: list[str], - _=Depends(verify_api_key) + tenant_id: str, role_id: str, permissions: list[str], _=Depends(verify_api_key) ): """更新角色权限""" if not TENANT_MANAGER_AVAILABLE: @@ -9277,32 +8899,20 @@ async def list_tenant_permissions_endpoint(_=Depends(verify_api_key)): raise HTTPException(status_code=500, detail="Tenant manager not available") tenant_manager = get_tenant_manager() - return { - "permissions": [ - {"id": k, "name": v} - for k, v in tenant_manager.PERMISSION_NAMES.items() - ] - } + return {"permissions": [{"id": k, "name": v} for k, v in tenant_manager.PERMISSION_NAMES.items()]} # Tenant Resolution API @app.get("/api/v1/tenants/resolve", tags=["Tenants"]) async def resolve_tenant_endpoint( - host: str | None = None, - slug: str | None = None, - tenant_id: str | None = None, - _=Depends(verify_api_key) + host: str | None = None, slug: str | None = None, tenant_id: str | None = None, _=Depends(verify_api_key) ): """从请求信息解析租户""" if not TENANT_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Tenant manager not available") tenant_manager = get_tenant_manager() - tenant = tenant_manager.resolve_tenant_from_request( - host=host, - slug=slug, - tenant_id=tenant_id - ) + tenant = tenant_manager.resolve_tenant_from_request(host=host, slug=slug, tenant_id=tenant_id) if not tenant: raise HTTPException(status_code=404, detail="Tenant not found") @@ -9329,6 +8939,7 @@ async def get_tenant_context_endpoint(tenant_id: str, _=Depends(verify_api_key)) # Phase 8 Task 2: Subscription & Billing APIs # ============================================ + # Pydantic Models for Subscription API class CreateSubscriptionRequest(BaseModel): plan_id: str = Field(..., description="订阅计划ID") @@ -9381,8 +8992,7 @@ class CreateCheckoutSessionRequest(BaseModel): # Subscription Plan APIs @app.get("/api/v1/subscription-plans", tags=["Subscriptions"]) async def list_subscription_plans( - include_inactive: bool = Query(default=False, description="包含已停用计划"), - _=Depends(verify_api_key) + include_inactive: bool = Query(default=False, description="包含已停用计划"), _=Depends(verify_api_key) ): """获取所有订阅计划""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9403,7 +9013,7 @@ async def list_subscription_plans( "currency": p.currency, "features": p.features, "limits": p.limits, - "is_active": p.is_active + "is_active": p.is_active, } for p in plans ] @@ -9411,10 +9021,7 @@ async def list_subscription_plans( @app.get("/api/v1/subscription-plans/{plan_id}", tags=["Subscriptions"]) -async def get_subscription_plan( - plan_id: str, - _=Depends(verify_api_key) -): +async def get_subscription_plan(plan_id: str, _=Depends(verify_api_key)): """获取订阅计划详情""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9436,7 +9043,7 @@ async def get_subscription_plan( "features": plan.features, "limits": plan.limits, "is_active": plan.is_active, - "created_at": plan.created_at.isoformat() + "created_at": plan.created_at.isoformat(), } @@ -9446,7 +9053,7 @@ async def create_subscription( tenant_id: str, request: CreateSubscriptionRequest, user_id: str = Header(..., description="当前用户ID"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """创建新订阅""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9459,7 +9066,7 @@ async def create_subscription( plan_id=request.plan_id, payment_provider=request.payment_provider, trial_days=request.trial_days, - billing_cycle=request.billing_cycle + billing_cycle=request.billing_cycle, ) return { @@ -9471,17 +9078,14 @@ async def create_subscription( "current_period_end": subscription.current_period_end.isoformat(), "trial_start": subscription.trial_start.isoformat() if subscription.trial_start else None, "trial_end": subscription.trial_end.isoformat() if subscription.trial_end else None, - "created_at": subscription.created_at.isoformat() + "created_at": subscription.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/tenants/{tenant_id}/subscription", tags=["Subscriptions"]) -async def get_tenant_subscription( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_tenant_subscription(tenant_id: str, _=Depends(verify_api_key)): """获取租户当前订阅""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9508,17 +9112,13 @@ async def get_tenant_subscription( "canceled_at": subscription.canceled_at.isoformat() if subscription.canceled_at else None, "trial_start": subscription.trial_start.isoformat() if subscription.trial_start else None, "trial_end": subscription.trial_end.isoformat() if subscription.trial_end else None, - "created_at": subscription.created_at.isoformat() + "created_at": subscription.created_at.isoformat(), } } @app.put("/api/v1/tenants/{tenant_id}/subscription/change-plan", tags=["Subscriptions"]) -async def change_subscription_plan( - tenant_id: str, - request: ChangePlanRequest, - _=Depends(verify_api_key) -): +async def change_subscription_plan(tenant_id: str, request: ChangePlanRequest, _=Depends(verify_api_key)): """更改订阅计划""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9531,27 +9131,21 @@ async def change_subscription_plan( try: updated = manager.change_plan( - subscription_id=subscription.id, - new_plan_id=request.new_plan_id, - prorate=request.prorate + subscription_id=subscription.id, new_plan_id=request.new_plan_id, prorate=request.prorate ) return { "id": updated.id, "plan_id": updated.plan_id, "status": updated.status, - "message": "Plan changed successfully" + "message": "Plan changed successfully", } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/api/v1/tenants/{tenant_id}/subscription/cancel", tags=["Subscriptions"]) -async def cancel_subscription( - tenant_id: str, - request: CancelSubscriptionRequest, - _=Depends(verify_api_key) -): +async def cancel_subscription(tenant_id: str, request: CancelSubscriptionRequest, _=Depends(verify_api_key)): """取消订阅""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9563,17 +9157,14 @@ async def cancel_subscription( raise HTTPException(status_code=404, detail="No active subscription found") try: - updated = manager.cancel_subscription( - subscription_id=subscription.id, - at_period_end=request.at_period_end - ) + updated = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=request.at_period_end) return { "id": updated.id, "status": updated.status, "cancel_at_period_end": updated.cancel_at_period_end, "canceled_at": updated.canceled_at.isoformat() if updated.canceled_at else None, - "message": "Subscription cancelled" + "message": "Subscription cancelled", } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -9581,11 +9172,7 @@ async def cancel_subscription( # Usage APIs @app.post("/api/v1/tenants/{tenant_id}/usage", tags=["Subscriptions"]) -async def record_usage( - tenant_id: str, - request: RecordUsageRequest, - _=Depends(verify_api_key) -): +async def record_usage(tenant_id: str, request: RecordUsageRequest, _=Depends(verify_api_key)): """记录用量""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9596,7 +9183,7 @@ async def record_usage( resource_type=request.resource_type, quantity=request.quantity, unit=request.unit, - description=request.description + description=request.description, ) return { @@ -9606,7 +9193,7 @@ async def record_usage( "quantity": record.quantity, "unit": record.unit, "cost": record.cost, - "recorded_at": record.recorded_at.isoformat() + "recorded_at": record.recorded_at.isoformat(), } @@ -9615,7 +9202,7 @@ async def get_usage_summary( tenant_id: str, start_date: str | None = Query(default=None, description="开始日期 (ISO格式)"), end_date: str | None = Query(default=None, description="结束日期 (ISO格式)"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取用量汇总""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9638,7 +9225,7 @@ async def list_payments( status: str | None = Query(default=None, description="支付状态过滤"), limit: int = Query(default=100, description="返回数量限制"), offset: int = Query(default=0, description="偏移量"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取支付记录列表""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9658,20 +9245,16 @@ async def list_payments( "payment_method": p.payment_method, "paid_at": p.paid_at.isoformat() if p.paid_at else None, "failed_at": p.failed_at.isoformat() if p.failed_at else None, - "created_at": p.created_at.isoformat() + "created_at": p.created_at.isoformat(), } for p in payments ], - "total": len(payments) + "total": len(payments), } @app.get("/api/v1/tenants/{tenant_id}/payments/{payment_id}", tags=["Subscriptions"]) -async def get_payment( - tenant_id: str, - payment_id: str, - _=Depends(verify_api_key) -): +async def get_payment(tenant_id: str, payment_id: str, _=Depends(verify_api_key)): """获取支付记录详情""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9696,7 +9279,7 @@ async def get_payment( "paid_at": payment.paid_at.isoformat() if payment.paid_at else None, "failed_at": payment.failed_at.isoformat() if payment.failed_at else None, "failure_reason": payment.failure_reason, - "created_at": payment.created_at.isoformat() + "created_at": payment.created_at.isoformat(), } @@ -9707,7 +9290,7 @@ async def list_invoices( status: str | None = Query(default=None, description="发票状态过滤"), limit: int = Query(default=100, description="返回数量限制"), offset: int = Query(default=0, description="偏移量"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取发票列表""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9730,20 +9313,16 @@ async def list_invoices( "description": inv.description, "due_date": inv.due_date.isoformat() if inv.due_date else None, "paid_at": inv.paid_at.isoformat() if inv.paid_at else None, - "created_at": inv.created_at.isoformat() + "created_at": inv.created_at.isoformat(), } for inv in invoices ], - "total": len(invoices) + "total": len(invoices), } @app.get("/api/v1/tenants/{tenant_id}/invoices/{invoice_id}", tags=["Subscriptions"]) -async def get_invoice( - tenant_id: str, - invoice_id: str, - _=Depends(verify_api_key) -): +async def get_invoice(tenant_id: str, invoice_id: str, _=Depends(verify_api_key)): """获取发票详情""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9769,7 +9348,7 @@ async def get_invoice( "paid_at": invoice.paid_at.isoformat() if invoice.paid_at else None, "voided_at": invoice.voided_at.isoformat() if invoice.voided_at else None, "void_reason": invoice.void_reason, - "created_at": invoice.created_at.isoformat() + "created_at": invoice.created_at.isoformat(), } @@ -9779,7 +9358,7 @@ async def request_refund( tenant_id: str, request: RequestRefundRequest, user_id: str = Header(..., description="当前用户ID"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """申请退款""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9792,7 +9371,7 @@ async def request_refund( payment_id=request.payment_id, amount=request.amount, reason=request.reason, - requested_by=user_id + requested_by=user_id, ) return { @@ -9802,7 +9381,7 @@ async def request_refund( "currency": refund.currency, "reason": refund.reason, "status": refund.status, - "requested_at": refund.requested_at.isoformat() + "requested_at": refund.requested_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -9814,7 +9393,7 @@ async def list_refunds( status: str | None = Query(default=None, description="退款状态过滤"), limit: int = Query(default=100, description="返回数量限制"), offset: int = Query(default=0, description="偏移量"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取退款记录列表""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9836,11 +9415,11 @@ async def list_refunds( "requested_at": r.requested_at.isoformat(), "approved_by": r.approved_by, "approved_at": r.approved_at.isoformat() if r.approved_at else None, - "completed_at": r.completed_at.isoformat() if r.completed_at else None + "completed_at": r.completed_at.isoformat() if r.completed_at else None, } for r in refunds ], - "total": len(refunds) + "total": len(refunds), } @@ -9850,7 +9429,7 @@ async def process_refund( refund_id: str, request: ProcessRefundRequest, user_id: str = Header(..., description="当前用户ID"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """处理退款申请(管理员)""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9866,11 +9445,7 @@ async def process_refund( # 自动完成退款(简化实现) refund = manager.complete_refund(refund_id) - return { - "id": refund.id, - "status": refund.status, - "message": "Refund approved and processed" - } + return {"id": refund.id, "status": refund.status, "message": "Refund approved and processed"} elif request.action == "reject": if not request.reason: @@ -9880,11 +9455,7 @@ async def process_refund( if not refund: raise HTTPException(status_code=404, detail="Refund not found") - return { - "id": refund.id, - "status": refund.status, - "message": "Refund rejected" - } + return {"id": refund.id, "status": refund.status, "message": "Refund rejected"} else: raise HTTPException(status_code=400, detail="Invalid action") @@ -9898,7 +9469,7 @@ async def get_billing_history( end_date: str | None = Query(default=None, description="结束日期 (ISO格式)"), limit: int = Query(default=100, description="返回数量限制"), offset: int = Query(default=0, description="偏移量"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """获取账单历史""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9921,21 +9492,17 @@ async def get_billing_history( "description": h.description, "reference_id": h.reference_id, "balance_after": h.balance_after, - "created_at": h.created_at.isoformat() + "created_at": h.created_at.isoformat(), } for h in history ], - "total": len(history) + "total": len(history), } # Payment Provider Integration APIs @app.post("/api/v1/tenants/{tenant_id}/checkout/stripe", tags=["Subscriptions"]) -async def create_stripe_checkout( - tenant_id: str, - request: CreateCheckoutSessionRequest, - _=Depends(verify_api_key) -): +async def create_stripe_checkout(tenant_id: str, request: CreateCheckoutSessionRequest, _=Depends(verify_api_key)): """创建 Stripe Checkout 会话""" if not SUBSCRIPTION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Subscription manager not available") @@ -9948,7 +9515,7 @@ async def create_stripe_checkout( plan_id=request.plan_id, success_url=request.success_url, cancel_url=request.cancel_url, - billing_cycle=request.billing_cycle + billing_cycle=request.billing_cycle, ) return session @@ -9961,7 +9528,7 @@ async def create_alipay_order( tenant_id: str, plan_id: str, billing_cycle: str = Query(default="monthly", description="计费周期"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """创建支付宝订单""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9970,11 +9537,7 @@ async def create_alipay_order( manager = get_subscription_manager() try: - order = manager.create_alipay_order( - tenant_id=tenant_id, - plan_id=plan_id, - billing_cycle=billing_cycle - ) + order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=plan_id, billing_cycle=billing_cycle) return order except Exception as e: @@ -9986,7 +9549,7 @@ async def create_wechat_order( tenant_id: str, plan_id: str, billing_cycle: str = Query(default="monthly", description="计费周期"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """创建微信支付订单""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9995,11 +9558,7 @@ async def create_wechat_order( manager = get_subscription_manager() try: - order = manager.create_wechat_order( - tenant_id=tenant_id, - plan_id=plan_id, - billing_cycle=billing_cycle - ) + order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=plan_id, billing_cycle=billing_cycle) return order except Exception as e: @@ -10062,6 +9621,7 @@ async def wechat_webhook(request: Request): # Pydantic Models for Enterprise + class SSOConfigCreate(BaseModel): provider: str = Field(..., description="SSO 提供商: wechat_work/dingtalk/feishu/okta/azure_ad/google/custom_saml") entity_id: str | None = Field(default=None, description="SAML Entity ID") @@ -10158,12 +9718,9 @@ class RetentionPolicyUpdate(BaseModel): # SSO/SAML APIs + @app.post("/api/v1/tenants/{tenant_id}/sso-configs", tags=["Enterprise"]) -async def create_sso_config_endpoint( - tenant_id: str, - config: SSOConfigCreate, - _=Depends(verify_api_key) -): +async def create_sso_config_endpoint(tenant_id: str, config: SSOConfigCreate, _=Depends(verify_api_key)): """创建 SSO 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10189,7 +9746,7 @@ async def create_sso_config_endpoint( attribute_mapping=config.attribute_mapping, auto_provision=config.auto_provision, default_role=config.default_role, - domain_restriction=config.domain_restriction + domain_restriction=config.domain_restriction, ) return { @@ -10203,17 +9760,14 @@ async def create_sso_config_endpoint( "scopes": sso_config.scopes, "auto_provision": sso_config.auto_provision, "default_role": sso_config.default_role, - "created_at": sso_config.created_at.isoformat() + "created_at": sso_config.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/tenants/{tenant_id}/sso-configs", tags=["Enterprise"]) -async def list_sso_configs_endpoint( - tenant_id: str, - _=Depends(verify_api_key) -): +async def list_sso_configs_endpoint(tenant_id: str, _=Depends(verify_api_key)): """列出租户的所有 SSO 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10232,20 +9786,16 @@ async def list_sso_configs_endpoint( "authorization_url": c.authorization_url, "auto_provision": c.auto_provision, "default_role": c.default_role, - "created_at": c.created_at.isoformat() + "created_at": c.created_at.isoformat(), } for c in configs ], - "total": len(configs) + "total": len(configs), } @app.get("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"]) -async def get_sso_config_endpoint( - tenant_id: str, - config_id: str, - _=Depends(verify_api_key) -): +async def get_sso_config_endpoint(tenant_id: str, config_id: str, _=Depends(verify_api_key)): """获取 SSO 配置详情""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10274,16 +9824,13 @@ async def get_sso_config_endpoint( "default_role": config.default_role, "domain_restriction": config.domain_restriction, "created_at": config.created_at.isoformat(), - "updated_at": config.updated_at.isoformat() + "updated_at": config.updated_at.isoformat(), } @app.put("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"]) async def update_sso_config_endpoint( - tenant_id: str, - config_id: str, - update: SSOConfigUpdate, - _=Depends(verify_api_key) + tenant_id: str, config_id: str, update: SSOConfigUpdate, _=Depends(verify_api_key) ): """更新 SSO 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10296,23 +9843,14 @@ async def update_sso_config_endpoint( raise HTTPException(status_code=404, detail="SSO config not found") updated = manager.update_sso_config( - config_id=config_id, - **{k: v for k, v in update.dict().items() if v is not None} + config_id=config_id, **{k: v for k, v in update.dict().items() if v is not None} ) - return { - "id": updated.id, - "status": updated.status, - "updated_at": updated.updated_at.isoformat() - } + return {"id": updated.id, "status": updated.status, "updated_at": updated.updated_at.isoformat()} @app.delete("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"]) -async def delete_sso_config_endpoint( - tenant_id: str, - config_id: str, - _=Depends(verify_api_key) -): +async def delete_sso_config_endpoint(tenant_id: str, config_id: str, _=Depends(verify_api_key)): """删除 SSO 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10329,10 +9867,7 @@ async def delete_sso_config_endpoint( @app.get("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}/metadata", tags=["Enterprise"]) async def get_sso_metadata_endpoint( - tenant_id: str, - config_id: str, - base_url: str = Query(..., description="服务基础 URL"), - _=Depends(verify_api_key) + tenant_id: str, config_id: str, base_url: str = Query(..., description="服务基础 URL"), _=Depends(verify_api_key) ): """获取 SAML Service Provider 元数据""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10350,18 +9885,15 @@ async def get_sso_metadata_endpoint( "metadata_xml": metadata, "entity_id": f"{base_url}/api/v1/sso/saml/{tenant_id}", "acs_url": f"{base_url}/api/v1/sso/saml/{tenant_id}/acs", - "slo_url": f"{base_url}/api/v1/sso/saml/{tenant_id}/slo" + "slo_url": f"{base_url}/api/v1/sso/saml/{tenant_id}/slo", } # SCIM APIs + @app.post("/api/v1/tenants/{tenant_id}/scim-configs", tags=["Enterprise"]) -async def create_scim_config_endpoint( - tenant_id: str, - config: SCIMConfigCreate, - _=Depends(verify_api_key) -): +async def create_scim_config_endpoint(tenant_id: str, config: SCIMConfigCreate, _=Depends(verify_api_key)): """创建 SCIM 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10376,7 +9908,7 @@ async def create_scim_config_endpoint( scim_token=config.scim_token, sync_interval_minutes=config.sync_interval_minutes, attribute_mapping=config.attribute_mapping, - sync_rules=config.sync_rules + sync_rules=config.sync_rules, ) return { @@ -10386,17 +9918,14 @@ async def create_scim_config_endpoint( "status": scim_config.status, "scim_base_url": scim_config.scim_base_url, "sync_interval_minutes": scim_config.sync_interval_minutes, - "created_at": scim_config.created_at.isoformat() + "created_at": scim_config.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/tenants/{tenant_id}/scim-configs", tags=["Enterprise"]) -async def get_scim_config_endpoint( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_scim_config_endpoint(tenant_id: str, _=Depends(verify_api_key)): """获取租户的 SCIM 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10417,16 +9946,13 @@ async def get_scim_config_endpoint( "last_sync_at": config.last_sync_at.isoformat() if config.last_sync_at else None, "last_sync_status": config.last_sync_status, "last_sync_users_count": config.last_sync_users_count, - "created_at": config.created_at.isoformat() + "created_at": config.created_at.isoformat(), } @app.put("/api/v1/tenants/{tenant_id}/scim-configs/{config_id}", tags=["Enterprise"]) async def update_scim_config_endpoint( - tenant_id: str, - config_id: str, - update: SCIMConfigUpdate, - _=Depends(verify_api_key) + tenant_id: str, config_id: str, update: SCIMConfigUpdate, _=Depends(verify_api_key) ): """更新 SCIM 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10439,23 +9965,14 @@ async def update_scim_config_endpoint( raise HTTPException(status_code=404, detail="SCIM config not found") updated = manager.update_scim_config( - config_id=config_id, - **{k: v for k, v in update.dict().items() if v is not None} + config_id=config_id, **{k: v for k, v in update.dict().items() if v is not None} ) - return { - "id": updated.id, - "status": updated.status, - "updated_at": updated.updated_at.isoformat() - } + return {"id": updated.id, "status": updated.status, "updated_at": updated.updated_at.isoformat()} @app.post("/api/v1/tenants/{tenant_id}/scim-configs/{config_id}/sync", tags=["Enterprise"]) -async def sync_scim_users_endpoint( - tenant_id: str, - config_id: str, - _=Depends(verify_api_key) -): +async def sync_scim_users_endpoint(tenant_id: str, config_id: str, _=Depends(verify_api_key)): """执行 SCIM 用户同步""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10473,9 +9990,7 @@ async def sync_scim_users_endpoint( @app.get("/api/v1/tenants/{tenant_id}/scim-users", tags=["Enterprise"]) async def list_scim_users_endpoint( - tenant_id: str, - active_only: bool = Query(default=True, description="仅显示活跃用户"), - _=Depends(verify_api_key) + tenant_id: str, active_only: bool = Query(default=True, description="仅显示活跃用户"), _=Depends(verify_api_key) ): """列出 SCIM 用户""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10494,22 +10009,23 @@ async def list_scim_users_endpoint( "display_name": u.display_name, "active": u.active, "groups": u.groups, - "synced_at": u.synced_at.isoformat() + "synced_at": u.synced_at.isoformat(), } for u in users ], - "total": len(users) + "total": len(users), } # Audit Log Export APIs + @app.post("/api/v1/tenants/{tenant_id}/audit-exports", tags=["Enterprise"]) async def create_audit_export_endpoint( tenant_id: str, request: AuditExportCreate, current_user: str = Header(default="user", description="当前用户ID"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """创建审计日志导出任务""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10528,7 +10044,7 @@ async def create_audit_export_endpoint( end_date=end_date, created_by=current_user, filters=request.filters, - compliance_standard=request.compliance_standard + compliance_standard=request.compliance_standard, ) return { @@ -10540,7 +10056,7 @@ async def create_audit_export_endpoint( "compliance_standard": export.compliance_standard, "status": export.status, "expires_at": export.expires_at.isoformat() if export.expires_at else None, - "created_at": export.created_at.isoformat() + "created_at": export.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -10548,9 +10064,7 @@ async def create_audit_export_endpoint( @app.get("/api/v1/tenants/{tenant_id}/audit-exports", tags=["Enterprise"]) async def list_audit_exports_endpoint( - tenant_id: str, - limit: int = Query(default=100, description="返回数量限制"), - _=Depends(verify_api_key) + tenant_id: str, limit: int = Query(default=100, description="返回数量限制"), _=Depends(verify_api_key) ): """列出审计日志导出记录""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10572,20 +10086,16 @@ async def list_audit_exports_endpoint( "record_count": e.record_count, "downloaded_by": e.downloaded_by, "expires_at": e.expires_at.isoformat() if e.expires_at else None, - "created_at": e.created_at.isoformat() + "created_at": e.created_at.isoformat(), } for e in exports ], - "total": len(exports) + "total": len(exports), } @app.get("/api/v1/tenants/{tenant_id}/audit-exports/{export_id}", tags=["Enterprise"]) -async def get_audit_export_endpoint( - tenant_id: str, - export_id: str, - _=Depends(verify_api_key) -): +async def get_audit_export_endpoint(tenant_id: str, export_id: str, _=Depends(verify_api_key)): """获取审计日志导出详情""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10612,7 +10122,7 @@ async def get_audit_export_endpoint( "expires_at": export.expires_at.isoformat() if export.expires_at else None, "created_at": export.created_at.isoformat(), "completed_at": export.completed_at.isoformat() if export.completed_at else None, - "error_message": export.error_message + "error_message": export.error_message, } @@ -10621,7 +10131,7 @@ async def download_audit_export_endpoint( tenant_id: str, export_id: str, current_user: str = Header(default="user", description="当前用户ID"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """下载审计日志导出文件""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10642,18 +10152,15 @@ async def download_audit_export_endpoint( # 返回文件下载信息 return { "download_url": f"/api/v1/tenants/{tenant_id}/audit-exports/{export_id}/file", - "expires_at": export.expires_at.isoformat() if export.expires_at else None + "expires_at": export.expires_at.isoformat() if export.expires_at else None, } # Data Retention Policy APIs + @app.post("/api/v1/tenants/{tenant_id}/retention-policies", tags=["Enterprise"]) -async def create_retention_policy_endpoint( - tenant_id: str, - policy: RetentionPolicyCreate, - _=Depends(verify_api_key) -): +async def create_retention_policy_endpoint(tenant_id: str, policy: RetentionPolicyCreate, _=Depends(verify_api_key)): """创建数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10673,7 +10180,7 @@ async def create_retention_policy_endpoint( execute_at=policy.execute_at, notify_before_days=policy.notify_before_days, archive_location=policy.archive_location, - archive_encryption=policy.archive_encryption + archive_encryption=policy.archive_encryption, ) return { @@ -10685,7 +10192,7 @@ async def create_retention_policy_endpoint( "action": new_policy.action, "auto_execute": new_policy.auto_execute, "is_active": new_policy.is_active, - "created_at": new_policy.created_at.isoformat() + "created_at": new_policy.created_at.isoformat(), } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -10695,7 +10202,7 @@ async def create_retention_policy_endpoint( async def list_retention_policies_endpoint( tenant_id: str, resource_type: str | None = Query(default=None, description="资源类型过滤"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """列出数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10714,20 +10221,16 @@ async def list_retention_policies_endpoint( "action": p.action, "auto_execute": p.auto_execute, "is_active": p.is_active, - "last_executed_at": p.last_executed_at.isoformat() if p.last_executed_at else None + "last_executed_at": p.last_executed_at.isoformat() if p.last_executed_at else None, } for p in policies ], - "total": len(policies) + "total": len(policies), } @app.get("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"]) -async def get_retention_policy_endpoint( - tenant_id: str, - policy_id: str, - _=Depends(verify_api_key) -): +async def get_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depends(verify_api_key)): """获取数据保留策略详情""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10755,16 +10258,13 @@ async def get_retention_policy_endpoint( "is_active": policy.is_active, "last_executed_at": policy.last_executed_at.isoformat() if policy.last_executed_at else None, "last_execution_result": policy.last_execution_result, - "created_at": policy.created_at.isoformat() + "created_at": policy.created_at.isoformat(), } @app.put("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"]) async def update_retention_policy_endpoint( - tenant_id: str, - policy_id: str, - update: RetentionPolicyUpdate, - _=Depends(verify_api_key) + tenant_id: str, policy_id: str, update: RetentionPolicyUpdate, _=Depends(verify_api_key) ): """更新数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10777,22 +10277,14 @@ async def update_retention_policy_endpoint( raise HTTPException(status_code=404, detail="Policy not found") updated = manager.update_retention_policy( - policy_id=policy_id, - **{k: v for k, v in update.dict().items() if v is not None} + policy_id=policy_id, **{k: v for k, v in update.dict().items() if v is not None} ) - return { - "id": updated.id, - "updated_at": updated.updated_at.isoformat() - } + return {"id": updated.id, "updated_at": updated.updated_at.isoformat()} @app.delete("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"]) -async def delete_retention_policy_endpoint( - tenant_id: str, - policy_id: str, - _=Depends(verify_api_key) -): +async def delete_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depends(verify_api_key)): """删除数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10808,11 +10300,7 @@ async def delete_retention_policy_endpoint( @app.post("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}/execute", tags=["Enterprise"]) -async def execute_retention_policy_endpoint( - tenant_id: str, - policy_id: str, - _=Depends(verify_api_key) -): +async def execute_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depends(verify_api_key)): """执行数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Enterprise manager not available") @@ -10830,7 +10318,7 @@ async def execute_retention_policy_endpoint( "policy_id": job.policy_id, "status": job.status, "started_at": job.started_at.isoformat() if job.started_at else None, - "created_at": job.created_at.isoformat() + "created_at": job.created_at.isoformat(), } @@ -10839,7 +10327,7 @@ async def list_retention_jobs_endpoint( tenant_id: str, policy_id: str, limit: int = Query(default=100, description="返回数量限制"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """列出数据保留任务""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10863,11 +10351,11 @@ async def list_retention_jobs_endpoint( "affected_records": j.affected_records, "archived_records": j.archived_records, "deleted_records": j.deleted_records, - "error_count": j.error_count + "error_count": j.error_count, } for j in jobs ], - "total": len(jobs) + "total": len(jobs), } @@ -10875,6 +10363,7 @@ async def list_retention_jobs_endpoint( # Phase 8 Task 7: Globalization & Localization API # ============================================ + # Pydantic Models for Localization API class TranslationCreate(BaseModel): key: str = Field(..., description="翻译键") @@ -10938,10 +10427,7 @@ class ConvertTimezoneRequest(BaseModel): # Translation APIs @app.get("/api/v1/translations/{language}/{key}", tags=["Localization"]) async def get_translation( - language: str, - key: str, - namespace: str = Query(default="common", description="命名空间"), - _=Depends(verify_api_key) + language: str, key: str, namespace: str = Query(default="common", description="命名空间"), _=Depends(verify_api_key) ): """获取翻译""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -10953,31 +10439,18 @@ async def get_translation( if value is None: raise HTTPException(status_code=404, detail="Translation not found") - return { - "key": key, - "language": language, - "namespace": namespace, - "value": value - } + return {"key": key, "language": language, "namespace": namespace, "value": value} @app.post("/api/v1/translations/{language}", tags=["Localization"]) -async def create_translation( - language: str, - request: TranslationCreate, - _=Depends(verify_api_key) -): +async def create_translation(language: str, request: TranslationCreate, _=Depends(verify_api_key)): """创建/更新翻译""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") manager = get_localization_manager() translation = manager.set_translation( - key=request.key, - language=language, - value=request.value, - namespace=request.namespace, - context=request.context + key=request.key, language=language, value=request.value, namespace=request.namespace, context=request.context ) return { @@ -10986,7 +10459,7 @@ async def create_translation( "language": translation.language, "namespace": translation.namespace, "value": translation.value, - "created_at": translation.created_at.isoformat() + "created_at": translation.created_at.isoformat(), } @@ -10996,7 +10469,7 @@ async def update_translation( key: str, request: TranslationUpdate, namespace: str = Query(default="common", description="命名空间"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """更新翻译""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11004,11 +10477,7 @@ async def update_translation( manager = get_localization_manager() translation = manager.set_translation( - key=key, - language=language, - value=request.value, - namespace=namespace, - context=request.context + key=key, language=language, value=request.value, namespace=namespace, context=request.context ) return { @@ -11017,16 +10486,13 @@ async def update_translation( "language": translation.language, "namespace": translation.namespace, "value": translation.value, - "updated_at": translation.updated_at.isoformat() + "updated_at": translation.updated_at.isoformat(), } @app.delete("/api/v1/translations/{language}/{key}", tags=["Localization"]) async def delete_translation( - language: str, - key: str, - namespace: str = Query(default="common", description="命名空间"), - _=Depends(verify_api_key) + language: str, key: str, namespace: str = Query(default="common", description="命名空间"), _=Depends(verify_api_key) ): """删除翻译""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11047,7 +10513,7 @@ async def list_translations( namespace: str | None = Query(default=None, description="命名空间"), limit: int = Query(default=1000, description="返回数量限制"), offset: int = Query(default=0, description="偏移量"), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """列出翻译""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11065,19 +10531,17 @@ async def list_translations( "namespace": t.namespace, "value": t.value, "is_reviewed": t.is_reviewed, - "updated_at": t.updated_at.isoformat() + "updated_at": t.updated_at.isoformat(), } for t in translations ], - "total": len(translations) + "total": len(translations), } # Language APIs @app.get("/api/v1/languages", tags=["Localization"]) -async def list_languages( - active_only: bool = Query(default=True, description="仅返回激活的语言") -): +async def list_languages(active_only: bool = Query(default=True, description="仅返回激活的语言")): """列出支持的语言""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") @@ -11096,11 +10560,11 @@ async def list_languages( "is_default": lang.is_default, "date_format": lang.date_format, "time_format": lang.time_format, - "calendar_type": lang.calendar_type + "calendar_type": lang.calendar_type, } for lang in languages ], - "total": len(languages) + "total": len(languages), } @@ -11130,7 +10594,7 @@ async def get_language(code: str): "number_format": lang.number_format, "currency_format": lang.currency_format, "first_day_of_week": lang.first_day_of_week, - "calendar_type": lang.calendar_type + "calendar_type": lang.calendar_type, } @@ -11138,7 +10602,7 @@ async def get_language(code: str): @app.get("/api/v1/data-centers", tags=["Localization"]) async def list_data_centers( status: str | None = Query(default=None, description="状态过滤"), - region: str | None = Query(default=None, description="区域过滤") + region: str | None = Query(default=None, description="区域过滤"), ): """列出数据中心""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11157,11 +10621,11 @@ async def list_data_centers( "endpoint": dc.endpoint, "status": dc.status, "priority": dc.priority, - "supported_regions": dc.supported_regions + "supported_regions": dc.supported_regions, } for dc in data_centers ], - "total": len(data_centers) + "total": len(data_centers), } @@ -11186,15 +10650,12 @@ async def get_data_center(dc_id: str): "status": dc.status, "priority": dc.priority, "supported_regions": dc.supported_regions, - "capabilities": dc.capabilities + "capabilities": dc.capabilities, } @app.get("/api/v1/tenants/{tenant_id}/data-center", tags=["Localization"]) -async def get_tenant_data_center( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_tenant_data_center(tenant_id: str, _=Depends(verify_api_key)): """获取租户数据中心配置""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") @@ -11214,37 +10675,39 @@ async def get_tenant_data_center( "tenant_id": mapping.tenant_id, "region_code": mapping.region_code, "data_residency": mapping.data_residency, - "primary_dc": { - "id": primary_dc.id, - "region_code": primary_dc.region_code, - "name": primary_dc.name, - "endpoint": primary_dc.endpoint - } if primary_dc else None, - "secondary_dc": { - "id": secondary_dc.id, - "region_code": secondary_dc.region_code, - "name": secondary_dc.name, - "endpoint": secondary_dc.endpoint - } if secondary_dc else None, - "created_at": mapping.created_at.isoformat() + "primary_dc": ( + { + "id": primary_dc.id, + "region_code": primary_dc.region_code, + "name": primary_dc.name, + "endpoint": primary_dc.endpoint, + } + if primary_dc + else None + ), + "secondary_dc": ( + { + "id": secondary_dc.id, + "region_code": secondary_dc.region_code, + "name": secondary_dc.name, + "endpoint": secondary_dc.endpoint, + } + if secondary_dc + else None + ), + "created_at": mapping.created_at.isoformat(), } @app.post("/api/v1/tenants/{tenant_id}/data-center", tags=["Localization"]) -async def set_tenant_data_center( - tenant_id: str, - request: DataCenterMappingRequest, - _=Depends(verify_api_key) -): +async def set_tenant_data_center(tenant_id: str, request: DataCenterMappingRequest, _=Depends(verify_api_key)): """设置租户数据中心""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") manager = get_localization_manager() mapping = manager.set_tenant_data_center( - tenant_id=tenant_id, - region_code=request.region_code, - data_residency=request.data_residency + tenant_id=tenant_id, region_code=request.region_code, data_residency=request.data_residency ) return { @@ -11252,7 +10715,7 @@ async def set_tenant_data_center( "tenant_id": mapping.tenant_id, "region_code": mapping.region_code, "data_residency": mapping.data_residency, - "created_at": mapping.created_at.isoformat() + "created_at": mapping.created_at.isoformat(), } @@ -11261,7 +10724,7 @@ async def set_tenant_data_center( async def list_payment_methods( country_code: str | None = Query(default=None, description="国家代码"), currency: str | None = Query(default=None, description="货币代码"), - active_only: bool = Query(default=True, description="仅返回激活的支付方式") + active_only: bool = Query(default=True, description="仅返回激活的支付方式"), ): """列出支付方式""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11282,18 +10745,17 @@ async def list_payment_methods( "is_active": m.is_active, "display_order": m.display_order, "min_amount": m.min_amount, - "max_amount": m.max_amount + "max_amount": m.max_amount, } for m in methods ], - "total": len(methods) + "total": len(methods), } @app.get("/api/v1/payment-methods/localized", tags=["Localization"]) async def get_localized_payment_methods( - country_code: str = Query(..., description="国家代码"), - language: str = Query(default="en", description="语言代码") + country_code: str = Query(..., description="国家代码"), language: str = Query(default="en", description="语言代码") ): """获取本地化的支付方式列表""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11302,18 +10764,14 @@ async def get_localized_payment_methods( manager = get_localization_manager() methods = manager.get_localized_payment_methods(country_code, language) - return { - "country_code": country_code, - "language": language, - "payment_methods": methods - } + return {"country_code": country_code, "language": language, "payment_methods": methods} # Country APIs @app.get("/api/v1/countries", tags=["Localization"]) async def list_countries( region: str | None = Query(default=None, description="区域过滤"), - active_only: bool = Query(default=True, description="仅返回激活的国家") + active_only: bool = Query(default=True, description="仅返回激活的国家"), ): """列出国家/地区""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11333,11 +10791,11 @@ async def list_countries( "default_currency": c.default_currency, "timezone": c.timezone, "calendar_type": c.calendar_type, - "vat_rate": c.vat_rate + "vat_rate": c.vat_rate, } for c in countries ], - "total": len(countries) + "total": len(countries), } @@ -11365,16 +10823,13 @@ async def get_country(code: str): "supported_currencies": country.supported_currencies, "timezone": country.timezone, "calendar_type": country.calendar_type, - "vat_rate": country.vat_rate + "vat_rate": country.vat_rate, } # Localization Settings APIs @app.get("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"]) -async def get_localization_settings( - tenant_id: str, - _=Depends(verify_api_key) -): +async def get_localization_settings(tenant_id: str, _=Depends(verify_api_key)): """获取租户本地化设置""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") @@ -11399,16 +10854,12 @@ async def get_localization_settings( "first_day_of_week": settings.first_day_of_week, "region_code": settings.region_code, "data_residency": settings.data_residency, - "updated_at": settings.updated_at.isoformat() + "updated_at": settings.updated_at.isoformat(), } @app.post("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"]) -async def create_localization_settings( - tenant_id: str, - request: LocalizationSettingsCreate, - _=Depends(verify_api_key) -): +async def create_localization_settings(tenant_id: str, request: LocalizationSettingsCreate, _=Depends(verify_api_key)): """创建租户本地化设置""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") @@ -11422,7 +10873,7 @@ async def create_localization_settings( supported_currencies=request.supported_currencies, default_timezone=request.default_timezone, region_code=request.region_code, - data_residency=request.data_residency + data_residency=request.data_residency, ) return { @@ -11435,16 +10886,12 @@ async def create_localization_settings( "default_timezone": settings.default_timezone, "region_code": settings.region_code, "data_residency": settings.data_residency, - "created_at": settings.created_at.isoformat() + "created_at": settings.created_at.isoformat(), } @app.put("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"]) -async def update_localization_settings( - tenant_id: str, - request: LocalizationSettingsUpdate, - _=Depends(verify_api_key) -): +async def update_localization_settings(tenant_id: str, request: LocalizationSettingsUpdate, _=Depends(verify_api_key)): """更新租户本地化设置""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") @@ -11467,15 +10914,14 @@ async def update_localization_settings( "default_timezone": settings.default_timezone, "region_code": settings.region_code, "data_residency": settings.data_residency, - "updated_at": settings.updated_at.isoformat() + "updated_at": settings.updated_at.isoformat(), } # Formatting APIs @app.post("/api/v1/format/datetime", tags=["Localization"]) async def format_datetime_endpoint( - request: FormatDateTimeRequest, - language: str = Query(default="en", description="语言代码") + request: FormatDateTimeRequest, language: str = Query(default="en", description="语言代码") ): """格式化日期时间""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11484,15 +10930,12 @@ async def format_datetime_endpoint( manager = get_localization_manager() try: - dt = datetime.fromisoformat(request.timestamp.replace('Z', '+00:00')) + dt = datetime.fromisoformat(request.timestamp.replace("Z", "+00:00")) except ValueError: raise HTTPException(status_code=400, detail="Invalid timestamp format") formatted = manager.format_datetime( - dt=dt, - language=language, - timezone=request.timezone, - format_type=request.format_type + dt=dt, language=language, timezone=request.timezone, format_type=request.format_type ) return { @@ -11500,61 +10943,40 @@ async def format_datetime_endpoint( "formatted": formatted, "language": language, "timezone": request.timezone, - "format_type": request.format_type + "format_type": request.format_type, } @app.post("/api/v1/format/number", tags=["Localization"]) async def format_number_endpoint( - request: FormatNumberRequest, - language: str = Query(default="en", description="语言代码") + request: FormatNumberRequest, language: str = Query(default="en", description="语言代码") ): """格式化数字""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") manager = get_localization_manager() - formatted = manager.format_number( - number=request.number, - language=language, - decimal_places=request.decimal_places - ) + formatted = manager.format_number(number=request.number, language=language, decimal_places=request.decimal_places) - return { - "original": request.number, - "formatted": formatted, - "language": language - } + return {"original": request.number, "formatted": formatted, "language": language} @app.post("/api/v1/format/currency", tags=["Localization"]) async def format_currency_endpoint( - request: FormatCurrencyRequest, - language: str = Query(default="en", description="语言代码") + request: FormatCurrencyRequest, language: str = Query(default="en", description="语言代码") ): """格式化货币""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") manager = get_localization_manager() - formatted = manager.format_currency( - amount=request.amount, - currency=request.currency, - language=language - ) + formatted = manager.format_currency(amount=request.amount, currency=request.currency, language=language) - return { - "original": request.amount, - "currency": request.currency, - "formatted": formatted, - "language": language - } + return {"original": request.amount, "currency": request.currency, "formatted": formatted, "language": language} @app.post("/api/v1/convert/timezone", tags=["Localization"]) -async def convert_timezone_endpoint( - request: ConvertTimezoneRequest -): +async def convert_timezone_endpoint(request: ConvertTimezoneRequest): """转换时区""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") @@ -11562,47 +10984,38 @@ async def convert_timezone_endpoint( manager = get_localization_manager() try: - dt = datetime.fromisoformat(request.timestamp.replace('Z', '+00:00')) + dt = datetime.fromisoformat(request.timestamp.replace("Z", "+00:00")) except ValueError: raise HTTPException(status_code=400, detail="Invalid timestamp format") - converted = manager.convert_timezone( - dt=dt, - from_tz=request.from_tz, - to_tz=request.to_tz - ) + converted = manager.convert_timezone(dt=dt, from_tz=request.from_tz, to_tz=request.to_tz) return { "original": request.timestamp, "from_timezone": request.from_tz, "to_timezone": request.to_tz, - "converted": converted.isoformat() + "converted": converted.isoformat(), } @app.get("/api/v1/detect/locale", tags=["Localization"]) async def detect_locale( accept_language: str | None = Header(default=None, description="Accept-Language 头"), - ip_country: str | None = Query(default=None, description="IP国家代码") + ip_country: str | None = Query(default=None, description="IP国家代码"), ): """检测用户本地化偏好""" if not LOCALIZATION_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="Localization manager not available") manager = get_localization_manager() - preferences = manager.detect_user_preferences( - accept_language=accept_language, - ip_country=ip_country - ) + preferences = manager.detect_user_preferences(accept_language=accept_language, ip_country=ip_country) return preferences @app.get("/api/v1/calendar/{calendar_type}", tags=["Localization"]) async def get_calendar_info( - calendar_type: str, - year: int = Query(..., description="年份"), - month: int = Query(..., description="月份") + calendar_type: str, year: int = Query(..., description="年份"), month: int = Query(..., description="月份") ): """获取日历信息""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11618,6 +11031,7 @@ async def get_calendar_info( # Phase 8 Task 4: AI 能力增强 API # ============================================ + class CreateCustomModelRequest(BaseModel): name: str description: str @@ -11690,9 +11104,7 @@ class PredictionFeedbackRequest(BaseModel): # 自定义模型管理 API @app.post("/api/v1/tenants/{tenant_id}/ai/custom-models", tags=["AI Enhancement"]) async def create_custom_model( - tenant_id: str, - request: CreateCustomModelRequest, - created_by: str = Query(..., description="创建者ID") + tenant_id: str, request: CreateCustomModelRequest, created_by: str = Query(..., description="创建者ID") ): """创建自定义模型""" if not AI_MANAGER_AVAILABLE: @@ -11708,14 +11120,14 @@ async def create_custom_model( model_type=ModelType(request.model_type), training_data=request.training_data, hyperparameters=request.hyperparameters, - created_by=created_by + created_by=created_by, ) return { "id": model.id, "name": model.name, "model_type": model.model_type.value, "status": model.status.value, - "created_at": model.created_at + "created_at": model.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -11725,7 +11137,7 @@ async def create_custom_model( async def list_custom_models( tenant_id: str, model_type: str | None = Query(default=None, description="模型类型过滤"), - status: str | None = Query(default=None, description="状态过滤") + status: str | None = Query(default=None, description="状态过滤"), ): """列出自定义模型""" if not AI_MANAGER_AVAILABLE: @@ -11746,7 +11158,7 @@ async def list_custom_models( "model_type": m.model_type.value, "status": m.status.value, "metrics": m.metrics, - "created_at": m.created_at + "created_at": m.created_at, } for m in models ] @@ -11778,15 +11190,12 @@ async def get_custom_model(model_id: str): "model_path": model.model_path, "created_at": model.created_at, "trained_at": model.trained_at, - "created_by": model.created_by + "created_by": model.created_by, } @app.post("/api/v1/ai/custom-models/{model_id}/samples", tags=["AI Enhancement"]) -async def add_training_sample( - model_id: str, - request: AddTrainingSampleRequest -): +async def add_training_sample(model_id: str, request: AddTrainingSampleRequest): """添加训练样本""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -11794,10 +11203,7 @@ async def add_training_sample( manager = get_ai_manager() sample = manager.add_training_sample( - model_id=model_id, - text=request.text, - entities=request.entities, - metadata=request.metadata + model_id=model_id, text=request.text, entities=request.entities, metadata=request.metadata ) return { @@ -11805,7 +11211,7 @@ async def add_training_sample( "model_id": sample.model_id, "text": sample.text, "entities": sample.entities, - "created_at": sample.created_at + "created_at": sample.created_at, } @@ -11820,13 +11226,7 @@ async def get_training_samples(model_id: str): return { "samples": [ - { - "id": s.id, - "text": s.text, - "entities": s.entities, - "metadata": s.metadata, - "created_at": s.created_at - } + {"id": s.id, "text": s.text, "entities": s.entities, "metadata": s.metadata, "created_at": s.created_at} for s in samples ] } @@ -11842,12 +11242,7 @@ async def train_custom_model(model_id: str): try: model = await manager.train_custom_model(model_id) - return { - "id": model.id, - "status": model.status.value, - "metrics": model.metrics, - "trained_at": model.trained_at - } + return {"id": model.id, "status": model.status.value, "metrics": model.metrics, "trained_at": model.trained_at} except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -11862,22 +11257,14 @@ async def predict_with_custom_model(request: PredictRequest): try: entities = await manager.predict_with_custom_model(request.model_id, request.text) - return { - "model_id": request.model_id, - "text": request.text, - "entities": entities - } + return {"model_id": request.model_id, "text": request.text, "entities": entities} except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) # 多模态分析 API @app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/multimodal", tags=["AI Enhancement"]) -async def analyze_multimodal( - tenant_id: str, - project_id: str, - request: MultimodalAnalysisRequest -): +async def analyze_multimodal(tenant_id: str, project_id: str, request: MultimodalAnalysisRequest): """多模态分析""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -11891,7 +11278,7 @@ async def analyze_multimodal( provider=MultimodalProvider(request.provider), input_type=request.input_type, input_urls=request.input_urls, - prompt=request.prompt + prompt=request.prompt, ) return { @@ -11901,7 +11288,7 @@ async def analyze_multimodal( "result": analysis.result, "tokens_used": analysis.tokens_used, "cost": analysis.cost, - "created_at": analysis.created_at + "created_at": analysis.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -11909,8 +11296,7 @@ async def analyze_multimodal( @app.get("/api/v1/tenants/{tenant_id}/ai/multimodal", tags=["AI Enhancement"]) async def list_multimodal_analyses( - tenant_id: str, - project_id: str | None = Query(default=None, description="项目ID过滤") + tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤") ): """获取多模态分析历史""" if not AI_MANAGER_AVAILABLE: @@ -11930,7 +11316,7 @@ async def list_multimodal_analyses( "result": a.result, "tokens_used": a.tokens_used, "cost": a.cost, - "created_at": a.created_at + "created_at": a.created_at, } for a in analyses ] @@ -11939,11 +11325,7 @@ async def list_multimodal_analyses( # 知识图谱 RAG API @app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/kg-rag", tags=["AI Enhancement"]) -async def create_kg_rag( - tenant_id: str, - project_id: str, - request: CreateKGRAGRequest -): +async def create_kg_rag(tenant_id: str, project_id: str, request: CreateKGRAGRequest): """创建知识图谱 RAG 配置""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -11957,7 +11339,7 @@ async def create_kg_rag( description=request.description, kg_config=request.kg_config, retrieval_config=request.retrieval_config, - generation_config=request.generation_config + generation_config=request.generation_config, ) return { @@ -11965,15 +11347,12 @@ async def create_kg_rag( "name": rag.name, "description": rag.description, "is_active": rag.is_active, - "created_at": rag.created_at + "created_at": rag.created_at, } @app.get("/api/v1/tenants/{tenant_id}/ai/kg-rag", tags=["AI Enhancement"]) -async def list_kg_rags( - tenant_id: str, - project_id: str | None = Query(default=None, description="项目ID过滤") -): +async def list_kg_rags(tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤")): """列出知识图谱 RAG 配置""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -11989,7 +11368,7 @@ async def list_kg_rags( "name": r.name, "description": r.description, "is_active": r.is_active, - "created_at": r.created_at + "created_at": r.created_at, } for r in rags ] @@ -12000,7 +11379,7 @@ async def list_kg_rags( async def query_kg_rag( request: KGRAGQueryRequest, project_entities: list[dict] = Body(default=[], description="项目实体列表"), - project_relations: list[dict] = Body(default=[], description="项目关系列表") + project_relations: list[dict] = Body(default=[], description="项目关系列表"), ): """基于知识图谱的 RAG 查询""" if not AI_MANAGER_AVAILABLE: @@ -12013,7 +11392,7 @@ async def query_kg_rag( rag_id=request.rag_id, query=request.query, project_entities=project_entities, - project_relations=project_relations + project_relations=project_relations, ) return { @@ -12025,7 +11404,7 @@ async def query_kg_rag( "confidence": result.confidence, "tokens_used": result.tokens_used, "latency_ms": result.latency_ms, - "created_at": result.created_at + "created_at": result.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -12033,11 +11412,7 @@ async def query_kg_rag( # 智能摘要 API @app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/summarize", tags=["AI Enhancement"]) -async def generate_smart_summary( - tenant_id: str, - project_id: str, - request: SmartSummaryRequest -): +async def generate_smart_summary(tenant_id: str, project_id: str, request: SmartSummaryRequest): """生成智能摘要""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -12050,7 +11425,7 @@ async def generate_smart_summary( source_type=request.source_type, source_id=request.source_id, summary_type=request.summary_type, - content_data=request.content_data + content_data=request.content_data, ) return { @@ -12063,7 +11438,7 @@ async def generate_smart_summary( "entities_mentioned": summary.entities_mentioned, "confidence": summary.confidence, "tokens_used": summary.tokens_used, - "created_at": summary.created_at + "created_at": summary.created_at, } @@ -12072,7 +11447,7 @@ async def list_smart_summaries( tenant_id: str, project_id: str, source_type: str | None = Query(default=None, description="来源类型过滤"), - source_id: str | None = Query(default=None, description="来源ID过滤") + source_id: str | None = Query(default=None, description="来源ID过滤"), ): """获取智能摘要列表""" if not AI_MANAGER_AVAILABLE: @@ -12086,11 +11461,7 @@ async def list_smart_summaries( # 预测模型 API @app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/prediction-models", tags=["AI Enhancement"]) -async def create_prediction_model( - tenant_id: str, - project_id: str, - request: CreatePredictionModelRequest -): +async def create_prediction_model(tenant_id: str, project_id: str, request: CreatePredictionModelRequest): """创建预测模型""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -12105,7 +11476,7 @@ async def create_prediction_model( prediction_type=PredictionType(request.prediction_type), target_entity_type=request.target_entity_type, features=request.features, - model_config=request.model_config + model_config=request.model_config, ) return { @@ -12115,7 +11486,7 @@ async def create_prediction_model( "target_entity_type": model.target_entity_type, "features": model.features, "is_active": model.is_active, - "created_at": model.created_at + "created_at": model.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -12123,8 +11494,7 @@ async def create_prediction_model( @app.get("/api/v1/tenants/{tenant_id}/ai/prediction-models", tags=["AI Enhancement"]) async def list_prediction_models( - tenant_id: str, - project_id: str | None = Query(default=None, description="项目ID过滤") + tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤") ): """列出预测模型""" if not AI_MANAGER_AVAILABLE: @@ -12145,7 +11515,7 @@ async def list_prediction_models( "accuracy": m.accuracy, "last_trained_at": m.last_trained_at, "prediction_count": m.prediction_count, - "is_active": m.is_active + "is_active": m.is_active, } for m in models ] @@ -12177,15 +11547,12 @@ async def get_prediction_model(model_id: str): "last_trained_at": model.last_trained_at, "prediction_count": model.prediction_count, "is_active": model.is_active, - "created_at": model.created_at + "created_at": model.created_at, } @app.post("/api/v1/ai/prediction-models/{model_id}/train", tags=["AI Enhancement"]) -async def train_prediction_model( - model_id: str, - historical_data: list[dict] = Body(..., description="历史训练数据") -): +async def train_prediction_model(model_id: str, historical_data: list[dict] = Body(..., description="历史训练数据")): """训练预测模型""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -12194,11 +11561,7 @@ async def train_prediction_model( try: model = await manager.train_prediction_model(model_id, historical_data) - return { - "id": model.id, - "accuracy": model.accuracy, - "last_trained_at": model.last_trained_at - } + return {"id": model.id, "accuracy": model.accuracy, "last_trained_at": model.last_trained_at} except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -12222,17 +11585,14 @@ async def predict(request: PredictDataRequest): "prediction_data": result.prediction_data, "confidence": result.confidence, "explanation": result.explanation, - "created_at": result.created_at + "created_at": result.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/ai/prediction-models/{model_id}/results", tags=["AI Enhancement"]) -async def get_prediction_results( - model_id: str, - limit: int = Query(default=100, description="返回结果数量限制") -): +async def get_prediction_results(model_id: str, limit: int = Query(default=100, description="返回结果数量限制")): """获取预测结果历史""" if not AI_MANAGER_AVAILABLE: raise HTTPException(status_code=500, detail="AI manager not available") @@ -12251,7 +11611,7 @@ async def get_prediction_results( "explanation": r.explanation, "actual_value": r.actual_value, "is_correct": r.is_correct, - "created_at": r.created_at + "created_at": r.created_at, } for r in results ] @@ -12266,9 +11626,7 @@ async def update_prediction_feedback(request: PredictionFeedbackRequest): manager = get_ai_manager() manager.update_prediction_feedback( - prediction_id=request.prediction_id, - actual_value=request.actual_value, - is_correct=request.is_correct + prediction_id=request.prediction_id, actual_value=request.actual_value, is_correct=request.is_correct ) return {"status": "success", "message": "Feedback updated"} @@ -12276,6 +11634,7 @@ async def update_prediction_feedback(request: PredictionFeedbackRequest): # ==================== Phase 8 Task 5: Growth & Analytics Endpoints ==================== + # Pydantic Models for Growth API class TrackEventRequest(BaseModel): tenant_id: str @@ -12391,6 +11750,7 @@ def get_growth_manager_instance(): # ==================== 用户行为分析 API ==================== + @app.post("/api/v1/analytics/track", tags=["Growth & Analytics"]) async def track_event_endpoint(request: TrackEventRequest): """ @@ -12413,18 +11773,14 @@ async def track_event_endpoint(request: TrackEventRequest): session_id=request.session_id, device_info=request.device_info, referrer=request.referrer, - utm_params={ - "source": request.utm_source, - "medium": request.utm_medium, - "campaign": request.utm_campaign - } if any([request.utm_source, request.utm_medium, request.utm_campaign]) else None + utm_params=( + {"source": request.utm_source, "medium": request.utm_medium, "campaign": request.utm_campaign} + if any([request.utm_source, request.utm_medium, request.utm_campaign]) + else None + ), ) - return { - "success": True, - "event_id": event.id, - "timestamp": event.timestamp.isoformat() - } + return {"success": True, "event_id": event.id, "timestamp": event.timestamp.isoformat()} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -12442,11 +11798,7 @@ async def get_analytics_dashboard(tenant_id: str): @app.get("/api/v1/analytics/summary/{tenant_id}", tags=["Growth & Analytics"]) -async def get_analytics_summary( - tenant_id: str, - start_date: str | None = None, - end_date: str | None = None -): +async def get_analytics_summary(tenant_id: str, start_date: str | None = None, end_date: str | None = None): """获取用户分析汇总""" if not GROWTH_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Growth manager not available") @@ -12483,12 +11835,13 @@ async def get_user_profile(tenant_id: str, user_id: str): "feature_usage": profile.feature_usage, "ltv": profile.ltv, "churn_risk_score": profile.churn_risk_score, - "engagement_score": profile.engagement_score + "engagement_score": profile.engagement_score, } # ==================== 转化漏斗 API ==================== + @app.post("/api/v1/analytics/funnels", tags=["Growth & Analytics"]) async def create_funnel_endpoint(request: CreateFunnelRequest, created_by: str = "system"): """创建转化漏斗""" @@ -12505,23 +11858,14 @@ async def create_funnel_endpoint(request: CreateFunnelRequest, created_by: str = name=request.name, description=request.description, steps=request.steps, - created_by=created_by + created_by=created_by, ) - return { - "id": funnel.id, - "name": funnel.name, - "steps": funnel.steps, - "created_at": funnel.created_at - } + return {"id": funnel.id, "name": funnel.name, "steps": funnel.steps, "created_at": funnel.created_at} @app.get("/api/v1/analytics/funnels/{funnel_id}/analyze", tags=["Growth & Analytics"]) -async def analyze_funnel_endpoint( - funnel_id: str, - period_start: str | None = None, - period_end: str | None = None -): +async def analyze_funnel_endpoint(funnel_id: str, period_start: str | None = None, period_end: str | None = None): """分析漏斗转化率""" if not GROWTH_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Growth manager not available") @@ -12543,15 +11887,13 @@ async def analyze_funnel_endpoint( "total_users": analysis.total_users, "step_conversions": analysis.step_conversions, "overall_conversion": analysis.overall_conversion, - "drop_off_points": analysis.drop_off_points + "drop_off_points": analysis.drop_off_points, } @app.get("/api/v1/analytics/retention/{tenant_id}", tags=["Growth & Analytics"]) async def calculate_retention( - tenant_id: str, - cohort_date: str, - periods: str | None = None # JSON array: [1, 3, 7, 14, 30] + tenant_id: str, cohort_date: str, periods: str | None = None # JSON array: [1, 3, 7, 14, 30] ): """计算留存率""" if not GROWTH_MANAGER_AVAILABLE: @@ -12569,6 +11911,7 @@ async def calculate_retention( # ==================== A/B 测试 API ==================== + @app.post("/api/v1/experiments", tags=["Growth & Analytics"]) async def create_experiment_endpoint(request: CreateExperimentRequest, created_by: str = "system"): """创建 A/B 测试实验""" @@ -12593,7 +11936,7 @@ async def create_experiment_endpoint(request: CreateExperimentRequest, created_b secondary_metrics=request.secondary_metrics, min_sample_size=request.min_sample_size, confidence_level=request.confidence_level, - created_by=created_by + created_by=created_by, ) return { @@ -12601,7 +11944,7 @@ async def create_experiment_endpoint(request: CreateExperimentRequest, created_b "name": experiment.name, "status": experiment.status.value, "variants": experiment.variants, - "created_at": experiment.created_at + "created_at": experiment.created_at, } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -12628,7 +11971,7 @@ async def list_experiments(status: str | None = None): "hypothesis": e.hypothesis, "primary_metric": e.primary_metric, "start_date": e.start_date.isoformat() if e.start_date else None, - "end_date": e.end_date.isoformat() if e.end_date else None + "end_date": e.end_date.isoformat() if e.end_date else None, } for e in experiments ] @@ -12658,7 +12001,7 @@ async def get_experiment_endpoint(experiment_id: str): "primary_metric": experiment.primary_metric, "secondary_metrics": experiment.secondary_metrics, "start_date": experiment.start_date.isoformat() if experiment.start_date else None, - "end_date": experiment.end_date.isoformat() if experiment.end_date else None + "end_date": experiment.end_date.isoformat() if experiment.end_date else None, } @@ -12671,19 +12014,13 @@ async def assign_variant_endpoint(experiment_id: str, request: AssignVariantRequ manager = get_growth_manager_instance() variant_id = manager.assign_variant( - experiment_id=experiment_id, - user_id=request.user_id, - user_attributes=request.user_attributes + experiment_id=experiment_id, user_id=request.user_id, user_attributes=request.user_attributes ) if not variant_id: raise HTTPException(status_code=400, detail="Failed to assign variant") - return { - "experiment_id": experiment_id, - "user_id": request.user_id, - "variant_id": variant_id - } + return {"experiment_id": experiment_id, "user_id": request.user_id, "variant_id": variant_id} @app.post("/api/v1/experiments/{experiment_id}/metrics", tags=["Growth & Analytics"]) @@ -12699,7 +12036,7 @@ async def record_experiment_metric_endpoint(experiment_id: str, request: RecordM variant_id=request.variant_id, user_id=request.user_id, metric_name=request.metric_name, - metric_value=request.metric_value + metric_value=request.metric_value, ) return {"success": True} @@ -12737,7 +12074,7 @@ async def start_experiment_endpoint(experiment_id: str): return { "id": experiment.id, "status": experiment.status.value, - "start_date": experiment.start_date.isoformat() if experiment.start_date else None + "start_date": experiment.start_date.isoformat() if experiment.start_date else None, } @@ -12757,12 +12094,13 @@ async def stop_experiment_endpoint(experiment_id: str): return { "id": experiment.id, "status": experiment.status.value, - "end_date": experiment.end_date.isoformat() if experiment.end_date else None + "end_date": experiment.end_date.isoformat() if experiment.end_date else None, } # ==================== 邮件营销 API ==================== + @app.post("/api/v1/email/templates", tags=["Growth & Analytics"]) async def create_email_template_endpoint(request: CreateEmailTemplateRequest): """创建邮件模板""" @@ -12783,7 +12121,7 @@ async def create_email_template_endpoint(request: CreateEmailTemplateRequest): variables=request.variables, from_name=request.from_name, from_email=request.from_email, - reply_to=request.reply_to + reply_to=request.reply_to, ) return { @@ -12792,7 +12130,7 @@ async def create_email_template_endpoint(request: CreateEmailTemplateRequest): "template_type": template.template_type.value, "subject": template.subject, "variables": template.variables, - "created_at": template.created_at + "created_at": template.created_at, } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -12818,7 +12156,7 @@ async def list_email_templates(template_type: str | None = None): "template_type": t.template_type.value, "subject": t.subject, "variables": t.variables, - "is_active": t.is_active + "is_active": t.is_active, } for t in templates ] @@ -12846,7 +12184,7 @@ async def get_email_template_endpoint(template_id: str): "text_content": template.text_content, "variables": template.variables, "from_name": template.from_name, - "from_email": template.from_email + "from_email": template.from_email, } @@ -12882,7 +12220,7 @@ async def create_email_campaign_endpoint(request: CreateCampaignRequest): name=request.name, template_id=request.template_id, recipient_list=request.recipients, - scheduled_at=scheduled_at + scheduled_at=scheduled_at, ) return { @@ -12891,7 +12229,7 @@ async def create_email_campaign_endpoint(request: CreateCampaignRequest): "template_id": campaign.template_id, "status": campaign.status, "recipient_count": campaign.recipient_count, - "scheduled_at": campaign.scheduled_at + "scheduled_at": campaign.scheduled_at, } @@ -12926,7 +12264,7 @@ async def create_automation_workflow_endpoint(request: CreateAutomationWorkflowR description=request.description, trigger_type=WorkflowTriggerType(request.trigger_type), trigger_conditions=request.trigger_conditions, - actions=request.actions + actions=request.actions, ) return { @@ -12934,12 +12272,13 @@ async def create_automation_workflow_endpoint(request: CreateAutomationWorkflowR "name": workflow.name, "trigger_type": workflow.trigger_type.value, "is_active": workflow.is_active, - "created_at": workflow.created_at + "created_at": workflow.created_at, } # ==================== 推荐系统 API ==================== + @app.post("/api/v1/referral/programs", tags=["Growth & Analytics"]) async def create_referral_program_endpoint(request: CreateReferralProgramRequest): """创建推荐计划""" @@ -12959,7 +12298,7 @@ async def create_referral_program_endpoint(request: CreateReferralProgramRequest referee_reward_value=request.referee_reward_value, max_referrals_per_user=request.max_referrals_per_user, referral_code_length=request.referral_code_length, - expiry_days=request.expiry_days + expiry_days=request.expiry_days, ) return { @@ -12969,7 +12308,7 @@ async def create_referral_program_endpoint(request: CreateReferralProgramRequest "referrer_reward_value": program.referrer_reward_value, "referee_reward_type": program.referee_reward_type, "referee_reward_value": program.referee_reward_value, - "is_active": program.is_active + "is_active": program.is_active, } @@ -12991,7 +12330,7 @@ async def generate_referral_code_endpoint(program_id: str, referrer_id: str): "referral_code": referral.referral_code, "referrer_id": referral.referrer_id, "status": referral.status.value, - "expires_at": referral.expires_at.isoformat() + "expires_at": referral.expires_at.isoformat(), } @@ -13042,7 +12381,7 @@ async def create_team_incentive_endpoint(request: CreateTeamIncentiveRequest): incentive_type=request.incentive_type, incentive_value=request.incentive_value, valid_from=datetime.fromisoformat(request.valid_from), - valid_until=datetime.fromisoformat(request.valid_until) + valid_until=datetime.fromisoformat(request.valid_until), ) return { @@ -13053,16 +12392,12 @@ async def create_team_incentive_endpoint(request: CreateTeamIncentiveRequest): "incentive_type": incentive.incentive_type, "incentive_value": incentive.incentive_value, "valid_from": incentive.valid_from.isoformat(), - "valid_until": incentive.valid_until.isoformat() + "valid_until": incentive.valid_until.isoformat(), } @app.get("/api/v1/team-incentives/check", tags=["Growth & Analytics"]) -async def check_team_incentive_eligibility( - tenant_id: str, - current_tier: str, - team_size: int -): +async def check_team_incentive_eligibility(tenant_id: str, current_tier: str, team_size: int): """检查团队激励资格""" if not GROWTH_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Growth manager not available") @@ -13073,12 +12408,7 @@ async def check_team_incentive_eligibility( return { "eligible_incentives": [ - { - "id": i.id, - "name": i.name, - "incentive_type": i.incentive_type, - "incentive_value": i.incentive_value - } + {"id": i.id, "name": i.name, "incentive_type": i.incentive_type, "incentive_value": i.incentive_value} for i in incentives ] } @@ -13101,6 +12431,7 @@ try: TemplateCategory, TemplateStatus, ) + DEVELOPER_ECOSYSTEM_AVAILABLE = True except ImportError as e: print(f"Developer Ecosystem Manager import error: {e}") @@ -13253,10 +12584,10 @@ def get_developer_ecosystem_manager_instance(): # ==================== SDK Release & Management API ==================== + @app.post("/api/v1/developer/sdks", tags=["Developer Ecosystem"]) async def create_sdk_release_endpoint( - request: SDKReleaseCreate, - created_by: str = Header(default="system", description="创建者ID") + request: SDKReleaseCreate, created_by: str = Header(default="system", description="创建者ID") ): """创建 SDK 发布""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13279,7 +12610,7 @@ async def create_sdk_release_endpoint( dependencies=request.dependencies, file_size=request.file_size, checksum=request.checksum, - created_by=created_by + created_by=created_by, ) return { @@ -13289,7 +12620,7 @@ async def create_sdk_release_endpoint( "version": sdk.version, "status": sdk.status.value, "package_name": sdk.package_name, - "created_at": sdk.created_at + "created_at": sdk.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -13299,7 +12630,7 @@ async def create_sdk_release_endpoint( async def list_sdk_releases_endpoint( language: str | None = Query(default=None, description="SDK语言过滤"), status: str | None = Query(default=None, description="状态过滤"), - search: str | None = Query(default=None, description="搜索关键词") + search: str | None = Query(default=None, description="搜索关键词"), ): """列出 SDK 发布""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13323,7 +12654,7 @@ async def list_sdk_releases_endpoint( "package_name": s.package_name, "status": s.status.value, "download_count": s.download_count, - "created_at": s.created_at + "created_at": s.created_at, } for s in sdks ] @@ -13360,7 +12691,7 @@ async def get_sdk_release_endpoint(sdk_id: str): "checksum": sdk.checksum, "download_count": sdk.download_count, "created_at": sdk.created_at, - "published_at": sdk.published_at + "published_at": sdk.published_at, } @@ -13378,12 +12709,7 @@ async def update_sdk_release_endpoint(sdk_id: str, request: SDKReleaseUpdate): if not sdk: raise HTTPException(status_code=404, detail="SDK not found") - return { - "id": sdk.id, - "name": sdk.name, - "status": sdk.status.value, - "updated_at": sdk.updated_at - } + return {"id": sdk.id, "name": sdk.name, "status": sdk.status.value, "updated_at": sdk.updated_at} @app.post("/api/v1/developer/sdks/{sdk_id}/publish", tags=["Developer Ecosystem"]) @@ -13398,11 +12724,7 @@ async def publish_sdk_release_endpoint(sdk_id: str): if not sdk: raise HTTPException(status_code=404, detail="SDK not found") - return { - "id": sdk.id, - "status": sdk.status.value, - "published_at": sdk.published_at - } + return {"id": sdk.id, "status": sdk.status.value, "published_at": sdk.published_at} @app.post("/api/v1/developer/sdks/{sdk_id}/download", tags=["Developer Ecosystem"]) @@ -13434,7 +12756,7 @@ async def get_sdk_versions_endpoint(sdk_id: str): "is_latest": v.is_latest, "is_lts": v.is_lts, "download_count": v.download_count, - "created_at": v.created_at + "created_at": v.created_at, } for v in versions ] @@ -13456,7 +12778,7 @@ async def add_sdk_version_endpoint(sdk_id: str, request: SDKVersionCreate): release_notes=request.release_notes, download_url=request.download_url, checksum=request.checksum, - file_size=request.file_size + file_size=request.file_size, ) return { @@ -13464,17 +12786,18 @@ async def add_sdk_version_endpoint(sdk_id: str, request: SDKVersionCreate): "version": version.version, "is_latest": version.is_latest, "is_lts": version.is_lts, - "created_at": version.created_at + "created_at": version.created_at, } # ==================== Template Market API ==================== + @app.post("/api/v1/developer/templates", tags=["Developer Ecosystem"]) async def create_template_endpoint( request: TemplateCreate, author_id: str = Header(default="system", description="作者ID"), - author_name: str = Header(default="System", description="作者名称") + author_name: str = Header(default="System", description="作者名称"), ): """创建模板""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13500,7 +12823,7 @@ async def create_template_endpoint( version=request.version, min_platform_version=request.min_platform_version, file_size=request.file_size, - checksum=request.checksum + checksum=request.checksum, ) return { @@ -13509,7 +12832,7 @@ async def create_template_endpoint( "category": template.category.value, "status": template.status.value, "price": template.price, - "created_at": template.created_at + "created_at": template.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -13523,7 +12846,7 @@ async def list_templates_endpoint( author_id: str | None = Query(default=None, description="作者ID过滤"), min_price: float | None = Query(default=None, description="最低价格"), max_price: float | None = Query(default=None, description="最高价格"), - sort_by: str = Query(default="created_at", description="排序方式") + sort_by: str = Query(default="created_at", description="排序方式"), ): """列出模板""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13541,7 +12864,7 @@ async def list_templates_endpoint( author_id=author_id, min_price=min_price, max_price=max_price, - sort_by=sort_by + sort_by=sort_by, ) return { @@ -13558,7 +12881,7 @@ async def list_templates_endpoint( "rating": t.rating, "install_count": t.install_count, "version": t.version, - "created_at": t.created_at + "created_at": t.created_at, } for t in templates ] @@ -13598,7 +12921,7 @@ async def get_template_endpoint(template_id: str): "rating_count": template.rating_count, "review_count": template.review_count, "version": template.version, - "created_at": template.created_at + "created_at": template.created_at, } @@ -13614,10 +12937,7 @@ async def approve_template_endpoint(template_id: str, reviewed_by: str = Header( if not template: raise HTTPException(status_code=404, detail="Template not found") - return { - "id": template.id, - "status": template.status.value - } + return {"id": template.id, "status": template.status.value} @app.post("/api/v1/developer/templates/{template_id}/publish", tags=["Developer Ecosystem"]) @@ -13632,11 +12952,7 @@ async def publish_template_endpoint(template_id: str): if not template: raise HTTPException(status_code=404, detail="Template not found") - return { - "id": template.id, - "status": template.status.value, - "published_at": template.published_at - } + return {"id": template.id, "status": template.status.value, "published_at": template.published_at} @app.post("/api/v1/developer/templates/{template_id}/reject", tags=["Developer Ecosystem"]) @@ -13651,10 +12967,7 @@ async def reject_template_endpoint(template_id: str, reason: str = ""): if not template: raise HTTPException(status_code=404, detail="Template not found") - return { - "id": template.id, - "status": template.status.value - } + return {"id": template.id, "status": template.status.value} @app.post("/api/v1/developer/templates/{template_id}/install", tags=["Developer Ecosystem"]) @@ -13674,7 +12987,7 @@ async def add_template_review_endpoint( template_id: str, request: TemplateReviewCreate, user_id: str = Header(default="user", description="用户ID"), - user_name: str = Header(default="User", description="用户名称") + user_name: str = Header(default="User", description="用户名称"), ): """添加模板评价""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13688,22 +13001,14 @@ async def add_template_review_endpoint( user_name=user_name, rating=request.rating, comment=request.comment, - is_verified_purchase=request.is_verified_purchase + is_verified_purchase=request.is_verified_purchase, ) - return { - "id": review.id, - "rating": review.rating, - "comment": review.comment, - "created_at": review.created_at - } + return {"id": review.id, "rating": review.rating, "comment": review.comment, "created_at": review.created_at} @app.get("/api/v1/developer/templates/{template_id}/reviews", tags=["Developer Ecosystem"]) -async def get_template_reviews_endpoint( - template_id: str, - limit: int = Query(default=50, description="返回数量限制") -): +async def get_template_reviews_endpoint(template_id: str, limit: int = Query(default=50, description="返回数量限制")): """获取模板评价""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: raise HTTPException(status_code=503, detail="Developer ecosystem manager not available") @@ -13720,7 +13025,7 @@ async def get_template_reviews_endpoint( "comment": r.comment, "is_verified_purchase": r.is_verified_purchase, "helpful_count": r.helpful_count, - "created_at": r.created_at + "created_at": r.created_at, } for r in reviews ] @@ -13729,11 +13034,12 @@ async def get_template_reviews_endpoint( # ==================== Plugin Market API ==================== + @app.post("/api/v1/developer/plugins", tags=["Developer Ecosystem"]) async def create_developer_plugin_endpoint( request: PluginCreate, author_id: str = Header(default="system", description="作者ID"), - author_name: str = Header(default="System", description="作者名称") + author_name: str = Header(default="System", description="作者名称"), ): """创建插件""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13762,7 +13068,7 @@ async def create_developer_plugin_endpoint( version=request.version, min_platform_version=request.min_platform_version, file_size=request.file_size, - checksum=request.checksum + checksum=request.checksum, ) return { @@ -13772,7 +13078,7 @@ async def create_developer_plugin_endpoint( "status": plugin.status.value, "price": plugin.price, "pricing_model": plugin.pricing_model, - "created_at": plugin.created_at + "created_at": plugin.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -13784,7 +13090,7 @@ async def list_developer_plugins_endpoint( status: str | None = Query(default=None, description="状态过滤"), search: str | None = Query(default=None, description="搜索关键词"), author_id: str | None = Query(default=None, description="作者ID过滤"), - sort_by: str = Query(default="created_at", description="排序方式") + sort_by: str = Query(default="created_at", description="排序方式"), ): """列出插件""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13796,11 +13102,7 @@ async def list_developer_plugins_endpoint( status_enum = PluginStatus(status) if status else None plugins = manager.list_plugins( - category=category_enum, - status=status_enum, - search=search, - author_id=author_id, - sort_by=sort_by + category=category_enum, status=status_enum, search=search, author_id=author_id, sort_by=sort_by ) return { @@ -13818,7 +13120,7 @@ async def list_developer_plugins_endpoint( "install_count": p.install_count, "active_install_count": p.active_install_count, "version": p.version, - "created_at": p.created_at + "created_at": p.created_at, } for p in plugins ] @@ -13861,7 +13163,7 @@ async def get_developer_plugin_endpoint(plugin_id: str): "version": plugin.version, "reviewed_by": plugin.reviewed_by, "reviewed_at": plugin.reviewed_at, - "created_at": plugin.created_at + "created_at": plugin.created_at, } @@ -13870,7 +13172,7 @@ async def review_plugin_endpoint( plugin_id: str, status: str = Query(..., description="审核状态: approved/rejected"), reviewed_by: str = Header(default="system", description="审核人ID"), - notes: str = "" + notes: str = "", ): """审核插件""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13889,7 +13191,7 @@ async def review_plugin_endpoint( "id": plugin.id, "status": plugin.status.value, "reviewed_by": plugin.reviewed_by, - "reviewed_at": plugin.reviewed_at + "reviewed_at": plugin.reviewed_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -13907,11 +13209,7 @@ async def publish_plugin_endpoint(plugin_id: str): if not plugin: raise HTTPException(status_code=404, detail="Plugin not found") - return { - "id": plugin.id, - "status": plugin.status.value, - "published_at": plugin.published_at - } + return {"id": plugin.id, "status": plugin.status.value, "published_at": plugin.published_at} @app.post("/api/v1/developer/plugins/{plugin_id}/install", tags=["Developer Ecosystem"]) @@ -13931,7 +13229,7 @@ async def add_plugin_review_endpoint( plugin_id: str, request: PluginReviewCreate, user_id: str = Header(default="user", description="用户ID"), - user_name: str = Header(default="User", description="用户名称") + user_name: str = Header(default="User", description="用户名称"), ): """添加插件评价""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13945,22 +13243,14 @@ async def add_plugin_review_endpoint( user_name=user_name, rating=request.rating, comment=request.comment, - is_verified_purchase=request.is_verified_purchase + is_verified_purchase=request.is_verified_purchase, ) - return { - "id": review.id, - "rating": review.rating, - "comment": review.comment, - "created_at": review.created_at - } + return {"id": review.id, "rating": review.rating, "comment": review.comment, "created_at": review.created_at} @app.get("/api/v1/developer/plugins/{plugin_id}/reviews", tags=["Developer Ecosystem"]) -async def get_plugin_reviews_endpoint( - plugin_id: str, - limit: int = Query(default=50, description="返回数量限制") -): +async def get_plugin_reviews_endpoint(plugin_id: str, limit: int = Query(default=50, description="返回数量限制")): """获取插件评价""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: raise HTTPException(status_code=503, detail="Developer ecosystem manager not available") @@ -13977,7 +13267,7 @@ async def get_plugin_reviews_endpoint( "comment": r.comment, "is_verified_purchase": r.is_verified_purchase, "helpful_count": r.helpful_count, - "created_at": r.created_at + "created_at": r.created_at, } for r in reviews ] @@ -13986,11 +13276,12 @@ async def get_plugin_reviews_endpoint( # ==================== Developer Revenue Sharing API ==================== + @app.get("/api/v1/developer/revenues/{developer_id}", tags=["Developer Ecosystem"]) async def get_developer_revenues_endpoint( developer_id: str, start_date: str | None = Query(default=None, description="开始日期 (ISO格式)"), - end_date: str | None = Query(default=None, description="结束日期 (ISO格式)") + end_date: str | None = Query(default=None, description="结束日期 (ISO格式)"), ): """获取开发者收益记录""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -14013,7 +13304,7 @@ async def get_developer_revenues_endpoint( "platform_fee": r.platform_fee, "developer_earnings": r.developer_earnings, "currency": r.currency, - "created_at": r.created_at + "created_at": r.created_at, } for r in revenues ] @@ -14034,6 +13325,7 @@ async def get_developer_revenue_summary_endpoint(developer_id: str): # ==================== Developer Profile & Management API ==================== + @app.post("/api/v1/developer/profiles", tags=["Developer Ecosystem"]) async def create_developer_profile_endpoint(request: DeveloperProfileCreate): """创建开发者档案""" @@ -14051,7 +13343,7 @@ async def create_developer_profile_endpoint(request: DeveloperProfileCreate): bio=request.bio, website=request.website, github_url=request.github_url, - avatar_url=request.avatar_url + avatar_url=request.avatar_url, ) return { @@ -14060,7 +13352,7 @@ async def create_developer_profile_endpoint(request: DeveloperProfileCreate): "display_name": profile.display_name, "email": profile.email, "status": profile.status.value, - "created_at": profile.created_at + "created_at": profile.created_at, } @@ -14092,7 +13384,7 @@ async def get_developer_profile_endpoint(developer_id: str): "template_count": profile.template_count, "rating_average": profile.rating_average, "created_at": profile.created_at, - "verified_at": profile.verified_at + "verified_at": profile.verified_at, } @@ -14114,7 +13406,7 @@ async def get_developer_profile_by_user_endpoint(user_id: str): "display_name": profile.display_name, "status": profile.status.value, "total_sales": profile.total_sales, - "total_downloads": profile.total_downloads + "total_downloads": profile.total_downloads, } @@ -14129,8 +13421,7 @@ async def update_developer_profile_endpoint(developer_id: str, request: Develope @app.post("/api/v1/developer/profiles/{developer_id}/verify", tags=["Developer Ecosystem"]) async def verify_developer_endpoint( - developer_id: str, - status: str = Query(..., description="认证状态: verified/certified/suspended") + developer_id: str, status: str = Query(..., description="认证状态: verified/certified/suspended") ): """验证开发者""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -14145,11 +13436,7 @@ async def verify_developer_endpoint( if not profile: raise HTTPException(status_code=404, detail="Developer profile not found") - return { - "id": profile.id, - "status": profile.status.value, - "verified_at": profile.verified_at - } + return {"id": profile.id, "status": profile.status.value, "verified_at": profile.verified_at} except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -14168,11 +13455,12 @@ async def update_developer_stats_endpoint(developer_id: str): # ==================== Code Examples API ==================== + @app.post("/api/v1/developer/code-examples", tags=["Developer Ecosystem"]) async def create_code_example_endpoint( request: CodeExampleCreate, author_id: str = Header(default="system", description="作者ID"), - author_name: str = Header(default="System", description="作者名称") + author_name: str = Header(default="System", description="作者名称"), ): """创建代码示例""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -14191,7 +13479,7 @@ async def create_code_example_endpoint( author_id=author_id, author_name=author_name, sdk_id=request.sdk_id, - api_endpoints=request.api_endpoints + api_endpoints=request.api_endpoints, ) return { @@ -14200,7 +13488,7 @@ async def create_code_example_endpoint( "language": example.language, "category": example.category, "tags": example.tags, - "created_at": example.created_at + "created_at": example.created_at, } @@ -14209,7 +13497,7 @@ async def list_code_examples_endpoint( language: str | None = Query(default=None, description="编程语言过滤"), category: str | None = Query(default=None, description="分类过滤"), sdk_id: str | None = Query(default=None, description="SDK ID过滤"), - search: str | None = Query(default=None, description="搜索关键词") + search: str | None = Query(default=None, description="搜索关键词"), ): """列出代码示例""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -14231,7 +13519,7 @@ async def list_code_examples_endpoint( "view_count": e.view_count, "copy_count": e.copy_count, "rating": e.rating, - "created_at": e.created_at + "created_at": e.created_at, } for e in examples ] @@ -14267,7 +13555,7 @@ async def get_code_example_endpoint(example_id: str): "view_count": example.view_count, "copy_count": example.copy_count, "rating": example.rating, - "created_at": example.created_at + "created_at": example.created_at, } @@ -14285,6 +13573,7 @@ async def copy_code_example_endpoint(example_id: str): # ==================== API Documentation API ==================== + @app.get("/api/v1/developer/api-docs", tags=["Developer Ecosystem"]) async def get_latest_api_documentation_endpoint(): """获取最新 API 文档""" @@ -14302,7 +13591,7 @@ async def get_latest_api_documentation_endpoint(): "version": doc.version, "changelog": doc.changelog, "generated_at": doc.generated_at, - "generated_by": doc.generated_by + "generated_by": doc.generated_by, } @@ -14326,12 +13615,13 @@ async def get_api_documentation_endpoint(doc_id: str): "html_content": doc.html_content, "changelog": doc.changelog, "generated_at": doc.generated_at, - "generated_by": doc.generated_by + "generated_by": doc.generated_by, } # ==================== Developer Portal API ==================== + @app.post("/api/v1/developer/portal-configs", tags=["Developer Ecosystem"]) async def create_portal_config_endpoint(request: PortalConfigCreate): """创建开发者门户配置""" @@ -14354,7 +13644,7 @@ async def create_portal_config_endpoint(request: PortalConfigCreate): support_url=request.support_url, github_url=request.github_url, discord_url=request.discord_url, - api_base_url=request.api_base_url + api_base_url=request.api_base_url, ) return { @@ -14362,7 +13652,7 @@ async def create_portal_config_endpoint(request: PortalConfigCreate): "name": config.name, "theme": config.theme, "is_active": config.is_active, - "created_at": config.created_at + "created_at": config.created_at, } @@ -14392,7 +13682,7 @@ async def get_active_portal_config_endpoint(): "github_url": config.github_url, "discord_url": config.discord_url, "api_base_url": config.api_base_url, - "is_active": config.is_active + "is_active": config.is_active, } @@ -14417,7 +13707,7 @@ async def get_portal_config_endpoint(config_id: str): "secondary_color": config.secondary_color, "support_email": config.support_email, "api_base_url": config.api_base_url, - "is_active": config.is_active + "is_active": config.is_active, } @@ -14471,8 +13761,9 @@ class AlertRuleResponse(BaseModel): class AlertChannelCreate(BaseModel): name: str = Field(..., description="渠道名称") - channel_type: str = Field(..., - description="渠道类型: pagerduty, opsgenie, feishu, dingtalk, slack, email, sms, webhook") + channel_type: str = Field( + ..., description="渠道类型: pagerduty, opsgenie, feishu, dingtalk, slack, email, sms, webhook" + ) config: dict = Field(default_factory=dict, description="渠道特定配置") severity_filter: list[str] = Field(default_factory=lambda: ["p0", "p1", "p2", "p3"], description="过滤的告警级别") @@ -14558,10 +13849,7 @@ class BackupJobCreate(BaseModel): # Alert Rules API @app.post("/api/v1/ops/alert-rules", response_model=AlertRuleResponse, tags=["Operations & Monitoring"]) async def create_alert_rule_endpoint( - tenant_id: str, - request: AlertRuleCreate, - user_id: str = "system", - _=Depends(verify_api_key) + tenant_id: str, request: AlertRuleCreate, user_id: str = "system", _=Depends(verify_api_key) ): """创建告警规则""" if not OPS_MANAGER_AVAILABLE: @@ -14584,7 +13872,7 @@ async def create_alert_rule_endpoint( channels=request.channels, labels=request.labels, annotations=request.annotations, - created_by=user_id + created_by=user_id, ) return AlertRuleResponse( @@ -14603,18 +13891,14 @@ async def create_alert_rule_endpoint( annotations=rule.annotations, is_enabled=rule.is_enabled, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/ops/alert-rules", tags=["Operations & Monitoring"]) -async def list_alert_rules_endpoint( - tenant_id: str, - is_enabled: bool | None = None, - _=Depends(verify_api_key) -): +async def list_alert_rules_endpoint(tenant_id: str, is_enabled: bool | None = None, _=Depends(verify_api_key)): """列出租户的告警规则""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -14639,7 +13923,7 @@ async def list_alert_rules_endpoint( annotations=rule.annotations, is_enabled=rule.is_enabled, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) for rule in rules ] @@ -14673,16 +13957,12 @@ async def get_alert_rule_endpoint(rule_id: str, _=Depends(verify_api_key)): annotations=rule.annotations, is_enabled=rule.is_enabled, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) @app.patch("/api/v1/ops/alert-rules/{rule_id}", response_model=AlertRuleResponse, tags=["Operations & Monitoring"]) -async def update_alert_rule_endpoint( - rule_id: str, - updates: dict, - _=Depends(verify_api_key) -): +async def update_alert_rule_endpoint(rule_id: str, updates: dict, _=Depends(verify_api_key)): """更新告警规则""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -14709,7 +13989,7 @@ async def update_alert_rule_endpoint( annotations=rule.annotations, is_enabled=rule.is_enabled, created_at=rule.created_at, - updated_at=rule.updated_at + updated_at=rule.updated_at, ) @@ -14730,11 +14010,7 @@ async def delete_alert_rule_endpoint(rule_id: str, _=Depends(verify_api_key)): # Alert Channels API @app.post("/api/v1/ops/alert-channels", response_model=AlertChannelResponse, tags=["Operations & Monitoring"]) -async def create_alert_channel_endpoint( - tenant_id: str, - request: AlertChannelCreate, - _=Depends(verify_api_key) -): +async def create_alert_channel_endpoint(tenant_id: str, request: AlertChannelCreate, _=Depends(verify_api_key)): """创建告警渠道""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -14747,7 +14023,7 @@ async def create_alert_channel_endpoint( name=request.name, channel_type=AlertChannelType(request.channel_type), config=request.config, - severity_filter=request.severity_filter + severity_filter=request.severity_filter, ) return AlertChannelResponse( @@ -14760,7 +14036,7 @@ async def create_alert_channel_endpoint( success_count=channel.success_count, fail_count=channel.fail_count, last_used_at=channel.last_used_at, - created_at=channel.created_at + created_at=channel.created_at, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -14786,7 +14062,7 @@ async def list_alert_channels_endpoint(tenant_id: str, _=Depends(verify_api_key) success_count=channel.success_count, fail_count=channel.fail_count, last_used_at=channel.last_used_at, - created_at=channel.created_at + created_at=channel.created_at, ) for channel in channels ] @@ -14810,11 +14086,7 @@ async def test_alert_channel_endpoint(channel_id: str, _=Depends(verify_api_key) # Alerts API @app.get("/api/v1/ops/alerts", tags=["Operations & Monitoring"]) async def list_alerts_endpoint( - tenant_id: str, - status: str | None = None, - severity: str | None = None, - limit: int = 100, - _=Depends(verify_api_key) + tenant_id: str, status: str | None = None, severity: str | None = None, limit: int = 100, _=Depends(verify_api_key) ): """列出租户的告警""" if not OPS_MANAGER_AVAILABLE: @@ -14842,18 +14114,14 @@ async def list_alerts_endpoint( started_at=alert.started_at, resolved_at=alert.resolved_at, acknowledged_by=alert.acknowledged_by, - suppression_count=alert.suppression_count + suppression_count=alert.suppression_count, ) for alert in alerts ] @app.post("/api/v1/ops/alerts/{alert_id}/acknowledge", tags=["Operations & Monitoring"]) -async def acknowledge_alert_endpoint( - alert_id: str, - user_id: str = "system", - _=Depends(verify_api_key) -): +async def acknowledge_alert_endpoint(alert_id: str, user_id: str = "system", _=Depends(verify_api_key)): """确认告警""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -14892,7 +14160,7 @@ async def record_resource_metric_endpoint( metric_value: float, unit: str, metadata: dict = None, - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """记录资源指标""" if not OPS_MANAGER_AVAILABLE: @@ -14908,7 +14176,7 @@ async def record_resource_metric_endpoint( metric_name=metric_name, metric_value=metric_value, unit=unit, - metadata=metadata + metadata=metadata, ) return { @@ -14917,7 +14185,7 @@ async def record_resource_metric_endpoint( "metric_name": metric.metric_name, "metric_value": metric.metric_value, "unit": metric.unit, - "timestamp": metric.timestamp + "timestamp": metric.timestamp, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -14925,10 +14193,7 @@ async def record_resource_metric_endpoint( @app.get("/api/v1/ops/resource-metrics", tags=["Operations & Monitoring"]) async def get_resource_metrics_endpoint( - tenant_id: str, - metric_name: str, - seconds: int = 3600, - _=Depends(verify_api_key) + tenant_id: str, metric_name: str, seconds: int = 3600, _=Depends(verify_api_key) ): """获取资源指标数据""" if not OPS_MANAGER_AVAILABLE: @@ -14945,7 +14210,7 @@ async def get_resource_metrics_endpoint( "metric_name": m.metric_name, "metric_value": m.metric_value, "unit": m.unit, - "timestamp": m.timestamp + "timestamp": m.timestamp, } for m in metrics ] @@ -14959,7 +14224,7 @@ async def create_capacity_plan_endpoint( current_capacity: float, prediction_date: str, confidence: float = 0.8, - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """创建容量规划""" if not OPS_MANAGER_AVAILABLE: @@ -14973,7 +14238,7 @@ async def create_capacity_plan_endpoint( resource_type=ResourceType(resource_type), current_capacity=current_capacity, prediction_date=prediction_date, - confidence=confidence + confidence=confidence, ) return { @@ -14985,7 +14250,7 @@ async def create_capacity_plan_endpoint( "confidence": plan.confidence, "recommended_action": plan.recommended_action, "estimated_cost": plan.estimated_cost, - "created_at": plan.created_at + "created_at": plan.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -15010,7 +14275,7 @@ async def list_capacity_plans_endpoint(tenant_id: str, _=Depends(verify_api_key) "confidence": plan.confidence, "recommended_action": plan.recommended_action, "estimated_cost": plan.estimated_cost, - "created_at": plan.created_at + "created_at": plan.created_at, } for plan in plans ] @@ -15019,9 +14284,7 @@ async def list_capacity_plans_endpoint(tenant_id: str, _=Depends(verify_api_key) # Auto Scaling API @app.post("/api/v1/ops/auto-scaling-policies", tags=["Operations & Monitoring"]) async def create_auto_scaling_policy_endpoint( - tenant_id: str, - request: AutoScalingPolicyCreate, - _=Depends(verify_api_key) + tenant_id: str, request: AutoScalingPolicyCreate, _=Depends(verify_api_key) ): """创建自动扩缩容策略""" if not OPS_MANAGER_AVAILABLE: @@ -15041,7 +14304,7 @@ async def create_auto_scaling_policy_endpoint( scale_down_threshold=request.scale_down_threshold, scale_up_step=request.scale_up_step, scale_down_step=request.scale_down_step, - cooldown_period=request.cooldown_period + cooldown_period=request.cooldown_period, ) return { @@ -15054,7 +14317,7 @@ async def create_auto_scaling_policy_endpoint( "scale_up_threshold": policy.scale_up_threshold, "scale_down_threshold": policy.scale_down_threshold, "is_enabled": policy.is_enabled, - "created_at": policy.created_at + "created_at": policy.created_at, } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -15078,7 +14341,7 @@ async def list_auto_scaling_policies_endpoint(tenant_id: str, _=Depends(verify_a "max_instances": policy.max_instances, "target_utilization": policy.target_utilization, "is_enabled": policy.is_enabled, - "created_at": policy.created_at + "created_at": policy.created_at, } for policy in policies ] @@ -15086,10 +14349,7 @@ async def list_auto_scaling_policies_endpoint(tenant_id: str, _=Depends(verify_a @app.get("/api/v1/ops/scaling-events", tags=["Operations & Monitoring"]) async def list_scaling_events_endpoint( - tenant_id: str, - policy_id: str | None = None, - limit: int = 100, - _=Depends(verify_api_key) + tenant_id: str, policy_id: str | None = None, limit: int = 100, _=Depends(verify_api_key) ): """获取扩缩容事件列表""" if not OPS_MANAGER_AVAILABLE: @@ -15108,7 +14368,7 @@ async def list_scaling_events_endpoint( "reason": event.reason, "status": event.status, "started_at": event.started_at, - "completed_at": event.completed_at + "completed_at": event.completed_at, } for event in events ] @@ -15116,11 +14376,7 @@ async def list_scaling_events_endpoint( # Health Check API @app.post("/api/v1/ops/health-checks", response_model=HealthCheckResponse, tags=["Operations & Monitoring"]) -async def create_health_check_endpoint( - tenant_id: str, - request: HealthCheckCreate, - _=Depends(verify_api_key) -): +async def create_health_check_endpoint(tenant_id: str, request: HealthCheckCreate, _=Depends(verify_api_key)): """创建健康检查""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -15136,7 +14392,7 @@ async def create_health_check_endpoint( check_config=request.check_config, interval=request.interval, timeout=request.timeout, - retry_count=request.retry_count + retry_count=request.retry_count, ) return HealthCheckResponse( @@ -15148,7 +14404,7 @@ async def create_health_check_endpoint( interval=check.interval, timeout=check.timeout, is_enabled=check.is_enabled, - created_at=check.created_at + created_at=check.created_at, ) @@ -15171,7 +14427,7 @@ async def list_health_checks_endpoint(tenant_id: str, _=Depends(verify_api_key)) "interval": check.interval, "timeout": check.timeout, "is_enabled": check.is_enabled, - "created_at": check.created_at + "created_at": check.created_at, } for check in checks ] @@ -15192,17 +14448,13 @@ async def execute_health_check_endpoint(check_id: str, _=Depends(verify_api_key) "status": result.status.value, "response_time": result.response_time, "message": result.message, - "checked_at": result.checked_at + "checked_at": result.checked_at, } # Backup API @app.post("/api/v1/ops/backup-jobs", tags=["Operations & Monitoring"]) -async def create_backup_job_endpoint( - tenant_id: str, - request: BackupJobCreate, - _=Depends(verify_api_key) -): +async def create_backup_job_endpoint(tenant_id: str, request: BackupJobCreate, _=Depends(verify_api_key)): """创建备份任务""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -15219,7 +14471,7 @@ async def create_backup_job_endpoint( retention_days=request.retention_days, encryption_enabled=request.encryption_enabled, compression_enabled=request.compression_enabled, - storage_location=request.storage_location + storage_location=request.storage_location, ) return { @@ -15229,7 +14481,7 @@ async def create_backup_job_endpoint( "target_type": job.target_type, "schedule": job.schedule, "is_enabled": job.is_enabled, - "created_at": job.created_at + "created_at": job.created_at, } @@ -15250,7 +14502,7 @@ async def list_backup_jobs_endpoint(tenant_id: str, _=Depends(verify_api_key)): "target_type": job.target_type, "schedule": job.schedule, "is_enabled": job.is_enabled, - "created_at": job.created_at + "created_at": job.created_at, } for job in jobs ] @@ -15273,16 +14525,13 @@ async def execute_backup_endpoint(job_id: str, _=Depends(verify_api_key)): "job_id": record.job_id, "status": record.status.value, "started_at": record.started_at, - "storage_path": record.storage_path + "storage_path": record.storage_path, } @app.get("/api/v1/ops/backup-records", tags=["Operations & Monitoring"]) async def list_backup_records_endpoint( - tenant_id: str, - job_id: str | None = None, - limit: int = 100, - _=Depends(verify_api_key) + tenant_id: str, job_id: str | None = None, limit: int = 100, _=Depends(verify_api_key) ): """获取备份记录列表""" if not OPS_MANAGER_AVAILABLE: @@ -15300,7 +14549,7 @@ async def list_backup_records_endpoint( "checksum": record.checksum, "started_at": record.started_at, "completed_at": record.completed_at, - "storage_path": record.storage_path + "storage_path": record.storage_path, } for record in records ] @@ -15308,12 +14557,7 @@ async def list_backup_records_endpoint( # Cost Optimization API @app.post("/api/v1/ops/cost-reports", tags=["Operations & Monitoring"]) -async def generate_cost_report_endpoint( - tenant_id: str, - year: int, - month: int, - _=Depends(verify_api_key) -): +async def generate_cost_report_endpoint(tenant_id: str, year: int, month: int, _=Depends(verify_api_key)): """生成成本报告""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -15329,7 +14573,7 @@ async def generate_cost_report_endpoint( "breakdown": report.breakdown, "trends": report.trends, "anomalies": report.anomalies, - "created_at": report.created_at + "created_at": report.created_at, } @@ -15352,17 +14596,14 @@ async def get_idle_resources_endpoint(tenant_id: str, _=Depends(verify_api_key)) "estimated_monthly_cost": resource.estimated_monthly_cost, "currency": resource.currency, "reason": resource.reason, - "recommendation": resource.recommendation + "recommendation": resource.recommendation, } for resource in idle_resources ] @app.post("/api/v1/ops/cost-optimization-suggestions", tags=["Operations & Monitoring"]) -async def generate_cost_optimization_suggestions_endpoint( - tenant_id: str, - _=Depends(verify_api_key) -): +async def generate_cost_optimization_suggestions_endpoint(tenant_id: str, _=Depends(verify_api_key)): """生成成本优化建议""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -15382,7 +14623,7 @@ async def generate_cost_optimization_suggestions_endpoint( "difficulty": suggestion.difficulty, "risk_level": suggestion.risk_level, "is_applied": suggestion.is_applied, - "created_at": suggestion.created_at + "created_at": suggestion.created_at, } for suggestion in suggestions ] @@ -15390,9 +14631,7 @@ async def generate_cost_optimization_suggestions_endpoint( @app.get("/api/v1/ops/cost-optimization-suggestions", tags=["Operations & Monitoring"]) async def list_cost_optimization_suggestions_endpoint( - tenant_id: str, - is_applied: bool | None = None, - _=Depends(verify_api_key) + tenant_id: str, is_applied: bool | None = None, _=Depends(verify_api_key) ): """获取成本优化建议列表""" if not OPS_MANAGER_AVAILABLE: @@ -15412,17 +14651,14 @@ async def list_cost_optimization_suggestions_endpoint( "difficulty": suggestion.difficulty, "risk_level": suggestion.risk_level, "is_applied": suggestion.is_applied, - "created_at": suggestion.created_at + "created_at": suggestion.created_at, } for suggestion in suggestions ] @app.post("/api/v1/ops/cost-optimization-suggestions/{suggestion_id}/apply", tags=["Operations & Monitoring"]) -async def apply_cost_optimization_suggestion_endpoint( - suggestion_id: str, - _=Depends(verify_api_key) -): +async def apply_cost_optimization_suggestion_endpoint(suggestion_id: str, _=Depends(verify_api_key)): """应用成本优化建议""" if not OPS_MANAGER_AVAILABLE: raise HTTPException(status_code=503, detail="Operations manager not available") @@ -15440,11 +14676,10 @@ async def apply_cost_optimization_suggestion_endpoint( "id": suggestion.id, "title": suggestion.title, "is_applied": suggestion.is_applied, - "applied_at": suggestion.applied_at - } + "applied_at": suggestion.applied_at, + }, } if __name__ == "__main__": - import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/multimodal_entity_linker.py b/backend/multimodal_entity_linker.py index 141f779..1ce7f43 100644 --- a/backend/multimodal_entity_linker.py +++ b/backend/multimodal_entity_linker.py @@ -80,7 +80,7 @@ class MultimodalEntityLinker: # 模态类型 MODALITIES = ["audio", "video", "image", "document"] - def __init__(self, similarity_threshold: float = 0.85): + def __init__(self, similarity_threshold: float = 0.85) -> None: """ 初始化多模态实体关联器 diff --git a/backend/multimodal_processor.py b/backend/multimodal_processor.py index 868198e..acd5a03 100644 --- a/backend/multimodal_processor.py +++ b/backend/multimodal_processor.py @@ -94,7 +94,7 @@ class VideoProcessingResult: class MultimodalProcessor: """多模态处理器 - 处理视频文件""" - def __init__(self, temp_dir: str = None, frame_interval: int = 5): + def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None: """ 初始化多模态处理器 @@ -401,7 +401,7 @@ class MultimodalProcessor: error_message=str(e), ) - def cleanup(self, video_id: str = None): + def cleanup(self, video_id: str = None) -> None: """ 清理临时文件 diff --git a/backend/neo4j_manager.py b/backend/neo4j_manager.py index c39fa99..d0888be 100644 --- a/backend/neo4j_manager.py +++ b/backend/neo4j_manager.py @@ -107,7 +107,7 @@ class Neo4jManager: self._connect() - def _connect(self): + def _connect(self) -> None: """建立 Neo4j 连接""" if not NEO4J_AVAILABLE: return @@ -121,7 +121,7 @@ class Neo4jManager: logger.error(f"Failed to connect to Neo4j: {e}") self._driver = None - def close(self): + def close(self) -> None: """关闭连接""" if self._driver: self._driver.close() @@ -137,7 +137,7 @@ class Neo4jManager: except BaseException: return False - def init_schema(self): + def init_schema(self) -> None: """初始化图数据库 Schema(约束和索引)""" if not self._driver: logger.error("Neo4j not connected") @@ -178,7 +178,7 @@ class Neo4jManager: # ==================== 数据同步 ==================== - def sync_project(self, project_id: str, project_name: str, project_description: str = ""): + def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None: """同步项目节点到 Neo4j""" if not self._driver: return @@ -196,7 +196,7 @@ class Neo4jManager: description=project_description, ) - def sync_entity(self, entity: GraphEntity): + def sync_entity(self, entity: GraphEntity) -> None: """同步单个实体到 Neo4j""" if not self._driver: return @@ -225,7 +225,7 @@ class Neo4jManager: properties=json.dumps(entity.properties), ) - def sync_entities_batch(self, entities: list[GraphEntity]): + def sync_entities_batch(self, entities: list[GraphEntity]) -> None: """批量同步实体到 Neo4j""" if not self._driver or not entities: return @@ -262,7 +262,7 @@ class Neo4jManager: entities=entities_data, ) - def sync_relation(self, relation: GraphRelation): + def sync_relation(self, relation: GraphRelation) -> None: """同步单个关系到 Neo4j""" if not self._driver: return @@ -286,7 +286,7 @@ class Neo4jManager: properties=json.dumps(relation.properties), ) - def sync_relations_batch(self, relations: list[GraphRelation]): + def sync_relations_batch(self, relations: list[GraphRelation]) -> None: """批量同步关系到 Neo4j""" if not self._driver or not relations: return @@ -318,7 +318,7 @@ class Neo4jManager: relations=relations_data, ) - def delete_entity(self, entity_id: str): + def delete_entity(self, entity_id: str) -> None: """从 Neo4j 删除实体及其关系""" if not self._driver: return @@ -332,7 +332,7 @@ class Neo4jManager: id=entity_id, ) - def delete_project(self, project_id: str): + def delete_project(self, project_id: str) -> None: """从 Neo4j 删除项目及其所有实体和关系""" if not self._driver: return @@ -949,7 +949,7 @@ def get_neo4j_manager() -> Neo4jManager: return _neo4j_manager -def close_neo4j_manager(): +def close_neo4j_manager() -> None: """关闭 Neo4j 连接""" global _neo4j_manager if _neo4j_manager: @@ -958,7 +958,7 @@ def close_neo4j_manager(): # 便捷函数 -def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]): +def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None: """ 同步整个项目到 Neo4j diff --git a/backend/ops_manager.py b/backend/ops_manager.py index b444b45..064e7c5 100644 --- a/backend/ops_manager.py +++ b/backend/ops_manager.py @@ -456,13 +456,13 @@ class OpsManager: self._evaluator_thread = None self._register_default_evaluators() - def _get_db(self): + def _get_db(self) -> None: """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn - def _register_default_evaluators(self): + def _register_default_evaluators(self) -> None: """注册默认的告警评估器""" self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule @@ -1249,7 +1249,7 @@ class OpsManager: return self.get_alert(alert_id) - def _increment_suppression_count(self, alert_id: str): + def _increment_suppression_count(self, alert_id: str) -> None: """增加告警抑制计数""" with self._get_db() as conn: conn.execute( @@ -1262,7 +1262,7 @@ class OpsManager: ) conn.commit() - def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool): + def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool) -> None: """更新告警通知状态""" with self._get_db() as conn: row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone() @@ -1276,7 +1276,7 @@ class OpsManager: ) conn.commit() - def _update_channel_stats(self, channel_id: str, success: bool): + def _update_channel_stats(self, channel_id: str, success: bool) -> None: """更新渠道统计""" now = datetime.now().isoformat() @@ -1769,9 +1769,7 @@ class OpsManager: return self._row_to_scaling_event(row) return None - def update_scaling_event_status( - self, event_id: str, status: str, error_message: str = None - ) -> ScalingEvent | None: + def update_scaling_event_status(self, event_id: str, status: str, error_message: str = None) -> ScalingEvent | None: """更新扩缩容事件状态""" now = datetime.now().isoformat() @@ -2339,7 +2337,7 @@ class OpsManager: return record - def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None): + def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None: """完成备份""" now = datetime.now().isoformat() checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] diff --git a/backend/oss_uploader.py b/backend/oss_uploader.py index af72403..83de463 100644 --- a/backend/oss_uploader.py +++ b/backend/oss_uploader.py @@ -37,7 +37,7 @@ class OSSUploader: url = self.bucket.sign_url("GET", object_name, 3600) return url, object_name - def delete_object(self, object_name: str): + def delete_object(self, object_name: str) -> None: """删除 OSS 对象""" self.bucket.delete_object(object_name) diff --git a/backend/performance_manager.py b/backend/performance_manager.py index 3a177ac..ed617d5 100644 --- a/backend/performance_manager.py +++ b/backend/performance_manager.py @@ -55,7 +55,7 @@ class CacheStats: expired: int = 0 hit_rate: float = 0.0 - def update_hit_rate(self): + def update_hit_rate(self) -> None: """更新命中率""" if self.total_requests > 0: self.hit_rate = round(self.hits / self.total_requests, 4) @@ -194,7 +194,7 @@ class CacheManager: # 初始化缓存统计表 self._init_cache_tables() - def _init_cache_tables(self): + def _init_cache_tables(self) -> None: """初始化缓存统计表""" conn = sqlite3.connect(self.db_path) @@ -234,7 +234,7 @@ class CacheManager: except BaseException: return 1024 # 默认估算 - def _evict_lru(self, required_space: int = 0): + def _evict_lru(self, required_space: int = 0) -> None: """LRU 淘汰策略""" with self.cache_lock: while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache: @@ -444,7 +444,7 @@ class CacheManager: return stats - def save_stats(self): + def save_stats(self) -> None: """保存缓存统计到数据库""" conn = sqlite3.connect(self.db_path) @@ -618,7 +618,7 @@ class DatabaseSharding: # 初始化分片 self._init_shards() - def _init_shards(self): + def _init_shards(self) -> None: """初始化分片""" # 计算每个分片的 key 范围 chars = "0123456789abcdef" @@ -645,7 +645,7 @@ class DatabaseSharding: if not os.path.exists(db_path): self._create_shard_db(db_path) - def _create_shard_db(self, db_path: str): + def _create_shard_db(self, db_path: str) -> None: """创建分片数据库""" conn = sqlite3.connect(db_path) @@ -792,7 +792,7 @@ class DatabaseSharding: print(f"迁移失败: {e}") return False - def _update_shard_stats(self, shard_id: str): + def _update_shard_stats(self, shard_id: str) -> None: """更新分片统计""" shard_info = self.shard_map.get(shard_id) if not shard_info: @@ -923,7 +923,7 @@ class TaskQueue: except Exception as e: print(f"Celery 初始化失败,使用内存任务队列: {e}") - def _init_task_tables(self): + def _init_task_tables(self) -> None: """初始化任务队列表""" conn = sqlite3.connect(self.db_path) @@ -953,7 +953,7 @@ class TaskQueue: """检查任务队列是否可用""" return self.use_celery or True # 内存模式也可用 - def register_handler(self, task_type: str, handler: Callable): + def register_handler(self, task_type: str, handler: Callable) -> None: """注册任务处理器""" self.task_handlers[task_type] = handler @@ -1014,7 +1014,7 @@ class TaskQueue: return task_id - def _execute_task(self, task_id: str): + def _execute_task(self, task_id: str) -> None: """执行任务(内存模式)""" with self.task_lock: task = self.tasks.get(task_id) @@ -1055,7 +1055,7 @@ class TaskQueue: self._update_task_status(task) - def _save_task(self, task: TaskInfo): + def _save_task(self, task: TaskInfo) -> None: """保存任务到数据库""" conn = sqlite3.connect(self.db_path) @@ -1084,7 +1084,7 @@ class TaskQueue: conn.commit() conn.close() - def _update_task_status(self, task: TaskInfo): + def _update_task_status(self, task: TaskInfo) -> None: """更新任务状态""" conn = sqlite3.connect(self.db_path) @@ -1143,9 +1143,7 @@ class TaskQueue: with self.task_lock: return self.tasks.get(task_id) - def list_tasks( - self, status: str | None = None, task_type: str | None = None, limit: int = 100 - ) -> list[TaskInfo]: + def list_tasks(self, status: str | None = None, task_type: str | None = None, limit: int = 100) -> list[TaskInfo]: """列出任务""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row @@ -1333,7 +1331,7 @@ class PerformanceMonitor: if metric_type == "db_query" and duration_ms > self.slow_query_threshold: self._record_slow_query(metric) - def _flush_metrics(self): + def _flush_metrics(self) -> None: """将缓冲区指标写入数据库""" if not self.metrics_buffer: return @@ -1362,12 +1360,12 @@ class PerformanceMonitor: self.metrics_buffer = [] - def _record_slow_query(self, metric: PerformanceMetric): + def _record_slow_query(self, metric: PerformanceMetric) -> None: """记录慢查询""" # 可以发送到专门的慢查询日志或监控系统 print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms") - def _trigger_alert(self, metric: PerformanceMetric): + def _trigger_alert(self, metric: PerformanceMetric) -> None: """触发告警""" alert_data = { "type": "performance_alert", @@ -1382,7 +1380,7 @@ class PerformanceMonitor: except Exception as e: print(f"告警处理失败: {e}") - def register_alert_handler(self, handler: Callable): + def register_alert_handler(self, handler: Callable) -> None: """注册告警处理器""" self.alert_handlers.append(handler) @@ -1585,7 +1583,9 @@ class PerformanceMonitor: # ==================== 性能装饰器 ==================== -def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None): +def cached( + cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None +) -> None: """ 缓存装饰器 @@ -1598,7 +1598,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k def decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> None: # 生成缓存键 if key_func: cache_key = key_func(*args, **kwargs) @@ -1625,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k return decorator -def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None): +def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None: """ 性能监控装饰器 diff --git a/backend/plugin_manager.py b/backend/plugin_manager.py index d330716..a58637f 100644 --- a/backend/plugin_manager.py +++ b/backend/plugin_manager.py @@ -163,7 +163,7 @@ class PluginManager: self._handlers = {} self._register_default_handlers() - def _register_default_handlers(self): + def _register_default_handlers(self) -> None: """注册默认处理器""" self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self) self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu") @@ -371,7 +371,7 @@ class PluginManager: return cursor.rowcount > 0 - def record_plugin_usage(self, plugin_id: str): + def record_plugin_usage(self, plugin_id: str) -> None: """记录插件使用""" conn = self.db.get_conn() now = datetime.now().isoformat() @@ -826,9 +826,6 @@ class BotHandler: async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool: """发送飞书消息""" - import base64 - import hashlib - timestamp = str(int(time.time())) # 生成签名 @@ -851,9 +848,6 @@ class BotHandler: async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool: """发送钉钉消息""" - import base64 - import hashlib - timestamp = str(round(time.time() * 1000)) # 生成签名 @@ -1358,7 +1352,7 @@ class WebDAVSyncManager: _plugin_manager = None -def get_plugin_manager(db_manager=None): +def get_plugin_manager(db_manager=None) -> None: """获取 PluginManager 单例""" global _plugin_manager if _plugin_manager is None: diff --git a/backend/rate_limiter.py b/backend/rate_limiter.py index 5e01b0b..db9142d 100644 --- a/backend/rate_limiter.py +++ b/backend/rate_limiter.py @@ -56,7 +56,7 @@ class SlidingWindowCounter: self._cleanup_old(now) return sum(self.requests.values()) - def _cleanup_old(self, now: int): + def _cleanup_old(self, now: int) -> None: """清理过期的请求记录 - 使用独立锁避免竞态条件""" cutoff = now - self.window_size old_keys = [k for k in list(self.requests.keys()) if k < cutoff] @@ -67,7 +67,7 @@ class SlidingWindowCounter: class RateLimiter: """API 限流器""" - def __init__(self): + def __init__(self) -> None: # key -> SlidingWindowCounter self.counters: dict[str, SlidingWindowCounter] = {} # key -> RateLimitConfig @@ -143,7 +143,7 @@ class RateLimiter: retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0, ) - def reset(self, key: str | None = None): + def reset(self, key: str | None = None) -> None: """重置限流计数器""" if key: self.counters.pop(key, None) @@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter: # 限流装饰器(用于函数级别限流) -def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None): +def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: """ 限流装饰器 diff --git a/backend/search_manager.py b/backend/search_manager.py index ce25840..7c0a8e0 100644 --- a/backend/search_manager.py +++ b/backend/search_manager.py @@ -22,6 +22,7 @@ from enum import Enum class SearchOperator(Enum): """搜索操作符""" + AND = "AND" OR = "OR" NOT = "NOT" @@ -31,15 +32,18 @@ class SearchOperator(Enum): try: from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity + SENTENCE_TRANSFORMERS_AVAILABLE = True except ImportError: SENTENCE_TRANSFORMERS_AVAILABLE = False # ==================== 数据模型 ==================== + @dataclass class SearchResult: """搜索结果数据模型""" + id: str content: str content_type: str # transcript, entity, relation @@ -56,13 +60,14 @@ class SearchResult: "project_id": self.project_id, "score": self.score, "highlights": self.highlights, - "metadata": self.metadata + "metadata": self.metadata, } @dataclass class SemanticSearchResult: """语义搜索结果数据模型""" + id: str content: str content_type: str @@ -78,7 +83,7 @@ class SemanticSearchResult: "content_type": self.content_type, "project_id": self.project_id, "similarity": round(self.similarity, 4), - "metadata": self.metadata + "metadata": self.metadata, } if self.embedding: result["embedding_dim"] = len(self.embedding) @@ -88,6 +93,7 @@ class SemanticSearchResult: @dataclass class EntityPath: """实体关系路径数据模型""" + path_id: str source_entity_id: str source_entity_name: str @@ -110,13 +116,14 @@ class EntityPath: "nodes": self.nodes, "edges": self.edges, "confidence": self.confidence, - "path_description": self.path_description + "path_description": self.path_description, } @dataclass class KnowledgeGap: """知识缺口数据模型""" + gap_id: str gap_type: str # missing_attribute, sparse_relation, isolated_entity, incomplete_entity entity_id: str | None @@ -137,13 +144,14 @@ class KnowledgeGap: "severity": self.severity, "suggestions": self.suggestions, "related_entities": self.related_entities, - "metadata": self.metadata + "metadata": self.metadata, } @dataclass class SearchIndex: """搜索索引数据模型""" + id: str content_id: str content_type: str @@ -157,6 +165,7 @@ class SearchIndex: @dataclass class TextEmbedding: """文本 Embedding 数据模型""" + id: str content_id: str content_type: str @@ -168,6 +177,7 @@ class TextEmbedding: # ==================== 全文搜索 ==================== + class FullTextSearch: """ 全文搜索模块 @@ -189,7 +199,7 @@ class FullTextSearch: conn.row_factory = sqlite3.Row return conn - def _init_search_tables(self): + def _init_search_tables(self) -> None: """初始化搜索相关表""" conn = self._get_conn() @@ -239,7 +249,7 @@ class FullTextSearch: # 清理文本 text = text.lower() # 提取中文字符、英文单词和数字 - tokens = re.findall(r'[\u4e00-\u9fa5]+|[a-z]+|\d+', text) + tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text) return tokens def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]: @@ -259,8 +269,7 @@ class FullTextSearch: return dict(positions) - def index_content(self, content_id: str, content_type: str, - project_id: str, text: str) -> bool: + def index_content(self, content_id: str, content_type: str, project_id: str, text: str) -> bool: """ 为内容创建搜索索引 @@ -294,28 +303,35 @@ class FullTextSearch: now = datetime.now().isoformat() # 保存索引 - conn.execute(""" + conn.execute( + """ INSERT OR REPLACE INTO search_indexes (id, content_id, content_type, project_id, tokens, token_positions, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - index_id, content_id, content_type, project_id, - json.dumps(tokens, ensure_ascii=False), - json.dumps(token_positions, ensure_ascii=False), - now, now - )) + """, + ( + index_id, + content_id, + content_type, + project_id, + json.dumps(tokens, ensure_ascii=False), + json.dumps(token_positions, ensure_ascii=False), + now, + now, + ), + ) # 保存词频统计 for token, freq in token_freq.items(): positions = token_positions.get(token, []) - conn.execute(""" + conn.execute( + """ INSERT OR REPLACE INTO search_term_freq (term, content_id, content_type, project_id, frequency, positions) VALUES (?, ?, ?, ?, ?, ?) - """, ( - token, content_id, content_type, project_id, freq, - json.dumps(positions, ensure_ascii=False) - )) + """, + (token, content_id, content_type, project_id, freq, json.dumps(positions, ensure_ascii=False)), + ) conn.commit() conn.close() @@ -325,9 +341,14 @@ class FullTextSearch: print(f"索引创建失败: {e}") return False - def search(self, query: str, project_id: str | None = None, - content_types: list[str] | None = None, - limit: int = 20, offset: int = 0) -> list[SearchResult]: + def search( + self, + query: str, + project_id: str | None = None, + content_types: list[str] | None = None, + limit: int = 20, + offset: int = 0, + ) -> list[SearchResult]: """ 全文搜索 @@ -345,9 +366,7 @@ class FullTextSearch: parsed_query = self._parse_boolean_query(query) # 执行搜索 - results = self._execute_boolean_search( - parsed_query, project_id, content_types - ) + results = self._execute_boolean_search(parsed_query, project_id, content_types) # 计算相关性分数 scored_results = self._score_results(results, parsed_query) @@ -355,7 +374,7 @@ class FullTextSearch: # 排序和分页 scored_results.sort(key=lambda x: x.score, reverse=True) - return scored_results[offset:offset + limit] + return scored_results[offset : offset + limit] def _parse_boolean_query(self, query: str) -> dict: """ @@ -371,7 +390,7 @@ class FullTextSearch: # 提取短语(引号内的内容) phrases = re.findall(r'"([^"]+)"', query) - query_without_phrases = re.sub(r'"[^"]+"', '', query) + query_without_phrases = re.sub(r'"[^"]+"', "", query) # 解析布尔操作 and_terms = [] @@ -379,13 +398,13 @@ class FullTextSearch: not_terms = [] # 处理 NOT - not_pattern = r'(?:NOT\s+|\-)(\w+)' + not_pattern = r"(?:NOT\s+|\-)(\w+)" not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) not_terms.extend(not_matches) - query_without_phrases = re.sub(not_pattern, '', query_without_phrases, flags=re.IGNORECASE) + query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE) # 处理 OR - or_parts = re.split(r'\s+OR\s+', query_without_phrases, flags=re.IGNORECASE) + or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE) if len(or_parts) > 1: or_terms = [p.strip() for p in or_parts[1:] if p.strip()] query_without_phrases = or_parts[0] @@ -393,16 +412,11 @@ class FullTextSearch: # 剩余的作为 AND 条件 and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()] - return { - "and": and_terms + phrases, - "or": or_terms, - "not": not_terms, - "phrases": phrases - } + return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases} - def _execute_boolean_search(self, parsed_query: dict, - project_id: str | None = None, - content_types: list[str] | None = None) -> list[dict]: + def _execute_boolean_search( + self, parsed_query: dict, project_id: str | None = None, content_types: list[str] | None = None + ) -> list[dict]: """执行布尔搜索""" conn = self._get_conn() @@ -415,7 +429,7 @@ class FullTextSearch: params.append(project_id) if content_types: - placeholders = ','.join(['?' for _ in content_types]) + placeholders = ",".join(["?" for _ in content_types]) base_where.append(f"content_type IN ({placeholders})") params.extend(content_types) @@ -427,13 +441,16 @@ class FullTextSearch: # 处理 AND 条件 if parsed_query["and"]: for term in parsed_query["and"]: - term_results = conn.execute(f""" + term_results = conn.execute( + f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq WHERE term = ? AND {base_where_str} - """, [term] + params).fetchall() + """, + [term] + params, + ).fetchall() - term_contents = {(r['content_id'], r['content_type']) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} if not candidates: candidates = term_contents @@ -443,13 +460,16 @@ class FullTextSearch: # 处理 OR 条件 if parsed_query["or"]: for term in parsed_query["or"]: - term_results = conn.execute(f""" + term_results = conn.execute( + f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq WHERE term = ? AND {base_where_str} - """, [term] + params).fetchall() + """, + [term] + params, + ).fetchall() - term_contents = {(r['content_id'], r['content_type']) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} candidates |= term_contents # 并集 # 如果没有 AND 和 OR,但有 phrases,使用 phrases @@ -459,13 +479,16 @@ class FullTextSearch: if phrase_tokens: # 查找包含所有短语的文档 for token in phrase_tokens: - term_results = conn.execute(f""" + term_results = conn.execute( + f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq WHERE term = ? AND {base_where_str} - """, [token] + params).fetchall() + """, + [token] + params, + ).fetchall() - term_contents = {(r['content_id'], r['content_type']) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} if not candidates: candidates = term_contents @@ -475,13 +498,16 @@ class FullTextSearch: # 处理 NOT 条件(排除) if parsed_query["not"]: for term in parsed_query["not"]: - term_results = conn.execute(f""" + term_results = conn.execute( + f""" SELECT content_id, content_type FROM search_term_freq WHERE term = ? AND {base_where_str} - """, [term] + params).fetchall() + """, + [term] + params, + ).fetchall() - term_contents = {(r['content_id'], r['content_type']) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} candidates -= term_contents # 差集 # 获取完整内容 @@ -490,33 +516,28 @@ class FullTextSearch: # 获取原始内容 content = self._get_content_by_id(conn, content_id, content_type) if content: - results.append({ - "id": content_id, - "content_type": content_type, - "project_id": project_id or self._get_project_id(conn, content_id, content_type), - "content": content, - "terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"] - }) + results.append( + { + "id": content_id, + "content_type": content_type, + "project_id": project_id or self._get_project_id(conn, content_id, content_type), + "content": content, + "terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"], + } + ) conn.close() return results - def _get_content_by_id(self, conn: sqlite3.Connection, - content_id: str, content_type: str) -> str | None: + def _get_content_by_id(self, conn: sqlite3.Connection, content_id: str, content_type: str) -> str | None: """根据ID获取内容""" try: if content_type == "transcript": - row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", - (content_id,) - ).fetchone() - return row['full_text'] if row else None + row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (content_id,)).fetchone() + return row["full_text"] if row else None elif content_type == "entity": - row = conn.execute( - "SELECT name, definition FROM entities WHERE id = ?", - (content_id,) - ).fetchone() + row = conn.execute("SELECT name, definition FROM entities WHERE id = ?", (content_id,)).fetchone() if row: return f"{row['name']} {row['definition'] or ''}" return None @@ -529,7 +550,7 @@ class FullTextSearch: JOIN entities e1 ON r.source_entity_id = e1.id JOIN entities e2 ON r.target_entity_id = e2.id WHERE r.id = ?""", - (content_id,) + (content_id,), ).fetchone() if row: return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}" @@ -540,29 +561,19 @@ class FullTextSearch: print(f"获取内容失败: {e}") return None - def _get_project_id(self, conn: sqlite3.Connection, - content_id: str, content_type: str) -> str | None: + def _get_project_id(self, conn: sqlite3.Connection, content_id: str, content_type: str) -> str | None: """获取内容所属的项目ID""" try: if content_type == "transcript": - row = conn.execute( - "SELECT project_id FROM transcripts WHERE id = ?", - (content_id,) - ).fetchone() + row = conn.execute("SELECT project_id FROM transcripts WHERE id = ?", (content_id,)).fetchone() elif content_type == "entity": - row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", - (content_id,) - ).fetchone() + row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (content_id,)).fetchone() elif content_type == "relation": - row = conn.execute( - "SELECT project_id FROM entity_relations WHERE id = ?", - (content_id,) - ).fetchone() + row = conn.execute("SELECT project_id FROM entity_relations WHERE id = ?", (content_id,)).fetchone() else: return None - return row['project_id'] if row else None + return row["project_id"] if row else None except Exception: return None @@ -612,20 +623,21 @@ class FullTextSearch: # 归一化分数 score = min(score / max(len(all_terms), 1), 10.0) - scored.append(SearchResult( - id=result["id"], - content=result["content"], - content_type=result["content_type"], - project_id=result["project_id"], - score=round(score, 4), - highlights=highlights[:10], # 限制高亮数量 - metadata={} - )) + scored.append( + SearchResult( + id=result["id"], + content=result["content"], + content_type=result["content_type"], + project_id=result["project_id"], + score=round(score, 4), + highlights=highlights[:10], # 限制高亮数量 + metadata={}, + ) + ) return scored - def highlight_text(self, text: str, query: str, - max_length: int = 300) -> str: + def highlight_text(self, text: str, query: str, max_length: int = 300) -> str: """ 高亮文本中的关键词 @@ -671,14 +683,12 @@ class FullTextSearch: # 删除索引 conn.execute( - "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?", - (content_id, content_type) + "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?", (content_id, content_type) ) # 删除词频统计 conn.execute( - "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?", - (content_id, content_type) + "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?", (content_id, content_type) ) conn.commit() @@ -696,26 +706,24 @@ class FullTextSearch: try: # 索引转录文本 transcripts = conn.execute( - "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", - (project_id,) + "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id,) ).fetchall() for t in transcripts: - if t['full_text']: - if self.index_content(t['id'], 'transcript', t['project_id'], t['full_text']): + if t["full_text"]: + if self.index_content(t["id"], "transcript", t["project_id"], t["full_text"]): stats["transcripts"] += 1 else: stats["errors"] += 1 # 索引实体 entities = conn.execute( - "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", - (project_id,) + "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id,) ).fetchall() for e in entities: text = f"{e['name']} {e['definition'] or ''}" - if self.index_content(e['id'], 'entity', e['project_id'], text): + if self.index_content(e["id"], "entity", e["project_id"], text): stats["entities"] += 1 else: stats["errors"] += 1 @@ -728,12 +736,12 @@ class FullTextSearch: JOIN entities e1 ON r.source_entity_id = e1.id JOIN entities e2 ON r.target_entity_id = e2.id WHERE r.project_id = ?""", - (project_id,) + (project_id,), ).fetchall() for r in relations: text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}" - if self.index_content(r['id'], 'relation', r['project_id'], text): + if self.index_content(r["id"], "relation", r["project_id"], text): stats["relations"] += 1 else: stats["errors"] += 1 @@ -748,6 +756,7 @@ class FullTextSearch: # ==================== 语义搜索 ==================== + class SemanticSearch: """ 语义搜索模块 @@ -759,8 +768,7 @@ class SemanticSearch: - 语义相似内容推荐 """ - def __init__(self, db_path: str = "insightflow.db", - model_name: str = "paraphrase-multilingual-MiniLM-L12-v2"): + def __init__(self, db_path: str = "insightflow.db", model_name: str = "paraphrase-multilingual-MiniLM-L12-v2"): self.db_path = db_path self.model_name = model_name self.model = None @@ -780,7 +788,7 @@ class SemanticSearch: conn.row_factory = sqlite3.Row return conn - def _init_embedding_tables(self): + def _init_embedding_tables(self) -> None: """初始化 embedding 相关表""" conn = self._get_conn() @@ -832,8 +840,7 @@ class SemanticSearch: print(f"生成 embedding 失败: {e}") return None - def index_embedding(self, content_id: str, content_type: str, - project_id: str, text: str) -> bool: + def index_embedding(self, content_id: str, content_type: str, project_id: str, text: str) -> bool: """ 为内容生成并保存 embedding @@ -858,16 +865,22 @@ class SemanticSearch: embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] - conn.execute(""" + conn.execute( + """ INSERT OR REPLACE INTO embeddings (id, content_id, content_type, project_id, embedding, model_name, created_at) VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - embedding_id, content_id, content_type, project_id, - json.dumps(embedding), - self.model_name, - datetime.now().isoformat() - )) + """, + ( + embedding_id, + content_id, + content_type, + project_id, + json.dumps(embedding), + self.model_name, + datetime.now().isoformat(), + ), + ) conn.commit() conn.close() @@ -877,9 +890,14 @@ class SemanticSearch: print(f"索引 embedding 失败: {e}") return False - def search(self, query: str, project_id: str | None = None, - content_types: list[str] | None = None, - top_k: int = 10, threshold: float = 0.5) -> list[SemanticSearchResult]: + def search( + self, + query: str, + project_id: str | None = None, + content_types: list[str] | None = None, + top_k: int = 10, + threshold: float = 0.5, + ) -> list[SemanticSearchResult]: """ 语义搜索 @@ -912,17 +930,20 @@ class SemanticSearch: params.append(project_id) if content_types: - placeholders = ','.join(['?' for _ in content_types]) + placeholders = ",".join(["?" for _ in content_types]) where_clauses.append(f"content_type IN ({placeholders})") params.extend(content_types) where_str = " AND ".join(where_clauses) if where_clauses else "1=1" - rows = conn.execute(f""" + rows = conn.execute( + f""" SELECT content_id, content_type, project_id, embedding FROM embeddings WHERE {where_str} - """, params).fetchall() + """, + params, + ).fetchall() conn.close() @@ -932,24 +953,26 @@ class SemanticSearch: for row in rows: try: - content_embedding = json.loads(row['embedding']) + content_embedding = json.loads(row["embedding"]) # 计算余弦相似度 similarity = cosine_similarity(query_vec, [content_embedding])[0][0] if similarity >= threshold: # 获取原始内容 - content = self._get_content_text(row['content_id'], row['content_type']) + content = self._get_content_text(row["content_id"], row["content_type"]) - results.append(SemanticSearchResult( - id=row['content_id'], - content=content or "", - content_type=row['content_type'], - project_id=row['project_id'], - similarity=float(similarity), - embedding=None, # 不返回 embedding 以节省带宽 - metadata={} - )) + results.append( + SemanticSearchResult( + id=row["content_id"], + content=content or "", + content_type=row["content_type"], + project_id=row["project_id"], + similarity=float(similarity), + embedding=None, # 不返回 embedding 以节省带宽 + metadata={}, + ) + ) except Exception as e: print(f"计算相似度失败: {e}") continue @@ -964,17 +987,11 @@ class SemanticSearch: try: if content_type == "transcript": - row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", - (content_id,) - ).fetchone() - result = row['full_text'] if row else None + row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (content_id,)).fetchone() + result = row["full_text"] if row else None elif content_type == "entity": - row = conn.execute( - "SELECT name, definition FROM entities WHERE id = ?", - (content_id,) - ).fetchone() + row = conn.execute("SELECT name, definition FROM entities WHERE id = ?", (content_id,)).fetchone() result = f"{row['name']}: {row['definition']}" if row else None elif content_type == "relation": @@ -985,7 +1002,7 @@ class SemanticSearch: JOIN entities e1 ON r.source_entity_id = e1.id JOIN entities e2 ON r.target_entity_id = e2.id WHERE r.id = ?""", - (content_id,) + (content_id,), ).fetchone() result = f"{row['source_name']} {row['relation_type']} {row['target_name']}" if row else None @@ -1000,8 +1017,7 @@ class SemanticSearch: print(f"获取内容失败: {e}") return None - def find_similar_content(self, content_id: str, content_type: str, - top_k: int = 5) -> list[SemanticSearchResult]: + def find_similar_content(self, content_id: str, content_type: str, top_k: int = 5) -> list[SemanticSearchResult]: """ 查找与指定内容相似的内容 @@ -1021,22 +1037,22 @@ class SemanticSearch: row = conn.execute( "SELECT embedding, project_id FROM embeddings WHERE content_id = ? AND content_type = ?", - (content_id, content_type) + (content_id, content_type), ).fetchone() if not row: conn.close() return [] - source_embedding = json.loads(row['embedding']) - project_id = row['project_id'] + source_embedding = json.loads(row["embedding"]) + project_id = row["project_id"] # 获取其他内容的 embedding rows = conn.execute( """SELECT content_id, content_type, project_id, embedding FROM embeddings WHERE project_id = ? AND (content_id != ? OR content_type != ?)""", - (project_id, content_id, content_type) + (project_id, content_id, content_type), ).fetchall() conn.close() @@ -1047,19 +1063,21 @@ class SemanticSearch: for row in rows: try: - content_embedding = json.loads(row['embedding']) + content_embedding = json.loads(row["embedding"]) similarity = cosine_similarity(source_vec, [content_embedding])[0][0] - content = self._get_content_text(row['content_id'], row['content_type']) + content = self._get_content_text(row["content_id"], row["content_type"]) - results.append(SemanticSearchResult( - id=row['content_id'], - content=content or "", - content_type=row['content_type'], - project_id=row['project_id'], - similarity=float(similarity), - metadata={} - )) + results.append( + SemanticSearchResult( + id=row["content_id"], + content=content or "", + content_type=row["content_type"], + project_id=row["project_id"], + similarity=float(similarity), + metadata={}, + ) + ) except Exception: continue @@ -1070,10 +1088,7 @@ class SemanticSearch: """删除内容的 embedding""" try: conn = self._get_conn() - conn.execute( - "DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", - (content_id, content_type) - ) + conn.execute("DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type)) conn.commit() conn.close() return True @@ -1084,6 +1099,7 @@ class SemanticSearch: # ==================== 实体关系路径发现 ==================== + class EntityPathDiscovery: """ 实体关系路径发现模块 @@ -1104,9 +1120,7 @@ class EntityPathDiscovery: conn.row_factory = sqlite3.Row return conn - def find_shortest_path(self, source_entity_id: str, - target_entity_id: str, - max_depth: int = 5) -> EntityPath | None: + def find_shortest_path(self, source_entity_id: str, target_entity_id: str, max_depth: int = 5) -> EntityPath | None: """ 查找两个实体之间的最短路径(BFS算法) @@ -1121,21 +1135,17 @@ class EntityPathDiscovery: conn = self._get_conn() # 获取项目ID - row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", - (source_entity_id,) - ).fetchone() + row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)).fetchone() if not row: conn.close() return None - project_id = row['project_id'] + project_id = row["project_id"] # 验证目标实体也在同一项目 row = conn.execute( - "SELECT 1 FROM entities WHERE id = ? AND project_id = ?", - (target_entity_id, project_id) + "SELECT 1 FROM entities WHERE id = ? AND project_id = ?", (target_entity_id, project_id) ).fetchone() if not row: @@ -1158,7 +1168,8 @@ class EntityPathDiscovery: return self._build_path_object(path, project_id) # 获取邻居 - neighbors = conn.execute(""" + neighbors = conn.execute( + """ SELECT target_entity_id as neighbor_id, relation_type, evidence FROM entity_relations WHERE source_entity_id = ? AND project_id = ? @@ -1166,10 +1177,12 @@ class EntityPathDiscovery: SELECT source_entity_id as neighbor_id, relation_type, evidence FROM entity_relations WHERE target_entity_id = ? AND project_id = ? - """, (current_id, project_id, current_id, project_id)).fetchall() + """, + (current_id, project_id, current_id, project_id), + ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor['neighbor_id'] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited.add(neighbor_id) queue.append((neighbor_id, path + [neighbor_id])) @@ -1177,10 +1190,9 @@ class EntityPathDiscovery: conn.close() return None - def find_all_paths(self, source_entity_id: str, - target_entity_id: str, - max_depth: int = 4, - max_paths: int = 10) -> list[EntityPath]: + def find_all_paths( + self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10 + ) -> list[EntityPath]: """ 查找两个实体之间的所有路径(限制数量和深度) @@ -1196,21 +1208,17 @@ class EntityPathDiscovery: conn = self._get_conn() # 获取项目ID - row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", - (source_entity_id,) - ).fetchone() + row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)).fetchone() if not row: conn.close() return [] - project_id = row['project_id'] + project_id = row["project_id"] paths = [] - def dfs(current_id: str, target_id: str, - path: list[str], visited: set[str], depth: int): + def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int): if depth > max_depth: return @@ -1219,7 +1227,8 @@ class EntityPathDiscovery: return # 获取邻居 - neighbors = conn.execute(""" + neighbors = conn.execute( + """ SELECT target_entity_id as neighbor_id FROM entity_relations WHERE source_entity_id = ? AND project_id = ? @@ -1227,10 +1236,12 @@ class EntityPathDiscovery: SELECT source_entity_id as neighbor_id FROM entity_relations WHERE target_entity_id = ? AND project_id = ? - """, (current_id, project_id, current_id, project_id)).fetchall() + """, + (current_id, project_id, current_id, project_id), + ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor['neighbor_id'] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited and len(paths) < max_paths: visited.add(neighbor_id) path.append(neighbor_id) @@ -1246,24 +1257,16 @@ class EntityPathDiscovery: # 构建路径对象 return [self._build_path_object(path, project_id) for path in paths] - def _build_path_object(self, entity_ids: list[str], - project_id: str) -> EntityPath: + def _build_path_object(self, entity_ids: list[str], project_id: str) -> EntityPath: """构建路径对象""" conn = self._get_conn() # 获取实体信息 nodes = [] for entity_id in entity_ids: - row = conn.execute( - "SELECT id, name, type FROM entities WHERE id = ?", - (entity_id,) - ).fetchone() + row = conn.execute("SELECT id, name, type FROM entities WHERE id = ?", (entity_id,)).fetchone() if row: - nodes.append({ - "id": row['id'], - "name": row['name'], - "type": row['type'] - }) + nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]}) # 获取边信息 edges = [] @@ -1271,27 +1274,32 @@ class EntityPathDiscovery: source_id = entity_ids[i] target_id = entity_ids[i + 1] - row = conn.execute(""" + row = conn.execute( + """ SELECT id, relation_type, evidence FROM entity_relations WHERE ((source_entity_id = ? AND target_entity_id = ?) OR (source_entity_id = ? AND target_entity_id = ?)) AND project_id = ? - """, (source_id, target_id, target_id, source_id, project_id)).fetchone() + """, + (source_id, target_id, target_id, source_id, project_id), + ).fetchone() if row: - edges.append({ - "id": row['id'], - "source": source_id, - "target": target_id, - "relation_type": row['relation_type'], - "evidence": row['evidence'] - }) + edges.append( + { + "id": row["id"], + "source": source_id, + "target": target_id, + "relation_type": row["relation_type"], + "evidence": row["evidence"], + } + ) conn.close() # 生成路径描述 - node_names = [n['name'] for n in nodes] + node_names = [n["name"] for n in nodes] path_desc = " → ".join(node_names) # 计算置信度(基于路径长度和关系数量) @@ -1300,18 +1308,17 @@ class EntityPathDiscovery: return EntityPath( path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", source_entity_id=entity_ids[0], - source_entity_name=nodes[0]['name'] if nodes else "", + source_entity_name=nodes[0]["name"] if nodes else "", target_entity_id=entity_ids[-1], - target_entity_name=nodes[-1]['name'] if nodes else "", + target_entity_name=nodes[-1]["name"] if nodes else "", path_length=len(entity_ids) - 1, nodes=nodes, edges=edges, confidence=round(confidence, 4), - path_description=path_desc + path_description=path_desc, ) - def find_multi_hop_relations(self, entity_id: str, - max_hops: int = 3) -> list[dict]: + def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: """ 查找实体的多跳关系 @@ -1325,17 +1332,14 @@ class EntityPathDiscovery: conn = self._get_conn() # 获取项目ID - row = conn.execute( - "SELECT project_id, name FROM entities WHERE id = ?", - (entity_id,) - ).fetchone() + row = conn.execute("SELECT project_id, name FROM entities WHERE id = ?", (entity_id,)).fetchone() if not row: conn.close() return [] - project_id = row['project_id'] - row['name'] + project_id = row["project_id"] + row["name"] # BFS 收集多跳关系 visited = {entity_id: 0} @@ -1349,7 +1353,8 @@ class EntityPathDiscovery: continue # 获取邻居 - neighbors = conn.execute(""" + neighbors = conn.execute( + """ SELECT CASE WHEN source_entity_id = ? THEN target_entity_id @@ -1360,10 +1365,12 @@ class EntityPathDiscovery: FROM entity_relations WHERE (source_entity_id = ? OR target_entity_id = ?) AND project_id = ? - """, (current_id, current_id, current_id, project_id)).fetchall() + """, + (current_id, current_id, current_id, project_id), + ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor['neighbor_id'] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited[neighbor_id] = depth + 1 @@ -1371,29 +1378,31 @@ class EntityPathDiscovery: # 获取邻居信息 neighbor_info = conn.execute( - "SELECT name, type FROM entities WHERE id = ?", - (neighbor_id,) + "SELECT name, type FROM entities WHERE id = ?", (neighbor_id,) ).fetchone() if neighbor_info: - relations.append({ - "entity_id": neighbor_id, - "entity_name": neighbor_info['name'], - "entity_type": neighbor_info['type'], - "hops": depth + 1, - "relation_type": neighbor['relation_type'], - "evidence": neighbor['evidence'], - "path": self._get_path_to_entity(entity_id, neighbor_id, project_id, conn) - }) + relations.append( + { + "entity_id": neighbor_id, + "entity_name": neighbor_info["name"], + "entity_type": neighbor_info["type"], + "hops": depth + 1, + "relation_type": neighbor["relation_type"], + "evidence": neighbor["evidence"], + "path": self._get_path_to_entity(entity_id, neighbor_id, project_id, conn), + } + ) conn.close() # 按跳数排序 - relations.sort(key=lambda x: x['hops']) + relations.sort(key=lambda x: x["hops"]) return relations - def _get_path_to_entity(self, source_id: str, target_id: str, - project_id: str, conn: sqlite3.Connection) -> list[str]: + def _get_path_to_entity( + self, source_id: str, target_id: str, project_id: str, conn: sqlite3.Connection + ) -> list[str]: """获取从源实体到目标实体的路径(简化版)""" # BFS 找路径 visited = {source_id} @@ -1408,7 +1417,8 @@ class EntityPathDiscovery: if len(path) > 5: # 限制路径长度 continue - neighbors = conn.execute(""" + neighbors = conn.execute( + """ SELECT CASE WHEN source_entity_id = ? THEN target_entity_id @@ -1417,10 +1427,12 @@ class EntityPathDiscovery: FROM entity_relations WHERE (source_entity_id = ? OR target_entity_id = ?) AND project_id = ? - """, (current, current, current, project_id)).fetchall() + """, + (current, current, current, project_id), + ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor['neighbor_id'] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited.add(neighbor_id) queue.append((neighbor_id, path + [neighbor_id])) @@ -1440,30 +1452,34 @@ class EntityPathDiscovery: # 节点数据 nodes = [] for node in path.nodes: - nodes.append({ - "id": node["id"], - "name": node["name"], - "type": node["type"], - "is_source": node["id"] == path.source_entity_id, - "is_target": node["id"] == path.target_entity_id - }) + nodes.append( + { + "id": node["id"], + "name": node["name"], + "type": node["type"], + "is_source": node["id"] == path.source_entity_id, + "is_target": node["id"] == path.target_entity_id, + } + ) # 边数据 links = [] for edge in path.edges: - links.append({ - "source": edge["source"], - "target": edge["target"], - "relation_type": edge["relation_type"], - "evidence": edge["evidence"] - }) + links.append( + { + "source": edge["source"], + "target": edge["target"], + "relation_type": edge["relation_type"], + "evidence": edge["evidence"], + } + ) return { "nodes": nodes, "links": links, "path_description": path.path_description, "path_length": path.path_length, - "confidence": path.confidence + "confidence": path.confidence, } def analyze_path_centrality(self, project_id: str) -> list[dict]: @@ -1479,19 +1495,17 @@ class EntityPathDiscovery: conn = self._get_conn() # 获取所有实体 - entities = conn.execute( - "SELECT id, name FROM entities WHERE project_id = ?", - (project_id,) - ).fetchall() + entities = conn.execute("SELECT id, name FROM entities WHERE project_id = ?", (project_id,)).fetchall() # 计算每个实体作为桥梁的次数 bridge_scores = [] for entity in entities: - entity_id = entity['id'] + entity_id = entity["id"] # 计算该实体连接的不同群组数量 - neighbors = conn.execute(""" + neighbors = conn.execute( + """ SELECT CASE WHEN source_entity_id = ? THEN target_entity_id @@ -1500,13 +1514,16 @@ class EntityPathDiscovery: FROM entity_relations WHERE (source_entity_id = ? OR target_entity_id = ?) AND project_id = ? - """, (entity_id, entity_id, entity_id, project_id)).fetchall() + """, + (entity_id, entity_id, entity_id, project_id), + ).fetchall() - neighbor_ids = {n['neighbor_id'] for n in neighbors} + neighbor_ids = {n["neighbor_id"] for n in neighbors} # 计算邻居之间的连接数(用于评估桥接程度) if len(neighbor_ids) > 1: - connections = conn.execute(f""" + connections = conn.execute( + f""" SELECT COUNT(*) as count FROM entity_relations WHERE ((source_entity_id IN ({','.join(['?' for _ in neighbor_ids])}) @@ -1514,29 +1531,34 @@ class EntityPathDiscovery: OR (target_entity_id IN ({','.join(['?' for _ in neighbor_ids])}) AND source_entity_id IN ({','.join(['?' for _ in neighbor_ids])}))) AND project_id = ? - """, list(neighbor_ids) * 4 + [project_id]).fetchone() + """, + list(neighbor_ids) * 4 + [project_id], + ).fetchone() # 桥接分数 = 邻居数量 / (邻居间连接数 + 1) - bridge_score = len(neighbor_ids) / (connections['count'] + 1) + bridge_score = len(neighbor_ids) / (connections["count"] + 1) else: bridge_score = 0 - bridge_scores.append({ - "entity_id": entity_id, - "entity_name": entity['name'], - "neighbor_count": len(neighbor_ids), - "bridge_score": round(bridge_score, 4) - }) + bridge_scores.append( + { + "entity_id": entity_id, + "entity_name": entity["name"], + "neighbor_count": len(neighbor_ids), + "bridge_score": round(bridge_score, 4), + } + ) conn.close() # 按桥接分数排序 - bridge_scores.sort(key=lambda x: x['bridge_score'], reverse=True) + bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True) return bridge_scores[:20] # 返回前20 # ==================== 知识缺口识别 ==================== + class KnowledgeGapDetection: """ 知识缺口识别模块 @@ -1597,36 +1619,31 @@ class KnowledgeGapDetection: # 获取项目的属性模板 templates = conn.execute( - "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", - (project_id,) + "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", (project_id,) ).fetchall() if not templates: conn.close() return [] - required_template_ids = {t['id'] for t in templates if t['is_required']} + required_template_ids = {t["id"] for t in templates if t["is_required"]} if not required_template_ids: conn.close() return [] # 检查每个实体的属性完整性 - entities = conn.execute( - "SELECT id, name FROM entities WHERE project_id = ?", - (project_id,) - ).fetchall() + entities = conn.execute("SELECT id, name FROM entities WHERE project_id = ?", (project_id,)).fetchall() for entity in entities: - entity_id = entity['id'] + entity_id = entity["id"] # 获取实体已有的属性 existing_attrs = conn.execute( - "SELECT template_id FROM entity_attributes WHERE entity_id = ?", - (entity_id,) + "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,) ).fetchall() - existing_template_ids = {a['template_id'] for a in existing_attrs} + existing_template_ids = {a["template_id"] for a in existing_attrs} # 找出缺失的必需属性 missing_templates = required_template_ids - existing_template_ids @@ -1635,27 +1652,28 @@ class KnowledgeGapDetection: missing_names = [] for template_id in missing_templates: template = conn.execute( - "SELECT name FROM attribute_templates WHERE id = ?", - (template_id,) + "SELECT name FROM attribute_templates WHERE id = ?", (template_id,) ).fetchone() if template: - missing_names.append(template['name']) + missing_names.append(template["name"]) if missing_names: - gaps.append(KnowledgeGap( - gap_id=f"gap_attr_{entity_id}", - gap_type="missing_attribute", - entity_id=entity_id, - entity_name=entity['name'], - description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", - severity="medium", - suggestions=[ - f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}", - "检查属性模板定义是否合理" - ], - related_entities=[], - metadata={"missing_attributes": missing_names} - )) + gaps.append( + KnowledgeGap( + gap_id=f"gap_attr_{entity_id}", + gap_type="missing_attribute", + entity_id=entity_id, + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", + severity="medium", + suggestions=[ + f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}", + "检查属性模板定义是否合理", + ], + related_entities=[], + metadata={"missing_attributes": missing_names}, + ) + ) conn.close() return gaps @@ -1666,28 +1684,29 @@ class KnowledgeGapDetection: gaps = [] # 获取所有实体及其关系数量 - entities = conn.execute( - "SELECT id, name, type FROM entities WHERE project_id = ?", - (project_id,) - ).fetchall() + entities = conn.execute("SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,)).fetchall() for entity in entities: - entity_id = entity['id'] + entity_id = entity["id"] # 计算关系数量 - relation_count = conn.execute(""" + relation_count = conn.execute( + """ SELECT COUNT(*) as count FROM entity_relations WHERE (source_entity_id = ? OR target_entity_id = ?) AND project_id = ? - """, (entity_id, entity_id, project_id)).fetchone()['count'] + """, + (entity_id, entity_id, project_id), + ).fetchone()["count"] # 根据实体类型判断阈值 - threshold = 1 if entity['type'] in ['PERSON', 'ORG'] else 0 + threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0 if relation_count <= threshold: # 查找潜在的相关实体 - potential_related = conn.execute(""" + potential_related = conn.execute( + """ SELECT e.id, e.name FROM entities e JOIN transcripts t ON t.project_id = e.project_id @@ -1695,26 +1714,30 @@ class KnowledgeGapDetection: AND e.id != ? AND t.full_text LIKE ? LIMIT 5 - """, (project_id, entity_id, f"%{entity['name']}%")).fetchall() + """, + (project_id, entity_id, f"%{entity['name']}%"), + ).fetchall() - gaps.append(KnowledgeGap( - gap_id=f"gap_sparse_{entity_id}", - gap_type="sparse_relation", - entity_id=entity_id, - entity_name=entity['name'], - description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", - severity="medium" if relation_count == 0 else "low", - suggestions=[ - f"检查转录文本中提及 '{entity['name']}' 的其他实体", - f"手动添加 '{entity['name']}' 与其他实体的关系", - "使用实体对齐功能合并相似实体" - ], - related_entities=[r['id'] for r in potential_related], - metadata={ - "relation_count": relation_count, - "potential_related": [r['name'] for r in potential_related] - } - )) + gaps.append( + KnowledgeGap( + gap_id=f"gap_sparse_{entity_id}", + gap_type="sparse_relation", + entity_id=entity_id, + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", + severity="medium" if relation_count == 0 else "low", + suggestions=[ + f"检查转录文本中提及 '{entity['name']}' 的其他实体", + f"手动添加 '{entity['name']}' 与其他实体的关系", + "使用实体对齐功能合并相似实体", + ], + related_entities=[r["id"] for r in potential_related], + metadata={ + "relation_count": relation_count, + "potential_related": [r["name"] for r in potential_related], + }, + ) + ) conn.close() return gaps @@ -1725,7 +1748,8 @@ class KnowledgeGapDetection: gaps = [] # 查找没有关系的实体 - isolated = conn.execute(""" + isolated = conn.execute( + """ SELECT e.id, e.name, e.type FROM entities e LEFT JOIN entity_relations r1 ON e.id = r1.source_entity_id @@ -1733,24 +1757,28 @@ class KnowledgeGapDetection: WHERE e.project_id = ? AND r1.id IS NULL AND r2.id IS NULL - """, (project_id,)).fetchall() + """, + (project_id,), + ).fetchall() for entity in isolated: - gaps.append(KnowledgeGap( - gap_id=f"gap_iso_{entity['id']}", - gap_type="isolated_entity", - entity_id=entity['id'], - entity_name=entity['name'], - description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", - severity="high", - suggestions=[ - f"检查 '{entity['name']}' 是否应该与其他实体建立关系", - f"考虑删除不相关的实体 '{entity['name']}'", - "运行关系发现算法自动识别潜在关系" - ], - related_entities=[], - metadata={"entity_type": entity['type']} - )) + gaps.append( + KnowledgeGap( + gap_id=f"gap_iso_{entity['id']}", + gap_type="isolated_entity", + entity_id=entity["id"], + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", + severity="high", + suggestions=[ + f"检查 '{entity['name']}' 是否应该与其他实体建立关系", + f"考虑删除不相关的实体 '{entity['name']}'", + "运行关系发现算法自动识别潜在关系", + ], + related_entities=[], + metadata={"entity_type": entity["type"]}, + ) + ) conn.close() return gaps @@ -1761,28 +1789,30 @@ class KnowledgeGapDetection: gaps = [] # 查找缺少定义的实体 - incomplete = conn.execute(""" + incomplete = conn.execute( + """ SELECT id, name, type, definition FROM entities WHERE project_id = ? AND (definition IS NULL OR definition = '') - """, (project_id,)).fetchall() + """, + (project_id,), + ).fetchall() for entity in incomplete: - gaps.append(KnowledgeGap( - gap_id=f"gap_inc_{entity['id']}", - gap_type="incomplete_entity", - entity_id=entity['id'], - entity_name=entity['name'], - description=f"实体 '{entity['name']}' 缺少定义", - severity="low", - suggestions=[ - f"为 '{entity['name']}' 添加定义", - "从转录文本中提取定义信息" - ], - related_entities=[], - metadata={"entity_type": entity['type']} - )) + gaps.append( + KnowledgeGap( + gap_id=f"gap_inc_{entity['id']}", + gap_type="incomplete_entity", + entity_id=entity["id"], + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 缺少定义", + severity="low", + suggestions=[f"为 '{entity['name']}' 添加定义", "从转录文本中提取定义信息"], + related_entities=[], + metadata={"entity_type": entity["type"]}, + ) + ) conn.close() return gaps @@ -1793,25 +1823,19 @@ class KnowledgeGapDetection: gaps = [] # 分析转录文本中频繁提及但未提取为实体的词 - transcripts = conn.execute( - "SELECT full_text FROM transcripts WHERE project_id = ?", - (project_id,) - ).fetchall() + transcripts = conn.execute("SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,)).fetchall() # 合并所有文本 - all_text = " ".join([t['full_text'] or "" for t in transcripts]) + all_text = " ".join([t["full_text"] or "" for t in transcripts]) # 获取现有实体名称 - existing_entities = conn.execute( - "SELECT name FROM entities WHERE project_id = ?", - (project_id,) - ).fetchall() + existing_entities = conn.execute("SELECT name FROM entities WHERE project_id = ?", (project_id,)).fetchall() - existing_names = {e['name'].lower() for e in existing_entities} + existing_names = {e["name"].lower() for e in existing_entities} # 简单的关键词提取(实际可以使用更复杂的 NLP 方法) # 查找大写的词组(可能是专有名词) - potential_entities = re.findall(r'[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*', all_text) + potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text) # 统计频率 freq = defaultdict(int) @@ -1822,20 +1846,19 @@ class KnowledgeGapDetection: # 找出高频但未提取的词 for entity, count in freq.items(): if count >= 3: # 出现3次以上 - gaps.append(KnowledgeGap( - gap_id=f"gap_missing_{hash(entity) % 10000}", - gap_type="missing_key_entity", - entity_id=None, - entity_name=None, - description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", - severity="low", - suggestions=[ - f"考虑将 '{entity}' 添加为实体", - "检查实体提取算法是否需要优化" - ], - related_entities=[], - metadata={"mention_count": count} - )) + gaps.append( + KnowledgeGap( + gap_id=f"gap_missing_{hash(entity) % 10000}", + gap_type="missing_key_entity", + entity_id=None, + entity_name=None, + description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", + severity="low", + suggestions=[f"考虑将 '{entity}' 添加为实体", "检查实体提取算法是否需要优化"], + related_entities=[], + metadata={"mention_count": count}, + ) + ) conn.close() return gaps[:10] # 限制数量 @@ -1853,12 +1876,15 @@ class KnowledgeGapDetection: conn = self._get_conn() # 基础统计 - stats = conn.execute(""" + stats = conn.execute( + """ SELECT (SELECT COUNT(*) FROM entities WHERE project_id = ?) as entity_count, (SELECT COUNT(*) FROM entity_relations WHERE project_id = ?) as relation_count, (SELECT COUNT(*) FROM transcripts WHERE project_id = ?) as transcript_count - """, (project_id, project_id, project_id)).fetchone() + """, + (project_id, project_id, project_id), + ).fetchone() # 计算完整性分数 gaps = self.analyze_project(project_id) @@ -1884,17 +1910,13 @@ class KnowledgeGapDetection: "project_id": project_id, "completeness_score": score, "statistics": { - "entity_count": stats['entity_count'], - "relation_count": stats['relation_count'], - "transcript_count": stats['transcript_count'] - }, - "gap_summary": { - "total": len(gaps), - "by_type": dict(gap_by_type), - "by_severity": severity_count + "entity_count": stats["entity_count"], + "relation_count": stats["relation_count"], + "transcript_count": stats["transcript_count"], }, + "gap_summary": {"total": len(gaps), "by_type": dict(gap_by_type), "by_severity": severity_count}, "top_gaps": [g.to_dict() for g in gaps[:10]], - "recommendations": self._generate_recommendations(gaps) + "recommendations": self._generate_recommendations(gaps), } def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]: @@ -1926,6 +1948,7 @@ class KnowledgeGapDetection: # ==================== 搜索管理器 ==================== + class SearchManager: """ 搜索管理器 - 统一入口 @@ -1940,8 +1963,7 @@ class SearchManager: self.path_discovery = EntityPathDiscovery(db_path) self.gap_detection = KnowledgeGapDetection(db_path) - def hybrid_search(self, query: str, project_id: str | None = None, - limit: int = 20) -> dict: + def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict: """ 混合搜索(全文 + 语义) @@ -1954,16 +1976,12 @@ class SearchManager: Dict: 混合搜索结果 """ # 全文搜索 - fulltext_results = self.fulltext_search.search( - query, project_id, limit=limit - ) + fulltext_results = self.fulltext_search.search(query, project_id, limit=limit) # 语义搜索 semantic_results = [] if self.semantic_search.is_available(): - semantic_results = self.semantic_search.search( - query, project_id, top_k=limit - ) + semantic_results = self.semantic_search.search(query, project_id, top_k=limit) # 合并结果(去重并加权) combined = {} @@ -1979,7 +1997,7 @@ class SearchManager: "fulltext_score": r.score, "semantic_score": 0, "combined_score": r.score * 0.6, # 全文权重 60% - "highlights": r.highlights + "highlights": r.highlights, } # 添加语义搜索结果 @@ -1997,7 +2015,7 @@ class SearchManager: "fulltext_score": 0, "semantic_score": r.similarity, "combined_score": r.similarity * 0.4, - "highlights": [] + "highlights": [], } # 排序 @@ -2010,7 +2028,7 @@ class SearchManager: "total": len(results), "fulltext_count": len(fulltext_results), "semantic_count": len(semantic_results), - "results": results[:limit] + "results": results[:limit], } def index_project(self, project_id: str) -> dict: @@ -2035,13 +2053,12 @@ class SearchManager: # 索引转录文本 transcripts = conn.execute( - "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", - (project_id,) + "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id,) ).fetchall() for t in transcripts: - if t['full_text'] and self.semantic_search.index_embedding( - t['id'], 'transcript', t['project_id'], t['full_text'] + if t["full_text"] and self.semantic_search.index_embedding( + t["id"], "transcript", t["project_id"], t["full_text"] ): semantic_stats["indexed"] += 1 else: @@ -2049,26 +2066,19 @@ class SearchManager: # 索引实体 entities = conn.execute( - "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", - (project_id,) + "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id,) ).fetchall() for e in entities: text = f"{e['name']} {e['definition'] or ''}" - if self.semantic_search.index_embedding( - e['id'], 'entity', e['project_id'], text - ): + if self.semantic_search.index_embedding(e["id"], "entity", e["project_id"], text): semantic_stats["indexed"] += 1 else: semantic_stats["errors"] += 1 conn.close() - return { - "project_id": project_id, - "fulltext": fulltext_stats, - "semantic": semantic_stats - } + return {"project_id": project_id, "fulltext": fulltext_stats, "semantic": semantic_stats} def get_search_stats(self, project_id: str | None = None) -> dict: """获取搜索统计信息""" @@ -2080,15 +2090,13 @@ class SearchManager: # 全文索引统计 fulltext_count = conn.execute( - f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", - params - ).fetchone()['count'] + f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params + ).fetchone()["count"] # 语义索引统计 - semantic_count = conn.execute( - f"SELECT COUNT(*) as count FROM embeddings {where_clause}", - params - ).fetchone()['count'] + semantic_count = conn.execute(f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params).fetchone()[ + "count" + ] # 按类型统计 type_stats = {} @@ -2097,9 +2105,9 @@ class SearchManager: """SELECT content_type, COUNT(*) as count FROM search_indexes WHERE project_id = ? GROUP BY content_type""", - (project_id,) + (project_id,), ).fetchall() - type_stats = {r['content_type']: r['count'] for r in rows} + type_stats = {r["content_type"]: r["count"] for r in rows} conn.close() @@ -2108,7 +2116,7 @@ class SearchManager: "fulltext_indexed": fulltext_count, "semantic_indexed": semantic_count, "by_content_type": type_stats, - "semantic_search_available": self.semantic_search.is_available() + "semantic_search_available": self.semantic_search.is_available(), } @@ -2125,22 +2133,19 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: # 便捷函数 -def fulltext_search(query: str, project_id: str | None = None, - limit: int = 20) -> list[SearchResult]: +def fulltext_search(query: str, project_id: str | None = None, limit: int = 20) -> list[SearchResult]: """全文搜索便捷函数""" manager = get_search_manager() return manager.fulltext_search.search(query, project_id, limit=limit) -def semantic_search(query: str, project_id: str | None = None, - top_k: int = 10) -> list[SemanticSearchResult]: +def semantic_search(query: str, project_id: str | None = None, top_k: int = 10) -> list[SemanticSearchResult]: """语义搜索便捷函数""" manager = get_search_manager() return manager.semantic_search.search(query, project_id, top_k=top_k) -def find_entity_path(source_id: str, target_id: str, - max_depth: int = 5) -> EntityPath | None: +def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: """查找实体路径便捷函数""" manager = get_search_manager() return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth) diff --git a/backend/security_manager.py b/backend/security_manager.py index 777bb5b..940a60b 100644 --- a/backend/security_manager.py +++ b/backend/security_manager.py @@ -19,6 +19,7 @@ try: from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + CRYPTO_AVAILABLE = True except ImportError: CRYPTO_AVAILABLE = False @@ -27,6 +28,7 @@ except ImportError: class AuditActionType(Enum): """审计动作类型""" + CREATE = "create" READ = "read" UPDATE = "update" @@ -49,26 +51,29 @@ class AuditActionType(Enum): class DataSensitivityLevel(Enum): """数据敏感度级别""" - PUBLIC = "public" # 公开 - INTERNAL = "internal" # 内部 + + PUBLIC = "public" # 公开 + INTERNAL = "internal" # 内部 CONFIDENTIAL = "confidential" # 机密 - SECRET = "secret" # 绝密 + SECRET = "secret" # 绝密 class MaskingRuleType(Enum): """脱敏规则类型""" - PHONE = "phone" # 手机号 - EMAIL = "email" # 邮箱 - ID_CARD = "id_card" # 身份证号 - BANK_CARD = "bank_card" # 银行卡号 - NAME = "name" # 姓名 - ADDRESS = "address" # 地址 - CUSTOM = "custom" # 自定义 + + PHONE = "phone" # 手机号 + EMAIL = "email" # 邮箱 + ID_CARD = "id_card" # 身份证号 + BANK_CARD = "bank_card" # 银行卡号 + NAME = "name" # 姓名 + ADDRESS = "address" # 地址 + CUSTOM = "custom" # 自定义 @dataclass class AuditLog: """审计日志条目""" + id: str action_type: str user_id: str | None = None @@ -90,6 +95,7 @@ class AuditLog: @dataclass class EncryptionConfig: """加密配置""" + id: str project_id: str is_enabled: bool = False @@ -107,6 +113,7 @@ class EncryptionConfig: @dataclass class MaskingRule: """脱敏规则""" + id: str project_id: str name: str @@ -126,6 +133,7 @@ class MaskingRule: @dataclass class DataAccessPolicy: """数据访问策略""" + id: str project_id: str name: str @@ -147,6 +155,7 @@ class DataAccessPolicy: @dataclass class AccessRequest: """访问请求(用于需要审批的访问)""" + id: str policy_id: str user_id: str @@ -166,30 +175,15 @@ class SecurityManager: # 预定义脱敏规则 DEFAULT_MASKING_RULES = { - MaskingRuleType.PHONE: { - "pattern": r"(\d{3})\d{4}(\d{4})", - "replacement": r"\1****\2" - }, - MaskingRuleType.EMAIL: { - "pattern": r"(\w{1,3})\w+(@\w+\.\w+)", - "replacement": r"\1***\2" - }, - MaskingRuleType.ID_CARD: { - "pattern": r"(\d{6})\d{8}(\d{4})", - "replacement": r"\1********\2" - }, - MaskingRuleType.BANK_CARD: { - "pattern": r"(\d{4})\d+(\d{4})", - "replacement": r"\1 **** **** \2" - }, - MaskingRuleType.NAME: { - "pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", - "replacement": r"\1**" - }, + MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"}, + MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"}, + MaskingRuleType.ID_CARD: {"pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2"}, + MaskingRuleType.BANK_CARD: {"pattern": r"(\d{4})\d+(\d{4})", "replacement": r"\1 **** **** \2"}, + MaskingRuleType.NAME: {"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", "replacement": r"\1**"}, MaskingRuleType.ADDRESS: { "pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)", - "replacement": r"\1\2***" - } + "replacement": r"\1\2***", + }, } def __init__(self, db_path: str = "insightflow.db"): @@ -200,7 +194,7 @@ class SecurityManager: self._local = {} self._init_db() - def _init_db(self): + def _init_db(self) -> None: """初始化数据库表""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() @@ -308,9 +302,7 @@ class SecurityManager: def _generate_id(self) -> str: """生成唯一ID""" - return hashlib.sha256( - f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode() - ).hexdigest()[:32] + return hashlib.sha256(f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()).hexdigest()[:32] # ==================== 审计日志 ==================== @@ -326,7 +318,7 @@ class SecurityManager: before_value: str | None = None, after_value: str | None = None, success: bool = True, - error_message: str | None = None + error_message: str | None = None, ) -> AuditLog: """记录审计日志""" log = AuditLog( @@ -341,22 +333,34 @@ class SecurityManager: before_value=before_value, after_value=after_value, success=success, - error_message=error_message + error_message=error_message, ) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO audit_logs (id, action_type, user_id, user_ip, user_agent, resource_type, resource_id, action_details, before_value, after_value, success, error_message, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - log.id, log.action_type, log.user_id, log.user_ip, log.user_agent, - log.resource_type, log.resource_id, log.action_details, - log.before_value, log.after_value, int(log.success), - log.error_message, log.created_at - )) + """, + ( + log.id, + log.action_type, + log.user_id, + log.user_ip, + log.user_agent, + log.resource_type, + log.resource_id, + log.action_details, + log.before_value, + log.after_value, + int(log.success), + log.error_message, + log.created_at, + ), + ) conn.commit() conn.close() @@ -372,7 +376,7 @@ class SecurityManager: end_time: str | None = None, success: bool | None = None, limit: int = 100, - offset: int = 0 + offset: int = 0, ) -> list[AuditLog]: """查询审计日志""" conn = sqlite3.connect(self.db_path) @@ -429,18 +433,14 @@ class SecurityManager: after_value=row[9], success=bool(row[10]), error_message=row[11], - created_at=row[12] + created_at=row[12], ) logs.append(log) conn.close() return logs - def get_audit_stats( - self, - start_time: str | None = None, - end_time: str | None = None - ) -> dict[str, Any]: + def get_audit_stats(self, start_time: str | None = None, end_time: str | None = None) -> dict[str, Any]: """获取审计统计""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() @@ -460,12 +460,7 @@ class SecurityManager: cursor.execute(query, params) rows = cursor.fetchall() - stats = { - "total_actions": 0, - "success_count": 0, - "failure_count": 0, - "action_breakdown": {} - } + stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}} for action_type, success, count in rows: stats["total_actions"] += count @@ -500,11 +495,7 @@ class SecurityManager: ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) - def enable_encryption( - self, - project_id: str, - master_password: str - ) -> EncryptionConfig: + def enable_encryption(self, project_id: str, master_password: str) -> EncryptionConfig: """启用项目加密""" if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") @@ -523,43 +514,54 @@ class SecurityManager: encryption_type="aes-256-gcm", key_derivation="pbkdf2", master_key_hash=key_hash, - salt=salt + salt=salt, ) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # 检查是否已存在配置 - cursor.execute( - "SELECT id FROM encryption_configs WHERE project_id = ?", - (project_id,) - ) + cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,)) existing = cursor.fetchone() if existing: - cursor.execute(""" + cursor.execute( + """ UPDATE encryption_configs SET is_enabled = 1, encryption_type = ?, key_derivation = ?, master_key_hash = ?, salt = ?, updated_at = ? WHERE project_id = ? - """, ( - config.encryption_type, config.key_derivation, - config.master_key_hash, config.salt, - config.updated_at, project_id - )) + """, + ( + config.encryption_type, + config.key_derivation, + config.master_key_hash, + config.salt, + config.updated_at, + project_id, + ), + ) config.id = existing[0] else: - cursor.execute(""" + cursor.execute( + """ INSERT INTO encryption_configs (id, project_id, is_enabled, encryption_type, key_derivation, master_key_hash, salt, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - config.id, config.project_id, int(config.is_enabled), - config.encryption_type, config.key_derivation, - config.master_key_hash, config.salt, - config.created_at, config.updated_at - )) + """, + ( + config.id, + config.project_id, + int(config.is_enabled), + config.encryption_type, + config.key_derivation, + config.master_key_hash, + config.salt, + config.created_at, + config.updated_at, + ), + ) conn.commit() conn.close() @@ -569,16 +571,12 @@ class SecurityManager: action_type=AuditActionType.ENCRYPTION_ENABLE, resource_type="project", resource_id=project_id, - action_details={"encryption_type": config.encryption_type} + action_details={"encryption_type": config.encryption_type}, ) return config - def disable_encryption( - self, - project_id: str, - master_password: str - ) -> bool: + def disable_encryption(self, project_id: str, master_password: str) -> bool: """禁用项目加密""" # 验证密码 if not self.verify_encryption_password(project_id, master_password): @@ -587,29 +585,24 @@ class SecurityManager: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE encryption_configs SET is_enabled = 0, updated_at = ? WHERE project_id = ? - """, (datetime.now().isoformat(), project_id)) + """, + (datetime.now().isoformat(), project_id), + ) conn.commit() conn.close() # 记录审计日志 - self.log_audit( - action_type=AuditActionType.ENCRYPTION_DISABLE, - resource_type="project", - resource_id=project_id - ) + self.log_audit(action_type=AuditActionType.ENCRYPTION_DISABLE, resource_type="project", resource_id=project_id) return True - def verify_encryption_password( - self, - project_id: str, - password: str - ) -> bool: + def verify_encryption_password(self, project_id: str, password: str) -> bool: """验证加密密码""" if not CRYPTO_AVAILABLE: return False @@ -617,10 +610,7 @@ class SecurityManager: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute( - "SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", - (project_id,) - ) + cursor.execute("SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id,)) row = cursor.fetchone() conn.close() @@ -638,10 +628,7 @@ class SecurityManager: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute( - "SELECT * FROM encryption_configs WHERE project_id = ?", - (project_id,) - ) + cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,)) row = cursor.fetchone() conn.close() @@ -657,15 +644,10 @@ class SecurityManager: master_key_hash=row[5], salt=row[6], created_at=row[7], - updated_at=row[8] + updated_at=row[8], ) - def encrypt_data( - self, - data: str, - password: str, - salt: str | None = None - ) -> tuple[str, str]: + def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: """加密数据""" if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") @@ -679,12 +661,7 @@ class SecurityManager: return base64.b64encode(encrypted).decode(), salt - def decrypt_data( - self, - encrypted_data: str, - password: str, - salt: str - ) -> str: + def decrypt_data(self, encrypted_data: str, password: str, salt: str) -> str: """解密数据""" if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") @@ -705,7 +682,7 @@ class SecurityManager: pattern: str | None = None, replacement: str | None = None, description: str | None = None, - priority: int = 0 + priority: int = 0, ) -> MaskingRule: """创建脱敏规则""" # 使用预定义规则或自定义规则 @@ -722,22 +699,33 @@ class SecurityManager: pattern=pattern or "", replacement=replacement or "****", description=description, - priority=priority + priority=priority, ) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO masking_rules (id, project_id, name, rule_type, pattern, replacement, is_active, priority, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - rule.id, rule.project_id, rule.name, rule.rule_type, - rule.pattern, rule.replacement, int(rule.is_active), - rule.priority, rule.description, rule.created_at, rule.updated_at - )) + """, + ( + rule.id, + rule.project_id, + rule.name, + rule.rule_type, + rule.pattern, + rule.replacement, + int(rule.is_active), + rule.priority, + rule.description, + rule.created_at, + rule.updated_at, + ), + ) conn.commit() conn.close() @@ -747,16 +735,12 @@ class SecurityManager: action_type=AuditActionType.DATA_MASKING, resource_type="project", resource_id=project_id, - action_details={"action": "create_rule", "rule_name": name} + action_details={"action": "create_rule", "rule_name": name}, ) return rule - def get_masking_rules( - self, - project_id: str, - active_only: bool = True - ) -> list[MaskingRule]: + def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]: """获取脱敏规则""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() @@ -775,27 +759,25 @@ class SecurityManager: rules = [] for row in rows: - rules.append(MaskingRule( - id=row[0], - project_id=row[1], - name=row[2], - rule_type=row[3], - pattern=row[4], - replacement=row[5], - is_active=bool(row[6]), - priority=row[7], - description=row[8], - created_at=row[9], - updated_at=row[10] - )) + rules.append( + MaskingRule( + id=row[0], + project_id=row[1], + name=row[2], + rule_type=row[3], + pattern=row[4], + replacement=row[5], + is_active=bool(row[6]), + priority=row[7], + description=row[8], + created_at=row[9], + updated_at=row[10], + ) + ) return rules - def update_masking_rule( - self, - rule_id: str, - **kwargs - ) -> MaskingRule | None: + def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None: """更新脱敏规则""" allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] @@ -818,11 +800,14 @@ class SecurityManager: params.append(datetime.now().isoformat()) params.append(rule_id) - cursor.execute(f""" + cursor.execute( + f""" UPDATE masking_rules SET {', '.join(set_clauses)} WHERE id = ? - """, params) + """, + params, + ) conn.commit() conn.close() @@ -848,7 +833,7 @@ class SecurityManager: priority=row[7], description=row[8], created_at=row[9], - updated_at=row[10] + updated_at=row[10], ) def delete_masking_rule(self, rule_id: str) -> bool: @@ -864,12 +849,7 @@ class SecurityManager: return success - def apply_masking( - self, - text: str, - project_id: str, - rule_types: list[MaskingRuleType] | None = None - ) -> str: + def apply_masking(self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None) -> str: """应用脱敏规则到文本""" rules = self.get_masking_rules(project_id) @@ -884,22 +864,14 @@ class SecurityManager: continue try: - masked_text = re.sub( - rule.pattern, - rule.replacement, - masked_text - ) + masked_text = re.sub(rule.pattern, rule.replacement, masked_text) except re.error: # 忽略无效的正则表达式 continue return masked_text - def apply_masking_to_entity( - self, - entity_data: dict[str, Any], - project_id: str - ) -> dict[str, Any]: + def apply_masking_to_entity(self, entity_data: dict[str, Any], project_id: str) -> dict[str, Any]: """对实体数据应用脱敏""" masked_data = entity_data.copy() @@ -924,7 +896,7 @@ class SecurityManager: allowed_ips: list[str] | None = None, time_restrictions: dict | None = None, max_access_count: int | None = None, - require_approval: bool = False + require_approval: bool = False, ) -> DataAccessPolicy: """创建数据访问策略""" policy = DataAccessPolicy( @@ -937,36 +909,43 @@ class SecurityManager: allowed_ips=json.dumps(allowed_ips) if allowed_ips else None, time_restrictions=json.dumps(time_restrictions) if time_restrictions else None, max_access_count=max_access_count, - require_approval=require_approval + require_approval=require_approval, ) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO data_access_policies (id, project_id, name, description, allowed_users, allowed_roles, allowed_ips, time_restrictions, max_access_count, require_approval, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - policy.id, policy.project_id, policy.name, policy.description, - policy.allowed_users, policy.allowed_roles, policy.allowed_ips, - policy.time_restrictions, policy.max_access_count, - int(policy.require_approval), int(policy.is_active), - policy.created_at, policy.updated_at - )) + """, + ( + policy.id, + policy.project_id, + policy.name, + policy.description, + policy.allowed_users, + policy.allowed_roles, + policy.allowed_ips, + policy.time_restrictions, + policy.max_access_count, + int(policy.require_approval), + int(policy.is_active), + policy.created_at, + policy.updated_at, + ), + ) conn.commit() conn.close() return policy - def get_access_policies( - self, - project_id: str, - active_only: bool = True - ) -> list[DataAccessPolicy]: + def get_access_policies(self, project_id: str, active_only: bool = True) -> list[DataAccessPolicy]: """获取数据访问策略""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() @@ -983,38 +962,34 @@ class SecurityManager: policies = [] for row in rows: - policies.append(DataAccessPolicy( - id=row[0], - project_id=row[1], - name=row[2], - description=row[3], - allowed_users=row[4], - allowed_roles=row[5], - allowed_ips=row[6], - time_restrictions=row[7], - max_access_count=row[8], - require_approval=bool(row[9]), - is_active=bool(row[10]), - created_at=row[11], - updated_at=row[12] - )) + policies.append( + DataAccessPolicy( + id=row[0], + project_id=row[1], + name=row[2], + description=row[3], + allowed_users=row[4], + allowed_roles=row[5], + allowed_ips=row[6], + time_restrictions=row[7], + max_access_count=row[8], + require_approval=bool(row[9]), + is_active=bool(row[10]), + created_at=row[11], + updated_at=row[12], + ) + ) return policies def check_access_permission( - self, - policy_id: str, - user_id: str, - user_ip: str | None = None + self, policy_id: str, user_id: str, user_ip: str | None = None ) -> tuple[bool, str | None]: """检查访问权限""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute( - "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", - (policy_id,) - ) + cursor.execute("SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,)) row = cursor.fetchone() conn.close() @@ -1034,7 +1009,7 @@ class SecurityManager: require_approval=bool(row[9]), is_active=bool(row[10]), created_at=row[11], - updated_at=row[12] + updated_at=row[12], ) # 检查用户白名单 @@ -1074,11 +1049,14 @@ class SecurityManager: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT * FROM access_requests WHERE policy_id = ? AND user_id = ? AND status = 'approved' AND (expires_at IS NULL OR expires_at > ?) - """, (policy_id, user_id, datetime.now().isoformat())) + """, + (policy_id, user_id, datetime.now().isoformat()), + ) request = cursor.fetchone() conn.close() @@ -1104,11 +1082,7 @@ class SecurityManager: return ip == pattern def create_access_request( - self, - policy_id: str, - user_id: str, - request_reason: str | None = None, - expires_hours: int = 24 + self, policy_id: str, user_id: str, request_reason: str | None = None, expires_hours: int = 24 ) -> AccessRequest: """创建访问请求""" request = AccessRequest( @@ -1116,21 +1090,28 @@ class SecurityManager: policy_id=policy_id, user_id=user_id, request_reason=request_reason, - expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat() + expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(), ) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO access_requests (id, policy_id, user_id, request_reason, status, expires_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - request.id, request.policy_id, request.user_id, - request.request_reason, request.status, request.expires_at, - request.created_at - )) + """, + ( + request.id, + request.policy_id, + request.user_id, + request.request_reason, + request.status, + request.expires_at, + request.created_at, + ), + ) conn.commit() conn.close() @@ -1138,10 +1119,7 @@ class SecurityManager: return request def approve_access_request( - self, - request_id: str, - approved_by: str, - expires_hours: int = 24 + self, request_id: str, approved_by: str, expires_hours: int = 24 ) -> AccessRequest | None: """批准访问请求""" conn = sqlite3.connect(self.db_path) @@ -1150,11 +1128,14 @@ class SecurityManager: expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat() approved_at = datetime.now().isoformat() - cursor.execute(""" + cursor.execute( + """ UPDATE access_requests SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ? WHERE id = ? - """, (approved_by, approved_at, expires_at, request_id)) + """, + (approved_by, approved_at, expires_at, request_id), + ) conn.commit() @@ -1175,23 +1156,22 @@ class SecurityManager: approved_by=row[5], approved_at=row[6], expires_at=row[7], - created_at=row[8] + created_at=row[8], ) - def reject_access_request( - self, - request_id: str, - rejected_by: str - ) -> AccessRequest | None: + def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None: """拒绝访问请求""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE access_requests SET status = 'rejected', approved_by = ? WHERE id = ? - """, (rejected_by, request_id)) + """, + (rejected_by, request_id), + ) conn.commit() @@ -1211,7 +1191,7 @@ class SecurityManager: approved_by=row[5], approved_at=row[6], expires_at=row[7], - created_at=row[8] + created_at=row[8], ) diff --git a/backend/subscription_manager.py b/backend/subscription_manager.py index 1f0a80f..2aef34d 100644 --- a/backend/subscription_manager.py +++ b/backend/subscription_manager.py @@ -24,63 +24,69 @@ logger = logging.getLogger(__name__) class SubscriptionStatus(StrEnum): """订阅状态""" - ACTIVE = "active" # 活跃 - CANCELLED = "cancelled" # 已取消 - EXPIRED = "expired" # 已过期 - PAST_DUE = "past_due" # 逾期 - TRIAL = "trial" # 试用中 - PENDING = "pending" # 待支付 + + ACTIVE = "active" # 活跃 + CANCELLED = "cancelled" # 已取消 + EXPIRED = "expired" # 已过期 + PAST_DUE = "past_due" # 逾期 + TRIAL = "trial" # 试用中 + PENDING = "pending" # 待支付 class PaymentProvider(StrEnum): """支付提供商""" - STRIPE = "stripe" # Stripe - ALIPAY = "alipay" # 支付宝 - WECHAT = "wechat" # 微信支付 + + STRIPE = "stripe" # Stripe + ALIPAY = "alipay" # 支付宝 + WECHAT = "wechat" # 微信支付 BANK_TRANSFER = "bank_transfer" # 银行转账 class PaymentStatus(StrEnum): """支付状态""" - PENDING = "pending" # 待支付 - PROCESSING = "processing" # 处理中 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 - REFUNDED = "refunded" # 已退款 + + PENDING = "pending" # 待支付 + PROCESSING = "processing" # 处理中 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 + REFUNDED = "refunded" # 已退款 PARTIAL_REFUNDED = "partial_refunded" # 部分退款 class InvoiceStatus(StrEnum): """发票状态""" - DRAFT = "draft" # 草稿 - ISSUED = "issued" # 已开具 - PAID = "paid" # 已支付 - OVERDUE = "overdue" # 逾期 - VOID = "void" # 作废 + + DRAFT = "draft" # 草稿 + ISSUED = "issued" # 已开具 + PAID = "paid" # 已支付 + OVERDUE = "overdue" # 逾期 + VOID = "void" # 作废 CREDIT_NOTE = "credit_note" # 贷项通知单 class RefundStatus(StrEnum): """退款状态""" - PENDING = "pending" # 待处理 - APPROVED = "approved" # 已批准 - REJECTED = "rejected" # 已拒绝 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 + + PENDING = "pending" # 待处理 + APPROVED = "approved" # 已批准 + REJECTED = "rejected" # 已拒绝 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 @dataclass class SubscriptionPlan: """订阅计划数据类""" + id: str name: str - tier: str # free/pro/enterprise + tier: str # free/pro/enterprise description: str - price_monthly: float # 月付价格 - price_yearly: float # 年付价格 - currency: str # CNY/USD - features: list[str] # 功能列表 - limits: dict[str, Any] # 资源限制 + price_monthly: float # 月付价格 + price_yearly: float # 年付价格 + currency: str # CNY/USD + features: list[str] # 功能列表 + limits: dict[str, Any] # 资源限制 is_active: bool created_at: datetime updated_at: datetime @@ -90,6 +96,7 @@ class SubscriptionPlan: @dataclass class Subscription: """订阅数据类""" + id: str tenant_id: str plan_id: str @@ -110,13 +117,14 @@ class Subscription: @dataclass class UsageRecord: """用量记录数据类""" + id: str tenant_id: str - resource_type: str # transcription/storage/api_call - quantity: float # 使用量 - unit: str # minutes/mb/count + resource_type: str # transcription/storage/api_call + quantity: float # 使用量 + unit: str # minutes/mb/count recorded_at: datetime - cost: float # 费用 + cost: float # 费用 description: str | None metadata: dict[str, Any] @@ -124,6 +132,7 @@ class UsageRecord: @dataclass class Payment: """支付记录数据类""" + id: str tenant_id: str subscription_id: str | None @@ -145,6 +154,7 @@ class Payment: @dataclass class Invoice: """发票数据类""" + id: str tenant_id: str subscription_id: str | None @@ -168,6 +178,7 @@ class Invoice: @dataclass class Refund: """退款数据类""" + id: str tenant_id: str payment_id: str @@ -190,14 +201,15 @@ class Refund: @dataclass class BillingHistory: """账单历史数据类""" + id: str tenant_id: str - type: str # subscription/usage/payment/refund + type: str # subscription/usage/payment/refund amount: float currency: str description: str - reference_id: str # 关联的订阅/支付/退款ID - balance_after: float # 操作后余额 + reference_id: str # 关联的订阅/支付/退款ID + balance_after: float # 操作后余额 created_at: datetime metadata: dict[str, Any] @@ -214,21 +226,15 @@ class SubscriptionManager: "price_monthly": 0.0, "price_yearly": 0.0, "currency": "CNY", - "features": [ - "basic_analysis", - "export_png", - "3_projects", - "100_mb_storage", - "60_min_transcription" - ], + "features": ["basic_analysis", "export_png", "3_projects", "100_mb_storage", "60_min_transcription"], "limits": { "max_projects": 3, "max_storage_mb": 100, "max_transcription_minutes": 60, "max_api_calls_per_day": 100, "max_team_members": 2, - "max_entities": 100 - } + "max_entities": 100, + }, }, "pro": { "name": "Pro", @@ -246,7 +252,7 @@ class SubscriptionManager: "collaboration", "20_projects", "10_gb_storage", - "600_min_transcription" + "600_min_transcription", ], "limits": { "max_projects": 20, @@ -254,8 +260,8 @@ class SubscriptionManager: "max_transcription_minutes": 600, "max_api_calls_per_day": 10000, "max_team_members": 10, - "max_entities": 1000 - } + "max_entities": 1000, + }, }, "enterprise": { "name": "Enterprise", @@ -272,7 +278,7 @@ class SubscriptionManager: "priority_support", "custom_integration", "sla_guarantee", - "dedicated_manager" + "dedicated_manager", ], "limits": { "max_projects": -1, @@ -280,33 +286,17 @@ class SubscriptionManager: "max_transcription_minutes": -1, "max_api_calls_per_day": -1, "max_team_members": -1, - "max_entities": -1 - } - } + "max_entities": -1, + }, + }, } # 按量计费单价(CNY) USAGE_PRICING = { - "transcription": { - "unit": "minute", - "price": 0.5, # 0.5元/分钟 - "free_quota": 60 # 每月免费额度 - }, - "storage": { - "unit": "gb", - "price": 10.0, # 10元/GB/月 - "free_quota": 0.1 # 100MB免费 - }, - "api_call": { - "unit": "1000_calls", - "price": 5.0, # 5元/1000次 - "free_quota": 1000 # 每月免费1000次 - }, - "export": { - "unit": "page", - "price": 0.1, # 0.1元/页(PDF导出) - "free_quota": 100 - } + "transcription": {"unit": "minute", "price": 0.5, "free_quota": 60}, # 0.5元/分钟 # 每月免费额度 + "storage": {"unit": "gb", "price": 10.0, "free_quota": 0.1}, # 10元/GB/月 # 100MB免费 + "api_call": {"unit": "1000_calls", "price": 5.0, "free_quota": 1000}, # 5元/1000次 # 每月免费1000次 + "export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页(PDF导出) } def __init__(self, db_path: str = "insightflow.db"): @@ -320,7 +310,7 @@ class SubscriptionManager: conn.row_factory = sqlite3.Row return conn - def _init_db(self): + def _init_db(self) -> None: """初始化数据库表""" conn = self._get_connection() try: @@ -504,33 +494,36 @@ class SubscriptionManager: finally: conn.close() - def _init_default_plans(self): + def _init_default_plans(self) -> None: """初始化默认订阅计划""" conn = self._get_connection() try: cursor = conn.cursor() for tier, plan_data in self.DEFAULT_PLANS.items(): - cursor.execute(""" + cursor.execute( + """ INSERT OR IGNORE INTO subscription_plans (id, name, tier, description, price_monthly, price_yearly, currency, features, limits, is_active, created_at, updated_at, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - str(uuid.uuid4()), - plan_data["name"], - plan_data["tier"], - plan_data["description"], - plan_data["price_monthly"], - plan_data["price_yearly"], - plan_data["currency"], - json.dumps(plan_data["features"]), - json.dumps(plan_data["limits"]), - 1, - datetime.now(), - datetime.now(), - json.dumps({}) - )) + """, + ( + str(uuid.uuid4()), + plan_data["name"], + plan_data["tier"], + plan_data["description"], + plan_data["price_monthly"], + plan_data["price_yearly"], + plan_data["currency"], + json.dumps(plan_data["features"]), + json.dumps(plan_data["limits"]), + 1, + datetime.now(), + datetime.now(), + json.dumps({}), + ), + ) conn.commit() logger.info("Default subscription plans initialized") @@ -589,10 +582,17 @@ class SubscriptionManager: finally: conn.close() - def create_plan(self, name: str, tier: str, description: str, - price_monthly: float, price_yearly: float, - currency: str = "CNY", features: list[str] = None, - limits: dict[str, Any] = None) -> SubscriptionPlan: + def create_plan( + self, + name: str, + tier: str, + description: str, + price_monthly: float, + price_yearly: float, + currency: str = "CNY", + features: list[str] = None, + limits: dict[str, Any] = None, + ) -> SubscriptionPlan: """创建新订阅计划""" conn = self._get_connection() try: @@ -611,22 +611,33 @@ class SubscriptionManager: is_active=True, created_at=datetime.now(), updated_at=datetime.now(), - metadata={} + metadata={}, ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO subscription_plans (id, name, tier, description, price_monthly, price_yearly, currency, features, limits, is_active, created_at, updated_at, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - plan.id, plan.name, plan.tier, plan.description, - plan.price_monthly, plan.price_yearly, plan.currency, - json.dumps(plan.features), json.dumps(plan.limits), - int(plan.is_active), plan.created_at, plan.updated_at, - json.dumps(plan.metadata) - )) + """, + ( + plan.id, + plan.name, + plan.tier, + plan.description, + plan.price_monthly, + plan.price_yearly, + plan.currency, + json.dumps(plan.features), + json.dumps(plan.limits), + int(plan.is_active), + plan.created_at, + plan.updated_at, + json.dumps(plan.metadata), + ), + ) conn.commit() logger.info(f"Subscription plan created: {plan_id} ({name})") @@ -650,15 +661,23 @@ class SubscriptionManager: updates = [] params = [] - allowed_fields = ['name', 'description', 'price_monthly', 'price_yearly', - 'currency', 'features', 'limits', 'is_active'] + allowed_fields = [ + "name", + "description", + "price_monthly", + "price_yearly", + "currency", + "features", + "limits", + "is_active", + ] for key, value in kwargs.items(): if key in allowed_fields: updates.append(f"{key} = ?") - if key in ['features', 'limits']: - params.append(json.dumps(value) if value else '{}') - elif key == 'is_active': + if key in ["features", "limits"]: + params.append(json.dumps(value) if value else "{}") + elif key == "is_active": params.append(int(value)) else: params.append(value) @@ -671,10 +690,13 @@ class SubscriptionManager: params.append(plan_id) cursor = conn.cursor() - cursor.execute(f""" + cursor.execute( + f""" UPDATE subscription_plans SET {', '.join(updates)} WHERE id = ? - """, params) + """, + params, + ) conn.commit() return self.get_plan(plan_id) @@ -684,19 +706,26 @@ class SubscriptionManager: # ==================== 订阅管理 ==================== - def create_subscription(self, tenant_id: str, plan_id: str, - payment_provider: str | None = None, - trial_days: int = 0, - billing_cycle: str = "monthly") -> Subscription: + def create_subscription( + self, + tenant_id: str, + plan_id: str, + payment_provider: str | None = None, + trial_days: int = 0, + billing_cycle: str = "monthly", + ) -> Subscription: """创建新订阅""" conn = self._get_connection() try: # 检查是否已有活跃订阅 cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT * FROM subscriptions WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending') - """, (tenant_id,)) + """, + (tenant_id,), + ) existing = cursor.fetchone() if existing: @@ -741,37 +770,60 @@ class SubscriptionManager: provider_subscription_id=None, created_at=now, updated_at=now, - metadata={"billing_cycle": billing_cycle} + metadata={"billing_cycle": billing_cycle}, ) - cursor.execute(""" + cursor.execute( + """ INSERT INTO subscriptions (id, tenant_id, plan_id, status, current_period_start, current_period_end, cancel_at_period_end, canceled_at, trial_start, trial_end, payment_provider, provider_subscription_id, created_at, updated_at, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - subscription.id, subscription.tenant_id, subscription.plan_id, - subscription.status, subscription.current_period_start, - subscription.current_period_end, int(subscription.cancel_at_period_end), - subscription.canceled_at, subscription.trial_start, subscription.trial_end, - subscription.payment_provider, subscription.provider_subscription_id, - subscription.created_at, subscription.updated_at, - json.dumps(subscription.metadata) - )) + """, + ( + subscription.id, + subscription.tenant_id, + subscription.plan_id, + subscription.status, + subscription.current_period_start, + subscription.current_period_end, + int(subscription.cancel_at_period_end), + subscription.canceled_at, + subscription.trial_start, + subscription.trial_end, + subscription.payment_provider, + subscription.provider_subscription_id, + subscription.created_at, + subscription.updated_at, + json.dumps(subscription.metadata), + ), + ) # 创建发票 amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly if amount > 0 and trial_days == 0: self._create_invoice_internal( - conn, tenant_id, subscription_id, amount, plan.currency, - now, period_end, f"{plan.name} Subscription ({billing_cycle})" + conn, + tenant_id, + subscription_id, + amount, + plan.currency, + now, + period_end, + f"{plan.name} Subscription ({billing_cycle})", ) # 记录账单历史 self._add_billing_history_internal( - conn, tenant_id, "subscription", 0, plan.currency, - f"Subscription created: {plan.name}", subscription_id, 0 + conn, + tenant_id, + "subscription", + 0, + plan.currency, + f"Subscription created: {plan.name}", + subscription_id, + 0, ) conn.commit() @@ -805,11 +857,14 @@ class SubscriptionManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT * FROM subscriptions WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending') ORDER BY created_at DESC LIMIT 1 - """, (tenant_id,)) + """, + (tenant_id,), + ) row = cursor.fetchone() if row: @@ -830,14 +885,21 @@ class SubscriptionManager: updates = [] params = [] - allowed_fields = ['status', 'current_period_start', 'current_period_end', - 'cancel_at_period_end', 'canceled_at', 'trial_end', - 'payment_provider', 'provider_subscription_id'] + allowed_fields = [ + "status", + "current_period_start", + "current_period_end", + "cancel_at_period_end", + "canceled_at", + "trial_end", + "payment_provider", + "provider_subscription_id", + ] for key, value in kwargs.items(): if key in allowed_fields: updates.append(f"{key} = ?") - if key == 'cancel_at_period_end': + if key == "cancel_at_period_end": params.append(int(value)) else: params.append(value) @@ -850,10 +912,13 @@ class SubscriptionManager: params.append(subscription_id) cursor = conn.cursor() - cursor.execute(f""" + cursor.execute( + f""" UPDATE subscriptions SET {', '.join(updates)} WHERE id = ? - """, params) + """, + params, + ) conn.commit() return self.get_subscription(subscription_id) @@ -861,8 +926,7 @@ class SubscriptionManager: finally: conn.close() - def cancel_subscription(self, subscription_id: str, - at_period_end: bool = True) -> Subscription | None: + def cancel_subscription(self, subscription_id: str, at_period_end: bool = True) -> Subscription | None: """取消订阅""" conn = self._get_connection() try: @@ -875,25 +939,36 @@ class SubscriptionManager: if at_period_end: # 在周期结束时取消 cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE subscriptions SET cancel_at_period_end = 1, canceled_at = ?, updated_at = ? WHERE id = ? - """, (now, now, subscription_id)) + """, + (now, now, subscription_id), + ) else: # 立即取消 cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE subscriptions SET status = 'cancelled', canceled_at = ?, updated_at = ? WHERE id = ? - """, (now, now, subscription_id)) + """, + (now, now, subscription_id), + ) # 记录账单历史 self._add_billing_history_internal( - conn, subscription.tenant_id, "subscription", 0, "CNY", + conn, + subscription.tenant_id, + "subscription", + 0, + "CNY", f"Subscription cancelled{' (at period end)' if at_period_end else ''}", - subscription_id, 0 + subscription_id, + 0, ) conn.commit() @@ -903,8 +978,7 @@ class SubscriptionManager: finally: conn.close() - def change_plan(self, subscription_id: str, new_plan_id: str, - prorate: bool = True) -> Subscription | None: + def change_plan(self, subscription_id: str, new_plan_id: str, prorate: bool = True) -> Subscription | None: """更改订阅计划""" conn = self._get_connection() try: @@ -926,17 +1000,25 @@ class SubscriptionManager: pass cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE subscriptions SET plan_id = ?, updated_at = ? WHERE id = ? - """, (new_plan_id, now, subscription_id)) + """, + (new_plan_id, now, subscription_id), + ) # 记录账单历史 self._add_billing_history_internal( - conn, subscription.tenant_id, "subscription", 0, new_plan.currency, + conn, + subscription.tenant_id, + "subscription", + 0, + new_plan.currency, f"Plan changed from {old_plan.name if old_plan else 'unknown'} to {new_plan.name}", - subscription_id, 0 + subscription_id, + 0, ) conn.commit() @@ -948,10 +1030,15 @@ class SubscriptionManager: # ==================== 用量计费 ==================== - def record_usage(self, tenant_id: str, resource_type: str, - quantity: float, unit: str, - description: str | None = None, - metadata: dict | None = None) -> UsageRecord: + def record_usage( + self, + tenant_id: str, + resource_type: str, + quantity: float, + unit: str, + description: str | None = None, + metadata: dict | None = None, + ) -> UsageRecord: """记录用量""" conn = self._get_connection() try: @@ -968,19 +1055,28 @@ class SubscriptionManager: recorded_at=datetime.now(), cost=cost, description=description, - metadata=metadata or {} + metadata=metadata or {}, ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO usage_records (id, tenant_id, resource_type, quantity, unit, recorded_at, cost, description, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - record.id, record.tenant_id, record.resource_type, - record.quantity, record.unit, record.recorded_at, - record.cost, record.description, json.dumps(record.metadata) - )) + """, + ( + record.id, + record.tenant_id, + record.resource_type, + record.quantity, + record.unit, + record.recorded_at, + record.cost, + record.description, + json.dumps(record.metadata), + ), + ) conn.commit() return record @@ -988,9 +1084,9 @@ class SubscriptionManager: finally: conn.close() - def get_usage_summary(self, tenant_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None) -> dict[str, Any]: + def get_usage_summary( + self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None + ) -> dict[str, Any]: """获取用量汇总""" conn = self._get_connection() try: @@ -1023,21 +1119,21 @@ class SubscriptionManager: total_cost = 0 for row in rows: - summary[row['resource_type']] = { - "quantity": row['total_quantity'], - "cost": row['total_cost'], - "records": row['record_count'] + summary[row["resource_type"]] = { + "quantity": row["total_quantity"], + "cost": row["total_cost"], + "records": row["record_count"], } - total_cost += row['total_cost'] + total_cost += row["total_cost"] return { "tenant_id": tenant_id, "period": { "start": start_date.isoformat() if start_date else None, - "end": end_date.isoformat() if end_date else None + "end": end_date.isoformat() if end_date else None, }, "breakdown": summary, - "total_cost": total_cost + "total_cost": total_cost, } finally: @@ -1060,11 +1156,17 @@ class SubscriptionManager: # ==================== 支付管理 ==================== - def create_payment(self, tenant_id: str, amount: float, currency: str, - provider: str, subscription_id: str | None = None, - invoice_id: str | None = None, - payment_method: str | None = None, - payment_details: dict | None = None) -> Payment: + def create_payment( + self, + tenant_id: str, + amount: float, + currency: str, + provider: str, + subscription_id: str | None = None, + invoice_id: str | None = None, + payment_method: str | None = None, + payment_details: dict | None = None, + ) -> Payment: """创建支付记录""" conn = self._get_connection() try: @@ -1087,24 +1189,37 @@ class SubscriptionManager: failed_at=None, failure_reason=None, created_at=now, - updated_at=now + updated_at=now, ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO payments (id, tenant_id, subscription_id, invoice_id, amount, currency, provider, provider_payment_id, status, payment_method, payment_details, paid_at, failed_at, failure_reason, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - payment.id, payment.tenant_id, payment.subscription_id, - payment.invoice_id, payment.amount, payment.currency, - payment.provider, payment.provider_payment_id, payment.status, - payment.payment_method, json.dumps(payment.payment_details), - payment.paid_at, payment.failed_at, payment.failure_reason, - payment.created_at, payment.updated_at - )) + """, + ( + payment.id, + payment.tenant_id, + payment.subscription_id, + payment.invoice_id, + payment.amount, + payment.currency, + payment.provider, + payment.provider_payment_id, + payment.status, + payment.payment_method, + json.dumps(payment.payment_details), + payment.paid_at, + payment.failed_at, + payment.failure_reason, + payment.created_at, + payment.updated_at, + ), + ) conn.commit() return payment @@ -1112,8 +1227,7 @@ class SubscriptionManager: finally: conn.close() - def confirm_payment(self, payment_id: str, - provider_payment_id: str | None = None) -> Payment | None: + def confirm_payment(self, payment_id: str, provider_payment_id: str | None = None) -> Payment | None: """确认支付完成""" conn = self._get_connection() try: @@ -1124,33 +1238,47 @@ class SubscriptionManager: now = datetime.now() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE payments SET status = 'completed', provider_payment_id = ?, paid_at = ?, updated_at = ? WHERE id = ? - """, (provider_payment_id, now, now, payment_id)) + """, + (provider_payment_id, now, now, payment_id), + ) # 如果有关联发票,更新发票状态 if payment.invoice_id: - cursor.execute(""" + cursor.execute( + """ UPDATE invoices SET status = 'paid', amount_paid = amount_due, paid_at = ? WHERE id = ? - """, (now, payment.invoice_id)) + """, + (now, payment.invoice_id), + ) # 如果有关联订阅,激活订阅 if payment.subscription_id: - cursor.execute(""" + cursor.execute( + """ UPDATE subscriptions SET status = 'active', updated_at = ? WHERE id = ? AND status = 'pending' - """, (now, payment.subscription_id)) + """, + (now, payment.subscription_id), + ) # 记录账单历史 self._add_billing_history_internal( - conn, payment.tenant_id, "payment", payment.amount, - payment.currency, f"Payment completed via {payment.provider}", - payment_id, 0 # 余额更新应该在账户管理中处理 + conn, + payment.tenant_id, + "payment", + payment.amount, + payment.currency, + f"Payment completed via {payment.provider}", + payment_id, + 0, # 余额更新应该在账户管理中处理 ) conn.commit() @@ -1167,11 +1295,14 @@ class SubscriptionManager: now = datetime.now() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE payments SET status = 'failed', failure_reason = ?, failed_at = ?, updated_at = ? WHERE id = ? - """, (reason, now, now, payment_id)) + """, + (reason, now, now, payment_id), + ) conn.commit() return self._get_payment_internal(conn, payment_id) @@ -1187,8 +1318,9 @@ class SubscriptionManager: finally: conn.close() - def list_payments(self, tenant_id: str, status: str | None = None, - limit: int = 100, offset: int = 0) -> list[Payment]: + def list_payments( + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + ) -> list[Payment]: """列出支付记录""" conn = self._get_connection() try: @@ -1224,11 +1356,18 @@ class SubscriptionManager: # ==================== 发票管理 ==================== - def _create_invoice_internal(self, conn: sqlite3.Connection, tenant_id: str, - subscription_id: str | None, amount: float, - currency: str, period_start: datetime, - period_end: datetime, description: str, - line_items: list[dict] | None = None) -> Invoice: + def _create_invoice_internal( + self, + conn: sqlite3.Connection, + tenant_id: str, + subscription_id: str | None, + amount: float, + currency: str, + period_start: datetime, + period_end: datetime, + description: str, + line_items: list[dict] | None = None, + ) -> Invoice: """内部方法:创建发票""" invoice_id = str(uuid.uuid4()) invoice_number = self._generate_invoice_number() @@ -1253,25 +1392,39 @@ class SubscriptionManager: voided_at=None, void_reason=None, created_at=now, - updated_at=now + updated_at=now, ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO invoices (id, tenant_id, subscription_id, invoice_number, status, amount_due, amount_paid, currency, period_start, period_end, description, line_items, due_date, paid_at, voided_at, void_reason, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - invoice.id, invoice.tenant_id, invoice.subscription_id, - invoice.invoice_number, invoice.status, invoice.amount_due, - invoice.amount_paid, invoice.currency, invoice.period_start, - invoice.period_end, invoice.description, - json.dumps(invoice.line_items), invoice.due_date, - invoice.paid_at, invoice.voided_at, invoice.void_reason, - invoice.created_at, invoice.updated_at - )) + """, + ( + invoice.id, + invoice.tenant_id, + invoice.subscription_id, + invoice.invoice_number, + invoice.status, + invoice.amount_due, + invoice.amount_paid, + invoice.currency, + invoice.period_start, + invoice.period_end, + invoice.description, + json.dumps(invoice.line_items), + invoice.due_date, + invoice.paid_at, + invoice.voided_at, + invoice.void_reason, + invoice.created_at, + invoice.updated_at, + ), + ) return invoice @@ -1305,8 +1458,9 @@ class SubscriptionManager: finally: conn.close() - def list_invoices(self, tenant_id: str, status: str | None = None, - limit: int = 100, offset: int = 0) -> list[Invoice]: + def list_invoices( + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + ) -> list[Invoice]: """列出发票""" conn = self._get_connection() try: @@ -1344,11 +1498,14 @@ class SubscriptionManager: now = datetime.now() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE invoices SET status = 'void', voided_at = ?, void_reason = ?, updated_at = ? WHERE id = ? - """, (now, reason, now, invoice_id)) + """, + (now, reason, now, invoice_id), + ) conn.commit() return self.get_invoice(invoice_id) @@ -1364,12 +1521,15 @@ class SubscriptionManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT COUNT(*) as count FROM invoices WHERE invoice_number LIKE ? - """, (f"{prefix}%",)) + """, + (f"{prefix}%",), + ) row = cursor.fetchone() - count = row['count'] + 1 + count = row["count"] + 1 return f"{prefix}-{count:06d}" @@ -1378,8 +1538,7 @@ class SubscriptionManager: # ==================== 退款管理 ==================== - def request_refund(self, tenant_id: str, payment_id: str, amount: float, - reason: str, requested_by: str) -> Refund: + def request_refund(self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str) -> Refund: """申请退款""" conn = self._get_connection() try: @@ -1417,23 +1576,38 @@ class SubscriptionManager: provider_refund_id=None, metadata={}, created_at=now, - updated_at=now + updated_at=now, ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO refunds (id, tenant_id, payment_id, invoice_id, amount, currency, reason, status, requested_by, requested_at, approved_by, approved_at, completed_at, provider_refund_id, metadata, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - refund.id, refund.tenant_id, refund.payment_id, refund.invoice_id, - refund.amount, refund.currency, refund.reason, refund.status, - refund.requested_by, refund.requested_at, refund.approved_by, - refund.approved_at, refund.completed_at, refund.provider_refund_id, - json.dumps(refund.metadata), refund.created_at, refund.updated_at - )) + """, + ( + refund.id, + refund.tenant_id, + refund.payment_id, + refund.invoice_id, + refund.amount, + refund.currency, + refund.reason, + refund.status, + refund.requested_by, + refund.requested_at, + refund.approved_by, + refund.approved_at, + refund.completed_at, + refund.provider_refund_id, + json.dumps(refund.metadata), + refund.created_at, + refund.updated_at, + ), + ) conn.commit() logger.info(f"Refund requested: {refund_id} for payment {payment_id}") @@ -1456,11 +1630,14 @@ class SubscriptionManager: now = datetime.now() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE refunds SET status = 'approved', approved_by = ?, approved_at = ?, updated_at = ? WHERE id = ? - """, (approved_by, now, now, refund_id)) + """, + (approved_by, now, now, refund_id), + ) conn.commit() return self._get_refund_internal(conn, refund_id) @@ -1468,8 +1645,7 @@ class SubscriptionManager: finally: conn.close() - def complete_refund(self, refund_id: str, - provider_refund_id: str | None = None) -> Refund | None: + def complete_refund(self, refund_id: str, provider_refund_id: str | None = None) -> Refund | None: """完成退款""" conn = self._get_connection() try: @@ -1480,24 +1656,35 @@ class SubscriptionManager: now = datetime.now() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE refunds SET status = 'completed', provider_refund_id = ?, completed_at = ?, updated_at = ? WHERE id = ? - """, (provider_refund_id, now, now, refund_id)) + """, + (provider_refund_id, now, now, refund_id), + ) # 更新原支付记录状态 - cursor.execute(""" + cursor.execute( + """ UPDATE payments SET status = 'refunded', updated_at = ? WHERE id = ? - """, (now, refund.payment_id)) + """, + (now, refund.payment_id), + ) # 记录账单历史 self._add_billing_history_internal( - conn, refund.tenant_id, "refund", -refund.amount, - refund.currency, f"Refund processed: {refund.reason}", - refund_id, 0 + conn, + refund.tenant_id, + "refund", + -refund.amount, + refund.currency, + f"Refund processed: {refund.reason}", + refund_id, + 0, ) conn.commit() @@ -1518,11 +1705,14 @@ class SubscriptionManager: now = datetime.now() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE refunds SET status = 'rejected', metadata = json_set(metadata, '$.rejection_reason', ?), updated_at = ? WHERE id = ? - """, (reason, now, refund_id)) + """, + (reason, now, refund_id), + ) conn.commit() return self._get_refund_internal(conn, refund_id) @@ -1538,8 +1728,9 @@ class SubscriptionManager: finally: conn.close() - def list_refunds(self, tenant_id: str, status: str | None = None, - limit: int = 100, offset: int = 0) -> list[Refund]: + def list_refunds( + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + ) -> list[Refund]: """列出退款记录""" conn = self._get_connection() try: @@ -1575,27 +1766,49 @@ class SubscriptionManager: # ==================== 账单历史 ==================== - def _add_billing_history_internal(self, conn: sqlite3.Connection, - tenant_id: str, type: str, amount: float, - currency: str, description: str, - reference_id: str, balance_after: float): + def _add_billing_history_internal( + self, + conn: sqlite3.Connection, + tenant_id: str, + type: str, + amount: float, + currency: str, + description: str, + reference_id: str, + balance_after: float, + ): """内部方法:添加账单历史""" history_id = str(uuid.uuid4()) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO billing_history (id, tenant_id, type, amount, currency, description, reference_id, balance_after, created_at, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - history_id, tenant_id, type, amount, currency, - description, reference_id, balance_after, datetime.now(), json.dumps({}) - )) + """, + ( + history_id, + tenant_id, + type, + amount, + currency, + description, + reference_id, + balance_after, + datetime.now(), + json.dumps({}), + ), + ) - def get_billing_history(self, tenant_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None, - limit: int = 100, offset: int = 0) -> list[BillingHistory]: + def get_billing_history( + self, + tenant_id: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[BillingHistory]: """获取账单历史""" conn = self._get_connection() try: @@ -1624,9 +1837,9 @@ class SubscriptionManager: # ==================== 支付提供商集成 ==================== - def create_stripe_checkout_session(self, tenant_id: str, plan_id: str, - success_url: str, cancel_url: str, - billing_cycle: str = "monthly") -> dict[str, Any]: + def create_stripe_checkout_session( + self, tenant_id: str, plan_id: str, success_url: str, cancel_url: str, billing_cycle: str = "monthly" + ) -> dict[str, Any]: """创建 Stripe Checkout 会话(占位实现)""" # 这里应该集成 Stripe SDK # 简化实现,返回模拟数据 @@ -1634,11 +1847,10 @@ class SubscriptionManager: "session_id": f"cs_{uuid.uuid4().hex[:24]}", "url": f"https://checkout.stripe.com/mock/{uuid.uuid4().hex[:24]}", "status": "created", - "provider": "stripe" + "provider": "stripe", } - def create_alipay_order(self, tenant_id: str, plan_id: str, - billing_cycle: str = "monthly") -> dict[str, Any]: + def create_alipay_order(self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly") -> dict[str, Any]: """创建支付宝订单(占位实现)""" # 这里应该集成支付宝 SDK plan = self.get_plan(plan_id) @@ -1650,11 +1862,10 @@ class SubscriptionManager: "currency": plan.currency, "qr_code_url": f"https://qr.alipay.com/mock/{uuid.uuid4().hex[:16]}", "status": "pending", - "provider": "alipay" + "provider": "alipay", } - def create_wechat_order(self, tenant_id: str, plan_id: str, - billing_cycle: str = "monthly") -> dict[str, Any]: + def create_wechat_order(self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly") -> dict[str, Any]: """创建微信支付订单(占位实现)""" # 这里应该集成微信支付 SDK plan = self.get_plan(plan_id) @@ -1666,7 +1877,7 @@ class SubscriptionManager: "currency": plan.currency, "prepay_id": f"wx{uuid.uuid4().hex[:32]}", "status": "pending", - "provider": "wechat" + "provider": "wechat", } def handle_webhook(self, provider: str, payload: dict[str, Any]) -> bool: @@ -1696,220 +1907,221 @@ class SubscriptionManager: def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan: """数据库行转换为 SubscriptionPlan 对象""" return SubscriptionPlan( - id=row['id'], - name=row['name'], - tier=row['tier'], - description=row['description'] or "", - price_monthly=row['price_monthly'], - price_yearly=row['price_yearly'], - currency=row['currency'], - features=json.loads( - row['features'] or '[]'), - limits=json.loads( - row['limits'] or '{}'), - is_active=bool( - row['is_active']), - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at'], - metadata=json.loads( - row['metadata'] or '{}')) + id=row["id"], + name=row["name"], + tier=row["tier"], + description=row["description"] or "", + price_monthly=row["price_monthly"], + price_yearly=row["price_yearly"], + currency=row["currency"], + features=json.loads(row["features"] or "[]"), + limits=json.loads(row["limits"] or "{}"), + is_active=bool(row["is_active"]), + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + metadata=json.loads(row["metadata"] or "{}"), + ) def _row_to_subscription(self, row: sqlite3.Row) -> Subscription: """数据库行转换为 Subscription 对象""" return Subscription( - id=row['id'], - tenant_id=row['tenant_id'], - plan_id=row['plan_id'], - status=row['status'], - current_period_start=datetime.fromisoformat( - row['current_period_start']) if row['current_period_start'] and isinstance( - row['current_period_start'], - str) else row['current_period_start'], - current_period_end=datetime.fromisoformat( - row['current_period_end']) if row['current_period_end'] and isinstance( - row['current_period_end'], - str) else row['current_period_end'], - cancel_at_period_end=bool( - row['cancel_at_period_end']), - canceled_at=datetime.fromisoformat( - row['canceled_at']) if row['canceled_at'] and isinstance( - row['canceled_at'], - str) else row['canceled_at'], - trial_start=datetime.fromisoformat( - row['trial_start']) if row['trial_start'] and isinstance( - row['trial_start'], - str) else row['trial_start'], - trial_end=datetime.fromisoformat( - row['trial_end']) if row['trial_end'] and isinstance( - row['trial_end'], - str) else row['trial_end'], - payment_provider=row['payment_provider'], - provider_subscription_id=row['provider_subscription_id'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at'], - metadata=json.loads( - row['metadata'] or '{}')) + id=row["id"], + tenant_id=row["tenant_id"], + plan_id=row["plan_id"], + status=row["status"], + current_period_start=( + datetime.fromisoformat(row["current_period_start"]) + if row["current_period_start"] and isinstance(row["current_period_start"], str) + else row["current_period_start"] + ), + current_period_end=( + datetime.fromisoformat(row["current_period_end"]) + if row["current_period_end"] and isinstance(row["current_period_end"], str) + else row["current_period_end"] + ), + cancel_at_period_end=bool(row["cancel_at_period_end"]), + canceled_at=( + datetime.fromisoformat(row["canceled_at"]) + if row["canceled_at"] and isinstance(row["canceled_at"], str) + else row["canceled_at"] + ), + trial_start=( + datetime.fromisoformat(row["trial_start"]) + if row["trial_start"] and isinstance(row["trial_start"], str) + else row["trial_start"] + ), + trial_end=( + datetime.fromisoformat(row["trial_end"]) + if row["trial_end"] and isinstance(row["trial_end"], str) + else row["trial_end"] + ), + payment_provider=row["payment_provider"], + provider_subscription_id=row["provider_subscription_id"], + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + metadata=json.loads(row["metadata"] or "{}"), + ) def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord: """数据库行转换为 UsageRecord 对象""" return UsageRecord( - id=row['id'], - tenant_id=row['tenant_id'], - resource_type=row['resource_type'], - quantity=row['quantity'], - unit=row['unit'], - recorded_at=datetime.fromisoformat( - row['recorded_at']) if isinstance( - row['recorded_at'], - str) else row['recorded_at'], - cost=row['cost'], - description=row['description'], - metadata=json.loads( - row['metadata'] or '{}')) + id=row["id"], + tenant_id=row["tenant_id"], + resource_type=row["resource_type"], + quantity=row["quantity"], + unit=row["unit"], + recorded_at=( + datetime.fromisoformat(row["recorded_at"]) + if isinstance(row["recorded_at"], str) + else row["recorded_at"] + ), + cost=row["cost"], + description=row["description"], + metadata=json.loads(row["metadata"] or "{}"), + ) def _row_to_payment(self, row: sqlite3.Row) -> Payment: """数据库行转换为 Payment 对象""" return Payment( - id=row['id'], - tenant_id=row['tenant_id'], - subscription_id=row['subscription_id'], - invoice_id=row['invoice_id'], - amount=row['amount'], - currency=row['currency'], - provider=row['provider'], - provider_payment_id=row['provider_payment_id'], - status=row['status'], - payment_method=row['payment_method'], - payment_details=json.loads( - row['payment_details'] or '{}'), - paid_at=datetime.fromisoformat( - row['paid_at']) if row['paid_at'] and isinstance( - row['paid_at'], - str) else row['paid_at'], - failed_at=datetime.fromisoformat( - row['failed_at']) if row['failed_at'] and isinstance( - row['failed_at'], - str) else row['failed_at'], - failure_reason=row['failure_reason'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at']) + id=row["id"], + tenant_id=row["tenant_id"], + subscription_id=row["subscription_id"], + invoice_id=row["invoice_id"], + amount=row["amount"], + currency=row["currency"], + provider=row["provider"], + provider_payment_id=row["provider_payment_id"], + status=row["status"], + payment_method=row["payment_method"], + payment_details=json.loads(row["payment_details"] or "{}"), + paid_at=( + datetime.fromisoformat(row["paid_at"]) + if row["paid_at"] and isinstance(row["paid_at"], str) + else row["paid_at"] + ), + failed_at=( + datetime.fromisoformat(row["failed_at"]) + if row["failed_at"] and isinstance(row["failed_at"], str) + else row["failed_at"] + ), + failure_reason=row["failure_reason"], + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + ) def _row_to_invoice(self, row: sqlite3.Row) -> Invoice: """数据库行转换为 Invoice 对象""" return Invoice( - id=row['id'], - tenant_id=row['tenant_id'], - subscription_id=row['subscription_id'], - invoice_number=row['invoice_number'], - status=row['status'], - amount_due=row['amount_due'], - amount_paid=row['amount_paid'], - currency=row['currency'], - period_start=datetime.fromisoformat( - row['period_start']) if row['period_start'] and isinstance( - row['period_start'], - str) else row['period_start'], - period_end=datetime.fromisoformat( - row['period_end']) if row['period_end'] and isinstance( - row['period_end'], - str) else row['period_end'], - description=row['description'], - line_items=json.loads( - row['line_items'] or '[]'), - due_date=datetime.fromisoformat( - row['due_date']) if row['due_date'] and isinstance( - row['due_date'], - str) else row['due_date'], - paid_at=datetime.fromisoformat( - row['paid_at']) if row['paid_at'] and isinstance( - row['paid_at'], - str) else row['paid_at'], - voided_at=datetime.fromisoformat( - row['voided_at']) if row['voided_at'] and isinstance( - row['voided_at'], - str) else row['voided_at'], - void_reason=row['void_reason'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at']) + id=row["id"], + tenant_id=row["tenant_id"], + subscription_id=row["subscription_id"], + invoice_number=row["invoice_number"], + status=row["status"], + amount_due=row["amount_due"], + amount_paid=row["amount_paid"], + currency=row["currency"], + period_start=( + datetime.fromisoformat(row["period_start"]) + if row["period_start"] and isinstance(row["period_start"], str) + else row["period_start"] + ), + period_end=( + datetime.fromisoformat(row["period_end"]) + if row["period_end"] and isinstance(row["period_end"], str) + else row["period_end"] + ), + description=row["description"], + line_items=json.loads(row["line_items"] or "[]"), + due_date=( + datetime.fromisoformat(row["due_date"]) + if row["due_date"] and isinstance(row["due_date"], str) + else row["due_date"] + ), + paid_at=( + datetime.fromisoformat(row["paid_at"]) + if row["paid_at"] and isinstance(row["paid_at"], str) + else row["paid_at"] + ), + voided_at=( + datetime.fromisoformat(row["voided_at"]) + if row["voided_at"] and isinstance(row["voided_at"], str) + else row["voided_at"] + ), + void_reason=row["void_reason"], + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + ) def _row_to_refund(self, row: sqlite3.Row) -> Refund: """数据库行转换为 Refund 对象""" return Refund( - id=row['id'], - tenant_id=row['tenant_id'], - payment_id=row['payment_id'], - invoice_id=row['invoice_id'], - amount=row['amount'], - currency=row['currency'], - reason=row['reason'], - status=row['status'], - requested_by=row['requested_by'], - requested_at=datetime.fromisoformat( - row['requested_at']) if isinstance( - row['requested_at'], - str) else row['requested_at'], - approved_by=row['approved_by'], - approved_at=datetime.fromisoformat( - row['approved_at']) if row['approved_at'] and isinstance( - row['approved_at'], - str) else row['approved_at'], - completed_at=datetime.fromisoformat( - row['completed_at']) if row['completed_at'] and isinstance( - row['completed_at'], - str) else row['completed_at'], - provider_refund_id=row['provider_refund_id'], - metadata=json.loads( - row['metadata'] or '{}'), - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at']) + id=row["id"], + tenant_id=row["tenant_id"], + payment_id=row["payment_id"], + invoice_id=row["invoice_id"], + amount=row["amount"], + currency=row["currency"], + reason=row["reason"], + status=row["status"], + requested_by=row["requested_by"], + requested_at=( + datetime.fromisoformat(row["requested_at"]) + if isinstance(row["requested_at"], str) + else row["requested_at"] + ), + approved_by=row["approved_by"], + approved_at=( + datetime.fromisoformat(row["approved_at"]) + if row["approved_at"] and isinstance(row["approved_at"], str) + else row["approved_at"] + ), + completed_at=( + datetime.fromisoformat(row["completed_at"]) + if row["completed_at"] and isinstance(row["completed_at"], str) + else row["completed_at"] + ), + provider_refund_id=row["provider_refund_id"], + metadata=json.loads(row["metadata"] or "{}"), + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + ) def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory: """数据库行转换为 BillingHistory 对象""" return BillingHistory( - id=row['id'], - tenant_id=row['tenant_id'], - type=row['type'], - amount=row['amount'], - currency=row['currency'], - description=row['description'], - reference_id=row['reference_id'], - balance_after=row['balance_after'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - metadata=json.loads( - row['metadata'] or '{}')) + id=row["id"], + tenant_id=row["tenant_id"], + type=row["type"], + amount=row["amount"], + currency=row["currency"], + description=row["description"], + reference_id=row["reference_id"], + balance_after=row["balance_after"], + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + metadata=json.loads(row["metadata"] or "{}"), + ) # 全局订阅管理器实例 diff --git a/backend/tenant_manager.py b/backend/tenant_manager.py index 5b71a25..5a38250 100644 --- a/backend/tenant_manager.py +++ b/backend/tenant_manager.py @@ -1,4 +1,3 @@ - """ InsightFlow Phase 8 - 多租户 SaaS 架构管理模块 @@ -27,6 +26,7 @@ logger = logging.getLogger(__name__) class TenantLimits: """租户资源限制常量""" + FREE_MAX_PROJECTS = 3 FREE_MAX_STORAGE_MB = 100 FREE_MAX_TRANSCRIPTION_MINUTES = 60 @@ -46,83 +46,90 @@ class TenantLimits: class TenantStatus(StrEnum): """租户状态""" - ACTIVE = "active" # 活跃 - SUSPENDED = "suspended" # 暂停 - TRIAL = "trial" # 试用 - EXPIRED = "expired" # 过期 - PENDING = "pending" # 待激活 + + ACTIVE = "active" # 活跃 + SUSPENDED = "suspended" # 暂停 + TRIAL = "trial" # 试用 + EXPIRED = "expired" # 过期 + PENDING = "pending" # 待激活 class TenantTier(StrEnum): """租户订阅层级""" - FREE = "free" # 免费版 - PRO = "pro" # 专业版 - ENTERPRISE = "enterprise" # 企业版 + + FREE = "free" # 免费版 + PRO = "pro" # 专业版 + ENTERPRISE = "enterprise" # 企业版 class TenantRole(StrEnum): """租户角色""" - OWNER = "owner" # 所有者 - ADMIN = "admin" # 管理员 - MEMBER = "member" # 成员 - VIEWER = "viewer" # 查看者 + + OWNER = "owner" # 所有者 + ADMIN = "admin" # 管理员 + MEMBER = "member" # 成员 + VIEWER = "viewer" # 查看者 class DomainStatus(StrEnum): """域名状态""" - PENDING = "pending" # 待验证 - VERIFIED = "verified" # 已验证 - FAILED = "failed" # 验证失败 - EXPIRED = "expired" # 已过期 + + PENDING = "pending" # 待验证 + VERIFIED = "verified" # 已验证 + FAILED = "failed" # 验证失败 + EXPIRED = "expired" # 已过期 @dataclass class Tenant: """租户数据类""" + id: str name: str - slug: str # URL 友好的唯一标识 + slug: str # URL 友好的唯一标识 description: str | None - tier: str # free/pro/enterprise - status: str # active/suspended/trial/expired/pending - owner_id: str # 所有者用户ID + tier: str # free/pro/enterprise + status: str # active/suspended/trial/expired/pending + owner_id: str # 所有者用户ID created_at: datetime updated_at: datetime expires_at: datetime | None # 订阅过期时间 - settings: dict[str, Any] # 租户级设置 + settings: dict[str, Any] # 租户级设置 resource_limits: dict[str, Any] # 资源限制 - metadata: dict[str, Any] # 元数据 + metadata: dict[str, Any] # 元数据 @dataclass class TenantDomain: """租户域名数据类""" + id: str tenant_id: str - domain: str # 自定义域名 - status: str # pending/verified/failed/expired - verification_token: str # 验证令牌 - verification_method: str # dns/file + domain: str # 自定义域名 + status: str # pending/verified/failed/expired + verification_token: str # 验证令牌 + verification_method: str # dns/file verified_at: datetime | None created_at: datetime updated_at: datetime - is_primary: bool # 是否主域名 - ssl_enabled: bool # SSL 是否启用 + is_primary: bool # 是否主域名 + ssl_enabled: bool # SSL 是否启用 ssl_expires_at: datetime | None @dataclass class TenantBranding: """租户品牌配置数据类""" + id: str tenant_id: str - logo_url: str | None # Logo URL - favicon_url: str | None # Favicon URL - primary_color: str | None # 主题主色 + logo_url: str | None # Logo URL + favicon_url: str | None # Favicon URL + primary_color: str | None # 主题主色 secondary_color: str | None # 主题次色 - custom_css: str | None # 自定义 CSS - custom_js: str | None # 自定义 JS - login_page_bg: str | None # 登录页背景 + custom_css: str | None # 自定义 CSS + custom_js: str | None # 自定义 JS + login_page_bg: str | None # 登录页背景 email_template: str | None # 邮件模板 created_at: datetime updated_at: datetime @@ -131,30 +138,32 @@ class TenantBranding: @dataclass class TenantMember: """租户成员数据类""" + id: str tenant_id: str user_id: str email: str - role: str # owner/admin/member/viewer - permissions: list[str] # 具体权限列表 - invited_by: str | None # 邀请者 + role: str # owner/admin/member/viewer + permissions: list[str] # 具体权限列表 + invited_by: str | None # 邀请者 invited_at: datetime joined_at: datetime | None last_active_at: datetime | None - status: str # active/pending/suspended + status: str # active/pending/suspended @dataclass class TenantPermission: """租户权限定义数据类""" + id: str tenant_id: str - name: str # 权限名称 - code: str # 权限代码 + name: str # 权限名称 + code: str # 权限代码 description: str | None - resource_type: str # project/entity/api/etc - actions: list[str] # create/read/update/delete/etc - conditions: dict | None # 条件限制 + resource_type: str # project/entity/api/etc + actions: list[str] # create/read/update/delete/etc + conditions: dict | None # 条件限制 created_at: datetime @@ -170,7 +179,7 @@ class TenantManager: "max_api_calls_per_day": TenantLimits.FREE_MAX_API_CALLS_PER_DAY, "max_team_members": TenantLimits.FREE_MAX_TEAM_MEMBERS, "max_entities": TenantLimits.FREE_MAX_ENTITIES, - "features": ["basic_analysis", "export_png"] + "features": ["basic_analysis", "export_png"], }, TenantTier.PRO: { "max_projects": TenantLimits.PRO_MAX_PROJECTS, @@ -179,8 +188,14 @@ class TenantManager: "max_api_calls_per_day": TenantLimits.PRO_MAX_API_CALLS_PER_DAY, "max_team_members": TenantLimits.PRO_MAX_TEAM_MEMBERS, "max_entities": TenantLimits.PRO_MAX_ENTITIES, - "features": ["basic_analysis", "advanced_analysis", "export_all", - "api_access", "webhooks", "collaboration"] + "features": [ + "basic_analysis", + "advanced_analysis", + "export_all", + "api_access", + "webhooks", + "collaboration", + ], }, TenantTier.ENTERPRISE: { "max_projects": TenantLimits.UNLIMITED, # 无限制 @@ -189,27 +204,23 @@ class TenantManager: "max_api_calls_per_day": TenantLimits.UNLIMITED, "max_team_members": TenantLimits.UNLIMITED, "max_entities": TenantLimits.UNLIMITED, - "features": ["all"] # 所有功能 - } + "features": ["all"], # 所有功能 + }, } # 角色权限映射 ROLE_PERMISSIONS = { - TenantRole.OWNER: [ - "tenant:*", "project:*", "member:*", "billing:*", - "settings:*", "api:*", "export:*" - ], - TenantRole.ADMIN: [ - "tenant:read", "project:*", "member:*", "billing:read", - "settings:*", "api:*", "export:*" - ], + TenantRole.OWNER: ["tenant:*", "project:*", "member:*", "billing:*", "settings:*", "api:*", "export:*"], + TenantRole.ADMIN: ["tenant:read", "project:*", "member:*", "billing:read", "settings:*", "api:*", "export:*"], TenantRole.MEMBER: [ - "tenant:read", "project:create", "project:read", "project:update", - "member:read", "export:basic" + "tenant:read", + "project:create", + "project:read", + "project:update", + "member:read", + "export:basic", ], - TenantRole.VIEWER: [ - "tenant:read", "project:read", "member:read" - ] + TenantRole.VIEWER: ["tenant:read", "project:read", "member:read"], } # 权限名称映射 @@ -227,7 +238,7 @@ class TenantManager: "settings:*": "设置完全控制", "api:*": "API完全控制", "export:*": "导出完全控制", - "export:basic": "基础导出" + "export:basic": "基础导出", } def __init__(self, db_path: str = "insightflow.db"): @@ -240,7 +251,7 @@ class TenantManager: conn.row_factory = sqlite3.Row return conn - def _init_db(self): + def _init_db(self) -> None: """初始化数据库表""" conn = self._get_connection() try: @@ -379,10 +390,9 @@ class TenantManager: # ==================== 租户管理 ==================== - def create_tenant(self, name: str, owner_id: str, - tier: str = "free", - description: str | None = None, - settings: dict | None = None) -> Tenant: + def create_tenant( + self, name: str, owner_id: str, tier: str = "free", description: str | None = None, settings: dict | None = None + ) -> Tenant: """创建新租户""" conn = self._get_connection() try: @@ -406,21 +416,32 @@ class TenantManager: expires_at=None, settings=settings or {}, resource_limits=resource_limits, - metadata={} + metadata={}, ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO tenants (id, name, slug, description, tier, status, owner_id, created_at, updated_at, expires_at, settings, resource_limits, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - tenant.id, tenant.name, tenant.slug, tenant.description, - tenant.tier, tenant.status, tenant.owner_id, - tenant.created_at, tenant.updated_at, tenant.expires_at, - json.dumps(tenant.settings), json.dumps(tenant.resource_limits), - json.dumps(tenant.metadata) - )) + """, + ( + tenant.id, + tenant.name, + tenant.slug, + tenant.description, + tenant.tier, + tenant.status, + tenant.owner_id, + tenant.created_at, + tenant.updated_at, + tenant.expires_at, + json.dumps(tenant.settings), + json.dumps(tenant.resource_limits), + json.dumps(tenant.metadata), + ), + ) # 自动将所有者添加为成员 self._add_member_internal(conn, tenant_id, owner_id, "", TenantRole.OWNER, None) @@ -471,11 +492,14 @@ class TenantManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT t.* FROM tenants t JOIN tenant_domains d ON t.id = d.tenant_id WHERE d.domain = ? AND d.status = 'verified' - """, (domain,)) + """, + (domain,), + ) row = cursor.fetchone() if row: @@ -485,12 +509,15 @@ class TenantManager: finally: conn.close() - def update_tenant(self, tenant_id: str, - name: str | None = None, - description: str | None = None, - tier: str | None = None, - status: str | None = None, - settings: dict | None = None) -> Tenant | None: + def update_tenant( + self, + tenant_id: str, + name: str | None = None, + description: str | None = None, + tier: str | None = None, + status: str | None = None, + settings: dict | None = None, + ) -> Tenant | None: """更新租户信息""" conn = self._get_connection() try: @@ -526,10 +553,13 @@ class TenantManager: params.append(tenant_id) cursor = conn.cursor() - cursor.execute(f""" + cursor.execute( + f""" UPDATE tenants SET {', '.join(updates)} WHERE id = ? - """, params) + """, + params, + ) conn.commit() return self.get_tenant(tenant_id) @@ -548,9 +578,9 @@ class TenantManager: finally: conn.close() - def list_tenants(self, status: str | None = None, - tier: str | None = None, - limit: int = 100, offset: int = 0) -> list[Tenant]: + def list_tenants( + self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0 + ) -> list[Tenant]: """列出租户""" conn = self._get_connection() try: @@ -579,9 +609,9 @@ class TenantManager: # ==================== 域名管理 ==================== - def add_domain(self, tenant_id: str, domain: str, - is_primary: bool = False, - verification_method: str = "dns") -> TenantDomain: + def add_domain( + self, tenant_id: str, domain: str, is_primary: bool = False, verification_method: str = "dns" + ) -> TenantDomain: """为租户添加自定义域名""" conn = self._get_connection() try: @@ -605,31 +635,43 @@ class TenantManager: updated_at=datetime.now(), is_primary=is_primary, ssl_enabled=False, - ssl_expires_at=None + ssl_expires_at=None, ) cursor = conn.cursor() # 如果设为主域名,取消其他主域名 if is_primary: - cursor.execute(""" + cursor.execute( + """ UPDATE tenant_domains SET is_primary = 0 WHERE tenant_id = ? - """, (tenant_id,)) + """, + (tenant_id,), + ) - cursor.execute(""" + cursor.execute( + """ INSERT INTO tenant_domains (id, tenant_id, domain, status, verification_token, verification_method, verified_at, created_at, updated_at, is_primary, ssl_enabled, ssl_expires_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - tenant_domain.id, tenant_domain.tenant_id, tenant_domain.domain, - tenant_domain.status, tenant_domain.verification_token, - tenant_domain.verification_method, tenant_domain.verified_at, - tenant_domain.created_at, tenant_domain.updated_at, - int(tenant_domain.is_primary), int(tenant_domain.ssl_enabled), - tenant_domain.ssl_expires_at - )) + """, + ( + tenant_domain.id, + tenant_domain.tenant_id, + tenant_domain.domain, + tenant_domain.status, + tenant_domain.verification_token, + tenant_domain.verification_method, + tenant_domain.verified_at, + tenant_domain.created_at, + tenant_domain.updated_at, + int(tenant_domain.is_primary), + int(tenant_domain.ssl_enabled), + tenant_domain.ssl_expires_at, + ), + ) conn.commit() logger.info(f"Domain added: {domain} for tenant {tenant_id}") @@ -649,36 +691,45 @@ class TenantManager: cursor = conn.cursor() # 获取域名信息 - cursor.execute(""" + cursor.execute( + """ SELECT * FROM tenant_domains WHERE id = ? AND tenant_id = ? - """, (domain_id, tenant_id)) + """, + (domain_id, tenant_id), + ) row = cursor.fetchone() if not row: return False - domain = row['domain'] - token = row['verification_token'] - method = row['verification_method'] + domain = row["domain"] + token = row["verification_token"] + method = row["verification_method"] # 执行验证 is_verified = self._check_domain_verification(domain, token, method) if is_verified: - cursor.execute(""" + cursor.execute( + """ UPDATE tenant_domains SET status = 'verified', verified_at = ?, updated_at = ? WHERE id = ? - """, (datetime.now(), datetime.now(), domain_id)) + """, + (datetime.now(), datetime.now(), domain_id), + ) conn.commit() logger.info(f"Domain verified: {domain}") else: - cursor.execute(""" + cursor.execute( + """ UPDATE tenant_domains SET status = 'failed', updated_at = ? WHERE id = ? - """, (datetime.now(), domain_id)) + """, + (datetime.now(), domain_id), + ) conn.commit() return is_verified @@ -700,26 +751,23 @@ class TenantManager: if not row: return None - domain = row['domain'] - token = row['verification_token'] + domain = row["domain"] + token = row["verification_token"] return { "domain": domain, - "verification_method": row['verification_method'], + "verification_method": row["verification_method"], "dns_record": { "type": "TXT", "name": "_insightflow", "value": f"insightflow-verify={token}", - "ttl": 3600 - }, - "file_verification": { - "url": f"http://{domain}/.well-known/insightflow-verify.txt", - "content": token + "ttl": 3600, }, + "file_verification": {"url": f"http://{domain}/.well-known/insightflow-verify.txt", "content": token}, "instructions": [ f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}", - f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt,内容为 {token}" - ] + f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt,内容为 {token}", + ], } finally: @@ -730,10 +778,13 @@ class TenantManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ DELETE FROM tenant_domains WHERE id = ? AND tenant_id = ? - """, (domain_id, tenant_id)) + """, + (domain_id, tenant_id), + ) conn.commit() return cursor.rowcount > 0 finally: @@ -744,11 +795,14 @@ class TenantManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT * FROM tenant_domains WHERE tenant_id = ? ORDER BY is_primary DESC, created_at DESC - """, (tenant_id,)) + """, + (tenant_id,), + ) rows = cursor.fetchall() return [self._row_to_domain(row) for row in rows] @@ -773,15 +827,18 @@ class TenantManager: finally: conn.close() - def update_branding(self, tenant_id: str, - logo_url: str | None = None, - favicon_url: str | None = None, - primary_color: str | None = None, - secondary_color: str | None = None, - custom_css: str | None = None, - custom_js: str | None = None, - login_page_bg: str | None = None, - email_template: str | None = None) -> TenantBranding: + def update_branding( + self, + tenant_id: str, + logo_url: str | None = None, + favicon_url: str | None = None, + primary_color: str | None = None, + secondary_color: str | None = None, + custom_css: str | None = None, + custom_js: str | None = None, + login_page_bg: str | None = None, + email_template: str | None = None, + ) -> TenantBranding: """更新租户品牌配置""" conn = self._get_connection() try: @@ -825,23 +882,38 @@ class TenantManager: params.append(datetime.now()) params.append(tenant_id) - cursor.execute(f""" + cursor.execute( + f""" UPDATE tenant_branding SET {', '.join(updates)} WHERE tenant_id = ? - """, params) + """, + params, + ) else: # 创建 branding_id = str(uuid.uuid4()) - cursor.execute(""" + cursor.execute( + """ INSERT INTO tenant_branding (id, tenant_id, logo_url, favicon_url, primary_color, secondary_color, custom_css, custom_js, login_page_bg, email_template, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - branding_id, tenant_id, logo_url, favicon_url, primary_color, - secondary_color, custom_css, custom_js, login_page_bg, email_template, - datetime.now(), datetime.now() - )) + """, + ( + branding_id, + tenant_id, + logo_url, + favicon_url, + primary_color, + secondary_color, + custom_css, + custom_js, + login_page_bg, + email_template, + datetime.now(), + datetime.now(), + ), + ) conn.commit() return self.get_branding(tenant_id) @@ -889,8 +961,9 @@ class TenantManager: # ==================== 成员与权限管理 ==================== - def invite_member(self, tenant_id: str, email: str, role: str, - invited_by: str, permissions: list[str] | None = None) -> TenantMember: + def invite_member( + self, tenant_id: str, email: str, role: str, invited_by: str, permissions: list[str] | None = None + ) -> TenantMember: """邀请成员加入租户""" conn = self._get_connection() try: @@ -912,21 +985,31 @@ class TenantManager: invited_at=datetime.now(), joined_at=None, last_active_at=None, - status="pending" + status="pending", ) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO tenant_members (id, tenant_id, user_id, email, role, permissions, invited_by, invited_at, joined_at, last_active_at, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - member.id, member.tenant_id, member.user_id, member.email, - member.role, json.dumps(member.permissions), member.invited_by, - member.invited_at, member.joined_at, member.last_active_at, - member.status - )) + """, + ( + member.id, + member.tenant_id, + member.user_id, + member.email, + member.role, + json.dumps(member.permissions), + member.invited_by, + member.invited_at, + member.joined_at, + member.last_active_at, + member.status, + ), + ) conn.commit() logger.info(f"Member invited: {email} to tenant {tenant_id}") @@ -940,11 +1023,14 @@ class TenantManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE tenant_members SET user_id = ?, status = 'active', joined_at = ? WHERE id = ? AND status = 'pending' - """, (user_id, datetime.now(), invitation_id)) + """, + (user_id, datetime.now(), invitation_id), + ) conn.commit() return cursor.rowcount > 0 @@ -957,17 +1043,21 @@ class TenantManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ DELETE FROM tenant_members WHERE id = ? AND tenant_id = ? - """, (member_id, tenant_id)) + """, + (member_id, tenant_id), + ) conn.commit() return cursor.rowcount > 0 finally: conn.close() - def update_member_role(self, tenant_id: str, member_id: str, - role: str, permissions: list[str] | None = None) -> bool: + def update_member_role( + self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None + ) -> bool: """更新成员角色""" conn = self._get_connection() try: @@ -976,11 +1066,14 @@ class TenantManager: final_permissions = permissions or default_permissions cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ UPDATE tenant_members SET role = ?, permissions = ?, updated_at = ? WHERE id = ? AND tenant_id = ? - """, (role, json.dumps(final_permissions), datetime.now(), member_id, tenant_id)) + """, + (role, json.dumps(final_permissions), datetime.now(), member_id, tenant_id), + ) conn.commit() return cursor.rowcount > 0 @@ -1011,23 +1104,25 @@ class TenantManager: finally: conn.close() - def check_permission(self, tenant_id: str, user_id: str, - resource: str, action: str) -> bool: + def check_permission(self, tenant_id: str, user_id: str, resource: str, action: str) -> bool: """检查用户是否有特定权限""" conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT role, permissions FROM tenant_members WHERE tenant_id = ? AND user_id = ? AND status = 'active' - """, (tenant_id, user_id)) + """, + (tenant_id, user_id), + ) row = cursor.fetchone() if not row: return False - role = row['role'] - permissions = json.loads(row['permissions'] or '[]') + role = row["role"] + permissions = json.loads(row["permissions"] or "[]") # 所有者拥有所有权限 if role == TenantRole.OWNER.value: @@ -1047,23 +1142,22 @@ class TenantManager: conn = self._get_connection() try: cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT t.*, m.role, m.status as member_status FROM tenants t JOIN tenant_members m ON t.id = m.tenant_id WHERE m.user_id = ? AND m.status = 'active' ORDER BY t.created_at DESC - """, (user_id,)) + """, + (user_id,), + ) rows = cursor.fetchall() result = [] for row in rows: tenant = self._row_to_tenant(row) - result.append({ - **asdict(tenant), - "member_role": row['role'], - "member_status": row['member_status'] - }) + result.append({**asdict(tenant), "member_role": row["role"], "member_status": row["member_status"]}) return result finally: @@ -1071,13 +1165,16 @@ class TenantManager: # ==================== 资源使用统计 ==================== - def record_usage(self, tenant_id: str, - storage_bytes: int = 0, - transcription_seconds: int = 0, - api_calls: int = 0, - projects_count: int = 0, - entities_count: int = 0, - members_count: int = 0): + def record_usage( + self, + tenant_id: str, + storage_bytes: int = 0, + transcription_seconds: int = 0, + api_calls: int = 0, + projects_count: int = 0, + entities_count: int = 0, + members_count: int = 0, + ): """记录资源使用""" conn = self._get_connection() try: @@ -1085,7 +1182,8 @@ class TenantManager: usage_id = str(uuid.uuid4()) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ INSERT INTO tenant_usage (id, tenant_id, date, storage_bytes, transcription_seconds, api_calls, projects_count, entities_count, members_count) @@ -1097,19 +1195,28 @@ class TenantManager: projects_count = MAX(projects_count, excluded.projects_count), entities_count = MAX(entities_count, excluded.entities_count), members_count = MAX(members_count, excluded.members_count) - """, ( - usage_id, tenant_id, today, storage_bytes, transcription_seconds, - api_calls, projects_count, entities_count, members_count - )) + """, + ( + usage_id, + tenant_id, + today, + storage_bytes, + transcription_seconds, + api_calls, + projects_count, + entities_count, + members_count, + ), + ) conn.commit() finally: conn.close() - def get_usage_stats(self, tenant_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None) -> dict[str, Any]: + def get_usage_stats( + self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None + ) -> dict[str, Any]: """获取使用统计""" conn = self._get_connection() try: @@ -1143,23 +1250,29 @@ class TenantManager: limits = tenant.resource_limits if tenant else {} return { - "storage_bytes": row['total_storage'] or 0, - "storage_mb": (row['total_storage'] or 0) / (1024 * 1024), - "transcription_seconds": row['total_transcription'] or 0, - "transcription_minutes": (row['total_transcription'] or 0) / 60, - "api_calls": row['total_api_calls'] or 0, - "projects_count": row['max_projects'] or 0, - "entities_count": row['max_entities'] or 0, - "members_count": row['max_members'] or 0, + "storage_bytes": row["total_storage"] or 0, + "storage_mb": (row["total_storage"] or 0) / (1024 * 1024), + "transcription_seconds": row["total_transcription"] or 0, + "transcription_minutes": (row["total_transcription"] or 0) / 60, + "api_calls": row["total_api_calls"] or 0, + "projects_count": row["max_projects"] or 0, + "entities_count": row["max_entities"] or 0, + "members_count": row["max_members"] or 0, "limits": limits, "usage_percentages": { - "storage": self._calc_percentage(row['total_storage'] or 0, limits.get('max_storage_mb', 0) * 1024 * 1024), - "transcription": self._calc_percentage(row['total_transcription'] or 0, limits.get('max_transcription_minutes', 0) * 60), - "api_calls": self._calc_percentage(row['total_api_calls'] or 0, limits.get('max_api_calls_per_day', 0)), - "projects": self._calc_percentage(row['max_projects'] or 0, limits.get('max_projects', 0)), - "entities": self._calc_percentage(row['max_entities'] or 0, limits.get('max_entities', 0)), - "members": self._calc_percentage(row['max_members'] or 0, limits.get('max_team_members', 0)) - } + "storage": self._calc_percentage( + row["total_storage"] or 0, limits.get("max_storage_mb", 0) * 1024 * 1024 + ), + "transcription": self._calc_percentage( + row["total_transcription"] or 0, limits.get("max_transcription_minutes", 0) * 60 + ), + "api_calls": self._calc_percentage( + row["total_api_calls"] or 0, limits.get("max_api_calls_per_day", 0) + ), + "projects": self._calc_percentage(row["max_projects"] or 0, limits.get("max_projects", 0)), + "entities": self._calc_percentage(row["max_entities"] or 0, limits.get("max_entities", 0)), + "members": self._calc_percentage(row["max_members"] or 0, limits.get("max_team_members", 0)), + }, } finally: @@ -1179,12 +1292,12 @@ class TenantManager: stats = self.get_usage_stats(tenant_id) resource_map = { - "storage": ("storage_mb", stats['storage_mb']), - "transcription": ("max_transcription_minutes", stats['transcription_minutes']), - "api_calls": ("max_api_calls_per_day", stats['api_calls']), - "projects": ("max_projects", stats['projects_count']), - "entities": ("max_entities", stats['entities_count']), - "members": ("max_team_members", stats['members_count']) + "storage": ("storage_mb", stats["storage_mb"]), + "transcription": ("max_transcription_minutes", stats["transcription_minutes"]), + "api_calls": ("max_api_calls_per_day", stats["api_calls"]), + "projects": ("max_projects", stats["projects_count"]), + "entities": ("max_entities", stats["entities_count"]), + "members": ("max_team_members", stats["members_count"]), } if resource_type not in resource_map: @@ -1204,8 +1317,8 @@ class TenantManager: def _generate_slug(self, name: str) -> str: """生成 URL 友好的 slug""" # 转换为小写,替换空格为连字符 - slug = re.sub(r'[^\w\s-]', '', name.lower()) - slug = re.sub(r'[-\s]+', '-', slug) + slug = re.sub(r"[^\w\s-]", "", name.lower()) + slug = re.sub(r"[-\s]+", "-", slug) # 检查是否已存在 conn = self._get_connection() @@ -1233,7 +1346,7 @@ class TenantManager: def _validate_domain(self, domain: str) -> bool: """验证域名格式""" - pattern = r'^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$' + pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$" return bool(re.match(pattern, domain)) def _check_domain_verification(self, domain: str, token: str, method: str) -> bool: @@ -1269,7 +1382,7 @@ class TenantManager: def _darken_color(self, hex_color: str, percent: int) -> str: """加深颜色""" - hex_color = hex_color.lstrip('#') + hex_color = hex_color.lstrip("#") r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) @@ -1286,133 +1399,147 @@ class TenantManager: return 0.0 if limit == 0 else 100.0 return min(100.0, round(current / limit * 100, 2)) - def _add_member_internal(self, conn: sqlite3.Connection, tenant_id: str, - user_id: str, email: str, role: TenantRole, - invited_by: str | None): + def _add_member_internal( + self, + conn: sqlite3.Connection, + tenant_id: str, + user_id: str, + email: str, + role: TenantRole, + invited_by: str | None, + ): """内部方法:添加成员""" cursor = conn.cursor() member_id = str(uuid.uuid4()) - cursor.execute(""" + cursor.execute( + """ INSERT OR IGNORE INTO tenant_members (id, tenant_id, user_id, email, role, permissions, invited_by, invited_at, joined_at, last_active_at, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - member_id, tenant_id, user_id, email, role.value, - json.dumps(self.ROLE_PERMISSIONS.get(role, [])), - invited_by, datetime.now(), datetime.now(), datetime.now(), "active" - )) + """, + ( + member_id, + tenant_id, + user_id, + email, + role.value, + json.dumps(self.ROLE_PERMISSIONS.get(role, [])), + invited_by, + datetime.now(), + datetime.now(), + datetime.now(), + "active", + ), + ) def _row_to_tenant(self, row: sqlite3.Row) -> Tenant: """数据库行转换为 Tenant 对象""" return Tenant( - id=row['id'], - name=row['name'], - slug=row['slug'], - description=row['description'], - tier=row['tier'], - status=row['status'], - owner_id=row['owner_id'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at'], - expires_at=datetime.fromisoformat( - row['expires_at']) if row['expires_at'] and isinstance( - row['expires_at'], - str) else row['expires_at'], - settings=json.loads( - row['settings'] or '{}'), - resource_limits=json.loads( - row['resource_limits'] or '{}'), - metadata=json.loads( - row['metadata'] or '{}')) + id=row["id"], + name=row["name"], + slug=row["slug"], + description=row["description"], + tier=row["tier"], + status=row["status"], + owner_id=row["owner_id"], + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + expires_at=( + datetime.fromisoformat(row["expires_at"]) + if row["expires_at"] and isinstance(row["expires_at"], str) + else row["expires_at"] + ), + settings=json.loads(row["settings"] or "{}"), + resource_limits=json.loads(row["resource_limits"] or "{}"), + metadata=json.loads(row["metadata"] or "{}"), + ) def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain: """数据库行转换为 TenantDomain 对象""" return TenantDomain( - id=row['id'], - tenant_id=row['tenant_id'], - domain=row['domain'], - status=row['status'], - verification_token=row['verification_token'], - verification_method=row['verification_method'], - verified_at=datetime.fromisoformat( - row['verified_at']) if row['verified_at'] and isinstance( - row['verified_at'], - str) else row['verified_at'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at'], - is_primary=bool( - row['is_primary']), - ssl_enabled=bool( - row['ssl_enabled']), - ssl_expires_at=datetime.fromisoformat( - row['ssl_expires_at']) if row['ssl_expires_at'] and isinstance( - row['ssl_expires_at'], - str) else row['ssl_expires_at']) + id=row["id"], + tenant_id=row["tenant_id"], + domain=row["domain"], + status=row["status"], + verification_token=row["verification_token"], + verification_method=row["verification_method"], + verified_at=( + datetime.fromisoformat(row["verified_at"]) + if row["verified_at"] and isinstance(row["verified_at"], str) + else row["verified_at"] + ), + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + is_primary=bool(row["is_primary"]), + ssl_enabled=bool(row["ssl_enabled"]), + ssl_expires_at=( + datetime.fromisoformat(row["ssl_expires_at"]) + if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str) + else row["ssl_expires_at"] + ), + ) def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding: """数据库行转换为 TenantBranding 对象""" return TenantBranding( - id=row['id'], - tenant_id=row['tenant_id'], - logo_url=row['logo_url'], - favicon_url=row['favicon_url'], - primary_color=row['primary_color'], - secondary_color=row['secondary_color'], - custom_css=row['custom_css'], - custom_js=row['custom_js'], - login_page_bg=row['login_page_bg'], - email_template=row['email_template'], - created_at=datetime.fromisoformat( - row['created_at']) if isinstance( - row['created_at'], - str) else row['created_at'], - updated_at=datetime.fromisoformat( - row['updated_at']) if isinstance( - row['updated_at'], - str) else row['updated_at']) + id=row["id"], + tenant_id=row["tenant_id"], + logo_url=row["logo_url"], + favicon_url=row["favicon_url"], + primary_color=row["primary_color"], + secondary_color=row["secondary_color"], + custom_css=row["custom_css"], + custom_js=row["custom_js"], + login_page_bg=row["login_page_bg"], + email_template=row["email_template"], + created_at=( + datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] + ), + updated_at=( + datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] + ), + ) def _row_to_member(self, row: sqlite3.Row) -> TenantMember: """数据库行转换为 TenantMember 对象""" return TenantMember( - id=row['id'], - tenant_id=row['tenant_id'], - user_id=row['user_id'], - email=row['email'], - role=row['role'], - permissions=json.loads( - row['permissions'] or '[]'), - invited_by=row['invited_by'], - invited_at=datetime.fromisoformat( - row['invited_at']) if isinstance( - row['invited_at'], - str) else row['invited_at'], - joined_at=datetime.fromisoformat( - row['joined_at']) if row['joined_at'] and isinstance( - row['joined_at'], - str) else row['joined_at'], - last_active_at=datetime.fromisoformat( - row['last_active_at']) if row['last_active_at'] and isinstance( - row['last_active_at'], - str) else row['last_active_at'], - status=row['status']) + id=row["id"], + tenant_id=row["tenant_id"], + user_id=row["user_id"], + email=row["email"], + role=row["role"], + permissions=json.loads(row["permissions"] or "[]"), + invited_by=row["invited_by"], + invited_at=( + datetime.fromisoformat(row["invited_at"]) if isinstance(row["invited_at"], str) else row["invited_at"] + ), + joined_at=( + datetime.fromisoformat(row["joined_at"]) + if row["joined_at"] and isinstance(row["joined_at"], str) + else row["joined_at"] + ), + last_active_at=( + datetime.fromisoformat(row["last_active_at"]) + if row["last_active_at"] and isinstance(row["last_active_at"], str) + else row["last_active_at"] + ), + status=row["status"], + ) # ==================== 租户上下文管理 ==================== + class TenantContext: """租户上下文管理器 - 用于请求级别的租户隔离""" @@ -1420,7 +1547,7 @@ class TenantContext: _current_user_id: str | None = None @classmethod - def set_current_tenant(cls, tenant_id: str): + def set_current_tenant(cls, tenant_id: str) -> None: """设置当前租户上下文""" cls._current_tenant_id = tenant_id @@ -1430,7 +1557,7 @@ class TenantContext: return cls._current_tenant_id @classmethod - def set_current_user(cls, user_id: str): + def set_current_user(cls, user_id: str) -> None: """设置当前用户""" cls._current_user_id = user_id @@ -1440,7 +1567,7 @@ class TenantContext: return cls._current_user_id @classmethod - def clear(cls): + def clear(cls) -> None: """清除上下文""" cls._current_tenant_id = None cls._current_user_id = None diff --git a/backend/tingwu_client.py b/backend/tingwu_client.py index fab65c7..9ff3862 100644 --- a/backend/tingwu_client.py +++ b/backend/tingwu_client.py @@ -20,7 +20,7 @@ class TingwuClient: def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]: """阿里云签名 V3""" - timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') + timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") # 简化签名,实际生产需要完整实现 # 这里使用基础认证头 @@ -39,25 +39,16 @@ class TingwuClient: from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient - config = open_api_models.Config( - access_key_id=self.access_key, - access_key_secret=self.secret_key - ) + config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key) config.endpoint = "tingwu.cn-beijing.aliyuncs.com" client = TingwuSDKClient(config) request = tingwu_models.CreateTaskRequest( type="offline", - input=tingwu_models.Input( - source="OSS", - file_url=audio_url - ), + input=tingwu_models.Input(source="OSS", file_url=audio_url), parameters=tingwu_models.Parameters( - transcription=tingwu_models.Transcription( - diarization_enabled=True, - sentence_max_length=20 - ) - ) + transcription=tingwu_models.Transcription(diarization_enabled=True, sentence_max_length=20) + ), ) response = client.create_task(request) @@ -81,10 +72,7 @@ class TingwuClient: from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient - config = open_api_models.Config( - access_key_id=self.access_key, - access_key_secret=self.secret_key - ) + config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key) config.endpoint = "tingwu.cn-beijing.aliyuncs.com" client = TingwuSDKClient(config) @@ -128,25 +116,29 @@ class TingwuClient: if transcription.sentences: for sent in transcription.sentences: - segments.append({ - "start": sent.begin_time / 1000, - "end": sent.end_time / 1000, - "text": sent.text, - "speaker": f"Speaker {sent.speaker_id}" - }) + segments.append( + { + "start": sent.begin_time / 1000, + "end": sent.end_time / 1000, + "text": sent.text, + "speaker": f"Speaker {sent.speaker_id}", + } + ) - return { - "full_text": full_text.strip(), - "segments": segments - } + return {"full_text": full_text.strip(), "segments": segments} def _mock_result(self) -> dict[str, Any]: """Mock 结果""" return { "full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "segments": [ - {"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"} - ] + { + "start": 0.0, + "end": 5.0, + "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", + "speaker": "Speaker A", + } + ], } def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]: diff --git a/backend/workflow_manager.py b/backend/workflow_manager.py index 2404140..cb87aa4 100644 --- a/backend/workflow_manager.py +++ b/backend/workflow_manager.py @@ -10,8 +10,12 @@ InsightFlow Workflow Manager - Phase 7 """ import asyncio +import base64 +import hashlib +import hmac import json import logging +import urllib.parse import uuid from collections.abc import Callable from dataclasses import dataclass, field @@ -32,6 +36,7 @@ logger = logging.getLogger(__name__) class WorkflowStatus(Enum): """工作流状态""" + ACTIVE = "active" PAUSED = "paused" ERROR = "error" @@ -40,15 +45,17 @@ class WorkflowStatus(Enum): class WorkflowType(Enum): """工作流类型""" + AUTO_ANALYZE = "auto_analyze" # 自动分析新文件 - AUTO_ALIGN = "auto_align" # 自动实体对齐 + AUTO_ALIGN = "auto_align" # 自动实体对齐 AUTO_RELATION = "auto_relation" # 自动关系发现 SCHEDULED_REPORT = "scheduled_report" # 定时报告 - CUSTOM = "custom" # 自定义工作流 + CUSTOM = "custom" # 自定义工作流 class WebhookType(Enum): """Webhook 类型""" + FEISHU = "feishu" DINGTALK = "dingtalk" SLACK = "slack" @@ -57,6 +64,7 @@ class WebhookType(Enum): class TaskStatus(Enum): """任务执行状态""" + PENDING = "pending" RUNNING = "running" SUCCESS = "success" @@ -67,6 +75,7 @@ class TaskStatus(Enum): @dataclass class WorkflowTask: """工作流任务定义""" + id: str workflow_id: str name: str @@ -90,6 +99,7 @@ class WorkflowTask: @dataclass class WebhookConfig: """Webhook 配置""" + id: str name: str webhook_type: str # feishu, dingtalk, slack, custom @@ -114,6 +124,7 @@ class WebhookConfig: @dataclass class Workflow: """工作流定义""" + id: str name: str description: str @@ -143,6 +154,7 @@ class Workflow: @dataclass class WorkflowLog: """工作流执行日志""" + id: str workflow_id: str task_id: str | None = None @@ -186,20 +198,13 @@ class WebhookNotifier: async def _send_feishu(self, config: WebhookConfig, message: dict) -> bool: """发送飞书通知""" - import base64 - import hashlib - import hmac - timestamp = str(int(datetime.now().timestamp())) # 签名计算 if config.secret: string_to_sign = f"{timestamp}\n{config.secret}" - hmac_code = hmac.new( - string_to_sign.encode('utf-8'), - digestmod=hashlib.sha256 - ).digest() - sign = base64.b64encode(hmac_code).decode('utf-8') + hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() + sign = base64.b64encode(hmac_code).decode("utf-8") else: sign = "" @@ -210,9 +215,7 @@ class WebhookNotifier: "timestamp": timestamp, "sign": sign, "msg_type": "text", - "content": { - "text": message["content"] - } + "content": {"text": message["content"]}, } elif "title" in message: # 富文本消息 @@ -220,34 +223,15 @@ class WebhookNotifier: "timestamp": timestamp, "sign": sign, "msg_type": "post", - "content": { - "post": { - "zh_cn": { - "title": message.get("title", ""), - "content": message.get("body", []) - } - } - } + "content": {"post": {"zh_cn": {"title": message.get("title", ""), "content": message.get("body", [])}}}, } else: # 卡片消息 - payload = { - "timestamp": timestamp, - "sign": sign, - "msg_type": "interactive", - "card": message.get("card", {}) - } + payload = {"timestamp": timestamp, "sign": sign, "msg_type": "interactive", "card": message.get("card", {})} - headers = { - "Content-Type": "application/json", - **config.headers - } + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post( - config.url, - json=payload, - headers=headers - ) + response = await self.http_client.post(config.url, json=payload, headers=headers) response.raise_for_status() result = response.json() @@ -255,18 +239,13 @@ class WebhookNotifier: async def _send_dingtalk(self, config: WebhookConfig, message: dict) -> bool: """发送钉钉通知""" - import base64 - import hashlib - import hmac - import urllib.parse - timestamp = str(round(datetime.now().timestamp() * 1000)) # 签名计算 if config.secret: - secret_enc = config.secret.encode('utf-8') + secret_enc = config.secret.encode("utf-8") string_to_sign = f"{timestamp}\n{config.secret}" - hmac_code = hmac.new(secret_enc, string_to_sign.encode('utf-8'), digestmod=hashlib.sha256).digest() + hmac_code = hmac.new(secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) url = f"{config.url}×tamp={timestamp}&sign={sign}" else: @@ -274,19 +253,11 @@ class WebhookNotifier: # 构建消息体 if "content" in message: - payload = { - "msgtype": "text", - "text": { - "content": message["content"] - } - } + payload = {"msgtype": "text", "text": {"content": message["content"]}} elif "title" in message: payload = { "msgtype": "markdown", - "markdown": { - "title": message["title"], - "text": message.get("markdown", "") - } + "markdown": {"title": message["title"], "text": message.get("markdown", "")}, } elif "link" in message: payload = { @@ -295,19 +266,13 @@ class WebhookNotifier: "text": message.get("text", ""), "title": message["title"], "picUrl": message.get("pic_url", ""), - "messageUrl": message["link"] - } + "messageUrl": message["link"], + }, } else: - payload = { - "msgtype": "action_card", - "action_card": message.get("action_card", {}) - } + payload = {"msgtype": "action_card", "action_card": message.get("action_card", {})} - headers = { - "Content-Type": "application/json", - **config.headers - } + headers = {"Content-Type": "application/json", **config.headers} response = await self.http_client.post(url, json=payload, headers=headers) response.raise_for_status() @@ -328,32 +293,18 @@ class WebhookNotifier: if "attachments" in message: payload["attachments"] = message["attachments"] - headers = { - "Content-Type": "application/json", - **config.headers - } + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post( - config.url, - json=payload, - headers=headers - ) + response = await self.http_client.post(config.url, json=payload, headers=headers) response.raise_for_status() return response.text == "ok" async def _send_custom(self, config: WebhookConfig, message: dict) -> bool: """发送自定义 Webhook 通知""" - headers = { - "Content-Type": "application/json", - **config.headers - } + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post( - config.url, - json=message, - headers=headers - ) + response = await self.http_client.post(config.url, json=message, headers=headers) response.raise_for_status() return True @@ -380,12 +331,9 @@ class WorkflowManager: self._setup_default_handlers() # 添加调度器事件监听 - self.scheduler.add_listener( - self._on_job_executed, - EVENT_JOB_EXECUTED | EVENT_JOB_ERROR - ) + self.scheduler.add_listener(self._on_job_executed, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR) - def _setup_default_handlers(self): + def _setup_default_handlers(self) -> None: """设置默认的任务处理器""" self._task_handlers = { "analyze": self._handle_analyze_task, @@ -395,11 +343,11 @@ class WorkflowManager: "custom": self._handle_custom_task, } - def register_task_handler(self, task_type: str, handler: Callable): + def register_task_handler(self, task_type: str, handler: Callable) -> None: """注册自定义任务处理器""" self._task_handlers[task_type] = handler - def start(self): + def start(self) -> None: """启动工作流管理器""" if not self.scheduler.running: self.scheduler.start() @@ -409,7 +357,7 @@ class WorkflowManager: if self.db: asyncio.create_task(self._load_and_schedule_workflows()) - def stop(self): + def stop(self) -> None: """停止工作流管理器""" if self.scheduler.running: self.scheduler.shutdown(wait=True) @@ -425,7 +373,7 @@ class WorkflowManager: except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Failed to load workflows: {e}") - def _schedule_workflow(self, workflow: Workflow): + def _schedule_workflow(self, workflow: Workflow) -> None: """调度工作流""" job_id = f"workflow_{workflow.id}" @@ -450,7 +398,7 @@ class WorkflowManager: args=[workflow.id], replace_existing=True, max_instances=1, - coalesce=True + coalesce=True, ) logger.info(f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}") @@ -462,7 +410,7 @@ class WorkflowManager: except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Scheduled workflow execution failed: {e}") - def _on_job_executed(self, event): + def _on_job_executed(self, event) -> None: """调度器事件处理""" if event.exception: logger.error(f"Job {event.job_id} failed: {event.exception}") @@ -482,11 +430,26 @@ class WorkflowManager: created_at, updated_at, last_run_at, next_run_at, run_count, success_count, fail_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (workflow.id, workflow.name, workflow.description, workflow.workflow_type, - workflow.project_id, workflow.status, workflow.schedule, workflow.schedule_type, - json.dumps(workflow.config), json.dumps(workflow.webhook_ids), workflow.is_active, - workflow.created_at, workflow.updated_at, workflow.last_run_at, workflow.next_run_at, - workflow.run_count, workflow.success_count, workflow.fail_count) + ( + workflow.id, + workflow.name, + workflow.description, + workflow.workflow_type, + workflow.project_id, + workflow.status, + workflow.schedule, + workflow.schedule_type, + json.dumps(workflow.config), + json.dumps(workflow.webhook_ids), + workflow.is_active, + workflow.created_at, + workflow.updated_at, + workflow.last_run_at, + workflow.next_run_at, + workflow.run_count, + workflow.success_count, + workflow.fail_count, + ), ) conn.commit() @@ -502,10 +465,7 @@ class WorkflowManager: """获取工作流""" conn = self.db.get_conn() try: - row = conn.execute( - "SELECT * FROM workflows WHERE id = ?", - (workflow_id,) - ).fetchone() + row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id,)).fetchone() if not row: return None @@ -514,8 +474,7 @@ class WorkflowManager: finally: conn.close() - def list_workflows(self, project_id: str = None, status: str = None, - workflow_type: str = None) -> list[Workflow]: + def list_workflows(self, project_id: str = None, status: str = None, workflow_type: str = None) -> list[Workflow]: """列出工作流""" conn = self.db.get_conn() try: @@ -535,8 +494,7 @@ class WorkflowManager: where_clause = " AND ".join(conditions) if conditions else "1=1" rows = conn.execute( - f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", - params + f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params ).fetchall() return [self._row_to_workflow(row) for row in rows] @@ -547,15 +505,23 @@ class WorkflowManager: """更新工作流""" conn = self.db.get_conn() try: - allowed_fields = ['name', 'description', 'status', 'schedule', - 'schedule_type', 'is_active', 'config', 'webhook_ids'] + allowed_fields = [ + "name", + "description", + "status", + "schedule", + "schedule_type", + "is_active", + "config", + "webhook_ids", + ] updates = [] values = [] for f in allowed_fields: if f in kwargs: updates.append(f"{f} = ?") - if f in ['config', 'webhook_ids']: + if f in ["config", "webhook_ids"]: values.append(json.dumps(kwargs[f])) else: values.append(kwargs[f]) @@ -607,24 +573,24 @@ class WorkflowManager: def _row_to_workflow(self, row) -> Workflow: """将数据库行转换为 Workflow 对象""" return Workflow( - id=row['id'], - name=row['name'], - description=row['description'] or "", - workflow_type=row['workflow_type'], - project_id=row['project_id'], - status=row['status'], - schedule=row['schedule'], - schedule_type=row['schedule_type'], - config=json.loads(row['config']) if row['config'] else {}, - webhook_ids=json.loads(row['webhook_ids']) if row['webhook_ids'] else [], - is_active=bool(row['is_active']), - created_at=row['created_at'], - updated_at=row['updated_at'], - last_run_at=row['last_run_at'], - next_run_at=row['next_run_at'], - run_count=row['run_count'] or 0, - success_count=row['success_count'] or 0, - fail_count=row['fail_count'] or 0 + id=row["id"], + name=row["name"], + description=row["description"] or "", + workflow_type=row["workflow_type"], + project_id=row["project_id"], + status=row["status"], + schedule=row["schedule"], + schedule_type=row["schedule_type"], + config=json.loads(row["config"]) if row["config"] else {}, + webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_run_at=row["last_run_at"], + next_run_at=row["next_run_at"], + run_count=row["run_count"] or 0, + success_count=row["success_count"] or 0, + fail_count=row["fail_count"] or 0, ) # ==================== Workflow Task CRUD ==================== @@ -639,10 +605,20 @@ class WorkflowManager: depends_on, timeout_seconds, retry_count, retry_delay, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (task.id, task.workflow_id, task.name, task.task_type, - json.dumps(task.config), task.order, json.dumps(task.depends_on), - task.timeout_seconds, task.retry_count, task.retry_delay, - task.created_at, task.updated_at) + ( + task.id, + task.workflow_id, + task.name, + task.task_type, + json.dumps(task.config), + task.order, + json.dumps(task.depends_on), + task.timeout_seconds, + task.retry_count, + task.retry_delay, + task.created_at, + task.updated_at, + ), ) conn.commit() return task @@ -653,10 +629,7 @@ class WorkflowManager: """获取任务""" conn = self.db.get_conn() try: - row = conn.execute( - "SELECT * FROM workflow_tasks WHERE id = ?", - (task_id,) - ).fetchone() + row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id,)).fetchone() if not row: return None @@ -670,8 +643,7 @@ class WorkflowManager: conn = self.db.get_conn() try: rows = conn.execute( - "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", - (workflow_id,) + "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", (workflow_id,) ).fetchall() return [self._row_to_task(row) for row in rows] @@ -682,15 +654,23 @@ class WorkflowManager: """更新任务""" conn = self.db.get_conn() try: - allowed_fields = ['name', 'task_type', 'config', 'task_order', - 'depends_on', 'timeout_seconds', 'retry_count', 'retry_delay'] + allowed_fields = [ + "name", + "task_type", + "config", + "task_order", + "depends_on", + "timeout_seconds", + "retry_count", + "retry_delay", + ] updates = [] values = [] for f in allowed_fields: if f in kwargs: updates.append(f"{f} = ?") - if f in ['config', 'depends_on']: + if f in ["config", "depends_on"]: values.append(json.dumps(kwargs[f])) else: values.append(kwargs[f]) @@ -723,18 +703,18 @@ class WorkflowManager: def _row_to_task(self, row) -> WorkflowTask: """将数据库行转换为 WorkflowTask 对象""" return WorkflowTask( - id=row['id'], - workflow_id=row['workflow_id'], - name=row['name'], - task_type=row['task_type'], - config=json.loads(row['config']) if row['config'] else {}, - order=row['task_order'] or 0, - depends_on=json.loads(row['depends_on']) if row['depends_on'] else [], - timeout_seconds=row['timeout_seconds'] or 300, - retry_count=row['retry_count'] or 3, - retry_delay=row['retry_delay'] or 5, - created_at=row['created_at'], - updated_at=row['updated_at'] + id=row["id"], + workflow_id=row["workflow_id"], + name=row["name"], + task_type=row["task_type"], + config=json.loads(row["config"]) if row["config"] else {}, + order=row["task_order"] or 0, + depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [], + timeout_seconds=row["timeout_seconds"] or 300, + retry_count=row["retry_count"] or 3, + retry_delay=row["retry_delay"] or 5, + created_at=row["created_at"], + updated_at=row["updated_at"], ) # ==================== Webhook Config CRUD ==================== @@ -749,10 +729,21 @@ class WorkflowManager: is_active, created_at, updated_at, last_used_at, success_count, fail_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (webhook.id, webhook.name, webhook.webhook_type, webhook.url, - webhook.secret, json.dumps(webhook.headers), webhook.template, - webhook.is_active, webhook.created_at, webhook.updated_at, - webhook.last_used_at, webhook.success_count, webhook.fail_count) + ( + webhook.id, + webhook.name, + webhook.webhook_type, + webhook.url, + webhook.secret, + json.dumps(webhook.headers), + webhook.template, + webhook.is_active, + webhook.created_at, + webhook.updated_at, + webhook.last_used_at, + webhook.success_count, + webhook.fail_count, + ), ) conn.commit() return webhook @@ -763,10 +754,7 @@ class WorkflowManager: """获取 Webhook 配置""" conn = self.db.get_conn() try: - row = conn.execute( - "SELECT * FROM webhook_configs WHERE id = ?", - (webhook_id,) - ).fetchone() + row = conn.execute("SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,)).fetchone() if not row: return None @@ -779,9 +767,7 @@ class WorkflowManager: """列出所有 Webhook 配置""" conn = self.db.get_conn() try: - rows = conn.execute( - "SELECT * FROM webhook_configs ORDER BY created_at DESC" - ).fetchall() + rows = conn.execute("SELECT * FROM webhook_configs ORDER BY created_at DESC").fetchall() return [self._row_to_webhook(row) for row in rows] finally: @@ -791,15 +777,14 @@ class WorkflowManager: """更新 Webhook 配置""" conn = self.db.get_conn() try: - allowed_fields = ['name', 'webhook_type', 'url', 'secret', - 'headers', 'template', 'is_active'] + allowed_fields = ["name", "webhook_type", "url", "secret", "headers", "template", "is_active"] updates = [] values = [] for f in allowed_fields: if f in kwargs: updates.append(f"{f} = ?") - if f == 'headers': + if f == "headers": values.append(json.dumps(kwargs[f])) else: values.append(kwargs[f]) @@ -829,7 +814,7 @@ class WorkflowManager: finally: conn.close() - def update_webhook_stats(self, webhook_id: str, success: bool): + def update_webhook_stats(self, webhook_id: str, success: bool) -> None: """更新 Webhook 统计""" conn = self.db.get_conn() try: @@ -838,14 +823,14 @@ class WorkflowManager: """UPDATE webhook_configs SET success_count = success_count + 1, last_used_at = ? WHERE id = ?""", - (datetime.now().isoformat(), webhook_id) + (datetime.now().isoformat(), webhook_id), ) else: conn.execute( """UPDATE webhook_configs SET fail_count = fail_count + 1, last_used_at = ? WHERE id = ?""", - (datetime.now().isoformat(), webhook_id) + (datetime.now().isoformat(), webhook_id), ) conn.commit() finally: @@ -854,19 +839,19 @@ class WorkflowManager: def _row_to_webhook(self, row) -> WebhookConfig: """将数据库行转换为 WebhookConfig 对象""" return WebhookConfig( - id=row['id'], - name=row['name'], - webhook_type=row['webhook_type'], - url=row['url'], - secret=row['secret'] or "", - headers=json.loads(row['headers']) if row['headers'] else {}, - template=row['template'] or "", - is_active=bool(row['is_active']), - created_at=row['created_at'], - updated_at=row['updated_at'], - last_used_at=row['last_used_at'], - success_count=row['success_count'] or 0, - fail_count=row['fail_count'] or 0 + id=row["id"], + name=row["name"], + webhook_type=row["webhook_type"], + url=row["url"], + secret=row["secret"] or "", + headers=json.loads(row["headers"]) if row["headers"] else {}, + template=row["template"] or "", + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_used_at=row["last_used_at"], + success_count=row["success_count"] or 0, + fail_count=row["fail_count"] or 0, ) # ==================== Workflow Log ==================== @@ -880,10 +865,19 @@ class WorkflowManager: (id, workflow_id, task_id, status, start_time, end_time, duration_ms, input_data, output_data, error_message, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (log.id, log.workflow_id, log.task_id, log.status, - log.start_time, log.end_time, log.duration_ms, - json.dumps(log.input_data), json.dumps(log.output_data), - log.error_message, log.created_at) + ( + log.id, + log.workflow_id, + log.task_id, + log.status, + log.start_time, + log.end_time, + log.duration_ms, + json.dumps(log.input_data), + json.dumps(log.output_data), + log.error_message, + log.created_at, + ), ) conn.commit() return log @@ -894,15 +888,14 @@ class WorkflowManager: """更新工作流日志""" conn = self.db.get_conn() try: - allowed_fields = ['status', 'end_time', 'duration_ms', - 'output_data', 'error_message'] + allowed_fields = ["status", "end_time", "duration_ms", "output_data", "error_message"] updates = [] values = [] for f in allowed_fields: if f in kwargs: updates.append(f"{f} = ?") - if f == 'output_data': + if f == "output_data": values.append(json.dumps(kwargs[f])) else: values.append(kwargs[f]) @@ -923,10 +916,7 @@ class WorkflowManager: """获取日志""" conn = self.db.get_conn() try: - row = conn.execute( - "SELECT * FROM workflow_logs WHERE id = ?", - (log_id,) - ).fetchone() + row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id,)).fetchone() if not row: return None @@ -935,8 +925,9 @@ class WorkflowManager: finally: conn.close() - def list_logs(self, workflow_id: str = None, task_id: str = None, - status: str = None, limit: int = 100, offset: int = 0) -> list[WorkflowLog]: + def list_logs( + self, workflow_id: str = None, task_id: str = None, status: str = None, limit: int = 100, offset: int = 0 + ) -> list[WorkflowLog]: """列出工作流日志""" conn = self.db.get_conn() try: @@ -960,7 +951,7 @@ class WorkflowManager: WHERE {where_clause} ORDER BY created_at DESC LIMIT ? OFFSET ?""", - params + [limit, offset] + params + [limit, offset], ).fetchall() return [self._row_to_log(row) for row in rows] @@ -975,27 +966,29 @@ class WorkflowManager: # 总执行次数 total = conn.execute( - "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", - (workflow_id, since) + "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since) ).fetchone()[0] # 成功次数 success = conn.execute( "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'success' AND created_at > ?", - (workflow_id, since) + (workflow_id, since), ).fetchone()[0] # 失败次数 failed = conn.execute( "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'failed' AND created_at > ?", - (workflow_id, since) + (workflow_id, since), ).fetchone()[0] # 平均执行时间 - avg_duration = conn.execute( - "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", - (workflow_id, since) - ).fetchone()[0] or 0 + avg_duration = ( + conn.execute( + "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", + (workflow_id, since), + ).fetchone()[0] + or 0 + ) # 每日统计 daily = conn.execute( @@ -1006,7 +999,7 @@ class WorkflowManager: WHERE workflow_id = ? AND created_at > ? GROUP BY DATE(created_at) ORDER BY date""", - (workflow_id, since) + (workflow_id, since), ).fetchall() return { @@ -1015,7 +1008,7 @@ class WorkflowManager: "failed": failed, "success_rate": round(success / total * 100, 2) if total > 0 else 0, "avg_duration_ms": round(avg_duration, 2), - "daily": [{"date": r["date"], "count": r["count"], "success": r["success"]} for r in daily] + "daily": [{"date": r["date"], "count": r["count"], "success": r["success"]} for r in daily], } finally: conn.close() @@ -1023,17 +1016,17 @@ class WorkflowManager: def _row_to_log(self, row) -> WorkflowLog: """将数据库行转换为 WorkflowLog 对象""" return WorkflowLog( - id=row['id'], - workflow_id=row['workflow_id'], - task_id=row['task_id'], - status=row['status'], - start_time=row['start_time'], - end_time=row['end_time'], - duration_ms=row['duration_ms'] or 0, - input_data=json.loads(row['input_data']) if row['input_data'] else {}, - output_data=json.loads(row['output_data']) if row['output_data'] else {}, - error_message=row['error_message'] or "", - created_at=row['created_at'] + id=row["id"], + workflow_id=row["workflow_id"], + task_id=row["task_id"], + status=row["status"], + start_time=row["start_time"], + end_time=row["end_time"], + duration_ms=row["duration_ms"] or 0, + input_data=json.loads(row["input_data"]) if row["input_data"] else {}, + output_data=json.loads(row["output_data"]) if row["output_data"] else {}, + error_message=row["error_message"] or "", + created_at=row["created_at"], ) # ==================== Workflow Execution ==================== @@ -1049,8 +1042,7 @@ class WorkflowManager: # 更新最后运行时间 now = datetime.now().isoformat() - self.update_workflow(workflow_id, last_run_at=now, - run_count=workflow.run_count + 1) + self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1) # 创建工作流执行日志 log = WorkflowLog( @@ -1058,7 +1050,7 @@ class WorkflowManager: workflow_id=workflow_id, status=TaskStatus.RUNNING.value, start_time=now, - input_data=input_data or {} + input_data=input_data or {}, ) self.create_log(log) @@ -1087,7 +1079,7 @@ class WorkflowManager: status=TaskStatus.SUCCESS.value, end_time=end_time.isoformat(), duration_ms=duration, - output_data=results + output_data=results, ) # 更新成功计数 @@ -1098,7 +1090,7 @@ class WorkflowManager: "workflow_id": workflow_id, "log_id": log.id, "results": results, - "duration_ms": duration + "duration_ms": duration, } except (TimeoutError, httpx.HTTPError) as e: @@ -1112,7 +1104,7 @@ class WorkflowManager: status=TaskStatus.FAILED.value, end_time=end_time.isoformat(), duration_ms=duration, - error_message=str(e) + error_message=str(e), ) # 更新失败计数 @@ -1123,8 +1115,7 @@ class WorkflowManager: raise - async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask], - input_data: dict, log_id: str) -> dict: + async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask], input_data: dict, log_id: str) -> dict: """按依赖顺序执行任务""" results = {} completed_tasks = set() @@ -1132,9 +1123,7 @@ class WorkflowManager: while len(completed_tasks) < len(tasks): # 找到可以执行的任务(依赖已完成) ready_tasks = [ - t for t in tasks - if t.id not in completed_tasks - and all(dep in completed_tasks for dep in t.depends_on) + t for t in tasks if t.id not in completed_tasks and all(dep in completed_tasks for dep in t.depends_on) ] if not ready_tasks: @@ -1171,8 +1160,7 @@ class WorkflowManager: return results - async def _execute_single_task(self, task: WorkflowTask, - input_data: dict, log_id: str) -> Any: + async def _execute_single_task(self, task: WorkflowTask, input_data: dict, log_id: str) -> Any: """执行单个任务""" handler = self._task_handlers.get(task.task_type) if not handler: @@ -1185,23 +1173,20 @@ class WorkflowManager: task_id=task.id, status=TaskStatus.RUNNING.value, start_time=datetime.now().isoformat(), - input_data=input_data + input_data=input_data, ) self.create_log(task_log) try: # 设置超时 - result = await asyncio.wait_for( - handler(task, input_data), - timeout=task.timeout_seconds - ) + result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds) # 更新任务日志为成功 self.update_log( task_log.id, status=TaskStatus.SUCCESS.value, end_time=datetime.now().isoformat(), - output_data={"result": result} if not isinstance(result, dict) else result + output_data={"result": result} if not isinstance(result, dict) else result, ) return result @@ -1211,21 +1196,17 @@ class WorkflowManager: task_log.id, status=TaskStatus.FAILED.value, end_time=datetime.now().isoformat(), - error_message="Task timeout" + error_message="Task timeout", ) raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s") except Exception as e: self.update_log( - task_log.id, - status=TaskStatus.FAILED.value, - end_time=datetime.now().isoformat(), - error_message=str(e) + task_log.id, status=TaskStatus.FAILED.value, end_time=datetime.now().isoformat(), error_message=str(e) ) raise - async def _execute_default_workflow(self, workflow: Workflow, - input_data: dict) -> dict: + async def _execute_default_workflow(self, workflow: Workflow, input_data: dict) -> dict: """执行默认工作流(根据类型)""" workflow_type = WorkflowType(workflow.workflow_type) @@ -1252,12 +1233,7 @@ class WorkflowManager: # 这里调用现有的文件分析逻辑 # 实际实现需要与 main.py 中的 upload_audio 逻辑集成 - return { - "task": "analyze", - "project_id": project_id, - "files_processed": len(file_ids), - "status": "completed" - } + return {"task": "analyze", "project_id": project_id, "files_processed": len(file_ids), "status": "completed"} async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理实体对齐任务""" @@ -1273,11 +1249,10 @@ class WorkflowManager: "project_id": project_id, "threshold": threshold, "entities_merged": 0, # 实际实现需要调用对齐逻辑 - "status": "completed" + "status": "completed", } - async def _handle_discover_relations_task(self, task: WorkflowTask, - input_data: dict) -> dict: + async def _handle_discover_relations_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理关系发现任务""" project_id = input_data.get("project_id") @@ -1289,7 +1264,7 @@ class WorkflowManager: "task": "discover_relations", "project_id": project_id, "relations_found": 0, # 实际实现需要调用关系发现逻辑 - "status": "completed" + "status": "completed", } async def _handle_notify_task(self, task: WorkflowTask, input_data: dict) -> dict: @@ -1314,21 +1289,12 @@ class WorkflowManager: success = await self.notifier.send(webhook, message) self.update_webhook_stats(webhook_id, success) - return { - "task": "notify", - "webhook_id": webhook_id, - "success": success - } + return {"task": "notify", "webhook_id": webhook_id, "success": success} async def _handle_custom_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理自定义任务""" # 自定义任务的具体逻辑由外部处理器实现 - return { - "task": "custom", - "task_name": task.name, - "config": task.config, - "status": "completed" - } + return {"task": "custom", "task_name": task.name, "config": task.config, "status": "completed"} # ==================== Default Workflow Implementations ==================== @@ -1344,7 +1310,7 @@ class WorkflowManager: "files_analyzed": 0, "entities_extracted": 0, "relations_extracted": 0, - "status": "completed" + "status": "completed", } async def _auto_align_entities(self, workflow: Workflow, input_data: dict) -> dict: @@ -1357,7 +1323,7 @@ class WorkflowManager: "project_id": project_id, "threshold": threshold, "entities_merged": 0, - "status": "completed" + "status": "completed", } async def _auto_discover_relations(self, workflow: Workflow, input_data: dict) -> dict: @@ -1368,7 +1334,7 @@ class WorkflowManager: "workflow_type": "auto_relation", "project_id": project_id, "relations_discovered": 0, - "status": "completed" + "status": "completed", } async def _generate_scheduled_report(self, workflow: Workflow, input_data: dict) -> dict: @@ -1380,13 +1346,12 @@ class WorkflowManager: "workflow_type": "scheduled_report", "project_id": project_id, "report_type": report_type, - "status": "completed" + "status": "completed", } # ==================== Notification ==================== - async def _send_workflow_notification(self, workflow: Workflow, - results: dict, success: bool = True): + async def _send_workflow_notification(self, workflow: Workflow, results: dict, success: bool = True): """发送工作流执行通知""" if not workflow.webhook_ids: return @@ -1409,7 +1374,7 @@ class WorkflowManager: "workflow_name": workflow.name, "status": "success" if success else "failed", "results": results, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } try: @@ -1418,8 +1383,7 @@ class WorkflowManager: except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Failed to send notification to {webhook_id}: {e}") - def _build_feishu_message(self, workflow: Workflow, results: dict, - success: bool) -> dict: + def _build_feishu_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建飞书消息""" status_text = "✅ 成功" if success else "❌ 失败" @@ -1429,11 +1393,10 @@ class WorkflowManager: [{"tag": "text", "text": f"工作流: {workflow.name}"}], [{"tag": "text", "text": f"状态: {status_text}"}], [{"tag": "text", "text": f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"}], - ] + ], } - def _build_dingtalk_message(self, workflow: Workflow, results: dict, - success: bool) -> dict: + def _build_dingtalk_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建钉钉消息""" status_text = "✅ 成功" if success else "❌ 失败" @@ -1451,11 +1414,10 @@ class WorkflowManager: ```json {json.dumps(results, ensure_ascii=False, indent=2)} ``` -""" +""", } - def _build_slack_message(self, workflow: Workflow, results: dict, - success: bool) -> dict: + def _build_slack_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建 Slack 消息""" color = "#36a64f" if success else "#ff0000" status_text = "Success" if success else "Failed" @@ -1467,10 +1429,10 @@ class WorkflowManager: "title": f"Workflow Execution: {workflow.name}", "fields": [ {"title": "Status", "value": status_text, "short": True}, - {"title": "Time", "value": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "short": True} + {"title": "Time", "value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "short": True}, ], "footer": "InsightFlow", - "ts": int(datetime.now().timestamp()) + "ts": int(datetime.now().timestamp()), } ] }