From 259f2c90d07dc0a8d2e98a057e11c3e95fcc1026 Mon Sep 17 00:00:00 2001 From: AutoFix Bot Date: Tue, 3 Mar 2026 21:11:47 +0800 Subject: [PATCH] fix: auto-fix code issues (cron) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复隐式 Optional 类型注解 (RUF013) - 修复不必要的赋值后返回 (RET504) - 优化列表推导式 (PERF401) - 修复未使用的参数 (ARG002) - 清理重复导入 - 优化异常处理 --- auto_fix_code.py | 320 ++++------- backend/ai_manager.py | 68 ++- backend/api_key_manager.py | 17 +- backend/collaboration_manager.py | 63 +-- backend/db_manager.py | 104 +++- backend/developer_ecosystem_manager.py | 42 +- backend/document_processor.py | 4 +- backend/enterprise_manager.py | 60 +- backend/entity_aligner.py | 15 +- backend/export_manager.py | 27 +- backend/growth_manager.py | 137 +++-- backend/image_processor.py | 19 +- backend/knowledge_reasoner.py | 39 +- backend/llm_client.py | 34 +- backend/localization_manager.py | 58 +- backend/main.py | 734 ++++++++++++++++++------- backend/multimodal_entity_linker.py | 28 +- backend/multimodal_processor.py | 37 +- backend/neo4j_manager.py | 57 +- backend/ops_manager.py | 96 +++- backend/performance_manager.py | 34 +- backend/plugin_manager.py | 67 ++- backend/rate_limiter.py | 11 +- backend/search_manager.py | 127 +++-- backend/security_manager.py | 30 +- backend/subscription_manager.py | 60 +- backend/tenant_manager.py | 41 +- backend/test_phase7_task6_8.py | 6 +- backend/test_phase8_task1.py | 19 +- backend/test_phase8_task2.py | 3 +- backend/test_phase8_task4.py | 6 +- backend/test_phase8_task5.py | 10 +- backend/test_phase8_task6.py | 67 +-- backend/test_phase8_task8.py | 21 +- backend/tingwu_client.py | 20 +- backend/workflow_manager.py | 33 +- 36 files changed, 1651 insertions(+), 863 deletions(-) diff --git a/auto_fix_code.py b/auto_fix_code.py index 02d1382..474e6fe 100644 --- a/auto_fix_code.py +++ b/auto_fix_code.py @@ -1,235 +1,109 @@ #!/usr/bin/env python3 """ -自动代码修复脚本 - 修复 InsightFlow 项目中的常见问题 +Auto-fix script for InsightFlow code issues """ -import os import re - - -def get_python_files(directory): - """获取目录下所有 Python 文件""" - python_files = [] - for root, _, files in os.walk(directory): - for file in files: - if file.endswith('.py'): - python_files.append(os.path.join(root, file)) - return python_files - - -def fix_missing_imports(content, filepath): - """修复缺失的导入""" - fixes = [] - - # 检查是否使用了 re 但没有导入 - if 're.search(' in content or 're.sub(' in content or 're.match(' in content: - if 'import re' not in content: - # 找到合适的位置添加导入 - lines = content.split('\n') - import_idx = 0 - for i, line in enumerate(lines): - if line.startswith('import ') or line.startswith('from '): - import_idx = i + 1 - lines.insert(import_idx, 'import re') - content = '\n'.join(lines) - fixes.append("添加缺失的 'import re'") - - # 检查是否使用了 csv 但没有导入 - if 'csv.' in content and 'import csv' not in content: - lines = content.split('\n') - import_idx = 0 - for i, line in enumerate(lines): - if line.startswith('import ') or line.startswith('from '): - import_idx = i + 1 - lines.insert(import_idx, 'import csv') - content = '\n'.join(lines) - fixes.append("添加缺失的 'import csv'") - - # 检查是否使用了 urllib 但没有导入 - if 'urllib.' in content and 'import urllib' not in content: - lines = content.split('\n') - import_idx = 0 - for i, line in enumerate(lines): - if line.startswith('import ') or line.startswith('from '): - import_idx = i + 1 - lines.insert(import_idx, 'import urllib.parse') - content = '\n'.join(lines) - fixes.append("添加缺失的 'import urllib.parse'") - - return content, fixes - - -def fix_bare_excepts(content): - """修复裸异常捕获""" - fixes = [] - - # 替换裸 except: - bare_except_pattern = r'except\s*:\s*$' - lines = content.split('\n') - new_lines = [] - for line in lines: - if re.match(bare_except_pattern, line.strip()): - # 缩进保持一致 - indent = len(line) - len(line.lstrip()) - new_line = ' ' * indent + 'except Exception:' - new_lines.append(new_line) - fixes.append(f"修复裸异常捕获: {line.strip()}") - else: - new_lines.append(line) - - content = '\n'.join(new_lines) - return content, fixes - - -def fix_unused_imports(content): - """修复未使用的导入 - 简单版本""" - fixes = [] - - # 查找导入语句 - import_pattern = r'^from\s+(\S+)\s+import\s+(.+)$' - lines = content.split('\n') - new_lines = [] - - for line in lines: - match = re.match(import_pattern, line) - if match: - module = match.group(1) - imports = match.group(2) - - # 检查每个导入是否被使用 - imported_items = [i.strip() for i in imports.split(',')] - used_items = [] - - for item in imported_items: - # 简单的使用检查 - item_name = item.split(' as ')[-1].strip() if ' as ' in item else item.strip() - if item_name in content.replace(line, ''): - used_items.append(item) - else: - fixes.append(f"移除未使用的导入: {item}") - - if used_items: - new_lines.append(f"from {module} import {', '.join(used_items)}") - else: - fixes.append(f"移除整行导入: {line.strip()}") - else: - new_lines.append(line) - - content = '\n'.join(new_lines) - return content, fixes - - -def fix_string_formatting(content): - """统一字符串格式化为 f-string""" - fixes = [] - - # 修复 .format() 调用 - format_pattern = r'["\']([^"\']*)\{([^}]+)\}[^"\']*["\']\.format\(([^)]+)\)' - - def replace_format(match): - template = match.group(1) + '{' + match.group(2) + '}' - # 简单替换,实际可能需要更复杂的处理 - return f'f"{template}"' - - new_content = re.sub(format_pattern, replace_format, content) - if new_content != content: - fixes.append("统一字符串格式化为 f-string") - content = new_content - - return content, fixes - - -def fix_pep8_formatting(content): - """修复 PEP8 格式问题""" - fixes = [] - lines = content.split('\n') - new_lines = [] - - for line in lines: - original = line - # 修复 E221: multiple spaces before operator - line = re.sub(r'(\w+)\s{2,}=\s', r'\1 = ', line) - # 修复 E251: unexpected spaces around keyword / parameter equals - line = re.sub(r'(\w+)\s*=\s{2,}', r'\1 = ', line) - line = re.sub(r'(\w+)\s{2,}=\s*', r'\1 = ', line) - - if line != original: - fixes.append(f"修复 PEP8 格式: {original.strip()[:50]}") - - new_lines.append(line) - - content = '\n'.join(new_lines) - return content, fixes - +import os +from pathlib import Path def fix_file(filepath): - """修复单个文件""" - print(f"\n处理文件: {filepath}") - - try: - with open(filepath, encoding='utf-8') as f: - content = f.read() - except Exception as e: - print(f" 无法读取文件: {e}") - return [] - - original_content = content - all_fixes = [] - - # 应用各种修复 - content, fixes = fix_missing_imports(content, filepath) - all_fixes.extend(fixes) - - content, fixes = fix_bare_excepts(content) - all_fixes.extend(fixes) - - content, fixes = fix_pep8_formatting(content) - all_fixes.extend(fixes) - - # 保存修改 - if content != original_content: - try: - with open(filepath, 'w', encoding='utf-8') as f: - f.write(content) - print(f" 已修复 {len(all_fixes)} 个问题") - for fix in all_fixes[:5]: # 只显示前5个 - print(f" - {fix}") - if len(all_fixes) > 5: - print(f" ... 还有 {len(all_fixes) - 5} 个修复") - except Exception as e: - print(f" 保存文件失败: {e}") - else: - print(" 无需修复") - - return all_fixes - + """Fix common issues in a Python file""" + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + + original = content + changes = [] + + # 1. Fix implicit Optional (RUF013) + # Pattern: def func(arg: type = None) -> def func(arg: type | None = None) + implicit_optional_pattern = r'(def\s+\w+\([^)]*?)(\w+\s*:\s*(?!.*\|.*None)([a-zA-Z_][a-zA-Z0-9_\[\]]*)\s*=\s*None)' + + def fix_optional(match): + prefix = match.group(1) + full_arg = match.group(2) + arg_name = full_arg.split(':')[0].strip() + arg_type = match.group(3).strip() + return f'{prefix}{arg_name}: {arg_type} | None = None' + + # More careful approach for implicit Optional + lines = content.split('\n') + new_lines = [] + for line in lines: + original_line = line + # Fix patterns like "metadata: dict = None," + if re.search(r':\s*\w+\s*=\s*None', line) and '| None' not in line: + # Match parameter definitions + match = re.search(r'(\w+)\s*:\s*(\w+(?:\[[^\]]+\])?)\s*=\s*None', line) + if match: + param_name = match.group(1) + param_type = match.group(2) + if param_type != 'NoneType': + line = line.replace(f'{param_name}: {param_type} = None', + f'{param_name}: {param_type} | None = None') + if line != original_line: + changes.append(f"Fixed implicit Optional: {param_name}") + new_lines.append(line) + content = '\n'.join(new_lines) + + # 2. Fix unnecessary assignment before return (RET504) + return_patterns = [ + (r'(\s+)entities\s*=\s*json\.loads\([^)]+\)\s*\n\1return\s+entities\b', + r'\1return json.loads(entities_match.group(0).split("=")[1].strip().split("\n")[0])'), + ] + + # 3. Fix RUF010 - Use explicit conversion flag + # f"...{str(var)}..." -> f"...{var!s}..." + content = re.sub(r'\{str\(([^)]+)\)\}', r'{\1!s}', content) + content = re.sub(r'\{repr\(([^)]+)\)\}', r'{\1!r}', content) + + # 4. Fix RET505 - Unnecessary else after return + # This is complex, skip for now + + # 5. Fix PERF401 - List comprehensions (basic cases) + # This is complex, skip for now + + # 6. Fix RUF012 - Mutable default values + # Pattern: def func(arg: list = []) -> def func(arg: list = None) with handling + content = re.sub(r'(\w+)\s*:\s*list\s*=\s*\[\]', r'\1: list | None = None', content) + content = re.sub(r'(\w+)\s*:\s*dict\s*=\s*\{\}', r'\1: dict | None = None', content) + + # 7. Fix unused imports (basic) + # Remove duplicate imports + import_lines = re.findall(r'^(import\s+\w+|from\s+\w+\s+import\s+[^\n]+)$', content, re.MULTILINE) + seen_imports = set() + for imp in import_lines: + if imp in seen_imports: + content = content.replace(imp + '\n', '\n', 1) + changes.append(f"Removed duplicate import: {imp}") + seen_imports.add(imp) + + if content != original: + with open(filepath, 'w', encoding='utf-8') as f: + f.write(content) + return True, changes + return False, [] def main(): - """主函数""" - base_dir = '/root/.openclaw/workspace/projects/insightflow' - backend_dir = os.path.join(base_dir, 'backend') - - print("=" * 60) - print("InsightFlow 代码自动修复工具") - print("=" * 60) - - # 获取所有 Python 文件 - files = get_python_files(backend_dir) - print(f"\n找到 {len(files)} 个 Python 文件") - - total_fixes = 0 - fixed_files = 0 - - for filepath in files: - fixes = fix_file(filepath) - if fixes: - total_fixes += len(fixes) - fixed_files += 1 - - print("\n" + "=" * 60) - print(f"修复完成: {fixed_files} 个文件, {total_fixes} 个问题") - print("=" * 60) - + backend_dir = Path('/root/.openclaw/workspace/projects/insightflow/backend') + py_files = list(backend_dir.glob('*.py')) + + fixed_files = [] + all_changes = [] + + for filepath in py_files: + fixed, changes = fix_file(filepath) + if fixed: + fixed_files.append(filepath.name) + all_changes.extend([f"{filepath.name}: {c}" for c in changes]) + + print(f"Fixed {len(fixed_files)} files:") + for f in fixed_files: + print(f" - {f}") + if all_changes: + print("\nChanges made:") + for c in all_changes[:20]: + print(f" {c}") if __name__ == '__main__': main() diff --git a/backend/ai_manager.py b/backend/ai_manager.py index 60cd0c0..26895cc 100644 --- a/backend/ai_manager.py +++ b/backend/ai_manager.py @@ -291,7 +291,10 @@ class AIManager: return self._row_to_custom_model(row) def list_custom_models( - self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None, + self, + tenant_id: str, + model_type: ModelType | None = None, + status: ModelStatus | None = None, ) -> list[CustomModel]: """列出自定义模型""" query = "SELECT * FROM custom_models WHERE tenant_id = ?" @@ -311,7 +314,11 @@ class AIManager: return [self._row_to_custom_model(row) for row in rows] def add_training_sample( - self, model_id: str, text: str, entities: list[dict], metadata: dict = None, + self, + model_id: str, + text: str, + entities: list[dict], + metadata: dict | None = None, ) -> TrainingSample: """添加训练样本""" sample_id = f"ts_{uuid.uuid4().hex[:16]}" @@ -463,8 +470,7 @@ class AIManager: json_match = re.search(r"\[.*?\]", content, re.DOTALL) if json_match: try: - entities = json.loads(json_match.group()) - return entities + return json.loads(json_match.group()) except (json.JSONDecodeError, ValueError): pass @@ -542,8 +548,9 @@ class AIManager: } content = [{"type": "text", "text": prompt}] - for url in image_urls: - content.append({"type": "image_url", "image_url": {"url": url}}) + content.extend( + [{"type": "image_url", "image_url": {"url": url}} for url in image_urls] + ) payload = { "model": "gpt-4-vision-preview", @@ -575,9 +582,9 @@ class AIManager: "anthropic-version": "2023-06-01", } - content = [] - for url in image_urls: - content.append({"type": "image", "source": {"type": "url", "url": url}}) + content = [ + {"type": "image", "source": {"type": "url", "url": url}} for url in image_urls + ] content.append({"type": "text", "text": prompt}) payload = { @@ -638,7 +645,9 @@ class AIManager: } def get_multimodal_analyses( - self, tenant_id: str, project_id: str | None = None, + self, + tenant_id: str, + project_id: str | None = None, ) -> list[MultimodalAnalysis]: """获取多模态分析历史""" query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" @@ -721,7 +730,9 @@ class AIManager: return self._row_to_kg_rag(row) def list_kg_rags( - self, tenant_id: str, project_id: str | None = None, + self, + tenant_id: str, + project_id: str | None = None, ) -> list[KnowledgeGraphRAG]: """列出知识图谱 RAG 配置""" query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" @@ -738,7 +749,11 @@ class AIManager: return [self._row_to_kg_rag(row) for row in rows] async def query_kg_rag( - self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict], + self, + rag_id: str, + query: str, + project_entities: list[dict], + project_relations: list[dict], ) -> RAGQuery: """基于知识图谱的 RAG 查询""" start_time = time.time() @@ -771,14 +786,15 @@ class AIManager: relevant_entities = relevant_entities[:top_k] # 检索相关关系 - relevant_relations = [] entity_ids = {e["id"] for e in relevant_entities} - for relation in project_relations: + relevant_relations = [ + relation + for relation in project_relations if ( relation.get("source_entity_id") in entity_ids or relation.get("target_entity_id") in entity_ids - ): - relevant_relations.append(relation) + ) + ] # 2. 构建上下文 context = {"entities": relevant_entities, "relations": relevant_relations[:10]} @@ -1123,7 +1139,8 @@ class AIManager: """获取预测模型""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM prediction_models WHERE id = ?", (model_id,), + "SELECT * FROM prediction_models WHERE id = ?", + (model_id,), ).fetchone() if not row: @@ -1132,7 +1149,9 @@ class AIManager: return self._row_to_prediction_model(row) def list_prediction_models( - self, tenant_id: str, project_id: str | None = None, + self, + tenant_id: str, + project_id: str | None = None, ) -> list[PredictionModel]: """列出预测模型""" query = "SELECT * FROM prediction_models WHERE tenant_id = ?" @@ -1149,7 +1168,9 @@ class AIManager: return [self._row_to_prediction_model(row) for row in rows] async def train_prediction_model( - self, model_id: str, historical_data: list[dict], + self, + model_id: str, + historical_data: list[dict], ) -> PredictionModel: """训练预测模型""" model = self.get_prediction_model(model_id) @@ -1369,7 +1390,9 @@ class AIManager: predicted_relations = [ {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} for rel_type, count in sorted( - relation_counts.items(), key=lambda x: x[1], reverse=True, + relation_counts.items(), + key=lambda x: x[1], + reverse=True, )[:5] ] @@ -1394,7 +1417,10 @@ class AIManager: return [self._row_to_prediction_result(row) for row in rows] def update_prediction_feedback( - self, prediction_id: str, actual_value: str, is_correct: bool, + self, + prediction_id: str, + actual_value: str, + is_correct: bool, ) -> None: """更新预测反馈(用于模型改进)""" with self._get_db() as conn: diff --git a/backend/api_key_manager.py b/backend/api_key_manager.py index 8ec7091..f4f073c 100644 --- a/backend/api_key_manager.py +++ b/backend/api_key_manager.py @@ -132,7 +132,7 @@ class ApiKeyManager: self, name: str, owner_id: str | None = None, - permissions: list[str] = None, + permissions: list[str] | None = None, rate_limit: int = 60, expires_days: int | None = None, ) -> tuple[str, ApiKey]: @@ -238,7 +238,8 @@ class ApiKeyManager: # 验证所有权(如果提供了 owner_id) if owner_id: row = conn.execute( - "SELECT owner_id FROM api_keys WHERE id = ?", (key_id,), + "SELECT owner_id FROM api_keys WHERE id = ?", + (key_id,), ).fetchone() if not row or row[0] != owner_id: return False @@ -267,7 +268,8 @@ class ApiKeyManager: if owner_id: row = conn.execute( - "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id), + "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", + (key_id, owner_id), ).fetchone() else: row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone() @@ -337,7 +339,8 @@ class ApiKeyManager: # 验证所有权 if owner_id: row = conn.execute( - "SELECT owner_id FROM api_keys WHERE id = ?", (key_id,), + "SELECT owner_id FROM api_keys WHERE id = ?", + (key_id,), ).fetchone() if not row or row[0] != owner_id: return False @@ -465,7 +468,8 @@ class ApiKeyManager: endpoint_params = [] if api_key_id: endpoint_query = endpoint_query.replace( - "WHERE created_at", "WHERE api_key_id = ? AND created_at", + "WHERE created_at", + "WHERE api_key_id = ? AND created_at", ) endpoint_params.insert(0, api_key_id) @@ -486,7 +490,8 @@ class ApiKeyManager: daily_params = [] if api_key_id: daily_query = daily_query.replace( - "WHERE created_at", "WHERE api_key_id = ? AND created_at", + "WHERE created_at", + "WHERE api_key_id = ? AND created_at", ) daily_params.insert(0, api_key_id) diff --git a/backend/collaboration_manager.py b/backend/collaboration_manager.py index aad1b31..4e14d02 100644 --- a/backend/collaboration_manager.py +++ b/backend/collaboration_manager.py @@ -304,7 +304,7 @@ class CollaborationManager: ) self.db.conn.commit() - def revoke_share_link(self, share_id: str, revoked_by: str) -> bool: + def revoke_share_link(self, share_id: str, _revoked_by: str) -> bool: """撤销分享链接""" if self.db: cursor = self.db.conn.cursor() @@ -335,26 +335,24 @@ class CollaborationManager: (project_id,), ) - shares = [] - for row in cursor.fetchall(): - shares.append( - ProjectShare( - id=row[0], - project_id=row[1], - token=row[2], - permission=row[3], - created_by=row[4], - created_at=row[5], - expires_at=row[6], - max_uses=row[7], - use_count=row[8], - password_hash=row[9], - is_active=bool(row[10]), - allow_download=bool(row[11]), - allow_export=bool(row[12]), - ), + return [ + ProjectShare( + id=row[0], + project_id=row[1], + token=row[2], + permission=row[3], + created_by=row[4], + created_at=row[5], + expires_at=row[6], + max_uses=row[7], + use_count=row[8], + password_hash=row[9], + is_active=bool(row[10]), + allow_download=bool(row[11]), + allow_export=bool(row[12]), ) - return shares + for row in cursor.fetchall() + ] # ============ 评论和批注 ============ @@ -435,7 +433,10 @@ class CollaborationManager: self.db.conn.commit() def get_comments( - self, target_type: str, target_id: str, include_resolved: bool = True, + self, + target_type: str, + target_id: str, + include_resolved: bool = True, ) -> list[Comment]: """获取评论列表""" if not self.db: @@ -461,10 +462,7 @@ class CollaborationManager: (target_type, target_id), ) - comments = [] - for row in cursor.fetchall(): - comments.append(self._row_to_comment(row)) - return comments + return [self._row_to_comment(row) for row in cursor.fetchall()] def _row_to_comment(self, row) -> Comment: """将数据库行转换为Comment对象""" @@ -554,7 +552,10 @@ class CollaborationManager: return cursor.rowcount > 0 def get_project_comments( - self, project_id: str, limit: int = 50, offset: int = 0, + self, + project_id: str, + limit: int = 50, + offset: int = 0, ) -> list[Comment]: """获取项目下的所有评论""" if not self.db: @@ -571,10 +572,7 @@ class CollaborationManager: (project_id, limit, offset), ) - comments = [] - for row in cursor.fetchall(): - comments.append(self._row_to_comment(row)) - return comments + return [self._row_to_comment(row) for row in cursor.fetchall()] # ============ 变更历史 ============ @@ -697,10 +695,7 @@ class CollaborationManager: (project_id, limit, offset), ) - records = [] - for row in cursor.fetchall(): - records.append(self._row_to_change_record(row)) - return records + return [self._row_to_change_record(row) for row in cursor.fetchall()] def _row_to_change_record(self, row) -> ChangeRecord: """将数据库行转换为ChangeRecord对象""" diff --git a/backend/db_manager.py b/backend/db_manager.py index 9c34d7f..d7c16f5 100644 --- a/backend/db_manager.py +++ b/backend/db_manager.py @@ -37,7 +37,7 @@ class Entity: canonical_name: str = "" aliases: list[str] = None embedding: str = "" # Phase 3: 实体嵌入向量 - attributes: dict = None # Phase 5: 实体属性 + attributes: dict | None = None # Phase 5: 实体属性 created_at: str = "" updated_at: str = "" @@ -149,7 +149,11 @@ class DatabaseManager: conn.commit() conn.close() return Project( - id=project_id, name=name, description=description, created_at=now, updated_at=now, + id=project_id, + name=name, + description=description, + created_at=now, + updated_at=now, ) def get_project(self, project_id: str) -> Project | None: @@ -206,7 +210,10 @@ class DatabaseManager: return None def find_similar_entities( - self, project_id: str, name: str, threshold: float = 0.8, + self, + project_id: str, + name: str, + threshold: float = 0.8, ) -> list[Entity]: """查找相似实体""" conn = self.get_conn() @@ -243,7 +250,8 @@ class DatabaseManager: (json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id), ) conn.execute( - "UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id), + "UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", + (target_id, source_id), ) conn.execute( "UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?", @@ -272,7 +280,8 @@ class DatabaseManager: def list_project_entities(self, project_id: str) -> list[Entity]: conn = self.get_conn() rows = conn.execute( - "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,), + "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", + (project_id,), ).fetchall() conn.close() @@ -478,7 +487,8 @@ class DatabaseManager: conn.commit() row = conn.execute( - "SELECT * FROM entity_relations WHERE id = ?", (relation_id,), + "SELECT * FROM entity_relations WHERE id = ?", + (relation_id,), ).fetchone() conn.close() return dict(row) if row else None @@ -494,12 +504,14 @@ class DatabaseManager: def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str: conn = self.get_conn() existing = conn.execute( - "SELECT * FROM glossary WHERE project_id = ? AND term = ?", (project_id, term), + "SELECT * FROM glossary WHERE project_id = ? AND term = ?", + (project_id, term), ).fetchone() if existing: conn.execute( - "UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],), + "UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", + (existing["id"],), ) conn.commit() conn.close() @@ -519,7 +531,8 @@ class DatabaseManager: def list_glossary(self, project_id: str) -> list[dict]: conn = self.get_conn() rows = conn.execute( - "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,), + "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", + (project_id,), ).fetchall() conn.close() return [dict(r) for r in rows] @@ -605,15 +618,18 @@ class DatabaseManager: project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() entity_count = conn.execute( - "SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM entities WHERE project_id = ?", + (project_id,), ).fetchone()["count"] transcript_count = conn.execute( - "SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", + (project_id,), ).fetchone()["count"] relation_count = conn.execute( - "SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", + (project_id,), ).fetchone()["count"] recent_transcripts = conn.execute( @@ -645,11 +661,15 @@ class DatabaseManager: } def get_transcript_context( - self, transcript_id: str, position: int, context_chars: int = 200, + self, + transcript_id: str, + position: int, + context_chars: int = 200, ) -> str: conn = self.get_conn() row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,), + "SELECT full_text FROM transcripts WHERE id = ?", + (transcript_id,), ).fetchone() conn.close() if not row: @@ -662,7 +682,11 @@ class DatabaseManager: # ==================== Phase 5: Timeline Operations ==================== def get_project_timeline( - self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None, + self, + project_id: str, + entity_id: str | None = None, + start_date: str = None, + end_date: str = None, ) -> list[dict]: conn = self.get_conn() @@ -776,7 +800,8 @@ class DatabaseManager: def get_attribute_template(self, template_id: str) -> AttributeTemplate | None: conn = self.get_conn() row = conn.execute( - "SELECT * FROM attribute_templates WHERE id = ?", (template_id,), + "SELECT * FROM attribute_templates WHERE id = ?", + (template_id,), ).fetchone() conn.close() if row: @@ -841,7 +866,10 @@ class DatabaseManager: conn.close() def set_entity_attribute( - self, attr: EntityAttribute, changed_by: str = "system", change_reason: str = "", + self, + attr: EntityAttribute, + changed_by: str = "system", + change_reason: str = "", ) -> EntityAttribute: conn = self.get_conn() now = datetime.now().isoformat() @@ -930,7 +958,11 @@ class DatabaseManager: return entity def delete_entity_attribute( - self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "", + self, + entity_id: str, + template_id: str, + changed_by: str = "system", + change_reason: str = "", ) -> None: conn = self.get_conn() old_row = conn.execute( @@ -964,7 +996,10 @@ class DatabaseManager: conn.close() def get_attribute_history( - self, entity_id: str = None, template_id: str = None, limit: int = 50, + self, + entity_id: str | None = None, + template_id: str = None, + limit: int = 50, ) -> list[AttributeHistory]: conn = self.get_conn() conditions = [] @@ -990,7 +1025,9 @@ class DatabaseManager: return [AttributeHistory(**dict(r)) for r in rows] def search_entities_by_attributes( - self, project_id: str, attribute_filters: dict[str, str], + self, + project_id: str, + attribute_filters: dict[str, str], ) -> list[Entity]: entities = self.list_project_entities(project_id) if not attribute_filters: @@ -1040,8 +1077,8 @@ class DatabaseManager: filename: str, duration: float = 0, fps: float = 0, - resolution: dict = None, - audio_transcript_id: str = None, + resolution: dict | None = None, + audio_transcript_id: str | None = None, full_ocr_text: str = "", extracted_entities: list[dict] = None, extracted_relations: list[dict] = None, @@ -1098,7 +1135,8 @@ class DatabaseManager: """获取项目的所有视频""" conn = self.get_conn() rows = conn.execute( - "SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id,), + "SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", + (project_id,), ).fetchall() conn.close() @@ -1121,8 +1159,8 @@ class DatabaseManager: video_id: str, frame_number: int, timestamp: float, - image_url: str = None, - ocr_text: str = None, + image_url: str | None = None, + ocr_text: str | None = None, extracted_entities: list[dict] = None, ) -> str: """创建视频帧记录""" @@ -1153,7 +1191,8 @@ class DatabaseManager: """获取视频的所有帧""" conn = self.get_conn() rows = conn.execute( - """SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id,), + """SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", + (video_id,), ).fetchall() conn.close() @@ -1223,7 +1262,8 @@ class DatabaseManager: """获取项目的所有图片""" conn = self.get_conn() rows = conn.execute( - "SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id,), + "SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", + (project_id,), ).fetchall() conn.close() @@ -1288,7 +1328,9 @@ class DatabaseManager: conn.close() return [dict(r) for r in rows] - def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]: + def get_project_multimodal_mentions( + self, project_id: str, modality: str | None = None + ) -> list[dict]: """获取项目的多模态提及""" conn = self.get_conn() @@ -1381,13 +1423,15 @@ class DatabaseManager: # 视频数量 row = conn.execute( - "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", + (project_id,), ).fetchone() stats["video_count"] = row["count"] # 图片数量 row = conn.execute( - "SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM images WHERE project_id = ?", + (project_id,), ).fetchone() stats["image_count"] = row["count"] diff --git a/backend/developer_ecosystem_manager.py b/backend/developer_ecosystem_manager.py index 8499e6f..8a0bdc5 100644 --- a/backend/developer_ecosystem_manager.py +++ b/backend/developer_ecosystem_manager.py @@ -538,7 +538,8 @@ class DeveloperEcosystemManager: """获取 SDK 版本历史""" with self._get_db() as conn: rows = conn.execute( - "SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC", (sdk_id,), + "SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC", + (sdk_id,), ).fetchall() return [self._row_to_sdk_version(row) for row in rows] @@ -700,7 +701,8 @@ class DeveloperEcosystemManager: """获取模板详情""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM template_market WHERE id = ?", (template_id,), + "SELECT * FROM template_market WHERE id = ?", + (template_id,), ).fetchone() if row: @@ -1076,7 +1078,11 @@ class DeveloperEcosystemManager: return [self._row_to_plugin(row) for row in rows] def review_plugin( - self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "", + self, + plugin_id: str, + reviewed_by: str, + status: PluginStatus, + notes: str = "", ) -> PluginMarketItem | None: """审核插件""" now = datetime.now().isoformat() @@ -1420,7 +1426,8 @@ class DeveloperEcosystemManager: """获取开发者档案""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM developer_profiles WHERE id = ?", (developer_id,), + "SELECT * FROM developer_profiles WHERE id = ?", + (developer_id,), ).fetchone() if row: @@ -1431,7 +1438,8 @@ class DeveloperEcosystemManager: """通过用户 ID 获取开发者档案""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,), + "SELECT * FROM developer_profiles WHERE user_id = ?", + (user_id,), ).fetchone() if row: @@ -1439,7 +1447,9 @@ class DeveloperEcosystemManager: return None def verify_developer( - self, developer_id: str, status: DeveloperStatus, + self, + developer_id: str, + status: DeveloperStatus, ) -> DeveloperProfile | None: """验证开发者""" now = datetime.now().isoformat() @@ -1453,9 +1463,11 @@ class DeveloperEcosystemManager: """, ( status.value, - now - if status in [DeveloperStatus.VERIFIED, DeveloperStatus.CERTIFIED] - else None, + ( + now + if status in [DeveloperStatus.VERIFIED, DeveloperStatus.CERTIFIED] + else None + ), now, developer_id, ), @@ -1469,7 +1481,8 @@ class DeveloperEcosystemManager: with self._get_db() as conn: # 统计插件数量 plugin_row = conn.execute( - "SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?", (developer_id,), + "SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?", + (developer_id,), ).fetchone() # 统计模板数量 @@ -1583,7 +1596,8 @@ class DeveloperEcosystemManager: """获取代码示例""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM code_examples WHERE id = ?", (example_id,), + "SELECT * FROM code_examples WHERE id = ?", + (example_id,), ).fetchone() if row: @@ -1699,7 +1713,8 @@ class DeveloperEcosystemManager: """获取 API 文档""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM api_documentation WHERE id = ?", (doc_id,), + "SELECT * FROM api_documentation WHERE id = ?", + (doc_id,), ).fetchone() if row: @@ -1799,7 +1814,8 @@ class DeveloperEcosystemManager: """获取开发者门户配置""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,), + "SELECT * FROM developer_portal_configs WHERE id = ?", + (config_id,), ).fetchone() if row: diff --git a/backend/document_processor.py b/backend/document_processor.py index 39dc2a5..61d396b 100644 --- a/backend/document_processor.py +++ b/backend/document_processor.py @@ -78,7 +78,7 @@ class DocumentProcessor: "PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2", ) except Exception as e: - raise ValueError(f"PDF extraction failed: {str(e)}") + raise ValueError(f"PDF extraction failed: {e!s}") def _extract_docx(self, content: bytes) -> str: """提取 DOCX 文本""" @@ -109,7 +109,7 @@ class DocumentProcessor: "DOCX processing requires python-docx. Install with: pip install python-docx", ) except Exception as e: - raise ValueError(f"DOCX extraction failed: {str(e)}") + raise ValueError(f"DOCX extraction failed: {e!s}") def _extract_txt(self, content: bytes) -> str: """提取纯文本""" diff --git a/backend/enterprise_manager.py b/backend/enterprise_manager.py index 0e78772..938c388 100644 --- a/backend/enterprise_manager.py +++ b/backend/enterprise_manager.py @@ -699,7 +699,9 @@ class EnterpriseManager: conn.close() def get_tenant_sso_config( - self, tenant_id: str, provider: str | None = None, + self, + tenant_id: str, + provider: str | None = None, ) -> SSOConfig | None: """获取租户的 SSO 配置""" conn = self._get_connection() @@ -871,7 +873,10 @@ class EnterpriseManager: return metadata def create_saml_auth_request( - self, tenant_id: str, config_id: str, relay_state: str | None = None, + self, + tenant_id: str, + config_id: str, + relay_state: str | None = None, ) -> SAMLAuthRequest: """创建 SAML 认证请求""" conn = self._get_connection() @@ -1235,7 +1240,10 @@ class EnterpriseManager: return [] def _upsert_scim_user( - self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any], + self, + conn: sqlite3.Connection, + tenant_id: str, + user_data: dict[str, Any], ) -> None: """插入或更新 SCIM 用户""" cursor = conn.cursor() @@ -1405,7 +1413,11 @@ class EnterpriseManager: try: # 获取审计日志数据 logs = self._fetch_audit_logs( - export.tenant_id, export.start_date, export.end_date, export.filters, db_manager, + export.tenant_id, + export.start_date, + export.end_date, + export.filters, + db_manager, ) # 根据合规标准过滤字段 @@ -1414,7 +1426,9 @@ class EnterpriseManager: # 生成导出文件 file_path, file_size, checksum = self._generate_export_file( - export_id, logs, export.export_format, + export_id, + logs, + export.export_format, ) now = datetime.now() @@ -1465,7 +1479,9 @@ class EnterpriseManager: return [] def _apply_compliance_filter( - self, logs: list[dict[str, Any]], standard: str, + self, + logs: list[dict[str, Any]], + standard: str, ) -> list[dict[str, Any]]: """应用合规标准字段过滤""" fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) @@ -1481,7 +1497,10 @@ class EnterpriseManager: return filtered_logs def _generate_export_file( - self, export_id: str, logs: list[dict[str, Any]], format: str, + self, + export_id: str, + logs: list[dict[str, Any]], + format: str, ) -> tuple[str, int, str]: """生成导出文件""" import hashlib @@ -1672,7 +1691,9 @@ class EnterpriseManager: conn.close() def list_retention_policies( - self, tenant_id: str, resource_type: str | None = None, + self, + tenant_id: str, + resource_type: str | None = None, ) -> list[DataRetentionPolicy]: """列出数据保留策略""" conn = self._get_connection() @@ -1876,7 +1897,10 @@ class EnterpriseManager: conn.close() def _retain_audit_logs( - self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime, + self, + conn: sqlite3.Connection, + policy: DataRetentionPolicy, + cutoff_date: datetime, ) -> dict[str, int]: """保留审计日志""" cursor = conn.cursor() @@ -1909,14 +1933,20 @@ class EnterpriseManager: return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} def _retain_projects( - self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime, + self, + conn: sqlite3.Connection, + policy: DataRetentionPolicy, + cutoff_date: datetime, ) -> dict[str, int]: """保留项目数据""" # 简化实现 return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} def _retain_transcripts( - self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime, + self, + conn: sqlite3.Connection, + policy: DataRetentionPolicy, + cutoff_date: datetime, ) -> dict[str, int]: """保留转录数据""" # 简化实现 @@ -2101,9 +2131,11 @@ class EnterpriseManager: if isinstance(row["start_date"], str) else row["start_date"] ), - end_date=datetime.fromisoformat(row["end_date"]) - if isinstance(row["end_date"], str) - else row["end_date"], + end_date=( + datetime.fromisoformat(row["end_date"]) + if isinstance(row["end_date"], str) + else row["end_date"] + ), filters=json.loads(row["filters"] or "{}"), compliance_standard=row["compliance_standard"], status=row["status"], diff --git a/backend/entity_aligner.py b/backend/entity_aligner.py index 41e2831..4247092 100644 --- a/backend/entity_aligner.py +++ b/backend/entity_aligner.py @@ -178,7 +178,10 @@ class EntityAligner: return best_match def _fallback_similarity_match( - self, entities: list[object], name: str, exclude_id: str | None = None, + self, + entities: list[object], + name: str, + exclude_id: str | None = None, ) -> object | None: """ 回退到简单的相似度匹配(不使用 embedding) @@ -212,7 +215,10 @@ class EntityAligner: return None def batch_align_entities( - self, project_id: str, new_entities: list[dict], threshold: float | None = None, + self, + project_id: str, + new_entities: list[dict], + threshold: float | None = None, ) -> list[dict]: """ 批量对齐实体 @@ -232,7 +238,10 @@ class EntityAligner: for new_ent in new_entities: matched = self.find_similar_entity( - project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold, + project_id, + new_ent["name"], + new_ent.get("definition", ""), + threshold=threshold, ) result = { diff --git a/backend/export_manager.py b/backend/export_manager.py index 670f691..c91b3e9 100644 --- a/backend/export_manager.py +++ b/backend/export_manager.py @@ -75,7 +75,10 @@ class ExportManager: self.db = db_manager def export_knowledge_graph_svg( - self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation], + self, + project_id: str, + entities: list[ExportEntity], + relations: list[ExportRelation], ) -> str: """ 导出知识图谱为 SVG 格式 @@ -220,7 +223,10 @@ class ExportManager: return "\n".join(svg_parts) def export_knowledge_graph_png( - self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation], + self, + project_id: str, + entities: list[ExportEntity], + relations: list[ExportRelation], ) -> bytes: """ 导出知识图谱为 PNG 格式 @@ -337,7 +343,9 @@ class ExportManager: return output.getvalue() def export_transcript_markdown( - self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity], + self, + transcript: ExportTranscript, + entities_map: dict[str, ExportEntity], ) -> str: """ 导出转录文本为 Markdown 格式 @@ -417,7 +425,12 @@ class ExportManager: output = io.BytesIO() doc = SimpleDocTemplate( - output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18, + output, + pagesize=A4, + rightMargin=72, + leftMargin=72, + topMargin=72, + bottomMargin=18, ) # 样式 @@ -510,7 +523,8 @@ class ExportManager: ) entity_table = Table( - entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch], + entity_data, + colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch], ) entity_table.setStyle( TableStyle( @@ -539,7 +553,8 @@ class ExportManager: relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"]) relation_table = Table( - relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch], + relation_data, + colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch], ) relation_table.setStyle( TableStyle( diff --git a/backend/growth_manager.py b/backend/growth_manager.py index 5ceeca5..c94ce8e 100644 --- a/backend/growth_manager.py +++ b/backend/growth_manager.py @@ -383,11 +383,11 @@ class GrowthManager: user_id: str, event_type: EventType, event_name: str, - properties: dict = None, - session_id: str = None, - device_info: dict = None, - referrer: str = None, - utm_params: dict = None, + properties: dict | None = None, + session_id: str | None = None, + device_info: dict | None = None, + referrer: str | None = None, + utm_params: dict | None = None, ) -> AnalyticsEvent: """追踪事件""" event_id = f"evt_{uuid.uuid4().hex[:16]}" @@ -475,7 +475,10 @@ class GrowthManager: async with httpx.AsyncClient() as client: await client.post( - "https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0, + "https://api.mixpanel.com/track", + headers=headers, + json=[payload], + timeout=10.0, ) except (RuntimeError, ValueError, TypeError) as e: print(f"Failed to send to Mixpanel: {e}") @@ -509,7 +512,11 @@ class GrowthManager: print(f"Failed to send to Amplitude: {e}") async def _update_user_profile( - self, tenant_id: str, user_id: str, event_type: EventType, event_name: str, + self, + tenant_id: str, + user_id: str, + event_type: EventType, + event_name: str, ) -> None: """更新用户画像""" with self._get_db() as conn: @@ -581,7 +588,10 @@ class GrowthManager: return None def get_user_analytics_summary( - self, tenant_id: str, start_date: datetime = None, end_date: datetime = None, + self, + tenant_id: str, + start_date: datetime | None = None, + end_date: datetime = None, ) -> dict: """获取用户分析汇总""" with self._get_db() as conn: @@ -635,7 +645,12 @@ class GrowthManager: } def create_funnel( - self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str, + self, + tenant_id: str, + name: str, + description: str, + steps: list[dict], + created_by: str, ) -> Funnel: """创建转化漏斗""" funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" @@ -673,12 +688,16 @@ class GrowthManager: return funnel def analyze_funnel( - self, funnel_id: str, period_start: datetime = None, period_end: datetime = None, + self, + funnel_id: str, + period_start: datetime | None = None, + period_end: datetime = None, ) -> FunnelAnalysis | None: """分析漏斗转化率""" with self._get_db() as conn: funnel_row = conn.execute( - "SELECT * FROM funnels WHERE id = ?", (funnel_id,), + "SELECT * FROM funnels WHERE id = ?", + (funnel_id,), ).fetchone() if not funnel_row: @@ -704,7 +723,8 @@ class GrowthManager: WHERE event_name = ? AND timestamp >= ? AND timestamp <= ? """ row = conn.execute( - query, (event_name, period_start.isoformat(), period_end.isoformat()), + query, + (event_name, period_start.isoformat(), period_end.isoformat()), ).fetchone() user_count = row["user_count"] if row else 0 @@ -752,7 +772,10 @@ class GrowthManager: ) def calculate_retention( - self, tenant_id: str, cohort_date: datetime, periods: list[int] = None, + self, + tenant_id: str, + cohort_date: datetime, + periods: list[int] = None, ) -> dict: """计算留存率""" if periods is None: @@ -825,7 +848,7 @@ class GrowthManager: secondary_metrics: list[str], min_sample_size: int = 100, confidence_level: float = 0.95, - created_by: str = None, + created_by: str | None = None, ) -> Experiment: """创建 A/B 测试实验""" experiment_id = f"exp_{uuid.uuid4().hex[:16]}" @@ -893,14 +916,17 @@ class GrowthManager: """获取实验详情""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM experiments WHERE id = ?", (experiment_id,), + "SELECT * FROM experiments WHERE id = ?", + (experiment_id,), ).fetchone() if row: return self._row_to_experiment(row) return None - def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]: + def list_experiments( + self, tenant_id: str, status: ExperimentStatus | None = None + ) -> list[Experiment]: """列出实验""" query = "SELECT * FROM experiments WHERE tenant_id = ?" params = [tenant_id] @@ -916,7 +942,10 @@ class GrowthManager: return [self._row_to_experiment(row) for row in rows] def assign_variant( - self, experiment_id: str, user_id: str, user_attributes: dict = None, + self, + experiment_id: str, + user_id: str, + user_attributes: dict | None = None, ) -> str | None: """为用户分配实验变体""" experiment = self.get_experiment(experiment_id) @@ -939,11 +968,15 @@ class GrowthManager: variant_id = self._random_allocation(experiment.variants, experiment.traffic_split) elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED: variant_id = self._stratified_allocation( - experiment.variants, experiment.traffic_split, user_attributes, + experiment.variants, + experiment.traffic_split, + user_attributes, ) else: # TARGETED variant_id = self._targeted_allocation( - experiment.variants, experiment.target_audience, user_attributes, + experiment.variants, + experiment.target_audience, + user_attributes, ) if variant_id: @@ -978,7 +1011,10 @@ class GrowthManager: return random.choices(variant_ids, weights=normalized_weights, k=1)[0] def _stratified_allocation( - self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict, + self, + variants: list[dict], + traffic_split: dict[str, float], + user_attributes: dict, ) -> str: """分层分配(基于用户属性)""" # 简化的分层分配:根据用户 ID 哈希值分配 @@ -991,7 +1027,10 @@ class GrowthManager: return self._random_allocation(variants, traffic_split) def _targeted_allocation( - self, variants: list[dict], target_audience: dict, user_attributes: dict, + self, + variants: list[dict], + target_audience: dict, + user_attributes: dict, ) -> str | None: """定向分配(基于目标受众条件)""" # 检查用户是否符合目标受众条件 @@ -1005,7 +1044,14 @@ class GrowthManager: user_value = user_attributes.get(attr_name) if user_attributes else None - if operator == "equals" and user_value != value or operator == "not_equals" and user_value == value or operator == "in" and user_value not in value: + if ( + operator == "equals" + and user_value != value + or operator == "not_equals" + and user_value == value + or operator == "in" + and user_value not in value + ): matches = False break @@ -1177,11 +1223,11 @@ class GrowthManager: template_type: EmailTemplateType, subject: str, html_content: str, - text_content: str = None, + text_content: str | None = None, variables: list[str] = None, - from_name: str = None, - from_email: str = None, - reply_to: str = None, + from_name: str | None = None, + from_email: str | None = None, + reply_to: str | None = None, ) -> EmailTemplate: """创建邮件模板""" template_id = f"et_{uuid.uuid4().hex[:16]}" @@ -1242,7 +1288,8 @@ class GrowthManager: """获取邮件模板""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM email_templates WHERE id = ?", (template_id,), + "SELECT * FROM email_templates WHERE id = ?", + (template_id,), ).fetchone() if row: @@ -1250,7 +1297,9 @@ class GrowthManager: return None def list_email_templates( - self, tenant_id: str, template_type: EmailTemplateType = None, + self, + tenant_id: str, + template_type: EmailTemplateType | None = None, ) -> list[EmailTemplate]: """列出邮件模板""" query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" @@ -1297,7 +1346,7 @@ class GrowthManager: name: str, template_id: str, recipient_list: list[dict], - scheduled_at: datetime = None, + scheduled_at: datetime | None = None, ) -> EmailCampaign: """创建邮件营销活动""" campaign_id = f"ec_{uuid.uuid4().hex[:16]}" @@ -1377,7 +1426,12 @@ class GrowthManager: return campaign async def send_email( - self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict, + self, + campaign_id: str, + user_id: str, + email: str, + template_id: str, + variables: dict, ) -> bool: """发送单封邮件""" template = self.get_email_template(template_id) @@ -1448,7 +1502,8 @@ class GrowthManager: """发送整个营销活动""" with self._get_db() as conn: campaign_row = conn.execute( - "SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,), + "SELECT * FROM email_campaigns WHERE id = ?", + (campaign_id,), ).fetchone() if not campaign_row: @@ -1478,7 +1533,11 @@ class GrowthManager: variables = self._get_user_variables(log["tenant_id"], log["user_id"]) success = await self.send_email( - campaign_id, log["user_id"], log["email"], log["template_id"], variables, + campaign_id, + log["user_id"], + log["email"], + log["template_id"], + variables, ) if success: @@ -1763,7 +1822,8 @@ class GrowthManager: with self._get_db() as conn: row = conn.execute( - "SELECT 1 FROM referrals WHERE referral_code = ?", (code,), + "SELECT 1 FROM referrals WHERE referral_code = ?", + (code,), ).fetchone() if not row: @@ -1773,7 +1833,8 @@ class GrowthManager: """获取推荐计划""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM referral_programs WHERE id = ?", (program_id,), + "SELECT * FROM referral_programs WHERE id = ?", + (program_id,), ).fetchone() if row: @@ -1859,7 +1920,8 @@ class GrowthManager: "expired": stats["expired"] or 0, "unique_referrers": stats["unique_referrers"] or 0, "conversion_rate": round( - (stats["converted"] or 0) / max(stats["total_referrals"] or 1, 1), 4, + (stats["converted"] or 0) / max(stats["total_referrals"] or 1, 1), + 4, ), } @@ -1922,7 +1984,10 @@ class GrowthManager: return incentive def check_team_incentive_eligibility( - self, tenant_id: str, current_tier: str, team_size: int, + self, + tenant_id: str, + current_tier: str, + team_size: int, ) -> list[TeamIncentive]: """检查团队激励资格""" with self._get_db() as conn: diff --git a/backend/image_processor.py b/backend/image_processor.py index 7cbe12f..31c4f93 100644 --- a/backend/image_processor.py +++ b/backend/image_processor.py @@ -96,7 +96,7 @@ class ImageProcessor: "other": "其他", } - def __init__(self, temp_dir: str = None) -> None: + def __init__(self, temp_dir: str | None = None) -> None: """ 初始化图片处理器 @@ -106,7 +106,7 @@ class ImageProcessor: self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") os.makedirs(self.temp_dir, exist_ok=True) - def preprocess_image(self, image, image_type: str = None) -> None: + def preprocess_image(self, image, image_type: str | None = None) -> None: """ 预处理图片以提高OCR质量 @@ -328,7 +328,10 @@ class ImageProcessor: return unique_entities def generate_description( - self, image_type: str, ocr_text: str, entities: list[ImageEntity], + self, + image_type: str, + ocr_text: str, + entities: list[ImageEntity], ) -> str: """ 生成图片描述 @@ -361,8 +364,8 @@ class ImageProcessor: def process_image( self, image_data: bytes, - filename: str = None, - image_id: str = None, + filename: str | None = None, + image_id: str | None = None, detect_type: bool = True, ) -> ImageProcessingResult: """ @@ -487,7 +490,9 @@ class ImageProcessor: return relations def process_batch( - self, images_data: list[tuple[bytes, str]], project_id: str = None, + self, + images_data: list[tuple[bytes, str]], + project_id: str | None = None, ) -> BatchProcessingResult: """ 批量处理图片 @@ -561,7 +566,7 @@ class ImageProcessor: _image_processor = None -def get_image_processor(temp_dir: str = None) -> ImageProcessor: +def get_image_processor(temp_dir: str | None = None) -> ImageProcessor: """获取图片处理器单例""" global _image_processor if _image_processor is None: diff --git a/backend/knowledge_reasoner.py b/backend/knowledge_reasoner.py index 9f1a013..2cdff42 100644 --- a/backend/knowledge_reasoner.py +++ b/backend/knowledge_reasoner.py @@ -51,7 +51,7 @@ class InferencePath: class KnowledgeReasoner: """知识推理引擎""" - def __init__(self, api_key: str = None, base_url: str = None) -> None: + def __init__(self, api_key: str | None = None, base_url: str = None) -> None: self.api_key = api_key or KIMI_API_KEY self.base_url = base_url or KIMI_BASE_URL self.headers = { @@ -82,7 +82,11 @@ class KnowledgeReasoner: return result["choices"][0]["message"]["content"] async def enhanced_qa( - self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium", + self, + query: str, + project_context: dict, + graph_data: dict, + reasoning_depth: str = "medium", ) -> ReasoningResult: """ 增强问答 - 结合图谱推理的问答 @@ -139,7 +143,10 @@ class KnowledgeReasoner: return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"} async def _causal_reasoning( - self, query: str, project_context: dict, graph_data: dict, + self, + query: str, + project_context: dict, + graph_data: dict, ) -> ReasoningResult: """因果推理 - 分析原因和影响""" @@ -200,7 +207,10 @@ class KnowledgeReasoner: ) async def _comparative_reasoning( - self, query: str, project_context: dict, graph_data: dict, + self, + query: str, + project_context: dict, + graph_data: dict, ) -> ReasoningResult: """对比推理 - 比较实体间的异同""" @@ -254,7 +264,10 @@ class KnowledgeReasoner: ) async def _temporal_reasoning( - self, query: str, project_context: dict, graph_data: dict, + self, + query: str, + project_context: dict, + graph_data: dict, ) -> ReasoningResult: """时序推理 - 分析时间线和演变""" @@ -308,7 +321,10 @@ class KnowledgeReasoner: ) async def _associative_reasoning( - self, query: str, project_context: dict, graph_data: dict, + self, + query: str, + project_context: dict, + graph_data: dict, ) -> ReasoningResult: """关联推理 - 发现实体间的隐含关联""" @@ -362,7 +378,11 @@ class KnowledgeReasoner: ) def find_inference_paths( - self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3, + self, + start_entity: str, + end_entity: str, + graph_data: dict, + max_depth: int = 3, ) -> list[InferencePath]: """ 发现两个实体之间的推理路径 @@ -449,7 +469,10 @@ class KnowledgeReasoner: return length_factor * confidence_factor async def summarize_project( - self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive", + self, + project_context: dict, + graph_data: dict, + summary_type: str = "comprehensive", ) -> dict: """ 项目智能总结 diff --git a/backend/llm_client.py b/backend/llm_client.py index 3010527..2603c7c 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -43,7 +43,7 @@ class RelationExtractionResult: class LLMClient: """Kimi API 客户端""" - def __init__(self, api_key: str = None, base_url: str = None) -> None: + def __init__(self, api_key: str | None = None, base_url: str = None) -> None: self.api_key = api_key or KIMI_API_KEY self.base_url = base_url or KIMI_BASE_URL self.headers = { @@ -52,7 +52,10 @@ class LLMClient: } async def chat( - self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False, + self, + messages: list[ChatMessage], + temperature: float = 0.3, + stream: bool = False, ) -> str: """发送聊天请求""" if not self.api_key: @@ -77,7 +80,9 @@ class LLMClient: return result["choices"][0]["message"]["content"] async def chat_stream( - self, messages: list[ChatMessage], temperature: float = 0.3, + self, + messages: list[ChatMessage], + temperature: float = 0.3, ) -> AsyncGenerator[str, None]: """流式聊天请求""" if not self.api_key: @@ -90,13 +95,16 @@ class LLMClient: "stream": True, } - async with httpx.AsyncClient() as client, client.stream( - "POST", - f"{self.base_url}/v1/chat/completions", - headers=self.headers, - json=payload, - timeout=120.0, - ) as response: + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{self.base_url}/v1/chat/completions", + headers=self.headers, + json=payload, + timeout=120.0, + ) as response, + ): response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): @@ -112,7 +120,8 @@ class LLMClient: pass async def extract_entities_with_confidence( - self, text: str, + self, + text: str, ) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]: """提取实体和关系,带置信度分数""" prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回: @@ -189,7 +198,8 @@ class LLMClient: messages = [ ChatMessage( - role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。", + role="system", + content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。", ), ChatMessage(role="user", content=prompt), ] diff --git a/backend/localization_manager.py b/backend/localization_manager.py index a1b89fe..8d8679b 100644 --- a/backend/localization_manager.py +++ b/backend/localization_manager.py @@ -963,7 +963,11 @@ class LocalizationManager: self._close_if_file_db(conn) def get_translation( - self, key: str, language: str, namespace: str = "common", fallback: bool = True, + self, + key: str, + language: str, + namespace: str = "common", + fallback: bool = True, ) -> str | None: conn = self._get_connection() try: @@ -979,7 +983,10 @@ class LocalizationManager: lang_config = self.get_language_config(language) if lang_config and lang_config.fallback_language: return self.get_translation( - key, lang_config.fallback_language, namespace, False, + key, + lang_config.fallback_language, + namespace, + False, ) if language != "en": return self.get_translation(key, "en", namespace, False) @@ -1019,7 +1026,11 @@ class LocalizationManager: self._close_if_file_db(conn) def _get_translation_internal( - self, conn: sqlite3.Connection, key: str, language: str, namespace: str, + self, + conn: sqlite3.Connection, + key: str, + language: str, + namespace: str, ) -> Translation | None: cursor = conn.cursor() cursor.execute( @@ -1121,7 +1132,9 @@ class LocalizationManager: self._close_if_file_db(conn) def list_data_centers( - self, status: str | None = None, region: str | None = None, + self, + status: str | None = None, + region: str | None = None, ) -> list[DataCenter]: conn = self._get_connection() try: @@ -1146,7 +1159,8 @@ class LocalizationManager: try: cursor = conn.cursor() cursor.execute( - "SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,), + "SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", + (tenant_id,), ) row = cursor.fetchone() if row: @@ -1156,7 +1170,10 @@ class LocalizationManager: self._close_if_file_db(conn) def set_tenant_data_center( - self, tenant_id: str, region_code: str, data_residency: str = "regional", + self, + tenant_id: str, + region_code: str, + data_residency: str = "regional", ) -> TenantDataCenterMapping: conn = self._get_connection() try: @@ -1222,7 +1239,8 @@ class LocalizationManager: try: cursor = conn.cursor() cursor.execute( - "SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,), + "SELECT * FROM localized_payment_methods WHERE provider = ?", + (provider,), ) row = cursor.fetchone() if row: @@ -1232,7 +1250,10 @@ class LocalizationManager: self._close_if_file_db(conn) def list_payment_methods( - self, country_code: str | None = None, currency: str | None = None, active_only: bool = True, + self, + country_code: str | None = None, + currency: str | None = None, + active_only: bool = True, ) -> list[LocalizedPaymentMethod]: conn = self._get_connection() try: @@ -1255,7 +1276,9 @@ class LocalizationManager: self._close_if_file_db(conn) def get_localized_payment_methods( - self, country_code: str, language: str = "en", + self, + country_code: str, + language: str = "en", ) -> list[dict[str, Any]]: methods = self.list_payment_methods(country_code=country_code) result = [] @@ -1287,7 +1310,9 @@ class LocalizationManager: self._close_if_file_db(conn) def list_country_configs( - self, region: str | None = None, active_only: bool = True, + self, + region: str | None = None, + active_only: bool = True, ) -> list[CountryConfig]: conn = self._get_connection() try: @@ -1345,14 +1370,19 @@ class LocalizationManager: return dt.strftime("%Y-%m-%d %H:%M") def format_number( - self, number: float, language: str = "en", decimal_places: int | None = None, + self, + number: float, + language: str = "en", + decimal_places: int | None = None, ) -> str: try: if BABEL_AVAILABLE: try: locale = Locale.parse(language.replace("_", "-")) return numbers.format_decimal( - number, locale=locale, decimal_quantization=(decimal_places is not None), + number, + locale=locale, + decimal_quantization=(decimal_places is not None), ) except (ValueError, AttributeError): pass @@ -1514,7 +1544,9 @@ class LocalizationManager: self._close_if_file_db(conn) def detect_user_preferences( - self, accept_language: str | None = None, ip_country: str | None = None, + self, + accept_language: str | None = None, + ip_country: str | None = None, ) -> dict[str, str]: preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} if accept_language: diff --git a/backend/main.py b/backend/main.py index fae1b67..1456d02 100644 --- a/backend/main.py +++ b/backend/main.py @@ -18,6 +18,7 @@ from datetime import datetime, timedelta from typing import Any, Optional import httpx +from export_manager import ExportEntity, ExportRelation, ExportTranscript from fastapi import ( Body, Depends, @@ -33,11 +34,9 @@ from fastapi import ( from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel, Field - -from export_manager import ExportEntity, ExportRelation, ExportTranscript from ops_manager import OpsManager from plugin_manager import PluginManager +from pydantic import BaseModel, Field # Configure logger logger = logging.getLogger(__name__) @@ -777,7 +776,8 @@ class WorkflowListResponse(BaseModel): class WorkflowTaskCreate(BaseModel): name: str = Field(..., description="任务名称") task_type: str = Field( - ..., description="任务类型: analyze, align, discover_relations, notify, custom", + ..., + description="任务类型: analyze, align, discover_relations, notify, custom", ) config: dict = Field(default_factory=dict, description="任务配置") order: int = Field(default=0, description="执行顺序") @@ -979,7 +979,9 @@ async def delete_entity(entity_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/entities/{entity_id}/merge", tags=["Entities"]) async def merge_entities_endpoint( - entity_id: str, merge_req: EntityMergeRequest, _=Depends(verify_api_key), + entity_id: str, + merge_req: EntityMergeRequest, + _=Depends(verify_api_key), ): """合并两个实体""" if not DB_AVAILABLE: @@ -1012,7 +1014,9 @@ async def merge_entities_endpoint( @app.post("/api/v1/projects/{project_id}/relations", tags=["Relations"]) async def create_relation_endpoint( - project_id: str, relation: RelationCreate, _=Depends(verify_api_key), + project_id: str, + relation: RelationCreate, + _=Depends(verify_api_key), ): """创建新的实体关系""" if not DB_AVAILABLE: @@ -1063,7 +1067,9 @@ async def update_relation(relation_id: str, relation: RelationCreate, _=Depends( db = get_db_manager() updated = db.update_relation( - relation_id=relation_id, relation_type=relation.relation_type, evidence=relation.evidence, + relation_id=relation_id, + relation_type=relation.relation_type, + evidence=relation.evidence, ) return { @@ -1094,7 +1100,9 @@ async def get_transcript(transcript_id: str, _=Depends(verify_api_key)): @app.put("/api/v1/transcripts/{transcript_id}", tags=["Transcripts"]) async def update_transcript( - transcript_id: str, update: TranscriptUpdate, _=Depends(verify_api_key), + transcript_id: str, + update: TranscriptUpdate, + _=Depends(verify_api_key), ): """更新转录文本(人工修正)""" if not DB_AVAILABLE: @@ -1129,7 +1137,9 @@ class ManualEntityCreate(BaseModel): @app.post("/api/v1/projects/{project_id}/entities", tags=["Entities"]) async def create_manual_entity( - project_id: str, entity: ManualEntityCreate, _=Depends(verify_api_key), + project_id: str, + entity: ManualEntityCreate, + _=Depends(verify_api_key), ): """手动创建实体(划词新建)""" if not DB_AVAILABLE: @@ -1165,7 +1175,7 @@ async def create_manual_entity( start_pos=entity.start_pos, end_pos=entity.end_pos, text_snippet=text[ - max(0, entity.start_pos - 20): min(len(text), entity.end_pos + 20) + max(0, entity.start_pos - 20) : min(len(text), entity.end_pos + 20) ], confidence=1.0, ) @@ -1386,7 +1396,10 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( ), ) ent_model = EntityModel( - id=new_ent.id, name=new_ent.name, type=new_ent.type, definition=new_ent.definition, + id=new_ent.id, + name=new_ent.name, + type=new_ent.type, + definition=new_ent.definition, ) entity_name_to_id[raw_ent["name"]] = new_ent.id @@ -1407,7 +1420,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( start_pos=pos, end_pos=pos + len(name), text_snippet=full_text[ - max(0, pos - 20): min(len(full_text), pos + len(name) + 20) + max(0, pos - 20) : min(len(full_text), pos + len(name) + 20) ], confidence=1.0, ) @@ -1465,7 +1478,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen try: result = processor.process(content, file.filename) except (OSError, ValueError, TypeError, RuntimeError) as e: - raise HTTPException(status_code=400, detail=f"Document processing failed: {str(e)}") + raise HTTPException(status_code=400, detail=f"Document processing failed: {e!s}") # 保存文档转录记录 transcript_id = str(uuid.uuid4())[:UUID_LENGTH] @@ -1533,7 +1546,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen start_pos=pos, end_pos=pos + len(name), text_snippet=full_text[ - max(0, pos - 20): min(len(full_text), pos + len(name) + 20) + max(0, pos - 20) : min(len(full_text), pos + len(name) + 20) ], confidence=1.0, ) @@ -1675,7 +1688,9 @@ async def add_glossary_term(project_id: str, term: GlossaryTermCreate, _=Depends raise HTTPException(status_code=404, detail="Project not found") term_id = db.add_glossary_term( - project_id=project_id, term=term.term, pronunciation=term.pronunciation, + project_id=project_id, + term=term.term, + pronunciation=term.pronunciation, ) return {"id": term_id, "term": term.term, "pronunciation": term.pronunciation, "success": True} @@ -1708,7 +1723,9 @@ async def delete_glossary_term(term_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/projects/{project_id}/align-entities") async def align_project_entities( - project_id: str, threshold: float = 0.85, _=Depends(verify_api_key), + project_id: str, + threshold: float = 0.85, + _=Depends(verify_api_key), ): """运行实体对齐算法,合并相似实体""" if not DB_AVAILABLE: @@ -1732,7 +1749,11 @@ async def align_project_entities( continue similar = aligner.find_similar_entity( - project_id, entity.name, entity.definition, exclude_id=entity.id, threshold=threshold, + project_id, + entity.name, + entity.definition, + exclude_id=entity.id, + threshold=threshold, ) if similar: @@ -1805,9 +1826,9 @@ async def get_project_transcripts(project_id: str, _=Depends(verify_api_key)): "filename": t["filename"], "type": t.get("type", "audio"), "created_at": t["created_at"], - "preview": t["full_text"][:100] + "..." - if len(t["full_text"]) > 100 - else t["full_text"], + "preview": ( + t["full_text"][:100] + "..." if len(t["full_text"]) > 100 else t["full_text"] + ), } for t in transcripts ] @@ -1888,7 +1909,8 @@ async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_k async def stream_response(): messages = [ ChatMessage( - role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。", + role="system", + content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。", ), ChatMessage( role="user", @@ -2149,9 +2171,9 @@ async def search_entities(project_id: str, q: str, _=Depends(verify_api_key)): @app.get("/api/v1/projects/{project_id}/timeline") async def get_project_timeline( project_id: str, - entity_id: str = None, - start_date: str = None, - end_date: str = None, + entity_id: str | None = None, + start_date: str | None = None, + end_date: str | None = None, _=Depends(verify_api_key), ): """获取项目时间线 - 按时间顺序的实体提及和关系事件""" @@ -2271,7 +2293,10 @@ async def reasoning_query(project_id: str, query: ReasoningQuery, _=Depends(veri @app.post("/api/v1/projects/{project_id}/reasoning/inference-path") async def find_inference_path( - project_id: str, start_entity: str, end_entity: str, _=Depends(verify_api_key), + project_id: str, + start_entity: str, + end_entity: str, + _=Depends(verify_api_key), ): """ 发现两个实体之间的推理路径 @@ -2354,7 +2379,9 @@ async def project_summary(project_id: str, req: SummaryRequest, _=Depends(verify # 生成总结 summary = await reasoner.summarize_project( - project_context=project_context, graph_data=graph_data, summary_type=req.summary_type, + project_context=project_context, + graph_data=graph_data, + summary_type=req.summary_type, ) return {"project_id": project_id, "summary_type": req.summary_type, **summary**summary} @@ -2402,7 +2429,9 @@ class EntityAttributeBatchSet(BaseModel): @app.post("/api/v1/projects/{project_id}/attribute-templates") async def create_attribute_template_endpoint( - project_id: str, template: AttributeTemplateCreate, _=Depends(verify_api_key), + project_id: str, + template: AttributeTemplateCreate, + _=Depends(verify_api_key), ): """创建属性模板""" if not DB_AVAILABLE: @@ -2485,7 +2514,9 @@ async def get_attribute_template_endpoint(template_id: str, _=Depends(verify_api @app.put("/api/v1/attribute-templates/{template_id}") async def update_attribute_template_endpoint( - template_id: str, update: AttributeTemplateUpdate, _=Depends(verify_api_key), + template_id: str, + update: AttributeTemplateUpdate, + _=Depends(verify_api_key), ): """更新属性模板""" if not DB_AVAILABLE: @@ -2519,7 +2550,9 @@ async def delete_attribute_template_endpoint(template_id: str, _=Depends(verify_ @app.post("/api/v1/entities/{entity_id}/attributes") async def set_entity_attribute_endpoint( - entity_id: str, attr: EntityAttributeSet, _=Depends(verify_api_key), + entity_id: str, + attr: EntityAttributeSet, + _=Depends(verify_api_key), ): """设置实体属性值""" if not DB_AVAILABLE: @@ -2550,7 +2583,8 @@ async def set_entity_attribute_endpoint( # 检查是否已存在 conn = db.get_conn() existing = conn.execute( - "SELECT * FROM entity_attributes WHERE entity_id = ? AND name = ?", (entity_id, attr.name), + "SELECT * FROM entity_attributes WHERE entity_id = ? AND name = ?", + (entity_id, attr.name), ).fetchone() now = datetime.now().isoformat() @@ -2623,7 +2657,9 @@ async def set_entity_attribute_endpoint( @app.post("/api/v1/entities/{entity_id}/attributes/batch") async def batch_set_entity_attributes_endpoint( - entity_id: str, batch: EntityAttributeBatchSet, _=Depends(verify_api_key), + entity_id: str, + batch: EntityAttributeBatchSet, + _=Depends(verify_api_key), ): """批量设置实体属性值""" if not DB_AVAILABLE: @@ -2645,7 +2681,9 @@ async def batch_set_entity_attributes_endpoint( value=attr_data.value, ) db.set_entity_attribute( - new_attr, changed_by="user", change_reason=batch.change_reason or "批量更新", + new_attr, + changed_by="user", + change_reason=batch.change_reason or "批量更新", ) results.append( { @@ -2690,7 +2728,10 @@ async def get_entity_attributes_endpoint(entity_id: str, _=Depends(verify_api_ke @app.delete("/api/v1/entities/{entity_id}/attributes/{template_id}") async def delete_entity_attribute_endpoint( - entity_id: str, template_id: str, reason: str | None = "", _=Depends(verify_api_key), + entity_id: str, + template_id: str, + reason: str | None = "", + _=Depends(verify_api_key), ): """删除实体属性值""" if not DB_AVAILABLE: @@ -2707,7 +2748,9 @@ async def delete_entity_attribute_endpoint( @app.get("/api/v1/entities/{entity_id}/attributes/history") async def get_entity_attribute_history_endpoint( - entity_id: str, limit: int = 50, _=Depends(verify_api_key), + entity_id: str, + limit: int = 50, + _=Depends(verify_api_key), ): """获取实体的属性变更历史""" if not DB_AVAILABLE: @@ -2732,7 +2775,9 @@ async def get_entity_attribute_history_endpoint( @app.get("/api/v1/attribute-templates/{template_id}/history") async def get_template_history_endpoint( - template_id: str, limit: int = 50, _=Depends(verify_api_key), + template_id: str, + limit: int = 50, + _=Depends(verify_api_key), ): """获取属性模板的所有变更历史(跨实体)""" if not DB_AVAILABLE: @@ -3098,7 +3143,12 @@ async def export_report_pdf_endpoint(project_id: str, _=Depends(verify_api_key)) export_mgr = get_export_manager() pdf_bytes = export_mgr.export_project_report_pdf( - project_id, project.name, entities, relations, transcripts, summary, + project_id, + project.name, + entities, + relations, + transcripts, + summary, ) return StreamingResponse( @@ -3171,7 +3221,11 @@ async def export_project_json_endpoint(project_id: str, _=Depends(verify_api_key export_mgr = get_export_manager() json_content = export_mgr.export_project_json( - project_id, project.name, entities, relations, transcripts, + project_id, + project.name, + entities, + relations, + transcripts, ) return StreamingResponse( @@ -3370,7 +3424,9 @@ async def find_shortest_path(request: PathQueryRequest, _=Depends(verify_api_key raise HTTPException(status_code=503, detail="Neo4j not connected") path = manager.find_shortest_path( - request.source_entity_id, request.target_entity_id, request.max_depth, + request.source_entity_id, + request.target_entity_id, + request.max_depth, ) if not path: @@ -3393,7 +3449,9 @@ async def find_all_paths(request: PathQueryRequest, _=Depends(verify_api_key)): raise HTTPException(status_code=503, detail="Neo4j not connected") paths = manager.find_all_paths( - request.source_entity_id, request.target_entity_id, request.max_depth, + request.source_entity_id, + request.target_entity_id, + request.max_depth, ) return { @@ -3406,7 +3464,10 @@ async def find_all_paths(request: PathQueryRequest, _=Depends(verify_api_key)): @app.get("/api/v1/entities/{entity_id}/neighbors") async def get_entity_neighbors( - entity_id: str, relation_type: str = None, limit: int = 50, _=Depends(verify_api_key), + entity_id: str, + relation_type: str | None = None, + limit: int = 50, + _=Depends(verify_api_key), ): """获取实体的邻居节点""" if not NEO4J_AVAILABLE: @@ -3441,7 +3502,9 @@ async def get_common_neighbors(entity_id1: str, entity_id2: str, _=Depends(verif @app.get("/api/v1/projects/{project_id}/graph/centrality") async def get_centrality_analysis( - project_id: str, metric: str = "degree", _=Depends(verify_api_key), + project_id: str, + metric: str = "degree", + _=Depends(verify_api_key), ): """获取中心性分析结果""" if not NEO4J_AVAILABLE: @@ -3544,7 +3607,10 @@ async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)): @app.get("/api/v1/api-keys", response_model=ApiKeyListResponse, tags=["API Keys"]) async def list_api_keys( - status: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key), + status: str | None = None, + limit: int = 100, + offset: int = 0, + _=Depends(verify_api_key), ): """ 列出所有 API Keys @@ -3689,13 +3755,18 @@ async def get_api_key_stats(key_id: str, days: int = 30, _=Depends(verify_api_ke stats = key_manager.get_call_stats(key_id, days=days) return ApiStatsResponse( - summary=ApiCallStats(**stats["summary"]), endpoints=stats["endpoints"], daily=stats["daily"], + summary=ApiCallStats(**stats["summary"]), + endpoints=stats["endpoints"], + daily=stats["daily"], ) @app.get("/api/v1/api-keys/{key_id}/logs", response_model=ApiLogsResponse, tags=["API Keys"]) async def get_api_key_logs( - key_id: str, limit: int = 100, offset: int = 0, _=Depends(verify_api_key), + key_id: str, + limit: int = 100, + offset: int = 0, + _=Depends(verify_api_key), ): """ 获取 API Key 的调用日志 @@ -3739,7 +3810,10 @@ async def get_rate_limit_status(request: Request, _=Depends(verify_api_key)): """获取当前请求的限流状态""" if not RATE_LIMITER_AVAILABLE: return RateLimitStatus( - limit=60, remaining=60, reset_time=int(time.time()) + 60, window="minute", + limit=60, + remaining=60, + reset_time=int(time.time()) + 60, + window="minute", ) limiter = get_rate_limiter() @@ -3757,7 +3831,10 @@ async def get_rate_limit_status(request: Request, _=Depends(verify_api_key)): info = await limiter.get_limit_info(limit_key) return RateLimitStatus( - limit=limit, remaining=info.remaining, reset_time=info.reset_time, window="minute", + limit=limit, + remaining=info.remaining, + reset_time=info.reset_time, + window="minute", ) @@ -3803,7 +3880,7 @@ async def system_status(): # ==================== Phase 7: Workflow Automation Endpoints ==================== # Workflow Manager singleton -_workflow_manager: Any = None +_workflow_manager: Any | None = None def get_workflow_manager_instance() -> Any: @@ -3960,7 +4037,9 @@ async def get_workflow_endpoint(workflow_id: str, _=Depends(verify_api_key)): @app.patch("/api/v1/workflows/{workflow_id}", response_model=WorkflowResponse, tags=["Workflows"]) async def update_workflow_endpoint( - workflow_id: str, request: WorkflowUpdate, _=Depends(verify_api_key), + workflow_id: str, + request: WorkflowUpdate, + _=Depends(verify_api_key), ): """更新工作流""" if not WORKFLOW_AVAILABLE: @@ -4017,7 +4096,9 @@ async def delete_workflow_endpoint(workflow_id: str, _=Depends(verify_api_key)): tags=["Workflows"], ) async def trigger_workflow_endpoint( - workflow_id: str, request: WorkflowTriggerRequest = None, _=Depends(verify_api_key), + workflow_id: str, + request: WorkflowTriggerRequest | None = None, + _=Depends(verify_api_key), ): """手动触发工作流""" if not WORKFLOW_AVAILABLE: @@ -4027,7 +4108,8 @@ async def trigger_workflow_endpoint( try: result = await manager.execute_workflow( - workflow_id, input_data=request.input_data if request else {}, + workflow_id, + input_data=request.input_data if request else {}, ) return WorkflowTriggerResponse( @@ -4210,7 +4292,9 @@ async def get_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)): @app.patch("/api/v1/webhooks/{webhook_id}", response_model=WebhookResponse, tags=["Webhooks"]) async def update_webhook_endpoint( - webhook_id: str, request: WebhookUpdate, _=Depends(verify_api_key), + webhook_id: str, + request: WebhookUpdate, + _=Depends(verify_api_key), ): """更新 Webhook 配置""" if not WORKFLOW_AVAILABLE: @@ -4390,7 +4474,8 @@ async def upload_video_endpoint( if not result.success: raise HTTPException( - status_code=500, detail=f"Video processing failed: {result.error_message}", + status_code=500, + detail=f"Video processing failed: {result.error_message}", ) # 保存视频信息到数据库 @@ -4521,9 +4606,9 @@ async def upload_video_endpoint( status="completed", audio_extracted=bool(result.audio_path), frame_count=len(result.frames), - ocr_text_preview=result.full_text[:200] + "..." - if len(result.full_text) > 200 - else result.full_text, + ocr_text_preview=( + result.full_text[:200] + "..." if len(result.full_text) > 200 else result.full_text + ), message="Video processed successfully", ) @@ -4572,7 +4657,8 @@ async def upload_image_endpoint( if not result.success: raise HTTPException( - status_code=500, detail=f"Image processing failed: {result.error_message}", + status_code=500, + detail=f"Image processing failed: {result.error_message}", ) # 保存图片信息到数据库 @@ -4668,9 +4754,9 @@ async def upload_image_endpoint( project_id=project_id, filename=file.filename, image_type=result.image_type, - ocr_text_preview=result.ocr_text[:200] + "..." - if len(result.ocr_text) > 200 - else result.ocr_text, + ocr_text_preview=( + result.ocr_text[:200] + "..." if len(result.ocr_text) > 200 else result.ocr_text + ), description=result.description, entity_count=len(result.entities), status="completed", @@ -4679,7 +4765,9 @@ async def upload_image_endpoint( @app.post("/api/v1/projects/{project_id}/upload-images-batch", tags=["Multimodal"]) async def upload_images_batch_endpoint( - project_id: str, files: list[UploadFile] = File(...), _=Depends(verify_api_key), + project_id: str, + files: list[UploadFile] = File(...), + _=Depends(verify_api_key), ): """ 批量上传图片文件进行处理 @@ -4768,7 +4856,9 @@ async def upload_images_batch_endpoint( tags=["Multimodal"], ) async def align_multimodal_entities_endpoint( - project_id: str, threshold: float = 0.85, _=Depends(verify_api_key), + project_id: str, + threshold: float = 0.85, + _=Depends(verify_api_key), ): """ 跨模态实体对齐 @@ -4795,7 +4885,8 @@ async def align_multimodal_entities_endpoint( # 获取多模态提及 conn = db.get_conn() mentions = conn.execute( - """SELECT * FROM multimodal_mentions WHERE project_id = ?""", (project_id,), + """SELECT * FROM multimodal_mentions WHERE project_id = ?""", + (project_id,), ).fetchall() conn.close() @@ -4894,12 +4985,14 @@ async def get_multimodal_stats_endpoint(project_id: str, _=Depends(verify_api_ke # 统计视频数量 video_count = conn.execute( - "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", + (project_id,), ).fetchone()["count"] # 统计图片数量 image_count = conn.execute( - "SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) as count FROM images WHERE project_id = ?", + (project_id,), ).fetchone()["count"] # 统计多模态实体提及 @@ -5001,9 +5094,9 @@ async def list_project_images_endpoint(project_id: str, _=Depends(verify_api_key else img["ocr_text"] ), "description": img["description"], - "entity_count": len(json.loads(img["extracted_entities"])) - if img["extracted_entities"] - else 0, + "entity_count": ( + len(json.loads(img["extracted_entities"])) if img["extracted_entities"] else 0 + ), "status": img["status"], "created_at": img["created_at"], } @@ -5333,7 +5426,8 @@ class WebDAVSyncCreate(BaseModel): password: str = Field(..., description="密码") remote_path: str = Field(default="/insightflow", description="远程路径") sync_mode: str = Field( - default="bidirectional", description="同步模式: bidirectional, upload_only, download_only", + default="bidirectional", + description="同步模式: bidirectional, upload_only, download_only", ) sync_interval: int = Field(default=3600, description="同步间隔(秒)") @@ -5538,7 +5632,8 @@ async def delete_plugin_endpoint(plugin_id: str, _=Depends(verify_api_key)): tags=["Chrome Extension"], ) async def create_chrome_token_endpoint( - request: ChromeExtensionTokenCreate, _=Depends(verify_api_key), + request: ChromeExtensionTokenCreate, + _=Depends(verify_api_key), ): """ 创建 Chrome 扩展令牌 @@ -5734,7 +5829,9 @@ async def create_dingtalk_session_endpoint(request: BotSessionCreate, _=Depends( @app.get("/api/v1/plugins/bot/{bot_type}/sessions", tags=["Bot"]) async def list_bot_sessions_endpoint( - bot_type: str, project_id: str | None = None, _=Depends(verify_api_key), + bot_type: str, + project_id: str | None = None, + _=Depends(verify_api_key), ): """列出机器人会话""" if not PLUGIN_MANAGER_AVAILABLE: @@ -5812,7 +5909,9 @@ async def bot_webhook_endpoint(bot_type: str, request: Request): if not session: # 自动创建会话 session = handler.create_session( - session_id=session_id, session_name=f"Auto-{session_id[:8]}", webhook_url="", + session_id=session_id, + session_name=f"Auto-{session_id[:8]}", + webhook_url="", ) # 处理消息 @@ -5827,7 +5926,10 @@ async def bot_webhook_endpoint(bot_type: str, request: Request): @app.post("/api/v1/plugins/bot/{bot_type}/sessions/{session_id}/send", tags=["Bot"]) async def send_bot_message_endpoint( - bot_type: str, session_id: str, message: str, _=Depends(verify_api_key), + bot_type: str, + session_id: str, + message: str, + _=Depends(verify_api_key), ): """发送消息到机器人会话""" if not PLUGIN_MANAGER_AVAILABLE: @@ -5939,7 +6041,9 @@ async def create_make_endpoint(request: WebhookEndpointCreate, _=Depends(verify_ @app.get("/api/v1/plugins/integrations/{endpoint_type}", tags=["Integrations"]) async def list_integration_endpoints_endpoint( - endpoint_type: str, project_id: str | None = None, _=Depends(verify_api_key), + endpoint_type: str, + project_id: str | None = None, + _=Depends(verify_api_key), ): """列出集成端点""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6006,13 +6110,18 @@ async def test_integration_endpoint(endpoint_id: str, _=Depends(verify_api_key)) result = await handler.test_endpoint(endpoint) return WebhookTestResponse( - success=result["success"], endpoint_id=endpoint_id, message=result["message"], + success=result["success"], + endpoint_id=endpoint_id, + message=result["message"], ) @app.post("/api/v1/plugins/integrations/{endpoint_id}/trigger", tags=["Integrations"]) async def trigger_integration_endpoint( - endpoint_id: str, event_type: str, data: dict, _=Depends(verify_api_key), + endpoint_id: str, + event_type: str, + data: dict, + _=Depends(verify_api_key), ): """手动触发集成端点""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6124,7 +6233,9 @@ async def list_webdav_syncs_endpoint(project_id: str | None = None, _=Depends(ve @app.post( - "/api/v1/plugins/webdav/{sync_id}/test", response_model=WebDAVTestResponse, tags=["WebDAV"], + "/api/v1/plugins/webdav/{sync_id}/test", + response_model=WebDAVTestResponse, + tags=["WebDAV"], ) async def test_webdav_connection_endpoint(sync_id: str, _=Depends(verify_api_key)): """测试 WebDAV 连接""" @@ -6423,10 +6534,13 @@ async def regenerate_plugin_key(plugin_id: str, api_key: str = Depends(verify_ap @app.post( - "/api/v1/plugins/chrome/clip", response_model=ChromeClipResponse, tags=["Chrome Extension"], + "/api/v1/plugins/chrome/clip", + response_model=ChromeClipResponse, + tags=["Chrome Extension"], ) async def chrome_clip( - request: ChromeClipRequest, x_api_key: str | None = Header(None, alias="X-API-Key"), + request: ChromeClipRequest, + x_api_key: str | None = Header(None, alias="X-API-Key"), ): """Chrome 插件保存网页内容""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6499,7 +6613,9 @@ URL: {request.url} @app.post("/api/v1/bots/webhook/{platform}", response_model=BotMessageResponse, tags=["Bot"]) async def bot_webhook( - platform: str, request: Request, x_signature: str | None = Header(None, alias="X-Signature"), + platform: str, + request: Request, + x_signature: str | None = Header(None, alias="X-Signature"), ): """接收机器人 Webhook 消息(飞书/钉钉/Slack)""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6567,7 +6683,9 @@ async def list_bot_sessions( @app.post( - "/api/v1/webhook-endpoints", response_model=WebhookEndpointResponse, tags=["Integrations"], + "/api/v1/webhook-endpoints", + response_model=WebhookEndpointResponse, + tags=["Integrations"], ) async def create_integration_webhook_endpoint( plugin_id: str, @@ -6604,10 +6722,13 @@ async def create_integration_webhook_endpoint( @app.get( - "/api/v1/webhook-endpoints", response_model=list[WebhookEndpointResponse], tags=["Integrations"], + "/api/v1/webhook-endpoints", + response_model=list[WebhookEndpointResponse], + tags=["Integrations"], ) async def list_webhook_endpoints( - plugin_id: str | None = None, api_key: str = Depends(verify_api_key), + plugin_id: str | None = None, + api_key: str = Depends(verify_api_key), ): """列出 Webhook 端点""" if not PLUGIN_MANAGER_AVAILABLE: @@ -6801,7 +6922,9 @@ async def trigger_webdav_sync(sync_id: str, api_key: str = Depends(verify_api_ke # 简化版本,仅返回成功 manager.update_webdav_sync( - sync_id, last_sync_at=datetime.now().isoformat(), last_sync_status="running", + sync_id, + last_sync_at=datetime.now().isoformat(), + last_sync_status="running", ) return {"success": True, "sync_id": sync_id, "status": "running", "message": "Sync started"} @@ -7030,7 +7153,9 @@ async def get_audit_stats( tags=["Security"], ) async def enable_project_encryption( - project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key), + project_id: str, + request: EncryptionEnableRequest, + api_key: str = Depends(verify_api_key), ): """启用项目端到端加密""" if not SECURITY_MANAGER_AVAILABLE: @@ -7054,7 +7179,9 @@ async def enable_project_encryption( @app.post("/api/v1/projects/{project_id}/encryption/disable", tags=["Security"]) async def disable_project_encryption( - project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key), + project_id: str, + request: EncryptionEnableRequest, + api_key: str = Depends(verify_api_key), ): """禁用项目加密""" if not SECURITY_MANAGER_AVAILABLE: @@ -7071,7 +7198,9 @@ async def disable_project_encryption( @app.post("/api/v1/projects/{project_id}/encryption/verify", tags=["Security"]) async def verify_encryption_password( - project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key), + project_id: str, + request: EncryptionEnableRequest, + api_key: str = Depends(verify_api_key), ): """验证加密密码""" if not SECURITY_MANAGER_AVAILABLE: @@ -7118,7 +7247,9 @@ async def get_encryption_config(project_id: str, api_key: str = Depends(verify_a tags=["Security"], ) async def create_masking_rule( - project_id: str, request: MaskingRuleCreateRequest, api_key: str = Depends(verify_api_key), + project_id: str, + request: MaskingRuleCreateRequest, + api_key: str = Depends(verify_api_key), ): """创建数据脱敏规则""" if not SECURITY_MANAGER_AVAILABLE: @@ -7162,7 +7293,9 @@ async def create_masking_rule( tags=["Security"], ) async def get_masking_rules( - project_id: str, active_only: bool = True, api_key: str = Depends(verify_api_key), + project_id: str, + active_only: bool = True, + api_key: str = Depends(verify_api_key), ): """获取项目脱敏规则""" if not SECURITY_MANAGER_AVAILABLE: @@ -7261,7 +7394,9 @@ async def delete_masking_rule(rule_id: str, api_key: str = Depends(verify_api_ke tags=["Security"], ) async def apply_masking( - project_id: str, request: MaskingApplyRequest, api_key: str = Depends(verify_api_key), + project_id: str, + request: MaskingApplyRequest, + api_key: str = Depends(verify_api_key), ): """应用脱敏规则到文本""" if not SECURITY_MANAGER_AVAILABLE: @@ -7281,7 +7416,9 @@ async def apply_masking( applied_rules = [r.name for r in rules if r.is_active] return MaskingApplyResponse( - original_text=request.text, masked_text=masked_text, applied_rules=applied_rules, + original_text=request.text, + masked_text=masked_text, + applied_rules=applied_rules, ) @@ -7294,7 +7431,9 @@ async def apply_masking( tags=["Security"], ) async def create_access_policy( - project_id: str, request: AccessPolicyCreateRequest, api_key: str = Depends(verify_api_key), + project_id: str, + request: AccessPolicyCreateRequest, + api_key: str = Depends(verify_api_key), ): """创建数据访问策略""" if not SECURITY_MANAGER_AVAILABLE: @@ -7322,9 +7461,9 @@ async def create_access_policy( allowed_users=json.loads(policy.allowed_users) if policy.allowed_users else None, allowed_roles=json.loads(policy.allowed_roles) if policy.allowed_roles else None, allowed_ips=json.loads(policy.allowed_ips) if policy.allowed_ips else None, - time_restrictions=json.loads(policy.time_restrictions) - if policy.time_restrictions - else None, + time_restrictions=( + json.loads(policy.time_restrictions) if policy.time_restrictions else None + ), max_access_count=policy.max_access_count, require_approval=policy.require_approval, is_active=policy.is_active, @@ -7339,7 +7478,9 @@ async def create_access_policy( tags=["Security"], ) async def get_access_policies( - project_id: str, active_only: bool = True, api_key: str = Depends(verify_api_key), + project_id: str, + active_only: bool = True, + api_key: str = Depends(verify_api_key), ): """获取项目访问策略""" if not SECURITY_MANAGER_AVAILABLE: @@ -7357,9 +7498,9 @@ async def get_access_policies( allowed_users=json.loads(policy.allowed_users) if policy.allowed_users else None, allowed_roles=json.loads(policy.allowed_roles) if policy.allowed_roles else None, allowed_ips=json.loads(policy.allowed_ips) if policy.allowed_ips else None, - time_restrictions=json.loads(policy.time_restrictions) - if policy.time_restrictions - else None, + time_restrictions=( + json.loads(policy.time_restrictions) if policy.time_restrictions else None + ), max_access_count=policy.max_access_count, require_approval=policy.require_approval, is_active=policy.is_active, @@ -7372,7 +7513,10 @@ async def get_access_policies( @app.post("/api/v1/access-policies/{policy_id}/check", tags=["Security"]) async def check_access_permission( - policy_id: str, user_id: str, user_ip: str | None = None, api_key: str = Depends(verify_api_key), + policy_id: str, + user_id: str, + user_ip: str | None = None, + api_key: str = Depends(verify_api_key), ): """检查访问权限""" if not SECURITY_MANAGER_AVAILABLE: @@ -7459,7 +7603,9 @@ async def approve_access_request( tags=["Security"], ) async def reject_access_request( - request_id: str, rejected_by: str, api_key: str = Depends(verify_api_key), + request_id: str, + rejected_by: str, + api_key: str = Depends(verify_api_key), ): """拒绝访问请求""" if not SECURITY_MANAGER_AVAILABLE: @@ -7537,7 +7683,9 @@ class TeamMemberRoleUpdate(BaseModel): @app.post("/api/v1/projects/{project_id}/shares") async def create_share_link( - project_id: str, request: ShareLinkCreate, created_by: str = "current_user", + project_id: str, + request: ShareLinkCreate, + created_by: str = "current_user", ): """创建项目分享链接""" if not COLLABORATION_AVAILABLE: @@ -7677,7 +7825,10 @@ async def revoke_share_link(share_id: str, revoked_by: str = "current_user"): @app.post("/api/v1/projects/{project_id}/comments") async def add_comment( - project_id: str, request: CommentCreate, author: str = "current_user", author_name: str = "User", + project_id: str, + request: CommentCreate, + author: str = "current_user", + author_name: str = "User", ): """添加评论""" if not COLLABORATION_AVAILABLE: @@ -7909,7 +8060,9 @@ async def revert_change(record_id: str, reverted_by: str = "current_user"): @app.post("/api/v1/projects/{project_id}/members") async def invite_team_member( - project_id: str, request: TeamMemberInvite, invited_by: str = "current_user", + project_id: str, + request: TeamMemberInvite, + invited_by: str = "current_user", ): """邀请团队成员""" if not COLLABORATION_AVAILABLE: @@ -7965,7 +8118,9 @@ async def list_team_members(project_id: str): @app.put("/api/v1/members/{member_id}/role") async def update_member_role( - member_id: str, request: TeamMemberRoleUpdate, updated_by: str = "current_user", + member_id: str, + request: TeamMemberRoleUpdate, + updated_by: str = "current_user", ): """更新成员角色""" if not COLLABORATION_AVAILABLE: @@ -8039,7 +8194,9 @@ class SemanticSearchRequest(BaseModel): @app.post("/api/v1/search/fulltext", tags=["Search"]) async def fulltext_search( - project_id: str, request: FullTextSearchRequest, _=Depends(verify_api_key), + project_id: str, + request: FullTextSearchRequest, + _=Depends(verify_api_key), ): """全文搜索""" if not SEARCH_MANAGER_AVAILABLE: @@ -8080,7 +8237,9 @@ async def fulltext_search( @app.post("/api/v1/search/semantic", tags=["Search"]) async def semantic_search( - project_id: str, request: SemanticSearchRequest, _=Depends(verify_api_key), + project_id: str, + request: SemanticSearchRequest, + _=Depends(verify_api_key), ): """语义搜索""" if not SEARCH_MANAGER_AVAILABLE: @@ -8123,11 +8282,15 @@ async def find_entity_paths( if find_all: paths = search_manager.path_discovery.find_all_paths( - source_entity_id=entity_id, target_entity_id=target_entity_id, max_depth=max_depth, + source_entity_id=entity_id, + target_entity_id=target_entity_id, + max_depth=max_depth, ) else: path = search_manager.path_discovery.find_shortest_path( - source_entity_id=entity_id, target_entity_id=target_entity_id, max_depth=max_depth, + source_entity_id=entity_id, + target_entity_id=target_entity_id, + max_depth=max_depth, ) paths = [path] if path else [] @@ -8260,7 +8423,10 @@ async def get_performance_metrics( start_time = (datetime.now() - timedelta(hours=hours)).isoformat() metrics = perf_manager.monitor.get_metrics( - metric_type=metric_type, endpoint=endpoint, start_time=start_time, limit=limit, + metric_type=metric_type, + endpoint=endpoint, + start_time=start_time, + limit=limit, ) return { @@ -8364,7 +8530,8 @@ async def cancel_task(task_id: str, _=Depends(verify_api_key)): return {"message": "Task cancelled successfully", "task_id": task_id} else: raise HTTPException( - status_code=400, detail="Failed to cancel task or task already completed", + status_code=400, + detail="Failed to cancel task or task already completed", ) @@ -8449,7 +8616,10 @@ async def create_tenant( manager = get_tenant_manager() try: tenant = manager.create_tenant( - name=request.name, owner_id=user_id, tier=request.tier, description=request.description, + name=request.name, + owner_id=user_id, + tier=request.tier, + description=request.description, ) return { "id": tenant.id, @@ -8465,7 +8635,8 @@ async def create_tenant( @app.get("/api/v1/tenants", tags=["Tenants"]) async def list_my_tenants( - user_id: str = Header(..., description="当前用户ID"), _=Depends(verify_api_key), + user_id: str = Header(..., description="当前用户ID"), + _=Depends(verify_api_key), ): """获取当前用户的所有租户""" if not TENANT_MANAGER_AVAILABLE: @@ -8557,7 +8728,9 @@ async def add_domain(tenant_id: str, request: AddDomainRequest, _=Depends(verify manager = get_tenant_manager() try: domain = manager.add_domain( - tenant_id=tenant_id, domain=request.domain, is_primary=request.is_primary, + tenant_id=tenant_id, + domain=request.domain, + is_primary=request.is_primary, ) # 获取验证指导 @@ -8667,7 +8840,9 @@ async def get_branding(tenant_id: str, _=Depends(verify_api_key)): @app.put("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"]) async def update_branding( - tenant_id: str, request: UpdateBrandingRequest, _=Depends(verify_api_key), + tenant_id: str, + request: UpdateBrandingRequest, + _=Depends(verify_api_key), ): """更新租户品牌配置""" if not TENANT_MANAGER_AVAILABLE: @@ -8724,7 +8899,10 @@ async def invite_member( manager = get_tenant_manager() try: member = manager.invite_member( - tenant_id=tenant_id, email=request.email, role=request.role, invited_by=user_id, + tenant_id=tenant_id, + email=request.email, + role=request.role, + invited_by=user_id, ) return { @@ -8767,7 +8945,10 @@ async def list_members(tenant_id: str, status: str | None = None, _=Depends(veri @app.put("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"]) async def update_member( - tenant_id: str, member_id: str, request: UpdateMemberRequest, _=Depends(verify_api_key), + tenant_id: str, + member_id: str, + request: UpdateMemberRequest, + _=Depends(verify_api_key), ): """更新成员角色""" if not TENANT_MANAGER_AVAILABLE: @@ -8874,7 +9055,7 @@ async def detailed_health_check(): conn.close() health["components"]["database"] = "ok" except Exception as e: - health["components"]["database"] = f"error: {str(e)}" + health["components"]["database"] = f"error: {e!s}" health["status"] = "degraded" else: health["components"]["database"] = "unavailable" @@ -8888,7 +9069,7 @@ async def detailed_health_check(): if perf_health.get("overall") != "healthy": health["status"] = "degraded" except Exception as e: - health["components"]["performance"] = f"error: {str(e)}" + health["components"]["performance"] = f"error: {e!s}" health["status"] = "degraded" # 搜索管理器检查 @@ -8916,7 +9097,8 @@ class TenantCreate(BaseModel): slug: str = Field(..., description="URL 友好的唯一标识(小写字母、数字、连字符)") description: str = Field(default="", description="租户描述") plan: str = Field( - default="free", description="套餐类型: free, starter, professional, enterprise", + default="free", + description="套餐类型: free, starter, professional, enterprise", ) billing_email: str = Field(default="", description="计费邮箱") @@ -9078,7 +9260,10 @@ async def list_tenants_endpoint( plan_enum = TenantTier(plan) if plan else None tenants = tenant_manager.list_tenants( - status=status_enum, plan=plan_enum, limit=limit, offset=offset, + status=status_enum, + plan=plan_enum, + limit=limit, + offset=offset, ) return [t.to_dict() for t in tenants] @@ -9152,10 +9337,14 @@ async def delete_tenant_endpoint(tenant_id: str, _=Depends(verify_api_key)): @app.post( - "/api/v1/tenants/{tenant_id}/domains", response_model=TenantDomainResponse, tags=["Tenants"], + "/api/v1/tenants/{tenant_id}/domains", + response_model=TenantDomainResponse, + tags=["Tenants"], ) async def add_tenant_domain_endpoint( - tenant_id: str, domain: TenantDomainCreate, _=Depends(verify_api_key), + tenant_id: str, + domain: TenantDomainCreate, + _=Depends(verify_api_key), ): """为租户添加自定义域名""" if not TENANT_MANAGER_AVAILABLE: @@ -9207,7 +9396,9 @@ async def verify_tenant_domain_endpoint(tenant_id: str, domain_id: str, _=Depend @app.post("/api/v1/tenants/{tenant_id}/domains/{domain_id}/activate", tags=["Tenants"]) async def activate_tenant_domain_endpoint( - tenant_id: str, domain_id: str, _=Depends(verify_api_key), + tenant_id: str, + domain_id: str, + _=Depends(verify_api_key), ): """激活已验证的域名""" if not TENANT_MANAGER_AVAILABLE: @@ -9257,7 +9448,9 @@ async def get_tenant_branding_endpoint(tenant_id: str, _=Depends(verify_api_key) @app.put("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"]) async def update_tenant_branding_endpoint( - tenant_id: str, branding: TenantBrandingUpdate, _=Depends(verify_api_key), + tenant_id: str, + branding: TenantBrandingUpdate, + _=Depends(verify_api_key), ): """更新租户品牌配置""" if not TENANT_MANAGER_AVAILABLE: @@ -9299,7 +9492,10 @@ async def get_tenant_theme_css_endpoint(tenant_id: str): tags=["Tenants"], ) async def invite_tenant_member_endpoint( - tenant_id: str, invite: TenantMemberInvite, request: Request, _=Depends(verify_api_key), + tenant_id: str, + invite: TenantMemberInvite, + request: Request, + _=Depends(verify_api_key), ): """邀请成员加入租户""" if not TENANT_MANAGER_AVAILABLE: @@ -9346,7 +9542,10 @@ async def accept_invitation_endpoint(token: str, user_id: str): tags=["Tenants"], ) async def list_tenant_members_endpoint( - tenant_id: str, status: str | None = None, role: str | None = None, _=Depends(verify_api_key), + tenant_id: str, + status: str | None = None, + role: str | None = None, + _=Depends(verify_api_key), ): """列出租户成员""" if not TENANT_MANAGER_AVAILABLE: @@ -9363,7 +9562,11 @@ async def list_tenant_members_endpoint( @app.put("/api/v1/tenants/{tenant_id}/members/{member_id}/role", tags=["Tenants"]) async def update_member_role_endpoint( - tenant_id: str, member_id: str, role: str, request: Request, _=Depends(verify_api_key), + tenant_id: str, + member_id: str, + role: str, + request: Request, + _=Depends(verify_api_key), ): """更新成员角色""" if not TENANT_MANAGER_AVAILABLE: @@ -9392,7 +9595,10 @@ async def update_member_role_endpoint( @app.delete("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"]) async def remove_tenant_member_endpoint( - tenant_id: str, member_id: str, request: Request, _=Depends(verify_api_key), + tenant_id: str, + member_id: str, + request: Request, + _=Depends(verify_api_key), ): """移除租户成员""" if not TENANT_MANAGER_AVAILABLE: @@ -9418,7 +9624,9 @@ async def remove_tenant_member_endpoint( @app.get( - "/api/v1/tenants/{tenant_id}/roles", response_model=list[TenantRoleResponse], tags=["Tenants"], + "/api/v1/tenants/{tenant_id}/roles", + response_model=list[TenantRoleResponse], + tags=["Tenants"], ) async def list_tenant_roles_endpoint(tenant_id: str, _=Depends(verify_api_key)): """列出租户角色""" @@ -9432,7 +9640,9 @@ async def list_tenant_roles_endpoint(tenant_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/tenants/{tenant_id}/roles", response_model=TenantRoleResponse, tags=["Tenants"]) async def create_tenant_role_endpoint( - tenant_id: str, role: TenantRoleCreate, _=Depends(verify_api_key), + tenant_id: str, + role: TenantRoleCreate, + _=Depends(verify_api_key), ): """创建自定义角色""" if not TENANT_MANAGER_AVAILABLE: @@ -9454,7 +9664,10 @@ async def create_tenant_role_endpoint( @app.put("/api/v1/tenants/{tenant_id}/roles/{role_id}/permissions", tags=["Tenants"]) async def update_role_permissions_endpoint( - tenant_id: str, role_id: str, permissions: list[str], _=Depends(verify_api_key), + tenant_id: str, + role_id: str, + permissions: list[str], + _=Depends(verify_api_key), ): """更新角色权限""" if not TENANT_MANAGER_AVAILABLE: @@ -9549,7 +9762,8 @@ class CreateSubscriptionRequest(BaseModel): plan_id: str = Field(..., description="订阅计划ID") billing_cycle: str = Field(default="monthly", description="计费周期: monthly/yearly") payment_provider: str | None = Field( - default=None, description="支付提供商: stripe/alipay/wechat", + default=None, + description="支付提供商: stripe/alipay/wechat", ) trial_days: int = Field(default=0, description="试用天数") @@ -9687,9 +9901,9 @@ async def create_subscription( "status": subscription.status, "current_period_start": subscription.current_period_start.isoformat(), "current_period_end": subscription.current_period_end.isoformat(), - "trial_start": subscription.trial_start.isoformat() - if subscription.trial_start - else None, + "trial_start": ( + subscription.trial_start.isoformat() if subscription.trial_start else None + ), "trial_end": subscription.trial_end.isoformat() if subscription.trial_end else None, "created_at": subscription.created_at.isoformat(), } @@ -9722,12 +9936,12 @@ async def get_tenant_subscription(tenant_id: str, _=Depends(verify_api_key)): "current_period_start": subscription.current_period_start.isoformat(), "current_period_end": subscription.current_period_end.isoformat(), "cancel_at_period_end": subscription.cancel_at_period_end, - "canceled_at": subscription.canceled_at.isoformat() - if subscription.canceled_at - else None, - "trial_start": subscription.trial_start.isoformat() - if subscription.trial_start - else None, + "canceled_at": ( + subscription.canceled_at.isoformat() if subscription.canceled_at else None + ), + "trial_start": ( + subscription.trial_start.isoformat() if subscription.trial_start else None + ), "trial_end": subscription.trial_end.isoformat() if subscription.trial_end else None, "created_at": subscription.created_at.isoformat(), }, @@ -9736,7 +9950,9 @@ async def get_tenant_subscription(tenant_id: str, _=Depends(verify_api_key)): @app.put("/api/v1/tenants/{tenant_id}/subscription/change-plan", tags=["Subscriptions"]) async def change_subscription_plan( - tenant_id: str, request: ChangePlanRequest, _=Depends(verify_api_key), + tenant_id: str, + request: ChangePlanRequest, + _=Depends(verify_api_key), ): """更改订阅计划""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9767,7 +9983,9 @@ async def change_subscription_plan( @app.post("/api/v1/tenants/{tenant_id}/subscription/cancel", tags=["Subscriptions"]) async def cancel_subscription( - tenant_id: str, request: CancelSubscriptionRequest, _=Depends(verify_api_key), + tenant_id: str, + request: CancelSubscriptionRequest, + _=Depends(verify_api_key), ): """取消订阅""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -9781,7 +9999,8 @@ async def cancel_subscription( try: updated = manager.cancel_subscription( - subscription_id=subscription.id, at_period_end=request.at_period_end, + subscription_id=subscription.id, + at_period_end=request.at_period_end, ) return { @@ -10144,7 +10363,9 @@ async def get_billing_history( @app.post("/api/v1/tenants/{tenant_id}/checkout/stripe", tags=["Subscriptions"]) async def create_stripe_checkout( - tenant_id: str, request: CreateCheckoutSessionRequest, _=Depends(verify_api_key), + tenant_id: str, + request: CreateCheckoutSessionRequest, + _=Depends(verify_api_key), ): """创建 Stripe Checkout 会话""" if not SUBSCRIPTION_MANAGER_AVAILABLE: @@ -10181,7 +10402,9 @@ async def create_alipay_order( try: order = manager.create_alipay_order( - tenant_id=tenant_id, plan_id=plan_id, billing_cycle=billing_cycle, + tenant_id=tenant_id, + plan_id=plan_id, + billing_cycle=billing_cycle, ) return order @@ -10204,7 +10427,9 @@ async def create_wechat_order( try: order = manager.create_wechat_order( - tenant_id=tenant_id, plan_id=plan_id, billing_cycle=billing_cycle, + tenant_id=tenant_id, + plan_id=plan_id, + billing_cycle=billing_cycle, ) return order @@ -10273,7 +10498,8 @@ async def wechat_webhook(request: Request): class SSOConfigCreate(BaseModel): provider: str = Field( - ..., description="SSO 提供商: wechat_work/dingtalk/feishu/okta/azure_ad/google/custom_saml", + ..., + description="SSO 提供商: wechat_work/dingtalk/feishu/okta/azure_ad/google/custom_saml", ) entity_id: str | None = Field(default=None, description="SAML Entity ID") sso_url: str | None = Field(default=None, description="SAML SSO URL") @@ -10337,7 +10563,8 @@ class AuditExportCreate(BaseModel): end_date: str = Field(..., description="结束日期 (ISO 格式)") filters: dict[str, Any] | None = Field(default_factory=dict, description="过滤条件") compliance_standard: str | None = Field( - default=None, description="合规标准: soc2/iso27001/gdpr/hipaa/pci_dss", + default=None, + description="合规标准: soc2/iso27001/gdpr/hipaa/pci_dss", ) @@ -10345,7 +10572,8 @@ class RetentionPolicyCreate(BaseModel): name: str = Field(..., description="策略名称") description: str | None = Field(default=None, description="策略描述") resource_type: str = Field( - ..., description="资源类型: project/transcript/entity/audit_log/user_data", + ..., + description="资源类型: project/transcript/entity/audit_log/user_data", ) retention_days: int = Field(..., description="保留天数") action: str = Field(..., description="动作: archive/delete/anonymize") @@ -10376,7 +10604,9 @@ class RetentionPolicyUpdate(BaseModel): @app.post("/api/v1/tenants/{tenant_id}/sso-configs", tags=["Enterprise"]) async def create_sso_config_endpoint( - tenant_id: str, config: SSOConfigCreate, _=Depends(verify_api_key), + tenant_id: str, + config: SSOConfigCreate, + _=Depends(verify_api_key), ): """创建 SSO 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10487,7 +10717,10 @@ async def get_sso_config_endpoint(tenant_id: str, config_id: str, _=Depends(veri @app.put("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"]) async def update_sso_config_endpoint( - tenant_id: str, config_id: str, update: SSOConfigUpdate, _=Depends(verify_api_key), + tenant_id: str, + config_id: str, + update: SSOConfigUpdate, + _=Depends(verify_api_key), ): """更新 SSO 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10500,7 +10733,8 @@ async def update_sso_config_endpoint( raise HTTPException(status_code=404, detail="SSO config not found") updated = manager.update_sso_config( - config_id=config_id, **{k: v for k, v in update.dict().items() if v is not None}, + config_id=config_id, + **{k: v for k, v in update.dict().items() if v is not None}, ) return { @@ -10558,7 +10792,9 @@ async def get_sso_metadata_endpoint( @app.post("/api/v1/tenants/{tenant_id}/scim-configs", tags=["Enterprise"]) async def create_scim_config_endpoint( - tenant_id: str, config: SCIMConfigCreate, _=Depends(verify_api_key), + tenant_id: str, + config: SCIMConfigCreate, + _=Depends(verify_api_key), ): """创建 SCIM 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10618,7 +10854,10 @@ async def get_scim_config_endpoint(tenant_id: str, _=Depends(verify_api_key)): @app.put("/api/v1/tenants/{tenant_id}/scim-configs/{config_id}", tags=["Enterprise"]) async def update_scim_config_endpoint( - tenant_id: str, config_id: str, update: SCIMConfigUpdate, _=Depends(verify_api_key), + tenant_id: str, + config_id: str, + update: SCIMConfigUpdate, + _=Depends(verify_api_key), ): """更新 SCIM 配置""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10631,7 +10870,8 @@ async def update_scim_config_endpoint( raise HTTPException(status_code=404, detail="SCIM config not found") updated = manager.update_scim_config( - config_id=config_id, **{k: v for k, v in update.dict().items() if v is not None}, + config_id=config_id, + **{k: v for k, v in update.dict().items() if v is not None}, ) return { @@ -10835,7 +11075,9 @@ async def download_audit_export_endpoint( @app.post("/api/v1/tenants/{tenant_id}/retention-policies", tags=["Enterprise"]) async def create_retention_policy_endpoint( - tenant_id: str, policy: RetentionPolicyCreate, _=Depends(verify_api_key), + tenant_id: str, + policy: RetentionPolicyCreate, + _=Depends(verify_api_key), ): """创建数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10932,9 +11174,9 @@ async def get_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depend "archive_location": policy.archive_location, "archive_encryption": policy.archive_encryption, "is_active": policy.is_active, - "last_executed_at": policy.last_executed_at.isoformat() - if policy.last_executed_at - else None, + "last_executed_at": ( + policy.last_executed_at.isoformat() if policy.last_executed_at else None + ), "last_execution_result": policy.last_execution_result, "created_at": policy.created_at.isoformat(), } @@ -10942,7 +11184,10 @@ async def get_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depend @app.put("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"]) async def update_retention_policy_endpoint( - tenant_id: str, policy_id: str, update: RetentionPolicyUpdate, _=Depends(verify_api_key), + tenant_id: str, + policy_id: str, + update: RetentionPolicyUpdate, + _=Depends(verify_api_key), ): """更新数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10955,7 +11200,8 @@ async def update_retention_policy_endpoint( raise HTTPException(status_code=404, detail="Policy not found") updated = manager.update_retention_policy( - policy_id=policy_id, **{k: v for k, v in update.dict().items() if v is not None}, + policy_id=policy_id, + **{k: v for k, v in update.dict().items() if v is not None}, ) return {"id": updated.id, "updated_at": updated.updated_at.isoformat()} @@ -10963,7 +11209,9 @@ async def update_retention_policy_endpoint( @app.delete("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"]) async def delete_retention_policy_endpoint( - tenant_id: str, policy_id: str, _=Depends(verify_api_key), + tenant_id: str, + policy_id: str, + _=Depends(verify_api_key), ): """删除数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -10981,7 +11229,9 @@ async def delete_retention_policy_endpoint( @app.post("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}/execute", tags=["Enterprise"]) async def execute_retention_policy_endpoint( - tenant_id: str, policy_id: str, _=Depends(verify_api_key), + tenant_id: str, + policy_id: str, + _=Depends(verify_api_key), ): """执行数据保留策略""" if not ENTERPRISE_MANAGER_AVAILABLE: @@ -11406,7 +11656,9 @@ async def get_tenant_data_center(tenant_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/tenants/{tenant_id}/data-center", tags=["Localization"]) async def set_tenant_data_center( - tenant_id: str, request: DataCenterMappingRequest, _=Depends(verify_api_key), + tenant_id: str, + request: DataCenterMappingRequest, + _=Depends(verify_api_key), ): """设置租户数据中心""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11414,7 +11666,9 @@ async def set_tenant_data_center( manager = get_localization_manager() mapping = manager.set_tenant_data_center( - tenant_id=tenant_id, region_code=request.region_code, data_residency=request.data_residency, + tenant_id=tenant_id, + region_code=request.region_code, + data_residency=request.data_residency, ) return { @@ -11574,7 +11828,9 @@ async def get_localization_settings(tenant_id: str, _=Depends(verify_api_key)): @app.post("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"]) async def create_localization_settings( - tenant_id: str, request: LocalizationSettingsCreate, _=Depends(verify_api_key), + tenant_id: str, + request: LocalizationSettingsCreate, + _=Depends(verify_api_key), ): """创建租户本地化设置""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11608,7 +11864,9 @@ async def create_localization_settings( @app.put("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"]) async def update_localization_settings( - tenant_id: str, request: LocalizationSettingsUpdate, _=Depends(verify_api_key), + tenant_id: str, + request: LocalizationSettingsUpdate, + _=Depends(verify_api_key), ): """更新租户本地化设置""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11641,7 +11899,8 @@ async def update_localization_settings( @app.post("/api/v1/format/datetime", tags=["Localization"]) async def format_datetime_endpoint( - request: FormatDateTimeRequest, language: str = Query(default="en", description="语言代码"), + request: FormatDateTimeRequest, + language: str = Query(default="en", description="语言代码"), ): """格式化日期时间""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11655,7 +11914,10 @@ async def format_datetime_endpoint( raise HTTPException(status_code=400, detail="Invalid timestamp format") formatted = manager.format_datetime( - dt=dt, language=language, timezone=request.timezone, format_type=request.format_type, + dt=dt, + language=language, + timezone=request.timezone, + format_type=request.format_type, ) return { @@ -11669,7 +11931,8 @@ async def format_datetime_endpoint( @app.post("/api/v1/format/number", tags=["Localization"]) async def format_number_endpoint( - request: FormatNumberRequest, language: str = Query(default="en", description="语言代码"), + request: FormatNumberRequest, + language: str = Query(default="en", description="语言代码"), ): """格式化数字""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11677,7 +11940,9 @@ async def format_number_endpoint( manager = get_localization_manager() formatted = manager.format_number( - number=request.number, language=language, decimal_places=request.decimal_places, + number=request.number, + language=language, + decimal_places=request.decimal_places, ) return {"original": request.number, "formatted": formatted, "language": language} @@ -11685,7 +11950,8 @@ async def format_number_endpoint( @app.post("/api/v1/format/currency", tags=["Localization"]) async def format_currency_endpoint( - request: FormatCurrencyRequest, language: str = Query(default="en", description="语言代码"), + request: FormatCurrencyRequest, + language: str = Query(default="en", description="语言代码"), ): """格式化货币""" if not LOCALIZATION_MANAGER_AVAILABLE: @@ -11693,7 +11959,9 @@ async def format_currency_endpoint( manager = get_localization_manager() formatted = manager.format_currency( - amount=request.amount, currency=request.currency, language=language, + amount=request.amount, + currency=request.currency, + language=language, ) return { @@ -11738,7 +12006,8 @@ async def detect_locale( manager = get_localization_manager() preferences = manager.detect_user_preferences( - accept_language=accept_language, ip_country=ip_country, + accept_language=accept_language, + ip_country=ip_country, ) return preferences @@ -11940,7 +12209,10 @@ async def add_training_sample(model_id: str, request: AddTrainingSampleRequest): manager = get_ai_manager() sample = manager.add_training_sample( - model_id=model_id, text=request.text, entities=request.entities, metadata=request.metadata, + model_id=model_id, + text=request.text, + entities=request.entities, + metadata=request.metadata, ) return { @@ -12014,7 +12286,8 @@ async def predict_with_custom_model(request: PredictRequest): @app.post( - "/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/multimodal", tags=["AI Enhancement"], + "/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/multimodal", + tags=["AI Enhancement"], ) async def analyze_multimodal(tenant_id: str, project_id: str, request: MultimodalAnalysisRequest): """多模态分析""" @@ -12048,7 +12321,8 @@ async def analyze_multimodal(tenant_id: str, project_id: str, request: Multimoda @app.get("/api/v1/tenants/{tenant_id}/ai/multimodal", tags=["AI Enhancement"]) async def list_multimodal_analyses( - tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤"), + tenant_id: str, + project_id: str | None = Query(default=None, description="项目ID过滤"), ): """获取多模态分析历史""" if not AI_MANAGER_AVAILABLE: @@ -12107,7 +12381,8 @@ async def create_kg_rag(tenant_id: str, project_id: str, request: CreateKGRAGReq @app.get("/api/v1/tenants/{tenant_id}/ai/kg-rag", tags=["AI Enhancement"]) async def list_kg_rags( - tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤"), + tenant_id: str, + project_id: str | None = Query(default=None, description="项目ID过滤"), ): """列出知识图谱 RAG 配置""" if not AI_MANAGER_AVAILABLE: @@ -12225,7 +12500,9 @@ async def list_smart_summaries( tags=["AI Enhancement"], ) async def create_prediction_model( - tenant_id: str, project_id: str, request: CreatePredictionModelRequest, + tenant_id: str, + project_id: str, + request: CreatePredictionModelRequest, ): """创建预测模型""" if not AI_MANAGER_AVAILABLE: @@ -12259,7 +12536,8 @@ async def create_prediction_model( @app.get("/api/v1/tenants/{tenant_id}/ai/prediction-models", tags=["AI Enhancement"]) async def list_prediction_models( - tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤"), + tenant_id: str, + project_id: str | None = Query(default=None, description="项目ID过滤"), ): """列出预测模型""" if not AI_MANAGER_AVAILABLE: @@ -12318,7 +12596,8 @@ async def get_prediction_model(model_id: str): @app.post("/api/v1/ai/prediction-models/{model_id}/train", tags=["AI Enhancement"]) async def train_prediction_model( - model_id: str, historical_data: list[dict] = Body(..., description="历史训练数据"), + model_id: str, + historical_data: list[dict] = Body(..., description="历史训练数据"), ): """训练预测模型""" if not AI_MANAGER_AVAILABLE: @@ -12364,7 +12643,8 @@ async def predict(request: PredictDataRequest): @app.get("/api/v1/ai/prediction-models/{model_id}/results", tags=["AI Enhancement"]) async def get_prediction_results( - model_id: str, limit: int = Query(default=100, description="返回结果数量限制"), + model_id: str, + limit: int = Query(default=100, description="返回结果数量限制"), ): """获取预测结果历史""" if not AI_MANAGER_AVAILABLE: @@ -12579,7 +12859,9 @@ async def get_analytics_dashboard(tenant_id: str): @app.get("/api/v1/analytics/summary/{tenant_id}", tags=["Growth & Analytics"]) async def get_analytics_summary( - tenant_id: str, start_date: str | None = None, end_date: str | None = None, + tenant_id: str, + start_date: str | None = None, + end_date: str | None = None, ): """获取用户分析汇总""" if not GROWTH_MANAGER_AVAILABLE: @@ -12653,7 +12935,9 @@ async def create_funnel_endpoint(request: CreateFunnelRequest, created_by: str = @app.get("/api/v1/analytics/funnels/{funnel_id}/analyze", tags=["Growth & Analytics"]) async def analyze_funnel_endpoint( - funnel_id: str, period_start: str | None = None, period_end: str | None = None, + funnel_id: str, + period_start: str | None = None, + period_end: str | None = None, ): """分析漏斗转化率""" if not GROWTH_MANAGER_AVAILABLE: @@ -13386,7 +13670,8 @@ def get_developer_ecosystem_manager_instance() -> "DeveloperEcosystemManager | N @app.post("/api/v1/developer/sdks", tags=["Developer Ecosystem"]) async def create_sdk_release_endpoint( - request: SDKReleaseCreate, created_by: str = Header(default="system", description="创建者ID"), + request: SDKReleaseCreate, + created_by: str = Header(default="system", description="创建者ID"), ): """创建 SDK 发布""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -13822,7 +14107,8 @@ async def add_template_review_endpoint( @app.get("/api/v1/developer/templates/{template_id}/reviews", tags=["Developer Ecosystem"]) async def get_template_reviews_endpoint( - template_id: str, limit: int = Query(default=50, description="返回数量限制"), + template_id: str, + limit: int = Query(default=50, description="返回数量限制"), ): """获取模板评价""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -14075,7 +14361,8 @@ async def add_plugin_review_endpoint( @app.get("/api/v1/developer/plugins/{plugin_id}/reviews", tags=["Developer Ecosystem"]) async def get_plugin_reviews_endpoint( - plugin_id: str, limit: int = Query(default=50, description="返回数量限制"), + plugin_id: str, + limit: int = Query(default=50, description="返回数量限制"), ): """获取插件评价""" if not DEVELOPER_ECOSYSTEM_AVAILABLE: @@ -14600,7 +14887,8 @@ class AlertChannelCreate(BaseModel): ) config: dict = Field(default_factory=dict, description="渠道特定配置") severity_filter: list[str] = Field( - default_factory=lambda: ["p0", "p1", "p2", "p3"], description="过滤的告警级别", + default_factory=lambda: ["p0", "p1", "p2", "p3"], + description="过滤的告警级别", ) @@ -14660,7 +14948,8 @@ class HealthCheckResponse(BaseModel): class AutoScalingPolicyCreate(BaseModel): name: str = Field(..., description="策略名称") resource_type: str = Field( - ..., description="资源类型: cpu, memory, disk, network, gpu, database, cache, queue", + ..., + description="资源类型: cpu, memory, disk, network, gpu, database, cache, queue", ) min_instances: int = Field(default=1, description="最小实例数") max_instances: int = Field(default=10, description="最大实例数") @@ -14688,10 +14977,15 @@ class BackupJobCreate(BaseModel): @app.post( - "/api/v1/ops/alert-rules", response_model=AlertRuleResponse, tags=["Operations & Monitoring"], + "/api/v1/ops/alert-rules", + response_model=AlertRuleResponse, + tags=["Operations & Monitoring"], ) async def create_alert_rule_endpoint( - tenant_id: str, request: AlertRuleCreate, user_id: str = "system", _=Depends(verify_api_key), + tenant_id: str, + request: AlertRuleCreate, + user_id: str = "system", + _=Depends(verify_api_key), ): """创建告警规则""" if not OPS_MANAGER_AVAILABLE: @@ -14741,7 +15035,9 @@ async def create_alert_rule_endpoint( @app.get("/api/v1/ops/alert-rules", tags=["Operations & Monitoring"]) async def list_alert_rules_endpoint( - tenant_id: str, is_enabled: bool | None = None, _=Depends(verify_api_key), + tenant_id: str, + is_enabled: bool | None = None, + _=Depends(verify_api_key), ): """列出租户的告警规则""" if not OPS_MANAGER_AVAILABLE: @@ -14869,7 +15165,9 @@ async def delete_alert_rule_endpoint(rule_id: str, _=Depends(verify_api_key)): tags=["Operations & Monitoring"], ) async def create_alert_channel_endpoint( - tenant_id: str, request: AlertChannelCreate, _=Depends(verify_api_key), + tenant_id: str, + request: AlertChannelCreate, + _=Depends(verify_api_key), ): """创建告警渠道""" if not OPS_MANAGER_AVAILABLE: @@ -14988,7 +15286,9 @@ async def list_alerts_endpoint( @app.post("/api/v1/ops/alerts/{alert_id}/acknowledge", tags=["Operations & Monitoring"]) async def acknowledge_alert_endpoint( - alert_id: str, user_id: str = "system", _=Depends(verify_api_key), + alert_id: str, + user_id: str = "system", + _=Depends(verify_api_key), ): """确认告警""" if not OPS_MANAGER_AVAILABLE: @@ -15029,7 +15329,7 @@ async def record_resource_metric_endpoint( metric_name: str, metric_value: float, unit: str, - metadata: dict = None, + metadata: dict | None = None, _=Depends(verify_api_key), ): """记录资源指标""" @@ -15063,7 +15363,10 @@ async def record_resource_metric_endpoint( @app.get("/api/v1/ops/resource-metrics", tags=["Operations & Monitoring"]) async def get_resource_metrics_endpoint( - tenant_id: str, metric_name: str, seconds: int = 3600, _=Depends(verify_api_key), + tenant_id: str, + metric_name: str, + seconds: int = 3600, + _=Depends(verify_api_key), ): """获取资源指标数据""" if not OPS_MANAGER_AVAILABLE: @@ -15158,7 +15461,9 @@ async def list_capacity_plans_endpoint(tenant_id: str, _=Depends(verify_api_key) @app.post("/api/v1/ops/auto-scaling-policies", tags=["Operations & Monitoring"]) async def create_auto_scaling_policy_endpoint( - tenant_id: str, request: AutoScalingPolicyCreate, _=Depends(verify_api_key), + tenant_id: str, + request: AutoScalingPolicyCreate, + _=Depends(verify_api_key), ): """创建自动扩缩容策略""" if not OPS_MANAGER_AVAILABLE: @@ -15223,7 +15528,10 @@ async def list_auto_scaling_policies_endpoint(tenant_id: str, _=Depends(verify_a @app.get("/api/v1/ops/scaling-events", tags=["Operations & Monitoring"]) async def list_scaling_events_endpoint( - tenant_id: str, policy_id: str | None = None, limit: int = 100, _=Depends(verify_api_key), + tenant_id: str, + policy_id: str | None = None, + limit: int = 100, + _=Depends(verify_api_key), ): """获取扩缩容事件列表""" if not OPS_MANAGER_AVAILABLE: @@ -15257,7 +15565,9 @@ async def list_scaling_events_endpoint( tags=["Operations & Monitoring"], ) async def create_health_check_endpoint( - tenant_id: str, request: HealthCheckCreate, _=Depends(verify_api_key), + tenant_id: str, + request: HealthCheckCreate, + _=Depends(verify_api_key), ): """创建健康检查""" if not OPS_MANAGER_AVAILABLE: @@ -15339,7 +15649,9 @@ async def execute_health_check_endpoint(check_id: str, _=Depends(verify_api_key) @app.post("/api/v1/ops/backup-jobs", tags=["Operations & Monitoring"]) async def create_backup_job_endpoint( - tenant_id: str, request: BackupJobCreate, _=Depends(verify_api_key), + tenant_id: str, + request: BackupJobCreate, + _=Depends(verify_api_key), ): """创建备份任务""" if not OPS_MANAGER_AVAILABLE: @@ -15417,7 +15729,10 @@ async def execute_backup_endpoint(job_id: str, _=Depends(verify_api_key)): @app.get("/api/v1/ops/backup-records", tags=["Operations & Monitoring"]) async def list_backup_records_endpoint( - tenant_id: str, job_id: str | None = None, limit: int = 100, _=Depends(verify_api_key), + tenant_id: str, + job_id: str | None = None, + limit: int = 100, + _=Depends(verify_api_key), ): """获取备份记录列表""" if not OPS_MANAGER_AVAILABLE: @@ -15446,7 +15761,10 @@ async def list_backup_records_endpoint( @app.post("/api/v1/ops/cost-reports", tags=["Operations & Monitoring"]) async def generate_cost_report_endpoint( - tenant_id: str, year: int, month: int, _=Depends(verify_api_key), + tenant_id: str, + year: int, + month: int, + _=Depends(verify_api_key), ): """生成成本报告""" if not OPS_MANAGER_AVAILABLE: @@ -15494,7 +15812,8 @@ async def get_idle_resources_endpoint(tenant_id: str, _=Depends(verify_api_key)) @app.post("/api/v1/ops/cost-optimization-suggestions", tags=["Operations & Monitoring"]) async def generate_cost_optimization_suggestions_endpoint( - tenant_id: str, _=Depends(verify_api_key), + tenant_id: str, + _=Depends(verify_api_key), ): """生成成本优化建议""" if not OPS_MANAGER_AVAILABLE: @@ -15523,7 +15842,9 @@ async def generate_cost_optimization_suggestions_endpoint( @app.get("/api/v1/ops/cost-optimization-suggestions", tags=["Operations & Monitoring"]) async def list_cost_optimization_suggestions_endpoint( - tenant_id: str, is_applied: bool | None = None, _=Depends(verify_api_key), + tenant_id: str, + is_applied: bool | None = None, + _=Depends(verify_api_key), ): """获取成本优化建议列表""" if not OPS_MANAGER_AVAILABLE: @@ -15554,7 +15875,8 @@ async def list_cost_optimization_suggestions_endpoint( tags=["Operations & Monitoring"], ) async def apply_cost_optimization_suggestion_endpoint( - suggestion_id: str, _=Depends(verify_api_key), + suggestion_id: str, + _=Depends(verify_api_key), ): """应用成本优化建议""" if not OPS_MANAGER_AVAILABLE: diff --git a/backend/multimodal_entity_linker.py b/backend/multimodal_entity_linker.py index fc6feea..022ff54 100644 --- a/backend/multimodal_entity_linker.py +++ b/backend/multimodal_entity_linker.py @@ -30,7 +30,7 @@ class MultimodalEntity: source_id: str mention_context: str confidence: float - modality_features: dict = None # 模态特定特征 + modality_features: dict | None = None # 模态特定特征 def __post_init__(self) -> None: if self.modality_features is None: @@ -137,7 +137,8 @@ class MultimodalEntityLinker: """ # 名称相似度 name_sim = self.calculate_string_similarity( - entity1.get("name", ""), entity2.get("name", ""), + entity1.get("name", ""), + entity2.get("name", ""), ) # 如果名称完全匹配 @@ -158,7 +159,8 @@ class MultimodalEntityLinker: # 定义相似度 def_sim = self.calculate_string_similarity( - entity1.get("definition", ""), entity2.get("definition", ""), + entity1.get("definition", ""), + entity2.get("definition", ""), ) # 综合相似度 @@ -170,7 +172,10 @@ class MultimodalEntityLinker: return combined_sim, "none" def find_matching_entity( - self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None, + self, + query_entity: dict, + candidate_entities: list[dict], + exclude_ids: set[str] = None, ) -> AlignmentResult | None: """ 在候选实体中查找匹配的实体 @@ -270,7 +275,10 @@ class MultimodalEntityLinker: return links def fuse_entity_knowledge( - self, entity_id: str, linked_entities: list[dict], multimodal_mentions: list[dict], + self, + entity_id: str, + linked_entities: list[dict], + multimodal_mentions: list[dict], ) -> FusionResult: """ 融合多模态实体知识 @@ -394,7 +402,9 @@ class MultimodalEntityLinker: return conflicts def suggest_entity_merges( - self, entities: list[dict], existing_links: list[EntityLink] = None, + self, + entities: list[dict], + existing_links: list[EntityLink] = None, ) -> list[dict]: """ 建议实体合并 @@ -510,9 +520,9 @@ class MultimodalEntityLinker: "total_multimodal_records": len(multimodal_entities), "unique_entities": len(entity_modalities), "cross_modal_entities": cross_modal_count, - "cross_modal_ratio": cross_modal_count / len(entity_modalities) - if entity_modalities - else 0, + "cross_modal_ratio": ( + cross_modal_count / len(entity_modalities) if entity_modalities else 0 + ), } diff --git a/backend/multimodal_processor.py b/backend/multimodal_processor.py index 4c3bb37..8ddf881 100644 --- a/backend/multimodal_processor.py +++ b/backend/multimodal_processor.py @@ -74,7 +74,7 @@ class VideoInfo: transcript_id: str = "" status: str = "pending" error_message: str = "" - metadata: dict = None + metadata: dict | None = None def __post_init__(self) -> None: if self.metadata is None: @@ -97,7 +97,7 @@ class VideoProcessingResult: class MultimodalProcessor: """多模态处理器 - 处理视频文件""" - def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None: + def __init__(self, temp_dir: str | None = None, frame_interval: int = 5) -> None: """ 初始化多模态处理器 @@ -130,10 +130,12 @@ class MultimodalProcessor: if FFMPEG_AVAILABLE: probe = ffmpeg.probe(video_path) video_stream = next( - (s for s in probe["streams"] if s["codec_type"] == "video"), None, + (s for s in probe["streams"] if s["codec_type"] == "video"), + None, ) audio_stream = next( - (s for s in probe["streams"] if s["codec_type"] == "audio"), None, + (s for s in probe["streams"] if s["codec_type"] == "audio"), + None, ) if video_stream: @@ -165,9 +167,9 @@ class MultimodalProcessor: return { "duration": float(data["format"].get("duration", 0)), "width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0, - "height": int(data["streams"][0].get("height", 0)) - if data["streams"] - else 0, + "height": ( + int(data["streams"][0].get("height", 0)) if data["streams"] else 0 + ), "fps": 30.0, # 默认值 "has_audio": len(data["streams"]) > 1, "bitrate": int(data["format"].get("bit_rate", 0)), @@ -177,7 +179,7 @@ class MultimodalProcessor: return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0} - def extract_audio(self, video_path: str, output_path: str = None) -> str: + def extract_audio(self, video_path: str, output_path: str | None = None) -> str: """ 从视频中提取音频 @@ -223,7 +225,9 @@ class MultimodalProcessor: print(f"Error extracting audio: {e}") raise - def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]: + def extract_keyframes( + self, video_path: str, video_id: str, interval: int | None = None + ) -> list[str]: """ 从视频中提取关键帧 @@ -260,7 +264,8 @@ class MultimodalProcessor: if frame_number % frame_interval_frames == 0: timestamp = frame_number / fps frame_path = os.path.join( - video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg", + video_frames_dir, + f"frame_{frame_number:06d}_{timestamp:.2f}.jpg", ) cv2.imwrite(frame_path, frame) frame_paths.append(frame_path) @@ -333,7 +338,11 @@ class MultimodalProcessor: return "", 0.0 def process_video( - self, video_data: bytes, filename: str, project_id: str, video_id: str = None, + self, + video_data: bytes, + filename: str, + project_id: str, + video_id: str | None = None, ) -> VideoProcessingResult: """ 处理视频文件:提取音频、关键帧、OCR @@ -426,7 +435,7 @@ class MultimodalProcessor: error_message=str(e), ) - def cleanup(self, video_id: str = None) -> None: + def cleanup(self, video_id: str | None = None) -> None: """ 清理临时文件 @@ -457,7 +466,9 @@ class MultimodalProcessor: _multimodal_processor = None -def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor: +def get_multimodal_processor( + temp_dir: str | None = None, frame_interval: int = 5 +) -> MultimodalProcessor: """获取多模态处理器单例""" global _multimodal_processor if _multimodal_processor is None: diff --git a/backend/neo4j_manager.py b/backend/neo4j_manager.py index 4a7e556..9ac9c44 100644 --- a/backend/neo4j_manager.py +++ b/backend/neo4j_manager.py @@ -37,7 +37,7 @@ class GraphEntity: type: str definition: str = "" aliases: list[str] = None - properties: dict = None + properties: dict | None = None def __post_init__(self) -> None: if self.aliases is None: @@ -55,7 +55,7 @@ class GraphRelation: target_id: str relation_type: str evidence: str = "" - properties: dict = None + properties: dict | None = None def __post_init__(self) -> None: if self.properties is None: @@ -95,7 +95,7 @@ class CentralityResult: class Neo4jManager: """Neo4j 图数据库管理器""" - def __init__(self, uri: str = None, user: str = None, password: str = None) -> None: + def __init__(self, uri: str | None = None, user: str = None, password: str = None) -> None: self.uri = uri or NEO4J_URI self.user = user or NEO4J_USER self.password = password or NEO4J_PASSWORD @@ -179,7 +179,10 @@ class Neo4jManager: # ==================== 数据同步 ==================== def sync_project( - self, project_id: str, project_name: str, project_description: str = "", + self, + project_id: str, + project_name: str, + project_description: str = "", ) -> None: """同步项目节点到 Neo4j""" if not self._driver: @@ -352,7 +355,10 @@ class Neo4jManager: # ==================== 复杂图查询 ==================== def find_shortest_path( - self, source_id: str, target_id: str, max_depth: int = 10, + self, + source_id: str, + target_id: str, + max_depth: int = 10, ) -> PathResult | None: """ 查找两个实体之间的最短路径 @@ -404,11 +410,17 @@ class Neo4jManager: ] return PathResult( - nodes=nodes, relationships=relationships, length=len(path.relationships), + nodes=nodes, + relationships=relationships, + length=len(path.relationships), ) def find_all_paths( - self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10, + self, + source_id: str, + target_id: str, + max_depth: int = 5, + limit: int = 10, ) -> list[PathResult]: """ 查找两个实体之间的所有路径 @@ -460,14 +472,19 @@ class Neo4jManager: paths.append( PathResult( - nodes=nodes, relationships=relationships, length=len(path.relationships), + nodes=nodes, + relationships=relationships, + length=len(path.relationships), ), ) return paths def find_neighbors( - self, entity_id: str, relation_type: str = None, limit: int = 50, + self, + entity_id: str, + relation_type: str | None = None, + limit: int = 50, ) -> list[dict]: """ 查找实体的邻居节点 @@ -752,7 +769,10 @@ class Neo4jManager: results.append( CommunityResult( - community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0), + community_id=comm_id, + nodes=nodes, + size=size, + density=min(density, 1.0), ), ) @@ -761,7 +781,9 @@ class Neo4jManager: return results def find_central_entities( - self, project_id: str, metric: str = "degree", + self, + project_id: str, + metric: str = "degree", ) -> list[CentralityResult]: """ 查找中心实体 @@ -896,9 +918,11 @@ class Neo4jManager: "type_distribution": types, "average_degree": round(avg_degree, 2) if avg_degree else 0, "relation_type_distribution": relation_types, - "density": round(relation_count / (entity_count * (entity_count - 1)), 4) - if entity_count > 1 - else 0, + "density": ( + round(relation_count / (entity_count * (entity_count - 1)), 4) + if entity_count > 1 + else 0 + ), } def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: @@ -993,7 +1017,10 @@ def close_neo4j_manager() -> None: def sync_project_to_neo4j( - project_id: str, project_name: str, entities: list[dict], relations: list[dict], + project_id: str, + project_name: str, + entities: list[dict], + relations: list[dict], ) -> None: """ 同步整个项目到 Neo4j diff --git a/backend/ops_manager.py b/backend/ops_manager.py index a436db1..4814a25 100644 --- a/backend/ops_manager.py +++ b/backend/ops_manager.py @@ -680,7 +680,8 @@ class OpsManager: """获取告警渠道""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM alert_channels WHERE id = ?", (channel_id,), + "SELECT * FROM alert_channels WHERE id = ?", + (channel_id,), ).fetchone() if row: @@ -819,7 +820,9 @@ class OpsManager: for rule in rules: # 获取相关指标 metrics = self.get_recent_metrics( - tenant_id, rule.metric, seconds=rule.duration + rule.evaluation_interval, + tenant_id, + rule.metric, + seconds=rule.duration + rule.evaluation_interval, ) # 评估规则 @@ -1129,7 +1132,9 @@ class OpsManager: async with httpx.AsyncClient() as client: response = await client.post( - "https://events.pagerduty.com/v2/enqueue", json=message, timeout=30.0, + "https://events.pagerduty.com/v2/enqueue", + json=message, + timeout=30.0, ) success = response.status_code == 202 self._update_channel_stats(channel.id, success) @@ -1299,12 +1304,16 @@ class OpsManager: conn.commit() def _update_alert_notification_status( - self, alert_id: str, channel_id: str, success: bool, + self, + alert_id: str, + channel_id: str, + success: bool, ) -> None: """更新告警通知状态""" with self._get_db() as conn: row = conn.execute( - "SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,), + "SELECT notification_sent FROM alerts WHERE id = ?", + (alert_id,), ).fetchone() if row: @@ -1394,7 +1403,8 @@ class OpsManager: """检查告警是否被抑制""" with self._get_db() as conn: rows = conn.execute( - "SELECT * FROM alert_suppression_rules WHERE tenant_id = ?", (rule.tenant_id,), + "SELECT * FROM alert_suppression_rules WHERE tenant_id = ?", + (rule.tenant_id,), ).fetchall() for row in rows: @@ -1436,7 +1446,7 @@ class OpsManager: metric_name: str, metric_value: float, unit: str, - metadata: dict = None, + metadata: dict | None = None, ) -> ResourceMetric: """记录资源指标""" metric_id = f"rm_{uuid.uuid4().hex[:16]}" @@ -1479,7 +1489,10 @@ class OpsManager: return metric def get_recent_metrics( - self, tenant_id: str, metric_name: str, seconds: int = 3600, + self, + tenant_id: str, + metric_name: str, + seconds: int = 3600, ) -> list[ResourceMetric]: """获取最近的指标数据""" cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat() @@ -1531,7 +1544,9 @@ class OpsManager: # 基于历史数据预测 metrics = self.get_recent_metrics( - tenant_id, f"{resource_type.value}_usage", seconds=30 * 24 * 3600, + tenant_id, + f"{resource_type.value}_usage", + seconds=30 * 24 * 3600, ) if metrics: @@ -1704,7 +1719,8 @@ class OpsManager: """获取自动扩缩容策略""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,), + "SELECT * FROM auto_scaling_policies WHERE id = ?", + (policy_id,), ).fetchone() if row: @@ -1721,7 +1737,10 @@ class OpsManager: return [self._row_to_auto_scaling_policy(row) for row in rows] def evaluate_scaling_policy( - self, policy_id: str, current_instances: int, current_utilization: float, + self, + policy_id: str, + current_instances: int, + current_utilization: float, ) -> ScalingEvent | None: """评估扩缩容策略""" policy = self.get_auto_scaling_policy(policy_id) @@ -1826,7 +1845,10 @@ class OpsManager: return None def update_scaling_event_status( - self, event_id: str, status: str, error_message: str = None, + self, + event_id: str, + status: str, + error_message: str | None = None, ) -> ScalingEvent | None: """更新扩缩容事件状态""" now = datetime.now().isoformat() @@ -1864,7 +1886,10 @@ class OpsManager: return None def list_scaling_events( - self, tenant_id: str, policy_id: str = None, limit: int = 100, + self, + tenant_id: str, + policy_id: str | None = None, + limit: int = 100, ) -> list[ScalingEvent]: """列出租户的扩缩容事件""" query = "SELECT * FROM scaling_events WHERE tenant_id = ?" @@ -2056,7 +2081,8 @@ class OpsManager: start_time = time.time() try: reader, writer = await asyncio.wait_for( - asyncio.open_connection(host, port), timeout=check.timeout, + asyncio.open_connection(host, port), + timeout=check.timeout, ) response_time = (time.time() - start_time) * 1000 writer.close() @@ -2101,7 +2127,7 @@ class OpsManager: failover_trigger: str, auto_failover: bool = False, failover_timeout: int = 300, - health_check_id: str = None, + health_check_id: str | None = None, ) -> FailoverConfig: """创建故障转移配置""" config_id = f"fc_{uuid.uuid4().hex[:16]}" @@ -2153,7 +2179,8 @@ class OpsManager: """获取故障转移配置""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM failover_configs WHERE id = ?", (config_id,), + "SELECT * FROM failover_configs WHERE id = ?", + (config_id,), ).fetchone() if row: @@ -2259,7 +2286,8 @@ class OpsManager: """获取故障转移事件""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM failover_events WHERE id = ?", (event_id,), + "SELECT * FROM failover_events WHERE id = ?", + (event_id,), ).fetchone() if row: @@ -2290,7 +2318,7 @@ class OpsManager: retention_days: int = 30, encryption_enabled: bool = True, compression_enabled: bool = True, - storage_location: str = None, + storage_location: str | None = None, ) -> BackupJob: """创建备份任务""" job_id = f"bj_{uuid.uuid4().hex[:16]}" @@ -2410,7 +2438,9 @@ class OpsManager: return record - def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None: + def _complete_backup( + self, record_id: str, size_bytes: int, checksum: str | None = None + ) -> None: """完成备份""" now = datetime.now().isoformat() checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] @@ -2430,7 +2460,8 @@ class OpsManager: """获取备份记录""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM backup_records WHERE id = ?", (record_id,), + "SELECT * FROM backup_records WHERE id = ?", + (record_id,), ).fetchone() if row: @@ -2438,7 +2469,10 @@ class OpsManager: return None def list_backup_records( - self, tenant_id: str, job_id: str = None, limit: int = 100, + self, + tenant_id: str, + job_id: str | None = None, + limit: int = 100, ) -> list[BackupRecord]: """列出租户的备份记录""" query = "SELECT * FROM backup_records WHERE tenant_id = ?" @@ -2624,7 +2658,9 @@ class OpsManager: return util def get_resource_utilizations( - self, tenant_id: str, report_period: str, + self, + tenant_id: str, + report_period: str, ) -> list[ResourceUtilization]: """获取资源利用率列表""" with self._get_db() as conn: @@ -2709,7 +2745,8 @@ class OpsManager: return [self._row_to_idle_resource(row) for row in rows] def generate_cost_optimization_suggestions( - self, tenant_id: str, + self, + tenant_id: str, ) -> list[CostOptimizationSuggestion]: """生成成本优化建议""" suggestions = [] @@ -2777,7 +2814,9 @@ class OpsManager: return suggestions def get_cost_optimization_suggestions( - self, tenant_id: str, is_applied: bool = None, + self, + tenant_id: str, + is_applied: bool | None = None, ) -> list[CostOptimizationSuggestion]: """获取成本优化建议""" query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?" @@ -2794,7 +2833,8 @@ class OpsManager: return [self._row_to_cost_optimization_suggestion(row) for row in rows] def apply_cost_optimization_suggestion( - self, suggestion_id: str, + self, + suggestion_id: str, ) -> CostOptimizationSuggestion | None: """应用成本优化建议""" now = datetime.now().isoformat() @@ -2813,12 +2853,14 @@ class OpsManager: return self.get_cost_optimization_suggestion(suggestion_id) def get_cost_optimization_suggestion( - self, suggestion_id: str, + self, + suggestion_id: str, ) -> CostOptimizationSuggestion | None: """获取成本优化建议详情""" with self._get_db() as conn: row = conn.execute( - "SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,), + "SELECT * FROM cost_optimization_suggestions WHERE id = ?", + (suggestion_id,), ).fetchone() if row: diff --git a/backend/performance_manager.py b/backend/performance_manager.py index 31d896e..be485f9 100644 --- a/backend/performance_manager.py +++ b/backend/performance_manager.py @@ -444,7 +444,8 @@ class CacheManager: "memory_size_bytes": self.current_memory_size, "max_memory_size_bytes": self.max_memory_size, "memory_usage_percent": round( - self.current_memory_size / self.max_memory_size * 100, 2, + self.current_memory_size / self.max_memory_size * 100, + 2, ), "cache_entries": len(self.memory_cache), }, @@ -548,11 +549,13 @@ class CacheManager: # 预热项目知识库摘要 entity_count = conn.execute( - "SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) FROM entities WHERE project_id = ?", + (project_id,), ).fetchone()[0] relation_count = conn.execute( - "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,), + "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", + (project_id,), ).fetchone()[0] summary = { @@ -757,11 +760,13 @@ class DatabaseSharding: source_conn.row_factory = sqlite3.Row entities = source_conn.execute( - "SELECT * FROM entities WHERE project_id = ?", (project_id,), + "SELECT * FROM entities WHERE project_id = ?", + (project_id,), ).fetchall() relations = source_conn.execute( - "SELECT * FROM entity_relations WHERE project_id = ?", (project_id,), + "SELECT * FROM entity_relations WHERE project_id = ?", + (project_id,), ).fetchall() source_conn.close() @@ -1061,7 +1066,9 @@ class TaskQueue: task.status = "retrying" # 延迟重试 threading.Timer( - 10 * task.retry_count, self._execute_task, args=(task_id,), + 10 * task.retry_count, + self._execute_task, + args=(task_id,), ).start() else: task.status = "failed" @@ -1163,7 +1170,10 @@ class TaskQueue: return self.tasks.get(task_id) def list_tasks( - self, status: str | None = None, task_type: str | None = None, limit: int = 100, + self, + status: str | None = None, + task_type: str | None = None, + limit: int = 100, ) -> list[TaskInfo]: """列出任务""" conn = sqlite3.connect(self.db_path) @@ -1635,7 +1645,7 @@ def cached( cache_key = key_func(*args, **kwargs) else: # 默认使用函数名和参数哈希 - key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}" + key_data = f"{func.__name__}:{args!s}:{kwargs!s}" cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}" # 尝试从缓存获取 @@ -1754,12 +1764,16 @@ _performance_manager = None def get_performance_manager( - db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False, + db_path: str = "insightflow.db", + redis_url: str | None = None, + enable_sharding: bool = False, ) -> PerformanceManager: """获取性能管理器单例""" global _performance_manager if _performance_manager is None: _performance_manager = PerformanceManager( - db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding, + db_path=db_path, + redis_url=redis_url, + enable_sharding=enable_sharding, ) return _performance_manager diff --git a/backend/plugin_manager.py b/backend/plugin_manager.py index 389d734..9e0114f 100644 --- a/backend/plugin_manager.py +++ b/backend/plugin_manager.py @@ -220,7 +220,10 @@ class PluginManager: return None def list_plugins( - self, project_id: str = None, plugin_type: str = None, status: str = None, + self, + project_id: str | None = None, + plugin_type: str = None, + status: str = None, ) -> list[Plugin]: """列出插件""" conn = self.db.get_conn() @@ -241,7 +244,8 @@ class PluginManager: where_clause = " AND ".join(conditions) if conditions else "1 = 1" rows = conn.execute( - f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params, + f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", + params, ).fetchall() conn.close() @@ -310,7 +314,11 @@ class PluginManager: # ==================== Plugin Config ==================== def set_plugin_config( - self, plugin_id: str, key: str, value: str, is_encrypted: bool = False, + self, + plugin_id: str, + key: str, + value: str, + is_encrypted: bool = False, ) -> PluginConfig: """设置插件配置""" conn = self.db.get_conn() @@ -367,7 +375,8 @@ class PluginManager: """获取插件所有配置""" conn = self.db.get_conn() rows = conn.execute( - "SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id,), + "SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", + (plugin_id,), ).fetchall() conn.close() @@ -377,7 +386,8 @@ class PluginManager: """删除插件配置""" conn = self.db.get_conn() cursor = conn.execute( - "DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key), + "DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", + (plugin_id, key), ) conn.commit() conn.close() @@ -408,10 +418,10 @@ class ChromeExtensionHandler: def create_token( self, name: str, - user_id: str = None, - project_id: str = None, + user_id: str | None = None, + project_id: str | None = None, permissions: list[str] = None, - expires_days: int = None, + expires_days: int | None = None, ) -> ChromeExtensionToken: """创建 Chrome 扩展令牌""" token_id = str(uuid.uuid4())[:UUID_LENGTH] @@ -512,7 +522,8 @@ class ChromeExtensionHandler: """撤销令牌""" conn = self.pm.db.get_conn() cursor = conn.execute( - "UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,), + "UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", + (token_id,), ) conn.commit() conn.close() @@ -520,7 +531,9 @@ class ChromeExtensionHandler: return cursor.rowcount > 0 def list_tokens( - self, user_id: str = None, project_id: str = None, + self, + user_id: str | None = None, + project_id: str = None, ) -> list[ChromeExtensionToken]: """列出令牌""" conn = self.pm.db.get_conn() @@ -569,7 +582,7 @@ class ChromeExtensionHandler: url: str, title: str, content: str, - html_content: str = None, + html_content: str | None = None, ) -> dict: """导入网页内容""" if not token.project_id: @@ -616,7 +629,7 @@ class BotHandler: self, session_id: str, session_name: str, - project_id: str = None, + project_id: str | None = None, webhook_url: str = "", secret: str = "", ) -> BotSession: @@ -674,7 +687,7 @@ class BotHandler: return self._row_to_session(row) return None - def list_sessions(self, project_id: str = None) -> list[BotSession]: + def list_sessions(self, project_id: str | None = None) -> list[BotSession]: """列出会话""" conn = self.pm.db.get_conn() @@ -849,7 +862,7 @@ class BotHandler: } except Exception as e: - return {"success": False, "error": f"Failed to process audio: {str(e)}"} + return {"success": False, "error": f"Failed to process audio: {e!s}"} async def _handle_file_message(self, session: BotSession, message: dict) -> dict: """处理文件消息""" @@ -897,12 +910,17 @@ class BotHandler: async with httpx.AsyncClient() as client: response = await client.post( - session.webhook_url, json=payload, headers={"Content-Type": "application/json"}, + session.webhook_url, + json=payload, + headers={"Content-Type": "application/json"}, ) return response.status_code == 200 async def _send_dingtalk_message( - self, session: BotSession, message: str, msg_type: str, + self, + session: BotSession, + message: str, + msg_type: str, ) -> bool: """发送钉钉消息""" timestamp = str(round(time.time() * 1000)) @@ -928,7 +946,9 @@ class BotHandler: async with httpx.AsyncClient() as client: response = await client.post( - url, json=payload, headers={"Content-Type": "application/json"}, + url, + json=payload, + headers={"Content-Type": "application/json"}, ) return response.status_code == 200 @@ -944,9 +964,9 @@ class WebhookIntegration: self, name: str, endpoint_url: str, - project_id: str = None, + project_id: str | None = None, auth_type: str = "none", - auth_config: dict = None, + auth_config: dict | None = None, trigger_events: list[str] = None, ) -> WebhookEndpoint: """创建 Webhook 端点""" @@ -1004,7 +1024,7 @@ class WebhookIntegration: return self._row_to_endpoint(row) return None - def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]: + def list_endpoints(self, project_id: str | None = None) -> list[WebhookEndpoint]: """列出端点""" conn = self.pm.db.get_conn() @@ -1115,7 +1135,10 @@ class WebhookIntegration: async with httpx.AsyncClient() as client: response = await client.post( - endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0, + endpoint.endpoint_url, + json=payload, + headers=headers, + timeout=30.0, ) success = response.status_code in [200, 201, 202] @@ -1229,7 +1252,7 @@ class WebDAVSyncManager: return self._row_to_sync(row) return None - def list_syncs(self, project_id: str = None) -> list[WebDAVSync]: + def list_syncs(self, project_id: str | None = None) -> list[WebDAVSync]: """列出同步配置""" conn = self.pm.db.get_conn() diff --git a/backend/rate_limiter.py b/backend/rate_limiter.py index d579a3a..75d2bb6 100644 --- a/backend/rate_limiter.py +++ b/backend/rate_limiter.py @@ -120,7 +120,10 @@ class RateLimiter: await counter.add_request() return RateLimitInfo( - allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0, + allowed=True, + remaining=remaining - 1, + reset_time=reset_time, + retry_after=0, ) async def get_limit_info(self, key: str) -> RateLimitInfo: @@ -145,9 +148,9 @@ class RateLimiter: allowed=current_count < config.requests_per_minute, remaining=remaining, reset_time=reset_time, - retry_after=max(0, config.window_size) - if current_count >= config.requests_per_minute - else 0, + retry_after=( + max(0, config.window_size) if current_count >= config.requests_per_minute else 0 + ), ) def reset(self, key: str | None = None) -> None: diff --git a/backend/search_manager.py b/backend/search_manager.py index 636413d..6b6ba5e 100644 --- a/backend/search_manager.py +++ b/backend/search_manager.py @@ -385,7 +385,7 @@ class FullTextSearch: # 排序和分页 scored_results.sort(key=lambda x: x.score, reverse=True) - return scored_results[offset: offset + limit] + return scored_results[offset : offset + limit] def _parse_boolean_query(self, query: str) -> dict: """ @@ -545,19 +545,24 @@ class FullTextSearch: return results def _get_content_by_id( - self, conn: sqlite3.Connection, content_id: str, content_type: str, + self, + conn: sqlite3.Connection, + content_id: str, + content_type: str, ) -> str | None: """根据ID获取内容""" try: if content_type == "transcript": row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", (content_id,), + "SELECT full_text FROM transcripts WHERE id = ?", + (content_id,), ).fetchone() return row["full_text"] if row else None elif content_type == "entity": row = conn.execute( - "SELECT name, definition FROM entities WHERE id = ?", (content_id,), + "SELECT name, definition FROM entities WHERE id = ?", + (content_id,), ).fetchone() if row: return f"{row['name']} {row['definition'] or ''}" @@ -583,21 +588,27 @@ class FullTextSearch: return None def _get_project_id( - self, conn: sqlite3.Connection, content_id: str, content_type: str, + self, + conn: sqlite3.Connection, + content_id: str, + content_type: str, ) -> str | None: """获取内容所属的项目ID""" try: if content_type == "transcript": row = conn.execute( - "SELECT project_id FROM transcripts WHERE id = ?", (content_id,), + "SELECT project_id FROM transcripts WHERE id = ?", + (content_id,), ).fetchone() elif content_type == "entity": row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", (content_id,), + "SELECT project_id FROM entities WHERE id = ?", + (content_id,), ).fetchone() elif content_type == "relation": row = conn.execute( - "SELECT project_id FROM entity_relations WHERE id = ?", (content_id,), + "SELECT project_id FROM entity_relations WHERE id = ?", + (content_id,), ).fetchone() else: return None @@ -880,7 +891,11 @@ class SemanticSearch: return None def index_embedding( - self, content_id: str, content_type: str, project_id: str, text: str, + self, + content_id: str, + content_type: str, + project_id: str, + text: str, ) -> bool: """ 为内容生成并保存 embedding @@ -1029,13 +1044,15 @@ class SemanticSearch: try: if content_type == "transcript": row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", (content_id,), + "SELECT full_text FROM transcripts WHERE id = ?", + (content_id,), ).fetchone() result = row["full_text"] if row else None elif content_type == "entity": row = conn.execute( - "SELECT name, definition FROM entities WHERE id = ?", (content_id,), + "SELECT name, definition FROM entities WHERE id = ?", + (content_id,), ).fetchone() result = f"{row['name']}: {row['definition']}" if row else None @@ -1067,7 +1084,10 @@ class SemanticSearch: return None def find_similar_content( - self, content_id: str, content_type: str, top_k: int = 5, + self, + content_id: str, + content_type: str, + top_k: int = 5, ) -> list[SemanticSearchResult]: """ 查找与指定内容相似的内容 @@ -1175,7 +1195,10 @@ class EntityPathDiscovery: return conn def find_shortest_path( - self, source_entity_id: str, target_entity_id: str, max_depth: int = 5, + self, + source_entity_id: str, + target_entity_id: str, + max_depth: int = 5, ) -> EntityPath | None: """ 查找两个实体之间的最短路径(BFS算法) @@ -1192,7 +1215,8 @@ class EntityPathDiscovery: # 获取项目ID row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", (source_entity_id,), + "SELECT project_id FROM entities WHERE id = ?", + (source_entity_id,), ).fetchone() if not row: @@ -1250,7 +1274,11 @@ class EntityPathDiscovery: return None def find_all_paths( - self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10, + self, + source_entity_id: str, + target_entity_id: str, + max_depth: int = 4, + max_paths: int = 10, ) -> list[EntityPath]: """ 查找两个实体之间的所有路径(限制数量和深度) @@ -1268,7 +1296,8 @@ class EntityPathDiscovery: # 获取项目ID row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", (source_entity_id,), + "SELECT project_id FROM entities WHERE id = ?", + (source_entity_id,), ).fetchone() if not row: @@ -1280,7 +1309,11 @@ class EntityPathDiscovery: paths = [] def dfs( - current_id: str, target_id: str, path: list[str], visited: set[str], depth: int, + current_id: str, + target_id: str, + path: list[str], + visited: set[str], + depth: int, ) -> None: if depth > max_depth: return @@ -1328,7 +1361,8 @@ class EntityPathDiscovery: nodes = [] for entity_id in entity_ids: row = conn.execute( - "SELECT id, name, type FROM entities WHERE id = ?", (entity_id,), + "SELECT id, name, type FROM entities WHERE id = ?", + (entity_id,), ).fetchone() if row: nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]}) @@ -1398,7 +1432,8 @@ class EntityPathDiscovery: # 获取项目ID row = conn.execute( - "SELECT project_id, name FROM entities WHERE id = ?", (entity_id,), + "SELECT project_id, name FROM entities WHERE id = ?", + (entity_id,), ).fetchone() if not row: @@ -1445,7 +1480,8 @@ class EntityPathDiscovery: # 获取邻居信息 neighbor_info = conn.execute( - "SELECT name, type FROM entities WHERE id = ?", (neighbor_id,), + "SELECT name, type FROM entities WHERE id = ?", + (neighbor_id,), ).fetchone() if neighbor_info: @@ -1458,7 +1494,10 @@ class EntityPathDiscovery: "relation_type": neighbor["relation_type"], "evidence": neighbor["evidence"], "path": self._get_path_to_entity( - entity_id, neighbor_id, project_id, conn, + entity_id, + neighbor_id, + project_id, + conn, ), }, ) @@ -1470,7 +1509,11 @@ class EntityPathDiscovery: return relations def _get_path_to_entity( - self, source_id: str, target_id: str, project_id: str, conn: sqlite3.Connection, + self, + source_id: str, + target_id: str, + project_id: str, + conn: sqlite3.Connection, ) -> list[str]: """获取从源实体到目标实体的路径(简化版)""" # BFS 找路径 @@ -1565,7 +1608,8 @@ class EntityPathDiscovery: # 获取所有实体 entities = conn.execute( - "SELECT id, name FROM entities WHERE project_id = ?", (project_id,), + "SELECT id, name FROM entities WHERE project_id = ?", + (project_id,), ).fetchall() # 计算每个实体作为桥梁的次数 @@ -1706,7 +1750,8 @@ class KnowledgeGapDetection: # 检查每个实体的属性完整性 entities = conn.execute( - "SELECT id, name FROM entities WHERE project_id = ?", (project_id,), + "SELECT id, name FROM entities WHERE project_id = ?", + (project_id,), ).fetchall() for entity in entities: @@ -1714,7 +1759,8 @@ class KnowledgeGapDetection: # 获取实体已有的属性 existing_attrs = conn.execute( - "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,), + "SELECT template_id FROM entity_attributes WHERE entity_id = ?", + (entity_id,), ).fetchall() existing_template_ids = {a["template_id"] for a in existing_attrs} @@ -1726,7 +1772,8 @@ class KnowledgeGapDetection: missing_names = [] for template_id in missing_templates: template = conn.execute( - "SELECT name FROM attribute_templates WHERE id = ?", (template_id,), + "SELECT name FROM attribute_templates WHERE id = ?", + (template_id,), ).fetchone() if template: missing_names.append(template["name"]) @@ -1759,7 +1806,8 @@ class KnowledgeGapDetection: # 获取所有实体及其关系数量 entities = conn.execute( - "SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,), + "SELECT id, name, type FROM entities WHERE project_id = ?", + (project_id,), ).fetchall() for entity in entities: @@ -1900,7 +1948,8 @@ class KnowledgeGapDetection: # 分析转录文本中频繁提及但未提取为实体的词 transcripts = conn.execute( - "SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,), + "SELECT full_text FROM transcripts WHERE project_id = ?", + (project_id,), ).fetchall() # 合并所有文本 @@ -1908,7 +1957,8 @@ class KnowledgeGapDetection: # 获取现有实体名称 existing_entities = conn.execute( - "SELECT name FROM entities WHERE project_id = ?", (project_id,), + "SELECT name FROM entities WHERE project_id = ?", + (project_id,), ).fetchall() existing_names = {e["name"].lower() for e in existing_entities} @@ -2146,7 +2196,10 @@ class SearchManager: for t in transcripts: if t["full_text"] and self.semantic_search.index_embedding( - t["id"], "transcript", t["project_id"], t["full_text"], + t["id"], + "transcript", + t["project_id"], + t["full_text"], ): semantic_stats["indexed"] += 1 else: @@ -2179,12 +2232,14 @@ class SearchManager: # 全文索引统计 fulltext_count = conn.execute( - f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params, + f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", + params, ).fetchone()["count"] # 语义索引统计 semantic_count = conn.execute( - f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params, + f"SELECT COUNT(*) as count FROM embeddings {where_clause}", + params, ).fetchone()["count"] # 按类型统计 @@ -2225,7 +2280,9 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: def fulltext_search( - query: str, project_id: str | None = None, limit: int = 20, + query: str, + project_id: str | None = None, + limit: int = 20, ) -> list[SearchResult]: """全文搜索便捷函数""" manager = get_search_manager() @@ -2233,7 +2290,9 @@ def fulltext_search( def semantic_search( - query: str, project_id: str | None = None, top_k: int = 10, + query: str, + project_id: str | None = None, + top_k: int = 10, ) -> list[SemanticSearchResult]: """语义搜索便捷函数""" manager = get_search_manager() diff --git a/backend/security_manager.py b/backend/security_manager.py index 6924e02..9e6a7e6 100644 --- a/backend/security_manager.py +++ b/backend/security_manager.py @@ -464,7 +464,9 @@ class SecurityManager: return logs def get_audit_stats( - self, start_time: str | None = None, end_time: str | None = None, + self, + start_time: str | None = None, + end_time: str | None = None, ) -> dict[str, Any]: """获取审计统计""" conn = sqlite3.connect(self.db_path) @@ -882,7 +884,10 @@ class SecurityManager: return success def apply_masking( - self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None, + self, + text: str, + project_id: str, + rule_types: list[MaskingRuleType] | None = None, ) -> str: """应用脱敏规则到文本""" rules = self.get_masking_rules(project_id) @@ -906,7 +911,9 @@ class SecurityManager: return masked_text def apply_masking_to_entity( - self, entity_data: dict[str, Any], project_id: str, + self, + entity_data: dict[str, Any], + project_id: str, ) -> dict[str, Any]: """对实体数据应用脱敏""" masked_data = entity_data.copy() @@ -982,7 +989,9 @@ class SecurityManager: return policy def get_access_policies( - self, project_id: str, active_only: bool = True, + self, + project_id: str, + active_only: bool = True, ) -> list[DataAccessPolicy]: """获取数据访问策略""" conn = sqlite3.connect(self.db_path) @@ -1021,14 +1030,18 @@ class SecurityManager: return policies def check_access_permission( - self, policy_id: str, user_id: str, user_ip: str | None = None, + self, + policy_id: str, + user_id: str, + user_ip: str | None = None, ) -> tuple[bool, str | None]: """检查访问权限""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( - "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,), + "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", + (policy_id,), ) row = cursor.fetchone() conn.close() @@ -1163,7 +1176,10 @@ class SecurityManager: return request def approve_access_request( - self, request_id: str, approved_by: str, expires_hours: int = 24, + self, + request_id: str, + approved_by: str, + expires_hours: int = 24, ) -> AccessRequest | None: """批准访问请求""" conn = sqlite3.connect(self.db_path) diff --git a/backend/subscription_manager.py b/backend/subscription_manager.py index ea8bdd4..c66e4df 100644 --- a/backend/subscription_manager.py +++ b/backend/subscription_manager.py @@ -588,7 +588,8 @@ class SubscriptionManager: try: cursor = conn.cursor() cursor.execute( - "SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,), + "SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", + (tier,), ) row = cursor.fetchone() @@ -963,7 +964,9 @@ class SubscriptionManager: conn.close() def cancel_subscription( - self, subscription_id: str, at_period_end: bool = True, + self, + subscription_id: str, + at_period_end: bool = True, ) -> Subscription | None: """取消订阅""" conn = self._get_connection() @@ -1017,7 +1020,10 @@ class SubscriptionManager: conn.close() def change_plan( - self, subscription_id: str, new_plan_id: str, prorate: bool = True, + self, + subscription_id: str, + new_plan_id: str, + prorate: bool = True, ) -> Subscription | None: """更改订阅计划""" conn = self._get_connection() @@ -1125,7 +1131,10 @@ class SubscriptionManager: conn.close() def get_usage_summary( - self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None, + self, + tenant_id: str, + start_date: datetime | None = None, + end_date: datetime | None = None, ) -> dict[str, Any]: """获取用量汇总""" conn = self._get_connection() @@ -1268,7 +1277,9 @@ class SubscriptionManager: conn.close() def confirm_payment( - self, payment_id: str, provider_payment_id: str | None = None, + self, + payment_id: str, + provider_payment_id: str | None = None, ) -> Payment | None: """确认支付完成""" conn = self._get_connection() @@ -1361,7 +1372,11 @@ class SubscriptionManager: conn.close() def list_payments( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0, + self, + tenant_id: str, + status: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[Payment]: """列出支付记录""" conn = self._get_connection() @@ -1501,7 +1516,11 @@ class SubscriptionManager: conn.close() def list_invoices( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0, + self, + tenant_id: str, + status: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[Invoice]: """列出发票""" conn = self._get_connection() @@ -1581,7 +1600,12 @@ class SubscriptionManager: # ==================== 退款管理 ==================== def request_refund( - self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str, + self, + tenant_id: str, + payment_id: str, + amount: float, + reason: str, + requested_by: str, ) -> Refund: """申请退款""" conn = self._get_connection() @@ -1690,7 +1714,9 @@ class SubscriptionManager: conn.close() def complete_refund( - self, refund_id: str, provider_refund_id: str | None = None, + self, + refund_id: str, + provider_refund_id: str | None = None, ) -> Refund | None: """完成退款""" conn = self._get_connection() @@ -1775,7 +1801,11 @@ class SubscriptionManager: conn.close() def list_refunds( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0, + self, + tenant_id: str, + status: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[Refund]: """列出退款记录""" conn = self._get_connection() @@ -1902,7 +1932,10 @@ class SubscriptionManager: } def create_alipay_order( - self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly", + self, + tenant_id: str, + plan_id: str, + billing_cycle: str = "monthly", ) -> dict[str, Any]: """创建支付宝订单(占位实现)""" # 这里应该集成支付宝 SDK @@ -1919,7 +1952,10 @@ class SubscriptionManager: } def create_wechat_order( - self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly", + self, + tenant_id: str, + plan_id: str, + billing_cycle: str = "monthly", ) -> dict[str, Any]: """创建微信支付订单(占位实现)""" # 这里应该集成微信支付 SDK diff --git a/backend/tenant_manager.py b/backend/tenant_manager.py index a6f9726..d68c8c8 100644 --- a/backend/tenant_manager.py +++ b/backend/tenant_manager.py @@ -433,7 +433,8 @@ class TenantManager: TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE ) resource_limits = self.DEFAULT_LIMITS.get( - tier_enum, self.DEFAULT_LIMITS[TenantTier.FREE], + tier_enum, + self.DEFAULT_LIMITS[TenantTier.FREE], ) tenant = Tenant( @@ -612,7 +613,11 @@ class TenantManager: conn.close() def list_tenants( - self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0, + self, + status: str | None = None, + tier: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[Tenant]: """列出租户""" conn = self._get_connection() @@ -1103,7 +1108,11 @@ class TenantManager: conn.close() def update_member_role( - self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None, + self, + tenant_id: str, + member_id: str, + role: str, + permissions: list[str] | None = None, ) -> bool: """更新成员角色""" conn = self._get_connection() @@ -1268,7 +1277,10 @@ class TenantManager: conn.close() def get_usage_stats( - self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None, + self, + tenant_id: str, + start_date: datetime | None = None, + end_date: datetime | None = None, ) -> dict[str, Any]: """获取使用统计""" conn = self._get_connection() @@ -1314,23 +1326,28 @@ class TenantManager: "limits": limits, "usage_percentages": { "storage": self._calc_percentage( - row["total_storage"] or 0, limits.get("max_storage_mb", 0) * 1024 * 1024, + row["total_storage"] or 0, + limits.get("max_storage_mb", 0) * 1024 * 1024, ), "transcription": self._calc_percentage( row["total_transcription"] or 0, limits.get("max_transcription_minutes", 0) * 60, ), "api_calls": self._calc_percentage( - row["total_api_calls"] or 0, limits.get("max_api_calls_per_day", 0), + row["total_api_calls"] or 0, + limits.get("max_api_calls_per_day", 0), ), "projects": self._calc_percentage( - row["max_projects"] or 0, limits.get("max_projects", 0), + row["max_projects"] or 0, + limits.get("max_projects", 0), ), "entities": self._calc_percentage( - row["max_entities"] or 0, limits.get("max_entities", 0), + row["max_entities"] or 0, + limits.get("max_entities", 0), ), "members": self._calc_percentage( - row["max_members"] or 0, limits.get("max_team_members", 0), + row["max_members"] or 0, + limits.get("max_team_members", 0), ), }, } @@ -1406,8 +1423,10 @@ class TenantManager: def _validate_domain(self, domain: str) -> bool: """验证域名格式""" - pattern = (r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*" - r"[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$") + pattern = ( + r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*" + r"[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$" + ) return bool(re.match(pattern, domain)) def _check_domain_verification(self, domain: str, token: str, method: str) -> bool: diff --git a/backend/test_phase7_task6_8.py b/backend/test_phase7_task6_8.py index 042a266..a786940 100644 --- a/backend/test_phase7_task6_8.py +++ b/backend/test_phase7_task6_8.py @@ -159,7 +159,8 @@ def test_cache_manager() -> None: # 批量操作 cache.set_many( - {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60, + {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, + ttl=60, ) print(" ✓ 批量设置缓存") @@ -208,7 +209,8 @@ def test_task_queue() -> None: # 提交任务 task_id = queue.submit( - task_type="test_task", payload={"test": "data", "timestamp": time.time()}, + task_type="test_task", + payload={"test": "data", "timestamp": time.time()}, ) print(" ✓ 提交任务: {task_id}") diff --git a/backend/test_phase8_task1.py b/backend/test_phase8_task1.py index f66f6ff..2abc7f5 100644 --- a/backend/test_phase8_task1.py +++ b/backend/test_phase8_task1.py @@ -29,7 +29,10 @@ def test_tenant_management() -> None: # 1. 创建租户 print("\n1.1 创建租户...") tenant = manager.create_tenant( - name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant", + name="Test Company", + owner_id="user_001", + tier="pro", + description="A test company tenant", ) print(f"✅ 租户创建成功: {tenant.id}") print(f" - 名称: {tenant.name}") @@ -53,7 +56,9 @@ def test_tenant_management() -> None: # 4. 更新租户 print("\n1.4 更新租户信息...") updated = manager.update_tenant( - tenant_id=tenant.id, name="Test Company Updated", tier="enterprise", + tenant_id=tenant.id, + name="Test Company Updated", + tier="enterprise", ) assert updated is not None, "更新租户失败" print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}") @@ -163,7 +168,10 @@ def test_member_management(tenant_id: str) -> None: # 1. 邀请成员 print("\n4.1 邀请成员...") member1 = manager.invite_member( - tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001", + tenant_id=tenant_id, + email="admin@test.com", + role="admin", + invited_by="user_001", ) print(f"✅ 成员邀请成功: {member1.email}") print(f" - ID: {member1.id}") @@ -171,7 +179,10 @@ def test_member_management(tenant_id: str) -> None: print(f" - 权限: {member1.permissions}") member2 = manager.invite_member( - tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001", + tenant_id=tenant_id, + email="member@test.com", + role="member", + invited_by="user_001", ) print(f"✅ 成员邀请成功: {member2.email}") diff --git a/backend/test_phase8_task2.py b/backend/test_phase8_task2.py index ecdec7b..d5261c4 100644 --- a/backend/test_phase8_task2.py +++ b/backend/test_phase8_task2.py @@ -205,7 +205,8 @@ def test_subscription_manager() -> None: # 更改计划 changed = manager.change_plan( - subscription_id=subscription.id, new_plan_id=enterprise_plan.id, + subscription_id=subscription.id, + new_plan_id=enterprise_plan.id, ) print(f"✓ 更改计划: {changed.plan_id} (Enterprise)") diff --git a/backend/test_phase8_task4.py b/backend/test_phase8_task4.py index 9a3841d..73a04ff 100644 --- a/backend/test_phase8_task4.py +++ b/backend/test_phase8_task4.py @@ -181,14 +181,16 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None: # 2. 趋势预测 print("2. 趋势预测...") trend_result = await manager.predict( - trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]}, + trend_model_id, + {"historical_values": [10, 12, 15, 14, 18, 20, 22]}, ) print(f" 预测结果: {trend_result.prediction_data}") # 3. 异常检测 print("3. 异常检测...") anomaly_result = await manager.predict( - anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]}, + anomaly_model_id, + {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]}, ) print(f" 检测结果: {anomaly_result.prediction_data}") diff --git a/backend/test_phase8_task5.py b/backend/test_phase8_task5.py index ffe6376..5cf5a7b 100644 --- a/backend/test_phase8_task5.py +++ b/backend/test_phase8_task5.py @@ -525,7 +525,8 @@ class TestGrowthManager: try: referral = self.manager.generate_referral_code( - program_id=program_id, referrer_id="referrer_user_001", + program_id=program_id, + referrer_id="referrer_user_001", ) if referral: @@ -551,7 +552,8 @@ class TestGrowthManager: try: success = self.manager.apply_referral_code( - referral_code=referral_code, referee_id="new_user_001", + referral_code=referral_code, + referee_id="new_user_001", ) if success: @@ -618,7 +620,9 @@ class TestGrowthManager: try: incentives = self.manager.check_team_incentive_eligibility( - tenant_id=self.test_tenant_id, current_tier="free", team_size=5, + tenant_id=self.test_tenant_id, + current_tier="free", + team_size=5, ) self.log(f"找到 {len(incentives)} 个符合条件的激励") diff --git a/backend/test_phase8_task6.py b/backend/test_phase8_task6.py index 2ec3077..2572433 100644 --- a/backend/test_phase8_task6.py +++ b/backend/test_phase8_task6.py @@ -162,7 +162,7 @@ class TestDeveloperEcosystem: self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})") except Exception as e: - self.log(f"Failed to create SDK: {str(e)}", success=False) + self.log(f"Failed to create SDK: {e!s}", success=False) def test_sdk_list(self) -> None: """测试列出 SDK""" @@ -179,7 +179,7 @@ class TestDeveloperEcosystem: self.log(f"Search found {len(search_results)} SDKs") except Exception as e: - self.log(f"Failed to list SDKs: {str(e)}", success=False) + self.log(f"Failed to list SDKs: {e!s}", success=False) def test_sdk_get(self) -> None: """测试获取 SDK 详情""" @@ -191,19 +191,20 @@ class TestDeveloperEcosystem: else: self.log("SDK not found", success=False) except Exception as e: - self.log(f"Failed to get SDK: {str(e)}", success=False) + self.log(f"Failed to get SDK: {e!s}", success=False) def test_sdk_update(self) -> None: """测试更新 SDK""" try: if self.created_ids["sdk"]: sdk = self.manager.update_sdk_release( - self.created_ids["sdk"][0], description="Updated description", + self.created_ids["sdk"][0], + description="Updated description", ) if sdk: self.log(f"Updated SDK: {sdk.name}") except Exception as e: - self.log(f"Failed to update SDK: {str(e)}", success=False) + self.log(f"Failed to update SDK: {e!s}", success=False) def test_sdk_publish(self) -> None: """测试发布 SDK""" @@ -213,7 +214,7 @@ class TestDeveloperEcosystem: if sdk: self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})") except Exception as e: - self.log(f"Failed to publish SDK: {str(e)}", success=False) + self.log(f"Failed to publish SDK: {e!s}", success=False) def test_sdk_version_add(self) -> None: """测试添加 SDK 版本""" @@ -230,7 +231,7 @@ class TestDeveloperEcosystem: ) self.log(f"Added SDK version: {version.version}") except Exception as e: - self.log(f"Failed to add SDK version: {str(e)}", success=False) + self.log(f"Failed to add SDK version: {e!s}", success=False) def test_template_create(self) -> None: """测试创建模板""" @@ -273,7 +274,7 @@ class TestDeveloperEcosystem: self.log(f"Created free template: {template_free.name}") except Exception as e: - self.log(f"Failed to create template: {str(e)}", success=False) + self.log(f"Failed to create template: {e!s}", success=False) def test_template_list(self) -> None: """测试列出模板""" @@ -290,7 +291,7 @@ class TestDeveloperEcosystem: self.log(f"Found {len(free_templates)} free templates") except Exception as e: - self.log(f"Failed to list templates: {str(e)}", success=False) + self.log(f"Failed to list templates: {e!s}", success=False) def test_template_get(self) -> None: """测试获取模板详情""" @@ -300,19 +301,20 @@ class TestDeveloperEcosystem: if template: self.log(f"Retrieved template: {template.name}") except Exception as e: - self.log(f"Failed to get template: {str(e)}", success=False) + self.log(f"Failed to get template: {e!s}", success=False) def test_template_approve(self) -> None: """测试审核通过模板""" try: if self.created_ids["template"]: template = self.manager.approve_template( - self.created_ids["template"][0], reviewed_by="admin_001", + self.created_ids["template"][0], + reviewed_by="admin_001", ) if template: self.log(f"Approved template: {template.name}") except Exception as e: - self.log(f"Failed to approve template: {str(e)}", success=False) + self.log(f"Failed to approve template: {e!s}", success=False) def test_template_publish(self) -> None: """测试发布模板""" @@ -322,7 +324,7 @@ class TestDeveloperEcosystem: if template: self.log(f"Published template: {template.name}") except Exception as e: - self.log(f"Failed to publish template: {str(e)}", success=False) + self.log(f"Failed to publish template: {e!s}", success=False) def test_template_review(self) -> None: """测试添加模板评价""" @@ -338,7 +340,7 @@ class TestDeveloperEcosystem: ) self.log(f"Added template review: {review.rating} stars") except Exception as e: - self.log(f"Failed to add template review: {str(e)}", success=False) + self.log(f"Failed to add template review: {e!s}", success=False) def test_plugin_create(self) -> None: """测试创建插件""" @@ -384,7 +386,7 @@ class TestDeveloperEcosystem: self.log(f"Created free plugin: {plugin_free.name}") except Exception as e: - self.log(f"Failed to create plugin: {str(e)}", success=False) + self.log(f"Failed to create plugin: {e!s}", success=False) def test_plugin_list(self) -> None: """测试列出插件""" @@ -397,7 +399,7 @@ class TestDeveloperEcosystem: self.log(f"Found {len(integration_plugins)} integration plugins") except Exception as e: - self.log(f"Failed to list plugins: {str(e)}", success=False) + self.log(f"Failed to list plugins: {e!s}", success=False) def test_plugin_get(self) -> None: """测试获取插件详情""" @@ -407,7 +409,7 @@ class TestDeveloperEcosystem: if plugin: self.log(f"Retrieved plugin: {plugin.name}") except Exception as e: - self.log(f"Failed to get plugin: {str(e)}", success=False) + self.log(f"Failed to get plugin: {e!s}", success=False) def test_plugin_review(self) -> None: """测试审核插件""" @@ -422,7 +424,7 @@ class TestDeveloperEcosystem: if plugin: self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})") except Exception as e: - self.log(f"Failed to review plugin: {str(e)}", success=False) + self.log(f"Failed to review plugin: {e!s}", success=False) def test_plugin_publish(self) -> None: """测试发布插件""" @@ -432,7 +434,7 @@ class TestDeveloperEcosystem: if plugin: self.log(f"Published plugin: {plugin.name}") except Exception as e: - self.log(f"Failed to publish plugin: {str(e)}", success=False) + self.log(f"Failed to publish plugin: {e!s}", success=False) def test_plugin_review_add(self) -> None: """测试添加插件评价""" @@ -448,7 +450,7 @@ class TestDeveloperEcosystem: ) self.log(f"Added plugin review: {review.rating} stars") except Exception as e: - self.log(f"Failed to add plugin review: {str(e)}", success=False) + self.log(f"Failed to add plugin review: {e!s}", success=False) def test_developer_profile_create(self) -> None: """测试创建开发者档案""" @@ -479,7 +481,7 @@ class TestDeveloperEcosystem: self.log(f"Created developer profile: {profile2.display_name}") except Exception as e: - self.log(f"Failed to create developer profile: {str(e)}", success=False) + self.log(f"Failed to create developer profile: {e!s}", success=False) def test_developer_profile_get(self) -> None: """测试获取开发者档案""" @@ -489,19 +491,20 @@ class TestDeveloperEcosystem: if profile: self.log(f"Retrieved developer profile: {profile.display_name}") except Exception as e: - self.log(f"Failed to get developer profile: {str(e)}", success=False) + self.log(f"Failed to get developer profile: {e!s}", success=False) def test_developer_verify(self) -> None: """测试验证开发者""" try: if self.created_ids["developer"]: profile = self.manager.verify_developer( - self.created_ids["developer"][0], DeveloperStatus.VERIFIED, + self.created_ids["developer"][0], + DeveloperStatus.VERIFIED, ) if profile: self.log(f"Verified developer: {profile.display_name} ({profile.status.value})") except Exception as e: - self.log(f"Failed to verify developer: {str(e)}", success=False) + self.log(f"Failed to verify developer: {e!s}", success=False) def test_developer_stats_update(self) -> None: """测试更新开发者统计""" @@ -513,7 +516,7 @@ class TestDeveloperEcosystem: f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates", ) except Exception as e: - self.log(f"Failed to update developer stats: {str(e)}", success=False) + self.log(f"Failed to update developer stats: {e!s}", success=False) def test_code_example_create(self) -> None: """测试创建代码示例""" @@ -562,7 +565,7 @@ console.log('Upload complete:', result.id); self.log(f"Created code example: {example_js.title}") except Exception as e: - self.log(f"Failed to create code example: {str(e)}", success=False) + self.log(f"Failed to create code example: {e!s}", success=False) def test_code_example_list(self) -> None: """测试列出代码示例""" @@ -575,7 +578,7 @@ console.log('Upload complete:', result.id); self.log(f"Found {len(python_examples)} Python examples") except Exception as e: - self.log(f"Failed to list code examples: {str(e)}", success=False) + self.log(f"Failed to list code examples: {e!s}", success=False) def test_code_example_get(self) -> None: """测试获取代码示例详情""" @@ -587,7 +590,7 @@ console.log('Upload complete:', result.id); f"Retrieved code example: {example.title} (views: {example.view_count})", ) except Exception as e: - self.log(f"Failed to get code example: {str(e)}", success=False) + self.log(f"Failed to get code example: {e!s}", success=False) def test_portal_config_create(self) -> None: """测试创建开发者门户配置""" @@ -608,7 +611,7 @@ console.log('Upload complete:', result.id); self.log(f"Created portal config: {config.name}") except Exception as e: - self.log(f"Failed to create portal config: {str(e)}", success=False) + self.log(f"Failed to create portal config: {e!s}", success=False) def test_portal_config_get(self) -> None: """测试获取开发者门户配置""" @@ -624,7 +627,7 @@ console.log('Upload complete:', result.id); self.log(f"Active portal config: {active_config.name}") except Exception as e: - self.log(f"Failed to get portal config: {str(e)}", success=False) + self.log(f"Failed to get portal config: {e!s}", success=False) def test_revenue_record(self) -> None: """测试记录开发者收益""" @@ -644,7 +647,7 @@ console.log('Upload complete:', result.id); self.log(f" - Platform fee: {revenue.platform_fee}") self.log(f" - Developer earnings: {revenue.developer_earnings}") except Exception as e: - self.log(f"Failed to record revenue: {str(e)}", success=False) + self.log(f"Failed to record revenue: {e!s}", success=False) def test_revenue_summary(self) -> None: """测试获取开发者收益汇总""" @@ -659,7 +662,7 @@ console.log('Upload complete:', result.id); self.log(f" - Total earnings: {summary['total_earnings']}") self.log(f" - Transaction count: {summary['transaction_count']}") except Exception as e: - self.log(f"Failed to get revenue summary: {str(e)}", success=False) + self.log(f"Failed to get revenue summary: {e!s}", success=False) def print_summary(self) -> None: """打印测试摘要""" diff --git a/backend/test_phase8_task8.py b/backend/test_phase8_task8.py index fcac2dc..5356e55 100644 --- a/backend/test_phase8_task8.py +++ b/backend/test_phase8_task8.py @@ -129,7 +129,9 @@ class TestOpsManager: # 更新告警规则 updated_rule = self.manager.update_alert_rule( - rule1.id, threshold=85.0, description="更新后的描述", + rule1.id, + threshold=85.0, + description="更新后的描述", ) assert updated_rule.threshold == 85.0 self.log(f"Updated alert rule threshold to {updated_rule.threshold}") @@ -421,7 +423,9 @@ class TestOpsManager: # 模拟扩缩容评估 event = self.manager.evaluate_scaling_policy( - policy_id=policy.id, current_instances=3, current_utilization=0.85, + policy_id=policy.id, + current_instances=3, + current_utilization=0.85, ) if event: @@ -439,7 +443,8 @@ class TestOpsManager: with self.manager._get_db() as conn: conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,)) conn.execute( - "DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,), + "DELETE FROM auto_scaling_policies WHERE tenant_id = ?", + (self.tenant_id,), ) conn.commit() self.log("Cleaned up auto scaling test data") @@ -530,7 +535,8 @@ class TestOpsManager: # 发起故障转移 event = self.manager.initiate_failover( - config_id=config.id, reason="Primary region health check failed", + config_id=config.id, + reason="Primary region health check failed", ) if event: @@ -638,7 +644,9 @@ class TestOpsManager: # 生成成本报告 now = datetime.now() report = self.manager.generate_cost_report( - tenant_id=self.tenant_id, year=now.year, month=now.month, + tenant_id=self.tenant_id, + year=now.year, + month=now.month, ) self.log(f"Generated cost report: {report.id}") @@ -691,7 +699,8 @@ class TestOpsManager: ) conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,)) conn.execute( - "DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,), + "DELETE FROM resource_utilizations WHERE tenant_id = ?", + (self.tenant_id,), ) conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,)) conn.commit() diff --git a/backend/tingwu_client.py b/backend/tingwu_client.py index 0529dcf..4f1e621 100644 --- a/backend/tingwu_client.py +++ b/backend/tingwu_client.py @@ -19,7 +19,11 @@ class TingwuClient: raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required") def _sign_request( - self, method: str, uri: str, query: str = "", body: str = "", + self, + method: str, + uri: str, + query: str = "", + body: str = "", ) -> dict[str, str]: """阿里云签名 V3""" timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") @@ -43,7 +47,8 @@ class TingwuClient: from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient config = open_api_models.Config( - access_key_id=self.access_key, access_key_secret=self.secret_key, + access_key_id=self.access_key, + access_key_secret=self.secret_key, ) config.endpoint = "tingwu.cn-beijing.aliyuncs.com" client = TingwuSDKClient(config) @@ -53,7 +58,8 @@ class TingwuClient: input=tingwu_models.Input(source="OSS", file_url=audio_url), parameters=tingwu_models.Parameters( transcription=tingwu_models.Transcription( - diarization_enabled=True, sentence_max_length=20, + diarization_enabled=True, + sentence_max_length=20, ), ), ) @@ -73,7 +79,10 @@ class TingwuClient: return f"mock_task_{int(time.time())}" def get_task_result( - self, task_id: str, max_retries: int = 60, interval: int = 5, + self, + task_id: str, + max_retries: int = 60, + interval: int = 5, ) -> dict[str, Any]: """获取任务结果""" try: @@ -83,7 +92,8 @@ class TingwuClient: from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient config = open_api_models.Config( - access_key_id=self.access_key, access_key_secret=self.secret_key, + access_key_id=self.access_key, + access_key_secret=self.secret_key, ) config.endpoint = "tingwu.cn-beijing.aliyuncs.com" client = TingwuSDKClient(config) diff --git a/backend/workflow_manager.py b/backend/workflow_manager.py index 235e6b2..5ce6447 100644 --- a/backend/workflow_manager.py +++ b/backend/workflow_manager.py @@ -264,7 +264,9 @@ class WebhookNotifier: secret_enc = config.secret.encode("utf-8") string_to_sign = f"{timestamp}\n{config.secret}" hmac_code = hmac.new( - secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256, + secret_enc, + string_to_sign.encode("utf-8"), + digestmod=hashlib.sha256, ).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) url = f"{config.url}×tamp = {timestamp}&sign = {sign}" @@ -497,7 +499,10 @@ class WorkflowManager: conn.close() def list_workflows( - self, project_id: str = None, status: str = None, workflow_type: str = None, + self, + project_id: str | None = None, + status: str = None, + workflow_type: str = None, ) -> list[Workflow]: """列出工作流""" conn = self.db.get_conn() @@ -518,7 +523,8 @@ class WorkflowManager: where_clause = " AND ".join(conditions) if conditions else "1 = 1" rows = conn.execute( - f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params, + f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", + params, ).fetchall() return [self._row_to_workflow(row) for row in rows] @@ -780,7 +786,8 @@ class WorkflowManager: conn = self.db.get_conn() try: row = conn.execute( - "SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,), + "SELECT * FROM webhook_configs WHERE id = ?", + (webhook_id,), ).fetchone() if not row: @@ -962,9 +969,9 @@ class WorkflowManager: def list_logs( self, - workflow_id: str = None, - task_id: str = None, - status: str = None, + workflow_id: str | None = None, + task_id: str | None = None, + status: str | None = None, limit: int = 100, offset: int = 0, ) -> list[WorkflowLog]: @@ -1074,7 +1081,7 @@ class WorkflowManager: # ==================== Workflow Execution ==================== - async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict: + async def execute_workflow(self, workflow_id: str, input_data: dict | None = None) -> dict: """执行工作流""" workflow = self.get_workflow(workflow_id) if not workflow: @@ -1159,7 +1166,10 @@ class WorkflowManager: raise async def _execute_tasks_with_deps( - self, tasks: list[WorkflowTask], input_data: dict, log_id: str, + self, + tasks: list[WorkflowTask], + input_data: dict, + log_id: str, ) -> dict: """按依赖顺序执行任务""" results = {} @@ -1413,7 +1423,10 @@ class WorkflowManager: # ==================== Notification ==================== async def _send_workflow_notification( - self, workflow: Workflow, results: dict, success: bool = True, + self, + workflow: Workflow, + results: dict, + success: bool = True, ) -> None: """发送工作流执行通知""" if not workflow.webhook_ids: