diff --git a/auto_fix_code.py b/auto_fix_code.py index caf66e1..70e7069 100644 --- a/auto_fix_code.py +++ b/auto_fix_code.py @@ -1,192 +1,153 @@ #!/usr/bin/env python3 """ -InsightFlow 代码自动修复脚本 - 增强版 -自动修复代码中的常见问题 +InsightFlow 代码自动修复脚本 """ -import json import os import re import subprocess from pathlib import Path +PROJECT_DIR = Path("/root/.openclaw/workspace/projects/insightflow") +BACKEND_DIR = PROJECT_DIR / "backend" -def run_ruff_check(directory: str) -> list[dict]: - """运行 ruff 检查并返回问题列表""" - try: - result = subprocess.run( - ["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory], - capture_output = True, - text = True, - check = False, - ) - if result.stdout: - return json.loads(result.stdout) - return [] - except Exception as e: - print(f"Ruff check failed: {e}") - return [] +def run_flake8(): + """运行 flake8 检查""" + result = subprocess.run( + ["flake8", "--max-line-length=120", "--ignore=E501,W503", "."], + cwd=BACKEND_DIR, + capture_output=True, + text=True + ) + return result.stdout +def fix_missing_imports(): + """修复缺失的导入""" + fixes = [] + + # 检查 workflow_manager.py 中的 urllib + workflow_file = BACKEND_DIR / "workflow_manager.py" + if workflow_file.exists(): + content = workflow_file.read_text() + if "import urllib" not in content and "urllib" in content: + # 在文件开头添加导入 + lines = content.split('\n') + import_idx = 0 + for i, line in enumerate(lines): + if line.startswith('import ') or line.startswith('from '): + import_idx = i + 1 + lines.insert(import_idx, 'import urllib.parse') + workflow_file.write_text('\n'.join(lines)) + fixes.append("workflow_manager.py: 添加 urllib.parse 导入") + + # 检查 plugin_manager.py 中的 urllib + plugin_file = BACKEND_DIR / "plugin_manager.py" + if plugin_file.exists(): + content = plugin_file.read_text() + if "import urllib" not in content and "urllib" in content: + lines = content.split('\n') + import_idx = 0 + for i, line in enumerate(lines): + if line.startswith('import ') or line.startswith('from '): + import_idx = i + 1 + lines.insert(import_idx, 'import urllib.parse') + plugin_file.write_text('\n'.join(lines)) + fixes.append("plugin_manager.py: 添加 urllib.parse 导入") + + # 检查 main.py 中的 PlainTextResponse + main_file = BACKEND_DIR / "main.py" + if main_file.exists(): + content = main_file.read_text() + if "PlainTextResponse" in content and "from fastapi.responses import" in content: + # 检查是否已导入 + if "PlainTextResponse" not in content.split('from fastapi.responses import')[1].split('\n')[0]: + # 添加导入 + content = content.replace( + "from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse", + "from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse" + ) + # 实际上已经导入了,可能是误报 + + return fixes -def fix_bare_except(content: str) -> str: - """修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:""" - pattern = r'except\s*:\s*\n' - replacement = 'except Exception:\n' - return re.sub(pattern, replacement, content) +def fix_unused_imports(): + """修复未使用的导入""" + fixes = [] + + # code_reviewer.py 中的未使用导入 + code_reviewer = PROJECT_DIR / "code_reviewer.py" + if code_reviewer.exists(): + content = code_reviewer.read_text() + original = content + # 移除未使用的导入 + content = re.sub(r'^import os\n', '', content, flags=re.MULTILINE) + content = re.sub(r'^import subprocess\n', '', content, flags=re.MULTILINE) + content = re.sub(r'^from typing import Any\n', '', content, flags=re.MULTILINE) + if content != original: + code_reviewer.write_text(content) + fixes.append("code_reviewer.py: 移除未使用的导入") + + return fixes +def fix_formatting(): + """使用 autopep8 修复格式问题""" + fixes = [] + + # 运行 autopep8 修复格式问题 + result = subprocess.run( + ["autopep8", "--in-place", "--aggressive", "--max-line-length=120", "."], + cwd=BACKEND_DIR, + capture_output=True, + text=True + ) + + if result.returncode == 0: + fixes.append("使用 autopep8 修复了格式问题") + + return fixes -def fix_undefined_names(content: str, filepath: str) -> str: - """修复未定义的名称""" - lines = content.split('\n') - modified = False - - import_map = { - 'ExportEntity': 'from export_manager import ExportEntity', - 'ExportRelation': 'from export_manager import ExportRelation', - 'ExportTranscript': 'from export_manager import ExportTranscript', - 'WorkflowManager': 'from workflow_manager import WorkflowManager', - 'PluginManager': 'from plugin_manager import PluginManager', - 'OpsManager': 'from ops_manager import OpsManager', - 'urllib': 'import urllib.parse', - } - - undefined_names = set() - for name, import_stmt in import_map.items(): - if name in content and import_stmt not in content: - undefined_names.add((name, import_stmt)) - - if undefined_names: - import_idx = 0 - for i, line in enumerate(lines): - if line.startswith('import ') or line.startswith('from '): - import_idx = i + 1 - - for name, import_stmt in sorted(undefined_names): - lines.insert(import_idx, import_stmt) - import_idx += 1 - modified = True - - if modified: - return '\n'.join(lines) - return content - - -def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]: - """修复单个文件的问题""" - with open(filepath, 'r', encoding = 'utf-8') as f: - original_content = f.read() - - content = original_content - fixed_issues = [] - manual_fix_needed = [] - - for issue in issues: - code = issue.get('code', '') - message = issue.get('message', '') - line_num = issue['location']['row'] - - if code == 'F821': - content = fix_undefined_names(content, filepath) - if content != original_content: - fixed_issues.append(f"F821 - {message} (line {line_num})") - else: - manual_fix_needed.append(f"F821 - {message} (line {line_num})") - elif code == 'E501': - manual_fix_needed.append(f"E501 (line {line_num})") - - content = fix_bare_except(content) - - if content != original_content: - with open(filepath, 'w', encoding = 'utf-8') as f: - f.write(content) - return True, fixed_issues, manual_fix_needed - - return False, fixed_issues, manual_fix_needed - - -def main() -> None: - base_dir = Path("/root/.openclaw/workspace/projects/insightflow") - backend_dir = base_dir / "backend" - - print(" = " * 60) +def main(): + print("=" * 60) print("InsightFlow 代码自动修复") - print(" = " * 60) - - print("\n1. 扫描代码问题...") - issues = run_ruff_check(str(backend_dir)) - - issues_by_file = {} - for issue in issues: - filepath = issue.get('filename', '') - if filepath not in issues_by_file: - issues_by_file[filepath] = [] - issues_by_file[filepath].append(issue) - - print(f" 发现 {len(issues)} 个问题,分布在 {len(issues_by_file)} 个文件中") - - issue_types = {} - for issue in issues: - code = issue.get('code', 'UNKNOWN') - issue_types[code] = issue_types.get(code, 0) + 1 - - print("\n2. 问题类型统计:") - for code, count in sorted(issue_types.items(), key = lambda x: -x[1]): - print(f" - {code}: {count} 个") - - print("\n3. 尝试自动修复...") - fixed_files = [] - all_fixed_issues = [] - all_manual_fixes = [] - - for filepath, file_issues in issues_by_file.items(): - if not os.path.exists(filepath): - continue - - modified, fixed, manual = fix_file(filepath, file_issues) - if modified: - fixed_files.append(filepath) - all_fixed_issues.extend(fixed) - all_manual_fixes.extend([(filepath, m) for m in manual]) - - print(f" 直接修改了 {len(fixed_files)} 个文件") - print(f" 自动修复了 {len(all_fixed_issues)} 个问题") - - print("\n4. 运行 ruff 自动格式化...") - try: - subprocess.run( - ["ruff", "format", str(backend_dir)], - capture_output = True, - check = False, - ) - print(" 格式化完成") - except Exception as e: - print(f" 格式化失败: {e}") - - print("\n5. 再次检查...") - remaining_issues = run_ruff_check(str(backend_dir)) - print(f" 剩余 {len(remaining_issues)} 个问题需要手动处理") - - report = { - 'total_issues': len(issues), - 'fixed_files': len(fixed_files), - 'fixed_issues': len(all_fixed_issues), - 'remaining_issues': len(remaining_issues), - 'issue_types': issue_types, - 'manual_fix_needed': all_manual_fixes[:30], - } - - return report - + print("=" * 60) + + all_fixes = [] + + # 1. 修复缺失的导入 + print("\n[1/3] 修复缺失的导入...") + fixes = fix_missing_imports() + all_fixes.extend(fixes) + for f in fixes: + print(f" ✓ {f}") + + # 2. 修复未使用的导入 + print("\n[2/3] 修复未使用的导入...") + fixes = fix_unused_imports() + all_fixes.extend(fixes) + for f in fixes: + print(f" ✓ {f}") + + # 3. 修复格式问题 + print("\n[3/3] 修复 PEP8 格式问题...") + fixes = fix_formatting() + all_fixes.extend(fixes) + for f in fixes: + print(f" ✓ {f}") + + print("\n" + "=" * 60) + print(f"修复完成!共修复 {len(all_fixes)} 个问题") + print("=" * 60) + + # 再次运行 flake8 检查 + print("\n重新运行 flake8 检查...") + remaining = run_flake8() + if remaining: + lines = remaining.strip().split('\n') + print(f" 仍有 {len(lines)} 个问题需要手动处理") + else: + print(" ✓ 所有问题已修复!") + + return all_fixes if __name__ == "__main__": - report = main() - print("\n" + " = " * 60) - print("修复报告") - print(" = " * 60) - print(f"总问题数: {report['total_issues']}") - print(f"修复文件数: {report['fixed_files']}") - print(f"自动修复问题数: {report['fixed_issues']}") - print(f"剩余问题数: {report['remaining_issues']}") - print(f"\n需要手动处理的问题 (前30个):") - for filepath, issue in report['manual_fix_needed']: - print(f" - {filepath}: {issue}") + main() diff --git a/backend/ai_manager.py b/backend/ai_manager.py index b1e5699..c4e5e2e 100644 --- a/backend/ai_manager.py +++ b/backend/ai_manager.py @@ -25,44 +25,44 @@ from enum import StrEnum import httpx # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class ModelType(StrEnum): """模型类型""" - CUSTOM_NER = "custom_ner" # 自定义实体识别 - MULTIMODAL = "multimodal" # 多模态 - SUMMARIZATION = "summarization" # 摘要 - PREDICTION = "prediction" # 预测 + CUSTOM_NER = "custom_ner" # 自定义实体识别 + MULTIMODAL = "multimodal" # 多模态 + SUMMARIZATION = "summarization" # 摘要 + PREDICTION = "prediction" # 预测 class ModelStatus(StrEnum): """模型状态""" - PENDING = "pending" - TRAINING = "training" - READY = "ready" - FAILED = "failed" - ARCHIVED = "archived" + PENDING = "pending" + TRAINING = "training" + READY = "ready" + FAILED = "failed" + ARCHIVED = "archived" class MultimodalProvider(StrEnum): """多模态模型提供商""" - GPT4V = "gpt-4-vision" - CLAUDE3 = "claude-3" - GEMINI = "gemini-pro-vision" - KIMI_VL = "kimi-vl" + GPT4V = "gpt-4-vision" + CLAUDE3 = "claude-3" + GEMINI = "gemini-pro-vision" + KIMI_VL = "kimi-vl" class PredictionType(StrEnum): """预测类型""" - TREND = "trend" # 趋势预测 - ANOMALY = "anomaly" # 异常检测 - ENTITY_GROWTH = "entity_growth" # 实体增长预测 - RELATION_EVOLUTION = "relation_evolution" # 关系演变预测 + TREND = "trend" # 趋势预测 + ANOMALY = "anomaly" # 异常检测 + ENTITY_GROWTH = "entity_growth" # 实体增长预测 + RELATION_EVOLUTION = "relation_evolution" # 关系演变预测 @dataclass @@ -204,17 +204,17 @@ class SmartSummary: class AIManager: """AI 能力管理主类""" - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path - self.kimi_api_key = os.getenv("KIMI_API_KEY", "") - self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") - self.openai_api_key = os.getenv("OPENAI_API_KEY", "") - self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self.kimi_api_key = os.getenv("KIMI_API_KEY", "") + self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") + self.openai_api_key = os.getenv("OPENAI_API_KEY", "") + self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") def _get_db(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn # ==================== 自定义模型训练 ==================== @@ -230,24 +230,24 @@ class AIManager: created_by: str, ) -> CustomModel: """创建自定义模型""" - model_id = f"cm_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + model_id = f"cm_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - model = CustomModel( - id = model_id, - tenant_id = tenant_id, - name = name, - description = description, - model_type = model_type, - status = ModelStatus.PENDING, - training_data = training_data, - hyperparameters = hyperparameters, - metrics = {}, - model_path = None, - created_at = now, - updated_at = now, - trained_at = None, - created_by = created_by, + model = CustomModel( + id=model_id, + tenant_id=tenant_id, + name=name, + description=description, + model_type=model_type, + status=ModelStatus.PENDING, + training_data=training_data, + hyperparameters=hyperparameters, + metrics={}, + model_path=None, + created_at=now, + updated_at=now, + trained_at=None, + created_by=created_by, ) with self._get_db() as conn: @@ -283,7 +283,7 @@ class AIManager: def get_custom_model(self, model_id: str) -> CustomModel | None: """获取自定义模型""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id, )).fetchone() + row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id, )).fetchone() if not row: return None @@ -291,11 +291,11 @@ class AIManager: return self._row_to_custom_model(row) def list_custom_models( - self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None + self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None ) -> list[CustomModel]: """列出自定义模型""" - query = "SELECT * FROM custom_models WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM custom_models WHERE tenant_id = ?" + params = [tenant_id] if model_type: query += " AND model_type = ?" @@ -307,23 +307,23 @@ class AIManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_custom_model(row) for row in rows] def add_training_sample( - self, model_id: str, text: str, entities: list[dict], metadata: dict = None + self, model_id: str, text: str, entities: list[dict], metadata: dict = None ) -> TrainingSample: """添加训练样本""" - sample_id = f"ts_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + sample_id = f"ts_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - sample = TrainingSample( - id = sample_id, - model_id = model_id, - text = text, - entities = entities, - metadata = metadata or {}, - created_at = now, + sample = TrainingSample( + id=sample_id, + model_id=model_id, + text=text, + entities=entities, + metadata=metadata or {}, + created_at=now, ) with self._get_db() as conn: @@ -349,7 +349,7 @@ class AIManager: def get_training_samples(self, model_id: str) -> list[TrainingSample]: """获取训练样本""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id, ) ).fetchall() @@ -357,7 +357,7 @@ class AIManager: async def train_custom_model(self, model_id: str) -> CustomModel: """训练自定义模型""" - model = self.get_custom_model(model_id) + model = self.get_custom_model(model_id) if not model: raise ValueError(f"Model {model_id} not found") @@ -371,7 +371,7 @@ class AIManager: try: # 获取训练样本 - samples = self.get_training_samples(model_id) + samples = self.get_training_samples(model_id) if len(samples) < 10: raise ValueError("至少需要 10 个训练样本") @@ -380,7 +380,7 @@ class AIManager: await asyncio.sleep(2) # 模拟训练时间 # 计算训练指标 - metrics = { + metrics = { "samples_count": len(samples), "epochs": model.hyperparameters.get("epochs", 10), "learning_rate": model.hyperparameters.get("learning_rate", 0.001), @@ -391,10 +391,10 @@ class AIManager: } # 保存模型(模拟) - model_path = f"models/{model_id}.bin" - os.makedirs("models", exist_ok = True) + model_path = f"models/{model_id}.bin" + os.makedirs("models", exist_ok=True) - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -420,49 +420,49 @@ class AIManager: async def predict_with_custom_model(self, model_id: str, text: str) -> list[dict]: """使用自定义模型进行预测""" - model = self.get_custom_model(model_id) + model = self.get_custom_model(model_id) if not model or model.status != ModelStatus.READY: raise ValueError(f"Model {model_id} not ready") # 模拟预测(实际项目中加载模型并进行推理) # 这里使用 LLM 模拟领域特定实体识别 - entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"]) + entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"]) - prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)} + prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)} 文本: {text} 以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}] 只返回 JSON 数组,不要其他内容。""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers = headers, - json = payload, - timeout = 60.0, + headers=headers, + json=payload, + timeout=60.0, ) response.raise_for_status() - result = response.json() - content = result["choices"][0]["message"]["content"] + result = response.json() + content = result["choices"][0]["message"]["content"] # 解析 JSON - json_match = re.search(r"\[.*?\]", content, re.DOTALL) + json_match = re.search(r"\[.*?\]", content, re.DOTALL) if json_match: try: - entities = json.loads(json_match.group()) + entities = json.loads(json_match.group()) return entities except (json.JSONDecodeError, ValueError): pass @@ -481,30 +481,30 @@ class AIManager: prompt: str, ) -> MultimodalAnalysis: """多模态分析""" - analysis_id = f"ma_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + analysis_id = f"ma_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 根据提供商调用不同的 API if provider == MultimodalProvider.GPT4V and self.openai_api_key: - result = await self._call_gpt4v(input_urls, prompt) + result = await self._call_gpt4v(input_urls, prompt) elif provider == MultimodalProvider.CLAUDE3 and self.anthropic_api_key: - result = await self._call_claude3(input_urls, prompt) + result = await self._call_claude3(input_urls, prompt) else: # 默认使用 Kimi - result = await self._call_kimi_multimodal(input_urls, prompt) + result = await self._call_kimi_multimodal(input_urls, prompt) - analysis = MultimodalAnalysis( - id = analysis_id, - tenant_id = tenant_id, - project_id = project_id, - provider = provider, - input_type = input_type, - input_urls = input_urls, - prompt = prompt, - result = result, - tokens_used = result.get("tokens_used", 0), - cost = result.get("cost", 0.0), - created_at = now, + analysis = MultimodalAnalysis( + id=analysis_id, + tenant_id=tenant_id, + project_id=project_id, + provider=provider, + input_type=input_type, + input_urls=input_urls, + prompt=prompt, + result=result, + tokens_used=result.get("tokens_used", 0), + cost=result.get("cost", 0.0), + created_at=now, ) with self._get_db() as conn: @@ -535,30 +535,30 @@ class AIManager: async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict: """调用 GPT-4V""" - headers = { + headers = { "Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json", } - content = [{"type": "text", "text": prompt}] + content = [{"type": "text", "text": prompt}] for url in image_urls: content.append({"type": "image_url", "image_url": {"url": url}}) - payload = { + payload = { "model": "gpt-4-vision-preview", "messages": [{"role": "user", "content": content}], "max_tokens": 2000, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( "https://api.openai.com/v1/chat/completions", - headers = headers, - json = payload, - timeout = 120.0, + headers=headers, + json=payload, + timeout=120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return { "content": result["choices"][0]["message"]["content"], @@ -568,32 +568,32 @@ class AIManager: async def _call_claude3(self, image_urls: list[str], prompt: str) -> dict: """调用 Claude 3""" - headers = { + headers = { "x-api-key": self.anthropic_api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", } - content = [] + content = [] for url in image_urls: content.append({"type": "image", "source": {"type": "url", "url": url}}) content.append({"type": "text", "text": prompt}) - payload = { + payload = { "model": "claude-3-opus-20240229", "max_tokens": 2000, "messages": [{"role": "user", "content": content}], } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( "https://api.anthropic.com/v1/messages", - headers = headers, - json = payload, - timeout = 120.0, + headers=headers, + json=payload, + timeout=120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return { "content": result["content"][0]["text"], @@ -604,7 +604,7 @@ class AIManager: async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict: """调用 Kimi 多模态模型""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } @@ -612,23 +612,23 @@ class AIManager: # Kimi 目前可能不支持真正的多模态,这里模拟返回 # 实际实现时需要根据 Kimi API 更新 - content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。" + content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。" - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": content}], "temperature": 0.3, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers = headers, - json = payload, - timeout = 60.0, + headers=headers, + json=payload, + timeout=60.0, ) response.raise_for_status() - result = response.json() + result = response.json() return { "content": result["choices"][0]["message"]["content"], @@ -637,11 +637,11 @@ class AIManager: } def get_multimodal_analyses( - self, tenant_id: str, project_id: str | None = None + self, tenant_id: str, project_id: str | None = None ) -> list[MultimodalAnalysis]: """获取多模态分析历史""" - query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" + params = [tenant_id] if project_id: query += " AND project_id = ?" @@ -650,7 +650,7 @@ class AIManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_multimodal_analysis(row) for row in rows] # ==================== 智能摘要与问答(基于知识图谱的 RAG) ==================== @@ -666,21 +666,21 @@ class AIManager: generation_config: dict, ) -> KnowledgeGraphRAG: """创建知识图谱 RAG 配置""" - rag_id = f"kgr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + rag_id = f"kgr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - rag = KnowledgeGraphRAG( - id = rag_id, - tenant_id = tenant_id, - project_id = project_id, - name = name, - description = description, - kg_config = kg_config, - retrieval_config = retrieval_config, - generation_config = generation_config, - is_active = True, - created_at = now, - updated_at = now, + rag = KnowledgeGraphRAG( + id=rag_id, + tenant_id=tenant_id, + project_id=project_id, + name=name, + description=description, + kg_config=kg_config, + retrieval_config=retrieval_config, + generation_config=generation_config, + is_active=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -712,7 +712,7 @@ class AIManager: def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None: """获取知识图谱 RAG 配置""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id, )).fetchone() + row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id, )).fetchone() if not row: return None @@ -720,11 +720,11 @@ class AIManager: return self._row_to_kg_rag(row) def list_kg_rags( - self, tenant_id: str, project_id: str | None = None + self, tenant_id: str, project_id: str | None = None ) -> list[KnowledgeGraphRAG]: """列出知识图谱 RAG 配置""" - query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" + params = [tenant_id] if project_id: query += " AND project_id = ?" @@ -733,30 +733,30 @@ class AIManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_kg_rag(row) for row in rows] async def query_kg_rag( self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict] ) -> RAGQuery: """基于知识图谱的 RAG 查询""" - start_time = time.time() + start_time = time.time() - rag = self.get_kg_rag(rag_id) + rag = self.get_kg_rag(rag_id) if not rag: raise ValueError(f"RAG config {rag_id} not found") # 1. 检索相关实体和关系 - retrieval_config = rag.retrieval_config - top_k = retrieval_config.get("top_k", 5) + retrieval_config = rag.retrieval_config + top_k = retrieval_config.get("top_k", 5) # 简单的语义检索(基于实体名称匹配) - query_lower = query.lower() - relevant_entities = [] + query_lower = query.lower() + relevant_entities = [] for entity in project_entities: - score = 0 - name = entity.get("name", "").lower() - definition = entity.get("definition", "").lower() + score = 0 + name = entity.get("name", "").lower() + definition = entity.get("definition", "").lower() if name in query_lower or any(word in name for word in query_lower.split()): score += 0.5 @@ -766,12 +766,12 @@ class AIManager: if score > 0: relevant_entities.append({**entity, "relevance_score": score}) - relevant_entities.sort(key = lambda x: x["relevance_score"], reverse = True) - relevant_entities = relevant_entities[:top_k] + relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True) + relevant_entities = relevant_entities[:top_k] # 检索相关关系 - relevant_relations = [] - entity_ids = {e["id"] for e in relevant_entities} + relevant_relations = [] + entity_ids = {e["id"] for e in relevant_entities} for relation in project_relations: if ( relation.get("source_entity_id") in entity_ids @@ -780,16 +780,16 @@ class AIManager: relevant_relations.append(relation) # 2. 构建上下文 - context = {"entities": relevant_entities, "relations": relevant_relations[:10]} + context = {"entities": relevant_entities, "relations": relevant_relations[:10]} - context_text = self._build_kg_context(relevant_entities, relevant_relations) + context_text = self._build_kg_context(relevant_entities, relevant_relations) # 3. 生成回答 - generation_config = rag.generation_config - temperature = generation_config.get("temperature", 0.3) - max_tokens = generation_config.get("max_tokens", 1000) + generation_config = rag.generation_config + temperature = generation_config.get("temperature", 0.3) + max_tokens = generation_config.get("max_tokens", 1000) - prompt = f"""基于以下知识图谱信息回答问题: + prompt = f"""基于以下知识图谱信息回答问题: ## 知识图谱上下文 {context_text} @@ -803,12 +803,12 @@ class AIManager: 2. 如果涉及多个实体,说明它们之间的关联 3. 保持简洁专业""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature, @@ -816,44 +816,44 @@ class AIManager: } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers = headers, - json = payload, - timeout = 60.0, + headers=headers, + json=payload, + timeout=60.0, ) response.raise_for_status() - result = response.json() + result = response.json() - answer = result["choices"][0]["message"]["content"] - tokens_used = result["usage"]["total_tokens"] + answer = result["choices"][0]["message"]["content"] + tokens_used = result["usage"]["total_tokens"] - latency_ms = int((time.time() - start_time) * 1000) + latency_ms = int((time.time() - start_time) * 1000) # 4. 保存查询记录 - query_id = f"rq_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + query_id = f"rq_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - sources = [ + sources = [ {"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]} for e in relevant_entities ] - rag_query = RAGQuery( - id = query_id, - rag_id = rag_id, - query = query, - context = context, - answer = answer, - sources = sources, - confidence = ( + rag_query = RAGQuery( + id=query_id, + rag_id=rag_id, + query=query, + context=context, + answer=answer, + sources=sources, + confidence=( sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities) if relevant_entities else 0 ), - tokens_used = tokens_used, - latency_ms = latency_ms, - created_at = now, + tokens_used=tokens_used, + latency_ms=latency_ms, + created_at=now, ) with self._get_db() as conn: @@ -883,23 +883,23 @@ class AIManager: def _build_kg_context(self, entities: list[dict], relations: list[dict]) -> str: """构建知识图谱上下文文本""" - context = [] + context = [] if entities: context.append("### 相关实体") for entity in entities: - name = entity.get("name", "") - entity_type = entity.get("type", "") - definition = entity.get("definition", "") + name = entity.get("name", "") + entity_type = entity.get("type", "") + definition = entity.get("definition", "") context.append(f"- **{name}** ({entity_type}): {definition}") if relations: context.append("\n### 相关关系") for relation in relations: - source = relation.get("source_name", "") - target = relation.get("target_name", "") - rel_type = relation.get("relation_type", "") - evidence = relation.get("evidence", "") + source = relation.get("source_name", "") + target = relation.get("target_name", "") + rel_type = relation.get("relation_type", "") + evidence = relation.get("evidence", "") context.append(f"- {source} --[{rel_type}]--> {target}") if evidence: context.append(f" - 依据: {evidence[:100]}...") @@ -916,12 +916,12 @@ class AIManager: content_data: dict, ) -> SmartSummary: """生成智能摘要""" - summary_id = f"ss_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + summary_id = f"ss_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 根据摘要类型生成不同的提示 if summary_type == "extractive": - prompt = f"""从以下内容中提取关键句子作为摘要: + prompt = f"""从以下内容中提取关键句子作为摘要: {content_data.get("text", "")[:5000]} @@ -931,7 +931,7 @@ class AIManager: 3. 以 JSON 格式返回: {{"summary": "摘要内容", "key_points": ["要点1", "要点2"]}}""" elif summary_type == "abstractive": - prompt = f"""对以下内容生成简洁的摘要: + prompt = f"""对以下内容生成简洁的摘要: {content_data.get("text", "")[:5000]} @@ -941,7 +941,7 @@ class AIManager: 3. 包含关键实体和概念""" elif summary_type == "key_points": - prompt = f"""从以下内容中提取关键要点: + prompt = f"""从以下内容中提取关键要点: {content_data.get("text", "")[:5000]} @@ -951,7 +951,7 @@ class AIManager: 3. 以 JSON 格式返回: {{"key_points": ["要点1", "要点2", ...]}}""" else: # timeline - prompt = f"""基于以下内容生成时间线摘要: + prompt = f"""基于以下内容生成时间线摘要: {content_data.get("text", "")[:5000]} @@ -960,72 +960,72 @@ class AIManager: 2. 标注时间节点(如果有) 3. 突出里程碑事件""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers = headers, - json = payload, - timeout = 60.0, + headers=headers, + json=payload, + timeout=60.0, ) response.raise_for_status() - result = response.json() + result = response.json() - content = result["choices"][0]["message"]["content"] - tokens_used = result["usage"]["total_tokens"] + content = result["choices"][0]["message"]["content"] + tokens_used = result["usage"]["total_tokens"] # 解析关键要点 - key_points = [] + key_points = [] # 尝试从 JSON 中提取 - json_match = re.search(r"\{.*?\}", content, re.DOTALL) + json_match = re.search(r"\{.*?\}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) - key_points = data.get("key_points", []) + data = json.loads(json_match.group()) + key_points = data.get("key_points", []) if "summary" in data: - content = data["summary"] + content = data["summary"] except (json.JSONDecodeError, ValueError): pass # 如果没有提取到关键要点,从文本中提取 if not key_points: - lines = content.split("\n") - key_points = [ + lines = content.split("\n") + key_points = [ line.strip("- ").strip() for line in lines if line.strip().startswith("-") or line.strip().startswith("•") ] if not key_points: - key_points = [content[:200] + "..."] if len(content) > 200 else [content] + key_points = [content[:200] + "..."] if len(content) > 200 else [content] # 提取提及的实体 - entities_mentioned = content_data.get("entities", []) - entity_names = [e.get("name", "") for e in entities_mentioned[:10]] + entities_mentioned = content_data.get("entities", []) + entity_names = [e.get("name", "") for e in entities_mentioned[:10]] - summary = SmartSummary( - id = summary_id, - tenant_id = tenant_id, - project_id = project_id, - source_type = source_type, - source_id = source_id, - summary_type = summary_type, - content = content, - key_points = key_points[:8], - entities_mentioned = entity_names, - confidence = 0.85, - tokens_used = tokens_used, - created_at = now, + summary = SmartSummary( + id=summary_id, + tenant_id=tenant_id, + project_id=project_id, + source_type=source_type, + source_id=source_id, + summary_type=summary_type, + content=content, + key_points=key_points[:8], + entities_mentioned=entity_names, + confidence=0.85, + tokens_used=tokens_used, + created_at=now, ) with self._get_db() as conn: @@ -1068,24 +1068,24 @@ class AIManager: model_config: dict, ) -> PredictionModel: """创建预测模型""" - model_id = f"pm_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + model_id = f"pm_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - model = PredictionModel( - id = model_id, - tenant_id = tenant_id, - project_id = project_id, - name = name, - prediction_type = prediction_type, - target_entity_type = target_entity_type, - features = features, - model_config = model_config, - accuracy = None, - last_trained_at = None, - prediction_count = 0, - is_active = True, - created_at = now, - updated_at = now, + model = PredictionModel( + id=model_id, + tenant_id=tenant_id, + project_id=project_id, + name=name, + prediction_type=prediction_type, + target_entity_type=target_entity_type, + features=features, + model_config=model_config, + accuracy=None, + last_trained_at=None, + prediction_count=0, + is_active=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1121,7 +1121,7 @@ class AIManager: def get_prediction_model(self, model_id: str) -> PredictionModel | None: """获取预测模型""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM prediction_models WHERE id = ?", (model_id, ) ).fetchone() @@ -1131,11 +1131,11 @@ class AIManager: return self._row_to_prediction_model(row) def list_prediction_models( - self, tenant_id: str, project_id: str | None = None + self, tenant_id: str, project_id: str | None = None ) -> list[PredictionModel]: """列出预测模型""" - query = "SELECT * FROM prediction_models WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM prediction_models WHERE tenant_id = ?" + params = [tenant_id] if project_id: query += " AND project_id = ?" @@ -1144,14 +1144,14 @@ class AIManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_prediction_model(row) for row in rows] async def train_prediction_model( self, model_id: str, historical_data: list[dict] ) -> PredictionModel: """训练预测模型""" - model = self.get_prediction_model(model_id) + model = self.get_prediction_model(model_id) if not model: raise ValueError(f"Prediction model {model_id} not found") @@ -1159,9 +1159,9 @@ class AIManager: await asyncio.sleep(1) # 计算准确率(模拟) - accuracy = round(0.75 + random.random() * 0.2, 4) + accuracy = round(0.75 + random.random() * 0.2, 4) - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -1178,39 +1178,39 @@ class AIManager: async def predict(self, model_id: str, input_data: dict) -> PredictionResult: """进行预测""" - model = self.get_prediction_model(model_id) + model = self.get_prediction_model(model_id) if not model or not model.is_active: raise ValueError(f"Prediction model {model_id} not available") - prediction_id = f"pr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + prediction_id = f"pr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 根据预测类型进行不同的预测逻辑 if model.prediction_type == PredictionType.TREND: - prediction_data = self._predict_trend(input_data, model) + prediction_data = self._predict_trend(input_data, model) elif model.prediction_type == PredictionType.ANOMALY: - prediction_data = self._detect_anomaly(input_data, model) + prediction_data = self._detect_anomaly(input_data, model) elif model.prediction_type == PredictionType.ENTITY_GROWTH: - prediction_data = self._predict_entity_growth(input_data, model) + prediction_data = self._predict_entity_growth(input_data, model) elif model.prediction_type == PredictionType.RELATION_EVOLUTION: - prediction_data = self._predict_relation_evolution(input_data, model) + prediction_data = self._predict_relation_evolution(input_data, model) else: - prediction_data = {"value": "unknown", "confidence": 0} + prediction_data = {"value": "unknown", "confidence": 0} - confidence = prediction_data.get("confidence", 0.8) - explanation = prediction_data.get("explanation", "基于历史数据模式预测") + confidence = prediction_data.get("confidence", 0.8) + explanation = prediction_data.get("explanation", "基于历史数据模式预测") - result = PredictionResult( - id = prediction_id, - model_id = model_id, - prediction_type = model.prediction_type, - target_id = input_data.get("target_id"), - prediction_data = prediction_data, - confidence = confidence, - explanation = explanation, - actual_value = None, - is_correct = None, - created_at = now, + result = PredictionResult( + id=prediction_id, + model_id=model_id, + prediction_type=model.prediction_type, + target_id=input_data.get("target_id"), + prediction_data=prediction_data, + confidence=confidence, + explanation=explanation, + actual_value=None, + is_correct=None, + created_at=now, ) with self._get_db() as conn: @@ -1246,7 +1246,7 @@ class AIManager: def _predict_trend(self, input_data: dict, model: PredictionModel) -> dict: """趋势预测""" - historical_values = input_data.get("historical_values", []) + historical_values = input_data.get("historical_values", []) if len(historical_values) < 2: return { @@ -1257,23 +1257,23 @@ class AIManager: } # 简单线性趋势预测 - 使用最小二乘法计算斜率 - n = len(historical_values) - x = list(range(n)) - y = historical_values + n = len(historical_values) + x = list(range(n)) + y = historical_values # 计算均值 - mean_x = sum(x) / n - mean_y = sum(y) / n + mean_x = sum(x) / n + mean_y = sum(y) / n # 计算斜率 (最小二乘法) - numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n)) - denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) - slope = numerator / denominator if denominator != 0 else 0 + numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n)) + denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) + slope = numerator / denominator if denominator != 0 else 0 # 预测下一个值 - next_value = y[-1] + slope + next_value = y[-1] + slope - trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable" + trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable" return { "predicted_value": round(next_value, 2), @@ -1285,8 +1285,8 @@ class AIManager: def _detect_anomaly(self, input_data: dict, model: PredictionModel) -> dict: """异常检测""" - value = input_data.get("value") - historical_values = input_data.get("historical_values", []) + value = input_data.get("value") + historical_values = input_data.get("historical_values", []) if not historical_values or value is None: return { @@ -1297,15 +1297,15 @@ class AIManager: } # 计算均值和标准差 - mean = statistics.mean(historical_values) - std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0 + mean = statistics.mean(historical_values) + std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0 if std == 0: - is_anomaly = value != mean - z_score = 0 if value == mean else 3 + is_anomaly = value != mean + z_score = 0 if value == mean else 3 else: - z_score = abs(value - mean) / std - is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常 + z_score = abs(value - mean) / std + is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常 return { "is_anomaly": is_anomaly, @@ -1319,7 +1319,7 @@ class AIManager: def _predict_entity_growth(self, input_data: dict, model: PredictionModel) -> dict: """实体增长预测""" - entity_history = input_data.get("entity_history", []) + entity_history = input_data.get("entity_history", []) if len(entity_history) < 3: return { @@ -1330,14 +1330,14 @@ class AIManager: } # 计算增长率 - counts = [h.get("count", 0) for h in entity_history] - growth_rates = [ + counts = [h.get("count", 0) for h in entity_history] + growth_rates = [ (counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts)) ] - avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0 + avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0 # 预测下一个周期的实体数量 - predicted_count = counts[-1] * (1 + avg_growth_rate) + predicted_count = counts[-1] * (1 + avg_growth_rate) return { "predicted_count": round(predicted_count), @@ -1349,7 +1349,7 @@ class AIManager: def _predict_relation_evolution(self, input_data: dict, model: PredictionModel) -> dict: """关系演变预测""" - relation_history = input_data.get("relation_history", []) + relation_history = input_data.get("relation_history", []) if len(relation_history) < 2: return { @@ -1359,16 +1359,16 @@ class AIManager: } # 分析关系变化趋势 - relation_counts = defaultdict(int) + relation_counts = defaultdict(int) for snapshot in relation_history: for rel in snapshot.get("relations", []): relation_counts[rel.get("type", "unknown")] += 1 # 预测可能出现的新关系类型 - predicted_relations = [ + predicted_relations = [ {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} for rel_type, count in sorted( - relation_counts.items(), key = lambda x: x[1], reverse = True + relation_counts.items(), key=lambda x: x[1], reverse=True )[:5] ] @@ -1379,10 +1379,10 @@ class AIManager: "explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势", } - def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]: + def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]: """获取预测结果历史""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM prediction_results WHERE model_id = ? ORDER BY created_at DESC @@ -1410,106 +1410,106 @@ class AIManager: def _row_to_custom_model(self, row) -> CustomModel: """将数据库行转换为 CustomModel""" return CustomModel( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - model_type = ModelType(row["model_type"]), - status = ModelStatus(row["status"]), - training_data = json.loads(row["training_data"]), - hyperparameters = json.loads(row["hyperparameters"]), - metrics = json.loads(row["metrics"]), - model_path = row["model_path"], - created_at = row["created_at"], - updated_at = row["updated_at"], - trained_at = row["trained_at"], - created_by = row["created_by"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + model_type=ModelType(row["model_type"]), + status=ModelStatus(row["status"]), + training_data=json.loads(row["training_data"]), + hyperparameters=json.loads(row["hyperparameters"]), + metrics=json.loads(row["metrics"]), + model_path=row["model_path"], + created_at=row["created_at"], + updated_at=row["updated_at"], + trained_at=row["trained_at"], + created_by=row["created_by"], ) def _row_to_training_sample(self, row) -> TrainingSample: """将数据库行转换为 TrainingSample""" return TrainingSample( - id = row["id"], - model_id = row["model_id"], - text = row["text"], - entities = json.loads(row["entities"]), - metadata = json.loads(row["metadata"]), - created_at = row["created_at"], + id=row["id"], + model_id=row["model_id"], + text=row["text"], + entities=json.loads(row["entities"]), + metadata=json.loads(row["metadata"]), + created_at=row["created_at"], ) def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis: """将数据库行转换为 MultimodalAnalysis""" return MultimodalAnalysis( - id = row["id"], - tenant_id = row["tenant_id"], - project_id = row["project_id"], - provider = MultimodalProvider(row["provider"]), - input_type = row["input_type"], - input_urls = json.loads(row["input_urls"]), - prompt = row["prompt"], - result = json.loads(row["result"]), - tokens_used = row["tokens_used"], - cost = row["cost"], - created_at = row["created_at"], + id=row["id"], + tenant_id=row["tenant_id"], + project_id=row["project_id"], + provider=MultimodalProvider(row["provider"]), + input_type=row["input_type"], + input_urls=json.loads(row["input_urls"]), + prompt=row["prompt"], + result=json.loads(row["result"]), + tokens_used=row["tokens_used"], + cost=row["cost"], + created_at=row["created_at"], ) def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG: """将数据库行转换为 KnowledgeGraphRAG""" return KnowledgeGraphRAG( - id = row["id"], - tenant_id = row["tenant_id"], - project_id = row["project_id"], - name = row["name"], - description = row["description"], - kg_config = json.loads(row["kg_config"]), - retrieval_config = json.loads(row["retrieval_config"]), - generation_config = json.loads(row["generation_config"]), - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + project_id=row["project_id"], + name=row["name"], + description=row["description"], + kg_config=json.loads(row["kg_config"]), + retrieval_config=json.loads(row["retrieval_config"]), + generation_config=json.loads(row["generation_config"]), + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_prediction_model(self, row) -> PredictionModel: """将数据库行转换为 PredictionModel""" return PredictionModel( - id = row["id"], - tenant_id = row["tenant_id"], - project_id = row["project_id"], - name = row["name"], - prediction_type = PredictionType(row["prediction_type"]), - target_entity_type = row["target_entity_type"], - features = json.loads(row["features"]), - model_config = json.loads(row["model_config"]), - accuracy = row["accuracy"], - last_trained_at = row["last_trained_at"], - prediction_count = row["prediction_count"], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + project_id=row["project_id"], + name=row["name"], + prediction_type=PredictionType(row["prediction_type"]), + target_entity_type=row["target_entity_type"], + features=json.loads(row["features"]), + model_config=json.loads(row["model_config"]), + accuracy=row["accuracy"], + last_trained_at=row["last_trained_at"], + prediction_count=row["prediction_count"], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_prediction_result(self, row) -> PredictionResult: """将数据库行转换为 PredictionResult""" return PredictionResult( - id = row["id"], - model_id = row["model_id"], - prediction_type = PredictionType(row["prediction_type"]), - target_id = row["target_id"], - prediction_data = json.loads(row["prediction_data"]), - confidence = row["confidence"], - explanation = row["explanation"], - actual_value = row["actual_value"], - is_correct = row["is_correct"], - created_at = row["created_at"], + id=row["id"], + model_id=row["model_id"], + prediction_type=PredictionType(row["prediction_type"]), + target_id=row["target_id"], + prediction_data=json.loads(row["prediction_data"]), + confidence=row["confidence"], + explanation=row["explanation"], + actual_value=row["actual_value"], + is_correct=row["is_correct"], + created_at=row["created_at"], ) # Singleton instance -_ai_manager = None +_ai_manager = None def get_ai_manager() -> AIManager: global _ai_manager if _ai_manager is None: - _ai_manager = AIManager() + _ai_manager = AIManager() return _ai_manager diff --git a/backend/api_key_manager.py b/backend/api_key_manager.py index 6a81461..9e16478 100644 --- a/backend/api_key_manager.py +++ b/backend/api_key_manager.py @@ -13,13 +13,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum -DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") +DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") class ApiKeyStatus(Enum): - ACTIVE = "active" - REVOKED = "revoked" - EXPIRED = "expired" + ACTIVE = "active" + REVOKED = "revoked" + EXPIRED = "expired" @dataclass @@ -37,18 +37,18 @@ class ApiKey: last_used_at: str | None revoked_at: str | None revoked_reason: str | None - total_calls: int = 0 + total_calls: int = 0 class ApiKeyManager: """API Key 管理器""" # Key 前缀 - KEY_PREFIX = "ak_live_" - KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40) + KEY_PREFIX = "ak_live_" + KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40) - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path self._init_db() def _init_db(self) -> None: @@ -117,7 +117,7 @@ class ApiKeyManager: def _generate_key(self) -> str: """生成新的 API Key""" # 生成 40 字符的随机字符串 - random_part = secrets.token_urlsafe(30)[:40] + random_part = secrets.token_urlsafe(30)[:40] return f"{self.KEY_PREFIX}{random_part}" def _hash_key(self, key: str) -> str: @@ -131,10 +131,10 @@ class ApiKeyManager: def create_key( self, name: str, - owner_id: str | None = None, - permissions: list[str] = None, - rate_limit: int = 60, - expires_days: int | None = None, + owner_id: str | None = None, + permissions: list[str] = None, + rate_limit: int = 60, + expires_days: int | None = None, ) -> tuple[str, ApiKey]: """ 创建新的 API Key @@ -143,32 +143,32 @@ class ApiKeyManager: tuple: (原始key(仅返回一次), ApiKey对象) """ if permissions is None: - permissions = ["read"] + permissions = ["read"] - key_id = secrets.token_hex(16) - raw_key = self._generate_key() - key_hash = self._hash_key(raw_key) - key_preview = self._get_preview(raw_key) + key_id = secrets.token_hex(16) + raw_key = self._generate_key() + key_hash = self._hash_key(raw_key) + key_preview = self._get_preview(raw_key) - expires_at = None + expires_at = None if expires_days: - expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat() + expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() - api_key = ApiKey( - id = key_id, - key_hash = key_hash, - key_preview = key_preview, - name = name, - owner_id = owner_id, - permissions = permissions, - rate_limit = rate_limit, - status = ApiKeyStatus.ACTIVE.value, - created_at = datetime.now().isoformat(), - expires_at = expires_at, - last_used_at = None, - revoked_at = None, - revoked_reason = None, - total_calls = 0, + api_key = ApiKey( + id=key_id, + key_hash=key_hash, + key_preview=key_preview, + name=name, + owner_id=owner_id, + permissions=permissions, + rate_limit=rate_limit, + status=ApiKeyStatus.ACTIVE.value, + created_at=datetime.now().isoformat(), + expires_at=expires_at, + last_used_at=None, + revoked_at=None, + revoked_reason=None, + total_calls=0, ) with sqlite3.connect(self.db_path) as conn: @@ -203,16 +203,16 @@ class ApiKeyManager: Returns: ApiKey if valid, None otherwise """ - key_hash = self._hash_key(key) + key_hash = self._hash_key(key) with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone() + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone() if not row: return None - api_key = self._row_to_api_key(row) + api_key = self._row_to_api_key(row) # 检查状态 if api_key.status != ApiKeyStatus.ACTIVE.value: @@ -220,7 +220,7 @@ class ApiKeyManager: # 检查是否过期 if api_key.expires_at: - expires = datetime.fromisoformat(api_key.expires_at) + expires = datetime.fromisoformat(api_key.expires_at) if datetime.now() > expires: # 更新状态为过期 conn.execute( @@ -232,18 +232,18 @@ class ApiKeyManager: return api_key - def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool: + def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool: """撤销 API Key""" with sqlite3.connect(self.db_path) as conn: # 验证所有权(如果提供了 owner_id) if owner_id: - row = conn.execute( + row = conn.execute( "SELECT owner_id FROM api_keys WHERE id = ?", (key_id, ) ).fetchone() if not row or row[0] != owner_id: return False - cursor = conn.execute( + cursor = conn.execute( """ UPDATE api_keys SET status = ?, revoked_at = ?, revoked_reason = ? @@ -260,17 +260,17 @@ class ApiKeyManager: conn.commit() return cursor.rowcount > 0 - def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None: + def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None: """通过 ID 获取 API Key(不包含敏感信息)""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row if owner_id: - row = conn.execute( + row = conn.execute( "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id) ).fetchone() else: - row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone() + row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone() if row: return self._row_to_api_key(row) @@ -278,17 +278,17 @@ class ApiKeyManager: def list_keys( self, - owner_id: str | None = None, - status: str | None = None, - limit: int = 100, - offset: int = 0, + owner_id: str | None = None, + status: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[ApiKey]: """列出 API Keys""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row - query = "SELECT * FROM api_keys WHERE 1 = 1" - params = [] + query = "SELECT * FROM api_keys WHERE 1 = 1" + params = [] if owner_id: query += " AND owner_id = ?" @@ -301,20 +301,20 @@ class ApiKeyManager: query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_api_key(row) for row in rows] def update_key( self, key_id: str, - name: str | None = None, - permissions: list[str] | None = None, - rate_limit: int | None = None, - owner_id: str | None = None, + name: str | None = None, + permissions: list[str] | None = None, + rate_limit: int | None = None, + owner_id: str | None = None, ) -> bool: """更新 API Key 信息""" - updates = [] - params = [] + updates = [] + params = [] if name is not None: updates.append("name = ?") @@ -336,14 +336,14 @@ class ApiKeyManager: with sqlite3.connect(self.db_path) as conn: # 验证所有权 if owner_id: - row = conn.execute( + row = conn.execute( "SELECT owner_id FROM api_keys WHERE id = ?", (key_id, ) ).fetchone() if not row or row[0] != owner_id: return False - query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?" - cursor = conn.execute(query, params) + query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?" + cursor = conn.execute(query, params) conn.commit() return cursor.rowcount > 0 @@ -365,11 +365,11 @@ class ApiKeyManager: api_key_id: str, endpoint: str, method: str, - status_code: int = 200, - response_time_ms: int = 0, - ip_address: str = "", - user_agent: str = "", - error_message: str = "", + status_code: int = 200, + response_time_ms: int = 0, + ip_address: str = "", + user_agent: str = "", + error_message: str = "", ) -> None: """记录 API 调用日志""" with sqlite3.connect(self.db_path) as conn: @@ -395,18 +395,18 @@ class ApiKeyManager: def get_call_logs( self, - api_key_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - limit: int = 100, - offset: int = 0, + api_key_id: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[dict]: """获取 API 调用日志""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row - query = "SELECT * FROM api_call_logs WHERE 1 = 1" - params = [] + query = "SELECT * FROM api_call_logs WHERE 1 = 1" + params = [] if api_key_id: query += " AND api_key_id = ?" @@ -423,16 +423,16 @@ class ApiKeyManager: query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [dict(row) for row in rows] - def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict: + def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict: """获取 API 调用统计""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row # 总体统计 - query = f""" + query = f""" SELECT COUNT(*) as total_calls, COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls, @@ -444,15 +444,15 @@ class ApiKeyManager: WHERE created_at >= date('now', '-{days} days') """ - params = [] + params = [] if api_key_id: - query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") + query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") params.insert(0, api_key_id) - row = conn.execute(query, params).fetchone() + row = conn.execute(query, params).fetchone() # 按端点统计 - endpoint_query = f""" + endpoint_query = f""" SELECT endpoint, method, @@ -462,19 +462,19 @@ class ApiKeyManager: WHERE created_at >= date('now', '-{days} days') """ - endpoint_params = [] + endpoint_params = [] if api_key_id: - endpoint_query = endpoint_query.replace( + endpoint_query = endpoint_query.replace( "WHERE created_at", "WHERE api_key_id = ? AND created_at" ) endpoint_params.insert(0, api_key_id) endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC" - endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall() + endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall() # 按天统计 - daily_query = f""" + daily_query = f""" SELECT date(created_at) as date, COUNT(*) as calls, @@ -483,16 +483,16 @@ class ApiKeyManager: WHERE created_at >= date('now', '-{days} days') """ - daily_params = [] + daily_params = [] if api_key_id: - daily_query = daily_query.replace( + daily_query = daily_query.replace( "WHERE created_at", "WHERE api_key_id = ? AND created_at" ) daily_params.insert(0, api_key_id) daily_query += " GROUP BY date(created_at) ORDER BY date" - daily_rows = conn.execute(daily_query, daily_params).fetchall() + daily_rows = conn.execute(daily_query, daily_params).fetchall() return { "summary": { @@ -510,30 +510,30 @@ class ApiKeyManager: def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey: """将数据库行转换为 ApiKey 对象""" return ApiKey( - id = row["id"], - key_hash = row["key_hash"], - key_preview = row["key_preview"], - name = row["name"], - owner_id = row["owner_id"], - permissions = json.loads(row["permissions"]), - rate_limit = row["rate_limit"], - status = row["status"], - created_at = row["created_at"], - expires_at = row["expires_at"], - last_used_at = row["last_used_at"], - revoked_at = row["revoked_at"], - revoked_reason = row["revoked_reason"], - total_calls = row["total_calls"], + id=row["id"], + key_hash=row["key_hash"], + key_preview=row["key_preview"], + name=row["name"], + owner_id=row["owner_id"], + permissions=json.loads(row["permissions"]), + rate_limit=row["rate_limit"], + status=row["status"], + created_at=row["created_at"], + expires_at=row["expires_at"], + last_used_at=row["last_used_at"], + revoked_at=row["revoked_at"], + revoked_reason=row["revoked_reason"], + total_calls=row["total_calls"], ) # 全局实例 -_api_key_manager: ApiKeyManager | None = None +_api_key_manager: ApiKeyManager | None = None def get_api_key_manager() -> ApiKeyManager: """获取 API Key 管理器实例""" global _api_key_manager if _api_key_manager is None: - _api_key_manager = ApiKeyManager() + _api_key_manager = ApiKeyManager() return _api_key_manager diff --git a/backend/collaboration_manager.py b/backend/collaboration_manager.py index 0c33ad5..0c316ec 100644 --- a/backend/collaboration_manager.py +++ b/backend/collaboration_manager.py @@ -15,29 +15,29 @@ from typing import Any class SharePermission(Enum): """分享权限级别""" - READ_ONLY = "read_only" # 只读 - COMMENT = "comment" # 可评论 - EDIT = "edit" # 可编辑 - ADMIN = "admin" # 管理员 + READ_ONLY = "read_only" # 只读 + COMMENT = "comment" # 可评论 + EDIT = "edit" # 可编辑 + ADMIN = "admin" # 管理员 class CommentTargetType(Enum): """评论目标类型""" - ENTITY = "entity" # 实体评论 - RELATION = "relation" # 关系评论 - TRANSCRIPT = "transcript" # 转录文本评论 - PROJECT = "project" # 项目级评论 + ENTITY = "entity" # 实体评论 + RELATION = "relation" # 关系评论 + TRANSCRIPT = "transcript" # 转录文本评论 + PROJECT = "project" # 项目级评论 class ChangeType(Enum): """变更类型""" - CREATE = "create" # 创建 - UPDATE = "update" # 更新 - DELETE = "delete" # 删除 - MERGE = "merge" # 合并 - SPLIT = "split" # 拆分 + CREATE = "create" # 创建 + UPDATE = "update" # 更新 + DELETE = "delete" # 删除 + MERGE = "merge" # 合并 + SPLIT = "split" # 拆分 @dataclass @@ -136,10 +136,10 @@ class TeamSpace: class CollaborationManager: """协作管理主类""" - def __init__(self, db_manager = None) -> None: - self.db = db_manager - self._shares_cache: dict[str, ProjectShare] = {} - self._comments_cache: dict[str, list[Comment]] = {} + def __init__(self, db_manager=None) -> None: + self.db = db_manager + self._shares_cache: dict[str, ProjectShare] = {} + self._comments_cache: dict[str, list[Comment]] = {} # ============ 项目分享 ============ @@ -147,57 +147,57 @@ class CollaborationManager: self, project_id: str, created_by: str, - permission: str = "read_only", - expires_in_days: int | None = None, - max_uses: int | None = None, - password: str | None = None, - allow_download: bool = False, - allow_export: bool = False, + permission: str = "read_only", + expires_in_days: int | None = None, + max_uses: int | None = None, + password: str | None = None, + allow_download: bool = False, + allow_export: bool = False, ) -> ProjectShare: """创建项目分享链接""" - share_id = str(uuid.uuid4()) - token = self._generate_share_token(project_id) + share_id = str(uuid.uuid4()) + token = self._generate_share_token(project_id) - now = datetime.now().isoformat() - expires_at = None + now = datetime.now().isoformat() + expires_at = None if expires_in_days: - expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat() + expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat() - password_hash = None + password_hash = None if password: - password_hash = hashlib.sha256(password.encode()).hexdigest() + password_hash = hashlib.sha256(password.encode()).hexdigest() - share = ProjectShare( - id = share_id, - project_id = project_id, - token = token, - permission = permission, - created_by = created_by, - created_at = now, - expires_at = expires_at, - max_uses = max_uses, - use_count = 0, - password_hash = password_hash, - is_active = True, - allow_download = allow_download, - allow_export = allow_export, + share = ProjectShare( + id=share_id, + project_id=project_id, + token=token, + permission=permission, + created_by=created_by, + created_at=now, + expires_at=expires_at, + max_uses=max_uses, + use_count=0, + password_hash=password_hash, + is_active=True, + allow_download=allow_download, + allow_export=allow_export, ) # 保存到数据库 if self.db: self._save_share_to_db(share) - self._shares_cache[token] = share + self._shares_cache[token] = share return share def _generate_share_token(self, project_id: str) -> str: """生成分享令牌""" - data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}" + data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}" return hashlib.sha256(data.encode()).hexdigest()[:32] def _save_share_to_db(self, share: ProjectShare) -> None: """保存分享记录到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO project_shares @@ -224,12 +224,12 @@ class CollaborationManager: ) self.db.conn.commit() - def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None: + def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None: """验证分享令牌""" # 从缓存或数据库获取 - share = self._shares_cache.get(token) + share = self._shares_cache.get(token) if not share and self.db: - share = self._get_share_from_db(token) + share = self._get_share_from_db(token) if not share: return None @@ -250,7 +250,7 @@ class CollaborationManager: if share.password_hash: if not password: return None - password_hash = hashlib.sha256(password.encode()).hexdigest() + password_hash = hashlib.sha256(password.encode()).hexdigest() if password_hash != share.password_hash: return None @@ -258,42 +258,42 @@ class CollaborationManager: def _get_share_from_db(self, token: str) -> ProjectShare | None: """从数据库获取分享记录""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM project_shares WHERE token = ? """, (token, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return None return ProjectShare( - id = row[0], - project_id = row[1], - token = row[2], - permission = row[3], - created_by = row[4], - created_at = row[5], - expires_at = row[6], - max_uses = row[7], - use_count = row[8], - password_hash = row[9], - is_active = bool(row[10]), - allow_download = bool(row[11]), - allow_export = bool(row[12]), + id=row[0], + project_id=row[1], + token=row[2], + permission=row[3], + created_by=row[4], + created_at=row[5], + expires_at=row[6], + max_uses=row[7], + use_count=row[8], + password_hash=row[9], + is_active=bool(row[10]), + allow_download=bool(row[11]), + allow_export=bool(row[12]), ) def increment_share_usage(self, token: str) -> None: """增加分享链接使用次数""" - share = self._shares_cache.get(token) + share = self._shares_cache.get(token) if share: share.use_count += 1 if self.db: - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE project_shares @@ -307,7 +307,7 @@ class CollaborationManager: def revoke_share_link(self, share_id: str, revoked_by: str) -> bool: """撤销分享链接""" if self.db: - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE project_shares @@ -325,7 +325,7 @@ class CollaborationManager: if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM project_shares @@ -335,23 +335,23 @@ class CollaborationManager: (project_id, ), ) - shares = [] + shares = [] for row in cursor.fetchall(): shares.append( ProjectShare( - id = row[0], - project_id = row[1], - token = row[2], - permission = row[3], - created_by = row[4], - created_at = row[5], - expires_at = row[6], - max_uses = row[7], - use_count = row[8], - password_hash = row[9], - is_active = bool(row[10]), - allow_download = bool(row[11]), - allow_export = bool(row[12]), + id=row[0], + project_id=row[1], + token=row[2], + permission=row[3], + created_by=row[4], + created_at=row[5], + expires_at=row[6], + max_uses=row[7], + use_count=row[8], + password_hash=row[9], + is_active=bool(row[10]), + allow_download=bool(row[11]), + allow_export=bool(row[12]), ) ) return shares @@ -366,46 +366,46 @@ class CollaborationManager: author: str, author_name: str, content: str, - parent_id: str | None = None, - mentions: list[str] | None = None, - attachments: list[dict] | None = None, + parent_id: str | None = None, + mentions: list[str] | None = None, + attachments: list[dict] | None = None, ) -> Comment: """添加评论""" - comment_id = str(uuid.uuid4()) - now = datetime.now().isoformat() + comment_id = str(uuid.uuid4()) + now = datetime.now().isoformat() - comment = Comment( - id = comment_id, - project_id = project_id, - target_type = target_type, - target_id = target_id, - parent_id = parent_id, - author = author, - author_name = author_name, - content = content, - created_at = now, - updated_at = now, - resolved = False, - resolved_by = None, - resolved_at = None, - mentions = mentions or [], - attachments = attachments or [], + comment = Comment( + id=comment_id, + project_id=project_id, + target_type=target_type, + target_id=target_id, + parent_id=parent_id, + author=author, + author_name=author_name, + content=content, + created_at=now, + updated_at=now, + resolved=False, + resolved_by=None, + resolved_at=None, + mentions=mentions or [], + attachments=attachments or [], ) if self.db: self._save_comment_to_db(comment) # 更新缓存 - key = f"{target_type}:{target_id}" + key = f"{target_type}:{target_id}" if key not in self._comments_cache: - self._comments_cache[key] = [] + self._comments_cache[key] = [] self._comments_cache[key].append(comment) return comment def _save_comment_to_db(self, comment: Comment) -> None: """保存评论到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO comments @@ -435,13 +435,13 @@ class CollaborationManager: self.db.conn.commit() def get_comments( - self, target_type: str, target_id: str, include_resolved: bool = True + self, target_type: str, target_id: str, include_resolved: bool = True ) -> list[Comment]: """获取评论列表""" if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() if include_resolved: cursor.execute( """ @@ -461,7 +461,7 @@ class CollaborationManager: (target_type, target_id), ) - comments = [] + comments = [] for row in cursor.fetchall(): comments.append(self._row_to_comment(row)) return comments @@ -469,21 +469,21 @@ class CollaborationManager: def _row_to_comment(self, row) -> Comment: """将数据库行转换为Comment对象""" return Comment( - id = row[0], - project_id = row[1], - target_type = row[2], - target_id = row[3], - parent_id = row[4], - author = row[5], - author_name = row[6], - content = row[7], - created_at = row[8], - updated_at = row[9], - resolved = bool(row[10]), - resolved_by = row[11], - resolved_at = row[12], - mentions = json.loads(row[13]) if row[13] else [], - attachments = json.loads(row[14]) if row[14] else [], + id=row[0], + project_id=row[1], + target_type=row[2], + target_id=row[3], + parent_id=row[4], + author=row[5], + author_name=row[6], + content=row[7], + created_at=row[8], + updated_at=row[9], + resolved=bool(row[10]), + resolved_by=row[11], + resolved_at=row[12], + mentions=json.loads(row[13]) if row[13] else [], + attachments=json.loads(row[14]) if row[14] else [], ) def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None: @@ -491,8 +491,8 @@ class CollaborationManager: if not self.db: return None - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE comments @@ -509,9 +509,9 @@ class CollaborationManager: def _get_comment_by_id(self, comment_id: str) -> Comment | None: """根据ID获取评论""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_comment(row) return None @@ -521,8 +521,8 @@ class CollaborationManager: if not self.db: return False - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE comments @@ -539,7 +539,7 @@ class CollaborationManager: if not self.db: return False - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() # 只允许作者或管理员删除 cursor.execute( """ @@ -554,13 +554,13 @@ class CollaborationManager: return cursor.rowcount > 0 def get_project_comments( - self, project_id: str, limit: int = 50, offset: int = 0 + self, project_id: str, limit: int = 50, offset: int = 0 ) -> list[Comment]: """获取项目下的所有评论""" if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM comments @@ -571,7 +571,7 @@ class CollaborationManager: (project_id, limit, offset), ) - comments = [] + comments = [] for row in cursor.fetchall(): comments.append(self._row_to_comment(row)) return comments @@ -587,32 +587,32 @@ class CollaborationManager: entity_name: str, changed_by: str, changed_by_name: str, - old_value: dict | None = None, - new_value: dict | None = None, - description: str = "", - session_id: str | None = None, + old_value: dict | None = None, + new_value: dict | None = None, + description: str = "", + session_id: str | None = None, ) -> ChangeRecord: """记录变更""" - record_id = str(uuid.uuid4()) - now = datetime.now().isoformat() + record_id = str(uuid.uuid4()) + now = datetime.now().isoformat() - record = ChangeRecord( - id = record_id, - project_id = project_id, - change_type = change_type, - entity_type = entity_type, - entity_id = entity_id, - entity_name = entity_name, - changed_by = changed_by, - changed_by_name = changed_by_name, - changed_at = now, - old_value = old_value, - new_value = new_value, - description = description, - session_id = session_id, - reverted = False, - reverted_at = None, - reverted_by = None, + record = ChangeRecord( + id=record_id, + project_id=project_id, + change_type=change_type, + entity_type=entity_type, + entity_id=entity_id, + entity_name=entity_name, + changed_by=changed_by, + changed_by_name=changed_by_name, + changed_at=now, + old_value=old_value, + new_value=new_value, + description=description, + session_id=session_id, + reverted=False, + reverted_at=None, + reverted_by=None, ) if self.db: @@ -622,7 +622,7 @@ class CollaborationManager: def _save_change_to_db(self, record: ChangeRecord) -> None: """保存变更记录到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO change_history @@ -655,16 +655,16 @@ class CollaborationManager: def get_change_history( self, project_id: str, - entity_type: str | None = None, - entity_id: str | None = None, - limit: int = 50, - offset: int = 0, + entity_type: str | None = None, + entity_id: str | None = None, + limit: int = 50, + offset: int = 0, ) -> list[ChangeRecord]: """获取变更历史""" if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() if entity_type and entity_id: cursor.execute( @@ -697,7 +697,7 @@ class CollaborationManager: (project_id, limit, offset), ) - records = [] + records = [] for row in cursor.fetchall(): records.append(self._row_to_change_record(row)) return records @@ -705,22 +705,22 @@ class CollaborationManager: def _row_to_change_record(self, row) -> ChangeRecord: """将数据库行转换为ChangeRecord对象""" return ChangeRecord( - id = row[0], - project_id = row[1], - change_type = row[2], - entity_type = row[3], - entity_id = row[4], - entity_name = row[5], - changed_by = row[6], - changed_by_name = row[7], - changed_at = row[8], - old_value = json.loads(row[9]) if row[9] else None, - new_value = json.loads(row[10]) if row[10] else None, - description = row[11], - session_id = row[12], - reverted = bool(row[13]), - reverted_at = row[14], - reverted_by = row[15], + id=row[0], + project_id=row[1], + change_type=row[2], + entity_type=row[3], + entity_id=row[4], + entity_name=row[5], + changed_by=row[6], + changed_by_name=row[7], + changed_at=row[8], + old_value=json.loads(row[9]) if row[9] else None, + new_value=json.loads(row[10]) if row[10] else None, + description=row[11], + session_id=row[12], + reverted=bool(row[13]), + reverted_at=row[14], + reverted_by=row[15], ) def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]: @@ -728,7 +728,7 @@ class CollaborationManager: if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM change_history @@ -738,7 +738,7 @@ class CollaborationManager: (entity_type, entity_id), ) - records = [] + records = [] for row in cursor.fetchall(): records.append(self._row_to_change_record(row)) return records @@ -748,8 +748,8 @@ class CollaborationManager: if not self.db: return False - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE change_history @@ -766,7 +766,7 @@ class CollaborationManager: if not self.db: return {} - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() # 总变更数 cursor.execute( @@ -775,7 +775,7 @@ class CollaborationManager: """, (project_id, ), ) - total_changes = cursor.fetchone()[0] + total_changes = cursor.fetchone()[0] # 按类型统计 cursor.execute( @@ -785,7 +785,7 @@ class CollaborationManager: """, (project_id, ), ) - type_counts = {row[0]: row[1] for row in cursor.fetchall()} + type_counts = {row[0]: row[1] for row in cursor.fetchall()} # 按实体类型统计 cursor.execute( @@ -795,7 +795,7 @@ class CollaborationManager: """, (project_id, ), ) - entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()} + entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()} # 最近活跃的用户 cursor.execute( @@ -808,7 +808,7 @@ class CollaborationManager: """, (project_id, ), ) - top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()] + top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()] return { "total_changes": total_changes, @@ -827,27 +827,27 @@ class CollaborationManager: user_email: str, role: str, invited_by: str, - permissions: list[str] | None = None, + permissions: list[str] | None = None, ) -> TeamMember: """添加团队成员""" - member_id = str(uuid.uuid4()) - now = datetime.now().isoformat() + member_id = str(uuid.uuid4()) + now = datetime.now().isoformat() # 根据角色设置默认权限 if permissions is None: - permissions = self._get_default_permissions(role) + permissions = self._get_default_permissions(role) - member = TeamMember( - id = member_id, - project_id = project_id, - user_id = user_id, - user_name = user_name, - user_email = user_email, - role = role, - joined_at = now, - invited_by = invited_by, - last_active_at = None, - permissions = permissions, + member = TeamMember( + id=member_id, + project_id=project_id, + user_id=user_id, + user_name=user_name, + user_email=user_email, + role=role, + joined_at=now, + invited_by=invited_by, + last_active_at=None, + permissions=permissions, ) if self.db: @@ -857,7 +857,7 @@ class CollaborationManager: def _get_default_permissions(self, role: str) -> list[str]: """获取角色的默认权限""" - permissions_map = { + permissions_map = { "owner": ["read", "write", "delete", "share", "admin", "export"], "admin": ["read", "write", "delete", "share", "export"], "editor": ["read", "write", "export"], @@ -868,7 +868,7 @@ class CollaborationManager: def _save_member_to_db(self, member: TeamMember) -> None: """保存成员到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO team_members @@ -896,7 +896,7 @@ class CollaborationManager: if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM team_members WHERE project_id = ? @@ -905,7 +905,7 @@ class CollaborationManager: (project_id, ), ) - members = [] + members = [] for row in cursor.fetchall(): members.append(self._row_to_team_member(row)) return members @@ -913,16 +913,16 @@ class CollaborationManager: def _row_to_team_member(self, row) -> TeamMember: """将数据库行转换为TeamMember对象""" return TeamMember( - id = row[0], - project_id = row[1], - user_id = row[2], - user_name = row[3], - user_email = row[4], - role = row[5], - joined_at = row[6], - invited_by = row[7], - last_active_at = row[8], - permissions = json.loads(row[9]) if row[9] else [], + id=row[0], + project_id=row[1], + user_id=row[2], + user_name=row[3], + user_email=row[4], + role=row[5], + joined_at=row[6], + invited_by=row[7], + last_active_at=row[8], + permissions=json.loads(row[9]) if row[9] else [], ) def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool: @@ -930,8 +930,8 @@ class CollaborationManager: if not self.db: return False - permissions = self._get_default_permissions(new_role) - cursor = self.db.conn.cursor() + permissions = self._get_default_permissions(new_role) + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE team_members @@ -948,7 +948,7 @@ class CollaborationManager: if not self.db: return False - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id, )) self.db.conn.commit() return cursor.rowcount > 0 @@ -958,7 +958,7 @@ class CollaborationManager: if not self.db: return False - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT permissions FROM team_members @@ -967,11 +967,11 @@ class CollaborationManager: (project_id, user_id), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return False - permissions = json.loads(row[0]) if row[0] else [] + permissions = json.loads(row[0]) if row[0] else [] return permission in permissions or "admin" in permissions def update_last_active(self, project_id: str, user_id: str) -> None: @@ -979,8 +979,8 @@ class CollaborationManager: if not self.db: return - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE team_members @@ -993,12 +993,12 @@ class CollaborationManager: # 全局协作管理器实例 -_collaboration_manager = None +_collaboration_manager = None -def get_collaboration_manager(db_manager = None) -> None: +def get_collaboration_manager(db_manager=None) -> None: """获取协作管理器单例""" global _collaboration_manager if _collaboration_manager is None: - _collaboration_manager = CollaborationManager(db_manager) + _collaboration_manager = CollaborationManager(db_manager) return _collaboration_manager diff --git a/backend/db_manager.py b/backend/db_manager.py index 37c1e0f..ab47b79 100644 --- a/backend/db_manager.py +++ b/backend/db_manager.py @@ -12,19 +12,19 @@ import uuid from dataclasses import dataclass from datetime import datetime -DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") +DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 @dataclass class Project: id: str name: str - description: str = "" - created_at: str = "" - updated_at: str = "" + description: str = "" + created_at: str = "" + updated_at: str = "" @dataclass @@ -33,19 +33,19 @@ class Entity: project_id: str name: str type: str - definition: str = "" - canonical_name: str = "" - aliases: list[str] = None - embedding: str = "" # Phase 3: 实体嵌入向量 - attributes: dict = None # Phase 5: 实体属性 - created_at: str = "" - updated_at: str = "" + definition: str = "" + canonical_name: str = "" + aliases: list[str] = None + embedding: str = "" # Phase 3: 实体嵌入向量 + attributes: dict = None # Phase 5: 实体属性 + created_at: str = "" + updated_at: str = "" def __post_init__(self) -> None: if self.aliases is None: - self.aliases = [] + self.aliases = [] if self.attributes is None: - self.attributes = {} + self.attributes = {} @dataclass @@ -56,17 +56,17 @@ class AttributeTemplate: project_id: str name: str type: str # text, number, date, select, multiselect, boolean - options: list[str] = None # 用于 select/multiselect - default_value: str = "" - description: str = "" - is_required: bool = False - sort_order: int = 0 - created_at: str = "" - updated_at: str = "" + options: list[str] = None # 用于 select/multiselect + default_value: str = "" + description: str = "" + is_required: bool = False + sort_order: int = 0 + created_at: str = "" + updated_at: str = "" def __post_init__(self) -> None: if self.options is None: - self.options = [] + self.options = [] @dataclass @@ -75,19 +75,19 @@ class EntityAttribute: id: str entity_id: str - template_id: str | None = None - name: str = "" # 属性名称 - type: str = "text" # 属性类型 - value: str = "" - options: list[str] = None # 选项列表 - template_name: str = "" # 关联查询时填充 - template_type: str = "" # 关联查询时填充 - created_at: str = "" - updated_at: str = "" + template_id: str | None = None + name: str = "" # 属性名称 + type: str = "text" # 属性类型 + value: str = "" + options: list[str] = None # 选项列表 + template_name: str = "" # 关联查询时填充 + template_type: str = "" # 关联查询时填充 + created_at: str = "" + updated_at: str = "" def __post_init__(self) -> None: if self.options is None: - self.options = [] + self.options = [] @dataclass @@ -96,12 +96,12 @@ class AttributeHistory: id: str entity_id: str - attribute_name: str = "" # 属性名称 - old_value: str = "" - new_value: str = "" - changed_by: str = "" - changed_at: str = "" - change_reason: str = "" + attribute_name: str = "" # 属性名称 + old_value: str = "" + new_value: str = "" + changed_by: str = "" + changed_at: str = "" + change_reason: str = "" @dataclass @@ -112,35 +112,35 @@ class EntityMention: start_pos: int end_pos: int text_snippet: str - confidence: float = 1.0 + confidence: float = 1.0 class DatabaseManager: - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path - os.makedirs(os.path.dirname(db_path), exist_ok = True) + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + os.makedirs(os.path.dirname(db_path), exist_ok=True) self.init_db() def get_conn(self) -> None: - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def init_db(self) -> None: """初始化数据库表""" with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f: - schema = f.read() + schema = f.read() - conn = self.get_conn() + conn = self.get_conn() conn.executescript(schema) conn.commit() conn.close() # ==================== Project Operations ==================== - def create_project(self, project_id: str, name: str, description: str = "") -> Project: - conn = self.get_conn() - now = datetime.now().isoformat() + def create_project(self, project_id: str, name: str, description: str = "") -> Project: + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO projects (id, name, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?)""", @@ -149,27 +149,27 @@ class DatabaseManager: conn.commit() conn.close() return Project( - id = project_id, name = name, description = description, created_at = now, updated_at = now + id=project_id, name=name, description=description, created_at=now, updated_at=now ) def get_project(self, project_id: str) -> Project | None: - conn = self.get_conn() - row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone() conn.close() if row: return Project(**dict(row)) return None def list_projects(self) -> list[Project]: - conn = self.get_conn() - rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() + conn = self.get_conn() + rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() conn.close() return [Project(**dict(r)) for r in rows] # ==================== Entity Operations ==================== def create_entity(self, entity: Entity) -> Entity: - conn = self.get_conn() + conn = self.get_conn() conn.execute( """INSERT INTO entities (id, project_id, name, canonical_name, type, definition, aliases, created_at, updated_at) @@ -192,48 +192,48 @@ class DatabaseManager: def get_entity_by_name(self, project_id: str, name: str) -> Entity | None: """通过名称查找实体(用于对齐)""" - conn = self.get_conn() - row = conn.execute( + conn = self.get_conn() + row = conn.execute( """SELECT * FROM entities WHERE project_id = ? AND (name = ? OR canonical_name = ? OR aliases LIKE ?)""", (project_id, name, name, f'%"{name}"%'), ).fetchone() conn.close() if row: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] return Entity(**data) return None def find_similar_entities( - self, project_id: str, name: str, threshold: float = 0.8 + self, project_id: str, name: str, threshold: float = 0.8 ) -> list[Entity]: """查找相似实体""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM entities WHERE project_id = ? AND name LIKE ?", (project_id, f"%{name}%") ).fetchall() conn.close() - entities = [] + entities = [] for row in rows: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] entities.append(Entity(**data)) return entities def merge_entities(self, target_id: str, source_id: str) -> Entity: """合并两个实体""" - conn = self.get_conn() + conn = self.get_conn() - target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id, )).fetchone() - source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id, )).fetchone() + target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id, )).fetchone() + source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id, )).fetchone() if not target or not source: conn.close() raise ValueError("Entity not found") - target_aliases = set(json.loads(target["aliases"]) if target["aliases"] else []) + target_aliases = set(json.loads(target["aliases"]) if target["aliases"] else []) target_aliases.add(source["name"]) target_aliases.update(json.loads(source["aliases"]) if source["aliases"] else []) @@ -259,36 +259,36 @@ class DatabaseManager: return self.get_entity(target_id) def get_entity(self, entity_id: str) -> Entity | None: - conn = self.get_conn() - row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone() conn.close() if row: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] return Entity(**data) return None def list_project_entities(self, project_id: str) -> list[Entity]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id, ) ).fetchall() conn.close() - entities = [] + entities = [] for row in rows: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] entities.append(Entity(**data)) return entities def update_entity(self, entity_id: str, **kwargs) -> Entity: """更新实体信息""" - conn = self.get_conn() + conn = self.get_conn() - allowed_fields = ["name", "type", "definition", "canonical_name"] - updates = [] - values = [] + allowed_fields = ["name", "type", "definition", "canonical_name"] + updates = [] + values = [] for field in allowed_fields: if field in kwargs: @@ -307,7 +307,7 @@ class DatabaseManager: values.append(datetime.now().isoformat()) values.append(entity_id) - query = f"UPDATE entities SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE entities SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -315,7 +315,7 @@ class DatabaseManager: def delete_entity(self, entity_id: str) -> None: """删除实体及其关联数据""" - conn = self.get_conn() + conn = self.get_conn() conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id, )) conn.execute( "DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?", @@ -329,7 +329,7 @@ class DatabaseManager: # ==================== Mention Operations ==================== def add_mention(self, mention: EntityMention) -> EntityMention: - conn = self.get_conn() + conn = self.get_conn() conn.execute( """INSERT INTO entity_mentions (id, entity_id, transcript_id, start_pos, end_pos, text_snippet, confidence) @@ -349,8 +349,8 @@ class DatabaseManager: return mention def get_entity_mentions(self, entity_id: str) -> list[EntityMention]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id, ), ).fetchall() @@ -365,10 +365,10 @@ class DatabaseManager: project_id: str, filename: str, full_text: str, - transcript_type: str = "audio", + transcript_type: str = "audio", ) -> None: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO transcripts (id, project_id, filename, full_text, type, created_at) @@ -379,28 +379,28 @@ class DatabaseManager: conn.close() def get_transcript(self, transcript_id: str) -> dict | None: - conn = self.get_conn() - row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone() conn.close() return dict(row) if row else None def list_project_transcripts(self, project_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id, ) ).fetchall() conn.close() return [dict(r) for r in rows] def update_transcript(self, transcript_id: str, full_text: str) -> dict: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( "UPDATE transcripts SET full_text = ?, updated_at = ? WHERE id = ?", (full_text, now, transcript_id), ) conn.commit() - row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone() + row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone() conn.close() return dict(row) if row else None @@ -411,13 +411,13 @@ class DatabaseManager: project_id: str, source_entity_id: str, target_entity_id: str, - relation_type: str = "related", - evidence: str = "", - transcript_id: str = "", + relation_type: str = "related", + evidence: str = "", + transcript_id: str = "", ) -> None: - conn = self.get_conn() - relation_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + conn = self.get_conn() + relation_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() conn.execute( """INSERT INTO entity_relations (id, project_id, source_entity_id, target_entity_id, relation_type, @@ -439,8 +439,8 @@ class DatabaseManager: return relation_id def get_entity_relations(self, entity_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT * FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ? ORDER BY created_at DESC""", @@ -450,8 +450,8 @@ class DatabaseManager: return [dict(r) for r in rows] def list_project_relations(self, project_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id, ), ).fetchall() @@ -459,10 +459,10 @@ class DatabaseManager: return [dict(r) for r in rows] def update_relation(self, relation_id: str, **kwargs) -> dict: - conn = self.get_conn() - allowed_fields = ["relation_type", "evidence"] - updates = [] - values = [] + conn = self.get_conn() + allowed_fields = ["relation_type", "evidence"] + updates = [] + values = [] for field in allowed_fields: if field in kwargs: @@ -470,26 +470,26 @@ class DatabaseManager: values.append(kwargs[field]) if updates: - query = f"UPDATE entity_relations SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE entity_relations SET {', '.join(updates)} WHERE id = ?" values.append(relation_id) conn.execute(query, values) conn.commit() - row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id, )).fetchone() + row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id, )).fetchone() conn.close() return dict(row) if row else None def delete_relation(self, relation_id: str) -> None: - conn = self.get_conn() + conn = self.get_conn() conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id, )) conn.commit() conn.close() # ==================== Glossary Operations ==================== - def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str: - conn = self.get_conn() - existing = conn.execute( + def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str: + conn = self.get_conn() + existing = conn.execute( "SELECT * FROM glossary WHERE project_id = ? AND term = ?", (project_id, term) ).fetchone() @@ -501,7 +501,7 @@ class DatabaseManager: conn.close() return existing["id"] - term_id = str(uuid.uuid4())[:UUID_LENGTH] + term_id = str(uuid.uuid4())[:UUID_LENGTH] conn.execute( """INSERT INTO glossary (id, project_id, term, pronunciation, frequency) @@ -513,15 +513,15 @@ class DatabaseManager: return term_id def list_glossary(self, project_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id, ) ).fetchall() conn.close() return [dict(r) for r in rows] def delete_glossary_term(self, term_id: str) -> None: - conn = self.get_conn() + conn = self.get_conn() conn.execute("DELETE FROM glossary WHERE id = ?", (term_id, )) conn.commit() conn.close() @@ -529,8 +529,8 @@ class DatabaseManager: # ==================== Phase 4: Agent & Provenance ==================== def get_relation_with_details(self, relation_id: str) -> dict | None: - conn = self.get_conn() - row = conn.execute( + conn = self.get_conn() + row = conn.execute( """SELECT r.*, s.name as source_name, t.name as target_name, tr.filename as transcript_filename, tr.full_text as transcript_text @@ -545,26 +545,26 @@ class DatabaseManager: return dict(row) if row else None def get_entity_with_mentions(self, entity_id: str) -> dict | None: - conn = self.get_conn() - entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone() + conn = self.get_conn() + entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone() if not entity_row: conn.close() return None - entity = dict(entity_row) - entity["aliases"] = json.loads(entity["aliases"]) if entity["aliases"] else [] + entity = dict(entity_row) + entity["aliases"] = json.loads(entity["aliases"]) if entity["aliases"] else [] - mentions = conn.execute( + mentions = conn.execute( """SELECT m.*, t.filename, t.created_at as transcript_date FROM entity_mentions m JOIN transcripts t ON m.transcript_id = t.id WHERE m.entity_id = ? ORDER BY t.created_at, m.start_pos""", (entity_id, ), ).fetchall() - entity["mentions"] = [dict(m) for m in mentions] - entity["mention_count"] = len(mentions) + entity["mentions"] = [dict(m) for m in mentions] + entity["mention_count"] = len(mentions) - relations = conn.execute( + relations = conn.execute( """SELECT r.*, s.name as source_name, t.name as target_name FROM entity_relations r JOIN entities s ON r.source_entity_id = s.id @@ -573,14 +573,14 @@ class DatabaseManager: ORDER BY r.created_at DESC""", (entity_id, entity_id), ).fetchall() - entity["relations"] = [dict(r) for r in relations] + entity["relations"] = [dict(r) for r in relations] conn.close() return entity def search_entities(self, project_id: str, query: str) -> list[Entity]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT * FROM entities WHERE project_id = ? AND (name LIKE ? OR definition LIKE ? OR aliases LIKE ?) @@ -589,36 +589,36 @@ class DatabaseManager: ).fetchall() conn.close() - entities = [] + entities = [] for row in rows: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] entities.append(Entity(**data)) return entities def get_project_summary(self, project_id: str) -> dict: - conn = self.get_conn() - project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone() + conn = self.get_conn() + project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone() - entity_count = conn.execute( + entity_count = conn.execute( "SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id, ) ).fetchone()["count"] - transcript_count = conn.execute( + transcript_count = conn.execute( "SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id, ) ).fetchone()["count"] - relation_count = conn.execute( + relation_count = conn.execute( "SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id, ) ).fetchone()["count"] - recent_transcripts = conn.execute( + recent_transcripts = conn.execute( """SELECT filename, full_text, created_at FROM transcripts WHERE project_id = ? ORDER BY created_at DESC LIMIT 5""", (project_id, ), ).fetchall() - top_entities = conn.execute( + top_entities = conn.execute( """SELECT e.name, e.type, e.definition, COUNT(m.id) as mention_count FROM entities e LEFT JOIN entity_mentions m ON e.id = m.entity_id @@ -641,29 +641,29 @@ class DatabaseManager: } def get_transcript_context( - self, transcript_id: str, position: int, context_chars: int = 200 + self, transcript_id: str, position: int, context_chars: int = 200 ) -> str: - conn = self.get_conn() - row = conn.execute( + conn = self.get_conn() + row = conn.execute( "SELECT full_text FROM transcripts WHERE id = ?", (transcript_id, ) ).fetchone() conn.close() if not row: return "" - text = row["full_text"] - start = max(0, position - context_chars) - end = min(len(text), position + context_chars) + text = row["full_text"] + start = max(0, position - context_chars) + end = min(len(text), position + context_chars) return text[start:end] # ==================== Phase 5: Timeline Operations ==================== def get_project_timeline( - self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None + self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None ) -> list[dict]: - conn = self.get_conn() + conn = self.get_conn() - conditions = ["t.project_id = ?"] - params = [project_id] + conditions = ["t.project_id = ?"] + params = [project_id] if entity_id: conditions.append("m.entity_id = ?") @@ -675,9 +675,9 @@ class DatabaseManager: conditions.append("t.created_at <= ?") params.append(end_date) - where_clause = " AND ".join(conditions) + where_clause = " AND ".join(conditions) - mentions = conn.execute( + mentions = conn.execute( f"""SELECT m.*, e.name as entity_name, e.type as entity_type, e.definition, t.filename, t.created_at as event_date, t.type as source_type FROM entity_mentions m @@ -687,7 +687,7 @@ class DatabaseManager: params, ).fetchall() - timeline_events = [] + timeline_events = [] for m in mentions: timeline_events.append( { @@ -708,13 +708,13 @@ class DatabaseManager: ) conn.close() - timeline_events.sort(key = lambda x: x["event_date"]) + timeline_events.sort(key=lambda x: x["event_date"]) return timeline_events def get_entity_timeline_summary(self, project_id: str) -> dict: - conn = self.get_conn() + conn = self.get_conn() - daily_stats = conn.execute( + daily_stats = conn.execute( """SELECT DATE(t.created_at) as date, COUNT(*) as count FROM entity_mentions m JOIN transcripts t ON m.transcript_id = t.id @@ -722,7 +722,7 @@ class DatabaseManager: (project_id, ), ).fetchall() - entity_stats = conn.execute( + entity_stats = conn.execute( """SELECT e.name, e.type, COUNT(m.id) as mention_count, MIN(t.created_at) as first_mentioned, MAX(t.created_at) as last_mentioned @@ -744,8 +744,8 @@ class DatabaseManager: # ==================== Phase 5: Entity Attributes ==================== def create_attribute_template(self, template: AttributeTemplate) -> AttributeTemplate: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO attribute_templates (id, project_id, name, type, options, default_value, description, @@ -770,36 +770,36 @@ class DatabaseManager: return template def get_attribute_template(self, template_id: str) -> AttributeTemplate | None: - conn = self.get_conn() - row = conn.execute( + conn = self.get_conn() + row = conn.execute( "SELECT * FROM attribute_templates WHERE id = ?", (template_id, ) ).fetchone() conn.close() if row: - data = dict(row) - data["options"] = json.loads(data["options"]) if data["options"] else [] + data = dict(row) + data["options"] = json.loads(data["options"]) if data["options"] else [] return AttributeTemplate(**data) return None def list_attribute_templates(self, project_id: str) -> list[AttributeTemplate]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT * FROM attribute_templates WHERE project_id = ? ORDER BY sort_order, created_at""", (project_id, ), ).fetchall() conn.close() - templates = [] + templates = [] for row in rows: - data = dict(row) - data["options"] = json.loads(data["options"]) if data["options"] else [] + data = dict(row) + data["options"] = json.loads(data["options"]) if data["options"] else [] templates.append(AttributeTemplate(**data)) return templates def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None: - conn = self.get_conn() - allowed_fields = [ + conn = self.get_conn() + allowed_fields = [ "name", "type", "options", @@ -808,8 +808,8 @@ class DatabaseManager: "is_required", "sort_order", ] - updates = [] - values = [] + updates = [] + values = [] for field in allowed_fields: if field in kwargs: @@ -823,7 +823,7 @@ class DatabaseManager: updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(template_id) - query = f"UPDATE attribute_templates SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE attribute_templates SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -831,22 +831,22 @@ class DatabaseManager: return self.get_attribute_template(template_id) def delete_attribute_template(self, template_id: str) -> None: - conn = self.get_conn() + conn = self.get_conn() conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id, )) conn.commit() conn.close() def set_entity_attribute( - self, attr: EntityAttribute, changed_by: str = "system", change_reason: str = "" + self, attr: EntityAttribute, changed_by: str = "system", change_reason: str = "" ) -> EntityAttribute: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() - old_row = conn.execute( + old_row = conn.execute( "SELECT value FROM entity_attributes WHERE entity_id = ? AND template_id = ?", (attr.entity_id, attr.template_id), ).fetchone() - old_value = old_row["value"] if old_row else None + old_value = old_row["value"] if old_row else None if old_value != attr.value: conn.execute( @@ -899,8 +899,8 @@ class DatabaseManager: return attr def get_entity_attributes(self, entity_id: str) -> list[EntityAttribute]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT ea.*, at.name as template_name, at.type as template_type FROM entity_attributes ea LEFT JOIN attribute_templates at ON ea.template_id = at.id @@ -911,11 +911,11 @@ class DatabaseManager: return [EntityAttribute(**dict(r)) for r in rows] def get_entity_with_attributes(self, entity_id: str) -> Entity | None: - entity = self.get_entity(entity_id) + entity = self.get_entity(entity_id) if not entity: return None - attrs = self.get_entity_attributes(entity_id) - entity.attributes = { + attrs = self.get_entity_attributes(entity_id) + entity.attributes = { attr.template_name: { "value": attr.value, "type": attr.template_type, @@ -926,10 +926,10 @@ class DatabaseManager: return entity def delete_entity_attribute( - self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "" + self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "" ) -> None: - conn = self.get_conn() - old_row = conn.execute( + conn = self.get_conn() + old_row = conn.execute( """SELECT value FROM entity_attributes WHERE entity_id = ? AND template_id = ?""", (entity_id, template_id), @@ -960,11 +960,11 @@ class DatabaseManager: conn.close() def get_attribute_history( - self, entity_id: str = None, template_id: str = None, limit: int = 50 + self, entity_id: str = None, template_id: str = None, limit: int = 50 ) -> list[AttributeHistory]: - conn = self.get_conn() - conditions = [] - params = [] + conn = self.get_conn() + conditions = [] + params = [] if entity_id: conditions.append("ah.entity_id = ?") @@ -973,9 +973,9 @@ class DatabaseManager: conditions.append("ah.template_id = ?") params.append(template_id) - where_clause = " AND ".join(conditions) if conditions else "1 = 1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"""SELECT ah.* FROM attribute_history ah WHERE {where_clause} @@ -988,17 +988,17 @@ class DatabaseManager: def search_entities_by_attributes( self, project_id: str, attribute_filters: dict[str, str] ) -> list[Entity]: - entities = self.list_project_entities(project_id) + entities = self.list_project_entities(project_id) if not attribute_filters: return entities - entity_ids = [e.id for e in entities] + entity_ids = [e.id for e in entities] if not entity_ids: return [] - conn = self.get_conn() - placeholders = ", ".join(["?" for _ in entity_ids]) - rows = conn.execute( + conn = self.get_conn() + placeholders = ", ".join(["?" for _ in entity_ids]) + rows = conn.execute( f"""SELECT ea.*, at.name as template_name FROM entity_attributes ea JOIN attribute_templates at ON ea.template_id = at.id @@ -1007,23 +1007,23 @@ class DatabaseManager: ).fetchall() conn.close() - entity_attrs = {} + entity_attrs = {} for row in rows: - eid = row["entity_id"] + eid = row["entity_id"] if eid not in entity_attrs: - entity_attrs[eid] = {} - entity_attrs[eid][row["template_name"]] = row["value"] + entity_attrs[eid] = {} + entity_attrs[eid][row["template_name"]] = row["value"] - filtered = [] + filtered = [] for entity in entities: - attrs = entity_attrs.get(entity.id, {}) - match = True + attrs = entity_attrs.get(entity.id, {}) + match = True for attr_name, attr_value in attribute_filters.items(): if attrs.get(attr_name) != attr_value: - match = False + match = False break if match: - entity.attributes = attrs + entity.attributes = attrs filtered.append(entity) return filtered @@ -1034,17 +1034,17 @@ class DatabaseManager: video_id: str, project_id: str, filename: str, - duration: float = 0, - fps: float = 0, - resolution: dict = None, - audio_transcript_id: str = None, - full_ocr_text: str = "", - extracted_entities: list[dict] = None, - extracted_relations: list[dict] = None, + duration: float = 0, + fps: float = 0, + resolution: dict = None, + audio_transcript_id: str = None, + full_ocr_text: str = "", + extracted_entities: list[dict] = None, + extracted_relations: list[dict] = None, ) -> str: """创建视频记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO videos @@ -1074,17 +1074,17 @@ class DatabaseManager: def get_video(self, video_id: str) -> dict | None: """获取视频信息""" - conn = self.get_conn() - row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id, )).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id, )).fetchone() conn.close() if row: - data = dict(row) - data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None - data["extracted_entities"] = ( + data = dict(row) + data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) return data @@ -1092,20 +1092,20 @@ class DatabaseManager: def list_project_videos(self, project_id: str) -> list[dict]: """获取项目的所有视频""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id, ) ).fetchall() conn.close() - videos = [] + videos = [] for row in rows: - data = dict(row) - data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None - data["extracted_entities"] = ( + data = dict(row) + data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) videos.append(data) @@ -1117,13 +1117,13 @@ class DatabaseManager: video_id: str, frame_number: int, timestamp: float, - image_url: str = None, - ocr_text: str = None, - extracted_entities: list[dict] = None, + image_url: str = None, + ocr_text: str = None, + extracted_entities: list[dict] = None, ) -> str: """创建视频帧记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO video_frames @@ -1147,16 +1147,16 @@ class DatabaseManager: def get_video_frames(self, video_id: str) -> list[dict]: """获取视频的所有帧""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id, ) ).fetchall() conn.close() - frames = [] + frames = [] for row in rows: - data = dict(row) - data["extracted_entities"] = ( + data = dict(row) + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) frames.append(data) @@ -1167,14 +1167,14 @@ class DatabaseManager: image_id: str, project_id: str, filename: str, - ocr_text: str = "", - description: str = "", - extracted_entities: list[dict] = None, - extracted_relations: list[dict] = None, + ocr_text: str = "", + description: str = "", + extracted_entities: list[dict] = None, + extracted_relations: list[dict] = None, ) -> str: """创建图片记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO images @@ -1200,16 +1200,16 @@ class DatabaseManager: def get_image(self, image_id: str) -> dict | None: """获取图片信息""" - conn = self.get_conn() - row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id, )).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id, )).fetchone() conn.close() if row: - data = dict(row) - data["extracted_entities"] = ( + data = dict(row) + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) return data @@ -1217,19 +1217,19 @@ class DatabaseManager: def list_project_images(self, project_id: str) -> list[dict]: """获取项目的所有图片""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( "SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id, ) ).fetchall() conn.close() - images = [] + images = [] for row in rows: - data = dict(row) - data["extracted_entities"] = ( + data = dict(row) + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) images.append(data) @@ -1243,12 +1243,12 @@ class DatabaseManager: modality: str, source_id: str, source_type: str, - text_snippet: str = "", - confidence: float = 1.0, + text_snippet: str = "", + confidence: float = 1.0, ) -> str: """创建多模态实体提及记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT OR REPLACE INTO multimodal_mentions @@ -1273,8 +1273,8 @@ class DatabaseManager: def get_entity_multimodal_mentions(self, entity_id: str) -> list[dict]: """获取实体的多模态提及""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT m.*, e.name as entity_name FROM multimodal_mentions m JOIN entities e ON m.entity_id = e.id @@ -1284,12 +1284,12 @@ class DatabaseManager: conn.close() return [dict(r) for r in rows] - def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]: + def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]: """获取项目的多模态提及""" - conn = self.get_conn() + conn = self.get_conn() if modality: - rows = conn.execute( + rows = conn.execute( """SELECT m.*, e.name as entity_name FROM multimodal_mentions m JOIN entities e ON m.entity_id = e.id @@ -1298,7 +1298,7 @@ class DatabaseManager: (project_id, modality), ).fetchall() else: - rows = conn.execute( + rows = conn.execute( """SELECT m.*, e.name as entity_name FROM multimodal_mentions m JOIN entities e ON m.entity_id = e.id @@ -1315,13 +1315,13 @@ class DatabaseManager: entity_id: str, linked_entity_id: str, link_type: str, - confidence: float = 1.0, - evidence: str = "", - modalities: list[str] = None, + confidence: float = 1.0, + evidence: str = "", + modalities: list[str] = None, ) -> str: """创建多模态实体关联""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT OR REPLACE INTO multimodal_entity_links @@ -1345,8 +1345,8 @@ class DatabaseManager: def get_entity_multimodal_links(self, entity_id: str) -> list[dict]: """获取实体的多模态关联""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT l.*, e1.name as entity_name, e2.name as linked_entity_name FROM multimodal_entity_links l JOIN entities e1 ON l.entity_id = e1.id @@ -1356,18 +1356,18 @@ class DatabaseManager: ).fetchall() conn.close() - links = [] + links = [] for row in rows: - data = dict(row) - data["modalities"] = json.loads(data["modalities"]) if data["modalities"] else [] + data = dict(row) + data["modalities"] = json.loads(data["modalities"]) if data["modalities"] else [] links.append(data) return links def get_project_multimodal_stats(self, project_id: str) -> dict: """获取项目多模态统计信息""" - conn = self.get_conn() + conn = self.get_conn() - stats = { + stats = { "video_count": 0, "image_count": 0, "multimodal_entity_count": 0, @@ -1376,52 +1376,52 @@ class DatabaseManager: } # 视频数量 - row = conn.execute( + row = conn.execute( "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id, ) ).fetchone() - stats["video_count"] = row["count"] + stats["video_count"] = row["count"] # 图片数量 - row = conn.execute( + row = conn.execute( "SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id, ) ).fetchone() - stats["image_count"] = row["count"] + stats["image_count"] = row["count"] # 多模态实体数量 - row = conn.execute( + row = conn.execute( """SELECT COUNT(DISTINCT entity_id) as count FROM multimodal_mentions WHERE project_id = ?""", (project_id, ), ).fetchone() - stats["multimodal_entity_count"] = row["count"] + stats["multimodal_entity_count"] = row["count"] # 跨模态关联数量 - row = conn.execute( + row = conn.execute( """SELECT COUNT(*) as count FROM multimodal_entity_links WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""", (project_id, ), ).fetchone() - stats["cross_modal_links"] = row["count"] + stats["cross_modal_links"] = row["count"] # 模态分布 for modality in ["audio", "video", "image", "document"]: - row = conn.execute( + row = conn.execute( """SELECT COUNT(*) as count FROM multimodal_mentions WHERE project_id = ? AND modality = ?""", (project_id, modality), ).fetchone() - stats["modality_distribution"][modality] = row["count"] + stats["modality_distribution"][modality] = row["count"] conn.close() return stats # Singleton instance -_db_manager = None +_db_manager = None def get_db_manager() -> DatabaseManager: global _db_manager if _db_manager is None: - _db_manager = DatabaseManager() + _db_manager = DatabaseManager() return _db_manager diff --git a/backend/developer_ecosystem_manager.py b/backend/developer_ecosystem_manager.py index fc48c95..fece738 100644 --- a/backend/developer_ecosystem_manager.py +++ b/backend/developer_ecosystem_manager.py @@ -19,81 +19,81 @@ from datetime import datetime from enum import StrEnum # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class SDKLanguage(StrEnum): """SDK 语言类型""" - PYTHON = "python" - JAVASCRIPT = "javascript" - TYPESCRIPT = "typescript" - GO = "go" - JAVA = "java" - RUST = "rust" + PYTHON = "python" + JAVASCRIPT = "javascript" + TYPESCRIPT = "typescript" + GO = "go" + JAVA = "java" + RUST = "rust" class SDKStatus(StrEnum): """SDK 状态""" - DRAFT = "draft" # 草稿 - BETA = "beta" # 测试版 - STABLE = "stable" # 稳定版 - DEPRECATED = "deprecated" # 已弃用 - ARCHIVED = "archived" # 已归档 + DRAFT = "draft" # 草稿 + BETA = "beta" # 测试版 + STABLE = "stable" # 稳定版 + DEPRECATED = "deprecated" # 已弃用 + ARCHIVED = "archived" # 已归档 class TemplateCategory(StrEnum): """模板分类""" - MEDICAL = "medical" # 医疗 - LEGAL = "legal" # 法律 - FINANCE = "finance" # 金融 - EDUCATION = "education" # 教育 - TECH = "tech" # 科技 - GENERAL = "general" # 通用 + MEDICAL = "medical" # 医疗 + LEGAL = "legal" # 法律 + FINANCE = "finance" # 金融 + EDUCATION = "education" # 教育 + TECH = "tech" # 科技 + GENERAL = "general" # 通用 class TemplateStatus(StrEnum): """模板状态""" - PENDING = "pending" # 待审核 - APPROVED = "approved" # 已通过 - REJECTED = "rejected" # 已拒绝 - PUBLISHED = "published" # 已发布 - UNLISTED = "unlisted" # 未列出 + PENDING = "pending" # 待审核 + APPROVED = "approved" # 已通过 + REJECTED = "rejected" # 已拒绝 + PUBLISHED = "published" # 已发布 + UNLISTED = "unlisted" # 未列出 class PluginStatus(StrEnum): """插件状态""" - PENDING = "pending" # 待审核 - REVIEWING = "reviewing" # 审核中 - APPROVED = "approved" # 已通过 - REJECTED = "rejected" # 已拒绝 - PUBLISHED = "published" # 已发布 - SUSPENDED = "suspended" # 已暂停 + PENDING = "pending" # 待审核 + REVIEWING = "reviewing" # 审核中 + APPROVED = "approved" # 已通过 + REJECTED = "rejected" # 已拒绝 + PUBLISHED = "published" # 已发布 + SUSPENDED = "suspended" # 已暂停 class PluginCategory(StrEnum): """插件分类""" - INTEGRATION = "integration" # 集成 - ANALYSIS = "analysis" # 分析 - VISUALIZATION = "visualization" # 可视化 - AUTOMATION = "automation" # 自动化 - SECURITY = "security" # 安全 - CUSTOM = "custom" # 自定义 + INTEGRATION = "integration" # 集成 + ANALYSIS = "analysis" # 分析 + VISUALIZATION = "visualization" # 可视化 + AUTOMATION = "automation" # 自动化 + SECURITY = "security" # 安全 + CUSTOM = "custom" # 自定义 class DeveloperStatus(StrEnum): """开发者认证状态""" - UNVERIFIED = "unverified" # 未认证 - PENDING = "pending" # 审核中 - VERIFIED = "verified" # 已认证 - CERTIFIED = "certified" # 已认证(高级) - SUSPENDED = "suspended" # 已暂停 + UNVERIFIED = "unverified" # 未认证 + PENDING = "pending" # 审核中 + VERIFIED = "verified" # 已认证 + CERTIFIED = "certified" # 已认证(高级) + SUSPENDED = "suspended" # 已暂停 @dataclass @@ -348,14 +348,14 @@ class DeveloperPortalConfig: class DeveloperEcosystemManager: """开发者生态系统管理主类""" - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path - self.platform_fee_rate = 0.30 # 平台抽成比例 30% + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self.platform_fee_rate = 0.30 # 平台抽成比例 30% def _get_db(self) -> None: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn # ==================== SDK 发布与管理 ==================== @@ -378,30 +378,30 @@ class DeveloperEcosystemManager: created_by: str, ) -> SDKRelease: """创建 SDK 发布""" - sdk_id = f"sdk_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + sdk_id = f"sdk_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - sdk = SDKRelease( - id = sdk_id, - name = name, - language = language, - version = version, - description = description, - changelog = changelog, - download_url = download_url, - documentation_url = documentation_url, - repository_url = repository_url, - package_name = package_name, - status = SDKStatus.DRAFT, - min_platform_version = min_platform_version, - dependencies = dependencies, - file_size = file_size, - checksum = checksum, - download_count = 0, - created_at = now, - updated_at = now, - published_at = None, - created_by = created_by, + sdk = SDKRelease( + id=sdk_id, + name=name, + language=language, + version=version, + description=description, + changelog=changelog, + download_url=download_url, + documentation_url=documentation_url, + repository_url=repository_url, + package_name=package_name, + status=SDKStatus.DRAFT, + min_platform_version=min_platform_version, + dependencies=dependencies, + file_size=file_size, + checksum=checksum, + download_count=0, + created_at=now, + updated_at=now, + published_at=None, + created_by=created_by, ) with self._get_db() as conn: @@ -444,7 +444,7 @@ class DeveloperEcosystemManager: def get_sdk_release(self, sdk_id: str) -> SDKRelease | None: """获取 SDK 发布详情""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id, )).fetchone() + row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id, )).fetchone() if row: return self._row_to_sdk_release(row) @@ -452,13 +452,13 @@ class DeveloperEcosystemManager: def list_sdk_releases( self, - language: SDKLanguage | None = None, - status: SDKStatus | None = None, - search: str | None = None, + language: SDKLanguage | None = None, + status: SDKStatus | None = None, + search: str | None = None, ) -> list[SDKRelease]: """列出 SDK 发布""" - query = "SELECT * FROM sdk_releases WHERE 1 = 1" - params = [] + query = "SELECT * FROM sdk_releases WHERE 1 = 1" + params = [] if language: query += " AND language = ?" @@ -473,12 +473,12 @@ class DeveloperEcosystemManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_sdk_release(row) for row in rows] def update_sdk_release(self, sdk_id: str, **kwargs) -> SDKRelease | None: """更新 SDK 发布""" - allowed_fields = [ + allowed_fields = [ "name", "description", "changelog", @@ -488,14 +488,14 @@ class DeveloperEcosystemManager: "status", ] - updates = {k: v for k, v in kwargs.items() if k in allowed_fields} + updates = {k: v for k, v in kwargs.items() if k in allowed_fields} if not updates: return self.get_sdk_release(sdk_id) - updates["updated_at"] = datetime.now().isoformat() + updates["updated_at"] = datetime.now().isoformat() with self._get_db() as conn: - set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) conn.execute( f"UPDATE sdk_releases SET {set_clause} WHERE id = ?", list(updates.values()) + [sdk_id], @@ -506,7 +506,7 @@ class DeveloperEcosystemManager: def publish_sdk_release(self, sdk_id: str) -> SDKRelease | None: """发布 SDK""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -537,7 +537,7 @@ class DeveloperEcosystemManager: def get_sdk_versions(self, sdk_id: str) -> list[SDKVersion]: """获取 SDK 版本历史""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC", (sdk_id, ) ).fetchall() return [self._row_to_sdk_version(row) for row in rows] @@ -553,8 +553,8 @@ class DeveloperEcosystemManager: file_size: int, ) -> SDKVersion: """添加 SDK 版本""" - version_id = f"sv_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + version_id = f"sv_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() with self._get_db() as conn: # 如果设置为最新版本,取消其他版本的最新标记 @@ -585,17 +585,17 @@ class DeveloperEcosystemManager: conn.commit() return SDKVersion( - id = version_id, - sdk_id = sdk_id, - version = version, - is_latest = True, - is_lts = is_lts, - release_notes = release_notes, - download_url = download_url, - checksum = checksum, - file_size = file_size, - download_count = 0, - created_at = now, + id=version_id, + sdk_id=sdk_id, + version=version, + is_latest=True, + is_lts=is_lts, + release_notes=release_notes, + download_url=download_url, + checksum=checksum, + file_size=file_size, + download_count=0, + created_at=now, ) # ==================== 模板市场 ==================== @@ -609,48 +609,48 @@ class DeveloperEcosystemManager: tags: list[str], author_id: str, author_name: str, - price: float = 0.0, - currency: str = "CNY", - preview_image_url: str | None = None, - demo_url: str | None = None, - documentation_url: str | None = None, - download_url: str | None = None, - version: str = "1.0.0", - min_platform_version: str = "1.0.0", - file_size: int = 0, - checksum: str = "", + price: float = 0.0, + currency: str = "CNY", + preview_image_url: str | None = None, + demo_url: str | None = None, + documentation_url: str | None = None, + download_url: str | None = None, + version: str = "1.0.0", + min_platform_version: str = "1.0.0", + file_size: int = 0, + checksum: str = "", ) -> TemplateMarketItem: """创建模板""" - template_id = f"tpl_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + template_id = f"tpl_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - template = TemplateMarketItem( - id = template_id, - name = name, - description = description, - category = category, - subcategory = subcategory, - tags = tags, - author_id = author_id, - author_name = author_name, - status = TemplateStatus.PENDING, - price = price, - currency = currency, - preview_image_url = preview_image_url, - demo_url = demo_url, - documentation_url = documentation_url, - download_url = download_url, - install_count = 0, - rating = 0.0, - rating_count = 0, - review_count = 0, - version = version, - min_platform_version = min_platform_version, - file_size = file_size, - checksum = checksum, - created_at = now, - updated_at = now, - published_at = None, + template = TemplateMarketItem( + id=template_id, + name=name, + description=description, + category=category, + subcategory=subcategory, + tags=tags, + author_id=author_id, + author_name=author_name, + status=TemplateStatus.PENDING, + price=price, + currency=currency, + preview_image_url=preview_image_url, + demo_url=demo_url, + documentation_url=documentation_url, + download_url=download_url, + install_count=0, + rating=0.0, + rating_count=0, + review_count=0, + version=version, + min_platform_version=min_platform_version, + file_size=file_size, + checksum=checksum, + created_at=now, + updated_at=now, + published_at=None, ) with self._get_db() as conn: @@ -699,7 +699,7 @@ class DeveloperEcosystemManager: def get_template(self, template_id: str) -> TemplateMarketItem | None: """获取模板详情""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM template_market WHERE id = ?", (template_id, ) ).fetchone() @@ -709,17 +709,17 @@ class DeveloperEcosystemManager: def list_templates( self, - category: TemplateCategory | None = None, - status: TemplateStatus | None = None, - search: str | None = None, - author_id: str | None = None, - min_price: float | None = None, - max_price: float | None = None, - sort_by: str = "created_at", + category: TemplateCategory | None = None, + status: TemplateStatus | None = None, + search: str | None = None, + author_id: str | None = None, + min_price: float | None = None, + max_price: float | None = None, + sort_by: str = "created_at", ) -> list[TemplateMarketItem]: """列出模板""" - query = "SELECT * FROM template_market WHERE 1 = 1" - params = [] + query = "SELECT * FROM template_market WHERE 1 = 1" + params = [] if category: query += " AND category = ?" @@ -741,7 +741,7 @@ class DeveloperEcosystemManager: params.append(max_price) # 排序 - sort_mapping = { + sort_mapping = { "created_at": "created_at DESC", "rating": "rating DESC", "install_count": "install_count DESC", @@ -751,12 +751,12 @@ class DeveloperEcosystemManager: query += f" ORDER BY {sort_mapping.get(sort_by, 'created_at DESC')}" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_template(row) for row in rows] def approve_template(self, template_id: str, reviewed_by: str) -> TemplateMarketItem | None: """审核通过模板""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -773,7 +773,7 @@ class DeveloperEcosystemManager: def publish_template(self, template_id: str) -> TemplateMarketItem | None: """发布模板""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -790,7 +790,7 @@ class DeveloperEcosystemManager: def reject_template(self, template_id: str, reason: str) -> TemplateMarketItem | None: """拒绝模板""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -825,23 +825,23 @@ class DeveloperEcosystemManager: user_name: str, rating: int, comment: str, - is_verified_purchase: bool = False, + is_verified_purchase: bool = False, ) -> TemplateReview: """添加模板评价""" - review_id = f"tr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + review_id = f"tr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - review = TemplateReview( - id = review_id, - template_id = template_id, - user_id = user_id, - user_name = user_name, - rating = rating, - comment = comment, - is_verified_purchase = is_verified_purchase, - helpful_count = 0, - created_at = now, - updated_at = now, + review = TemplateReview( + id=review_id, + template_id=template_id, + user_id=user_id, + user_name=user_name, + rating=rating, + comment=comment, + is_verified_purchase=is_verified_purchase, + helpful_count=0, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -874,7 +874,7 @@ class DeveloperEcosystemManager: def _update_template_rating(self, conn, template_id: str) -> None: """更新模板评分""" - row = conn.execute( + row = conn.execute( """ SELECT AVG(rating) as avg_rating, COUNT(*) as count FROM template_reviews @@ -898,10 +898,10 @@ class DeveloperEcosystemManager: ), ) - def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]: + def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]: """获取模板评价""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM template_reviews WHERE template_id = ? ORDER BY created_at DESC @@ -920,59 +920,59 @@ class DeveloperEcosystemManager: tags: list[str], author_id: str, author_name: str, - price: float = 0.0, - currency: str = "CNY", - pricing_model: str = "free", - preview_image_url: str | None = None, - demo_url: str | None = None, - documentation_url: str | None = None, - repository_url: str | None = None, - download_url: str | None = None, - webhook_url: str | None = None, - permissions: list[str] = None, - version: str = "1.0.0", - min_platform_version: str = "1.0.0", - file_size: int = 0, - checksum: str = "", + price: float = 0.0, + currency: str = "CNY", + pricing_model: str = "free", + preview_image_url: str | None = None, + demo_url: str | None = None, + documentation_url: str | None = None, + repository_url: str | None = None, + download_url: str | None = None, + webhook_url: str | None = None, + permissions: list[str] = None, + version: str = "1.0.0", + min_platform_version: str = "1.0.0", + file_size: int = 0, + checksum: str = "", ) -> PluginMarketItem: """创建插件""" - plugin_id = f"plg_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + plugin_id = f"plg_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - plugin = PluginMarketItem( - id = plugin_id, - name = name, - description = description, - category = category, - tags = tags, - author_id = author_id, - author_name = author_name, - status = PluginStatus.PENDING, - price = price, - currency = currency, - pricing_model = pricing_model, - preview_image_url = preview_image_url, - demo_url = demo_url, - documentation_url = documentation_url, - repository_url = repository_url, - download_url = download_url, - webhook_url = webhook_url, - permissions = permissions or [], - install_count = 0, - active_install_count = 0, - rating = 0.0, - rating_count = 0, - review_count = 0, - version = version, - min_platform_version = min_platform_version, - file_size = file_size, - checksum = checksum, - created_at = now, - updated_at = now, - published_at = None, - reviewed_by = None, - reviewed_at = None, - review_notes = None, + plugin = PluginMarketItem( + id=plugin_id, + name=name, + description=description, + category=category, + tags=tags, + author_id=author_id, + author_name=author_name, + status=PluginStatus.PENDING, + price=price, + currency=currency, + pricing_model=pricing_model, + preview_image_url=preview_image_url, + demo_url=demo_url, + documentation_url=documentation_url, + repository_url=repository_url, + download_url=download_url, + webhook_url=webhook_url, + permissions=permissions or [], + install_count=0, + active_install_count=0, + rating=0.0, + rating_count=0, + review_count=0, + version=version, + min_platform_version=min_platform_version, + file_size=file_size, + checksum=checksum, + created_at=now, + updated_at=now, + published_at=None, + reviewed_by=None, + reviewed_at=None, + review_notes=None, ) with self._get_db() as conn: @@ -1032,7 +1032,7 @@ class DeveloperEcosystemManager: def get_plugin(self, plugin_id: str) -> PluginMarketItem | None: """获取插件详情""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id, )).fetchone() + row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id, )).fetchone() if row: return self._row_to_plugin(row) @@ -1040,15 +1040,15 @@ class DeveloperEcosystemManager: def list_plugins( self, - category: PluginCategory | None = None, - status: PluginStatus | None = None, - search: str | None = None, - author_id: str | None = None, - sort_by: str = "created_at", + category: PluginCategory | None = None, + status: PluginStatus | None = None, + search: str | None = None, + author_id: str | None = None, + sort_by: str = "created_at", ) -> list[PluginMarketItem]: """列出插件""" - query = "SELECT * FROM plugin_market WHERE 1 = 1" - params = [] + query = "SELECT * FROM plugin_market WHERE 1 = 1" + params = [] if category: query += " AND category = ?" @@ -1063,7 +1063,7 @@ class DeveloperEcosystemManager: query += " AND (name LIKE ? OR description LIKE ? OR tags LIKE ?)" params.extend([f"%{search}%", f"%{search}%", f"%{search}%"]) - sort_mapping = { + sort_mapping = { "created_at": "created_at DESC", "rating": "rating DESC", "install_count": "install_count DESC", @@ -1072,14 +1072,14 @@ class DeveloperEcosystemManager: query += f" ORDER BY {sort_mapping.get(sort_by, 'created_at DESC')}" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_plugin(row) for row in rows] def review_plugin( - self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "" + self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "" ) -> PluginMarketItem | None: """审核插件""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -1096,7 +1096,7 @@ class DeveloperEcosystemManager: def publish_plugin(self, plugin_id: str) -> PluginMarketItem | None: """发布插件""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -1111,7 +1111,7 @@ class DeveloperEcosystemManager: return self.get_plugin(plugin_id) - def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None: + def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None: """增加插件安装计数""" with self._get_db() as conn: conn.execute( @@ -1141,23 +1141,23 @@ class DeveloperEcosystemManager: user_name: str, rating: int, comment: str, - is_verified_purchase: bool = False, + is_verified_purchase: bool = False, ) -> PluginReview: """添加插件评价""" - review_id = f"pr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + review_id = f"pr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - review = PluginReview( - id = review_id, - plugin_id = plugin_id, - user_id = user_id, - user_name = user_name, - rating = rating, - comment = comment, - is_verified_purchase = is_verified_purchase, - helpful_count = 0, - created_at = now, - updated_at = now, + review = PluginReview( + id=review_id, + plugin_id=plugin_id, + user_id=user_id, + user_name=user_name, + rating=rating, + comment=comment, + is_verified_purchase=is_verified_purchase, + helpful_count=0, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1189,7 +1189,7 @@ class DeveloperEcosystemManager: def _update_plugin_rating(self, conn, plugin_id: str) -> None: """更新插件评分""" - row = conn.execute( + row = conn.execute( """ SELECT AVG(rating) as avg_rating, COUNT(*) as count FROM plugin_reviews @@ -1213,10 +1213,10 @@ class DeveloperEcosystemManager: ), ) - def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]: + def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]: """获取插件评价""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM plugin_reviews WHERE plugin_id = ? ORDER BY created_at DESC @@ -1239,25 +1239,25 @@ class DeveloperEcosystemManager: transaction_id: str, ) -> DeveloperRevenue: """记录开发者收益""" - revenue_id = f"rev_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + revenue_id = f"rev_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - platform_fee = sale_amount * self.platform_fee_rate - developer_earnings = sale_amount - platform_fee + platform_fee = sale_amount * self.platform_fee_rate + developer_earnings = sale_amount - platform_fee - revenue = DeveloperRevenue( - id = revenue_id, - developer_id = developer_id, - item_type = item_type, - item_id = item_id, - item_name = item_name, - sale_amount = sale_amount, - platform_fee = platform_fee, - developer_earnings = developer_earnings, - currency = currency, - buyer_id = buyer_id, - transaction_id = transaction_id, - created_at = now, + revenue = DeveloperRevenue( + id=revenue_id, + developer_id=developer_id, + item_type=item_type, + item_id=item_id, + item_name=item_name, + sale_amount=sale_amount, + platform_fee=platform_fee, + developer_earnings=developer_earnings, + currency=currency, + buyer_id=buyer_id, + transaction_id=transaction_id, + created_at=now, ) with self._get_db() as conn: @@ -1301,12 +1301,12 @@ class DeveloperEcosystemManager: def get_developer_revenues( self, developer_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, ) -> list[DeveloperRevenue]: """获取开发者收益记录""" - query = "SELECT * FROM developer_revenues WHERE developer_id = ?" - params = [developer_id] + query = "SELECT * FROM developer_revenues WHERE developer_id = ?" + params = [developer_id] if start_date: query += " AND created_at >= ?" @@ -1318,13 +1318,13 @@ class DeveloperEcosystemManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_developer_revenue(row) for row in rows] def get_developer_revenue_summary(self, developer_id: str) -> dict: """获取开发者收益汇总""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """ SELECT SUM(sale_amount) as total_sales, @@ -1352,34 +1352,34 @@ class DeveloperEcosystemManager: user_id: str, display_name: str, email: str, - bio: str | None = None, - website: str | None = None, - github_url: str | None = None, - avatar_url: str | None = None, + bio: str | None = None, + website: str | None = None, + github_url: str | None = None, + avatar_url: str | None = None, ) -> DeveloperProfile: """创建开发者档案""" - profile_id = f"dev_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + profile_id = f"dev_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - profile = DeveloperProfile( - id = profile_id, - user_id = user_id, - display_name = display_name, - email = email, - bio = bio, - website = website, - github_url = github_url, - avatar_url = avatar_url, - status = DeveloperStatus.UNVERIFIED, - verification_documents = {}, - total_sales = 0.0, - total_downloads = 0, - plugin_count = 0, - template_count = 0, - rating_average = 0.0, - created_at = now, - updated_at = now, - verified_at = None, + profile = DeveloperProfile( + id=profile_id, + user_id=user_id, + display_name=display_name, + email=email, + bio=bio, + website=website, + github_url=github_url, + avatar_url=avatar_url, + status=DeveloperStatus.UNVERIFIED, + verification_documents={}, + total_sales=0.0, + total_downloads=0, + plugin_count=0, + template_count=0, + rating_average=0.0, + created_at=now, + updated_at=now, + verified_at=None, ) with self._get_db() as conn: @@ -1419,7 +1419,7 @@ class DeveloperEcosystemManager: def get_developer_profile(self, developer_id: str) -> DeveloperProfile | None: """获取开发者档案""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM developer_profiles WHERE id = ?", (developer_id, ) ).fetchone() @@ -1430,7 +1430,7 @@ class DeveloperEcosystemManager: def get_developer_profile_by_user(self, user_id: str) -> DeveloperProfile | None: """通过用户 ID 获取开发者档案""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM developer_profiles WHERE user_id = ?", (user_id, ) ).fetchone() @@ -1442,7 +1442,7 @@ class DeveloperEcosystemManager: self, developer_id: str, status: DeveloperStatus ) -> DeveloperProfile | None: """验证开发者""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -1468,17 +1468,17 @@ class DeveloperEcosystemManager: """更新开发者统计信息""" with self._get_db() as conn: # 统计插件数量 - plugin_row = conn.execute( + plugin_row = conn.execute( "SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?", (developer_id, ) ).fetchone() # 统计模板数量 - template_row = conn.execute( + template_row = conn.execute( "SELECT COUNT(*) as count FROM template_market WHERE author_id = ?", (developer_id, ) ).fetchone() # 统计总下载量 - download_row = conn.execute( + download_row = conn.execute( """ SELECT SUM(install_count) as total FROM ( SELECT install_count FROM plugin_market WHERE author_id = ? @@ -1518,31 +1518,31 @@ class DeveloperEcosystemManager: tags: list[str], author_id: str, author_name: str, - sdk_id: str | None = None, - api_endpoints: list[str] = None, + sdk_id: str | None = None, + api_endpoints: list[str] = None, ) -> CodeExample: """创建代码示例""" - example_id = f"ex_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + example_id = f"ex_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - example = CodeExample( - id = example_id, - title = title, - description = description, - language = language, - category = category, - code = code, - explanation = explanation, - tags = tags, - author_id = author_id, - author_name = author_name, - sdk_id = sdk_id, - api_endpoints = api_endpoints or [], - view_count = 0, - copy_count = 0, - rating = 0.0, - created_at = now, - updated_at = now, + example = CodeExample( + id=example_id, + title=title, + description=description, + language=language, + category=category, + code=code, + explanation=explanation, + tags=tags, + author_id=author_id, + author_name=author_name, + sdk_id=sdk_id, + api_endpoints=api_endpoints or [], + view_count=0, + copy_count=0, + rating=0.0, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1581,7 +1581,7 @@ class DeveloperEcosystemManager: def get_code_example(self, example_id: str) -> CodeExample | None: """获取代码示例""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id, )).fetchone() + row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id, )).fetchone() if row: return self._row_to_code_example(row) @@ -1589,14 +1589,14 @@ class DeveloperEcosystemManager: def list_code_examples( self, - language: str | None = None, - category: str | None = None, - sdk_id: str | None = None, - search: str | None = None, + language: str | None = None, + category: str | None = None, + sdk_id: str | None = None, + search: str | None = None, ) -> list[CodeExample]: """列出代码示例""" - query = "SELECT * FROM code_examples WHERE 1 = 1" - params = [] + query = "SELECT * FROM code_examples WHERE 1 = 1" + params = [] if language: query += " AND language = ?" @@ -1614,7 +1614,7 @@ class DeveloperEcosystemManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_code_example(row) for row in rows] def increment_example_view(self, example_id: str) -> None: @@ -1655,18 +1655,18 @@ class DeveloperEcosystemManager: generated_by: str, ) -> APIDocumentation: """创建 API 文档""" - doc_id = f"api_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + doc_id = f"api_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - doc = APIDocumentation( - id = doc_id, - version = version, - openapi_spec = openapi_spec, - markdown_content = markdown_content, - html_content = html_content, - changelog = changelog, - generated_at = now, - generated_by = generated_by, + doc = APIDocumentation( + id=doc_id, + version=version, + openapi_spec=openapi_spec, + markdown_content=markdown_content, + html_content=html_content, + changelog=changelog, + generated_at=now, + generated_by=generated_by, ) with self._get_db() as conn: @@ -1695,7 +1695,7 @@ class DeveloperEcosystemManager: def get_api_documentation(self, doc_id: str) -> APIDocumentation | None: """获取 API 文档""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id, )).fetchone() + row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id, )).fetchone() if row: return self._row_to_api_documentation(row) @@ -1704,7 +1704,7 @@ class DeveloperEcosystemManager: def get_latest_api_documentation(self) -> APIDocumentation | None: """获取最新 API 文档""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1" ).fetchone() @@ -1718,42 +1718,42 @@ class DeveloperEcosystemManager: self, name: str, description: str, - theme: str = "default", - custom_css: str | None = None, - custom_js: str | None = None, - logo_url: str | None = None, - favicon_url: str | None = None, - primary_color: str = "#1890ff", - secondary_color: str = "#52c41a", - support_email: str = "support@insightflow.io", - support_url: str | None = None, - github_url: str | None = None, - discord_url: str | None = None, - api_base_url: str = "https://api.insightflow.io", + theme: str = "default", + custom_css: str | None = None, + custom_js: str | None = None, + logo_url: str | None = None, + favicon_url: str | None = None, + primary_color: str = "#1890ff", + secondary_color: str = "#52c41a", + support_email: str = "support@insightflow.io", + support_url: str | None = None, + github_url: str | None = None, + discord_url: str | None = None, + api_base_url: str = "https://api.insightflow.io", ) -> DeveloperPortalConfig: """创建开发者门户配置""" - config_id = f"portal_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + config_id = f"portal_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - config = DeveloperPortalConfig( - id = config_id, - name = name, - description = description, - theme = theme, - custom_css = custom_css, - custom_js = custom_js, - logo_url = logo_url, - favicon_url = favicon_url, - primary_color = primary_color, - secondary_color = secondary_color, - support_email = support_email, - support_url = support_url, - github_url = github_url, - discord_url = discord_url, - api_base_url = api_base_url, - is_active = True, - created_at = now, - updated_at = now, + config = DeveloperPortalConfig( + id=config_id, + name=name, + description=description, + theme=theme, + custom_css=custom_css, + custom_js=custom_js, + logo_url=logo_url, + favicon_url=favicon_url, + primary_color=primary_color, + secondary_color=secondary_color, + support_email=support_email, + support_url=support_url, + github_url=github_url, + discord_url=discord_url, + api_base_url=api_base_url, + is_active=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1793,7 +1793,7 @@ class DeveloperEcosystemManager: def get_portal_config(self, config_id: str) -> DeveloperPortalConfig | None: """获取开发者门户配置""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM developer_portal_configs WHERE id = ?", (config_id, ) ).fetchone() @@ -1804,7 +1804,7 @@ class DeveloperEcosystemManager: def get_active_portal_config(self) -> DeveloperPortalConfig | None: """获取活跃的开发者门户配置""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1" ).fetchone() @@ -1817,249 +1817,249 @@ class DeveloperEcosystemManager: def _row_to_sdk_release(self, row) -> SDKRelease: """将数据库行转换为 SDKRelease""" return SDKRelease( - id = row["id"], - name = row["name"], - language = SDKLanguage(row["language"]), - version = row["version"], - description = row["description"], - changelog = row["changelog"], - download_url = row["download_url"], - documentation_url = row["documentation_url"], - repository_url = row["repository_url"], - package_name = row["package_name"], - status = SDKStatus(row["status"]), - min_platform_version = row["min_platform_version"], - dependencies = json.loads(row["dependencies"]), - file_size = row["file_size"], - checksum = row["checksum"], - download_count = row["download_count"], - created_at = row["created_at"], - updated_at = row["updated_at"], - published_at = row["published_at"], - created_by = row["created_by"], + id=row["id"], + name=row["name"], + language=SDKLanguage(row["language"]), + version=row["version"], + description=row["description"], + changelog=row["changelog"], + download_url=row["download_url"], + documentation_url=row["documentation_url"], + repository_url=row["repository_url"], + package_name=row["package_name"], + status=SDKStatus(row["status"]), + min_platform_version=row["min_platform_version"], + dependencies=json.loads(row["dependencies"]), + file_size=row["file_size"], + checksum=row["checksum"], + download_count=row["download_count"], + created_at=row["created_at"], + updated_at=row["updated_at"], + published_at=row["published_at"], + created_by=row["created_by"], ) def _row_to_sdk_version(self, row) -> SDKVersion: """将数据库行转换为 SDKVersion""" return SDKVersion( - id = row["id"], - sdk_id = row["sdk_id"], - version = row["version"], - is_latest = bool(row["is_latest"]), - is_lts = bool(row["is_lts"]), - release_notes = row["release_notes"], - download_url = row["download_url"], - checksum = row["checksum"], - file_size = row["file_size"], - download_count = row["download_count"], - created_at = row["created_at"], + id=row["id"], + sdk_id=row["sdk_id"], + version=row["version"], + is_latest=bool(row["is_latest"]), + is_lts=bool(row["is_lts"]), + release_notes=row["release_notes"], + download_url=row["download_url"], + checksum=row["checksum"], + file_size=row["file_size"], + download_count=row["download_count"], + created_at=row["created_at"], ) def _row_to_template(self, row) -> TemplateMarketItem: """将数据库行转换为 TemplateMarketItem""" return TemplateMarketItem( - id = row["id"], - name = row["name"], - description = row["description"], - category = TemplateCategory(row["category"]), - subcategory = row["subcategory"], - tags = json.loads(row["tags"]), - author_id = row["author_id"], - author_name = row["author_name"], - status = TemplateStatus(row["status"]), - price = row["price"], - currency = row["currency"], - preview_image_url = row["preview_image_url"], - demo_url = row["demo_url"], - documentation_url = row["documentation_url"], - download_url = row["download_url"], - install_count = row["install_count"], - rating = row["rating"], - rating_count = row["rating_count"], - review_count = row["review_count"], - version = row["version"], - min_platform_version = row["min_platform_version"], - file_size = row["file_size"], - checksum = row["checksum"], - created_at = row["created_at"], - updated_at = row["updated_at"], - published_at = row["published_at"], + id=row["id"], + name=row["name"], + description=row["description"], + category=TemplateCategory(row["category"]), + subcategory=row["subcategory"], + tags=json.loads(row["tags"]), + author_id=row["author_id"], + author_name=row["author_name"], + status=TemplateStatus(row["status"]), + price=row["price"], + currency=row["currency"], + preview_image_url=row["preview_image_url"], + demo_url=row["demo_url"], + documentation_url=row["documentation_url"], + download_url=row["download_url"], + install_count=row["install_count"], + rating=row["rating"], + rating_count=row["rating_count"], + review_count=row["review_count"], + version=row["version"], + min_platform_version=row["min_platform_version"], + file_size=row["file_size"], + checksum=row["checksum"], + created_at=row["created_at"], + updated_at=row["updated_at"], + published_at=row["published_at"], ) def _row_to_template_review(self, row) -> TemplateReview: """将数据库行转换为 TemplateReview""" return TemplateReview( - id = row["id"], - template_id = row["template_id"], - user_id = row["user_id"], - user_name = row["user_name"], - rating = row["rating"], - comment = row["comment"], - is_verified_purchase = bool(row["is_verified_purchase"]), - helpful_count = row["helpful_count"], - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + template_id=row["template_id"], + user_id=row["user_id"], + user_name=row["user_name"], + rating=row["rating"], + comment=row["comment"], + is_verified_purchase=bool(row["is_verified_purchase"]), + helpful_count=row["helpful_count"], + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_plugin(self, row) -> PluginMarketItem: """将数据库行转换为 PluginMarketItem""" return PluginMarketItem( - id = row["id"], - name = row["name"], - description = row["description"], - category = PluginCategory(row["category"]), - tags = json.loads(row["tags"]), - author_id = row["author_id"], - author_name = row["author_name"], - status = PluginStatus(row["status"]), - price = row["price"], - currency = row["currency"], - pricing_model = row["pricing_model"], - preview_image_url = row["preview_image_url"], - demo_url = row["demo_url"], - documentation_url = row["documentation_url"], - repository_url = row["repository_url"], - download_url = row["download_url"], - webhook_url = row["webhook_url"], - permissions = json.loads(row["permissions"]), - install_count = row["install_count"], - active_install_count = row["active_install_count"], - rating = row["rating"], - rating_count = row["rating_count"], - review_count = row["review_count"], - version = row["version"], - min_platform_version = row["min_platform_version"], - file_size = row["file_size"], - checksum = row["checksum"], - created_at = row["created_at"], - updated_at = row["updated_at"], - published_at = row["published_at"], - reviewed_by = row["reviewed_by"], - reviewed_at = row["reviewed_at"], - review_notes = row["review_notes"], + id=row["id"], + name=row["name"], + description=row["description"], + category=PluginCategory(row["category"]), + tags=json.loads(row["tags"]), + author_id=row["author_id"], + author_name=row["author_name"], + status=PluginStatus(row["status"]), + price=row["price"], + currency=row["currency"], + pricing_model=row["pricing_model"], + preview_image_url=row["preview_image_url"], + demo_url=row["demo_url"], + documentation_url=row["documentation_url"], + repository_url=row["repository_url"], + download_url=row["download_url"], + webhook_url=row["webhook_url"], + permissions=json.loads(row["permissions"]), + install_count=row["install_count"], + active_install_count=row["active_install_count"], + rating=row["rating"], + rating_count=row["rating_count"], + review_count=row["review_count"], + version=row["version"], + min_platform_version=row["min_platform_version"], + file_size=row["file_size"], + checksum=row["checksum"], + created_at=row["created_at"], + updated_at=row["updated_at"], + published_at=row["published_at"], + reviewed_by=row["reviewed_by"], + reviewed_at=row["reviewed_at"], + review_notes=row["review_notes"], ) def _row_to_plugin_review(self, row) -> PluginReview: """将数据库行转换为 PluginReview""" return PluginReview( - id = row["id"], - plugin_id = row["plugin_id"], - user_id = row["user_id"], - user_name = row["user_name"], - rating = row["rating"], - comment = row["comment"], - is_verified_purchase = bool(row["is_verified_purchase"]), - helpful_count = row["helpful_count"], - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + plugin_id=row["plugin_id"], + user_id=row["user_id"], + user_name=row["user_name"], + rating=row["rating"], + comment=row["comment"], + is_verified_purchase=bool(row["is_verified_purchase"]), + helpful_count=row["helpful_count"], + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_developer_profile(self, row) -> DeveloperProfile: """将数据库行转换为 DeveloperProfile""" return DeveloperProfile( - id = row["id"], - user_id = row["user_id"], - display_name = row["display_name"], - email = row["email"], - bio = row["bio"], - website = row["website"], - github_url = row["github_url"], - avatar_url = row["avatar_url"], - status = DeveloperStatus(row["status"]), - verification_documents = json.loads(row["verification_documents"]), - total_sales = row["total_sales"], - total_downloads = row["total_downloads"], - plugin_count = row["plugin_count"], - template_count = row["template_count"], - rating_average = row["rating_average"], - created_at = row["created_at"], - updated_at = row["updated_at"], - verified_at = row["verified_at"], + id=row["id"], + user_id=row["user_id"], + display_name=row["display_name"], + email=row["email"], + bio=row["bio"], + website=row["website"], + github_url=row["github_url"], + avatar_url=row["avatar_url"], + status=DeveloperStatus(row["status"]), + verification_documents=json.loads(row["verification_documents"]), + total_sales=row["total_sales"], + total_downloads=row["total_downloads"], + plugin_count=row["plugin_count"], + template_count=row["template_count"], + rating_average=row["rating_average"], + created_at=row["created_at"], + updated_at=row["updated_at"], + verified_at=row["verified_at"], ) def _row_to_developer_revenue(self, row) -> DeveloperRevenue: """将数据库行转换为 DeveloperRevenue""" return DeveloperRevenue( - id = row["id"], - developer_id = row["developer_id"], - item_type = row["item_type"], - item_id = row["item_id"], - item_name = row["item_name"], - sale_amount = row["sale_amount"], - platform_fee = row["platform_fee"], - developer_earnings = row["developer_earnings"], - currency = row["currency"], - buyer_id = row["buyer_id"], - transaction_id = row["transaction_id"], - created_at = row["created_at"], + id=row["id"], + developer_id=row["developer_id"], + item_type=row["item_type"], + item_id=row["item_id"], + item_name=row["item_name"], + sale_amount=row["sale_amount"], + platform_fee=row["platform_fee"], + developer_earnings=row["developer_earnings"], + currency=row["currency"], + buyer_id=row["buyer_id"], + transaction_id=row["transaction_id"], + created_at=row["created_at"], ) def _row_to_code_example(self, row) -> CodeExample: """将数据库行转换为 CodeExample""" return CodeExample( - id = row["id"], - title = row["title"], - description = row["description"], - language = row["language"], - category = row["category"], - code = row["code"], - explanation = row["explanation"], - tags = json.loads(row["tags"]), - author_id = row["author_id"], - author_name = row["author_name"], - sdk_id = row["sdk_id"], - api_endpoints = json.loads(row["api_endpoints"]), - view_count = row["view_count"], - copy_count = row["copy_count"], - rating = row["rating"], - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + title=row["title"], + description=row["description"], + language=row["language"], + category=row["category"], + code=row["code"], + explanation=row["explanation"], + tags=json.loads(row["tags"]), + author_id=row["author_id"], + author_name=row["author_name"], + sdk_id=row["sdk_id"], + api_endpoints=json.loads(row["api_endpoints"]), + view_count=row["view_count"], + copy_count=row["copy_count"], + rating=row["rating"], + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_api_documentation(self, row) -> APIDocumentation: """将数据库行转换为 APIDocumentation""" return APIDocumentation( - id = row["id"], - version = row["version"], - openapi_spec = row["openapi_spec"], - markdown_content = row["markdown_content"], - html_content = row["html_content"], - changelog = row["changelog"], - generated_at = row["generated_at"], - generated_by = row["generated_by"], + id=row["id"], + version=row["version"], + openapi_spec=row["openapi_spec"], + markdown_content=row["markdown_content"], + html_content=row["html_content"], + changelog=row["changelog"], + generated_at=row["generated_at"], + generated_by=row["generated_by"], ) def _row_to_portal_config(self, row) -> DeveloperPortalConfig: """将数据库行转换为 DeveloperPortalConfig""" return DeveloperPortalConfig( - id = row["id"], - name = row["name"], - description = row["description"], - theme = row["theme"], - custom_css = row["custom_css"], - custom_js = row["custom_js"], - logo_url = row["logo_url"], - favicon_url = row["favicon_url"], - primary_color = row["primary_color"], - secondary_color = row["secondary_color"], - support_email = row["support_email"], - support_url = row["support_url"], - github_url = row["github_url"], - discord_url = row["discord_url"], - api_base_url = row["api_base_url"], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + name=row["name"], + description=row["description"], + theme=row["theme"], + custom_css=row["custom_css"], + custom_js=row["custom_js"], + logo_url=row["logo_url"], + favicon_url=row["favicon_url"], + primary_color=row["primary_color"], + secondary_color=row["secondary_color"], + support_email=row["support_email"], + support_url=row["support_url"], + github_url=row["github_url"], + discord_url=row["discord_url"], + api_base_url=row["api_base_url"], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) # Singleton instance -_developer_ecosystem_manager = None +_developer_ecosystem_manager = None def get_developer_ecosystem_manager() -> DeveloperEcosystemManager: """获取开发者生态系统管理器单例""" global _developer_ecosystem_manager if _developer_ecosystem_manager is None: - _developer_ecosystem_manager = DeveloperEcosystemManager() + _developer_ecosystem_manager = DeveloperEcosystemManager() return _developer_ecosystem_manager diff --git a/backend/document_processor.py b/backend/document_processor.py index a73913b..fc20405 100644 --- a/backend/document_processor.py +++ b/backend/document_processor.py @@ -12,7 +12,7 @@ class DocumentProcessor: """文档处理器 - 提取 PDF/DOCX 文本""" def __init__(self) -> None: - self.supported_formats = { + self.supported_formats = { ".pdf": self._extract_pdf, ".docx": self._extract_docx, ".doc": self._extract_docx, @@ -31,18 +31,18 @@ class DocumentProcessor: Returns: {"text": "提取的文本内容", "format": "文件格式"} """ - ext = os.path.splitext(filename.lower())[1] + ext = os.path.splitext(filename.lower())[1] if ext not in self.supported_formats: raise ValueError( f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}" ) - extractor = self.supported_formats[ext] - text = extractor(content) + extractor = self.supported_formats[ext] + text = extractor(content) # 清理文本 - text = self._clean_text(text) + text = self._clean_text(text) return {"text": text, "format": ext, "filename": filename} @@ -51,12 +51,12 @@ class DocumentProcessor: try: import PyPDF2 - pdf_file = io.BytesIO(content) - reader = PyPDF2.PdfReader(pdf_file) + pdf_file = io.BytesIO(content) + reader = PyPDF2.PdfReader(pdf_file) - text_parts = [] + text_parts = [] for page in reader.pages: - page_text = page.extract_text() + page_text = page.extract_text() if page_text: text_parts.append(page_text) @@ -66,10 +66,10 @@ class DocumentProcessor: try: import pdfplumber - text_parts = [] + text_parts = [] with pdfplumber.open(io.BytesIO(content)) as pdf: for page in pdf.pages: - page_text = page.extract_text() + page_text = page.extract_text() if page_text: text_parts.append(page_text) return "\n\n".join(text_parts) @@ -85,10 +85,10 @@ class DocumentProcessor: try: import docx - doc_file = io.BytesIO(content) - doc = docx.Document(doc_file) + doc_file = io.BytesIO(content) + doc = docx.Document(doc_file) - text_parts = [] + text_parts = [] for para in doc.paragraphs: if para.text.strip(): text_parts.append(para.text) @@ -96,7 +96,7 @@ class DocumentProcessor: # 提取表格中的文本 for table in doc.tables: for row in table.rows: - row_text = [] + row_text = [] for cell in row.cells: if cell.text.strip(): row_text.append(cell.text.strip()) @@ -114,7 +114,7 @@ class DocumentProcessor: def _extract_txt(self, content: bytes) -> str: """提取纯文本""" # 尝试多种编码 - encodings = ["utf-8", "gbk", "gb2312", "latin-1"] + encodings = ["utf-8", "gbk", "gb2312", "latin-1"] for encoding in encodings: try: @@ -123,7 +123,7 @@ class DocumentProcessor: continue # 如果都失败了,使用 latin-1 并忽略错误 - return content.decode("latin-1", errors = "ignore") + return content.decode("latin-1", errors="ignore") def _clean_text(self, text: str) -> str: """清理提取的文本""" @@ -131,29 +131,29 @@ class DocumentProcessor: return "" # 移除多余的空白字符 - lines = text.split("\n") - cleaned_lines = [] + lines = text.split("\n") + cleaned_lines = [] for line in lines: - line = line.strip() + line = line.strip() # 移除空行,但保留段落分隔 if line: cleaned_lines.append(line) # 合并行,保留段落结构 - text = "\n\n".join(cleaned_lines) + text = "\n\n".join(cleaned_lines) # 移除多余的空格 - text = " ".join(text.split()) + text = " ".join(text.split()) # 移除控制字符 - text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t") + text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t") return text.strip() def is_supported(self, filename: str) -> bool: """检查文件格式是否支持""" - ext = os.path.splitext(filename.lower())[1] + ext = os.path.splitext(filename.lower())[1] return ext in self.supported_formats @@ -165,7 +165,7 @@ class SimpleTextExtractor: def extract(self, content: bytes, filename: str) -> str: """尝试提取文本""" - encodings = ["utf-8", "gbk", "latin-1"] + encodings = ["utf-8", "gbk", "latin-1"] for encoding in encodings: try: @@ -173,15 +173,15 @@ class SimpleTextExtractor: except UnicodeDecodeError: continue - return content.decode("latin-1", errors = "ignore") + return content.decode("latin-1", errors="ignore") if __name__ == "__main__": # 测试 - processor = DocumentProcessor() + processor = DocumentProcessor() # 测试文本提取 - test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs." - result = processor.process(test_text.encode("utf-8"), "test.txt") + test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs." + result = processor.process(test_text.encode("utf-8"), "test.txt") print(f"Text extraction test: {len(result['text'])} chars") print(result["text"][:100]) diff --git a/backend/enterprise_manager.py b/backend/enterprise_manager.py index 44b61e8..2ffe5e8 100644 --- a/backend/enterprise_manager.py +++ b/backend/enterprise_manager.py @@ -19,64 +19,64 @@ from datetime import datetime, timedelta from enum import StrEnum from typing import Any -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class SSOProvider(StrEnum): """SSO 提供商类型""" - WECHAT_WORK = "wechat_work" # 企业微信 - DINGTALK = "dingtalk" # 钉钉 - FEISHU = "feishu" # 飞书 - OKTA = "okta" # Okta - AZURE_AD = "azure_ad" # Azure AD - GOOGLE = "google" # Google Workspace - CUSTOM_SAML = "custom_saml" # 自定义 SAML + WECHAT_WORK = "wechat_work" # 企业微信 + DINGTALK = "dingtalk" # 钉钉 + FEISHU = "feishu" # 飞书 + OKTA = "okta" # Okta + AZURE_AD = "azure_ad" # Azure AD + GOOGLE = "google" # Google Workspace + CUSTOM_SAML = "custom_saml" # 自定义 SAML class SSOStatus(StrEnum): """SSO 配置状态""" - DISABLED = "disabled" # 未启用 - PENDING = "pending" # 待配置 - ACTIVE = "active" # 已启用 - ERROR = "error" # 配置错误 + DISABLED = "disabled" # 未启用 + PENDING = "pending" # 待配置 + ACTIVE = "active" # 已启用 + ERROR = "error" # 配置错误 class SCIMSyncStatus(StrEnum): """SCIM 同步状态""" - IDLE = "idle" # 空闲 - SYNCING = "syncing" # 同步中 - SUCCESS = "success" # 同步成功 - FAILED = "failed" # 同步失败 + IDLE = "idle" # 空闲 + SYNCING = "syncing" # 同步中 + SUCCESS = "success" # 同步成功 + FAILED = "failed" # 同步失败 class AuditLogExportFormat(StrEnum): """审计日志导出格式""" - JSON = "json" - CSV = "csv" - PDF = "pdf" - XLSX = "xlsx" + JSON = "json" + CSV = "csv" + PDF = "pdf" + XLSX = "xlsx" class DataRetentionAction(StrEnum): """数据保留策略动作""" - ARCHIVE = "archive" # 归档 - DELETE = "delete" # 删除 - ANONYMIZE = "anonymize" # 匿名化 + ARCHIVE = "archive" # 归档 + DELETE = "delete" # 删除 + ANONYMIZE = "anonymize" # 匿名化 class ComplianceStandard(StrEnum): """合规标准""" - SOC2 = "soc2" - ISO27001 = "iso27001" - GDPR = "gdpr" - HIPAA = "hipaa" - PCI_DSS = "pci_dss" + SOC2 = "soc2" + ISO27001 = "iso27001" + GDPR = "gdpr" + HIPAA = "hipaa" + PCI_DSS = "pci_dss" @dataclass @@ -264,7 +264,7 @@ class EnterpriseManager: """企业级功能管理器""" # 默认属性映射 - DEFAULT_ATTRIBUTE_MAPPING = { + DEFAULT_ATTRIBUTE_MAPPING = { SSOProvider.WECHAT_WORK: { "email": "email", "name": "name", @@ -293,7 +293,7 @@ class EnterpriseManager: } # 合规标准字段映射 - COMPLIANCE_FIELDS = { + COMPLIANCE_FIELDS = { ComplianceStandard.SOC2: [ "timestamp", "user_id", @@ -329,21 +329,21 @@ class EnterpriseManager: ], } - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_db() def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_db(self) -> None: """初始化数据库表""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # SSO 配置表 cursor.execute(""" @@ -582,61 +582,61 @@ class EnterpriseManager: self, tenant_id: str, provider: str, - entity_id: str | None = None, - sso_url: str | None = None, - slo_url: str | None = None, - certificate: str | None = None, - metadata_url: str | None = None, - metadata_xml: str | None = None, - client_id: str | None = None, - client_secret: str | None = None, - authorization_url: str | None = None, - token_url: str | None = None, - userinfo_url: str | None = None, - scopes: list[str] | None = None, - attribute_mapping: dict[str, str] | None = None, - auto_provision: bool = True, - default_role: str = "member", - domain_restriction: list[str] | None = None, + entity_id: str | None = None, + sso_url: str | None = None, + slo_url: str | None = None, + certificate: str | None = None, + metadata_url: str | None = None, + metadata_xml: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + authorization_url: str | None = None, + token_url: str | None = None, + userinfo_url: str | None = None, + scopes: list[str] | None = None, + attribute_mapping: dict[str, str] | None = None, + auto_provision: bool = True, + default_role: str = "member", + domain_restriction: list[str] | None = None, ) -> SSOConfig: """创建 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config_id = str(uuid.uuid4()) - now = datetime.now() + config_id = str(uuid.uuid4()) + now = datetime.now() # 使用默认属性映射 if attribute_mapping is None and provider in self.DEFAULT_ATTRIBUTE_MAPPING: - attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)] + attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)] - config = SSOConfig( - id = config_id, - tenant_id = tenant_id, - provider = provider, - status = SSOStatus.PENDING.value, - entity_id = entity_id, - sso_url = sso_url, - slo_url = slo_url, - certificate = certificate, - metadata_url = metadata_url, - metadata_xml = metadata_xml, - client_id = client_id, - client_secret = client_secret, - authorization_url = authorization_url, - token_url = token_url, - userinfo_url = userinfo_url, - scopes = scopes or ["openid", "email", "profile"], - attribute_mapping = attribute_mapping or {}, - auto_provision = auto_provision, - default_role = default_role, - domain_restriction = domain_restriction or [], - created_at = now, - updated_at = now, - last_tested_at = None, - last_error = None, + config = SSOConfig( + id=config_id, + tenant_id=tenant_id, + provider=provider, + status=SSOStatus.PENDING.value, + entity_id=entity_id, + sso_url=sso_url, + slo_url=slo_url, + certificate=certificate, + metadata_url=metadata_url, + metadata_xml=metadata_xml, + client_id=client_id, + client_secret=client_secret, + authorization_url=authorization_url, + token_url=token_url, + userinfo_url=userinfo_url, + scopes=scopes or ["openid", "email", "profile"], + attribute_mapping=attribute_mapping or {}, + auto_provision=auto_provision, + default_role=default_role, + domain_restriction=domain_restriction or [], + created_at=now, + updated_at=now, + last_tested_at=None, + last_error=None, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO sso_configs @@ -685,11 +685,11 @@ class EnterpriseManager: def get_sso_config(self, config_id: str) -> SSOConfig | None: """获取 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_sso_config(row) @@ -699,12 +699,12 @@ class EnterpriseManager: conn.close() def get_tenant_sso_config( - self, tenant_id: str, provider: str | None = None + self, tenant_id: str, provider: str | None = None ) -> SSOConfig | None: """获取租户的 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() if provider: cursor.execute( @@ -725,7 +725,7 @@ class EnterpriseManager: (tenant_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_sso_config(row) @@ -736,16 +736,16 @@ class EnterpriseManager: def update_sso_config(self, config_id: str, **kwargs) -> SSOConfig | None: """更新 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config = self.get_sso_config(config_id) + config = self.get_sso_config(config_id) if not config: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "entity_id", "sso_url", "slo_url", @@ -782,7 +782,7 @@ class EnterpriseManager: params.append(datetime.now()) params.append(config_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE sso_configs SET {", ".join(updates)} @@ -799,9 +799,9 @@ class EnterpriseManager: def delete_sso_config(self, config_id: str) -> bool: """删除 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id, )) conn.commit() return cursor.rowcount > 0 @@ -810,9 +810,9 @@ class EnterpriseManager: def list_sso_configs(self, tenant_id: str) -> list[SSOConfig]: """列出租户的所有 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM sso_configs WHERE tenant_id = ? @@ -820,7 +820,7 @@ class EnterpriseManager: """, (tenant_id, ), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_sso_config(row) for row in rows] @@ -829,19 +829,19 @@ class EnterpriseManager: def generate_saml_metadata(self, config_id: str, base_url: str) -> str: """生成 SAML Service Provider 元数据""" - config = self.get_sso_config(config_id) + config = self.get_sso_config(config_id) if not config: raise ValueError(f"SSO config {config_id} not found") # 生成 SP 实体 ID - sp_entity_id = f"{base_url}/api/v1/sso/saml/{config.tenant_id}" - acs_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/acs" - slo_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/slo" + sp_entity_id = f"{base_url}/api/v1/sso/saml/{config.tenant_id}" + acs_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/acs" + slo_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/slo" # 生成 X.509 证书(简化实现,实际应该生成真实的密钥对) - cert = config.certificate or self._generate_self_signed_cert() + cert = config.certificate or self._generate_self_signed_cert() - metadata = f""" + metadata = f""" SAMLAuthRequest: """创建 SAML 认证请求""" - conn = self._get_connection() + conn = self._get_connection() try: - request_id = f"_{uuid.uuid4().hex}" - now = datetime.now() - expires = now + timedelta(minutes = 10) + request_id = f"_{uuid.uuid4().hex}" + now = datetime.now() + expires = now + timedelta(minutes=10) - auth_request = SAMLAuthRequest( - id = str(uuid.uuid4()), - tenant_id = tenant_id, - sso_config_id = config_id, - request_id = request_id, - relay_state = relay_state, - created_at = now, - expires_at = expires, - used = False, - used_at = None, + auth_request = SAMLAuthRequest( + id=str(uuid.uuid4()), + tenant_id=tenant_id, + sso_config_id=config_id, + request_id=request_id, + relay_state=relay_state, + created_at=now, + expires_at=expires, + used=False, + used_at=None, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO saml_auth_requests @@ -919,16 +919,16 @@ class EnterpriseManager: def get_saml_auth_request(self, request_id: str) -> SAMLAuthRequest | None: """获取 SAML 认证请求""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM saml_auth_requests WHERE request_id = ? """, (request_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_saml_request(row) @@ -942,27 +942,27 @@ class EnterpriseManager: # 这里应该实现实际的 SAML 响应解析 # 简化实现:假设响应已经验证并解析 - conn = self._get_connection() + conn = self._get_connection() try: # 解析 SAML Response(简化) # 实际应该使用 python-saml 或类似库 - attributes = self._parse_saml_response(saml_response) + attributes = self._parse_saml_response(saml_response) - auth_response = SAMLAuthResponse( - id = str(uuid.uuid4()), - request_id = request_id, - tenant_id = "", # 从 request 获取 - user_id = None, - email = attributes.get("email"), - name = attributes.get("name"), - attributes = attributes, - session_index = attributes.get("session_index"), - processed = False, - processed_at = None, - created_at = datetime.now(), + auth_response = SAMLAuthResponse( + id=str(uuid.uuid4()), + request_id=request_id, + tenant_id="", # 从 request 获取 + user_id=None, + email=attributes.get("email"), + name=attributes.get("name"), + attributes=attributes, + session_index=attributes.get("session_index"), + processed=False, + processed_at=None, + created_at=datetime.now(), ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO saml_auth_responses @@ -1017,35 +1017,35 @@ class EnterpriseManager: provider: str, scim_base_url: str, scim_token: str, - sync_interval_minutes: int = 60, - attribute_mapping: dict[str, str] | None = None, - sync_rules: dict[str, Any] | None = None, + sync_interval_minutes: int = 60, + attribute_mapping: dict[str, str] | None = None, + sync_rules: dict[str, Any] | None = None, ) -> SCIMConfig: """创建 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config_id = str(uuid.uuid4()) - now = datetime.now() + config_id = str(uuid.uuid4()) + now = datetime.now() - config = SCIMConfig( - id = config_id, - tenant_id = tenant_id, - provider = provider, - status = "disabled", - scim_base_url = scim_base_url, - scim_token = scim_token, - sync_interval_minutes = sync_interval_minutes, - last_sync_at = None, - last_sync_status = None, - last_sync_error = None, - last_sync_users_count = 0, - attribute_mapping = attribute_mapping or {}, - sync_rules = sync_rules or {}, - created_at = now, - updated_at = now, + config = SCIMConfig( + id=config_id, + tenant_id=tenant_id, + provider=provider, + status="disabled", + scim_base_url=scim_base_url, + scim_token=scim_token, + sync_interval_minutes=sync_interval_minutes, + last_sync_at=None, + last_sync_status=None, + last_sync_error=None, + last_sync_users_count=0, + attribute_mapping=attribute_mapping or {}, + sync_rules=sync_rules or {}, + created_at=now, + updated_at=now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO scim_configs @@ -1081,11 +1081,11 @@ class EnterpriseManager: def get_scim_config(self, config_id: str) -> SCIMConfig | None: """获取 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_scim_config(row) @@ -1096,9 +1096,9 @@ class EnterpriseManager: def get_tenant_scim_config(self, tenant_id: str) -> SCIMConfig | None: """获取租户的 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM scim_configs WHERE tenant_id = ? @@ -1106,7 +1106,7 @@ class EnterpriseManager: """, (tenant_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_scim_config(row) @@ -1117,16 +1117,16 @@ class EnterpriseManager: def update_scim_config(self, config_id: str, **kwargs) -> SCIMConfig | None: """更新 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config = self.get_scim_config(config_id) + config = self.get_scim_config(config_id) if not config: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "scim_base_url", "scim_token", "sync_interval_minutes", @@ -1150,7 +1150,7 @@ class EnterpriseManager: params.append(datetime.now()) params.append(config_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE scim_configs SET {", ".join(updates)} @@ -1167,16 +1167,16 @@ class EnterpriseManager: def sync_scim_users(self, config_id: str) -> dict[str, Any]: """执行 SCIM 用户同步""" - config = self.get_scim_config(config_id) + config = self.get_scim_config(config_id) if not config: raise ValueError(f"SCIM config {config_id} not found") - conn = self._get_connection() + conn = self._get_connection() try: - now = datetime.now() + now = datetime.now() # 更新同步状态 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE scim_configs @@ -1190,9 +1190,9 @@ class EnterpriseManager: try: # 模拟从 SCIM 服务端获取用户 # 实际应该使用 HTTP 请求获取 - users = self._fetch_scim_users(config) + users = self._fetch_scim_users(config) - synced_count = 0 + synced_count = 0 for user_data in users: self._upsert_scim_user(conn, config.tenant_id, user_data) synced_count += 1 @@ -1238,17 +1238,17 @@ class EnterpriseManager: self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any] ) -> None: """插入或更新 SCIM 用户""" - cursor = conn.cursor() + cursor = conn.cursor() - external_id = user_data.get("id") - user_name = user_data.get("userName", "") - email = user_data.get("emails", [{}])[0].get("value", "") - display_name = user_data.get("displayName") - name = user_data.get("name", {}) - given_name = name.get("givenName") - family_name = name.get("familyName") - active = user_data.get("active", True) - groups = [g.get("value") for g in user_data.get("groups", [])] + external_id = user_data.get("id") + user_name = user_data.get("userName", "") + email = user_data.get("emails", [{}])[0].get("value", "") + display_name = user_data.get("displayName") + name = user_data.get("name", {}) + given_name = name.get("givenName") + family_name = name.get("familyName") + active = user_data.get("active", True) + groups = [g.get("value") for g in user_data.get("groups", [])] cursor.execute( """ @@ -1284,14 +1284,14 @@ class EnterpriseManager: ), ) - def list_scim_users(self, tenant_id: str, active_only: bool = True) -> list[SCIMUser]: + def list_scim_users(self, tenant_id: str, active_only: bool = True) -> list[SCIMUser]: """列出 SCIM 用户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM scim_users WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM scim_users WHERE tenant_id = ?" + params = [tenant_id] if active_only: query += " AND active = 1" @@ -1299,7 +1299,7 @@ class EnterpriseManager: query += " ORDER BY synced_at DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_scim_user(row) for row in rows] @@ -1315,41 +1315,41 @@ class EnterpriseManager: start_date: datetime, end_date: datetime, created_by: str, - filters: dict[str, Any] | None = None, - compliance_standard: str | None = None, + filters: dict[str, Any] | None = None, + compliance_standard: str | None = None, ) -> AuditLogExport: """创建审计日志导出任务""" - conn = self._get_connection() + conn = self._get_connection() try: - export_id = str(uuid.uuid4()) - now = datetime.now() + export_id = str(uuid.uuid4()) + now = datetime.now() # 默认7天后过期 - expires_at = now + timedelta(days = 7) + expires_at = now + timedelta(days=7) - export = AuditLogExport( - id = export_id, - tenant_id = tenant_id, - export_format = export_format, - start_date = start_date, - end_date = end_date, - filters = filters or {}, - compliance_standard = compliance_standard, - status = "pending", - file_path = None, - file_size = None, - record_count = None, - checksum = None, - downloaded_by = None, - downloaded_at = None, - expires_at = expires_at, - created_by = created_by, - created_at = now, - completed_at = None, - error_message = None, + export = AuditLogExport( + id=export_id, + tenant_id=tenant_id, + export_format=export_format, + start_date=start_date, + end_date=end_date, + filters=filters or {}, + compliance_standard=compliance_standard, + status="pending", + file_path=None, + file_size=None, + record_count=None, + checksum=None, + downloaded_by=None, + downloaded_at=None, + expires_at=expires_at, + created_by=created_by, + created_at=now, + completed_at=None, + error_message=None, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO audit_log_exports @@ -1383,16 +1383,16 @@ class EnterpriseManager: finally: conn.close() - def process_audit_export(self, export_id: str, db_manager = None) -> AuditLogExport | None: + def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None: """处理审计日志导出任务""" - export = self.get_audit_export(export_id) + export = self.get_audit_export(export_id) if not export: return None - conn = self._get_connection() + conn = self._get_connection() try: # 更新状态为处理中 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE audit_log_exports SET status = 'processing' @@ -1404,20 +1404,20 @@ class EnterpriseManager: try: # 获取审计日志数据 - logs = self._fetch_audit_logs( + logs = self._fetch_audit_logs( export.tenant_id, export.start_date, export.end_date, export.filters, db_manager ) # 根据合规标准过滤字段 if export.compliance_standard: - logs = self._apply_compliance_filter(logs, export.compliance_standard) + logs = self._apply_compliance_filter(logs, export.compliance_standard) # 生成导出文件 - file_path, file_size, checksum = self._generate_export_file( + file_path, file_size, checksum = self._generate_export_file( export_id, logs, export.export_format ) - now = datetime.now() + now = datetime.now() # 更新导出记录 cursor.execute( @@ -1454,7 +1454,7 @@ class EnterpriseManager: start_date: datetime, end_date: datetime, filters: dict[str, Any], - db_manager = None, + db_manager=None, ) -> list[dict[str, Any]]: """获取审计日志数据""" if db_manager is None: @@ -1468,14 +1468,14 @@ class EnterpriseManager: self, logs: list[dict[str, Any]], standard: str ) -> list[dict[str, Any]]: """应用合规标准字段过滤""" - fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) + fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) if not fields: return logs - filtered_logs = [] + filtered_logs = [] for log in logs: - filtered_log = {k: v for k, v in log.items() if k in fields} + filtered_log = {k: v for k, v in log.items() if k in fields} filtered_logs.append(filtered_log) return filtered_logs @@ -1487,44 +1487,44 @@ class EnterpriseManager: import hashlib import os - export_dir = "/tmp/insightflow/exports" - os.makedirs(export_dir, exist_ok = True) + export_dir = "/tmp/insightflow/exports" + os.makedirs(export_dir, exist_ok=True) - file_path = f"{export_dir}/audit_export_{export_id}.{format}" + file_path = f"{export_dir}/audit_export_{export_id}.{format}" if format == "json": - content = json.dumps(logs, ensure_ascii = False, indent = 2) - with open(file_path, "w", encoding = "utf-8") as f: + content = json.dumps(logs, ensure_ascii=False, indent=2) + with open(file_path, "w", encoding="utf-8") as f: f.write(content) elif format == "csv": import csv if logs: - with open(file_path, "w", newline = "", encoding = "utf-8") as f: - writer = csv.DictWriter(f, fieldnames = logs[0].keys()) + with open(file_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=logs[0].keys()) writer.writeheader() writer.writerows(logs) else: # 其他格式暂不支持 - content = json.dumps(logs, ensure_ascii = False) - with open(file_path, "w", encoding = "utf-8") as f: + content = json.dumps(logs, ensure_ascii=False) + with open(file_path, "w", encoding="utf-8") as f: f.write(content) - file_size = os.path.getsize(file_path) + file_size = os.path.getsize(file_path) # 计算校验和 with open(file_path, "rb") as f: - checksum = hashlib.sha256(f.read()).hexdigest() + checksum = hashlib.sha256(f.read()).hexdigest() return file_path, file_size, checksum def get_audit_export(self, export_id: str) -> AuditLogExport | None: """获取审计日志导出记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_audit_export(row) @@ -1533,11 +1533,11 @@ class EnterpriseManager: finally: conn.close() - def list_audit_exports(self, tenant_id: str, limit: int = 100) -> list[AuditLogExport]: + def list_audit_exports(self, tenant_id: str, limit: int = 100) -> list[AuditLogExport]: """列出审计日志导出记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM audit_log_exports @@ -1547,7 +1547,7 @@ class EnterpriseManager: """, (tenant_id, limit), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_audit_export(row) for row in rows] @@ -1556,9 +1556,9 @@ class EnterpriseManager: def mark_export_downloaded(self, export_id: str, downloaded_by: str) -> bool: """标记导出文件已下载""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE audit_log_exports @@ -1581,42 +1581,42 @@ class EnterpriseManager: resource_type: str, retention_days: int, action: str, - description: str | None = None, - conditions: dict[str, Any] | None = None, - auto_execute: bool = False, - execute_at: str | None = None, - notify_before_days: int = 7, - archive_location: str | None = None, - archive_encryption: bool = True, + description: str | None = None, + conditions: dict[str, Any] | None = None, + auto_execute: bool = False, + execute_at: str | None = None, + notify_before_days: int = 7, + archive_location: str | None = None, + archive_encryption: bool = True, ) -> DataRetentionPolicy: """创建数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - policy_id = str(uuid.uuid4()) - now = datetime.now() + policy_id = str(uuid.uuid4()) + now = datetime.now() - policy = DataRetentionPolicy( - id = policy_id, - tenant_id = tenant_id, - name = name, - description = description, - resource_type = resource_type, - retention_days = retention_days, - action = action, - conditions = conditions or {}, - auto_execute = auto_execute, - execute_at = execute_at, - notify_before_days = notify_before_days, - archive_location = archive_location, - archive_encryption = archive_encryption, - is_active = True, - last_executed_at = None, - last_execution_result = None, - created_at = now, - updated_at = now, + policy = DataRetentionPolicy( + id=policy_id, + tenant_id=tenant_id, + name=name, + description=description, + resource_type=resource_type, + retention_days=retention_days, + action=action, + conditions=conditions or {}, + auto_execute=auto_execute, + execute_at=execute_at, + notify_before_days=notify_before_days, + archive_location=archive_location, + archive_encryption=archive_encryption, + is_active=True, + last_executed_at=None, + last_execution_result=None, + created_at=now, + updated_at=now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO data_retention_policies @@ -1658,11 +1658,11 @@ class EnterpriseManager: def get_retention_policy(self, policy_id: str) -> DataRetentionPolicy | None: """获取数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_retention_policy(row) @@ -1672,15 +1672,15 @@ class EnterpriseManager: conn.close() def list_retention_policies( - self, tenant_id: str, resource_type: str | None = None + self, tenant_id: str, resource_type: str | None = None ) -> list[DataRetentionPolicy]: """列出数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM data_retention_policies WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM data_retention_policies WHERE tenant_id = ?" + params = [tenant_id] if resource_type: query += " AND resource_type = ?" @@ -1689,7 +1689,7 @@ class EnterpriseManager: query += " ORDER BY created_at DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_retention_policy(row) for row in rows] @@ -1698,16 +1698,16 @@ class EnterpriseManager: def update_retention_policy(self, policy_id: str, **kwargs) -> DataRetentionPolicy | None: """更新数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - policy = self.get_retention_policy(policy_id) + policy = self.get_retention_policy(policy_id) if not policy: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "name", "description", "retention_days", @@ -1738,7 +1738,7 @@ class EnterpriseManager: params.append(datetime.now()) params.append(policy_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE data_retention_policies SET {", ".join(updates)} @@ -1755,9 +1755,9 @@ class EnterpriseManager: def delete_retention_policy(self, policy_id: str) -> bool: """删除数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id, )) conn.commit() return cursor.rowcount > 0 @@ -1766,31 +1766,31 @@ class EnterpriseManager: def execute_retention_policy(self, policy_id: str) -> DataRetentionJob: """执行数据保留策略""" - policy = self.get_retention_policy(policy_id) + policy = self.get_retention_policy(policy_id) if not policy: raise ValueError(f"Retention policy {policy_id} not found") - conn = self._get_connection() + conn = self._get_connection() try: - job_id = str(uuid.uuid4()) - now = datetime.now() + job_id = str(uuid.uuid4()) + now = datetime.now() - job = DataRetentionJob( - id = job_id, - policy_id = policy_id, - tenant_id = policy.tenant_id, - status = "running", - started_at = now, - completed_at = None, - affected_records = 0, - archived_records = 0, - deleted_records = 0, - error_count = 0, - details = {}, - created_at = now, + job = DataRetentionJob( + id=job_id, + policy_id=policy_id, + tenant_id=policy.tenant_id, + status="running", + started_at=now, + completed_at=None, + affected_records=0, + archived_records=0, + deleted_records=0, + error_count=0, + details={}, + created_at=now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO data_retention_jobs @@ -1804,17 +1804,17 @@ class EnterpriseManager: try: # 计算截止日期 - cutoff_date = now - timedelta(days = policy.retention_days) + cutoff_date = now - timedelta(days=policy.retention_days) # 根据资源类型执行不同的处理 if policy.resource_type == "audit_log": - result = self._retain_audit_logs(conn, policy, cutoff_date) + result = self._retain_audit_logs(conn, policy, cutoff_date) elif policy.resource_type == "project": - result = self._retain_projects(conn, policy, cutoff_date) + result = self._retain_projects(conn, policy, cutoff_date) elif policy.resource_type == "transcript": - result = self._retain_transcripts(conn, policy, cutoff_date) + result = self._retain_transcripts(conn, policy, cutoff_date) else: - result = {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} + result = {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} # 更新任务状态 cursor.execute( @@ -1879,7 +1879,7 @@ class EnterpriseManager: self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime ) -> dict[str, int]: """保留审计日志""" - cursor = conn.cursor() + cursor = conn.cursor() # 获取符合条件的记录数 cursor.execute( @@ -1889,7 +1889,7 @@ class EnterpriseManager: """, (cutoff_date, ), ) - count = cursor.fetchone()["count"] + count = cursor.fetchone()["count"] if policy.action == DataRetentionAction.DELETE.value: cursor.execute( @@ -1898,12 +1898,12 @@ class EnterpriseManager: """, (cutoff_date, ), ) - deleted = cursor.rowcount + deleted = cursor.rowcount return {"affected": count, "archived": 0, "deleted": deleted, "errors": 0} elif policy.action == DataRetentionAction.ARCHIVE.value: # 归档逻辑(简化实现) - archived = count + archived = count return {"affected": count, "archived": archived, "deleted": 0, "errors": 0} return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} @@ -1924,11 +1924,11 @@ class EnterpriseManager: def get_retention_job(self, job_id: str) -> DataRetentionJob | None: """获取数据保留任务""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_retention_job(row) @@ -1937,11 +1937,11 @@ class EnterpriseManager: finally: conn.close() - def list_retention_jobs(self, policy_id: str, limit: int = 100) -> list[DataRetentionJob]: + def list_retention_jobs(self, policy_id: str, limit: int = 100) -> list[DataRetentionJob]: """列出数据保留任务""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM data_retention_jobs @@ -1951,7 +1951,7 @@ class EnterpriseManager: """, (policy_id, limit), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_retention_job(row) for row in rows] @@ -1963,64 +1963,64 @@ class EnterpriseManager: def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig: """数据库行转换为 SSOConfig 对象""" return SSOConfig( - id = row["id"], - tenant_id = row["tenant_id"], - provider = row["provider"], - status = row["status"], - entity_id = row["entity_id"], - sso_url = row["sso_url"], - slo_url = row["slo_url"], - certificate = row["certificate"], - metadata_url = row["metadata_url"], - metadata_xml = row["metadata_xml"], - client_id = row["client_id"], - client_secret = row["client_secret"], - authorization_url = row["authorization_url"], - token_url = row["token_url"], - userinfo_url = row["userinfo_url"], - scopes = json.loads(row["scopes"] or '["openid", "email", "profile"]'), - attribute_mapping = json.loads(row["attribute_mapping"] or "{}"), - auto_provision = bool(row["auto_provision"]), - default_role = row["default_role"], - domain_restriction = json.loads(row["domain_restriction"] or "[]"), - created_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + provider=row["provider"], + status=row["status"], + entity_id=row["entity_id"], + sso_url=row["sso_url"], + slo_url=row["slo_url"], + certificate=row["certificate"], + metadata_url=row["metadata_url"], + metadata_xml=row["metadata_xml"], + client_id=row["client_id"], + client_secret=row["client_secret"], + authorization_url=row["authorization_url"], + token_url=row["token_url"], + userinfo_url=row["userinfo_url"], + scopes=json.loads(row["scopes"] or '["openid", "email", "profile"]'), + attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), + auto_provision=bool(row["auto_provision"]), + default_role=row["default_role"], + domain_restriction=json.loads(row["domain_restriction"] or "[]"), + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - last_tested_at = ( + last_tested_at=( datetime.fromisoformat(row["last_tested_at"]) if row["last_tested_at"] and isinstance(row["last_tested_at"], str) else row["last_tested_at"] ), - last_error = row["last_error"], + last_error=row["last_error"], ) def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest: """数据库行转换为 SAMLAuthRequest 对象""" return SAMLAuthRequest( - id = row["id"], - tenant_id = row["tenant_id"], - sso_config_id = row["sso_config_id"], - request_id = row["request_id"], - relay_state = row["relay_state"], - created_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + sso_config_id=row["sso_config_id"], + request_id=row["request_id"], + relay_state=row["relay_state"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - expires_at = ( + expires_at=( datetime.fromisoformat(row["expires_at"]) if isinstance(row["expires_at"], str) else row["expires_at"] ), - used = bool(row["used"]), - used_at = ( + used=bool(row["used"]), + used_at=( datetime.fromisoformat(row["used_at"]) if row["used_at"] and isinstance(row["used_at"], str) else row["used_at"] @@ -2030,29 +2030,29 @@ class EnterpriseManager: def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig: """数据库行转换为 SCIMConfig 对象""" return SCIMConfig( - id = row["id"], - tenant_id = row["tenant_id"], - provider = row["provider"], - status = row["status"], - scim_base_url = row["scim_base_url"], - scim_token = row["scim_token"], - sync_interval_minutes = row["sync_interval_minutes"], - last_sync_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + provider=row["provider"], + status=row["status"], + scim_base_url=row["scim_base_url"], + scim_token=row["scim_token"], + sync_interval_minutes=row["sync_interval_minutes"], + last_sync_at=( datetime.fromisoformat(row["last_sync_at"]) if row["last_sync_at"] and isinstance(row["last_sync_at"], str) else row["last_sync_at"] ), - last_sync_status = row["last_sync_status"], - last_sync_error = row["last_sync_error"], - last_sync_users_count = row["last_sync_users_count"], - attribute_mapping = json.loads(row["attribute_mapping"] or "{}"), - sync_rules = json.loads(row["sync_rules"] or "{}"), - created_at = ( + last_sync_status=row["last_sync_status"], + last_sync_error=row["last_sync_error"], + last_sync_users_count=row["last_sync_users_count"], + attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), + sync_rules=json.loads(row["sync_rules"] or "{}"), + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2062,28 +2062,28 @@ class EnterpriseManager: def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser: """数据库行转换为 SCIMUser 对象""" return SCIMUser( - id = row["id"], - tenant_id = row["tenant_id"], - external_id = row["external_id"], - user_name = row["user_name"], - email = row["email"], - display_name = row["display_name"], - given_name = row["given_name"], - family_name = row["family_name"], - active = bool(row["active"]), - groups = json.loads(row["groups"] or "[]"), - raw_data = json.loads(row["raw_data"] or "{}"), - synced_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + external_id=row["external_id"], + user_name=row["user_name"], + email=row["email"], + display_name=row["display_name"], + given_name=row["given_name"], + family_name=row["family_name"], + active=bool(row["active"]), + groups=json.loads(row["groups"] or "[]"), + raw_data=json.loads(row["raw_data"] or "{}"), + synced_at=( datetime.fromisoformat(row["synced_at"]) if isinstance(row["synced_at"], str) else row["synced_at"] ), - created_at = ( + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2093,78 +2093,78 @@ class EnterpriseManager: def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport: """数据库行转换为 AuditLogExport 对象""" return AuditLogExport( - id = row["id"], - tenant_id = row["tenant_id"], - export_format = row["export_format"], - start_date = ( + id=row["id"], + tenant_id=row["tenant_id"], + export_format=row["export_format"], + start_date=( datetime.fromisoformat(row["start_date"]) if isinstance(row["start_date"], str) else row["start_date"] ), - end_date = datetime.fromisoformat(row["end_date"]) + end_date=datetime.fromisoformat(row["end_date"]) if isinstance(row["end_date"], str) else row["end_date"], - filters = json.loads(row["filters"] or "{}"), - compliance_standard = row["compliance_standard"], - status = row["status"], - file_path = row["file_path"], - file_size = row["file_size"], - record_count = row["record_count"], - checksum = row["checksum"], - downloaded_by = row["downloaded_by"], - downloaded_at = ( + filters=json.loads(row["filters"] or "{}"), + compliance_standard=row["compliance_standard"], + status=row["status"], + file_path=row["file_path"], + file_size=row["file_size"], + record_count=row["record_count"], + checksum=row["checksum"], + downloaded_by=row["downloaded_by"], + downloaded_at=( datetime.fromisoformat(row["downloaded_at"]) if row["downloaded_at"] and isinstance(row["downloaded_at"], str) else row["downloaded_at"] ), - expires_at = ( + expires_at=( datetime.fromisoformat(row["expires_at"]) if isinstance(row["expires_at"], str) else row["expires_at"] ), - created_by = row["created_by"], - created_at = ( + created_by=row["created_by"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - completed_at = ( + completed_at=( datetime.fromisoformat(row["completed_at"]) if row["completed_at"] and isinstance(row["completed_at"], str) else row["completed_at"] ), - error_message = row["error_message"], + error_message=row["error_message"], ) def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy: """数据库行转换为 DataRetentionPolicy 对象""" return DataRetentionPolicy( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - resource_type = row["resource_type"], - retention_days = row["retention_days"], - action = row["action"], - conditions = json.loads(row["conditions"] or "{}"), - auto_execute = bool(row["auto_execute"]), - execute_at = row["execute_at"], - notify_before_days = row["notify_before_days"], - archive_location = row["archive_location"], - archive_encryption = bool(row["archive_encryption"]), - is_active = bool(row["is_active"]), - last_executed_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + resource_type=row["resource_type"], + retention_days=row["retention_days"], + action=row["action"], + conditions=json.loads(row["conditions"] or "{}"), + auto_execute=bool(row["auto_execute"]), + execute_at=row["execute_at"], + notify_before_days=row["notify_before_days"], + archive_location=row["archive_location"], + archive_encryption=bool(row["archive_encryption"]), + is_active=bool(row["is_active"]), + last_executed_at=( datetime.fromisoformat(row["last_executed_at"]) if row["last_executed_at"] and isinstance(row["last_executed_at"], str) else row["last_executed_at"] ), - last_execution_result = row["last_execution_result"], - created_at = ( + last_execution_result=row["last_execution_result"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2174,26 +2174,26 @@ class EnterpriseManager: def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob: """数据库行转换为 DataRetentionJob 对象""" return DataRetentionJob( - id = row["id"], - policy_id = row["policy_id"], - tenant_id = row["tenant_id"], - status = row["status"], - started_at = ( + id=row["id"], + policy_id=row["policy_id"], + tenant_id=row["tenant_id"], + status=row["status"], + started_at=( datetime.fromisoformat(row["started_at"]) if row["started_at"] and isinstance(row["started_at"], str) else row["started_at"] ), - completed_at = ( + completed_at=( datetime.fromisoformat(row["completed_at"]) if row["completed_at"] and isinstance(row["completed_at"], str) else row["completed_at"] ), - affected_records = row["affected_records"], - archived_records = row["archived_records"], - deleted_records = row["deleted_records"], - error_count = row["error_count"], - details = json.loads(row["details"] or "{}"), - created_at = ( + affected_records=row["affected_records"], + archived_records=row["archived_records"], + deleted_records=row["deleted_records"], + error_count=row["error_count"], + details=json.loads(row["details"] or "{}"), + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] @@ -2202,12 +2202,12 @@ class EnterpriseManager: # 全局实例 -_enterprise_manager = None +_enterprise_manager = None -def get_enterprise_manager(db_path: str = "insightflow.db") -> EnterpriseManager: +def get_enterprise_manager(db_path: str = "insightflow.db") -> EnterpriseManager: """获取 EnterpriseManager 单例""" global _enterprise_manager if _enterprise_manager is None: - _enterprise_manager = EnterpriseManager(db_path) + _enterprise_manager = EnterpriseManager(db_path) return _enterprise_manager diff --git a/backend/entity_aligner.py b/backend/entity_aligner.py index e1999d1..b43294e 100644 --- a/backend/entity_aligner.py +++ b/backend/entity_aligner.py @@ -12,8 +12,8 @@ import httpx import numpy as np # API Keys -KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") -KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") +KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") +KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") @dataclass @@ -27,9 +27,9 @@ class EntityEmbedding: class EntityAligner: """实体对齐器 - 使用 embedding 进行相似度匹配""" - def __init__(self, similarity_threshold: float = 0.85) -> None: - self.similarity_threshold = similarity_threshold - self.embedding_cache: dict[str, list[float]] = {} + def __init__(self, similarity_threshold: float = 0.85) -> None: + self.similarity_threshold = similarity_threshold + self.embedding_cache: dict[str, list[float]] = {} def get_embedding(self, text: str) -> list[float] | None: """ @@ -45,25 +45,25 @@ class EntityAligner: return None # 检查缓存 - cache_key = hash(text) + cache_key = hash(text) if cache_key in self.embedding_cache: return self.embedding_cache[cache_key] try: - response = httpx.post( + response = httpx.post( f"{KIMI_BASE_URL}/v1/embeddings", - headers = { + headers={ "Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json", }, - json = {"model": "k2p5", "input": text[:500]}, # 限制长度 - timeout = 30.0, + json={"model": "k2p5", "input": text[:500]}, # 限制长度 + timeout=30.0, ) response.raise_for_status() - result = response.json() + result = response.json() - embedding = result["data"][0]["embedding"] - self.embedding_cache[cache_key] = embedding + embedding = result["data"][0]["embedding"] + self.embedding_cache[cache_key] = embedding return embedding except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: @@ -81,20 +81,20 @@ class EntityAligner: Returns: 相似度分数 (0-1) """ - vec1 = np.array(embedding1) - vec2 = np.array(embedding2) + vec1 = np.array(embedding1) + vec2 = np.array(embedding2) # 余弦相似度 - dot_product = np.dot(vec1, vec2) - norm1 = np.linalg.norm(vec1) - norm2 = np.linalg.norm(vec2) + dot_product = np.dot(vec1, vec2) + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) if norm1 == 0 or norm2 == 0: return 0.0 return float(dot_product / (norm1 * norm2)) - def get_entity_text(self, name: str, definition: str = "") -> str: + def get_entity_text(self, name: str, definition: str = "") -> str: """ 构建用于 embedding 的实体文本 @@ -113,9 +113,9 @@ class EntityAligner: self, project_id: str, name: str, - definition: str = "", - exclude_id: str | None = None, - threshold: float | None = None, + definition: str = "", + exclude_id: str | None = None, + threshold: float | None = None, ) -> object | None: """ 查找相似的实体 @@ -131,54 +131,54 @@ class EntityAligner: 相似的实体或 None """ if threshold is None: - threshold = self.similarity_threshold + threshold = self.similarity_threshold try: from db_manager import get_db_manager - db = get_db_manager() + db = get_db_manager() except ImportError: return None # 获取项目的所有实体 - entities = db.get_all_entities_for_embedding(project_id) + entities = db.get_all_entities_for_embedding(project_id) if not entities: return None # 获取查询实体的 embedding - query_text = self.get_entity_text(name, definition) - query_embedding = self.get_embedding(query_text) + query_text = self.get_entity_text(name, definition) + query_embedding = self.get_embedding(query_text) if query_embedding is None: # 如果 embedding API 失败,回退到简单匹配 return self._fallback_similarity_match(entities, name, exclude_id) - best_match = None - best_score = threshold + best_match = None + best_score = threshold for entity in entities: if exclude_id and entity.id == exclude_id: continue # 获取实体的 embedding - entity_text = self.get_entity_text(entity.name, entity.definition) - entity_embedding = self.get_embedding(entity_text) + entity_text = self.get_entity_text(entity.name, entity.definition) + entity_embedding = self.get_embedding(entity_text) if entity_embedding is None: continue # 计算相似度 - similarity = self.compute_similarity(query_embedding, entity_embedding) + similarity = self.compute_similarity(query_embedding, entity_embedding) if similarity > best_score: - best_score = similarity - best_match = entity + best_score = similarity + best_match = entity return best_match def _fallback_similarity_match( - self, entities: list[object], name: str, exclude_id: str | None = None + self, entities: list[object], name: str, exclude_id: str | None = None ) -> object | None: """ 回退到简单的相似度匹配(不使用 embedding) @@ -191,7 +191,7 @@ class EntityAligner: Returns: 最相似的实体或 None """ - name_lower = name.lower() + name_lower = name.lower() # 1. 精确匹配 for entity in entities: @@ -212,7 +212,7 @@ class EntityAligner: return None def batch_align_entities( - self, project_id: str, new_entities: list[dict], threshold: float | None = None + self, project_id: str, new_entities: list[dict], threshold: float | None = None ) -> list[dict]: """ 批量对齐实体 @@ -226,16 +226,16 @@ class EntityAligner: 对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}] """ if threshold is None: - threshold = self.similarity_threshold + threshold = self.similarity_threshold - results = [] + results = [] for new_ent in new_entities: - matched = self.find_similar_entity( - project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold + matched = self.find_similar_entity( + project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold ) - result = { + result = { "new_entity": new_ent, "matched_entity": None, "similarity": 0.0, @@ -244,28 +244,28 @@ class EntityAligner: if matched: # 计算相似度 - query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", "")) - matched_text = self.get_entity_text(matched.name, matched.definition) + query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", "")) + matched_text = self.get_entity_text(matched.name, matched.definition) - query_emb = self.get_embedding(query_text) - matched_emb = self.get_embedding(matched_text) + query_emb = self.get_embedding(query_text) + matched_emb = self.get_embedding(matched_text) if query_emb and matched_emb: - similarity = self.compute_similarity(query_emb, matched_emb) - result["matched_entity"] = { + similarity = self.compute_similarity(query_emb, matched_emb) + result["matched_entity"] = { "id": matched.id, "name": matched.name, "type": matched.type, "definition": matched.definition, } - result["similarity"] = similarity - result["should_merge"] = similarity >= threshold + result["similarity"] = similarity + result["should_merge"] = similarity >= threshold results.append(result) return results - def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]: + def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]: """ 使用 LLM 建议实体的别名 @@ -279,7 +279,7 @@ class EntityAligner: if not KIMI_API_KEY: return [] - prompt = f"""为以下实体生成可能的别名或简称: + prompt = f"""为以下实体生成可能的别名或简称: 实体名称:{entity_name} 定义:{entity_definition} @@ -290,28 +290,28 @@ class EntityAligner: 只返回 JSON,不要其他内容。""" try: - response = httpx.post( + response = httpx.post( f"{KIMI_BASE_URL}/v1/chat/completions", - headers = { + headers={ "Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json", }, - json = { + json={ "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3, }, - timeout = 30.0, + timeout=30.0, ) response.raise_for_status() - result = response.json() - content = result["choices"][0]["message"]["content"] + result = response.json() + content = result["choices"][0]["message"]["content"] import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return data.get("aliases", []) except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: print(f"Alias suggestion failed: {e}") @@ -340,8 +340,8 @@ def simple_similarity(str1: str, str2: str) -> float: return 0.0 # 转换为小写 - s1 = str1.lower() - s2 = str2.lower() + s1 = str1.lower() + s2 = str2.lower() # 包含关系 if s1 in s2 or s2 in s1: @@ -355,11 +355,11 @@ def simple_similarity(str1: str, str2: str) -> float: if __name__ == "__main__": # 测试 - aligner = EntityAligner() + aligner = EntityAligner() # 测试 embedding - test_text = "Kubernetes 容器编排平台" - embedding = aligner.get_embedding(test_text) + test_text = "Kubernetes 容器编排平台" + embedding = aligner.get_embedding(test_text) if embedding: print(f"Embedding dimension: {len(embedding)}") print(f"First 5 values: {embedding[:5]}") @@ -367,7 +367,7 @@ if __name__ == "__main__": print("Embedding API not available") # 测试相似度计算 - emb1 = [1.0, 0.0, 0.0] - emb2 = [0.9, 0.1, 0.0] - sim = aligner.compute_similarity(emb1, emb2) + emb1 = [1.0, 0.0, 0.0] + emb2 = [0.9, 0.1, 0.0] + sim = aligner.compute_similarity(emb1, emb2) print(f"Similarity: {sim:.4f}") diff --git a/backend/export_manager.py b/backend/export_manager.py index 8d1547a..362b1b6 100644 --- a/backend/export_manager.py +++ b/backend/export_manager.py @@ -14,9 +14,9 @@ from typing import Any try: import pandas as pd - PANDAS_AVAILABLE = True + PANDAS_AVAILABLE = True except ImportError: - PANDAS_AVAILABLE = False + PANDAS_AVAILABLE = False try: from reportlab.lib import colors @@ -32,9 +32,9 @@ try: TableStyle, ) - REPORTLAB_AVAILABLE = True + REPORTLAB_AVAILABLE = True except ImportError: - REPORTLAB_AVAILABLE = False + REPORTLAB_AVAILABLE = False @dataclass @@ -71,8 +71,8 @@ class ExportTranscript: class ExportManager: """导出管理器 - 处理各种导出需求""" - def __init__(self, db_manager = None) -> None: - self.db = db_manager + def __init__(self, db_manager=None) -> None: + self.db = db_manager def export_knowledge_graph_svg( self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation] @@ -84,21 +84,21 @@ class ExportManager: SVG 字符串 """ # 计算布局参数 - width = 1200 - height = 800 - center_x = width / 2 - center_y = height / 2 - radius = 300 + width = 1200 + height = 800 + center_x = width / 2 + center_y = height / 2 + radius = 300 # 按类型分组实体 - entities_by_type = {} + entities_by_type = {} for e in entities: if e.type not in entities_by_type: - entities_by_type[e.type] = [] + entities_by_type[e.type] = [] entities_by_type[e.type].append(e) # 颜色映射 - type_colors = { + type_colors = { "PERSON": "#FF6B6B", "ORGANIZATION": "#4ECDC4", "LOCATION": "#45B7D1", @@ -110,17 +110,17 @@ class ExportManager: } # 计算实体位置 - entity_positions = {} - angle_step = 2 * 3.14159 / max(len(entities), 1) + entity_positions = {} + angle_step = 2 * 3.14159 / max(len(entities), 1) for i, entity in enumerate(entities): i * angle_step - x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50 - y = center_y + radius * 0.6 * ((i % 6) - 3) * 80 - entity_positions[entity.id] = (x, y) + x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50 + y = center_y + radius * 0.6 * ((i % 6) - 3) * 80 + entity_positions[entity.id] = (x, y) # 生成 SVG - svg_parts = [ + svg_parts = [ f'', "", @@ -137,17 +137,17 @@ class ExportManager: # 绘制关系连线 for rel in relations: if rel.source in entity_positions and rel.target in entity_positions: - x1, y1 = entity_positions[rel.source] - x2, y2 = entity_positions[rel.target] + x1, y1 = entity_positions[rel.source] + x2, y2 = entity_positions[rel.target] # 计算箭头终点(避免覆盖节点) - dx = x2 - x1 - dy = y2 - y1 - dist = (dx**2 + dy**2) ** 0.5 + dx = x2 - x1 + dy = y2 - y1 + dist = (dx**2 + dy**2) ** 0.5 if dist > 0: - offset = 40 - x2 = x2 - dx * offset / dist - y2 = y2 - dy * offset / dist + offset = 40 + x2 = x2 - dx * offset / dist + y2 = y2 - dy * offset / dist svg_parts.append( f'' @@ -169,8 +169,8 @@ class ExportManager: # 绘制实体节点 for entity in entities: if entity.id in entity_positions: - x, y = entity_positions[entity.id] - color = type_colors.get(entity.type, type_colors["default"]) + x, y = entity_positions[entity.id] + color = type_colors.get(entity.type, type_colors["default"]) # 节点圆圈 svg_parts.append( @@ -190,11 +190,11 @@ class ExportManager: ) # 图例 - legend_x = width - 150 - legend_y = 80 - rect_x = legend_x - 10 - rect_y = legend_y - 20 - rect_height = len(type_colors) * 25 + 10 + legend_x = width - 150 + legend_y = 80 + rect_x = legend_x - 10 + rect_y = legend_y - 20 + rect_height = len(type_colors) * 25 + 10 svg_parts.append( f'' @@ -206,11 +206,11 @@ class ExportManager: for i, (etype, color) in enumerate(type_colors.items()): if etype != "default": - y_pos = legend_y + 25 + i * 20 + y_pos = legend_y + 25 + i * 20 svg_parts.append( f'' ) - text_y = y_pos + 4 + text_y = y_pos + 4 svg_parts.append( f'{etype}' @@ -231,12 +231,12 @@ class ExportManager: try: import cairosvg - svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) - png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8")) + svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) + png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) return png_bytes except ImportError: # 如果没有 cairosvg,返回 SVG 的 base64 - svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) + svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) return base64.b64encode(svg_content.encode("utf-8")) def export_entities_excel(self, entities: list[ExportEntity]) -> bytes: @@ -250,9 +250,9 @@ class ExportManager: raise ImportError("pandas is required for Excel export") # 准备数据 - data = [] + data = [] for e in entities: - row = { + row = { "ID": e.id, "名称": e.name, "类型": e.type, @@ -262,29 +262,29 @@ class ExportManager: } # 添加属性 for attr_name, attr_value in e.attributes.items(): - row[f"属性:{attr_name}"] = attr_value + row[f"属性:{attr_name}"] = attr_value data.append(row) - df = pd.DataFrame(data) + df = pd.DataFrame(data) # 写入 Excel - output = io.BytesIO() - with pd.ExcelWriter(output, engine = "openpyxl") as writer: - df.to_excel(writer, sheet_name = "实体列表", index = False) + output = io.BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, sheet_name="实体列表", index=False) # 调整列宽 - worksheet = writer.sheets["实体列表"] + worksheet = writer.sheets["实体列表"] for column in worksheet.columns: - max_length = 0 - column_letter = column[0].column_letter + max_length = 0 + column_letter = column[0].column_letter for cell in column: try: if len(str(cell.value)) > max_length: - max_length = len(str(cell.value)) + max_length = len(str(cell.value)) except (AttributeError, TypeError, ValueError): pass - adjusted_width = min(max_length + 2, 50) - worksheet.column_dimensions[column_letter].width = adjusted_width + adjusted_width = min(max_length + 2, 50) + worksheet.column_dimensions[column_letter].width = adjusted_width return output.getvalue() @@ -295,24 +295,24 @@ class ExportManager: Returns: CSV 字符串 """ - output = io.StringIO() + output = io.StringIO() # 收集所有可能的属性列 - all_attrs = set() + all_attrs = set() for e in entities: all_attrs.update(e.attributes.keys()) # 表头 - headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [ + headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [ f"属性:{a}" for a in sorted(all_attrs) ] - writer = csv.writer(output) + writer = csv.writer(output) writer.writerow(headers) # 数据行 for e in entities: - row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count] + row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count] for attr in sorted(all_attrs): row.append(e.attributes.get(attr, "")) writer.writerow(row) @@ -327,8 +327,8 @@ class ExportManager: CSV 字符串 """ - output = io.StringIO() - writer = csv.writer(output) + output = io.StringIO() + writer = csv.writer(output) writer.writerow(["ID", "源实体", "目标实体", "关系类型", "置信度", "证据"]) for r in relations: @@ -345,7 +345,7 @@ class ExportManager: Returns: Markdown 字符串 """ - lines = [ + lines = [ f"# {transcript.name}", "", f"**类型**: {transcript.type}", @@ -369,10 +369,10 @@ class ExportManager: ] ) for seg in transcript.segments: - speaker = seg.get("speaker", "Unknown") - start = seg.get("start", 0) - end = seg.get("end", 0) - text = seg.get("text", "") + speaker = seg.get("speaker", "Unknown") + start = seg.get("start", 0) + end = seg.get("end", 0) + text = seg.get("text", "") lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}") lines.append("") @@ -387,12 +387,12 @@ class ExportManager: ] ) for mention in transcript.entity_mentions: - entity_id = mention.get("entity_id", "") - entity = entities_map.get(entity_id) - entity_name = entity.name if entity else mention.get("entity_name", "Unknown") - entity_type = entity.type if entity else "Unknown" - position = mention.get("position", "") - context = mention.get("context", "")[:50] + "..." if mention.get("context") else "" + entity_id = mention.get("entity_id", "") + entity = entities_map.get(entity_id) + entity_name = entity.name if entity else mention.get("entity_name", "Unknown") + entity_type = entity.type if entity else "Unknown" + position = mention.get("position", "") + context = mention.get("context", "")[:50] + "..." if mention.get("context") else "" lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |") return "\n".join(lines) @@ -404,7 +404,7 @@ class ExportManager: entities: list[ExportEntity], relations: list[ExportRelation], transcripts: list[ExportTranscript], - summary: str = "", + summary: str = "", ) -> bytes: """ 导出项目报告为 PDF 格式 @@ -415,29 +415,29 @@ class ExportManager: if not REPORTLAB_AVAILABLE: raise ImportError("reportlab is required for PDF export") - output = io.BytesIO() - doc = SimpleDocTemplate( - output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18 + output = io.BytesIO() + doc = SimpleDocTemplate( + output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18 ) # 样式 - styles = getSampleStyleSheet() - title_style = ParagraphStyle( + styles = getSampleStyleSheet() + title_style = ParagraphStyle( "CustomTitle", - parent = styles["Heading1"], - fontSize = 24, - spaceAfter = 30, - textColor = colors.HexColor("#2c3e50"), + parent=styles["Heading1"], + fontSize=24, + spaceAfter=30, + textColor=colors.HexColor("#2c3e50"), ) - heading_style = ParagraphStyle( + heading_style = ParagraphStyle( "CustomHeading", - parent = styles["Heading2"], - fontSize = 16, - spaceAfter = 12, - textColor = colors.HexColor("#34495e"), + parent=styles["Heading2"], + fontSize=16, + spaceAfter=12, + textColor=colors.HexColor("#34495e"), ) - story = [] + story = [] # 标题页 story.append(Paragraph("InsightFlow 项目报告", title_style)) @@ -452,7 +452,7 @@ class ExportManager: # 统计概览 story.append(Paragraph("项目概览", heading_style)) - stats_data = [ + stats_data = [ ["指标", "数值"], ["实体数量", str(len(entities))], ["关系数量", str(len(relations))], @@ -460,14 +460,14 @@ class ExportManager: ] # 按类型统计实体 - type_counts = {} + type_counts = {} for e in entities: - type_counts[e.type] = type_counts.get(e.type, 0) + 1 + type_counts[e.type] = type_counts.get(e.type, 0) + 1 for etype, count in sorted(type_counts.items()): stats_data.append([f"{etype} 实体", str(count)]) - stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch]) + stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch]) stats_table.setStyle( TableStyle( [ @@ -496,8 +496,8 @@ class ExportManager: story.append(PageBreak()) story.append(Paragraph("实体列表", heading_style)) - entity_data = [["名称", "类型", "提及次数", "定义"]] - for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[ + entity_data = [["名称", "类型", "提及次数", "定义"]] + for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[ :50 ]: # 限制前50个 entity_data.append( @@ -509,8 +509,8 @@ class ExportManager: ] ) - entity_table = Table( - entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch] + entity_table = Table( + entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch] ) entity_table.setStyle( TableStyle( @@ -534,12 +534,12 @@ class ExportManager: story.append(PageBreak()) story.append(Paragraph("关系列表", heading_style)) - relation_data = [["源实体", "关系", "目标实体", "置信度"]] + relation_data = [["源实体", "关系", "目标实体", "置信度"]] for r in relations[:100]: # 限制前100个 relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"]) - relation_table = Table( - relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch] + relation_table = Table( + relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch] ) relation_table.setStyle( TableStyle( @@ -574,7 +574,7 @@ class ExportManager: Returns: JSON 字符串 """ - data = { + data = { "project_id": project_id, "project_name": project_name, "export_time": datetime.now().isoformat(), @@ -613,16 +613,16 @@ class ExportManager: ], } - return json.dumps(data, ensure_ascii = False, indent = 2) + return json.dumps(data, ensure_ascii=False, indent=2) # 全局导出管理器实例 -_export_manager = None +_export_manager = None -def get_export_manager(db_manager = None) -> None: +def get_export_manager(db_manager=None) -> None: """获取导出管理器实例""" global _export_manager if _export_manager is None: - _export_manager = ExportManager(db_manager) + _export_manager = ExportManager(db_manager) return _export_manager diff --git a/backend/growth_manager.py b/backend/growth_manager.py index ffcae8f..ee51790 100644 --- a/backend/growth_manager.py +++ b/backend/growth_manager.py @@ -26,88 +26,88 @@ from typing import Any import httpx # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class EventType(StrEnum): """事件类型""" - PAGE_VIEW = "page_view" # 页面浏览 - FEATURE_USE = "feature_use" # 功能使用 - CONVERSION = "conversion" # 转化 - SIGNUP = "signup" # 注册 - LOGIN = "login" # 登录 - UPGRADE = "upgrade" # 升级 - DOWNGRADE = "downgrade" # 降级 - CANCEL = "cancel" # 取消订阅 - INVITE_SENT = "invite_sent" # 发送邀请 - INVITE_ACCEPTED = "invite_accepted" # 接受邀请 - REFERRAL_REWARD = "referral_reward" # 推荐奖励 + PAGE_VIEW = "page_view" # 页面浏览 + FEATURE_USE = "feature_use" # 功能使用 + CONVERSION = "conversion" # 转化 + SIGNUP = "signup" # 注册 + LOGIN = "login" # 登录 + UPGRADE = "upgrade" # 升级 + DOWNGRADE = "downgrade" # 降级 + CANCEL = "cancel" # 取消订阅 + INVITE_SENT = "invite_sent" # 发送邀请 + INVITE_ACCEPTED = "invite_accepted" # 接受邀请 + REFERRAL_REWARD = "referral_reward" # 推荐奖励 class ExperimentStatus(StrEnum): """实验状态""" - DRAFT = "draft" # 草稿 - RUNNING = "running" # 运行中 - PAUSED = "paused" # 暂停 - COMPLETED = "completed" # 已完成 - ARCHIVED = "archived" # 已归档 + DRAFT = "draft" # 草稿 + RUNNING = "running" # 运行中 + PAUSED = "paused" # 暂停 + COMPLETED = "completed" # 已完成 + ARCHIVED = "archived" # 已归档 class TrafficAllocationType(StrEnum): """流量分配类型""" - RANDOM = "random" # 随机分配 - STRATIFIED = "stratified" # 分层分配 - TARGETED = "targeted" # 定向分配 + RANDOM = "random" # 随机分配 + STRATIFIED = "stratified" # 分层分配 + TARGETED = "targeted" # 定向分配 class EmailTemplateType(StrEnum): """邮件模板类型""" - WELCOME = "welcome" # 欢迎邮件 - ONBOARDING = "onboarding" # 引导邮件 - FEATURE_ANNOUNCEMENT = "feature_announcement" # 功能公告 - CHURN_RECOVERY = "churn_recovery" # 流失挽回 - UPGRADE_PROMPT = "upgrade_prompt" # 升级提示 - REFERRAL = "referral" # 推荐邀请 - NEWSLETTER = "newsletter" # 新闻通讯 + WELCOME = "welcome" # 欢迎邮件 + ONBOARDING = "onboarding" # 引导邮件 + FEATURE_ANNOUNCEMENT = "feature_announcement" # 功能公告 + CHURN_RECOVERY = "churn_recovery" # 流失挽回 + UPGRADE_PROMPT = "upgrade_prompt" # 升级提示 + REFERRAL = "referral" # 推荐邀请 + NEWSLETTER = "newsletter" # 新闻通讯 class EmailStatus(StrEnum): """邮件状态""" - DRAFT = "draft" # 草稿 - SCHEDULED = "scheduled" # 已计划 - SENDING = "sending" # 发送中 - SENT = "sent" # 已发送 - DELIVERED = "delivered" # 已送达 - OPENED = "opened" # 已打开 - CLICKED = "clicked" # 已点击 - BOUNCED = "bounced" # 退信 - FAILED = "failed" # 失败 + DRAFT = "draft" # 草稿 + SCHEDULED = "scheduled" # 已计划 + SENDING = "sending" # 发送中 + SENT = "sent" # 已发送 + DELIVERED = "delivered" # 已送达 + OPENED = "opened" # 已打开 + CLICKED = "clicked" # 已点击 + BOUNCED = "bounced" # 退信 + FAILED = "failed" # 失败 class WorkflowTriggerType(StrEnum): """工作流触发类型""" - USER_SIGNUP = "user_signup" # 用户注册 - USER_LOGIN = "user_login" # 用户登录 - SUBSCRIPTION_CREATED = "subscription_created" # 创建订阅 - SUBSCRIPTION_CANCELLED = "subscription_cancelled" # 取消订阅 - INACTIVITY = "inactivity" # 不活跃 - MILESTONE = "milestone" # 里程碑 - CUSTOM_EVENT = "custom_event" # 自定义事件 + USER_SIGNUP = "user_signup" # 用户注册 + USER_LOGIN = "user_login" # 用户登录 + SUBSCRIPTION_CREATED = "subscription_created" # 创建订阅 + SUBSCRIPTION_CANCELLED = "subscription_cancelled" # 取消订阅 + INACTIVITY = "inactivity" # 不活跃 + MILESTONE = "milestone" # 里程碑 + CUSTOM_EVENT = "custom_event" # 自定义事件 class ReferralStatus(StrEnum): """推荐状态""" - PENDING = "pending" # 待处理 - CONVERTED = "converted" # 已转化 - REWARDED = "rewarded" # 已奖励 - EXPIRED = "expired" # 已过期 + PENDING = "pending" # 待处理 + CONVERTED = "converted" # 已转化 + REWARDED = "rewarded" # 已奖励 + EXPIRED = "expired" # 已过期 @dataclass @@ -362,17 +362,17 @@ class TeamIncentive: class GrowthManager: """运营与增长管理主类""" - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path - self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "") - self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "") - self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "") - self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "") + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "") + self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "") + self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "") + self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "") def _get_db(self) -> None: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn # ==================== 用户行为分析 ==================== @@ -383,30 +383,30 @@ class GrowthManager: user_id: str, event_type: EventType, event_name: str, - properties: dict = None, - session_id: str = None, - device_info: dict = None, - referrer: str = None, - utm_params: dict = None, + properties: dict = None, + session_id: str = None, + device_info: dict = None, + referrer: str = None, + utm_params: dict = None, ) -> AnalyticsEvent: """追踪事件""" - event_id = f"evt_{uuid.uuid4().hex[:16]}" - now = datetime.now() + event_id = f"evt_{uuid.uuid4().hex[:16]}" + now = datetime.now() - event = AnalyticsEvent( - id = event_id, - tenant_id = tenant_id, - user_id = user_id, - event_type = event_type, - event_name = event_name, - properties = properties or {}, - timestamp = now, - session_id = session_id, - device_info = device_info or {}, - referrer = referrer, - utm_source = utm_params.get("source") if utm_params else None, - utm_medium = utm_params.get("medium") if utm_params else None, - utm_campaign = utm_params.get("campaign") if utm_params else None, + event = AnalyticsEvent( + id=event_id, + tenant_id=tenant_id, + user_id=user_id, + event_type=event_type, + event_name=event_name, + properties=properties or {}, + timestamp=now, + session_id=session_id, + device_info=device_info or {}, + referrer=referrer, + utm_source=utm_params.get("source") if utm_params else None, + utm_medium=utm_params.get("medium") if utm_params else None, + utm_campaign=utm_params.get("campaign") if utm_params else None, ) with self._get_db() as conn: @@ -445,7 +445,7 @@ class GrowthManager: async def _send_to_analytics_platforms(self, event: AnalyticsEvent) -> None: """发送事件到第三方分析平台""" - tasks = [] + tasks = [] if self.mixpanel_token: tasks.append(self._send_to_mixpanel(event)) @@ -453,17 +453,17 @@ class GrowthManager: tasks.append(self._send_to_amplitude(event)) if tasks: - await asyncio.gather(*tasks, return_exceptions = True) + await asyncio.gather(*tasks, return_exceptions=True) async def _send_to_mixpanel(self, event: AnalyticsEvent) -> None: """发送事件到 Mixpanel""" try: - headers = { + headers = { "Content-Type": "application/json", "Authorization": f"Basic {self.mixpanel_token}", } - payload = { + payload = { "event": event.event_name, "properties": { "distinct_id": event.user_id, @@ -475,7 +475,7 @@ class GrowthManager: async with httpx.AsyncClient() as client: await client.post( - "https://api.mixpanel.com/track", headers = headers, json = [payload], timeout = 10.0 + "https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0 ) except (RuntimeError, ValueError, TypeError) as e: print(f"Failed to send to Mixpanel: {e}") @@ -483,9 +483,9 @@ class GrowthManager: async def _send_to_amplitude(self, event: AnalyticsEvent) -> None: """发送事件到 Amplitude""" try: - headers = {"Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} - payload = { + payload = { "api_key": self.amplitude_api_key, "events": [ { @@ -501,9 +501,9 @@ class GrowthManager: async with httpx.AsyncClient() as client: await client.post( "https://api.amplitude.com/2/httpapi", - headers = headers, - json = payload, - timeout = 10.0, + headers=headers, + json=payload, + timeout=10.0, ) except (RuntimeError, ValueError, TypeError) as e: print(f"Failed to send to Amplitude: {e}") @@ -514,18 +514,18 @@ class GrowthManager: """更新用户画像""" with self._get_db() as conn: # 检查用户画像是否存在 - row = conn.execute( + row = conn.execute( "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", (tenant_id, user_id), ).fetchone() - now = datetime.now().isoformat() + now = datetime.now().isoformat() if row: # 更新现有画像 - feature_usage = json.loads(row["feature_usage"]) + feature_usage = json.loads(row["feature_usage"]) if event_name not in feature_usage: - feature_usage[event_name] = 0 + feature_usage[event_name] = 0 feature_usage[event_name] += 1 conn.execute( @@ -539,7 +539,7 @@ class GrowthManager: ) else: # 创建新画像 - profile_id = f"up_{uuid.uuid4().hex[:16]}" + profile_id = f"up_{uuid.uuid4().hex[:16]}" conn.execute( """ INSERT INTO user_profiles @@ -571,7 +571,7 @@ class GrowthManager: def get_user_profile(self, tenant_id: str, user_id: str) -> UserProfile | None: """获取用户画像""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", (tenant_id, user_id), ).fetchone() @@ -581,11 +581,11 @@ class GrowthManager: return None def get_user_analytics_summary( - self, tenant_id: str, start_date: datetime = None, end_date: datetime = None + self, tenant_id: str, start_date: datetime = None, end_date: datetime = None ) -> dict: """获取用户分析汇总""" with self._get_db() as conn: - query = """ + query = """ SELECT COUNT(DISTINCT user_id) as unique_users, COUNT(*) as total_events, @@ -594,7 +594,7 @@ class GrowthManager: FROM analytics_events WHERE tenant_id = ? """ - params = [tenant_id] + params = [tenant_id] if start_date: query += " AND timestamp >= ?" @@ -603,15 +603,15 @@ class GrowthManager: query += " AND timestamp <= ?" params.append(end_date.isoformat()) - row = conn.execute(query, params).fetchone() + row = conn.execute(query, params).fetchone() # 获取事件类型分布 - type_query = """ + type_query = """ SELECT event_type, COUNT(*) as count FROM analytics_events WHERE tenant_id = ? """ - type_params = [tenant_id] + type_params = [tenant_id] if start_date: type_query += " AND timestamp >= ?" @@ -622,7 +622,7 @@ class GrowthManager: type_query += " GROUP BY event_type" - type_rows = conn.execute(type_query, type_params).fetchall() + type_rows = conn.execute(type_query, type_params).fetchall() return { "unique_users": row["unique_users"], @@ -638,17 +638,17 @@ class GrowthManager: self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str ) -> Funnel: """创建转化漏斗""" - funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - funnel = Funnel( - id = funnel_id, - tenant_id = tenant_id, - name = name, - description = description, - steps = steps, - created_at = now, - updated_at = now, + funnel = Funnel( + id=funnel_id, + tenant_id=tenant_id, + name=name, + description=description, + steps=steps, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -673,46 +673,46 @@ class GrowthManager: return funnel def analyze_funnel( - self, funnel_id: str, period_start: datetime = None, period_end: datetime = None + self, funnel_id: str, period_start: datetime = None, period_end: datetime = None ) -> FunnelAnalysis | None: """分析漏斗转化率""" with self._get_db() as conn: - funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id, )).fetchone() + funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id, )).fetchone() if not funnel_row: return None - steps = json.loads(funnel_row["steps"]) + steps = json.loads(funnel_row["steps"]) if not period_start: - period_start = datetime.now() - timedelta(days = 30) + period_start = datetime.now() - timedelta(days=30) if not period_end: - period_end = datetime.now() + period_end = datetime.now() # 计算每步转化 - step_conversions = [] - previous_count = None + step_conversions = [] + previous_count = None for step in steps: - event_name = step.get("event_name") + event_name = step.get("event_name") - query = """ + query = """ SELECT COUNT(DISTINCT user_id) as user_count FROM analytics_events WHERE event_name = ? AND timestamp >= ? AND timestamp <= ? """ - row = conn.execute( + row = conn.execute( query, (event_name, period_start.isoformat(), period_end.isoformat()) ).fetchone() - user_count = row["user_count"] if row else 0 + user_count = row["user_count"] if row else 0 - conversion_rate = 0.0 - drop_off_rate = 0.0 + conversion_rate = 0.0 + drop_off_rate = 0.0 if previous_count and previous_count > 0: - conversion_rate = user_count / previous_count - drop_off_rate = 1 - conversion_rate + conversion_rate = user_count / previous_count + drop_off_rate = 1 - conversion_rate step_conversions.append( { @@ -724,41 +724,41 @@ class GrowthManager: } ) - previous_count = user_count + previous_count = user_count # 计算总体转化率 if steps and step_conversions: - first_step_count = step_conversions[0]["user_count"] - last_step_count = step_conversions[-1]["user_count"] - overall_conversion = last_step_count / max(first_step_count, 1) + first_step_count = step_conversions[0]["user_count"] + last_step_count = step_conversions[-1]["user_count"] + overall_conversion = last_step_count / max(first_step_count, 1) else: - overall_conversion = 0.0 + overall_conversion = 0.0 # 找出主要流失点 - drop_off_points = [ + drop_off_points = [ s for s in step_conversions if s["drop_off_rate"] > 0.2 and s != step_conversions[0] ] return FunnelAnalysis( - funnel_id = funnel_id, - period_start = period_start, - period_end = period_end, - total_users = step_conversions[0]["user_count"] if step_conversions else 0, - step_conversions = step_conversions, - overall_conversion = round(overall_conversion, 4), - drop_off_points = drop_off_points, + funnel_id=funnel_id, + period_start=period_start, + period_end=period_end, + total_users=step_conversions[0]["user_count"] if step_conversions else 0, + step_conversions=step_conversions, + overall_conversion=round(overall_conversion, 4), + drop_off_points=drop_off_points, ) def calculate_retention( - self, tenant_id: str, cohort_date: datetime, periods: list[int] = None + self, tenant_id: str, cohort_date: datetime, periods: list[int] = None ) -> dict: """计算留存率""" if periods is None: - periods = [1, 3, 7, 14, 30] + periods = [1, 3, 7, 14, 30] with self._get_db() as conn: # 获取同期群用户(在 cohort_date 当天首次活跃的用户) - cohort_query = """ + cohort_query = """ SELECT DISTINCT user_id FROM analytics_events WHERE tenant_id = ? AND date(timestamp) = date(?) @@ -767,36 +767,36 @@ class GrowthManager: WHERE tenant_id = ? AND date(first_seen) = date(?) ) """ - cohort_rows = conn.execute( + cohort_rows = conn.execute( cohort_query, (tenant_id, cohort_date.isoformat(), tenant_id, cohort_date.isoformat()), ).fetchall() - cohort_users = {r["user_id"] for r in cohort_rows} - cohort_size = len(cohort_users) + cohort_users = {r["user_id"] for r in cohort_rows} + cohort_size = len(cohort_users) if cohort_size == 0: return {"cohort_date": cohort_date.isoformat(), "cohort_size": 0, "retention": {}} - retention_rates = {} + retention_rates = {} for period in periods: - period_date = cohort_date + timedelta(days = period) + period_date = cohort_date + timedelta(days=period) - active_query = """ + active_query = """ SELECT COUNT(DISTINCT user_id) as active_count FROM analytics_events WHERE tenant_id = ? AND date(timestamp) = date(?) AND user_id IN ({}) """.format(", ".join(["?" for _ in cohort_users])) - params = [tenant_id, period_date.isoformat()] + list(cohort_users) - row = conn.execute(active_query, params).fetchone() + params = [tenant_id, period_date.isoformat()] + list(cohort_users) + row = conn.execute(active_query, params).fetchone() - active_count = row["active_count"] if row else 0 - retention_rate = active_count / cohort_size + active_count = row["active_count"] if row else 0 + retention_rate = active_count / cohort_size - retention_rates[f"day_{period}"] = { + retention_rates[f"day_{period}"] = { "active_users": active_count, "retention_rate": round(retention_rate, 4), } @@ -821,34 +821,34 @@ class GrowthManager: target_audience: dict, primary_metric: str, secondary_metrics: list[str], - min_sample_size: int = 100, - confidence_level: float = 0.95, - created_by: str = None, + min_sample_size: int = 100, + confidence_level: float = 0.95, + created_by: str = None, ) -> Experiment: """创建 A/B 测试实验""" - experiment_id = f"exp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + experiment_id = f"exp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - experiment = Experiment( - id = experiment_id, - tenant_id = tenant_id, - name = name, - description = description, - hypothesis = hypothesis, - status = ExperimentStatus.DRAFT, - variants = variants, - traffic_allocation = traffic_allocation, - traffic_split = traffic_split, - target_audience = target_audience, - primary_metric = primary_metric, - secondary_metrics = secondary_metrics, - start_date = None, - end_date = None, - min_sample_size = min_sample_size, - confidence_level = confidence_level, - created_at = now, - updated_at = now, - created_by = created_by or "system", + experiment = Experiment( + id=experiment_id, + tenant_id=tenant_id, + name=name, + description=description, + hypothesis=hypothesis, + status=ExperimentStatus.DRAFT, + variants=variants, + traffic_allocation=traffic_allocation, + traffic_split=traffic_split, + target_audience=target_audience, + primary_metric=primary_metric, + secondary_metrics=secondary_metrics, + start_date=None, + end_date=None, + min_sample_size=min_sample_size, + confidence_level=confidence_level, + created_at=now, + updated_at=now, + created_by=created_by or "system", ) with self._get_db() as conn: @@ -890,7 +890,7 @@ class GrowthManager: def get_experiment(self, experiment_id: str) -> Experiment | None: """获取实验详情""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM experiments WHERE id = ?", (experiment_id, ) ).fetchone() @@ -898,10 +898,10 @@ class GrowthManager: return self._row_to_experiment(row) return None - def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]: + def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]: """列出实验""" - query = "SELECT * FROM experiments WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM experiments WHERE tenant_id = ?" + params = [tenant_id] if status: query += " AND status = ?" @@ -910,20 +910,20 @@ class GrowthManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_experiment(row) for row in rows] def assign_variant( - self, experiment_id: str, user_id: str, user_attributes: dict = None + self, experiment_id: str, user_id: str, user_attributes: dict = None ) -> str | None: """为用户分配实验变体""" - experiment = self.get_experiment(experiment_id) + experiment = self.get_experiment(experiment_id) if not experiment or experiment.status != ExperimentStatus.RUNNING: return None # 检查用户是否已分配 with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT variant_id FROM experiment_assignments WHERE experiment_id = ? AND user_id = ?""", (experiment_id, user_id), @@ -934,18 +934,18 @@ class GrowthManager: # 根据分配策略选择变体 if experiment.traffic_allocation == TrafficAllocationType.RANDOM: - variant_id = self._random_allocation(experiment.variants, experiment.traffic_split) + variant_id = self._random_allocation(experiment.variants, experiment.traffic_split) elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED: - variant_id = self._stratified_allocation( + variant_id = self._stratified_allocation( experiment.variants, experiment.traffic_split, user_attributes ) else: # TARGETED - variant_id = self._targeted_allocation( + variant_id = self._targeted_allocation( experiment.variants, experiment.target_audience, user_attributes ) if variant_id: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ INSERT INTO experiment_assignments @@ -967,13 +967,13 @@ class GrowthManager: def _random_allocation(self, variants: list[dict], traffic_split: dict[str, float]) -> str: """随机分配""" - variant_ids = [v["id"] for v in variants] - weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids] + variant_ids = [v["id"] for v in variants] + weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids] - total = sum(weights) - normalized_weights = [w / total for w in weights] + total = sum(weights) + normalized_weights = [w / total for w in weights] - return random.choices(variant_ids, weights = normalized_weights, k = 1)[0] + return random.choices(variant_ids, weights=normalized_weights, k=1)[0] def _stratified_allocation( self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict @@ -981,9 +981,9 @@ class GrowthManager: """分层分配(基于用户属性)""" # 简化的分层分配:根据用户 ID 哈希值分配 if user_attributes and "user_id" in user_attributes: - hash_value = int(hashlib.md5(user_attributes["user_id"].encode()).hexdigest(), 16) - variant_ids = [v["id"] for v in variants] - index = hash_value % len(variant_ids) + hash_value = int(hashlib.md5(user_attributes["user_id"].encode()).hexdigest(), 16) + variant_ids = [v["id"] for v in variants] + index = hash_value % len(variant_ids) return variant_ids[index] return self._random_allocation(variants, traffic_split) @@ -993,29 +993,29 @@ class GrowthManager: ) -> str | None: """定向分配(基于目标受众条件)""" # 检查用户是否符合目标受众条件 - conditions = target_audience.get("conditions", []) + conditions = target_audience.get("conditions", []) - matches = True + matches = True for condition in conditions: - attr_name = condition.get("attribute") - operator = condition.get("operator") - value = condition.get("value") + attr_name = condition.get("attribute") + operator = condition.get("operator") + value = condition.get("value") - user_value = user_attributes.get(attr_name) if user_attributes else None + user_value = user_attributes.get(attr_name) if user_attributes else None if operator == "equals" and user_value != value: - matches = False + matches = False break elif operator == "not_equals" and user_value == value: - matches = False + matches = False break elif operator == "in" and user_value not in value: - matches = False + matches = False break if not matches: # 用户不符合条件,返回对照组 - control_variant = next((v for v in variants if v.get("is_control")), variants[0]) + control_variant = next((v for v in variants if v.get("is_control")), variants[0]) return control_variant["id"] if control_variant else None return self._random_allocation(variants, target_audience.get("traffic_split", {})) @@ -1050,18 +1050,18 @@ class GrowthManager: def analyze_experiment(self, experiment_id: str) -> dict: """分析实验结果""" - experiment = self.get_experiment(experiment_id) + experiment = self.get_experiment(experiment_id) if not experiment: return {"error": "Experiment not found"} with self._get_db() as conn: - results = {} + results = {} for variant in experiment.variants: - variant_id = variant["id"] + variant_id = variant["id"] # 获取样本量 - sample_row = conn.execute( + sample_row = conn.execute( """ SELECT COUNT(DISTINCT user_id) as sample_size FROM experiment_assignments @@ -1070,10 +1070,10 @@ class GrowthManager: (experiment_id, variant_id), ).fetchone() - sample_size = sample_row["sample_size"] if sample_row else 0 + sample_size = sample_row["sample_size"] if sample_row else 0 # 获取主要指标统计 - metric_row = conn.execute( + metric_row = conn.execute( """ SELECT AVG(metric_value) as mean_value, @@ -1085,11 +1085,11 @@ class GrowthManager: (experiment_id, variant_id, experiment.primary_metric), ).fetchone() - mean_value = ( + mean_value = ( metric_row["mean_value"] if metric_row and metric_row["mean_value"] else 0 ) - results[variant_id] = { + results[variant_id] = { "variant_name": variant.get("name", variant_id), "is_control": variant.get("is_control", False), "sample_size": sample_size, @@ -1098,27 +1098,27 @@ class GrowthManager: } # 计算统计显著性(简化版) - control_variant = next((v for v in experiment.variants if v.get("is_control")), None) + control_variant = next((v for v in experiment.variants if v.get("is_control")), None) if control_variant: - control_id = control_variant["id"] - control_result = results.get(control_id, {}) + control_id = control_variant["id"] + control_result = results.get(control_id, {}) for variant_id, result in results.items(): if variant_id != control_id: - control_mean = control_result.get("mean_value", 0) - variant_mean = result.get("mean_value", 0) + control_mean = control_result.get("mean_value", 0) + variant_mean = result.get("mean_value", 0) if control_mean > 0: - uplift = (variant_mean - control_mean) / control_mean + uplift = (variant_mean - control_mean) / control_mean else: - uplift = 0 + uplift = 0 # 简化的显著性判断 - is_significant = abs(uplift) > 0.05 and result["sample_size"] > 100 + is_significant = abs(uplift) > 0.05 and result["sample_size"] > 100 - result["uplift"] = round(uplift, 4) - result["is_significant"] = is_significant - result["p_value"] = 0.05 if is_significant else 0.5 + result["uplift"] = round(uplift, 4) + result["is_significant"] = is_significant + result["p_value"] = 0.05 if is_significant else 0.5 return { "experiment_id": experiment_id, @@ -1131,7 +1131,7 @@ class GrowthManager: def start_experiment(self, experiment_id: str) -> Experiment | None: """启动实验""" with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE experiments @@ -1153,7 +1153,7 @@ class GrowthManager: def stop_experiment(self, experiment_id: str) -> Experiment | None: """停止实验""" with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE experiments @@ -1181,36 +1181,36 @@ class GrowthManager: template_type: EmailTemplateType, subject: str, html_content: str, - text_content: str = None, - variables: list[str] = None, - from_name: str = None, - from_email: str = None, - reply_to: str = None, + text_content: str = None, + variables: list[str] = None, + from_name: str = None, + from_email: str = None, + reply_to: str = None, ) -> EmailTemplate: """创建邮件模板""" - template_id = f"et_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + template_id = f"et_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 自动提取变量 if variables is None: - variables = re.findall(r"\{\{(\w+)\}\}", html_content) + variables = re.findall(r"\{\{(\w+)\}\}", html_content) - template = EmailTemplate( - id = template_id, - tenant_id = tenant_id, - name = name, - template_type = template_type, - subject = subject, - html_content = html_content, - text_content = text_content or re.sub(r"<[^>]+>", "", html_content), - variables = variables, - preview_text = None, - from_name = from_name or "InsightFlow", - from_email = from_email or "noreply@insightflow.io", - reply_to = reply_to, - is_active = True, - created_at = now, - updated_at = now, + template = EmailTemplate( + id=template_id, + tenant_id=tenant_id, + name=name, + template_type=template_type, + subject=subject, + html_content=html_content, + text_content=text_content or re.sub(r"<[^>]+>", "", html_content), + variables=variables, + preview_text=None, + from_name=from_name or "InsightFlow", + from_email=from_email or "noreply@insightflow.io", + reply_to=reply_to, + is_active=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1245,7 +1245,7 @@ class GrowthManager: def get_email_template(self, template_id: str) -> EmailTemplate | None: """获取邮件模板""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM email_templates WHERE id = ?", (template_id, ) ).fetchone() @@ -1254,11 +1254,11 @@ class GrowthManager: return None def list_email_templates( - self, tenant_id: str, template_type: EmailTemplateType = None + self, tenant_id: str, template_type: EmailTemplateType = None ) -> list[EmailTemplate]: """列出邮件模板""" - query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" - params = [tenant_id] + query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" + params = [tenant_id] if template_type: query += " AND template_type = ?" @@ -1267,24 +1267,24 @@ class GrowthManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_email_template(row) for row in rows] def render_template(self, template_id: str, variables: dict) -> dict[str, str]: """渲染邮件模板""" - template = self.get_email_template(template_id) + template = self.get_email_template(template_id) if not template: return None - subject = template.subject - html_content = template.html_content - text_content = template.text_content + subject = template.subject + html_content = template.html_content + text_content = template.text_content for key, value in variables.items(): - placeholder = f"{{{{{key}}}}}" - subject = subject.replace(placeholder, str(value)) - html_content = html_content.replace(placeholder, str(value)) - text_content = text_content.replace(placeholder, str(value)) + placeholder = f"{{{{{key}}}}}" + subject = subject.replace(placeholder, str(value)) + html_content = html_content.replace(placeholder, str(value)) + text_content = text_content.replace(placeholder, str(value)) return { "subject": subject, @@ -1301,29 +1301,29 @@ class GrowthManager: name: str, template_id: str, recipient_list: list[dict], - scheduled_at: datetime = None, + scheduled_at: datetime = None, ) -> EmailCampaign: """创建邮件营销活动""" - campaign_id = f"ec_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + campaign_id = f"ec_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - campaign = EmailCampaign( - id = campaign_id, - tenant_id = tenant_id, - name = name, - template_id = template_id, - status = "draft", - recipient_count = len(recipient_list), - sent_count = 0, - delivered_count = 0, - opened_count = 0, - clicked_count = 0, - bounced_count = 0, - failed_count = 0, - scheduled_at = scheduled_at.isoformat() if scheduled_at else None, - started_at = None, - completed_at = None, - created_at = now, + campaign = EmailCampaign( + id=campaign_id, + tenant_id=tenant_id, + name=name, + template_id=template_id, + status="draft", + recipient_count=len(recipient_list), + sent_count=0, + delivered_count=0, + opened_count=0, + clicked_count=0, + bounced_count=0, + failed_count=0, + scheduled_at=scheduled_at.isoformat() if scheduled_at else None, + started_at=None, + completed_at=None, + created_at=now, ) with self._get_db() as conn: @@ -1384,15 +1384,15 @@ class GrowthManager: self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict ) -> bool: """发送单封邮件""" - template = self.get_email_template(template_id) + template = self.get_email_template(template_id) if not template: return False - rendered = self.render_template(template_id, variables) + rendered = self.render_template(template_id, variables) # 更新状态为发送中 with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE email_logs @@ -1408,11 +1408,11 @@ class GrowthManager: # 目前使用模拟发送 await asyncio.sleep(0.1) - success = True # 模拟成功 + success = True # 模拟成功 # 更新状态 with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() if success: conn.execute( """ @@ -1451,7 +1451,7 @@ class GrowthManager: async def send_campaign(self, campaign_id: str) -> dict: """发送整个营销活动""" with self._get_db() as conn: - campaign_row = conn.execute( + campaign_row = conn.execute( "SELECT * FROM email_campaigns WHERE id = ?", (campaign_id, ) ).fetchone() @@ -1459,14 +1459,14 @@ class GrowthManager: return {"error": "Campaign not found"} # 获取待发送的邮件 - logs = conn.execute( + logs = conn.execute( """SELECT * FROM email_logs WHERE campaign_id = ? AND status IN (?, ?)""", (campaign_id, EmailStatus.DRAFT.value, EmailStatus.SCHEDULED.value), ).fetchall() # 更新活动状态 - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( "UPDATE email_campaigns SET status = ?, started_at = ? WHERE id = ?", ("sending", now, campaign_id), @@ -1474,14 +1474,14 @@ class GrowthManager: conn.commit() # 批量发送 - success_count = 0 - failed_count = 0 + success_count = 0 + failed_count = 0 for log in logs: # 获取用户变量 - variables = self._get_user_variables(log["tenant_id"], log["user_id"]) + variables = self._get_user_variables(log["tenant_id"], log["user_id"]) - success = await self.send_email( + success = await self.send_email( campaign_id, log["user_id"], log["email"], log["template_id"], variables ) @@ -1492,7 +1492,7 @@ class GrowthManager: # 更新活动状态 with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE email_campaigns @@ -1526,21 +1526,21 @@ class GrowthManager: actions: list[dict], ) -> AutomationWorkflow: """创建自动化工作流""" - workflow_id = f"aw_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + workflow_id = f"aw_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - workflow = AutomationWorkflow( - id = workflow_id, - tenant_id = tenant_id, - name = name, - description = description, - trigger_type = trigger_type, - trigger_conditions = trigger_conditions, - actions = actions, - is_active = True, - execution_count = 0, - created_at = now, - updated_at = now, + workflow = AutomationWorkflow( + id=workflow_id, + tenant_id=tenant_id, + name=name, + description=description, + trigger_type=trigger_type, + trigger_conditions=trigger_conditions, + actions=actions, + is_active=True, + execution_count=0, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1572,14 +1572,14 @@ class GrowthManager: async def trigger_workflow(self, workflow_id: str, event_data: dict) -> None: """触发自动化工作流""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id, ) ).fetchone() if not row: return False - workflow = self._row_to_automation_workflow(row) + workflow = self._row_to_automation_workflow(row) # 检查触发条件 if not self._check_trigger_conditions(workflow.trigger_conditions, event_data): @@ -1608,7 +1608,7 @@ class GrowthManager: async def _execute_action(self, action: dict, event_data: dict) -> None: """执行工作流动作""" - action_type = action.get("type") + action_type = action.get("type") if action_type == "send_email": action.get("template_id") @@ -1631,29 +1631,29 @@ class GrowthManager: referrer_reward_value: float, referee_reward_type: str, referee_reward_value: float, - max_referrals_per_user: int = 10, - referral_code_length: int = 8, - expiry_days: int = 30, + max_referrals_per_user: int = 10, + referral_code_length: int = 8, + expiry_days: int = 30, ) -> ReferralProgram: """创建推荐计划""" - program_id = f"rp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + program_id = f"rp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - program = ReferralProgram( - id = program_id, - tenant_id = tenant_id, - name = name, - description = description, - referrer_reward_type = referrer_reward_type, - referrer_reward_value = referrer_reward_value, - referee_reward_type = referee_reward_type, - referee_reward_value = referee_reward_value, - max_referrals_per_user = max_referrals_per_user, - referral_code_length = referral_code_length, - expiry_days = expiry_days, - is_active = True, - created_at = now, - updated_at = now, + program = ReferralProgram( + id=program_id, + tenant_id=tenant_id, + name=name, + description=description, + referrer_reward_type=referrer_reward_type, + referrer_reward_value=referrer_reward_value, + referee_reward_type=referee_reward_type, + referee_reward_value=referee_reward_value, + max_referrals_per_user=max_referrals_per_user, + referral_code_length=referral_code_length, + expiry_days=expiry_days, + is_active=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1688,13 +1688,13 @@ class GrowthManager: def generate_referral_code(self, program_id: str, referrer_id: str) -> Referral: """生成推荐码""" - program = self._get_referral_program(program_id) + program = self._get_referral_program(program_id) if not program: return None # 检查推荐次数限制 with self._get_db() as conn: - count_row = conn.execute( + count_row = conn.execute( """SELECT COUNT(*) as count FROM referrals WHERE program_id = ? AND referrer_id = ? AND status != ?""", (program_id, referrer_id, ReferralStatus.EXPIRED.value), @@ -1704,28 +1704,28 @@ class GrowthManager: return None # 生成推荐码 - referral_code = self._generate_unique_code(program.referral_code_length) + referral_code = self._generate_unique_code(program.referral_code_length) - referral_id = f"ref_{uuid.uuid4().hex[:16]}" - now = datetime.now() - expires_at = now + timedelta(days = program.expiry_days) + referral_id = f"ref_{uuid.uuid4().hex[:16]}" + now = datetime.now() + expires_at = now + timedelta(days=program.expiry_days) - referral = Referral( - id = referral_id, - program_id = program_id, - tenant_id = program.tenant_id, - referrer_id = referrer_id, - referee_id = None, - referral_code = referral_code, - status = ReferralStatus.PENDING, - referrer_rewarded = False, - referee_rewarded = False, - referrer_reward_value = program.referrer_reward_value, - referee_reward_value = program.referee_reward_value, - converted_at = None, - rewarded_at = None, - expires_at = expires_at, - created_at = now, + referral = Referral( + id=referral_id, + program_id=program_id, + tenant_id=program.tenant_id, + referrer_id=referrer_id, + referee_id=None, + referral_code=referral_code, + status=ReferralStatus.PENDING, + referrer_rewarded=False, + referee_rewarded=False, + referrer_reward_value=program.referrer_reward_value, + referee_reward_value=program.referee_reward_value, + converted_at=None, + rewarded_at=None, + expires_at=expires_at, + created_at=now, ) conn.execute( @@ -1760,12 +1760,12 @@ class GrowthManager: def _generate_unique_code(self, length: int) -> str: """生成唯一推荐码""" - chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符 + chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符 while True: - code = "".join(random.choices(chars, k = length)) + code = "".join(random.choices(chars, k=length)) with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT 1 FROM referrals WHERE referral_code = ?", (code, ) ).fetchone() @@ -1775,7 +1775,7 @@ class GrowthManager: def _get_referral_program(self, program_id: str) -> ReferralProgram | None: """获取推荐计划""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM referral_programs WHERE id = ?", (program_id, ) ).fetchone() @@ -1786,7 +1786,7 @@ class GrowthManager: def apply_referral_code(self, referral_code: str, referee_id: str) -> bool: """应用推荐码""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT * FROM referrals WHERE referral_code = ? AND status = ? AND expires_at > ?""", (referral_code, ReferralStatus.PENDING.value, datetime.now().isoformat()), @@ -1795,7 +1795,7 @@ class GrowthManager: if not row: return False - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE referrals @@ -1811,12 +1811,12 @@ class GrowthManager: def reward_referral(self, referral_id: str) -> bool: """发放推荐奖励""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id, )).fetchone() + row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id, )).fetchone() if not row or row["status"] != ReferralStatus.CONVERTED.value: return False - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE referrals @@ -1832,7 +1832,7 @@ class GrowthManager: def get_referral_stats(self, program_id: str) -> dict: """获取推荐统计""" with self._get_db() as conn: - stats = conn.execute( + stats = conn.execute( """ SELECT COUNT(*) as total_referrals, @@ -1879,22 +1879,22 @@ class GrowthManager: valid_until: datetime, ) -> TeamIncentive: """创建团队升级激励""" - incentive_id = f"ti_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + incentive_id = f"ti_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - incentive = TeamIncentive( - id = incentive_id, - tenant_id = tenant_id, - name = name, - description = description, - target_tier = target_tier, - min_team_size = min_team_size, - incentive_type = incentive_type, - incentive_value = incentive_value, - valid_from = valid_from.isoformat(), - valid_until = valid_until.isoformat(), - is_active = True, - created_at = now, + incentive = TeamIncentive( + id=incentive_id, + tenant_id=tenant_id, + name=name, + description=description, + target_tier=target_tier, + min_team_size=min_team_size, + incentive_type=incentive_type, + incentive_value=incentive_value, + valid_from=valid_from.isoformat(), + valid_until=valid_until.isoformat(), + is_active=True, + created_at=now, ) with self._get_db() as conn: @@ -1929,8 +1929,8 @@ class GrowthManager: ) -> list[TeamIncentive]: """检查团队激励资格""" with self._get_db() as conn: - now = datetime.now().isoformat() - rows = conn.execute( + now = datetime.now().isoformat() + rows = conn.execute( """ SELECT * FROM team_incentives WHERE tenant_id = ? AND is_active = 1 @@ -1946,12 +1946,12 @@ class GrowthManager: def get_realtime_dashboard(self, tenant_id: str) -> dict: """获取实时分析仪表板数据""" - now = datetime.now() - today_start = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0) + now = datetime.now() + today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) with self._get_db() as conn: # 今日统计 - today_stats = conn.execute( + today_stats = conn.execute( """ SELECT COUNT(DISTINCT user_id) as active_users, @@ -1964,7 +1964,7 @@ class GrowthManager: ).fetchone() # 最近事件 - recent_events = conn.execute( + recent_events = conn.execute( """ SELECT event_name, event_type, timestamp, user_id FROM analytics_events @@ -1976,7 +1976,7 @@ class GrowthManager: ).fetchall() # 热门功能 - top_features = conn.execute( + top_features = conn.execute( """ SELECT event_name, COUNT(*) as count FROM analytics_events @@ -1989,12 +1989,12 @@ class GrowthManager: ).fetchall() # 活跃用户趋势(最近24小时,每小时) - hourly_trend = [] + hourly_trend = [] for i in range(24): - hour_start = now - timedelta(hours = i + 1) - hour_end = now - timedelta(hours = i) + hour_start = now - timedelta(hours=i + 1) + hour_end = now - timedelta(hours=i) - row = conn.execute( + row = conn.execute( """ SELECT COUNT(DISTINCT user_id) as count FROM analytics_events @@ -2035,125 +2035,125 @@ class GrowthManager: def _row_to_user_profile(self, row) -> UserProfile: """将数据库行转换为 UserProfile""" return UserProfile( - id = row["id"], - tenant_id = row["tenant_id"], - user_id = row["user_id"], - first_seen = datetime.fromisoformat(row["first_seen"]), - last_seen = datetime.fromisoformat(row["last_seen"]), - total_sessions = row["total_sessions"], - total_events = row["total_events"], - feature_usage = json.loads(row["feature_usage"]), - subscription_history = json.loads(row["subscription_history"]), - ltv = row["ltv"], - churn_risk_score = row["churn_risk_score"], - engagement_score = row["engagement_score"], - created_at = datetime.fromisoformat(row["created_at"]), - updated_at = datetime.fromisoformat(row["updated_at"]), + id=row["id"], + tenant_id=row["tenant_id"], + user_id=row["user_id"], + first_seen=datetime.fromisoformat(row["first_seen"]), + last_seen=datetime.fromisoformat(row["last_seen"]), + total_sessions=row["total_sessions"], + total_events=row["total_events"], + feature_usage=json.loads(row["feature_usage"]), + subscription_history=json.loads(row["subscription_history"]), + ltv=row["ltv"], + churn_risk_score=row["churn_risk_score"], + engagement_score=row["engagement_score"], + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), ) def _row_to_experiment(self, row) -> Experiment: """将数据库行转换为 Experiment""" return Experiment( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - hypothesis = row["hypothesis"], - status = ExperimentStatus(row["status"]), - variants = json.loads(row["variants"]), - traffic_allocation = TrafficAllocationType(row["traffic_allocation"]), - traffic_split = json.loads(row["traffic_split"]), - target_audience = json.loads(row["target_audience"]), - primary_metric = row["primary_metric"], - secondary_metrics = json.loads(row["secondary_metrics"]), - start_date = datetime.fromisoformat(row["start_date"]) if row["start_date"] else None, - end_date = datetime.fromisoformat(row["end_date"]) if row["end_date"] else None, - min_sample_size = row["min_sample_size"], - confidence_level = row["confidence_level"], - created_at = row["created_at"], - updated_at = row["updated_at"], - created_by = row["created_by"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + hypothesis=row["hypothesis"], + status=ExperimentStatus(row["status"]), + variants=json.loads(row["variants"]), + traffic_allocation=TrafficAllocationType(row["traffic_allocation"]), + traffic_split=json.loads(row["traffic_split"]), + target_audience=json.loads(row["target_audience"]), + primary_metric=row["primary_metric"], + secondary_metrics=json.loads(row["secondary_metrics"]), + start_date=datetime.fromisoformat(row["start_date"]) if row["start_date"] else None, + end_date=datetime.fromisoformat(row["end_date"]) if row["end_date"] else None, + min_sample_size=row["min_sample_size"], + confidence_level=row["confidence_level"], + created_at=row["created_at"], + updated_at=row["updated_at"], + created_by=row["created_by"], ) def _row_to_email_template(self, row) -> EmailTemplate: """将数据库行转换为 EmailTemplate""" return EmailTemplate( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - template_type = EmailTemplateType(row["template_type"]), - subject = row["subject"], - html_content = row["html_content"], - text_content = row["text_content"], - variables = json.loads(row["variables"]), - preview_text = row["preview_text"], - from_name = row["from_name"], - from_email = row["from_email"], - reply_to = row["reply_to"], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + template_type=EmailTemplateType(row["template_type"]), + subject=row["subject"], + html_content=row["html_content"], + text_content=row["text_content"], + variables=json.loads(row["variables"]), + preview_text=row["preview_text"], + from_name=row["from_name"], + from_email=row["from_email"], + reply_to=row["reply_to"], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_automation_workflow(self, row) -> AutomationWorkflow: """将数据库行转换为 AutomationWorkflow""" return AutomationWorkflow( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - trigger_type = WorkflowTriggerType(row["trigger_type"]), - trigger_conditions = json.loads(row["trigger_conditions"]), - actions = json.loads(row["actions"]), - is_active = bool(row["is_active"]), - execution_count = row["execution_count"], - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + trigger_type=WorkflowTriggerType(row["trigger_type"]), + trigger_conditions=json.loads(row["trigger_conditions"]), + actions=json.loads(row["actions"]), + is_active=bool(row["is_active"]), + execution_count=row["execution_count"], + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_referral_program(self, row) -> ReferralProgram: """将数据库行转换为 ReferralProgram""" return ReferralProgram( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - referrer_reward_type = row["referrer_reward_type"], - referrer_reward_value = row["referrer_reward_value"], - referee_reward_type = row["referee_reward_type"], - referee_reward_value = row["referee_reward_value"], - max_referrals_per_user = row["max_referrals_per_user"], - referral_code_length = row["referral_code_length"], - expiry_days = row["expiry_days"], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + referrer_reward_type=row["referrer_reward_type"], + referrer_reward_value=row["referrer_reward_value"], + referee_reward_type=row["referee_reward_type"], + referee_reward_value=row["referee_reward_value"], + max_referrals_per_user=row["max_referrals_per_user"], + referral_code_length=row["referral_code_length"], + expiry_days=row["expiry_days"], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_team_incentive(self, row) -> TeamIncentive: """将数据库行转换为 TeamIncentive""" return TeamIncentive( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - target_tier = row["target_tier"], - min_team_size = row["min_team_size"], - incentive_type = row["incentive_type"], - incentive_value = row["incentive_value"], - valid_from = datetime.fromisoformat(row["valid_from"]), - valid_until = datetime.fromisoformat(row["valid_until"]), - is_active = bool(row["is_active"]), - created_at = row["created_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + target_tier=row["target_tier"], + min_team_size=row["min_team_size"], + incentive_type=row["incentive_type"], + incentive_value=row["incentive_value"], + valid_from=datetime.fromisoformat(row["valid_from"]), + valid_until=datetime.fromisoformat(row["valid_until"]), + is_active=bool(row["is_active"]), + created_at=row["created_at"], ) # Singleton instance -_growth_manager = None +_growth_manager = None def get_growth_manager() -> GrowthManager: global _growth_manager if _growth_manager is None: - _growth_manager = GrowthManager() + _growth_manager = GrowthManager() return _growth_manager diff --git a/backend/image_processor.py b/backend/image_processor.py index c14950f..5e39931 100644 --- a/backend/image_processor.py +++ b/backend/image_processor.py @@ -11,30 +11,30 @@ import uuid from dataclasses import dataclass # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # 尝试导入图像处理库 try: from PIL import Image, ImageEnhance, ImageFilter - PIL_AVAILABLE = True + PIL_AVAILABLE = True except ImportError: - PIL_AVAILABLE = False + PIL_AVAILABLE = False try: import cv2 import numpy as np - CV2_AVAILABLE = True + CV2_AVAILABLE = True except ImportError: - CV2_AVAILABLE = False + CV2_AVAILABLE = False try: import pytesseract - PYTESSERACT_AVAILABLE = True + PYTESSERACT_AVAILABLE = True except ImportError: - PYTESSERACT_AVAILABLE = False + PYTESSERACT_AVAILABLE = False @dataclass @@ -44,7 +44,7 @@ class ImageEntity: name: str type: str confidence: float - bbox: tuple[int, int, int, int] | None = None # (x, y, width, height) + bbox: tuple[int, int, int, int] | None = None # (x, y, width, height) @dataclass @@ -70,7 +70,7 @@ class ImageProcessingResult: width: int height: int success: bool - error_message: str = "" + error_message: str = "" @dataclass @@ -87,7 +87,7 @@ class ImageProcessor: """图片处理器 - 处理各种类型图片""" # 图片类型定义 - IMAGE_TYPES = { + IMAGE_TYPES = { "whiteboard": "白板", "ppt": "PPT/演示文稿", "handwritten": "手写笔记", @@ -96,17 +96,17 @@ class ImageProcessor: "other": "其他", } - def __init__(self, temp_dir: str = None) -> None: + def __init__(self, temp_dir: str = None) -> None: """ 初始化图片处理器 Args: temp_dir: 临时文件目录 """ - self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") - os.makedirs(self.temp_dir, exist_ok = True) + self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") + os.makedirs(self.temp_dir, exist_ok=True) - def preprocess_image(self, image, image_type: str = None) -> None: + def preprocess_image(self, image, image_type: str = None) -> None: """ 预处理图片以提高OCR质量 @@ -123,25 +123,25 @@ class ImageProcessor: try: # 转换为RGB(如果是RGBA) if image.mode == "RGBA": - image = image.convert("RGB") + image = image.convert("RGB") # 根据图片类型进行针对性处理 if image_type == "whiteboard": # 白板:增强对比度,去除背景 - image = self._enhance_whiteboard(image) + image = self._enhance_whiteboard(image) elif image_type == "handwritten": # 手写笔记:降噪,增强对比度 - image = self._enhance_handwritten(image) + image = self._enhance_handwritten(image) elif image_type == "screenshot": # 截图:轻微锐化 - image = image.filter(ImageFilter.SHARPEN) + image = image.filter(ImageFilter.SHARPEN) # 通用处理:调整大小(如果太大) - max_size = 4096 + max_size = 4096 if max(image.size) > max_size: - ratio = max_size / max(image.size) - new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) - image = image.resize(new_size, Image.Resampling.LANCZOS) + ratio = max_size / max(image.size) + new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) + image = image.resize(new_size, Image.Resampling.LANCZOS) return image except Exception as e: @@ -151,33 +151,33 @@ class ImageProcessor: def _enhance_whiteboard(self, image) -> None: """增强白板图片""" # 转换为灰度 - gray = image.convert("L") + gray = image.convert("L") # 增强对比度 - enhancer = ImageEnhance.Contrast(gray) - enhanced = enhancer.enhance(2.0) + enhancer = ImageEnhance.Contrast(gray) + enhanced = enhancer.enhance(2.0) # 二值化 - threshold = 128 - binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1") + threshold = 128 + binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1") return binary.convert("L") def _enhance_handwritten(self, image) -> None: """增强手写笔记图片""" # 转换为灰度 - gray = image.convert("L") + gray = image.convert("L") # 轻微降噪 - blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1)) + blurred = gray.filter(ImageFilter.GaussianBlur(radius=1)) # 增强对比度 - enhancer = ImageEnhance.Contrast(blurred) - enhanced = enhancer.enhance(1.5) + enhancer = ImageEnhance.Contrast(blurred) + enhanced = enhancer.enhance(1.5) return enhanced - def detect_image_type(self, image, ocr_text: str = "") -> str: + def detect_image_type(self, image, ocr_text: str = "") -> str: """ 自动检测图片类型 @@ -193,8 +193,8 @@ class ImageProcessor: try: # 基于图片特征和OCR内容判断类型 - width, height = image.size - aspect_ratio = width / height + width, height = image.size + aspect_ratio = width / height # 检测是否为PPT(通常是16:9或4:3) if 1.3 <= aspect_ratio <= 1.8: @@ -204,12 +204,12 @@ class ImageProcessor: # 检测是否为白板(大量手写文字,可能有箭头、框等) if CV2_AVAILABLE: - img_array = np.array(image.convert("RGB")) - gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) + img_array = np.array(image.convert("RGB")) + gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) # 检测边缘(白板通常有很多线条) - edges = cv2.Canny(gray, 50, 150) - edge_ratio = np.sum(edges > 0) / edges.size + edges = cv2.Canny(gray, 50, 150) + edge_ratio = np.sum(edges > 0) / edges.size # 如果边缘比例高,可能是白板 if edge_ratio > 0.05 and len(ocr_text) > 50: @@ -236,7 +236,7 @@ class ImageProcessor: print(f"Image type detection error: {e}") return "other" - def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]: + def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]: """ 对图片进行OCR识别 @@ -252,15 +252,15 @@ class ImageProcessor: try: # 预处理图片 - processed_image = self.preprocess_image(image) + processed_image = self.preprocess_image(image) # 执行OCR - text = pytesseract.image_to_string(processed_image, lang = lang) + text = pytesseract.image_to_string(processed_image, lang=lang) # 获取置信度 - data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT) - confidences = [int(c) for c in data["conf"] if int(c) > 0] - avg_confidence = sum(confidences) / len(confidences) if confidences else 0 + data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT) + confidences = [int(c) for c in data["conf"] if int(c) > 0] + avg_confidence = sum(confidences) / len(confidences) if confidences else 0 return text.strip(), avg_confidence / 100.0 except Exception as e: @@ -277,26 +277,26 @@ class ImageProcessor: Returns: 实体列表 """ - entities = [] + entities = [] # 简单的实体提取规则(可以替换为LLM调用) # 提取大写字母开头的词组(可能是专有名词) import re # 项目名称(通常是大写或带引号) - project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)' + project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)' for match in re.finditer(project_pattern, text): - name = match.group(1) or match.group(2) + name = match.group(1) or match.group(2) if name and len(name) > 2: - entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7)) + entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7)) # 人名(中文) - name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)" + name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)" for match in re.finditer(name_pattern, text): - entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8)) + entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8)) # 技术术语 - tech_keywords = [ + tech_keywords = [ "K8s", "Kubernetes", "Docker", @@ -314,13 +314,13 @@ class ImageProcessor: ] for keyword in tech_keywords: if keyword in text: - entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9)) + entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9)) # 去重 - seen = set() - unique_entities = [] + seen = set() + unique_entities = [] for e in entities: - key = (e.name.lower(), e.type) + key = (e.name.lower(), e.type) if key not in seen: seen.add(key) unique_entities.append(e) @@ -341,19 +341,19 @@ class ImageProcessor: Returns: 图片描述 """ - type_name = self.IMAGE_TYPES.get(image_type, "图片") + type_name = self.IMAGE_TYPES.get(image_type, "图片") - description_parts = [f"这是一张{type_name}图片。"] + description_parts = [f"这是一张{type_name}图片。"] if ocr_text: # 提取前200字符作为摘要 - text_preview = ocr_text[:200].replace("\n", " ") + text_preview = ocr_text[:200].replace("\n", " ") if len(ocr_text) > 200: text_preview += "..." description_parts.append(f"内容摘要:{text_preview}") if entities: - entity_names = [e.name for e in entities[:5]] # 最多显示5个实体 + entity_names = [e.name for e in entities[:5]] # 最多显示5个实体 description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}") return " ".join(description_parts) @@ -361,9 +361,9 @@ class ImageProcessor: def process_image( self, image_data: bytes, - filename: str = None, - image_id: str = None, - detect_type: bool = True, + filename: str = None, + image_id: str = None, + detect_type: bool = True, ) -> ImageProcessingResult: """ 处理单张图片 @@ -377,73 +377,73 @@ class ImageProcessor: Returns: 图片处理结果 """ - image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH] + image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH] if not PIL_AVAILABLE: return ImageProcessingResult( - image_id = image_id, - image_type = "other", - ocr_text = "", - description = "PIL not available", - entities = [], - relations = [], - width = 0, - height = 0, - success = False, - error_message = "PIL library not available", + image_id=image_id, + image_type="other", + ocr_text="", + description="PIL not available", + entities=[], + relations=[], + width=0, + height=0, + success=False, + error_message="PIL library not available", ) try: # 加载图片 - image = Image.open(io.BytesIO(image_data)) - width, height = image.size + image = Image.open(io.BytesIO(image_data)) + width, height = image.size # 执行OCR - ocr_text, ocr_confidence = self.perform_ocr(image) + ocr_text, ocr_confidence = self.perform_ocr(image) # 检测图片类型 - image_type = "other" + image_type = "other" if detect_type: - image_type = self.detect_image_type(image, ocr_text) + image_type = self.detect_image_type(image, ocr_text) # 提取实体 - entities = self.extract_entities_from_text(ocr_text) + entities = self.extract_entities_from_text(ocr_text) # 生成描述 - description = self.generate_description(image_type, ocr_text, entities) + description = self.generate_description(image_type, ocr_text, entities) # 提取关系(基于实体共现) - relations = self._extract_relations(entities, ocr_text) + relations = self._extract_relations(entities, ocr_text) # 保存图片文件(可选) if filename: - save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}") + save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}") image.save(save_path) return ImageProcessingResult( - image_id = image_id, - image_type = image_type, - ocr_text = ocr_text, - description = description, - entities = entities, - relations = relations, - width = width, - height = height, - success = True, + image_id=image_id, + image_type=image_type, + ocr_text=ocr_text, + description=description, + entities=entities, + relations=relations, + width=width, + height=height, + success=True, ) except Exception as e: return ImageProcessingResult( - image_id = image_id, - image_type = "other", - ocr_text = "", - description = "", - entities = [], - relations = [], - width = 0, - height = 0, - success = False, - error_message = str(e), + image_id=image_id, + image_type="other", + ocr_text="", + description="", + entities=[], + relations=[], + width=0, + height=0, + success=False, + error_message=str(e), ) def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]: @@ -457,16 +457,16 @@ class ImageProcessor: Returns: 关系列表 """ - relations = [] + relations = [] if len(entities) < 2: return relations # 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关 - sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".") + sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".") for sentence in sentences: - sentence_entities = [] + sentence_entities = [] for entity in entities: if entity.name in sentence: sentence_entities.append(entity) @@ -477,17 +477,17 @@ class ImageProcessor: for j in range(i + 1, len(sentence_entities)): relations.append( ImageRelation( - source = sentence_entities[i].name, - target = sentence_entities[j].name, - relation_type = "related", - confidence = 0.5, + source=sentence_entities[i].name, + target=sentence_entities[j].name, + relation_type="related", + confidence=0.5, ) ) return relations def process_batch( - self, images_data: list[tuple[bytes, str]], project_id: str = None + self, images_data: list[tuple[bytes, str]], project_id: str = None ) -> BatchProcessingResult: """ 批量处理图片 @@ -499,12 +499,12 @@ class ImageProcessor: Returns: 批量处理结果 """ - results = [] - success_count = 0 - failed_count = 0 + results = [] + success_count = 0 + failed_count = 0 for image_data, filename in images_data: - result = self.process_image(image_data, filename) + result = self.process_image(image_data, filename) results.append(result) if result.success: @@ -513,10 +513,10 @@ class ImageProcessor: failed_count += 1 return BatchProcessingResult( - results = results, - total_count = len(results), - success_count = success_count, - failed_count = failed_count, + results=results, + total_count=len(results), + success_count=success_count, + failed_count=failed_count, ) def image_to_base64(self, image_data: bytes) -> str: @@ -531,7 +531,7 @@ class ImageProcessor: """ return base64.b64encode(image_data).decode("utf-8") - def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes: + def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes: """ 生成图片缩略图 @@ -546,11 +546,11 @@ class ImageProcessor: return image_data try: - image = Image.open(io.BytesIO(image_data)) + image = Image.open(io.BytesIO(image_data)) image.thumbnail(size, Image.Resampling.LANCZOS) - buffer = io.BytesIO() - image.save(buffer, format = "JPEG") + buffer = io.BytesIO() + image.save(buffer, format="JPEG") return buffer.getvalue() except Exception as e: print(f"Thumbnail generation error: {e}") @@ -558,12 +558,12 @@ class ImageProcessor: # Singleton instance -_image_processor = None +_image_processor = None -def get_image_processor(temp_dir: str = None) -> ImageProcessor: +def get_image_processor(temp_dir: str = None) -> ImageProcessor: """获取图片处理器单例""" global _image_processor if _image_processor is None: - _image_processor = ImageProcessor(temp_dir) + _image_processor = ImageProcessor(temp_dir) return _image_processor diff --git a/backend/init_db.py b/backend/init_db.py index db80146..7cd7778 100644 --- a/backend/init_db.py +++ b/backend/init_db.py @@ -4,27 +4,27 @@ import os import sqlite3 -db_path = os.path.join(os.path.dirname(__file__), "insightflow.db") -schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") +db_path = os.path.join(os.path.dirname(__file__), "insightflow.db") +schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") print(f"Database path: {db_path}") print(f"Schema path: {schema_path}") # Read schema with open(schema_path) as f: - schema = f.read() + schema = f.read() # Execute schema -conn = sqlite3.connect(db_path) -cursor = conn.cursor() +conn = sqlite3.connect(db_path) +cursor = conn.cursor() # Split schema by semicolons and execute each statement -statements = schema.split(";") -success_count = 0 -error_count = 0 +statements = schema.split(";") +success_count = 0 +error_count = 0 for stmt in statements: - stmt = stmt.strip() + stmt = stmt.strip() if stmt: try: cursor.execute(stmt) diff --git a/backend/knowledge_reasoner.py b/backend/knowledge_reasoner.py index beb7e2c..3b3a397 100644 --- a/backend/knowledge_reasoner.py +++ b/backend/knowledge_reasoner.py @@ -12,18 +12,18 @@ from enum import Enum import httpx -KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") -KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") +KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") +KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") class ReasoningType(Enum): """推理类型""" - CAUSAL = "causal" # 因果推理 - ASSOCIATIVE = "associative" # 关联推理 - TEMPORAL = "temporal" # 时序推理 - COMPARATIVE = "comparative" # 对比推理 - SUMMARY = "summary" # 总结推理 + CAUSAL = "causal" # 因果推理 + ASSOCIATIVE = "associative" # 关联推理 + TEMPORAL = "temporal" # 时序推理 + COMPARATIVE = "comparative" # 对比推理 + SUMMARY = "summary" # 总结推理 @dataclass @@ -51,38 +51,38 @@ class InferencePath: class KnowledgeReasoner: """知识推理引擎""" - def __init__(self, api_key: str = None, base_url: str = None) -> None: - self.api_key = api_key or KIMI_API_KEY - self.base_url = base_url or KIMI_BASE_URL - self.headers = { + def __init__(self, api_key: str = None, base_url: str = None) -> None: + self.api_key = api_key or KIMI_API_KEY + self.base_url = base_url or KIMI_BASE_URL + self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str: + async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str: """调用 LLM""" if not self.api_key: raise ValueError("KIMI_API_KEY not set") - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.base_url}/v1/chat/completions", - headers = self.headers, - json = payload, - timeout = 120.0, + headers=self.headers, + json=payload, + timeout=120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return result["choices"][0]["message"]["content"] async def enhanced_qa( - self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium" + self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium" ) -> ReasoningResult: """ 增强问答 - 结合图谱推理的问答 @@ -94,7 +94,7 @@ class KnowledgeReasoner: reasoning_depth: 推理深度 (shallow/medium/deep) """ # 1. 分析问题类型 - analysis = await self._analyze_question(query) + analysis = await self._analyze_question(query) # 2. 根据问题类型选择推理策略 if analysis["type"] == "causal": @@ -108,7 +108,7 @@ class KnowledgeReasoner: async def _analyze_question(self, query: str) -> dict: """分析问题类型和意图""" - prompt = f"""分析以下问题的类型和意图: + prompt = f"""分析以下问题的类型和意图: 问题:{query} @@ -127,9 +127,9 @@ class KnowledgeReasoner: - factual: 事实类问题(是什么、有哪些) - opinion: 观点类问题(怎么看、态度、评价)""" - content = await self._call_llm(prompt, temperature = 0.1) + content = await self._call_llm(prompt, temperature=0.1) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: return json.loads(json_match.group()) @@ -144,10 +144,10 @@ class KnowledgeReasoner: """因果推理 - 分析原因和影响""" # 构建因果分析提示 - entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2) - relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2) + entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2) + relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2) - prompt = f"""基于以下知识图谱进行因果推理分析: + prompt = f"""基于以下知识图谱进行因果推理分析: ## 问题 {query} @@ -159,7 +159,7 @@ class KnowledgeReasoner: {relations_str[:2000]} ## 项目上下文 -{json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]} +{json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]} 请进行因果分析,返回 JSON 格式: {{ @@ -172,31 +172,31 @@ class KnowledgeReasoner: "knowledge_gaps": ["缺失信息1"] }}""" - content = await self._call_llm(prompt, temperature = 0.3) + content = await self._call_llm(prompt, temperature=0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer = data.get("answer", ""), - reasoning_type = ReasoningType.CAUSAL, - confidence = data.get("confidence", 0.7), - evidence = [{"text": e} for e in data.get("evidence", [])], - related_entities = [], - gaps = data.get("knowledge_gaps", []), + answer=data.get("answer", ""), + reasoning_type=ReasoningType.CAUSAL, + confidence=data.get("confidence", 0.7), + evidence=[{"text": e} for e in data.get("evidence", [])], + related_entities=[], + gaps=data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer = content, - reasoning_type = ReasoningType.CAUSAL, - confidence = 0.5, - evidence = [], - related_entities = [], - gaps = ["无法完成因果推理"], + answer=content, + reasoning_type=ReasoningType.CAUSAL, + confidence=0.5, + evidence=[], + related_entities=[], + gaps=["无法完成因果推理"], ) async def _comparative_reasoning( @@ -204,16 +204,16 @@ class KnowledgeReasoner: ) -> ReasoningResult: """对比推理 - 比较实体间的异同""" - prompt = f"""基于以下知识图谱进行对比分析: + prompt = f"""基于以下知识图谱进行对比分析: ## 问题 {query} ## 实体 -{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]} +{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]} ## 关系 -{json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]} +{json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]} 请进行对比分析,返回 JSON 格式: {{ @@ -226,31 +226,31 @@ class KnowledgeReasoner: "knowledge_gaps": [] }}""" - content = await self._call_llm(prompt, temperature = 0.3) + content = await self._call_llm(prompt, temperature=0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer = data.get("answer", ""), - reasoning_type = ReasoningType.COMPARATIVE, - confidence = data.get("confidence", 0.7), - evidence = [{"text": e} for e in data.get("evidence", [])], - related_entities = [], - gaps = data.get("knowledge_gaps", []), + answer=data.get("answer", ""), + reasoning_type=ReasoningType.COMPARATIVE, + confidence=data.get("confidence", 0.7), + evidence=[{"text": e} for e in data.get("evidence", [])], + related_entities=[], + gaps=data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer = content, - reasoning_type = ReasoningType.COMPARATIVE, - confidence = 0.5, - evidence = [], - related_entities = [], - gaps = [], + answer=content, + reasoning_type=ReasoningType.COMPARATIVE, + confidence=0.5, + evidence=[], + related_entities=[], + gaps=[], ) async def _temporal_reasoning( @@ -258,16 +258,16 @@ class KnowledgeReasoner: ) -> ReasoningResult: """时序推理 - 分析时间线和演变""" - prompt = f"""基于以下知识图谱进行时序分析: + prompt = f"""基于以下知识图谱进行时序分析: ## 问题 {query} ## 项目时间线 -{json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]} +{json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]} ## 实体提及历史 -{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]} +{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]} 请进行时序分析,返回 JSON 格式: {{ @@ -280,31 +280,31 @@ class KnowledgeReasoner: "knowledge_gaps": [] }}""" - content = await self._call_llm(prompt, temperature = 0.3) + content = await self._call_llm(prompt, temperature=0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer = data.get("answer", ""), - reasoning_type = ReasoningType.TEMPORAL, - confidence = data.get("confidence", 0.7), - evidence = [{"text": e} for e in data.get("evidence", [])], - related_entities = [], - gaps = data.get("knowledge_gaps", []), + answer=data.get("answer", ""), + reasoning_type=ReasoningType.TEMPORAL, + confidence=data.get("confidence", 0.7), + evidence=[{"text": e} for e in data.get("evidence", [])], + related_entities=[], + gaps=data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer = content, - reasoning_type = ReasoningType.TEMPORAL, - confidence = 0.5, - evidence = [], - related_entities = [], - gaps = [], + answer=content, + reasoning_type=ReasoningType.TEMPORAL, + confidence=0.5, + evidence=[], + related_entities=[], + gaps=[], ) async def _associative_reasoning( @@ -312,16 +312,16 @@ class KnowledgeReasoner: ) -> ReasoningResult: """关联推理 - 发现实体间的隐含关联""" - prompt = f"""基于以下知识图谱进行关联分析: + prompt = f"""基于以下知识图谱进行关联分析: ## 问题 {query} ## 实体 -{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)} +{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)} ## 关系 -{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)} +{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)} 请进行关联推理,发现隐含联系,返回 JSON 格式: {{ @@ -334,52 +334,52 @@ class KnowledgeReasoner: "knowledge_gaps": [] }}""" - content = await self._call_llm(prompt, temperature = 0.4) + content = await self._call_llm(prompt, temperature=0.4) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer = data.get("answer", ""), - reasoning_type = ReasoningType.ASSOCIATIVE, - confidence = data.get("confidence", 0.7), - evidence = [{"text": e} for e in data.get("evidence", [])], - related_entities = [], - gaps = data.get("knowledge_gaps", []), + answer=data.get("answer", ""), + reasoning_type=ReasoningType.ASSOCIATIVE, + confidence=data.get("confidence", 0.7), + evidence=[{"text": e} for e in data.get("evidence", [])], + related_entities=[], + gaps=data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer = content, - reasoning_type = ReasoningType.ASSOCIATIVE, - confidence = 0.5, - evidence = [], - related_entities = [], - gaps = [], + answer=content, + reasoning_type=ReasoningType.ASSOCIATIVE, + confidence=0.5, + evidence=[], + related_entities=[], + gaps=[], ) def find_inference_paths( - self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3 + self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3 ) -> list[InferencePath]: """ 发现两个实体之间的推理路径 使用 BFS 在关系图中搜索路径 """ - relations = graph_data.get("relations", []) + relations = graph_data.get("relations", []) # 构建邻接表 - adj = {} + adj = {} for r in relations: - src = r.get("source_id") or r.get("source") - tgt = r.get("target_id") or r.get("target") + src = r.get("source_id") or r.get("source") + tgt = r.get("target_id") or r.get("target") if src not in adj: - adj[src] = [] + adj[src] = [] if tgt not in adj: - adj[tgt] = [] + adj[tgt] = [] adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r}) # 无向图也添加反向 adj[tgt].append( @@ -389,21 +389,21 @@ class KnowledgeReasoner: # BFS 搜索路径 from collections import deque - paths = [] - queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])]) + paths = [] + queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])]) {start_entity} while queue and len(paths) < 5: - current, path = queue.popleft() + current, path = queue.popleft() if current == end_entity and len(path) > 1: # 找到一条路径 paths.append( InferencePath( - start_entity = start_entity, - end_entity = end_entity, - path = path, - strength = self._calculate_path_strength(path), + start_entity=start_entity, + end_entity=end_entity, + path=path, + strength=self._calculate_path_strength(path), ) ) continue @@ -412,9 +412,9 @@ class KnowledgeReasoner: continue for neighbor in adj.get(current, []): - next_entity = neighbor["target"] + next_entity = neighbor["target"] if next_entity not in [p["entity"] for p in path]: # 避免循环 - new_path = path + [ + new_path = path + [ { "entity": next_entity, "relation": neighbor["relation"], @@ -424,7 +424,7 @@ class KnowledgeReasoner: queue.append((next_entity, new_path)) # 按强度排序 - paths.sort(key = lambda p: p.strength, reverse = True) + paths.sort(key=lambda p: p.strength, reverse=True) return paths def _calculate_path_strength(self, path: list[dict]) -> float: @@ -433,23 +433,23 @@ class KnowledgeReasoner: return 0.0 # 路径越短越强 - length_factor = 1.0 / len(path) + length_factor = 1.0 / len(path) # 关系置信度 - confidence_sum = 0 - confidence_count = 0 + confidence_sum = 0 + confidence_count = 0 for node in path[1:]: # 跳过第一个节点 - rel_data = node.get("relation_data", {}) + rel_data = node.get("relation_data", {}) if "confidence" in rel_data: confidence_sum += rel_data["confidence"] confidence_count += 1 - confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5 + confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5 return length_factor * confidence_factor async def summarize_project( - self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive" + self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive" ) -> dict: """ 项目智能总结 @@ -457,17 +457,17 @@ class KnowledgeReasoner: Args: summary_type: comprehensive/executive/technical/risk """ - type_prompts = { + type_prompts = { "comprehensive": "全面总结项目的所有方面", "executive": "高管摘要,关注关键决策和风险", "technical": "技术总结,关注架构和技术栈", "risk": "风险分析,关注潜在问题和依赖", } - prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}: + prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}: ## 项目信息 -{json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]} +{json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]} ## 知识图谱 实体数: {len(graph_data.get("entities", []))} @@ -483,9 +483,9 @@ class KnowledgeReasoner: "confidence": 0.85 }}""" - content = await self._call_llm(prompt, temperature = 0.3) + content = await self._call_llm(prompt, temperature=0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: @@ -504,11 +504,11 @@ class KnowledgeReasoner: # Singleton instance -_reasoner = None +_reasoner = None def get_knowledge_reasoner() -> KnowledgeReasoner: global _reasoner if _reasoner is None: - _reasoner = KnowledgeReasoner() + _reasoner = KnowledgeReasoner() return _reasoner diff --git a/backend/llm_client.py b/backend/llm_client.py index 368ffed..e3fec1d 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -12,8 +12,8 @@ from dataclasses import dataclass import httpx -KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") -KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") +KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") +KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") @dataclass @@ -41,22 +41,22 @@ class RelationExtractionResult: class LLMClient: """Kimi API 客户端""" - def __init__(self, api_key: str = None, base_url: str = None) -> None: - self.api_key = api_key or KIMI_API_KEY - self.base_url = base_url or KIMI_BASE_URL - self.headers = { + def __init__(self, api_key: str = None, base_url: str = None) -> None: + self.api_key = api_key or KIMI_API_KEY + self.base_url = base_url or KIMI_BASE_URL + self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } async def chat( - self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False + self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False ) -> str: """发送聊天请求""" if not self.api_key: raise ValueError("KIMI_API_KEY not set") - payload = { + payload = { "model": "k2p5", "messages": [{"role": m.role, "content": m.content} for m in messages], "temperature": temperature, @@ -64,24 +64,24 @@ class LLMClient: } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.base_url}/v1/chat/completions", - headers = self.headers, - json = payload, - timeout = 120.0, + headers=self.headers, + json=payload, + timeout=120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return result["choices"][0]["message"]["content"] async def chat_stream( - self, messages: list[ChatMessage], temperature: float = 0.3 + self, messages: list[ChatMessage], temperature: float = 0.3 ) -> AsyncGenerator[str, None]: """流式聊天请求""" if not self.api_key: raise ValueError("KIMI_API_KEY not set") - payload = { + payload = { "model": "k2p5", "messages": [{"role": m.role, "content": m.content} for m in messages], "temperature": temperature, @@ -92,19 +92,19 @@ class LLMClient: async with client.stream( "POST", f"{self.base_url}/v1/chat/completions", - headers = self.headers, - json = payload, - timeout = 120.0, + headers=self.headers, + json=payload, + timeout=120.0, ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): - data = line[6:] + data = line[6:] if data == "[DONE]": break try: - chunk = json.loads(data) - delta = chunk["choices"][0]["delta"] + chunk = json.loads(data) + delta = chunk["choices"][0]["delta"] if "content" in delta: yield delta["content"] except (json.JSONDecodeError, KeyError, IndexError): @@ -114,7 +114,7 @@ class LLMClient: self, text: str ) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]: """提取实体和关系,带置信度分数""" - prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回: + prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回: 文本:{text[:3000]} @@ -139,30 +139,30 @@ class LLMClient: ] }}""" - messages = [ChatMessage(role = "user", content = prompt)] - content = await self.chat(messages, temperature = 0.1) + messages = [ChatMessage(role="user", content=prompt)] + content = await self.chat(messages, temperature=0.1) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if not json_match: return [], [] try: - data = json.loads(json_match.group()) - entities = [ + data = json.loads(json_match.group()) + entities = [ EntityExtractionResult( - name = e["name"], - type = e.get("type", "OTHER"), - definition = e.get("definition", ""), - confidence = e.get("confidence", 0.8), + name=e["name"], + type=e.get("type", "OTHER"), + definition=e.get("definition", ""), + confidence=e.get("confidence", 0.8), ) for e in data.get("entities", []) ] - relations = [ + relations = [ RelationExtractionResult( - source = r["source"], - target = r["target"], - type = r.get("type", "related"), - confidence = r.get("confidence", 0.8), + source=r["source"], + target=r["target"], + type=r.get("type", "related"), + confidence=r.get("confidence", 0.8), ) for r in data.get("relations", []) ] @@ -173,10 +173,10 @@ class LLMClient: async def rag_query(self, query: str, context: str, project_context: dict) -> str: """RAG 问答 - 基于项目上下文回答问题""" - prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: + prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: ## 项目信息 -{json.dumps(project_context, ensure_ascii = False, indent = 2)} +{json.dumps(project_context, ensure_ascii=False, indent=2)} ## 相关上下文 {context[:4000]} @@ -186,21 +186,21 @@ class LLMClient: 请用中文回答,保持简洁专业。如果信息不足,请明确说明。""" - messages = [ + messages = [ ChatMessage( - role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。" + role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。" ), - ChatMessage(role = "user", content = prompt), + ChatMessage(role="user", content=prompt), ] - return await self.chat(messages, temperature = 0.3) + return await self.chat(messages, temperature=0.3) async def agent_command(self, command: str, project_context: dict) -> dict: """Agent 指令解析 - 将自然语言指令转换为结构化操作""" - prompt = f"""解析以下用户指令,转换为结构化操作: + prompt = f"""解析以下用户指令,转换为结构化操作: ## 项目信息 -{json.dumps(project_context, ensure_ascii = False, indent = 2)} +{json.dumps(project_context, ensure_ascii=False, indent=2)} ## 用户指令 {command} @@ -221,10 +221,10 @@ class LLMClient: - create_relation: 创建关系,params 包含 source(源实体), target(目标实体), relation_type(关系类型) """ - messages = [ChatMessage(role = "user", content = prompt)] - content = await self.chat(messages, temperature = 0.1) + messages = [ChatMessage(role="user", content=prompt)] + content = await self.chat(messages, temperature=0.1) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if not json_match: return {"intent": "unknown", "explanation": "无法解析指令"} @@ -235,14 +235,14 @@ class LLMClient: async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str: """分析实体在项目中的演变/态度变化""" - mentions_text = "\n".join( + mentions_text = "\n".join( [ f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20] ] # 限制数量 ) - prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化: + prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化: ## 提及记录 {mentions_text} @@ -255,16 +255,16 @@ class LLMClient: 用中文回答,结构清晰。""" - messages = [ChatMessage(role = "user", content = prompt)] - return await self.chat(messages, temperature = 0.3) + messages = [ChatMessage(role="user", content=prompt)] + return await self.chat(messages, temperature=0.3) # Singleton instance -_llm_client = None +_llm_client = None def get_llm_client() -> LLMClient: global _llm_client if _llm_client is None: - _llm_client = LLMClient() + _llm_client = LLMClient() return _llm_client diff --git a/backend/localization_manager.py b/backend/localization_manager.py index 344a95f..f150ef9 100644 --- a/backend/localization_manager.py +++ b/backend/localization_manager.py @@ -22,90 +22,90 @@ from typing import Any try: import pytz - PYTZ_AVAILABLE = True + PYTZ_AVAILABLE = True except ImportError: - PYTZ_AVAILABLE = False + PYTZ_AVAILABLE = False try: from babel import Locale, dates, numbers - BABEL_AVAILABLE = True + BABEL_AVAILABLE = True except ImportError: - BABEL_AVAILABLE = False + BABEL_AVAILABLE = False -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class LanguageCode(StrEnum): """支持的语言代码""" - EN = "en" - ZH_CN = "zh_CN" - ZH_TW = "zh_TW" - JA = "ja" - KO = "ko" - DE = "de" - FR = "fr" - ES = "es" - PT = "pt" - RU = "ru" - AR = "ar" - HI = "hi" + EN = "en" + ZH_CN = "zh_CN" + ZH_TW = "zh_TW" + JA = "ja" + KO = "ko" + DE = "de" + FR = "fr" + ES = "es" + PT = "pt" + RU = "ru" + AR = "ar" + HI = "hi" class RegionCode(StrEnum): """区域代码""" - GLOBAL = "global" - NORTH_AMERICA = "na" - EUROPE = "eu" - ASIA_PACIFIC = "apac" - CHINA = "cn" - LATIN_AMERICA = "latam" - MIDDLE_EAST = "me" + GLOBAL = "global" + NORTH_AMERICA = "na" + EUROPE = "eu" + ASIA_PACIFIC = "apac" + CHINA = "cn" + LATIN_AMERICA = "latam" + MIDDLE_EAST = "me" class DataCenterRegion(StrEnum): """数据中心区域""" - US_EAST = "us-east" - US_WEST = "us-west" - EU_WEST = "eu-west" - EU_CENTRAL = "eu-central" - AP_SOUTHEAST = "ap-southeast" - AP_NORTHEAST = "ap-northeast" - AP_SOUTH = "ap-south" - CN_NORTH = "cn-north" - CN_EAST = "cn-east" + US_EAST = "us-east" + US_WEST = "us-west" + EU_WEST = "eu-west" + EU_CENTRAL = "eu-central" + AP_SOUTHEAST = "ap-southeast" + AP_NORTHEAST = "ap-northeast" + AP_SOUTH = "ap-south" + CN_NORTH = "cn-north" + CN_EAST = "cn-east" class PaymentProvider(StrEnum): """支付提供商""" - STRIPE = "stripe" - ALIPAY = "alipay" - WECHAT_PAY = "wechat_pay" - PAYPAL = "paypal" - APPLE_PAY = "apple_pay" - GOOGLE_PAY = "google_pay" - KLARNA = "klarna" - IDEAL = "ideal" - BANCONTACT = "bancontact" - GIROPAY = "giropay" - SEPA = "sepa" - UNIONPAY = "unionpay" + STRIPE = "stripe" + ALIPAY = "alipay" + WECHAT_PAY = "wechat_pay" + PAYPAL = "paypal" + APPLE_PAY = "apple_pay" + GOOGLE_PAY = "google_pay" + KLARNA = "klarna" + IDEAL = "ideal" + BANCONTACT = "bancontact" + GIROPAY = "giropay" + SEPA = "sepa" + UNIONPAY = "unionpay" class CalendarType(StrEnum): """日历类型""" - GREGORIAN = "gregorian" - CHINESE_LUNAR = "chinese_lunar" - ISLAMIC = "islamic" - HEBREW = "hebrew" - INDIAN = "indian" - PERSIAN = "persian" - BUDDHIST = "buddhist" + GREGORIAN = "gregorian" + CHINESE_LUNAR = "chinese_lunar" + ISLAMIC = "islamic" + HEBREW = "hebrew" + INDIAN = "indian" + PERSIAN = "persian" + BUDDHIST = "buddhist" @dataclass @@ -252,7 +252,7 @@ class LocalizationSettings: class LocalizationManager: - DEFAULT_LANGUAGES = { + DEFAULT_LANGUAGES = { LanguageCode.EN: { "name": "English", "name_local": "English", @@ -399,7 +399,7 @@ class LocalizationManager: }, } - DEFAULT_DATA_CENTERS = { + DEFAULT_DATA_CENTERS = { DataCenterRegion.US_EAST: { "name": "US East (Virginia)", "location": "Virginia, USA", @@ -474,7 +474,7 @@ class LocalizationManager: }, } - DEFAULT_PAYMENT_METHODS = { + DEFAULT_PAYMENT_METHODS = { PaymentProvider.STRIPE: { "name": "Credit Card", "name_local": { @@ -572,7 +572,7 @@ class LocalizationManager: }, } - DEFAULT_COUNTRIES = { + DEFAULT_COUNTRIES = { "US": { "name": "United States", "name_local": {"en": "United States"}, @@ -719,21 +719,21 @@ class LocalizationManager: }, } - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path - self._is_memory_db = db_path == ":memory:" - self._conn = None + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path + self._is_memory_db = db_path == ":memory:" + self._conn = None self._init_db() self._init_default_data() def _get_connection(self) -> sqlite3.Connection: if self._is_memory_db: if self._conn is None: - self._conn = sqlite3.connect(self.db_path) - self._conn.row_factory = sqlite3.Row + self._conn = sqlite3.connect(self.db_path) + self._conn.row_factory = sqlite3.Row return self._conn - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _close_if_file_db(self, conn) -> None: @@ -741,9 +741,9 @@ class LocalizationManager: conn.close() def _init_db(self) -> None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS translations ( id TEXT PRIMARY KEY, key TEXT NOT NULL, language TEXT NOT NULL, @@ -864,9 +864,9 @@ class LocalizationManager: self._close_if_file_db(conn) def _init_default_data(self) -> None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() for code, config in self.DEFAULT_LANGUAGES.items(): cursor.execute( """ @@ -894,7 +894,7 @@ class LocalizationManager: ), ) for region_code, config in self.DEFAULT_DATA_CENTERS.items(): - dc_id = str(uuid.uuid4()) + dc_id = str(uuid.uuid4()) cursor.execute( """ INSERT OR IGNORE INTO data_centers @@ -913,7 +913,7 @@ class LocalizationManager: ), ) for provider, config in self.DEFAULT_PAYMENT_METHODS.items(): - pm_id = str(uuid.uuid4()) + pm_id = str(uuid.uuid4()) cursor.execute( """ INSERT OR IGNORE INTO localized_payment_methods @@ -963,20 +963,20 @@ class LocalizationManager: self._close_if_file_db(conn) def get_translation( - self, key: str, language: str, namespace: str = "common", fallback: bool = True + self, key: str, language: str, namespace: str = "common", fallback: bool = True ) -> str | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( "SELECT value FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return row["value"] if fallback: - lang_config = self.get_language_config(language) + lang_config = self.get_language_config(language) if lang_config and lang_config.fallback_language: return self.get_translation( key, lang_config.fallback_language, namespace, False @@ -992,14 +992,14 @@ class LocalizationManager: key: str, language: str, value: str, - namespace: str = "common", - context: str | None = None, + namespace: str = "common", + context: str | None = None, ) -> Translation: - conn = self._get_connection() + conn = self._get_connection() try: - translation_id = str(uuid.uuid4()) - now = datetime.now() - cursor = conn.cursor() + translation_id = str(uuid.uuid4()) + now = datetime.now() + cursor = conn.cursor() cursor.execute( """ INSERT INTO translations @@ -1021,20 +1021,20 @@ class LocalizationManager: def _get_translation_internal( self, conn: sqlite3.Connection, key: str, language: str, namespace: str ) -> Translation | None: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( "SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_translation(row) return None - def delete_translation(self, key: str, language: str, namespace: str = "common") -> bool: - conn = self._get_connection() + def delete_translation(self, key: str, language: str, namespace: str = "common") -> bool: + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( "DELETE FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace), @@ -1046,16 +1046,16 @@ class LocalizationManager: def list_translations( self, - language: str | None = None, - namespace: str | None = None, - limit: int = 1000, - offset: int = 0, + language: str | None = None, + namespace: str | None = None, + limit: int = 1000, + offset: int = 0, ) -> list[Translation]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM translations WHERE 1 = 1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM translations WHERE 1 = 1" + params = [] if language: query += " AND language = ?" params.append(language) @@ -1065,43 +1065,43 @@ class LocalizationManager: query += " ORDER BY namespace, key LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_translation(row) for row in rows] finally: self._close_if_file_db(conn) def get_language_config(self, code: str) -> LanguageConfig | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_language_config(row) return None finally: self._close_if_file_db(conn) - def list_language_configs(self, active_only: bool = True) -> list[LanguageConfig]: - conn = self._get_connection() + def list_language_configs(self, active_only: bool = True) -> list[LanguageConfig]: + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM language_configs" + cursor = conn.cursor() + query = "SELECT * FROM language_configs" if active_only: query += " WHERE is_active = 1" query += " ORDER BY name" cursor.execute(query) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_language_config(row) for row in rows] finally: self._close_if_file_db(conn) def get_data_center(self, dc_id: str) -> DataCenter | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_data_center(row) return None @@ -1109,11 +1109,11 @@ class LocalizationManager: self._close_if_file_db(conn) def get_data_center_by_region(self, region_code: str) -> DataCenter | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_data_center(row) return None @@ -1121,13 +1121,13 @@ class LocalizationManager: self._close_if_file_db(conn) def list_data_centers( - self, status: str | None = None, region: str | None = None + self, status: str | None = None, region: str | None = None ) -> list[DataCenter]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM data_centers WHERE 1 = 1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM data_centers WHERE 1 = 1" + params = [] if status: query += " AND status = ?" params.append(status) @@ -1136,19 +1136,19 @@ class LocalizationManager: params.append(f'%"{region}"%') query += " ORDER BY priority" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_data_center(row) for row in rows] finally: self._close_if_file_db(conn) def get_tenant_data_center(self, tenant_id: str) -> TenantDataCenterMapping | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( "SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id, ) ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_tenant_dc_mapping(row) return None @@ -1156,11 +1156,11 @@ class LocalizationManager: self._close_if_file_db(conn) def set_tenant_data_center( - self, tenant_id: str, region_code: str, data_residency: str = "regional" + self, tenant_id: str, region_code: str, data_residency: str = "regional" ) -> TenantDataCenterMapping: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM data_centers WHERE supported_regions LIKE ? AND status = 'active' @@ -1168,26 +1168,26 @@ class LocalizationManager: """, (f'%"{region_code}"%', ), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: cursor.execute(""" SELECT * FROM data_centers WHERE supported_regions LIKE '%"global"%' AND status = 'active' ORDER BY priority LIMIT 1 """) - row = cursor.fetchone() + row = cursor.fetchone() if not row: raise ValueError(f"No data center available for region: {region_code}") - primary_dc_id = row["id"] + primary_dc_id = row["id"] cursor.execute( """ SELECT * FROM data_centers WHERE id != ? AND status = 'active' ORDER BY priority LIMIT 1 """, (primary_dc_id, ), ) - secondary_row = cursor.fetchone() - secondary_dc_id = secondary_row["id"] if secondary_row else None - mapping_id = str(uuid.uuid4()) - now = datetime.now() + secondary_row = cursor.fetchone() + secondary_dc_id = secondary_row["id"] if secondary_row else None + mapping_id = str(uuid.uuid4()) + now = datetime.now() cursor.execute( """ INSERT INTO tenant_data_center_mappings @@ -1218,13 +1218,13 @@ class LocalizationManager: self._close_if_file_db(conn) def get_payment_method(self, provider: str) -> LocalizedPaymentMethod | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( "SELECT * FROM localized_payment_methods WHERE provider = ?", (provider, ) ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_payment_method(row) return None @@ -1232,13 +1232,13 @@ class LocalizationManager: self._close_if_file_db(conn) def list_payment_methods( - self, country_code: str | None = None, currency: str | None = None, active_only: bool = True + self, country_code: str | None = None, currency: str | None = None, active_only: bool = True ) -> list[LocalizedPaymentMethod]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM localized_payment_methods WHERE 1 = 1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM localized_payment_methods WHERE 1 = 1" + params = [] if active_only: query += " AND is_active = 1" if country_code: @@ -1249,18 +1249,18 @@ class LocalizationManager: params.append(f'%"{currency}"%') query += " ORDER BY display_order" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_payment_method(row) for row in rows] finally: self._close_if_file_db(conn) def get_localized_payment_methods( - self, country_code: str, language: str = "en" + self, country_code: str, language: str = "en" ) -> list[dict[str, Any]]: - methods = self.list_payment_methods(country_code = country_code) - result = [] + methods = self.list_payment_methods(country_code=country_code) + result = [] for method in methods: - name_local = method.name_local.get(language, method.name) + name_local = method.name_local.get(language, method.name) result.append( { "id": method.id, @@ -1275,11 +1275,11 @@ class LocalizationManager: return result def get_country_config(self, code: str) -> CountryConfig | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_country_config(row) return None @@ -1287,13 +1287,13 @@ class LocalizationManager: self._close_if_file_db(conn) def list_country_configs( - self, region: str | None = None, active_only: bool = True + self, region: str | None = None, active_only: bool = True ) -> list[CountryConfig]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM country_configs WHERE 1 = 1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM country_configs WHERE 1 = 1" + params = [] if active_only: query += " AND is_active = 1" if region: @@ -1301,7 +1301,7 @@ class LocalizationManager: params.append(region) query += " ORDER BY name" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_country_config(row) for row in rows] finally: self._close_if_file_db(conn) @@ -1309,34 +1309,34 @@ class LocalizationManager: def format_datetime( self, dt: datetime, - language: str = "en", - timezone: str | None = None, - format_type: str = "datetime", + language: str = "en", + timezone: str | None = None, + format_type: str = "datetime", ) -> str: try: if timezone and PYTZ_AVAILABLE: - tz = pytz.timezone(timezone) + tz = pytz.timezone(timezone) if dt.tzinfo is None: - dt = pytz.UTC.localize(dt) - dt = dt.astimezone(tz) - lang_config = self.get_language_config(language) + dt = pytz.UTC.localize(dt) + dt = dt.astimezone(tz) + lang_config = self.get_language_config(language) if not lang_config: - lang_config = self.get_language_config("en") + lang_config = self.get_language_config("en") if format_type == "date": - fmt = lang_config.date_format if lang_config else "%Y-%m-%d" + fmt = lang_config.date_format if lang_config else "%Y-%m-%d" elif format_type == "time": - fmt = lang_config.time_format if lang_config else "%H:%M" + fmt = lang_config.time_format if lang_config else "%H:%M" else: - fmt = lang_config.datetime_format if lang_config else "%Y-%m-%d %H:%M" + fmt = lang_config.datetime_format if lang_config else "%Y-%m-%d %H:%M" if BABEL_AVAILABLE: try: - locale = Locale.parse(language.replace("_", "-")) + locale = Locale.parse(language.replace("_", "-")) if format_type == "date": - return dates.format_date(dt, locale = locale) + return dates.format_date(dt, locale=locale) elif format_type == "time": - return dates.format_time(dt, locale = locale) + return dates.format_time(dt, locale=locale) else: - return dates.format_datetime(dt, locale = locale) + return dates.format_datetime(dt, locale=locale) except (ValueError, AttributeError): pass return dt.strftime(fmt) @@ -1345,14 +1345,14 @@ class LocalizationManager: return dt.strftime("%Y-%m-%d %H:%M") def format_number( - self, number: float, language: str = "en", decimal_places: int | None = None + self, number: float, language: str = "en", decimal_places: int | None = None ) -> str: try: if BABEL_AVAILABLE: try: - locale = Locale.parse(language.replace("_", "-")) + locale = Locale.parse(language.replace("_", "-")) return numbers.format_decimal( - number, locale = locale, decimal_quantization = (decimal_places is not None) + number, locale=locale, decimal_quantization=(decimal_places is not None) ) except (ValueError, AttributeError): pass @@ -1363,12 +1363,12 @@ class LocalizationManager: logger.error(f"Error formatting number: {e}") return str(number) - def format_currency(self, amount: float, currency: str, language: str = "en") -> str: + def format_currency(self, amount: float, currency: str, language: str = "en") -> str: try: if BABEL_AVAILABLE: try: - locale = Locale.parse(language.replace("_", "-")) - return numbers.format_currency(amount, currency, locale = locale) + locale = Locale.parse(language.replace("_", "-")) + return numbers.format_currency(amount, currency, locale=locale) except (ValueError, AttributeError): pass return f"{currency} {amount:, .2f}" @@ -1379,10 +1379,10 @@ class LocalizationManager: def convert_timezone(self, dt: datetime, from_tz: str, to_tz: str) -> datetime: try: if PYTZ_AVAILABLE: - from_zone = pytz.timezone(from_tz) - to_zone = pytz.timezone(to_tz) + from_zone = pytz.timezone(from_tz) + to_zone = pytz.timezone(to_tz) if dt.tzinfo is None: - dt = from_zone.localize(dt) + dt = from_zone.localize(dt) return dt.astimezone(to_zone) return dt except Exception as e: @@ -1392,8 +1392,8 @@ class LocalizationManager: def get_calendar_info(self, calendar_type: str, year: int, month: int) -> dict[str, Any]: import calendar - cal = calendar.Calendar() - month_days = cal.monthdayscalendar(year, month) + cal = calendar.Calendar() + month_days = cal.monthdayscalendar(year, month) return { "calendar_type": calendar_type, "year": year, @@ -1405,11 +1405,11 @@ class LocalizationManager: } def get_localization_settings(self, tenant_id: str) -> LocalizationSettings | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_localization_settings(row) return None @@ -1419,22 +1419,22 @@ class LocalizationManager: def create_localization_settings( self, tenant_id: str, - default_language: str = "en", - supported_languages: list[str] | None = None, - default_currency: str = "USD", - supported_currencies: list[str] | None = None, - default_timezone: str = "UTC", - region_code: str = "global", - data_residency: str = "regional", + default_language: str = "en", + supported_languages: list[str] | None = None, + default_currency: str = "USD", + supported_currencies: list[str] | None = None, + default_timezone: str = "UTC", + region_code: str = "global", + data_residency: str = "regional", ) -> LocalizationSettings: - conn = self._get_connection() + conn = self._get_connection() try: - settings_id = str(uuid.uuid4()) - now = datetime.now() - supported_languages = supported_languages or [default_language] - supported_currencies = supported_currencies or [default_currency] - lang_config = self.get_language_config(default_language) - cursor = conn.cursor() + settings_id = str(uuid.uuid4()) + now = datetime.now() + supported_languages = supported_languages or [default_language] + supported_currencies = supported_currencies or [default_currency] + lang_config = self.get_language_config(default_language) + cursor = conn.cursor() cursor.execute( """ INSERT INTO localization_settings @@ -1468,14 +1468,14 @@ class LocalizationManager: self._close_if_file_db(conn) def update_localization_settings(self, tenant_id: str, **kwargs) -> LocalizationSettings | None: - conn = self._get_connection() + conn = self._get_connection() try: - settings = self.get_localization_settings(tenant_id) + settings = self.get_localization_settings(tenant_id) if not settings: return None - updates = [] - params = [] - allowed_fields = [ + updates = [] + params = [] + allowed_fields = [ "default_language", "supported_languages", "default_currency", @@ -1503,7 +1503,7 @@ class LocalizationManager: updates.append("updated_at = ?") params.append(datetime.now()) params.append(tenant_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f"UPDATE localization_settings SET {', '.join(updates)} WHERE tenant_id = ?", params ) @@ -1513,48 +1513,48 @@ class LocalizationManager: self._close_if_file_db(conn) def detect_user_preferences( - self, accept_language: str | None = None, ip_country: str | None = None + self, accept_language: str | None = None, ip_country: str | None = None ) -> dict[str, str]: - preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} + preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} if accept_language: - langs = accept_language.split(", ") + langs = accept_language.split(", ") for lang in langs: - lang_code = lang.split(";")[0].strip().replace("-", "_") - lang_config = self.get_language_config(lang_code) + lang_code = lang.split(";")[0].strip().replace("-", "_") + lang_config = self.get_language_config(lang_code) if lang_config and lang_config.is_active: - preferences["language"] = lang_code + preferences["language"] = lang_code break if ip_country: - country = self.get_country_config(ip_country) + country = self.get_country_config(ip_country) if country: - preferences["country"] = ip_country - preferences["currency"] = country.default_currency - preferences["timezone"] = country.timezone + preferences["country"] = ip_country + preferences["currency"] = country.default_currency + preferences["timezone"] = country.timezone if country.default_language not in preferences["language"]: - preferences["language"] = country.default_language + preferences["language"] = country.default_language return preferences def _row_to_translation(self, row: sqlite3.Row) -> Translation: return Translation( - id = row["id"], - key = row["key"], - language = row["language"], - value = row["value"], - namespace = row["namespace"], - context = row["context"], - created_at = ( + id=row["id"], + key=row["key"], + language=row["language"], + value=row["value"], + namespace=row["namespace"], + context=row["context"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - is_reviewed = bool(row["is_reviewed"]), - reviewed_by = row["reviewed_by"], - reviewed_at = ( + is_reviewed=bool(row["is_reviewed"]), + reviewed_by=row["reviewed_by"], + reviewed_at=( datetime.fromisoformat(row["reviewed_at"]) if row["reviewed_at"] and isinstance(row["reviewed_at"], str) else row["reviewed_at"] @@ -1563,39 +1563,39 @@ class LocalizationManager: def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig: return LanguageConfig( - code = row["code"], - name = row["name"], - name_local = row["name_local"], - is_rtl = bool(row["is_rtl"]), - is_active = bool(row["is_active"]), - is_default = bool(row["is_default"]), - fallback_language = row["fallback_language"], - date_format = row["date_format"], - time_format = row["time_format"], - datetime_format = row["datetime_format"], - number_format = row["number_format"], - currency_format = row["currency_format"], - first_day_of_week = row["first_day_of_week"], - calendar_type = row["calendar_type"], + code=row["code"], + name=row["name"], + name_local=row["name_local"], + is_rtl=bool(row["is_rtl"]), + is_active=bool(row["is_active"]), + is_default=bool(row["is_default"]), + fallback_language=row["fallback_language"], + date_format=row["date_format"], + time_format=row["time_format"], + datetime_format=row["datetime_format"], + number_format=row["number_format"], + currency_format=row["currency_format"], + first_day_of_week=row["first_day_of_week"], + calendar_type=row["calendar_type"], ) def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter: return DataCenter( - id = row["id"], - region_code = row["region_code"], - name = row["name"], - location = row["location"], - endpoint = row["endpoint"], - status = row["status"], - priority = row["priority"], - supported_regions = json.loads(row["supported_regions"] or "[]"), - capabilities = json.loads(row["capabilities"] or "{}"), - created_at = ( + id=row["id"], + region_code=row["region_code"], + name=row["name"], + location=row["location"], + endpoint=row["endpoint"], + status=row["status"], + priority=row["priority"], + supported_regions=json.loads(row["supported_regions"] or "[]"), + capabilities=json.loads(row["capabilities"] or "{}"), + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1604,18 +1604,18 @@ class LocalizationManager: def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping: return TenantDataCenterMapping( - id = row["id"], - tenant_id = row["tenant_id"], - primary_dc_id = row["primary_dc_id"], - secondary_dc_id = row["secondary_dc_id"], - region_code = row["region_code"], - data_residency = row["data_residency"], - created_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + primary_dc_id=row["primary_dc_id"], + secondary_dc_id=row["secondary_dc_id"], + region_code=row["region_code"], + data_residency=row["data_residency"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1624,24 +1624,24 @@ class LocalizationManager: def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod: return LocalizedPaymentMethod( - id = row["id"], - provider = row["provider"], - name = row["name"], - name_local = json.loads(row["name_local"] or "{}"), - supported_countries = json.loads(row["supported_countries"] or "[]"), - supported_currencies = json.loads(row["supported_currencies"] or "[]"), - is_active = bool(row["is_active"]), - config = json.loads(row["config"] or "{}"), - icon_url = row["icon_url"], - display_order = row["display_order"], - min_amount = row["min_amount"], - max_amount = row["max_amount"], - created_at = ( + id=row["id"], + provider=row["provider"], + name=row["name"], + name_local=json.loads(row["name_local"] or "{}"), + supported_countries=json.loads(row["supported_countries"] or "[]"), + supported_currencies=json.loads(row["supported_currencies"] or "[]"), + is_active=bool(row["is_active"]), + config=json.loads(row["config"] or "{}"), + icon_url=row["icon_url"], + display_order=row["display_order"], + min_amount=row["min_amount"], + max_amount=row["max_amount"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1650,48 +1650,48 @@ class LocalizationManager: def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig: return CountryConfig( - code = row["code"], - code3 = row["code3"], - name = row["name"], - name_local = json.loads(row["name_local"] or "{}"), - region = row["region"], - default_language = row["default_language"], - supported_languages = json.loads(row["supported_languages"] or "[]"), - default_currency = row["default_currency"], - supported_currencies = json.loads(row["supported_currencies"] or "[]"), - timezone = row["timezone"], - calendar_type = row["calendar_type"], - date_format = row["date_format"], - time_format = row["time_format"], - number_format = row["number_format"], - address_format = row["address_format"], - phone_format = row["phone_format"], - vat_rate = row["vat_rate"], - is_active = bool(row["is_active"]), + code=row["code"], + code3=row["code3"], + name=row["name"], + name_local=json.loads(row["name_local"] or "{}"), + region=row["region"], + default_language=row["default_language"], + supported_languages=json.loads(row["supported_languages"] or "[]"), + default_currency=row["default_currency"], + supported_currencies=json.loads(row["supported_currencies"] or "[]"), + timezone=row["timezone"], + calendar_type=row["calendar_type"], + date_format=row["date_format"], + time_format=row["time_format"], + number_format=row["number_format"], + address_format=row["address_format"], + phone_format=row["phone_format"], + vat_rate=row["vat_rate"], + is_active=bool(row["is_active"]), ) def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings: return LocalizationSettings( - id = row["id"], - tenant_id = row["tenant_id"], - default_language = row["default_language"], - supported_languages = json.loads(row["supported_languages"] or '["en"]'), - default_currency = row["default_currency"], - supported_currencies = json.loads(row["supported_currencies"] or '["USD"]'), - default_timezone = row["default_timezone"], - default_date_format = row["default_date_format"], - default_time_format = row["default_time_format"], - default_number_format = row["default_number_format"], - calendar_type = row["calendar_type"], - first_day_of_week = row["first_day_of_week"], - region_code = row["region_code"], - data_residency = row["data_residency"], - created_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + default_language=row["default_language"], + supported_languages=json.loads(row["supported_languages"] or '["en"]'), + default_currency=row["default_currency"], + supported_currencies=json.loads(row["supported_currencies"] or '["USD"]'), + default_timezone=row["default_timezone"], + default_date_format=row["default_date_format"], + default_time_format=row["default_time_format"], + default_number_format=row["default_number_format"], + calendar_type=row["calendar_type"], + first_day_of_week=row["first_day_of_week"], + region_code=row["region_code"], + data_residency=row["data_residency"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1699,11 +1699,11 @@ class LocalizationManager: ) -_localization_manager = None +_localization_manager = None -def get_localization_manager(db_path: str = "insightflow.db") -> LocalizationManager: +def get_localization_manager(db_path: str = "insightflow.db") -> LocalizationManager: global _localization_manager if _localization_manager is None: - _localization_manager = LocalizationManager(db_path) + _localization_manager = LocalizationManager(db_path) return _localization_manager diff --git a/backend/main.py b/backend/main.py index 71025dc..8138b71 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1166,7 +1166,7 @@ async def create_manual_entity( start_pos=entity.start_pos, end_pos=entity.end_pos, text_snippet=text[ - max(0, entity.start_pos - 20) : min(len(text), entity.end_pos + 20) + max(0, entity.start_pos - 20): min(len(text), entity.end_pos + 20) ], confidence=1.0, ) @@ -1408,7 +1408,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends( start_pos=pos, end_pos=pos + len(name), text_snippet=full_text[ - max(0, pos - 20) : min(len(full_text), pos + len(name) + 20) + max(0, pos - 20): min(len(full_text), pos + len(name) + 20) ], confidence=1.0, ) @@ -1534,7 +1534,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen start_pos=pos, end_pos=pos + len(name), text_snippet=full_text[ - max(0, pos - 20) : min(len(full_text), pos + len(name) + 20) + max(0, pos - 20): min(len(full_text), pos + len(name) + 20) ], confidence=1.0, ) @@ -3804,10 +3804,10 @@ async def system_status(): # ==================== Phase 7: Workflow Automation Endpoints ==================== # Workflow Manager singleton -_workflow_manager: "WorkflowManager | None" = None +_workflow_manager: Any = None -def get_workflow_manager_instance() -> "WorkflowManager | None": +def get_workflow_manager_instance() -> Any: global _workflow_manager if _workflow_manager is None and WORKFLOW_AVAILABLE and DB_AVAILABLE: from workflow_manager import WorkflowManager diff --git a/backend/multimodal_entity_linker.py b/backend/multimodal_entity_linker.py index 803f566..0dc411e 100644 --- a/backend/multimodal_entity_linker.py +++ b/backend/multimodal_entity_linker.py @@ -9,13 +9,13 @@ from dataclasses import dataclass from difflib import SequenceMatcher # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # 尝试导入embedding库 try: - NUMPY_AVAILABLE = True + NUMPY_AVAILABLE = True except ImportError: - NUMPY_AVAILABLE = False + NUMPY_AVAILABLE = False @dataclass @@ -30,11 +30,11 @@ class MultimodalEntity: source_id: str mention_context: str confidence: float - modality_features: dict = None # 模态特定特征 + modality_features: dict = None # 模态特定特征 def __post_init__(self) -> None: if self.modality_features is None: - self.modality_features = {} + self.modality_features = {} @dataclass @@ -78,7 +78,7 @@ class MultimodalEntityLinker: """多模态实体关联器 - 跨模态实体对齐和知识融合""" # 关联类型 - LINK_TYPES = { + LINK_TYPES = { "same_as": "同一实体", "related_to": "相关实体", "part_of": "组成部分", @@ -86,16 +86,16 @@ class MultimodalEntityLinker: } # 模态类型 - MODALITIES = ["audio", "video", "image", "document"] + MODALITIES = ["audio", "video", "image", "document"] - def __init__(self, similarity_threshold: float = 0.85) -> None: + def __init__(self, similarity_threshold: float = 0.85) -> None: """ 初始化多模态实体关联器 Args: similarity_threshold: 相似度阈值 """ - self.similarity_threshold = similarity_threshold + self.similarity_threshold = similarity_threshold def calculate_string_similarity(self, s1: str, s2: str) -> float: """ @@ -111,7 +111,7 @@ class MultimodalEntityLinker: if not s1 or not s2: return 0.0 - s1, s2 = s1.lower().strip(), s2.lower().strip() + s1, s2 = s1.lower().strip(), s2.lower().strip() # 完全匹配 if s1 == s2: @@ -136,7 +136,7 @@ class MultimodalEntityLinker: (相似度, 匹配类型) """ # 名称相似度 - name_sim = self.calculate_string_similarity( + name_sim = self.calculate_string_similarity( entity1.get("name", ""), entity2.get("name", "") ) @@ -145,8 +145,8 @@ class MultimodalEntityLinker: return 1.0, "exact" # 检查别名 - aliases1 = set(a.lower() for a in entity1.get("aliases", [])) - aliases2 = set(a.lower() for a in entity2.get("aliases", [])) + aliases1 = set(a.lower() for a in entity1.get("aliases", [])) + aliases2 = set(a.lower() for a in entity2.get("aliases", [])) if aliases1 & aliases2: # 有共同别名 return 0.95, "alias_match" @@ -157,12 +157,12 @@ class MultimodalEntityLinker: return 0.95, "alias_match" # 定义相似度 - def_sim = self.calculate_string_similarity( + def_sim = self.calculate_string_similarity( entity1.get("definition", ""), entity2.get("definition", "") ) # 综合相似度 - combined_sim = name_sim * 0.7 + def_sim * 0.3 + combined_sim = name_sim * 0.7 + def_sim * 0.3 if combined_sim >= self.similarity_threshold: return combined_sim, "fuzzy" @@ -170,7 +170,7 @@ class MultimodalEntityLinker: return combined_sim, "none" def find_matching_entity( - self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None + self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None ) -> AlignmentResult | None: """ 在候选实体中查找匹配的实体 @@ -183,28 +183,28 @@ class MultimodalEntityLinker: Returns: 对齐结果 """ - exclude_ids = exclude_ids or set() - best_match = None - best_similarity = 0.0 + exclude_ids = exclude_ids or set() + best_match = None + best_similarity = 0.0 for candidate in candidate_entities: if candidate.get("id") in exclude_ids: continue - similarity, match_type = self.calculate_entity_similarity(query_entity, candidate) + similarity, match_type = self.calculate_entity_similarity(query_entity, candidate) if similarity > best_similarity and similarity >= self.similarity_threshold: - best_similarity = similarity - best_match = candidate - best_match_type = match_type + best_similarity = similarity + best_match = candidate + best_match_type = match_type if best_match: return AlignmentResult( - entity_id = query_entity.get("id"), - matched_entity_id = best_match.get("id"), - similarity = best_similarity, - match_type = best_match_type, - confidence = best_similarity, + entity_id=query_entity.get("id"), + matched_entity_id=best_match.get("id"), + similarity=best_similarity, + match_type=best_match_type, + confidence=best_similarity, ) return None @@ -230,10 +230,10 @@ class MultimodalEntityLinker: Returns: 实体关联列表 """ - links = [] + links = [] # 合并所有实体 - all_entities = { + all_entities = { "audio": audio_entities, "video": video_entities, "image": image_entities, @@ -246,24 +246,24 @@ class MultimodalEntityLinker: if mod1 >= mod2: # 避免重复比较 continue - entities1 = all_entities.get(mod1, []) - entities2 = all_entities.get(mod2, []) + entities1 = all_entities.get(mod1, []) + entities2 = all_entities.get(mod2, []) for ent1 in entities1: # 在另一个模态中查找匹配 - result = self.find_matching_entity(ent1, entities2) + result = self.find_matching_entity(ent1, entities2) if result and result.matched_entity_id: - link = EntityLink( - id = str(uuid.uuid4())[:UUID_LENGTH], - project_id = project_id, - source_entity_id = ent1.get("id"), - target_entity_id = result.matched_entity_id, - link_type = "same_as" if result.similarity > 0.95 else "related_to", - source_modality = mod1, - target_modality = mod2, - confidence = result.confidence, - evidence = f"Cross-modal alignment: {result.match_type}", + link = EntityLink( + id=str(uuid.uuid4())[:UUID_LENGTH], + project_id=project_id, + source_entity_id=ent1.get("id"), + target_entity_id=result.matched_entity_id, + link_type="same_as" if result.similarity > 0.95 else "related_to", + source_modality=mod1, + target_modality=mod2, + confidence=result.confidence, + evidence=f"Cross-modal alignment: {result.match_type}", ) links.append(link) @@ -284,7 +284,7 @@ class MultimodalEntityLinker: 融合结果 """ # 收集所有属性 - fused_properties = { + fused_properties = { "names": set(), "definitions": [], "aliases": set(), @@ -293,7 +293,7 @@ class MultimodalEntityLinker: "contexts": [], } - merged_ids = [] + merged_ids = [] for entity in linked_entities: merged_ids.append(entity.get("id")) @@ -318,21 +318,21 @@ class MultimodalEntityLinker: fused_properties["contexts"].append(mention.get("mention_context")) # 选择最佳定义(最长的那个) - best_definition = ( - max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else "" + best_definition = ( + max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else "" ) # 选择最佳名称(最常见的那个) from collections import Counter - name_counts = Counter(fused_properties["names"]) - best_name = name_counts.most_common(1)[0][0] if name_counts else "" + name_counts = Counter(fused_properties["names"]) + best_name = name_counts.most_common(1)[0][0] if name_counts else "" # 构建融合结果 return FusionResult( - canonical_entity_id = entity_id, - merged_entity_ids = merged_ids, - fused_properties = { + canonical_entity_id=entity_id, + merged_entity_ids=merged_ids, + fused_properties={ "name": best_name, "definition": best_definition, "aliases": list(fused_properties["aliases"]), @@ -340,8 +340,8 @@ class MultimodalEntityLinker: "modalities": list(fused_properties["modalities"]), "contexts": fused_properties["contexts"][:10], # 最多10个上下文 }, - source_modalities = list(fused_properties["modalities"]), - confidence = min(1.0, len(linked_entities) * 0.2 + 0.5), + source_modalities=list(fused_properties["modalities"]), + confidence=min(1.0, len(linked_entities) * 0.2 + 0.5), ) def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]: @@ -354,30 +354,30 @@ class MultimodalEntityLinker: Returns: 冲突列表 """ - conflicts = [] + conflicts = [] # 按名称分组 - name_groups = {} + name_groups = {} for entity in entities: - name = entity.get("name", "").lower() + name = entity.get("name", "").lower() if name: if name not in name_groups: - name_groups[name] = [] + name_groups[name] = [] name_groups[name].append(entity) # 检测同名但定义不同的实体 for name, group in name_groups.items(): if len(group) > 1: # 检查定义是否相似 - definitions = [e.get("definition", "") for e in group if e.get("definition")] + definitions = [e.get("definition", "") for e in group if e.get("definition")] if len(definitions) > 1: # 计算定义之间的相似度 - sim_matrix = [] + sim_matrix = [] for i, d1 in enumerate(definitions): for j, d2 in enumerate(definitions): if i < j: - sim = self.calculate_string_similarity(d1, d2) + sim = self.calculate_string_similarity(d1, d2) sim_matrix.append(sim) # 如果定义相似度都很低,可能是冲突 @@ -394,7 +394,7 @@ class MultimodalEntityLinker: return conflicts def suggest_entity_merges( - self, entities: list[dict], existing_links: list[EntityLink] = None + self, entities: list[dict], existing_links: list[EntityLink] = None ) -> list[dict]: """ 建议实体合并 @@ -406,13 +406,13 @@ class MultimodalEntityLinker: Returns: 合并建议列表 """ - suggestions = [] - existing_pairs = set() + suggestions = [] + existing_pairs = set() # 记录已有的关联 if existing_links: for link in existing_links: - pair = tuple(sorted([link.source_entity_id, link.target_entity_id])) + pair = tuple(sorted([link.source_entity_id, link.target_entity_id])) existing_pairs.add(pair) # 检查所有实体对 @@ -422,12 +422,12 @@ class MultimodalEntityLinker: continue # 检查是否已有关联 - pair = tuple(sorted([ent1.get("id"), ent2.get("id")])) + pair = tuple(sorted([ent1.get("id"), ent2.get("id")])) if pair in existing_pairs: continue # 计算相似度 - similarity, match_type = self.calculate_entity_similarity(ent1, ent2) + similarity, match_type = self.calculate_entity_similarity(ent1, ent2) if similarity >= self.similarity_threshold: suggestions.append( @@ -441,7 +441,7 @@ class MultimodalEntityLinker: ) # 按相似度排序 - suggestions.sort(key = lambda x: x["similarity"], reverse = True) + suggestions.sort(key=lambda x: x["similarity"], reverse=True) return suggestions @@ -451,8 +451,8 @@ class MultimodalEntityLinker: entity_id: str, source_type: str, source_id: str, - mention_context: str = "", - confidence: float = 1.0, + mention_context: str = "", + confidence: float = 1.0, ) -> MultimodalEntity: """ 创建多模态实体记录 @@ -469,14 +469,14 @@ class MultimodalEntityLinker: 多模态实体记录 """ return MultimodalEntity( - id = str(uuid.uuid4())[:UUID_LENGTH], - entity_id = entity_id, - project_id = project_id, - name = "", # 将在后续填充 - source_type = source_type, - source_id = source_id, - mention_context = mention_context, - confidence = confidence, + id=str(uuid.uuid4())[:UUID_LENGTH], + entity_id=entity_id, + project_id=project_id, + name="", # 将在后续填充 + source_type=source_type, + source_id=source_id, + mention_context=mention_context, + confidence=confidence, ) def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict: @@ -489,7 +489,7 @@ class MultimodalEntityLinker: Returns: 模态分布统计 """ - distribution = {mod: 0 for mod in self.MODALITIES} + distribution = {mod: 0 for mod in self.MODALITIES} # 统计每个模态的实体数 for me in multimodal_entities: @@ -497,13 +497,13 @@ class MultimodalEntityLinker: distribution[me.source_type] += 1 # 统计跨模态实体 - entity_modalities = {} + entity_modalities = {} for me in multimodal_entities: if me.entity_id not in entity_modalities: - entity_modalities[me.entity_id] = set() + entity_modalities[me.entity_id] = set() entity_modalities[me.entity_id].add(me.source_type) - cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1) + cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1) return { "modality_distribution": distribution, @@ -517,12 +517,12 @@ class MultimodalEntityLinker: # Singleton instance -_multimodal_entity_linker = None +_multimodal_entity_linker = None -def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker: +def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker: """获取多模态实体关联器单例""" global _multimodal_entity_linker if _multimodal_entity_linker is None: - _multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold) + _multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold) return _multimodal_entity_linker diff --git a/backend/multimodal_processor.py b/backend/multimodal_processor.py index b13b4d9..9b564ab 100644 --- a/backend/multimodal_processor.py +++ b/backend/multimodal_processor.py @@ -13,30 +13,30 @@ from dataclasses import dataclass from pathlib import Path # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # 尝试导入OCR库 try: import pytesseract from PIL import Image - PYTESSERACT_AVAILABLE = True + PYTESSERACT_AVAILABLE = True except ImportError: - PYTESSERACT_AVAILABLE = False + PYTESSERACT_AVAILABLE = False try: import cv2 - CV2_AVAILABLE = True + CV2_AVAILABLE = True except ImportError: - CV2_AVAILABLE = False + CV2_AVAILABLE = False try: import ffmpeg - FFMPEG_AVAILABLE = True + FFMPEG_AVAILABLE = True except ImportError: - FFMPEG_AVAILABLE = False + FFMPEG_AVAILABLE = False @dataclass @@ -48,13 +48,13 @@ class VideoFrame: frame_number: int timestamp: float frame_path: str - ocr_text: str = "" - ocr_confidence: float = 0.0 - entities_detected: list[dict] = None + ocr_text: str = "" + ocr_confidence: float = 0.0 + entities_detected: list[dict] = None def __post_init__(self) -> None: if self.entities_detected is None: - self.entities_detected = [] + self.entities_detected = [] @dataclass @@ -65,20 +65,20 @@ class VideoInfo: project_id: str filename: str file_path: str - duration: float = 0.0 - width: int = 0 - height: int = 0 - fps: float = 0.0 - audio_extracted: bool = False - audio_path: str = "" - transcript_id: str = "" - status: str = "pending" - error_message: str = "" - metadata: dict = None + duration: float = 0.0 + width: int = 0 + height: int = 0 + fps: float = 0.0 + audio_extracted: bool = False + audio_path: str = "" + transcript_id: str = "" + status: str = "pending" + error_message: str = "" + metadata: dict = None def __post_init__(self) -> None: if self.metadata is None: - self.metadata = {} + self.metadata = {} @dataclass @@ -91,13 +91,13 @@ class VideoProcessingResult: ocr_results: list[dict] full_text: str # 整合的文本(音频转录 + OCR文本) success: bool - error_message: str = "" + error_message: str = "" class MultimodalProcessor: """多模态处理器 - 处理视频文件""" - def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None: + def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None: """ 初始化多模态处理器 @@ -105,16 +105,16 @@ class MultimodalProcessor: temp_dir: 临时文件目录 frame_interval: 关键帧提取间隔(秒) """ - self.temp_dir = temp_dir or tempfile.gettempdir() - self.frame_interval = frame_interval - self.video_dir = os.path.join(self.temp_dir, "videos") - self.frames_dir = os.path.join(self.temp_dir, "frames") - self.audio_dir = os.path.join(self.temp_dir, "audio") + self.temp_dir = temp_dir or tempfile.gettempdir() + self.frame_interval = frame_interval + self.video_dir = os.path.join(self.temp_dir, "videos") + self.frames_dir = os.path.join(self.temp_dir, "frames") + self.audio_dir = os.path.join(self.temp_dir, "audio") # 创建目录 - os.makedirs(self.video_dir, exist_ok = True) - os.makedirs(self.frames_dir, exist_ok = True) - os.makedirs(self.audio_dir, exist_ok = True) + os.makedirs(self.video_dir, exist_ok=True) + os.makedirs(self.frames_dir, exist_ok=True) + os.makedirs(self.audio_dir, exist_ok=True) def extract_video_info(self, video_path: str) -> dict: """ @@ -128,11 +128,11 @@ class MultimodalProcessor: """ try: if FFMPEG_AVAILABLE: - probe = ffmpeg.probe(video_path) - video_stream = next( + probe = ffmpeg.probe(video_path) + video_stream = next( (s for s in probe["streams"] if s["codec_type"] == "video"), None ) - audio_stream = next( + audio_stream = next( (s for s in probe["streams"] if s["codec_type"] == "audio"), None ) @@ -147,7 +147,7 @@ class MultimodalProcessor: } else: # 使用 ffprobe 命令行 - cmd = [ + cmd = [ "ffprobe", "-v", "error", @@ -159,9 +159,9 @@ class MultimodalProcessor: "json", video_path, ] - result = subprocess.run(cmd, capture_output = True, text = True) + result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode == 0: - data = json.loads(result.stdout) + data = json.loads(result.stdout) return { "duration": float(data["format"].get("duration", 0)), "width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0, @@ -177,7 +177,7 @@ class MultimodalProcessor: return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0} - def extract_audio(self, video_path: str, output_path: str = None) -> str: + def extract_audio(self, video_path: str, output_path: str = None) -> str: """ 从视频中提取音频 @@ -189,20 +189,20 @@ class MultimodalProcessor: 提取的音频文件路径 """ if output_path is None: - video_name = Path(video_path).stem - output_path = os.path.join(self.audio_dir, f"{video_name}.wav") + video_name = Path(video_path).stem + output_path = os.path.join(self.audio_dir, f"{video_name}.wav") try: if FFMPEG_AVAILABLE: ( ffmpeg.input(video_path) - .output(output_path, ac = 1, ar = 16000, vn = None) + .output(output_path, ac=1, ar=16000, vn=None) .overwrite_output() - .run(quiet = True) + .run(quiet=True) ) else: # 使用命令行 ffmpeg - cmd = [ + cmd = [ "ffmpeg", "-i", video_path, @@ -216,14 +216,14 @@ class MultimodalProcessor: "-y", output_path, ] - subprocess.run(cmd, check = True, capture_output = True) + subprocess.run(cmd, check=True, capture_output=True) return output_path except Exception as e: print(f"Error extracting audio: {e}") raise - def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]: + def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]: """ 从视频中提取关键帧 @@ -235,31 +235,31 @@ class MultimodalProcessor: Returns: 提取的帧文件路径列表 """ - interval = interval or self.frame_interval - frame_paths = [] + interval = interval or self.frame_interval + frame_paths = [] # 创建帧存储目录 - video_frames_dir = os.path.join(self.frames_dir, video_id) - os.makedirs(video_frames_dir, exist_ok = True) + video_frames_dir = os.path.join(self.frames_dir, video_id) + os.makedirs(video_frames_dir, exist_ok=True) try: if CV2_AVAILABLE: # 使用 OpenCV 提取帧 - cap = cv2.VideoCapture(video_path) - fps = cap.get(cv2.CAP_PROP_FPS) + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - frame_interval_frames = int(fps * interval) - frame_number = 0 + frame_interval_frames = int(fps * interval) + frame_number = 0 while True: - ret, frame = cap.read() + ret, frame = cap.read() if not ret: break if frame_number % frame_interval_frames == 0: - timestamp = frame_number / fps - frame_path = os.path.join( + timestamp = frame_number / fps + frame_path = os.path.join( video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg" ) cv2.imwrite(frame_path, frame) @@ -271,9 +271,9 @@ class MultimodalProcessor: else: # 使用 ffmpeg 命令行提取帧 Path(video_path).stem - output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg") + output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg") - cmd = [ + cmd = [ "ffmpeg", "-i", video_path, @@ -284,10 +284,10 @@ class MultimodalProcessor: "-y", output_pattern, ] - subprocess.run(cmd, check = True, capture_output = True) + subprocess.run(cmd, check=True, capture_output=True) # 获取生成的帧文件列表 - frame_paths = sorted( + frame_paths = sorted( [ os.path.join(video_frames_dir, f) for f in os.listdir(video_frames_dir) @@ -313,19 +313,19 @@ class MultimodalProcessor: return "", 0.0 try: - image = Image.open(image_path) + image = Image.open(image_path) # 预处理:转换为灰度图 if image.mode != "L": - image = image.convert("L") + image = image.convert("L") # 使用 pytesseract 进行 OCR - text = pytesseract.image_to_string(image, lang = "chi_sim+eng") + text = pytesseract.image_to_string(image, lang="chi_sim+eng") # 获取置信度数据 - data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT) - confidences = [int(c) for c in data["conf"] if int(c) > 0] - avg_confidence = sum(confidences) / len(confidences) if confidences else 0 + data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) + confidences = [int(c) for c in data["conf"] if int(c) > 0] + avg_confidence = sum(confidences) / len(confidences) if confidences else 0 return text.strip(), avg_confidence / 100.0 except Exception as e: @@ -333,7 +333,7 @@ class MultimodalProcessor: return "", 0.0 def process_video( - self, video_data: bytes, filename: str, project_id: str, video_id: str = None + self, video_data: bytes, filename: str, project_id: str, video_id: str = None ) -> VideoProcessingResult: """ 处理视频文件:提取音频、关键帧、OCR @@ -347,48 +347,48 @@ class MultimodalProcessor: Returns: 视频处理结果 """ - video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH] + video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH] try: # 保存视频文件 - video_path = os.path.join(self.video_dir, f"{video_id}_{filename}") + video_path = os.path.join(self.video_dir, f"{video_id}_{filename}") with open(video_path, "wb") as f: f.write(video_data) # 提取视频信息 - video_info = self.extract_video_info(video_path) + video_info = self.extract_video_info(video_path) # 提取音频 - audio_path = "" + audio_path = "" if video_info["has_audio"]: - audio_path = self.extract_audio(video_path) + audio_path = self.extract_audio(video_path) # 提取关键帧 - frame_paths = self.extract_keyframes(video_path, video_id) + frame_paths = self.extract_keyframes(video_path, video_id) # 对关键帧进行 OCR - frames = [] - ocr_results = [] - all_ocr_text = [] + frames = [] + ocr_results = [] + all_ocr_text = [] for i, frame_path in enumerate(frame_paths): # 解析帧信息 - frame_name = os.path.basename(frame_path) - parts = frame_name.replace(".jpg", "").split("_") - frame_number = int(parts[1]) if len(parts) > 1 else i - timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval + frame_name = os.path.basename(frame_path) + parts = frame_name.replace(".jpg", "").split("_") + frame_number = int(parts[1]) if len(parts) > 1 else i + timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval # OCR 识别 - ocr_text, confidence = self.perform_ocr(frame_path) + ocr_text, confidence = self.perform_ocr(frame_path) - frame = VideoFrame( - id = str(uuid.uuid4())[:UUID_LENGTH], - video_id = video_id, - frame_number = frame_number, - timestamp = timestamp, - frame_path = frame_path, - ocr_text = ocr_text, - ocr_confidence = confidence, + frame = VideoFrame( + id=str(uuid.uuid4())[:UUID_LENGTH], + video_id=video_id, + frame_number=frame_number, + timestamp=timestamp, + frame_path=frame_path, + ocr_text=ocr_text, + ocr_confidence=confidence, ) frames.append(frame) @@ -404,29 +404,29 @@ class MultimodalProcessor: all_ocr_text.append(ocr_text) # 整合所有 OCR 文本 - full_ocr_text = "\n\n".join(all_ocr_text) + full_ocr_text = "\n\n".join(all_ocr_text) return VideoProcessingResult( - video_id = video_id, - audio_path = audio_path, - frames = frames, - ocr_results = ocr_results, - full_text = full_ocr_text, - success = True, + video_id=video_id, + audio_path=audio_path, + frames=frames, + ocr_results=ocr_results, + full_text=full_ocr_text, + success=True, ) except Exception as e: return VideoProcessingResult( - video_id = video_id, - audio_path = "", - frames = [], - ocr_results = [], - full_text = "", - success = False, - error_message = str(e), + video_id=video_id, + audio_path="", + frames=[], + ocr_results=[], + full_text="", + success=False, + error_message=str(e), ) - def cleanup(self, video_id: str = None) -> None: + def cleanup(self, video_id: str = None) -> None: """ 清理临时文件 @@ -438,7 +438,7 @@ class MultimodalProcessor: if video_id: # 清理特定视频的文件 for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: - target_dir = ( + target_dir = ( os.path.join(dir_path, video_id) if dir_path == self.frames_dir else dir_path ) if os.path.exists(target_dir): @@ -450,16 +450,16 @@ class MultimodalProcessor: for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: if os.path.exists(dir_path): shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok = True) + os.makedirs(dir_path, exist_ok=True) # Singleton instance -_multimodal_processor = None +_multimodal_processor = None -def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor: +def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor: """获取多模态处理器单例""" global _multimodal_processor if _multimodal_processor is None: - _multimodal_processor = MultimodalProcessor(temp_dir, frame_interval) + _multimodal_processor = MultimodalProcessor(temp_dir, frame_interval) return _multimodal_processor diff --git a/backend/neo4j_manager.py b/backend/neo4j_manager.py index 229175d..167c87b 100644 --- a/backend/neo4j_manager.py +++ b/backend/neo4j_manager.py @@ -10,20 +10,20 @@ import logging import os from dataclasses import dataclass -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # Neo4j 连接配置 -NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") -NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") -NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password") +NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") +NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") +NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password") # 延迟导入,避免未安装时出错 try: from neo4j import Driver, GraphDatabase - NEO4J_AVAILABLE = True + NEO4J_AVAILABLE = True except ImportError: - NEO4J_AVAILABLE = False + NEO4J_AVAILABLE = False logger.warning("Neo4j driver not installed. Neo4j features will be disabled.") @@ -35,15 +35,15 @@ class GraphEntity: project_id: str name: str type: str - definition: str = "" - aliases: list[str] = None - properties: dict = None + definition: str = "" + aliases: list[str] = None + properties: dict = None def __post_init__(self) -> None: if self.aliases is None: - self.aliases = [] + self.aliases = [] if self.properties is None: - self.properties = {} + self.properties = {} @dataclass @@ -54,12 +54,12 @@ class GraphRelation: source_id: str target_id: str relation_type: str - evidence: str = "" - properties: dict = None + evidence: str = "" + properties: dict = None def __post_init__(self) -> None: if self.properties is None: - self.properties = {} + self.properties = {} @dataclass @@ -69,7 +69,7 @@ class PathResult: nodes: list[dict] relationships: list[dict] length: int - total_weight: float = 0.0 + total_weight: float = 0.0 @dataclass @@ -79,7 +79,7 @@ class CommunityResult: community_id: int nodes: list[dict] size: int - density: float = 0.0 + density: float = 0.0 @dataclass @@ -89,17 +89,17 @@ class CentralityResult: entity_id: str entity_name: str score: float - rank: int = 0 + rank: int = 0 class Neo4jManager: """Neo4j 图数据库管理器""" - def __init__(self, uri: str = None, user: str = None, password: str = None) -> None: - self.uri = uri or NEO4J_URI - self.user = user or NEO4J_USER - self.password = password or NEO4J_PASSWORD - self._driver: Driver | None = None + def __init__(self, uri: str = None, user: str = None, password: str = None) -> None: + self.uri = uri or NEO4J_URI + self.user = user or NEO4J_USER + self.password = password or NEO4J_PASSWORD + self._driver: Driver | None = None if not NEO4J_AVAILABLE: logger.error("Neo4j driver not available. Please install: pip install neo4j") @@ -113,13 +113,13 @@ class Neo4jManager: return try: - self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password)) + self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) # 验证连接 self._driver.verify_connectivity() logger.info(f"Connected to Neo4j at {self.uri}") except (RuntimeError, ValueError, TypeError) as e: logger.error(f"Failed to connect to Neo4j: {e}") - self._driver = None + self._driver = None def close(self) -> None: """关闭连接""" @@ -179,7 +179,7 @@ class Neo4jManager: # ==================== 数据同步 ==================== def sync_project( - self, project_id: str, project_name: str, project_description: str = "" + self, project_id: str, project_name: str, project_description: str = "" ) -> None: """同步项目节点到 Neo4j""" if not self._driver: @@ -193,9 +193,9 @@ class Neo4jManager: p.description = $description, p.updated_at = datetime() """, - project_id = project_id, - name = project_name, - description = project_description, + project_id=project_id, + name=project_name, + description=project_description, ) def sync_entity(self, entity: GraphEntity) -> None: @@ -218,13 +218,13 @@ class Neo4jManager: MATCH (p:Project {id: $project_id}) MERGE (e)-[:BELONGS_TO]->(p) """, - id = entity.id, - project_id = entity.project_id, - name = entity.name, - type = entity.type, - definition = entity.definition, - aliases = json.dumps(entity.aliases), - properties = json.dumps(entity.properties), + id=entity.id, + project_id=entity.project_id, + name=entity.name, + type=entity.type, + definition=entity.definition, + aliases=json.dumps(entity.aliases), + properties=json.dumps(entity.properties), ) def sync_entities_batch(self, entities: list[GraphEntity]) -> None: @@ -234,7 +234,7 @@ class Neo4jManager: with self._driver.session() as session: # 使用 UNWIND 批量处理 - entities_data = [ + entities_data = [ { "id": e.id, "project_id": e.project_id, @@ -261,7 +261,7 @@ class Neo4jManager: MATCH (p:Project {id: entity.project_id}) MERGE (e)-[:BELONGS_TO]->(p) """, - entities = entities_data, + entities=entities_data, ) def sync_relation(self, relation: GraphRelation) -> None: @@ -280,12 +280,12 @@ class Neo4jManager: r.properties = $properties, r.updated_at = datetime() """, - id = relation.id, - source_id = relation.source_id, - target_id = relation.target_id, - relation_type = relation.relation_type, - evidence = relation.evidence, - properties = json.dumps(relation.properties), + id=relation.id, + source_id=relation.source_id, + target_id=relation.target_id, + relation_type=relation.relation_type, + evidence=relation.evidence, + properties=json.dumps(relation.properties), ) def sync_relations_batch(self, relations: list[GraphRelation]) -> None: @@ -294,7 +294,7 @@ class Neo4jManager: return with self._driver.session() as session: - relations_data = [ + relations_data = [ { "id": r.id, "source_id": r.source_id, @@ -317,7 +317,7 @@ class Neo4jManager: r.properties = rel.properties, r.updated_at = datetime() """, - relations = relations_data, + relations=relations_data, ) def delete_entity(self, entity_id: str) -> None: @@ -331,7 +331,7 @@ class Neo4jManager: MATCH (e:Entity {id: $id}) DETACH DELETE e """, - id = entity_id, + id=entity_id, ) def delete_project(self, project_id: str) -> None: @@ -346,13 +346,13 @@ class Neo4jManager: OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p) DETACH DELETE e, p """, - id = project_id, + id=project_id, ) # ==================== 复杂图查询 ==================== def find_shortest_path( - self, source_id: str, target_id: str, max_depth: int = 10 + self, source_id: str, target_id: str, max_depth: int = 10 ) -> PathResult | None: """ 查找两个实体之间的最短路径 @@ -369,31 +369,31 @@ class Neo4jManager: return None with self._driver.session() as session: - result = session.run( + result = session.run( """ MATCH path = shortestPath( (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) ) RETURN path """, - source_id = source_id, - target_id = target_id, - max_depth = max_depth, + source_id=source_id, + target_id=target_id, + max_depth=max_depth, ) - record = result.single() + record = result.single() if not record: return None - path = record["path"] + path = record["path"] # 提取节点和关系 - nodes = [ + nodes = [ {"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes ] - relationships = [ + relationships = [ { "source": rel.start_node["id"], "target": rel.end_node["id"], @@ -404,11 +404,11 @@ class Neo4jManager: ] return PathResult( - nodes = nodes, relationships = relationships, length = len(path.relationships) + nodes=nodes, relationships=relationships, length=len(path.relationships) ) def find_all_paths( - self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10 + self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10 ) -> list[PathResult]: """ 查找两个实体之间的所有路径 @@ -426,29 +426,29 @@ class Neo4jManager: return [] with self._driver.session() as session: - result = session.run( + result = session.run( """ MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) WHERE source <> target RETURN path LIMIT $limit """, - source_id = source_id, - target_id = target_id, - max_depth = max_depth, - limit = limit, + source_id=source_id, + target_id=target_id, + max_depth=max_depth, + limit=limit, ) - paths = [] + paths = [] for record in result: - path = record["path"] + path = record["path"] - nodes = [ + nodes = [ {"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes ] - relationships = [ + relationships = [ { "source": rel.start_node["id"], "target": rel.end_node["id"], @@ -460,14 +460,14 @@ class Neo4jManager: paths.append( PathResult( - nodes = nodes, relationships = relationships, length = len(path.relationships) + nodes=nodes, relationships=relationships, length=len(path.relationships) ) ) return paths def find_neighbors( - self, entity_id: str, relation_type: str = None, limit: int = 50 + self, entity_id: str, relation_type: str = None, limit: int = 50 ) -> list[dict]: """ 查找实体的邻居节点 @@ -485,30 +485,30 @@ class Neo4jManager: with self._driver.session() as session: if relation_type: - result = session.run( + result = session.run( """ MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]-(neighbor:Entity) RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence LIMIT $limit """, - entity_id = entity_id, - relation_type = relation_type, - limit = limit, + entity_id=entity_id, + relation_type=relation_type, + limit=limit, ) else: - result = session.run( + result = session.run( """ MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO]-(neighbor:Entity) RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence LIMIT $limit """, - entity_id = entity_id, - limit = limit, + entity_id=entity_id, + limit=limit, ) - neighbors = [] + neighbors = [] for record in result: - node = record["neighbor"] + node = record["neighbor"] neighbors.append( { "id": node["id"], @@ -536,13 +536,13 @@ class Neo4jManager: return [] with self._driver.session() as session: - result = session.run( + result = session.run( """ MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2}) RETURN DISTINCT common """, - id1 = entity_id1, - id2 = entity_id2, + id1=entity_id1, + id2=entity_id2, ) return [ @@ -556,7 +556,7 @@ class Neo4jManager: # ==================== 图算法分析 ==================== - def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: + def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: """ 计算 PageRank 中心性 @@ -571,7 +571,7 @@ class Neo4jManager: return [] with self._driver.session() as session: - result = session.run( + result = session.run( """ CALL gds.graph.exists('project-graph-$project_id') YIELD exists WITH exists @@ -581,7 +581,7 @@ class Neo4jManager: {} ) YIELD value RETURN value """, - project_id = project_id, + project_id=project_id, ) # 创建临时图 @@ -601,11 +601,11 @@ class Neo4jManager: } ) """, - project_id = project_id, + project_id=project_id, ) # 运行 PageRank - result = session.run( + result = session.run( """ CALL gds.pageRank.stream('project-graph-$project_id') YIELD nodeId, score @@ -615,19 +615,19 @@ class Neo4jManager: ORDER BY score DESC LIMIT $top_n """, - project_id = project_id, - top_n = top_n, + project_id=project_id, + top_n=top_n, ) - rankings = [] - rank = 1 + rankings = [] + rank = 1 for record in result: rankings.append( CentralityResult( - entity_id = record["entity_id"], - entity_name = record["entity_name"], - score = record["score"], - rank = rank, + entity_id=record["entity_id"], + entity_name=record["entity_name"], + score=record["score"], + rank=rank, ) ) rank += 1 @@ -637,12 +637,12 @@ class Neo4jManager: """ CALL gds.graph.drop('project-graph-$project_id') """, - project_id = project_id, + project_id=project_id, ) return rankings - def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: + def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: """ 计算 Betweenness 中心性(桥梁作用) @@ -658,7 +658,7 @@ class Neo4jManager: with self._driver.session() as session: # 使用 APOC 的 betweenness 计算(如果没有 GDS) - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) @@ -667,19 +667,19 @@ class Neo4jManager: LIMIT $top_n RETURN e.id as entity_id, e.name as entity_name, degree as score """, - project_id = project_id, - top_n = top_n, + project_id=project_id, + top_n=top_n, ) - rankings = [] - rank = 1 + rankings = [] + rank = 1 for record in result: rankings.append( CentralityResult( - entity_id = record["entity_id"], - entity_name = record["entity_name"], - score = float(record["score"]), - rank = rank, + entity_id=record["entity_id"], + entity_name=record["entity_name"], + score=float(record["score"]), + rank=rank, ) ) rank += 1 @@ -701,7 +701,7 @@ class Neo4jManager: with self._driver.session() as session: # 简单的社区检测:基于连通分量 - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)-[:BELONGS_TO]->(p) @@ -710,25 +710,25 @@ class Neo4jManager: connections, size(connections) as connection_count ORDER BY connection_count DESC """, - project_id = project_id, + project_id=project_id, ) # 手动分组(基于连通性) - communities = {} + communities = {} for record in result: - entity_id = record["entity_id"] - connections = record["connections"] + entity_id = record["entity_id"] + connections = record["connections"] # 找到所属的社区 - found_community = None + found_community = None for comm_id, comm_data in communities.items(): if any(conn in comm_data["member_ids"] for conn in connections): - found_community = comm_id + found_community = comm_id break if found_community is None: - found_community = len(communities) - communities[found_community] = {"member_ids": set(), "nodes": []} + found_community = len(communities) + communities[found_community] = {"member_ids": set(), "nodes": []} communities[found_community]["member_ids"].add(entity_id) communities[found_community]["nodes"].append( @@ -741,27 +741,27 @@ class Neo4jManager: ) # 构建结果 - results = [] + results = [] for comm_id, comm_data in communities.items(): - nodes = comm_data["nodes"] - size = len(nodes) + nodes = comm_data["nodes"] + size = len(nodes) # 计算密度(简化版) - max_edges = size * (size - 1) / 2 if size > 1 else 1 - actual_edges = sum(n["connections"] for n in nodes) / 2 - density = actual_edges / max_edges if max_edges > 0 else 0 + max_edges = size * (size - 1) / 2 if size > 1 else 1 + actual_edges = sum(n["connections"] for n in nodes) / 2 + density = actual_edges / max_edges if max_edges > 0 else 0 results.append( CommunityResult( - community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0) + community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0) ) ) # 按大小排序 - results.sort(key = lambda x: x.size, reverse = True) + results.sort(key=lambda x: x.size, reverse=True) return results def find_central_entities( - self, project_id: str, metric: str = "degree" + self, project_id: str, metric: str = "degree" ) -> list[CentralityResult]: """ 查找中心实体 @@ -778,7 +778,7 @@ class Neo4jManager: with self._driver.session() as session: if metric == "degree": - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) @@ -787,11 +787,11 @@ class Neo4jManager: ORDER BY degree DESC LIMIT 20 """, - project_id = project_id, + project_id=project_id, ) else: # 默认使用度中心性 - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) @@ -800,18 +800,18 @@ class Neo4jManager: ORDER BY degree DESC LIMIT 20 """, - project_id = project_id, + project_id=project_id, ) - rankings = [] - rank = 1 + rankings = [] + rank = 1 for record in result: rankings.append( CentralityResult( - entity_id = record["entity_id"], - entity_name = record["entity_name"], - score = float(record["score"]), - rank = rank, + entity_id=record["entity_id"], + entity_name=record["entity_name"], + score=float(record["score"]), + rank=rank, ) ) rank += 1 @@ -835,49 +835,49 @@ class Neo4jManager: with self._driver.session() as session: # 实体数量 - entity_count = session.run( + entity_count = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) RETURN count(e) as count """, - project_id = project_id, + project_id=project_id, ).single()["count"] # 关系数量 - relation_count = session.run( + relation_count = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e)-[r:RELATES_TO]-() RETURN count(r) as count """, - project_id = project_id, + project_id=project_id, ).single()["count"] # 实体类型分布 - type_distribution = session.run( + type_distribution = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) RETURN e.type as type, count(e) as count ORDER BY count DESC """, - project_id = project_id, + project_id=project_id, ) - types = {record["type"]: record["count"] for record in type_distribution} + types = {record["type"]: record["count"] for record in type_distribution} # 平均度 - avg_degree = session.run( + avg_degree = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other) WITH e, count(other) as degree RETURN avg(degree) as avg_degree """, - project_id = project_id, + project_id=project_id, ).single()["avg_degree"] # 关系类型分布 - rel_types = session.run( + rel_types = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e)-[r:RELATES_TO]-() @@ -885,10 +885,10 @@ class Neo4jManager: ORDER BY count DESC LIMIT 10 """, - project_id = project_id, + project_id=project_id, ) - relation_types = {record["type"]: record["count"] for record in rel_types} + relation_types = {record["type"]: record["count"] for record in rel_types} return { "entity_count": entity_count, @@ -901,7 +901,7 @@ class Neo4jManager: else 0, } - def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: + def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: """ 获取指定实体的子图 @@ -916,7 +916,7 @@ class Neo4jManager: return {"nodes": [], "relationships": []} with self._driver.session() as session: - result = session.run( + result = session.run( """ MATCH (e:Entity) WHERE e.id IN $entity_ids @@ -927,14 +927,14 @@ class Neo4jManager: }) YIELD node RETURN DISTINCT node """, - entity_ids = entity_ids, - depth = depth, + entity_ids=entity_ids, + depth=depth, ) - nodes = [] - node_ids = set() + nodes = [] + node_ids = set() for record in result: - node = record["node"] + node = record["node"] node_ids.add(node["id"]) nodes.append( { @@ -946,17 +946,17 @@ class Neo4jManager: ) # 获取这些节点之间的关系 - result = session.run( + result = session.run( """ MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity) WHERE source.id IN $node_ids AND target.id IN $node_ids RETURN source.id as source_id, target.id as target_id, r.relation_type as type, r.evidence as evidence """, - node_ids = list(node_ids), + node_ids=list(node_ids), ) - relationships = [ + relationships = [ { "source": record["source_id"], "target": record["target_id"], @@ -970,14 +970,14 @@ class Neo4jManager: # 全局单例 -_neo4j_manager = None +_neo4j_manager = None def get_neo4j_manager() -> Neo4jManager: """获取 Neo4j 管理器单例""" global _neo4j_manager if _neo4j_manager is None: - _neo4j_manager = Neo4jManager() + _neo4j_manager = Neo4jManager() return _neo4j_manager @@ -986,7 +986,7 @@ def close_neo4j_manager() -> None: global _neo4j_manager if _neo4j_manager: _neo4j_manager.close() - _neo4j_manager = None + _neo4j_manager = None # 便捷函数 @@ -1004,7 +1004,7 @@ def sync_project_to_neo4j( entities: 实体列表(字典格式) relations: 关系列表(字典格式) """ - manager = get_neo4j_manager() + manager = get_neo4j_manager() if not manager.is_connected(): logger.warning("Neo4j not connected, skipping sync") return @@ -1013,29 +1013,29 @@ def sync_project_to_neo4j( manager.sync_project(project_id, project_name) # 同步实体 - graph_entities = [ + graph_entities = [ GraphEntity( - id = e["id"], - project_id = project_id, - name = e["name"], - type = e.get("type", "unknown"), - definition = e.get("definition", ""), - aliases = e.get("aliases", []), - properties = e.get("properties", {}), + id=e["id"], + project_id=project_id, + name=e["name"], + type=e.get("type", "unknown"), + definition=e.get("definition", ""), + aliases=e.get("aliases", []), + properties=e.get("properties", {}), ) for e in entities ] manager.sync_entities_batch(graph_entities) # 同步关系 - graph_relations = [ + graph_relations = [ GraphRelation( - id = r["id"], - source_id = r["source_entity_id"], - target_id = r["target_entity_id"], - relation_type = r["relation_type"], - evidence = r.get("evidence", ""), - properties = r.get("properties", {}), + id=r["id"], + source_id=r["source_entity_id"], + target_id=r["target_entity_id"], + relation_type=r["relation_type"], + evidence=r.get("evidence", ""), + properties=r.get("properties", {}), ) for r in relations ] @@ -1048,9 +1048,9 @@ def sync_project_to_neo4j( if __name__ == "__main__": # 测试代码 - logging.basicConfig(level = logging.INFO) + logging.basicConfig(level=logging.INFO) - manager = Neo4jManager() + manager = Neo4jManager() if manager.is_connected(): print("✅ Connected to Neo4j") @@ -1064,18 +1064,18 @@ if __name__ == "__main__": print("✅ Project synced") # 测试实体 - test_entity = GraphEntity( - id = "test-entity-1", - project_id = "test-project", - name = "Test Entity", - type = "Person", - definition = "A test entity", + test_entity = GraphEntity( + id="test-entity-1", + project_id="test-project", + name="Test Entity", + type="Person", + definition="A test entity", ) manager.sync_entity(test_entity) print("✅ Entity synced") # 获取统计 - stats = manager.get_graph_stats("test-project") + stats = manager.get_graph_stats("test-project") print(f"📊 Graph stats: {stats}") else: diff --git a/backend/ops_manager.py b/backend/ops_manager.py index ba80dfc..da7a992 100644 --- a/backend/ops_manager.py +++ b/backend/ops_manager.py @@ -27,87 +27,87 @@ from enum import StrEnum import httpx # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class AlertSeverity(StrEnum): """告警严重级别 P0-P3""" - P0 = "p0" # 紧急 - 系统不可用,需要立即处理 - P1 = "p1" # 严重 - 核心功能受损,需要1小时内处理 - P2 = "p2" # 一般 - 部分功能受影响,需要4小时内处理 - P3 = "p3" # 轻微 - 非核心功能问题,24小时内处理 + P0 = "p0" # 紧急 - 系统不可用,需要立即处理 + P1 = "p1" # 严重 - 核心功能受损,需要1小时内处理 + P2 = "p2" # 一般 - 部分功能受影响,需要4小时内处理 + P3 = "p3" # 轻微 - 非核心功能问题,24小时内处理 class AlertStatus(StrEnum): """告警状态""" - FIRING = "firing" # 正在告警 - RESOLVED = "resolved" # 已恢复 - ACKNOWLEDGED = "acknowledged" # 已确认 - SUPPRESSED = "suppressed" # 已抑制 + FIRING = "firing" # 正在告警 + RESOLVED = "resolved" # 已恢复 + ACKNOWLEDGED = "acknowledged" # 已确认 + SUPPRESSED = "suppressed" # 已抑制 class AlertChannelType(StrEnum): """告警渠道类型""" - PAGERDUTY = "pagerduty" - OPSGENIE = "opsgenie" - FEISHU = "feishu" - DINGTALK = "dingtalk" - SLACK = "slack" - EMAIL = "email" - SMS = "sms" - WEBHOOK = "webhook" + PAGERDUTY = "pagerduty" + OPSGENIE = "opsgenie" + FEISHU = "feishu" + DINGTALK = "dingtalk" + SLACK = "slack" + EMAIL = "email" + SMS = "sms" + WEBHOOK = "webhook" class AlertRuleType(StrEnum): """告警规则类型""" - THRESHOLD = "threshold" # 阈值告警 - ANOMALY = "anomaly" # 异常检测 - PREDICTIVE = "predictive" # 预测性告警 - COMPOSITE = "composite" # 复合告警 + THRESHOLD = "threshold" # 阈值告警 + ANOMALY = "anomaly" # 异常检测 + PREDICTIVE = "predictive" # 预测性告警 + COMPOSITE = "composite" # 复合告警 class ResourceType(StrEnum): """资源类型""" - CPU = "cpu" - MEMORY = "memory" - DISK = "disk" - NETWORK = "network" - GPU = "gpu" - DATABASE = "database" - CACHE = "cache" - QUEUE = "queue" + CPU = "cpu" + MEMORY = "memory" + DISK = "disk" + NETWORK = "network" + GPU = "gpu" + DATABASE = "database" + CACHE = "cache" + QUEUE = "queue" class ScalingAction(StrEnum): """扩缩容动作""" - SCALE_UP = "scale_up" # 扩容 - SCALE_DOWN = "scale_down" # 缩容 - MAINTAIN = "maintain" # 保持 + SCALE_UP = "scale_up" # 扩容 + SCALE_DOWN = "scale_down" # 缩容 + MAINTAIN = "maintain" # 保持 class HealthStatus(StrEnum): """健康状态""" - HEALTHY = "healthy" - DEGRADED = "degraded" - UNHEALTHY = "unhealthy" - UNKNOWN = "unknown" + HEALTHY = "healthy" + DEGRADED = "degraded" + UNHEALTHY = "unhealthy" + UNKNOWN = "unknown" class BackupStatus(StrEnum): """备份状态""" - PENDING = "pending" - IN_PROGRESS = "in_progress" - COMPLETED = "completed" - FAILED = "failed" - VERIFIED = "verified" + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + VERIFIED = "verified" @dataclass @@ -121,7 +121,7 @@ class AlertRule: rule_type: AlertRuleType severity: AlertSeverity metric: str # 监控指标 - condition: str # 条件: >, <, ==, >= , <= , != + condition: str # 条件: >, <, ==, >= , <= , != threshold: float duration: int # 持续时间(秒) evaluation_interval: int # 评估间隔(秒) @@ -450,23 +450,23 @@ class OpsManager: """运维与监控管理主类""" def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path - self._alert_evaluators: dict[str, Callable] = {} - self._running = False - self._evaluator_thread = None + self.db_path = db_path + self._alert_evaluators: dict[str, Callable] = {} + self._running = False + self._evaluator_thread = None self._register_default_evaluators() def _get_db(self) -> None: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _register_default_evaluators(self) -> None: """注册默认的告警评估器""" - self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule - self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule - self._alert_evaluators[AlertRuleType.PREDICTIVE.value] = self._evaluate_predictive_rule + self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule + self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule + self._alert_evaluators[AlertRuleType.PREDICTIVE.value] = self._evaluate_predictive_rule # ==================== 告警规则管理 ==================== @@ -488,28 +488,28 @@ class OpsManager: created_by: str, ) -> AlertRule: """创建告警规则""" - rule_id = f"ar_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + rule_id = f"ar_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - rule = AlertRule( - id = rule_id, - tenant_id = tenant_id, - name = name, - description = description, - rule_type = rule_type, - severity = severity, - metric = metric, - condition = condition, - threshold = threshold, - duration = duration, - evaluation_interval = evaluation_interval, - channels = channels, - labels = labels or {}, - annotations = annotations or {}, - is_enabled = True, - created_at = now, - updated_at = now, - created_by = created_by, + rule = AlertRule( + id=rule_id, + tenant_id=tenant_id, + name=name, + description=description, + rule_type=rule_type, + severity=severity, + metric=metric, + condition=condition, + threshold=threshold, + duration=duration, + evaluation_interval=evaluation_interval, + channels=channels, + labels=labels or {}, + annotations=annotations or {}, + is_enabled=True, + created_at=now, + updated_at=now, + created_by=created_by, ) with self._get_db() as conn: @@ -549,16 +549,16 @@ class OpsManager: def get_alert_rule(self, rule_id: str) -> AlertRule | None: """获取告警规则""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id, )).fetchone() + row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id, )).fetchone() if row: return self._row_to_alert_rule(row) return None - def list_alert_rules(self, tenant_id: str, is_enabled: bool | None = None) -> list[AlertRule]: + def list_alert_rules(self, tenant_id: str, is_enabled: bool | None = None) -> list[AlertRule]: """列出租户的所有告警规则""" - query = "SELECT * FROM alert_rules WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM alert_rules WHERE tenant_id = ?" + params = [tenant_id] if is_enabled is not None: query += " AND is_enabled = ?" @@ -567,12 +567,12 @@ class OpsManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_alert_rule(row) for row in rows] def update_alert_rule(self, rule_id: str, **kwargs) -> AlertRule | None: """更新告警规则""" - allowed_fields = [ + allowed_fields = [ "name", "description", "severity", @@ -587,24 +587,24 @@ class OpsManager: "is_enabled", ] - updates = {k: v for k, v in kwargs.items() if k in allowed_fields} + updates = {k: v for k, v in kwargs.items() if k in allowed_fields} if not updates: return self.get_alert_rule(rule_id) # 处理列表和字典字段 if "channels" in updates: - updates["channels"] = json.dumps(updates["channels"]) + updates["channels"] = json.dumps(updates["channels"]) if "labels" in updates: - updates["labels"] = json.dumps(updates["labels"]) + updates["labels"] = json.dumps(updates["labels"]) if "annotations" in updates: - updates["annotations"] = json.dumps(updates["annotations"]) + updates["annotations"] = json.dumps(updates["annotations"]) if "severity" in updates and isinstance(updates["severity"], AlertSeverity): - updates["severity"] = updates["severity"].value + updates["severity"] = updates["severity"].value - updates["updated_at"] = datetime.now().isoformat() + updates["updated_at"] = datetime.now().isoformat() with self._get_db() as conn: - set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) conn.execute( f"UPDATE alert_rules SET {set_clause} WHERE id = ?", list(updates.values()) + [rule_id], @@ -628,25 +628,25 @@ class OpsManager: name: str, channel_type: AlertChannelType, config: dict, - severity_filter: list[str] = None, + severity_filter: list[str] = None, ) -> AlertChannel: """创建告警渠道""" - channel_id = f"ac_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + channel_id = f"ac_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - channel = AlertChannel( - id = channel_id, - tenant_id = tenant_id, - name = name, - channel_type = channel_type, - config = config, - severity_filter = severity_filter or [s.value for s in AlertSeverity], - is_enabled = True, - success_count = 0, - fail_count = 0, - last_used_at = None, - created_at = now, - updated_at = now, + channel = AlertChannel( + id=channel_id, + tenant_id=tenant_id, + name=name, + channel_type=channel_type, + config=config, + severity_filter=severity_filter or [s.value for s in AlertSeverity], + is_enabled=True, + success_count=0, + fail_count=0, + last_used_at=None, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -679,7 +679,7 @@ class OpsManager: def get_alert_channel(self, channel_id: str) -> AlertChannel | None: """获取告警渠道""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM alert_channels WHERE id = ?", (channel_id, ) ).fetchone() @@ -690,7 +690,7 @@ class OpsManager: def list_alert_channels(self, tenant_id: str) -> list[AlertChannel]: """列出租户的所有告警渠道""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM alert_channels WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id, ), ).fetchall() @@ -698,29 +698,29 @@ class OpsManager: def test_alert_channel(self, channel_id: str) -> bool: """测试告警渠道""" - channel = self.get_alert_channel(channel_id) + channel = self.get_alert_channel(channel_id) if not channel: return False - test_alert = Alert( - id = "test", - rule_id = "test", - tenant_id = channel.tenant_id, - severity = AlertSeverity.P3, - status = AlertStatus.FIRING, - title = "测试告警", - description = "这是一条测试告警消息,用于验证告警渠道配置。", - metric = "test_metric", - value = 0.0, - threshold = 0.0, - labels = {"test": "true"}, - annotations = {}, - started_at = datetime.now().isoformat(), - resolved_at = None, - acknowledged_by = None, - acknowledged_at = None, - notification_sent = {}, - suppression_count = 0, + test_alert = Alert( + id="test", + rule_id="test", + tenant_id=channel.tenant_id, + severity=AlertSeverity.P3, + status=AlertStatus.FIRING, + title="测试告警", + description="这是一条测试告警消息,用于验证告警渠道配置。", + metric="test_metric", + value=0.0, + threshold=0.0, + labels={"test": "true"}, + annotations={}, + started_at=datetime.now().isoformat(), + resolved_at=None, + acknowledged_by=None, + acknowledged_at=None, + notification_sent={}, + suppression_count=0, ) return asyncio.run(self._send_alert_to_channel(test_alert, channel)) @@ -733,14 +733,14 @@ class OpsManager: return False # 获取最近 duration 秒内的指标 - cutoff_time = datetime.now() - timedelta(seconds = rule.duration) - recent_metrics = [m for m in metrics if datetime.fromisoformat(m.timestamp) > cutoff_time] + cutoff_time = datetime.now() - timedelta(seconds=rule.duration) + recent_metrics = [m for m in metrics if datetime.fromisoformat(m.timestamp) > cutoff_time] if not recent_metrics: return False # 计算平均值 - avg_value = statistics.mean([m.metric_value for m in recent_metrics]) + avg_value = statistics.mean([m.metric_value for m in recent_metrics]) # 评估条件 condition_map = { @@ -752,7 +752,7 @@ class OpsManager: "!=": lambda x, y: x != y, } - evaluator = condition_map.get(rule.condition) + evaluator = condition_map.get(rule.condition) if evaluator: return evaluator(avg_value, rule.threshold) @@ -763,16 +763,16 @@ class OpsManager: if len(metrics) < 10: return False - values = [m.metric_value for m in metrics] - mean = statistics.mean(values) - std = statistics.stdev(values) if len(values) > 1 else 0 + values = [m.metric_value for m in metrics] + mean = statistics.mean(values) + std = statistics.stdev(values) if len(values) > 1 else 0 if std == 0: return False # 最近值偏离均值超过3个标准差视为异常 - latest_value = values[-1] - z_score = abs(latest_value - mean) / std + latest_value = values[-1] + z_score = abs(latest_value - mean) / std return z_score > 3.0 @@ -782,15 +782,15 @@ class OpsManager: return False # 简单的线性趋势预测 - values = [m.metric_value for m in metrics[-10:]] # 最近10个点 - n = len(values) + values = [m.metric_value for m in metrics[-10:]] # 最近10个点 + n = len(values) if n < 2: return False - x = list(range(n)) - mean_x = sum(x) / n - mean_y = sum(values) / n + x = list(range(n)) + mean_x = sum(x) / n + mean_y = sum(values) / n # 计算斜率 numerator = sum((x[i] - mean_x) * (values[i] - mean_y) for i in range(n)) @@ -801,12 +801,12 @@ class OpsManager: predicted = values[-1] + slope # 如果预测值超过阈值,触发告警 - condition_map = { + condition_map = { ">": lambda x, y: x > y, "<": lambda x, y: x < y, } - evaluator = condition_map.get(rule.condition) + evaluator = condition_map.get(rule.condition) if evaluator: return evaluator(predicted, rule.threshold) @@ -814,16 +814,16 @@ class OpsManager: async def evaluate_alert_rules(self, tenant_id: str) -> None: """评估所有告警规则""" - rules = self.list_alert_rules(tenant_id, is_enabled = True) + rules = self.list_alert_rules(tenant_id, is_enabled=True) for rule in rules: # 获取相关指标 - metrics = self.get_recent_metrics( - tenant_id, rule.metric, seconds = rule.duration + rule.evaluation_interval + metrics = self.get_recent_metrics( + tenant_id, rule.metric, seconds=rule.duration + rule.evaluation_interval ) # 评估规则 - evaluator = self._alert_evaluators.get(rule.rule_type.value) + evaluator = self._alert_evaluators.get(rule.rule_type.value) if evaluator and evaluator(rule, metrics): # 触发告警 await self._trigger_alert(rule, metrics[-1] if metrics else None) @@ -831,7 +831,7 @@ class OpsManager: async def _trigger_alert(self, rule: AlertRule, metric: ResourceMetric | None) -> None: """触发告警""" # 检查是否已有相同告警在触发中 - existing = self.get_active_alert_by_rule(rule.id) + existing = self.get_active_alert_by_rule(rule.id) if existing: # 更新抑制计数 self._increment_suppression_count(existing.id) @@ -841,28 +841,28 @@ class OpsManager: if self._is_alert_suppressed(rule): return - alert_id = f"al_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + alert_id = f"al_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - alert = Alert( - id = alert_id, - rule_id = rule.id, - tenant_id = rule.tenant_id, - severity = rule.severity, - status = AlertStatus.FIRING, - title = rule.annotations.get("summary", f"告警: {rule.name}"), - description = rule.annotations.get("description", rule.description), - metric = rule.metric, - value = metric.metric_value if metric else 0.0, - threshold = rule.threshold, - labels = rule.labels, - annotations = rule.annotations, - started_at = now, - resolved_at = None, - acknowledged_by = None, - acknowledged_at = None, - notification_sent = {}, - suppression_count = 0, + alert = Alert( + id=alert_id, + rule_id=rule.id, + tenant_id=rule.tenant_id, + severity=rule.severity, + status=AlertStatus.FIRING, + title=rule.annotations.get("summary", f"告警: {rule.name}"), + description=rule.annotations.get("description", rule.description), + metric=rule.metric, + value=metric.metric_value if metric else 0.0, + threshold=rule.threshold, + labels=rule.labels, + annotations=rule.annotations, + started_at=now, + resolved_at=None, + acknowledged_by=None, + acknowledged_at=None, + notification_sent={}, + suppression_count=0, ) # 保存告警 @@ -900,9 +900,9 @@ class OpsManager: async def _send_alert_notifications(self, alert: Alert, rule: AlertRule) -> None: """发送告警通知到所有配置的渠道""" - channels = [] + channels = [] for channel_id in rule.channels: - channel = self.get_alert_channel(channel_id) + channel = self.get_alert_channel(channel_id) if channel and channel.is_enabled: channels.append(channel) @@ -911,10 +911,10 @@ class OpsManager: if alert.severity.value not in channel.severity_filter: continue - success = await self._send_alert_to_channel(alert, channel) + success = await self._send_alert_to_channel(alert, channel) # 更新发送状态 - alert.notification_sent[channel.id] = success + alert.notification_sent[channel.id] = success self._update_alert_notification_status(alert.id, channel.id, success) async def _send_alert_to_channel(self, alert: Alert, channel: AlertChannel) -> bool: @@ -942,22 +942,22 @@ class OpsManager: async def _send_feishu_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送飞书告警""" - config = channel.config - webhook_url = config.get("webhook_url") + config = channel.config + webhook_url = config.get("webhook_url") config.get("secret", "") if not webhook_url: return False # 构建飞书消息 - severity_colors = { + severity_colors = { AlertSeverity.P0.value: "red", AlertSeverity.P1.value: "orange", AlertSeverity.P2.value: "yellow", AlertSeverity.P3.value: "blue", } - message = { + message = { "msg_type": "interactive", "card": { "config": {"wide_screen_mode": True}, @@ -990,27 +990,27 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json = message, timeout = 30.0) - success = response.status_code == 200 + response = await client.post(webhook_url, json=message, timeout=30.0) + success = response.status_code == 200 if success: - self._update_channel_stats(channel.id, success = True) + self._update_channel_stats(channel.id, success=True) else: - self._update_channel_stats(channel.id, success = False) + self._update_channel_stats(channel.id, success=False) return success async def _send_dingtalk_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送钉钉告警""" - config = channel.config - webhook_url = config.get("webhook_url") + config = channel.config + webhook_url = config.get("webhook_url") config.get("secret", "") if not webhook_url: return False # 构建钉钉消息 - message = { + message = { "msgtype": "markdown", "markdown": { "title": f"[{alert.severity.value.upper()}] {alert.title}", @@ -1024,29 +1024,29 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json = message, timeout = 30.0) - success = response.status_code == 200 + response = await client.post(webhook_url, json=message, timeout=30.0) + success = response.status_code == 200 self._update_channel_stats(channel.id, success) return success async def _send_slack_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 Slack 告警""" - config = channel.config - webhook_url = config.get("webhook_url") + config = channel.config + webhook_url = config.get("webhook_url") if not webhook_url: return False - severity_emojis = { + severity_emojis = { AlertSeverity.P0.value: "🔴", AlertSeverity.P1.value: "🟠", AlertSeverity.P2.value: "🟡", AlertSeverity.P3.value: "🔵", } - emoji = severity_emojis.get(alert.severity.value, "⚪") + emoji = severity_emojis.get(alert.severity.value, "⚪") - message = { + message = { "text": f"{emoji} [{alert.severity.value.upper()}] {alert.title}", "blocks": [ { @@ -1073,20 +1073,20 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json = message, timeout = 30.0) - success = response.status_code == 200 + response = await client.post(webhook_url, json=message, timeout=30.0) + success = response.status_code == 200 self._update_channel_stats(channel.id, success) return success async def _send_email_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送邮件告警(模拟实现)""" # 实际实现需要集成邮件服务如 SendGrid、AWS SES 等 - config = channel.config - smtp_host = config.get("smtp_host") + config = channel.config + smtp_host = config.get("smtp_host") config.get("smtp_port", 587) - username = config.get("username") - password = config.get("password") - to_addresses = config.get("to_addresses", []) + username = config.get("username") + password = config.get("password") + to_addresses = config.get("to_addresses", []) if not all([smtp_host, username, password, to_addresses]): return False @@ -1097,20 +1097,20 @@ class OpsManager: async def _send_pagerduty_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 PagerDuty 告警""" - config = channel.config - integration_key = config.get("integration_key") + config = channel.config + integration_key = config.get("integration_key") if not integration_key: return False - severity_map = { + severity_map = { AlertSeverity.P0.value: "critical", AlertSeverity.P1.value: "error", AlertSeverity.P2.value: "warning", AlertSeverity.P3.value: "info", } - message = { + message = { "routing_key": integration_key, "event_action": "trigger", "dedup_key": alert.id, @@ -1128,29 +1128,29 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post( - "https://events.pagerduty.com/v2/enqueue", json = message, timeout = 30.0 + response = await client.post( + "https://events.pagerduty.com/v2/enqueue", json=message, timeout=30.0 ) - success = response.status_code == 202 + success = response.status_code == 202 self._update_channel_stats(channel.id, success) return success async def _send_opsgenie_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 Opsgenie 告警""" - config = channel.config - api_key = config.get("api_key") + config = channel.config + api_key = config.get("api_key") if not api_key: return False - priority_map = { + priority_map = { AlertSeverity.P0.value: "P1", AlertSeverity.P1.value: "P2", AlertSeverity.P2.value: "P3", AlertSeverity.P3.value: "P4", } - message = { + message = { "message": alert.title, "description": alert.description, "priority": priority_map.get(alert.severity.value, "P3"), @@ -1163,26 +1163,26 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( "https://api.opsgenie.com/v2/alerts", - json = message, - headers = {"Authorization": f"GenieKey {api_key}"}, - timeout = 30.0, + json=message, + headers={"Authorization": f"GenieKey {api_key}"}, + timeout=30.0, ) - success = response.status_code in [200, 201, 202] + success = response.status_code in [200, 201, 202] self._update_channel_stats(channel.id, success) return success async def _send_webhook_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 Webhook 告警""" - config = channel.config - webhook_url = config.get("webhook_url") - headers = config.get("headers", {}) + config = channel.config + webhook_url = config.get("webhook_url") + headers = config.get("headers", {}) if not webhook_url: return False - message = { + message = { "alert_id": alert.id, "severity": alert.severity.value, "status": alert.status.value, @@ -1196,8 +1196,8 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json = message, headers = headers, timeout = 30.0) - success = response.status_code in [200, 201, 202] + response = await client.post(webhook_url, json=message, headers=headers, timeout=30.0) + success = response.status_code in [200, 201, 202] self._update_channel_stats(channel.id, success) return success @@ -1206,7 +1206,7 @@ class OpsManager: def get_active_alert_by_rule(self, rule_id: str) -> Alert | None: """获取规则对应的活跃告警""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT * FROM alerts WHERE rule_id = ? AND status = ? ORDER BY started_at DESC LIMIT 1""", @@ -1220,7 +1220,7 @@ class OpsManager: def get_alert(self, alert_id: str) -> Alert | None: """获取告警详情""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id, )).fetchone() + row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id, )).fetchone() if row: return self._row_to_alert(row) @@ -1229,13 +1229,13 @@ class OpsManager: def list_alerts( self, tenant_id: str, - status: AlertStatus | None = None, - severity: AlertSeverity | None = None, - limit: int = 100, + status: AlertStatus | None = None, + severity: AlertSeverity | None = None, + limit: int = 100, ) -> list[Alert]: """列出租户的告警""" - query = "SELECT * FROM alerts WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM alerts WHERE tenant_id = ?" + params = [tenant_id] if status: query += " AND status = ?" @@ -1248,12 +1248,12 @@ class OpsManager: params.append(limit) with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_alert(row) for row in rows] def acknowledge_alert(self, alert_id: str, user_id: str) -> Alert | None: """确认告警""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -1270,7 +1270,7 @@ class OpsManager: def resolve_alert(self, alert_id: str) -> Alert | None: """解决告警""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -1303,13 +1303,13 @@ class OpsManager: ) -> None: """更新告警通知状态""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT notification_sent FROM alerts WHERE id = ?", (alert_id, ) ).fetchone() if row: - notification_sent = json.loads(row["notification_sent"]) - notification_sent[channel_id] = success + notification_sent = json.loads(row["notification_sent"]) + notification_sent[channel_id] = success conn.execute( "UPDATE alerts SET notification_sent = ? WHERE id = ?", @@ -1319,7 +1319,7 @@ class OpsManager: def _update_channel_stats(self, channel_id: str, success: bool) -> None: """更新渠道统计""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: if success: @@ -1350,22 +1350,22 @@ class OpsManager: name: str, matchers: dict[str, str], duration: int, - is_regex: bool = False, - expires_at: str | None = None, + is_regex: bool = False, + expires_at: str | None = None, ) -> AlertSuppressionRule: """创建告警抑制规则""" - rule_id = f"sr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + rule_id = f"sr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - rule = AlertSuppressionRule( - id = rule_id, - tenant_id = tenant_id, - name = name, - matchers = matchers, - duration = duration, - is_regex = is_regex, - created_at = now, - expires_at = expires_at, + rule = AlertSuppressionRule( + id=rule_id, + tenant_id=tenant_id, + name=name, + matchers=matchers, + duration=duration, + is_regex=is_regex, + created_at=now, + expires_at=expires_at, ) with self._get_db() as conn: @@ -1393,12 +1393,12 @@ class OpsManager: def _is_alert_suppressed(self, rule: AlertRule) -> bool: """检查告警是否被抑制""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM alert_suppression_rules WHERE tenant_id = ?", (rule.tenant_id, ) ).fetchall() for row in rows: - suppression_rule = self._row_to_suppression_rule(row) + suppression_rule = self._row_to_suppression_rule(row) # 检查是否过期 if suppression_rule.expires_at: @@ -1406,19 +1406,19 @@ class OpsManager: continue # 检查匹配 - matchers = suppression_rule.matchers - match = True + matchers = suppression_rule.matchers + match = True for key, pattern in matchers.items(): - value = rule.labels.get(key, "") + value = rule.labels.get(key, "") if suppression_rule.is_regex: if not re.match(pattern, value): - match = False + match = False break else: if value != pattern: - match = False + match = False break if match: @@ -1436,22 +1436,22 @@ class OpsManager: metric_name: str, metric_value: float, unit: str, - metadata: dict = None, + metadata: dict = None, ) -> ResourceMetric: """记录资源指标""" - metric_id = f"rm_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + metric_id = f"rm_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - metric = ResourceMetric( - id = metric_id, - tenant_id = tenant_id, - resource_type = resource_type, - resource_id = resource_id, - metric_name = metric_name, - metric_value = metric_value, - unit = unit, - timestamp = now, - metadata = metadata or {}, + metric = ResourceMetric( + id=metric_id, + tenant_id=tenant_id, + resource_type=resource_type, + resource_id=resource_id, + metric_name=metric_name, + metric_value=metric_value, + unit=unit, + timestamp=now, + metadata=metadata or {}, ) with self._get_db() as conn: @@ -1479,13 +1479,13 @@ class OpsManager: return metric def get_recent_metrics( - self, tenant_id: str, metric_name: str, seconds: int = 3600 + self, tenant_id: str, metric_name: str, seconds: int = 3600 ) -> list[ResourceMetric]: """获取最近的指标数据""" - cutoff_time = (datetime.now() - timedelta(seconds = seconds)).isoformat() + cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat() with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM resource_metrics WHERE tenant_id = ? AND metric_name = ? AND timestamp > ? ORDER BY timestamp DESC""", @@ -1505,7 +1505,7 @@ class OpsManager: ) -> list[ResourceMetric]: """获取指定资源的指标数据""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM resource_metrics WHERE tenant_id = ? AND resource_type = ? AND resource_id = ? AND metric_name = ? AND timestamp BETWEEN ? AND ? @@ -1523,51 +1523,51 @@ class OpsManager: resource_type: ResourceType, current_capacity: float, prediction_date: str, - confidence: float = 0.8, + confidence: float = 0.8, ) -> CapacityPlan: """创建容量规划""" - plan_id = f"cp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + plan_id = f"cp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 基于历史数据预测 - metrics = self.get_recent_metrics( - tenant_id, f"{resource_type.value}_usage", seconds = 30 * 24 * 3600 + metrics = self.get_recent_metrics( + tenant_id, f"{resource_type.value}_usage", seconds=30 * 24 * 3600 ) if metrics: - values = [m.metric_value for m in metrics] - trend = self._calculate_trend(values) + values = [m.metric_value for m in metrics] + trend = self._calculate_trend(values) # 预测未来容量需求 - days_ahead = (datetime.fromisoformat(prediction_date) - datetime.now()).days - predicted_capacity = current_capacity * (1 + trend * days_ahead / 30) + days_ahead = (datetime.fromisoformat(prediction_date) - datetime.now()).days + predicted_capacity = current_capacity * (1 + trend * days_ahead / 30) # 推荐操作 if predicted_capacity > current_capacity * 1.2: - recommended_action = "scale_up" - estimated_cost = (predicted_capacity - current_capacity) * 10 # 简化计算 + recommended_action = "scale_up" + estimated_cost = (predicted_capacity - current_capacity) * 10 # 简化计算 elif predicted_capacity < current_capacity * 0.5: - recommended_action = "scale_down" - estimated_cost = 0 + recommended_action = "scale_down" + estimated_cost = 0 else: - recommended_action = "maintain" - estimated_cost = 0 + recommended_action = "maintain" + estimated_cost = 0 else: - predicted_capacity = current_capacity - recommended_action = "insufficient_data" - estimated_cost = 0 + predicted_capacity = current_capacity + recommended_action = "insufficient_data" + estimated_cost = 0 - plan = CapacityPlan( - id = plan_id, - tenant_id = tenant_id, - resource_type = resource_type, - current_capacity = current_capacity, - predicted_capacity = predicted_capacity, - prediction_date = prediction_date, - confidence = confidence, - recommended_action = recommended_action, - estimated_cost = estimated_cost, - created_at = now, + plan = CapacityPlan( + id=plan_id, + tenant_id=tenant_id, + resource_type=resource_type, + current_capacity=current_capacity, + predicted_capacity=predicted_capacity, + prediction_date=prediction_date, + confidence=confidence, + recommended_action=recommended_action, + estimated_cost=estimated_cost, + created_at=now, ) with self._get_db() as conn: @@ -1601,21 +1601,21 @@ class OpsManager: return 0.0 # 使用最近的数据计算趋势 - recent = values[-10:] if len(values) > 10 else values - n = len(recent) + recent = values[-10:] if len(values) > 10 else values + n = len(recent) if n < 2: return 0.0 # 简单线性回归计算斜率 - x = list(range(n)) - mean_x = sum(x) / n - mean_y = sum(recent) / n + x = list(range(n)) + mean_x = sum(x) / n + mean_y = sum(recent) / n - numerator = sum((x[i] - mean_x) * (recent[i] - mean_y) for i in range(n)) - denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) + numerator = sum((x[i] - mean_x) * (recent[i] - mean_y) for i in range(n)) + denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) - slope = numerator / denominator if denominator != 0 else 0 + slope = numerator / denominator if denominator != 0 else 0 # 归一化为增长率 if mean_y != 0: @@ -1625,7 +1625,7 @@ class OpsManager: def get_capacity_plans(self, tenant_id: str) -> list[CapacityPlan]: """获取容量规划列表""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM capacity_plans WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id, ), ).fetchall() @@ -1643,30 +1643,30 @@ class OpsManager: target_utilization: float, scale_up_threshold: float, scale_down_threshold: float, - scale_up_step: int = 1, - scale_down_step: int = 1, - cooldown_period: int = 300, + scale_up_step: int = 1, + scale_down_step: int = 1, + cooldown_period: int = 300, ) -> AutoScalingPolicy: """创建自动扩缩容策略""" - policy_id = f"asp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + policy_id = f"asp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - policy = AutoScalingPolicy( - id = policy_id, - tenant_id = tenant_id, - name = name, - resource_type = resource_type, - min_instances = min_instances, - max_instances = max_instances, - target_utilization = target_utilization, - scale_up_threshold = scale_up_threshold, - scale_down_threshold = scale_down_threshold, - scale_up_step = scale_up_step, - scale_down_step = scale_down_step, - cooldown_period = cooldown_period, - is_enabled = True, - created_at = now, - updated_at = now, + policy = AutoScalingPolicy( + id=policy_id, + tenant_id=tenant_id, + name=name, + resource_type=resource_type, + min_instances=min_instances, + max_instances=max_instances, + target_utilization=target_utilization, + scale_up_threshold=scale_up_threshold, + scale_down_threshold=scale_down_threshold, + scale_up_step=scale_up_step, + scale_down_step=scale_down_step, + cooldown_period=cooldown_period, + is_enabled=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1703,7 +1703,7 @@ class OpsManager: def get_auto_scaling_policy(self, policy_id: str) -> AutoScalingPolicy | None: """获取自动扩缩容策略""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id, ) ).fetchone() @@ -1714,7 +1714,7 @@ class OpsManager: def list_auto_scaling_policies(self, tenant_id: str) -> list[AutoScalingPolicy]: """列出租户的自动扩缩容策略""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM auto_scaling_policies WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id, ), ).fetchall() @@ -1724,36 +1724,36 @@ class OpsManager: self, policy_id: str, current_instances: int, current_utilization: float ) -> ScalingEvent | None: """评估扩缩容策略""" - policy = self.get_auto_scaling_policy(policy_id) + policy = self.get_auto_scaling_policy(policy_id) if not policy or not policy.is_enabled: return None # 检查是否在冷却期 - last_event = self.get_last_scaling_event(policy_id) + last_event = self.get_last_scaling_event(policy_id) if last_event: - last_time = datetime.fromisoformat(last_event.started_at) + last_time = datetime.fromisoformat(last_event.started_at) if (datetime.now() - last_time).total_seconds() < policy.cooldown_period: return None - action = None - reason = "" + action = None + reason = "" if current_utilization > policy.scale_up_threshold: if current_instances < policy.max_instances: - action = ScalingAction.SCALE_UP - reason = ( + action = ScalingAction.SCALE_UP + reason = ( f"利用率 {current_utilization:.1%} 超过扩容阈值 {policy.scale_up_threshold:.1%}" ) elif current_utilization < policy.scale_down_threshold: if current_instances > policy.min_instances: - action = ScalingAction.SCALE_DOWN - reason = f"利用率 {current_utilization:.1%} 低于缩容阈值 {policy.scale_down_threshold:.1%}" + action = ScalingAction.SCALE_DOWN + reason = f"利用率 {current_utilization:.1%} 低于缩容阈值 {policy.scale_down_threshold:.1%}" if action: if action == ScalingAction.SCALE_UP: - new_count = min(current_instances + policy.scale_up_step, policy.max_instances) + new_count = min(current_instances + policy.scale_up_step, policy.max_instances) else: - new_count = max(current_instances - policy.scale_down_step, policy.min_instances) + new_count = max(current_instances - policy.scale_down_step, policy.min_instances) return self._create_scaling_event(policy, action, current_instances, new_count, reason) @@ -1768,22 +1768,22 @@ class OpsManager: reason: str, ) -> ScalingEvent: """创建扩缩容事件""" - event_id = f"se_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + event_id = f"se_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - event = ScalingEvent( - id = event_id, - policy_id = policy.id, - tenant_id = policy.tenant_id, - action = action, - from_count = from_count, - to_count = to_count, - reason = reason, - triggered_by = "auto", - status = "pending", - started_at = now, - completed_at = None, - error_message = None, + event = ScalingEvent( + id=event_id, + policy_id=policy.id, + tenant_id=policy.tenant_id, + action=action, + from_count=from_count, + to_count=to_count, + reason=reason, + triggered_by="auto", + status="pending", + started_at=now, + completed_at=None, + error_message=None, ) with self._get_db() as conn: @@ -1814,7 +1814,7 @@ class OpsManager: def get_last_scaling_event(self, policy_id: str) -> ScalingEvent | None: """获取最近的扩缩容事件""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT * FROM scaling_events WHERE policy_id = ? ORDER BY started_at DESC LIMIT 1""", @@ -1826,10 +1826,10 @@ class OpsManager: return None def update_scaling_event_status( - self, event_id: str, status: str, error_message: str = None + self, event_id: str, status: str, error_message: str = None ) -> ScalingEvent | None: """更新扩缩容事件状态""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: if status in ["completed", "failed"]: @@ -1857,18 +1857,18 @@ class OpsManager: def get_scaling_event(self, event_id: str) -> ScalingEvent | None: """获取扩缩容事件""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id, )).fetchone() + row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id, )).fetchone() if row: return self._row_to_scaling_event(row) return None def list_scaling_events( - self, tenant_id: str, policy_id: str = None, limit: int = 100 + self, tenant_id: str, policy_id: str = None, limit: int = 100 ) -> list[ScalingEvent]: """列出租户的扩缩容事件""" - query = "SELECT * FROM scaling_events WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM scaling_events WHERE tenant_id = ?" + params = [tenant_id] if policy_id: query += " AND policy_id = ?" @@ -1878,7 +1878,7 @@ class OpsManager: params.append(limit) with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_scaling_event(row) for row in rows] # ==================== 健康检查与故障转移 ==================== @@ -1891,30 +1891,30 @@ class OpsManager: target_id: str, check_type: str, check_config: dict, - interval: int = 60, - timeout: int = 10, - retry_count: int = 3, + interval: int = 60, + timeout: int = 10, + retry_count: int = 3, ) -> HealthCheck: """创建健康检查""" - check_id = f"hc_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + check_id = f"hc_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - check = HealthCheck( - id = check_id, - tenant_id = tenant_id, - name = name, - target_type = target_type, - target_id = target_id, - check_type = check_type, - check_config = check_config, - interval = interval, - timeout = timeout, - retry_count = retry_count, - healthy_threshold = 2, - unhealthy_threshold = 3, - is_enabled = True, - created_at = now, - updated_at = now, + check = HealthCheck( + id=check_id, + tenant_id=tenant_id, + name=name, + target_type=target_type, + target_id=target_id, + check_type=check_type, + check_config=check_config, + interval=interval, + timeout=timeout, + retry_count=retry_count, + healthy_threshold=2, + unhealthy_threshold=3, + is_enabled=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -1951,7 +1951,7 @@ class OpsManager: def get_health_check(self, check_id: str) -> HealthCheck | None: """获取健康检查配置""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id, )).fetchone() + row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id, )).fetchone() if row: return self._row_to_health_check(row) @@ -1960,7 +1960,7 @@ class OpsManager: def list_health_checks(self, tenant_id: str) -> list[HealthCheck]: """列出租户的健康检查""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM health_checks WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id, ), ).fetchall() @@ -1968,32 +1968,32 @@ class OpsManager: async def execute_health_check(self, check_id: str) -> HealthCheckResult: """执行健康检查""" - check = self.get_health_check(check_id) + check = self.get_health_check(check_id) if not check: raise ValueError(f"Health check {check_id} not found") - result_id = f"hcr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + result_id = f"hcr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 模拟健康检查(实际实现需要根据 check_type 执行具体检查) if check.check_type == "http": - status, response_time, message = await self._check_http_health(check) + status, response_time, message = await self._check_http_health(check) elif check.check_type == "tcp": - status, response_time, message = await self._check_tcp_health(check) + status, response_time, message = await self._check_tcp_health(check) elif check.check_type == "ping": - status, response_time, message = await self._check_ping_health(check) + status, response_time, message = await self._check_ping_health(check) else: - status, response_time, message = HealthStatus.UNKNOWN, 0, "Unknown check type" + status, response_time, message = HealthStatus.UNKNOWN, 0, "Unknown check type" - result = HealthCheckResult( - id = result_id, - check_id = check_id, - tenant_id = check.tenant_id, - status = status, - response_time = response_time, - message = message, - details = {}, - checked_at = now, + result = HealthCheckResult( + id=result_id, + check_id=check_id, + tenant_id=check.tenant_id, + status=status, + response_time=response_time, + message=message, + details={}, + checked_at=now, ) with self._get_db() as conn: @@ -2020,18 +2020,18 @@ class OpsManager: async def _check_http_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]: """HTTP 健康检查""" - config = check.check_config - url = config.get("url") - expected_status = config.get("expected_status", 200) + config = check.check_config + url = config.get("url") + expected_status = config.get("expected_status", 200) if not url: return HealthStatus.UNHEALTHY, 0, "URL not configured" - start_time = time.time() + start_time = time.time() try: async with httpx.AsyncClient() as client: - response = await client.get(url, timeout = check.timeout) - response_time = (time.time() - start_time) * 1000 + response = await client.get(url, timeout=check.timeout) + response_time = (time.time() - start_time) * 1000 if response.status_code == expected_status: return HealthStatus.HEALTHY, response_time, "OK" @@ -2046,19 +2046,19 @@ class OpsManager: async def _check_tcp_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]: """TCP 健康检查""" - config = check.check_config - host = config.get("host") - port = config.get("port") + config = check.check_config + host = config.get("host") + port = config.get("port") if not host or not port: return HealthStatus.UNHEALTHY, 0, "Host or port not configured" - start_time = time.time() + start_time = time.time() try: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(host, port), timeout = check.timeout + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=check.timeout ) - response_time = (time.time() - start_time) * 1000 + response_time = (time.time() - start_time) * 1000 writer.close() await writer.wait_closed() return HealthStatus.HEALTHY, response_time, "TCP connection successful" @@ -2069,8 +2069,8 @@ class OpsManager: async def _check_ping_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]: """Ping 健康检查(模拟)""" - config = check.check_config - host = config.get("host") + config = check.check_config + host = config.get("host") if not host: return HealthStatus.UNHEALTHY, 0, "Host not configured" @@ -2079,10 +2079,10 @@ class OpsManager: # 这里模拟成功 return HealthStatus.HEALTHY, 10.0, "Ping successful" - def get_health_check_results(self, check_id: str, limit: int = 100) -> list[HealthCheckResult]: + def get_health_check_results(self, check_id: str, limit: int = 100) -> list[HealthCheckResult]: """获取健康检查历史结果""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM health_check_results WHERE check_id = ? ORDER BY checked_at DESC LIMIT ?""", @@ -2099,27 +2099,27 @@ class OpsManager: primary_region: str, secondary_regions: list[str], failover_trigger: str, - auto_failover: bool = False, - failover_timeout: int = 300, - health_check_id: str = None, + auto_failover: bool = False, + failover_timeout: int = 300, + health_check_id: str = None, ) -> FailoverConfig: """创建故障转移配置""" - config_id = f"fc_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + config_id = f"fc_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - config = FailoverConfig( - id = config_id, - tenant_id = tenant_id, - name = name, - primary_region = primary_region, - secondary_regions = secondary_regions, - failover_trigger = failover_trigger, - auto_failover = auto_failover, - failover_timeout = failover_timeout, - health_check_id = health_check_id, - is_enabled = True, - created_at = now, - updated_at = now, + config = FailoverConfig( + id=config_id, + tenant_id=tenant_id, + name=name, + primary_region=primary_region, + secondary_regions=secondary_regions, + failover_trigger=failover_trigger, + auto_failover=auto_failover, + failover_timeout=failover_timeout, + health_check_id=health_check_id, + is_enabled=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -2152,7 +2152,7 @@ class OpsManager: def get_failover_config(self, config_id: str) -> FailoverConfig | None: """获取故障转移配置""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM failover_configs WHERE id = ?", (config_id, ) ).fetchone() @@ -2163,7 +2163,7 @@ class OpsManager: def list_failover_configs(self, tenant_id: str) -> list[FailoverConfig]: """列出租户的故障转移配置""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM failover_configs WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id, ), ).fetchall() @@ -2171,30 +2171,30 @@ class OpsManager: def initiate_failover(self, config_id: str, reason: str) -> FailoverEvent | None: """发起故障转移""" - config = self.get_failover_config(config_id) + config = self.get_failover_config(config_id) if not config or not config.is_enabled: return None - event_id = f"fe_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + event_id = f"fe_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 选择备用区域 - to_region = config.secondary_regions[0] if config.secondary_regions else None + to_region = config.secondary_regions[0] if config.secondary_regions else None if not to_region: return None - event = FailoverEvent( - id = event_id, - config_id = config_id, - tenant_id = config.tenant_id, - from_region = config.primary_region, - to_region = to_region, - reason = reason, - status = "initiated", - started_at = now, - completed_at = None, - rolled_back_at = None, + event = FailoverEvent( + id=event_id, + config_id=config_id, + tenant_id=config.tenant_id, + from_region=config.primary_region, + to_region=to_region, + reason=reason, + status="initiated", + started_at=now, + completed_at=None, + rolled_back_at=None, ) with self._get_db() as conn: @@ -2221,7 +2221,7 @@ class OpsManager: def update_failover_status(self, event_id: str, status: str) -> FailoverEvent | None: """更新故障转移状态""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: if status == "completed": @@ -2258,16 +2258,16 @@ class OpsManager: def get_failover_event(self, event_id: str) -> FailoverEvent | None: """获取故障转移事件""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id, )).fetchone() + row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id, )).fetchone() if row: return self._row_to_failover_event(row) return None - def list_failover_events(self, tenant_id: str, limit: int = 100) -> list[FailoverEvent]: + def list_failover_events(self, tenant_id: str, limit: int = 100) -> list[FailoverEvent]: """列出租户的故障转移事件""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM failover_events WHERE tenant_id = ? ORDER BY started_at DESC LIMIT ?""", @@ -2285,30 +2285,30 @@ class OpsManager: target_type: str, target_id: str, schedule: str, - retention_days: int = 30, - encryption_enabled: bool = True, - compression_enabled: bool = True, - storage_location: str = None, + retention_days: int = 30, + encryption_enabled: bool = True, + compression_enabled: bool = True, + storage_location: str = None, ) -> BackupJob: """创建备份任务""" - job_id = f"bj_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + job_id = f"bj_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - job = BackupJob( - id = job_id, - tenant_id = tenant_id, - name = name, - backup_type = backup_type, - target_type = target_type, - target_id = target_id, - schedule = schedule, - retention_days = retention_days, - encryption_enabled = encryption_enabled, - compression_enabled = compression_enabled, - storage_location = storage_location or f"backups/{tenant_id}", - is_enabled = True, - created_at = now, - updated_at = now, + job = BackupJob( + id=job_id, + tenant_id=tenant_id, + name=name, + backup_type=backup_type, + target_type=target_type, + target_id=target_id, + schedule=schedule, + retention_days=retention_days, + encryption_enabled=encryption_enabled, + compression_enabled=compression_enabled, + storage_location=storage_location or f"backups/{tenant_id}", + is_enabled=True, + created_at=now, + updated_at=now, ) with self._get_db() as conn: @@ -2344,7 +2344,7 @@ class OpsManager: def get_backup_job(self, job_id: str) -> BackupJob | None: """获取备份任务""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id, )).fetchone() + row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id, )).fetchone() if row: return self._row_to_backup_job(row) @@ -2353,7 +2353,7 @@ class OpsManager: def list_backup_jobs(self, tenant_id: str) -> list[BackupJob]: """列出租户的备份任务""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM backup_jobs WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id, ), ).fetchall() @@ -2361,25 +2361,25 @@ class OpsManager: def execute_backup(self, job_id: str) -> BackupRecord | None: """执行备份""" - job = self.get_backup_job(job_id) + job = self.get_backup_job(job_id) if not job or not job.is_enabled: return None - record_id = f"br_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + record_id = f"br_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - record = BackupRecord( - id = record_id, - job_id = job_id, - tenant_id = job.tenant_id, - status = BackupStatus.IN_PROGRESS, - size_bytes = 0, - checksum = "", - started_at = now, - completed_at = None, - verified_at = None, - error_message = None, - storage_path = f"{job.storage_location}/{record_id}", + record = BackupRecord( + id=record_id, + job_id=job_id, + tenant_id=job.tenant_id, + status=BackupStatus.IN_PROGRESS, + size_bytes=0, + checksum="", + started_at=now, + completed_at=None, + verified_at=None, + error_message=None, + storage_path=f"{job.storage_location}/{record_id}", ) with self._get_db() as conn: @@ -2404,14 +2404,14 @@ class OpsManager: # 异步执行备份(实际实现中应该启动后台任务) # 这里模拟备份完成 - self._complete_backup(record_id, size_bytes = 1024 * 1024 * 100) # 模拟100MB + self._complete_backup(record_id, size_bytes=1024 * 1024 * 100) # 模拟100MB return record - def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None: + def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None: """完成备份""" - now = datetime.now().isoformat() - checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] + now = datetime.now().isoformat() + checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] with self._get_db() as conn: conn.execute( @@ -2427,18 +2427,18 @@ class OpsManager: def get_backup_record(self, record_id: str) -> BackupRecord | None: """获取备份记录""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id, )).fetchone() + row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id, )).fetchone() if row: return self._row_to_backup_record(row) return None def list_backup_records( - self, tenant_id: str, job_id: str = None, limit: int = 100 + self, tenant_id: str, job_id: str = None, limit: int = 100 ) -> list[BackupRecord]: """列出租户的备份记录""" - query = "SELECT * FROM backup_records WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM backup_records WHERE tenant_id = ?" + params = [tenant_id] if job_id: query += " AND job_id = ?" @@ -2448,12 +2448,12 @@ class OpsManager: params.append(limit) with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_backup_record(row) for row in rows] def restore_from_backup(self, record_id: str) -> bool: """从备份恢复""" - record = self.get_backup_record(record_id) + record = self.get_backup_record(record_id) if not record or record.status != BackupStatus.COMPLETED: return False @@ -2465,42 +2465,42 @@ class OpsManager: def generate_cost_report(self, tenant_id: str, year: int, month: int) -> CostReport: """生成成本报告""" - report_id = f"cr_{uuid.uuid4().hex[:16]}" - report_period = f"{year:04d}-{month:02d}" - now = datetime.now().isoformat() + report_id = f"cr_{uuid.uuid4().hex[:16]}" + report_period = f"{year:04d}-{month:02d}" + now = datetime.now().isoformat() # 获取资源利用率数据 - utilizations = self.get_resource_utilizations(tenant_id, report_period) + utilizations = self.get_resource_utilizations(tenant_id, report_period) # 计算成本分解 - breakdown = {} - total_cost = 0.0 + breakdown = {} + total_cost = 0.0 for util in utilizations: # 简化计算:假设每单位资源每月成本 - unit_cost = 10.0 - resource_cost = unit_cost * util.utilization_rate - breakdown[util.resource_type.value] = ( + unit_cost = 10.0 + resource_cost = unit_cost * util.utilization_rate + breakdown[util.resource_type.value] = ( breakdown.get(util.resource_type.value, 0) + resource_cost ) total_cost += resource_cost # 检测异常 - anomalies = self._detect_cost_anomalies(utilizations) + anomalies = self._detect_cost_anomalies(utilizations) # 计算趋势 - trends = self._calculate_cost_trends(tenant_id, year, month) + trends = self._calculate_cost_trends(tenant_id, year, month) - report = CostReport( - id = report_id, - tenant_id = tenant_id, - report_period = report_period, - total_cost = total_cost, - currency = "CNY", - breakdown = breakdown, - trends = trends, - anomalies = anomalies, - created_at = now, + report = CostReport( + id=report_id, + tenant_id=tenant_id, + report_period=report_period, + total_cost=total_cost, + currency="CNY", + breakdown=breakdown, + trends=trends, + anomalies=anomalies, + created_at=now, ) with self._get_db() as conn: @@ -2528,7 +2528,7 @@ class OpsManager: def _detect_cost_anomalies(self, utilizations: list[ResourceUtilization]) -> list[dict]: """检测成本异常""" - anomalies = [] + anomalies = [] for util in utilizations: # 检测低利用率 @@ -2576,22 +2576,22 @@ class OpsManager: avg_utilization: float, idle_time_percent: float, report_date: str, - recommendations: list[str] = None, + recommendations: list[str] = None, ) -> ResourceUtilization: """记录资源利用率""" - util_id = f"ru_{uuid.uuid4().hex[:16]}" + util_id = f"ru_{uuid.uuid4().hex[:16]}" - util = ResourceUtilization( - id = util_id, - tenant_id = tenant_id, - resource_type = resource_type, - resource_id = resource_id, - utilization_rate = utilization_rate, - peak_utilization = peak_utilization, - avg_utilization = avg_utilization, - idle_time_percent = idle_time_percent, - report_date = report_date, - recommendations = recommendations or [], + util = ResourceUtilization( + id=util_id, + tenant_id=tenant_id, + resource_type=resource_type, + resource_id=resource_id, + utilization_rate=utilization_rate, + peak_utilization=peak_utilization, + avg_utilization=avg_utilization, + idle_time_percent=idle_time_percent, + report_date=report_date, + recommendations=recommendations or [], ) with self._get_db() as conn: @@ -2624,7 +2624,7 @@ class OpsManager: ) -> list[ResourceUtilization]: """获取资源利用率列表""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM resource_utilizations WHERE tenant_id = ? AND report_date LIKE ? ORDER BY report_date DESC""", @@ -2634,12 +2634,12 @@ class OpsManager: def detect_idle_resources(self, tenant_id: str) -> list[IdleResource]: """检测闲置资源""" - idle_resources = [] + idle_resources = [] # 获取最近30天的利用率数据 with self._get_db() as conn: - thirty_days_ago = (datetime.now() - timedelta(days = 30)).isoformat() - rows = conn.execute( + thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + rows = conn.execute( """SELECT resource_type, resource_id, AVG(utilization_rate) as avg_utilization, MAX(idle_time_percent) as max_idle_time FROM resource_utilizations @@ -2650,21 +2650,21 @@ class OpsManager: ).fetchall() for row in rows: - idle_id = f"ir_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + idle_id = f"ir_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - idle_resource = IdleResource( - id = idle_id, - tenant_id = tenant_id, - resource_type = ResourceType(row["resource_type"]), - resource_id = row["resource_id"], - resource_name = f"{row['resource_type']}-{row['resource_id']}", - idle_since = thirty_days_ago, - estimated_monthly_cost = 50.0, # 简化计算 - currency = "CNY", - reason = "Low utilization rate over 30 days", - recommendation = "Consider downsizing or terminating this resource", - detected_at = now, + idle_resource = IdleResource( + id=idle_id, + tenant_id=tenant_id, + resource_type=ResourceType(row["resource_type"]), + resource_id=row["resource_id"], + resource_name=f"{row['resource_type']}-{row['resource_id']}", + idle_since=thirty_days_ago, + estimated_monthly_cost=50.0, # 简化计算 + currency="CNY", + reason="Low utilization rate over 30 days", + recommendation="Consider downsizing or terminating this resource", + detected_at=now, ) conn.execute( @@ -2698,7 +2698,7 @@ class OpsManager: def get_idle_resources(self, tenant_id: str) -> list[IdleResource]: """获取闲置资源列表""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM idle_resources WHERE tenant_id = ? ORDER BY detected_at DESC", (tenant_id, ), ).fetchall() @@ -2708,36 +2708,36 @@ class OpsManager: self, tenant_id: str ) -> list[CostOptimizationSuggestion]: """生成成本优化建议""" - suggestions = [] + suggestions = [] # 基于闲置资源生成建议 - idle_resources = self.detect_idle_resources(tenant_id) + idle_resources = self.detect_idle_resources(tenant_id) - total_potential_savings = sum(r.estimated_monthly_cost for r in idle_resources) + total_potential_savings = sum(r.estimated_monthly_cost for r in idle_resources) if total_potential_savings > 0: - suggestion_id = f"cos_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + suggestion_id = f"cos_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - suggestion = CostOptimizationSuggestion( - id = suggestion_id, - tenant_id = tenant_id, - category = "resource_rightsize", - title = "清理闲置资源", - description = f"检测到 {len(idle_resources)} 个闲置资源,建议清理以节省成本。", - potential_savings = total_potential_savings, - currency = "CNY", - confidence = 0.85, - difficulty = "easy", - implementation_steps = [ + suggestion = CostOptimizationSuggestion( + id=suggestion_id, + tenant_id=tenant_id, + category="resource_rightsize", + title="清理闲置资源", + description=f"检测到 {len(idle_resources)} 个闲置资源,建议清理以节省成本。", + potential_savings=total_potential_savings, + currency="CNY", + confidence=0.85, + difficulty="easy", + implementation_steps=[ "Review the list of idle resources", "Confirm resources are no longer needed", "Terminate or downsize unused resources", ], - risk_level = "low", - is_applied = False, - created_at = now, - applied_at = None, + risk_level="low", + is_applied=False, + created_at=now, + applied_at=None, ) with self._get_db() as conn: @@ -2773,11 +2773,11 @@ class OpsManager: return suggestions def get_cost_optimization_suggestions( - self, tenant_id: str, is_applied: bool = None + self, tenant_id: str, is_applied: bool = None ) -> list[CostOptimizationSuggestion]: """获取成本优化建议""" - query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?" + params = [tenant_id] if is_applied is not None: query += " AND is_applied = ?" @@ -2786,14 +2786,14 @@ class OpsManager: query += " ORDER BY potential_savings DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_cost_optimization_suggestion(row) for row in rows] def apply_cost_optimization_suggestion( self, suggestion_id: str ) -> CostOptimizationSuggestion | None: """应用成本优化建议""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( @@ -2813,7 +2813,7 @@ class OpsManager: ) -> CostOptimizationSuggestion | None: """获取成本优化建议详情""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id, ) ).fetchone() @@ -2825,286 +2825,286 @@ class OpsManager: def _row_to_alert_rule(self, row) -> AlertRule: return AlertRule( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - description = row["description"], - rule_type = AlertRuleType(row["rule_type"]), - severity = AlertSeverity(row["severity"]), - metric = row["metric"], - condition = row["condition"], - threshold = row["threshold"], - duration = row["duration"], - evaluation_interval = row["evaluation_interval"], - channels = json.loads(row["channels"]), - labels = json.loads(row["labels"]), - annotations = json.loads(row["annotations"]), - is_enabled = bool(row["is_enabled"]), - created_at = row["created_at"], - updated_at = row["updated_at"], - created_by = row["created_by"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + description=row["description"], + rule_type=AlertRuleType(row["rule_type"]), + severity=AlertSeverity(row["severity"]), + metric=row["metric"], + condition=row["condition"], + threshold=row["threshold"], + duration=row["duration"], + evaluation_interval=row["evaluation_interval"], + channels=json.loads(row["channels"]), + labels=json.loads(row["labels"]), + annotations=json.loads(row["annotations"]), + is_enabled=bool(row["is_enabled"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + created_by=row["created_by"], ) def _row_to_alert_channel(self, row) -> AlertChannel: return AlertChannel( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - channel_type = AlertChannelType(row["channel_type"]), - config = json.loads(row["config"]), - severity_filter = json.loads(row["severity_filter"]), - is_enabled = bool(row["is_enabled"]), - success_count = row["success_count"], - fail_count = row["fail_count"], - last_used_at = row["last_used_at"], - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + channel_type=AlertChannelType(row["channel_type"]), + config=json.loads(row["config"]), + severity_filter=json.loads(row["severity_filter"]), + is_enabled=bool(row["is_enabled"]), + success_count=row["success_count"], + fail_count=row["fail_count"], + last_used_at=row["last_used_at"], + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_alert(self, row) -> Alert: return Alert( - id = row["id"], - rule_id = row["rule_id"], - tenant_id = row["tenant_id"], - severity = AlertSeverity(row["severity"]), - status = AlertStatus(row["status"]), - title = row["title"], - description = row["description"], - metric = row["metric"], - value = row["value"], - threshold = row["threshold"], - labels = json.loads(row["labels"]), - annotations = json.loads(row["annotations"]), - started_at = row["started_at"], - resolved_at = row["resolved_at"], - acknowledged_by = row["acknowledged_by"], - acknowledged_at = row["acknowledged_at"], - notification_sent = json.loads(row["notification_sent"]), - suppression_count = row["suppression_count"], + id=row["id"], + rule_id=row["rule_id"], + tenant_id=row["tenant_id"], + severity=AlertSeverity(row["severity"]), + status=AlertStatus(row["status"]), + title=row["title"], + description=row["description"], + metric=row["metric"], + value=row["value"], + threshold=row["threshold"], + labels=json.loads(row["labels"]), + annotations=json.loads(row["annotations"]), + started_at=row["started_at"], + resolved_at=row["resolved_at"], + acknowledged_by=row["acknowledged_by"], + acknowledged_at=row["acknowledged_at"], + notification_sent=json.loads(row["notification_sent"]), + suppression_count=row["suppression_count"], ) def _row_to_suppression_rule(self, row) -> AlertSuppressionRule: return AlertSuppressionRule( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - matchers = json.loads(row["matchers"]), - duration = row["duration"], - is_regex = bool(row["is_regex"]), - created_at = row["created_at"], - expires_at = row["expires_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + matchers=json.loads(row["matchers"]), + duration=row["duration"], + is_regex=bool(row["is_regex"]), + created_at=row["created_at"], + expires_at=row["expires_at"], ) def _row_to_resource_metric(self, row) -> ResourceMetric: return ResourceMetric( - id = row["id"], - tenant_id = row["tenant_id"], - resource_type = ResourceType(row["resource_type"]), - resource_id = row["resource_id"], - metric_name = row["metric_name"], - metric_value = row["metric_value"], - unit = row["unit"], - timestamp = row["timestamp"], - metadata = json.loads(row["metadata"]), + id=row["id"], + tenant_id=row["tenant_id"], + resource_type=ResourceType(row["resource_type"]), + resource_id=row["resource_id"], + metric_name=row["metric_name"], + metric_value=row["metric_value"], + unit=row["unit"], + timestamp=row["timestamp"], + metadata=json.loads(row["metadata"]), ) def _row_to_capacity_plan(self, row) -> CapacityPlan: return CapacityPlan( - id = row["id"], - tenant_id = row["tenant_id"], - resource_type = ResourceType(row["resource_type"]), - current_capacity = row["current_capacity"], - predicted_capacity = row["predicted_capacity"], - prediction_date = row["prediction_date"], - confidence = row["confidence"], - recommended_action = row["recommended_action"], - estimated_cost = row["estimated_cost"], - created_at = row["created_at"], + id=row["id"], + tenant_id=row["tenant_id"], + resource_type=ResourceType(row["resource_type"]), + current_capacity=row["current_capacity"], + predicted_capacity=row["predicted_capacity"], + prediction_date=row["prediction_date"], + confidence=row["confidence"], + recommended_action=row["recommended_action"], + estimated_cost=row["estimated_cost"], + created_at=row["created_at"], ) def _row_to_auto_scaling_policy(self, row) -> AutoScalingPolicy: return AutoScalingPolicy( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - resource_type = ResourceType(row["resource_type"]), - min_instances = row["min_instances"], - max_instances = row["max_instances"], - target_utilization = row["target_utilization"], - scale_up_threshold = row["scale_up_threshold"], - scale_down_threshold = row["scale_down_threshold"], - scale_up_step = row["scale_up_step"], - scale_down_step = row["scale_down_step"], - cooldown_period = row["cooldown_period"], - is_enabled = bool(row["is_enabled"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + resource_type=ResourceType(row["resource_type"]), + min_instances=row["min_instances"], + max_instances=row["max_instances"], + target_utilization=row["target_utilization"], + scale_up_threshold=row["scale_up_threshold"], + scale_down_threshold=row["scale_down_threshold"], + scale_up_step=row["scale_up_step"], + scale_down_step=row["scale_down_step"], + cooldown_period=row["cooldown_period"], + is_enabled=bool(row["is_enabled"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_scaling_event(self, row) -> ScalingEvent: return ScalingEvent( - id = row["id"], - policy_id = row["policy_id"], - tenant_id = row["tenant_id"], - action = ScalingAction(row["action"]), - from_count = row["from_count"], - to_count = row["to_count"], - reason = row["reason"], - triggered_by = row["triggered_by"], - status = row["status"], - started_at = row["started_at"], - completed_at = row["completed_at"], - error_message = row["error_message"], + id=row["id"], + policy_id=row["policy_id"], + tenant_id=row["tenant_id"], + action=ScalingAction(row["action"]), + from_count=row["from_count"], + to_count=row["to_count"], + reason=row["reason"], + triggered_by=row["triggered_by"], + status=row["status"], + started_at=row["started_at"], + completed_at=row["completed_at"], + error_message=row["error_message"], ) def _row_to_health_check(self, row) -> HealthCheck: return HealthCheck( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - target_type = row["target_type"], - target_id = row["target_id"], - check_type = row["check_type"], - check_config = json.loads(row["check_config"]), - interval = row["interval"], - timeout = row["timeout"], - retry_count = row["retry_count"], - healthy_threshold = row["healthy_threshold"], - unhealthy_threshold = row["unhealthy_threshold"], - is_enabled = bool(row["is_enabled"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + target_type=row["target_type"], + target_id=row["target_id"], + check_type=row["check_type"], + check_config=json.loads(row["check_config"]), + interval=row["interval"], + timeout=row["timeout"], + retry_count=row["retry_count"], + healthy_threshold=row["healthy_threshold"], + unhealthy_threshold=row["unhealthy_threshold"], + is_enabled=bool(row["is_enabled"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_health_check_result(self, row) -> HealthCheckResult: return HealthCheckResult( - id = row["id"], - check_id = row["check_id"], - tenant_id = row["tenant_id"], - status = HealthStatus(row["status"]), - response_time = row["response_time"], - message = row["message"], - details = json.loads(row["details"]), - checked_at = row["checked_at"], + id=row["id"], + check_id=row["check_id"], + tenant_id=row["tenant_id"], + status=HealthStatus(row["status"]), + response_time=row["response_time"], + message=row["message"], + details=json.loads(row["details"]), + checked_at=row["checked_at"], ) def _row_to_failover_config(self, row) -> FailoverConfig: return FailoverConfig( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - primary_region = row["primary_region"], - secondary_regions = json.loads(row["secondary_regions"]), - failover_trigger = row["failover_trigger"], - auto_failover = bool(row["auto_failover"]), - failover_timeout = row["failover_timeout"], - health_check_id = row["health_check_id"], - is_enabled = bool(row["is_enabled"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + primary_region=row["primary_region"], + secondary_regions=json.loads(row["secondary_regions"]), + failover_trigger=row["failover_trigger"], + auto_failover=bool(row["auto_failover"]), + failover_timeout=row["failover_timeout"], + health_check_id=row["health_check_id"], + is_enabled=bool(row["is_enabled"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_failover_event(self, row) -> FailoverEvent: return FailoverEvent( - id = row["id"], - config_id = row["config_id"], - tenant_id = row["tenant_id"], - from_region = row["from_region"], - to_region = row["to_region"], - reason = row["reason"], - status = row["status"], - started_at = row["started_at"], - completed_at = row["completed_at"], - rolled_back_at = row["rolled_back_at"], + id=row["id"], + config_id=row["config_id"], + tenant_id=row["tenant_id"], + from_region=row["from_region"], + to_region=row["to_region"], + reason=row["reason"], + status=row["status"], + started_at=row["started_at"], + completed_at=row["completed_at"], + rolled_back_at=row["rolled_back_at"], ) def _row_to_backup_job(self, row) -> BackupJob: return BackupJob( - id = row["id"], - tenant_id = row["tenant_id"], - name = row["name"], - backup_type = row["backup_type"], - target_type = row["target_type"], - target_id = row["target_id"], - schedule = row["schedule"], - retention_days = row["retention_days"], - encryption_enabled = bool(row["encryption_enabled"]), - compression_enabled = bool(row["compression_enabled"]), - storage_location = row["storage_location"], - is_enabled = bool(row["is_enabled"]), - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + tenant_id=row["tenant_id"], + name=row["name"], + backup_type=row["backup_type"], + target_type=row["target_type"], + target_id=row["target_id"], + schedule=row["schedule"], + retention_days=row["retention_days"], + encryption_enabled=bool(row["encryption_enabled"]), + compression_enabled=bool(row["compression_enabled"]), + storage_location=row["storage_location"], + is_enabled=bool(row["is_enabled"]), + created_at=row["created_at"], + updated_at=row["updated_at"], ) def _row_to_backup_record(self, row) -> BackupRecord: return BackupRecord( - id = row["id"], - job_id = row["job_id"], - tenant_id = row["tenant_id"], - status = BackupStatus(row["status"]), - size_bytes = row["size_bytes"], - checksum = row["checksum"], - started_at = row["started_at"], - completed_at = row["completed_at"], - verified_at = row["verified_at"], - error_message = row["error_message"], - storage_path = row["storage_path"], + id=row["id"], + job_id=row["job_id"], + tenant_id=row["tenant_id"], + status=BackupStatus(row["status"]), + size_bytes=row["size_bytes"], + checksum=row["checksum"], + started_at=row["started_at"], + completed_at=row["completed_at"], + verified_at=row["verified_at"], + error_message=row["error_message"], + storage_path=row["storage_path"], ) def _row_to_resource_utilization(self, row) -> ResourceUtilization: return ResourceUtilization( - id = row["id"], - tenant_id = row["tenant_id"], - resource_type = ResourceType(row["resource_type"]), - resource_id = row["resource_id"], - utilization_rate = row["utilization_rate"], - peak_utilization = row["peak_utilization"], - avg_utilization = row["avg_utilization"], - idle_time_percent = row["idle_time_percent"], - report_date = row["report_date"], - recommendations = json.loads(row["recommendations"]), + id=row["id"], + tenant_id=row["tenant_id"], + resource_type=ResourceType(row["resource_type"]), + resource_id=row["resource_id"], + utilization_rate=row["utilization_rate"], + peak_utilization=row["peak_utilization"], + avg_utilization=row["avg_utilization"], + idle_time_percent=row["idle_time_percent"], + report_date=row["report_date"], + recommendations=json.loads(row["recommendations"]), ) def _row_to_idle_resource(self, row) -> IdleResource: return IdleResource( - id = row["id"], - tenant_id = row["tenant_id"], - resource_type = ResourceType(row["resource_type"]), - resource_id = row["resource_id"], - resource_name = row["resource_name"], - idle_since = row["idle_since"], - estimated_monthly_cost = row["estimated_monthly_cost"], - currency = row["currency"], - reason = row["reason"], - recommendation = row["recommendation"], - detected_at = row["detected_at"], + id=row["id"], + tenant_id=row["tenant_id"], + resource_type=ResourceType(row["resource_type"]), + resource_id=row["resource_id"], + resource_name=row["resource_name"], + idle_since=row["idle_since"], + estimated_monthly_cost=row["estimated_monthly_cost"], + currency=row["currency"], + reason=row["reason"], + recommendation=row["recommendation"], + detected_at=row["detected_at"], ) def _row_to_cost_optimization_suggestion(self, row) -> CostOptimizationSuggestion: return CostOptimizationSuggestion( - id = row["id"], - tenant_id = row["tenant_id"], - category = row["category"], - title = row["title"], - description = row["description"], - potential_savings = row["potential_savings"], - currency = row["currency"], - confidence = row["confidence"], - difficulty = row["difficulty"], - implementation_steps = json.loads(row["implementation_steps"]), - risk_level = row["risk_level"], - is_applied = bool(row["is_applied"]), - created_at = row["created_at"], - applied_at = row["applied_at"], + id=row["id"], + tenant_id=row["tenant_id"], + category=row["category"], + title=row["title"], + description=row["description"], + potential_savings=row["potential_savings"], + currency=row["currency"], + confidence=row["confidence"], + difficulty=row["difficulty"], + implementation_steps=json.loads(row["implementation_steps"]), + risk_level=row["risk_level"], + is_applied=bool(row["is_applied"]), + created_at=row["created_at"], + applied_at=row["applied_at"], ) # Singleton instance -_ops_manager = None +_ops_manager = None def get_ops_manager() -> OpsManager: global _ops_manager if _ops_manager is None: - _ops_manager = OpsManager() + _ops_manager = OpsManager() return _ops_manager diff --git a/backend/oss_uploader.py b/backend/oss_uploader.py index edbbf7d..743f82d 100644 --- a/backend/oss_uploader.py +++ b/backend/oss_uploader.py @@ -12,29 +12,29 @@ import oss2 class OSSUploader: def __init__(self) -> None: - self.access_key = os.getenv("ALI_ACCESS_KEY") - self.secret_key = os.getenv("ALI_SECRET_KEY") - self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio") - self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com") - self.endpoint = f"https://{self.region}" + self.access_key = os.getenv("ALI_ACCESS_KEY") + self.secret_key = os.getenv("ALI_SECRET_KEY") + self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio") + self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com") + self.endpoint = f"https://{self.region}" if not self.access_key or not self.secret_key: raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY must be set") - self.auth = oss2.Auth(self.access_key, self.secret_key) - self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name) + self.auth = oss2.Auth(self.access_key, self.secret_key) + self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name) def upload_audio(self, audio_data: bytes, filename: str) -> tuple: """上传音频到 OSS,返回 (URL, object_name)""" # 生成唯一文件名 - ext = os.path.splitext(filename)[1] or ".wav" - object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}" + ext = os.path.splitext(filename)[1] or ".wav" + object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}" # 上传文件 self.bucket.put_object(object_name, audio_data) # 生成临时访问 URL (1小时有效) - url = self.bucket.sign_url("GET", object_name, 3600) + url = self.bucket.sign_url("GET", object_name, 3600) return url, object_name def delete_object(self, object_name: str) -> None: @@ -43,11 +43,11 @@ class OSSUploader: # 单例 -_oss_uploader = None +_oss_uploader = None def get_oss_uploader() -> OSSUploader: global _oss_uploader if _oss_uploader is None: - _oss_uploader = OSSUploader() + _oss_uploader = OSSUploader() return _oss_uploader diff --git a/backend/performance_manager.py b/backend/performance_manager.py index 88a43d2..39c3fca 100644 --- a/backend/performance_manager.py +++ b/backend/performance_manager.py @@ -27,18 +27,18 @@ from typing import Any try: import redis - REDIS_AVAILABLE = True + REDIS_AVAILABLE = True except ImportError: - REDIS_AVAILABLE = False + REDIS_AVAILABLE = False # 尝试导入 Celery try: from celery import Celery from celery.result import AsyncResult - CELERY_AVAILABLE = True + CELERY_AVAILABLE = True except ImportError: - CELERY_AVAILABLE = False + CELERY_AVAILABLE = False # ==================== 数据模型 ==================== @@ -47,17 +47,17 @@ except ImportError: class CacheStats: """缓存统计数据模型""" - total_requests: int = 0 - hits: int = 0 - misses: int = 0 - evictions: int = 0 - expired: int = 0 - hit_rate: float = 0.0 + total_requests: int = 0 + hits: int = 0 + misses: int = 0 + evictions: int = 0 + expired: int = 0 + hit_rate: float = 0.0 def update_hit_rate(self) -> None: """更新命中率""" if self.total_requests > 0: - self.hit_rate = round(self.hits / self.total_requests, 4) + self.hit_rate = round(self.hits / self.total_requests, 4) @dataclass @@ -68,9 +68,9 @@ class CacheEntry: value: Any created_at: float expires_at: float | None - access_count: int = 0 - last_accessed: float = 0 - size_bytes: int = 0 + access_count: int = 0 + last_accessed: float = 0 + size_bytes: int = 0 @dataclass @@ -82,7 +82,7 @@ class PerformanceMetric: endpoint: str | None duration_ms: float timestamp: str - metadata: dict = field(default_factory = dict) + metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: return { @@ -104,12 +104,12 @@ class TaskInfo: status: str # pending, running, success, failed, retrying payload: dict created_at: str - started_at: str | None = None - completed_at: str | None = None - result: Any | None = None - error_message: str | None = None - retry_count: int = 0 - max_retries: int = 3 + started_at: str | None = None + completed_at: str | None = None + result: Any | None = None + error_message: str | None = None + retry_count: int = 0 + max_retries: int = 3 def to_dict(self) -> dict: return { @@ -134,10 +134,10 @@ class ShardInfo: shard_id: str shard_key_range: tuple[str, str] # (start, end) db_path: str - entity_count: int = 0 - is_active: bool = True - created_at: str = "" - last_accessed: str = "" + entity_count: int = 0 + is_active: bool = True + created_at: str = "" + last_accessed: str = "" # ==================== Redis 缓存层 ==================== @@ -160,42 +160,42 @@ class CacheManager: def __init__( self, - redis_url: str | None = None, - max_memory_size: int = 100 * 1024 * 1024, # 100MB - default_ttl: int = 3600, # 1小时 - db_path: str = "insightflow.db", + redis_url: str | None = None, + max_memory_size: int = 100 * 1024 * 1024, # 100MB + default_ttl: int = 3600, # 1小时 + db_path: str = "insightflow.db", ) -> None: - self.db_path = db_path - self.default_ttl = default_ttl - self.max_memory_size = max_memory_size - self.current_memory_size = 0 + self.db_path = db_path + self.default_ttl = default_ttl + self.max_memory_size = max_memory_size + self.current_memory_size = 0 # Redis 客户端 - self.redis_client = None - self.use_redis = False + self.redis_client = None + self.use_redis = False if REDIS_AVAILABLE and redis_url: try: - self.redis_client = redis.from_url(redis_url, decode_responses = True) + self.redis_client = redis.from_url(redis_url, decode_responses=True) self.redis_client.ping() - self.use_redis = True + self.use_redis = True print(f"Redis 缓存已连接: {redis_url}") except Exception as e: print(f"Redis 连接失败,使用内存缓存: {e}") # 内存缓存(LRU) - self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict() - self.cache_lock = threading.RLock() + self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict() + self.cache_lock = threading.RLock() # 统计 - self.stats = CacheStats() + self.stats = CacheStats() # 初始化缓存统计表 self._init_cache_tables() def _init_cache_tables(self) -> None: """初始化缓存统计表""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute(""" CREATE TABLE IF NOT EXISTS cache_stats ( @@ -233,11 +233,11 @@ class CacheManager: def _get_entry_size(self, value: Any) -> int: """估算缓存条目大小""" try: - return len(json.dumps(value, ensure_ascii = False).encode("utf-8")) + return len(json.dumps(value, ensure_ascii=False).encode("utf-8")) except (TypeError, ValueError): return 1024 # 默认估算 - def _evict_lru(self, required_space: int = 0) -> None: + def _evict_lru(self, required_space: int = 0) -> None: """LRU 淘汰策略""" with self.cache_lock: while ( @@ -245,7 +245,7 @@ class CacheManager: and self.memory_cache ): # 移除最久未访问的 - oldest_key, oldest_entry = self.memory_cache.popitem(last = False) + oldest_key, oldest_entry = self.memory_cache.popitem(last=False) self.current_memory_size -= oldest_entry.size_bytes self.stats.evictions += 1 @@ -263,7 +263,7 @@ class CacheManager: if self.use_redis: try: - value = self.redis_client.get(key) + value = self.redis_client.get(key) if value: self.stats.hits += 1 return json.loads(value) @@ -276,7 +276,7 @@ class CacheManager: else: # 内存缓存 with self.cache_lock: - entry = self.memory_cache.get(key) + entry = self.memory_cache.get(key) if entry: # 检查是否过期 @@ -289,7 +289,7 @@ class CacheManager: # 更新访问信息 entry.access_count += 1 - entry.last_accessed = time.time() + entry.last_accessed = time.time() self.memory_cache.move_to_end(key) self.stats.hits += 1 @@ -298,7 +298,7 @@ class CacheManager: self.stats.misses += 1 return None - def set(self, key: str, value: Any, ttl: int | None = None) -> bool: + def set(self, key: str, value: Any, ttl: int | None = None) -> bool: """ 设置缓存值 @@ -310,11 +310,11 @@ class CacheManager: Returns: bool: 是否成功 """ - ttl = ttl or self.default_ttl + ttl = ttl or self.default_ttl if self.use_redis: try: - serialized = json.dumps(value, ensure_ascii = False) + serialized = json.dumps(value, ensure_ascii=False) self.redis_client.setex(key, ttl, serialized) return True except Exception as e: @@ -323,27 +323,27 @@ class CacheManager: else: # 内存缓存 with self.cache_lock: - size = self._get_entry_size(value) + size = self._get_entry_size(value) # 检查是否需要淘汰 if self.current_memory_size + size > self.max_memory_size: self._evict_lru(size) - now = time.time() - entry = CacheEntry( - key = key, - value = value, - created_at = now, - expires_at = now + ttl if ttl > 0 else None, - size_bytes = size, - last_accessed = now, + now = time.time() + entry = CacheEntry( + key=key, + value=value, + created_at=now, + expires_at=now + ttl if ttl > 0 else None, + size_bytes=size, + last_accessed=now, ) # 如果已存在,更新大小 if key in self.memory_cache: self.current_memory_size -= self.memory_cache[key].size_bytes - self.memory_cache[key] = entry + self.memory_cache[key] = entry self.memory_cache.move_to_end(key) self.current_memory_size += size @@ -360,7 +360,7 @@ class CacheManager: else: with self.cache_lock: if key in self.memory_cache: - entry = self.memory_cache.pop(key) + entry = self.memory_cache.pop(key) self.current_memory_size -= entry.size_bytes return True return False @@ -377,19 +377,19 @@ class CacheManager: else: with self.cache_lock: self.memory_cache.clear() - self.current_memory_size = 0 + self.current_memory_size = 0 return True def get_many(self, keys: list[str]) -> dict[str, Any]: """批量获取缓存""" - results = {} + results = {} if self.use_redis: try: - values = self.redis_client.mget(keys) + values = self.redis_client.mget(keys) for key, value in zip(keys, values): if value: - results[key] = json.loads(value) + results[key] = json.loads(value) self.stats.hits += 1 else: self.stats.misses += 1 @@ -398,21 +398,21 @@ class CacheManager: print(f"Redis mget 失败: {e}") else: for key in keys: - value = self.get(key) + value = self.get(key) if value is not None: - results[key] = value + results[key] = value return results - def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool: + def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool: """批量设置缓存""" - ttl = ttl or self.default_ttl + ttl = ttl or self.default_ttl if self.use_redis: try: - pipe = self.redis_client.pipeline() + pipe = self.redis_client.pipeline() for key, value in mapping.items(): - serialized = json.dumps(value, ensure_ascii = False) + serialized = json.dumps(value, ensure_ascii=False) pipe.setex(key, ttl, serialized) pipe.execute() return True @@ -428,7 +428,7 @@ class CacheManager: """获取缓存统计""" self.stats.update_hit_rate() - stats = { + stats = { "total_requests": self.stats.total_requests, "hits": self.stats.hits, "misses": self.stats.misses, @@ -454,7 +454,7 @@ class CacheManager: def save_stats(self) -> None: """保存缓存统计到数据库""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) self.stats.update_hit_rate() @@ -487,13 +487,13 @@ class CacheManager: Returns: Dict: 预热统计 """ - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - stats = {"entities": 0, "relations": 0, "transcripts": 0} + stats = {"entities": 0, "relations": 0, "transcripts": 0} # 预热实体数据 - entities = conn.execute( + entities = conn.execute( """SELECT e.*, (SELECT COUNT(*) FROM entity_mentions m WHERE m.entity_id = e.id) as mention_count FROM entities e @@ -504,12 +504,12 @@ class CacheManager: ).fetchall() for entity in entities: - key = f"entity:{entity['id']}" - self.set(key, dict(entity), ttl = 7200) # 2小时 + key = f"entity:{entity['id']}" + self.set(key, dict(entity), ttl=7200) # 2小时 stats["entities"] += 1 # 预热关系数据 - relations = conn.execute( + relations = conn.execute( """SELECT r.*, e1.name as source_name, e2.name as target_name FROM entity_relations r @@ -521,12 +521,12 @@ class CacheManager: ).fetchall() for relation in relations: - key = f"relation:{relation['id']}" - self.set(key, dict(relation), ttl = 3600) + key = f"relation:{relation['id']}" + self.set(key, dict(relation), ttl=3600) stats["relations"] += 1 # 预热最近的转录 - transcripts = conn.execute( + transcripts = conn.execute( """SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC @@ -535,33 +535,33 @@ class CacheManager: ).fetchall() for transcript in transcripts: - key = f"transcript:{transcript['id']}" + key = f"transcript:{transcript['id']}" # 只缓存元数据,不缓存完整文本 - meta = { + meta = { "id": transcript["id"], "filename": transcript["filename"], "type": transcript.get("type", "audio"), "created_at": transcript["created_at"], } - self.set(key, meta, ttl = 1800) # 30分钟 + self.set(key, meta, ttl=1800) # 30分钟 stats["transcripts"] += 1 # 预热项目知识库摘要 - entity_count = conn.execute( + entity_count = conn.execute( "SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id, ) ).fetchone()[0] - relation_count = conn.execute( + relation_count = conn.execute( "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id, ) ).fetchone()[0] - summary = { + summary = { "project_id": project_id, "entity_count": entity_count, "relation_count": relation_count, "cached_at": datetime.now().isoformat(), } - self.set(f"project_summary:{project_id}", summary, ttl = 3600) + self.set(f"project_summary:{project_id}", summary, ttl=3600) conn.close() @@ -577,13 +577,13 @@ class CacheManager: Returns: int: 清除的缓存数量 """ - count = 0 + count = 0 if self.use_redis: try: # 使用 Redis 的 scan 查找相关 key - pattern = f"*:{project_id}:*" - for key in self.redis_client.scan_iter(match = pattern): + pattern = f"*:{project_id}:*" + for key in self.redis_client.scan_iter(match=pattern): self.redis_client.delete(key) count += 1 except Exception as e: @@ -591,9 +591,9 @@ class CacheManager: else: # 内存缓存 - 查找并删除相关 key with self.cache_lock: - keys_to_delete = [key for key in self.memory_cache.keys() if project_id in key] + keys_to_delete = [key for key in self.memory_cache.keys() if project_id in key] for key in keys_to_delete: - entry = self.memory_cache.pop(key) + entry = self.memory_cache.pop(key) self.current_memory_size -= entry.size_bytes count += 1 @@ -616,19 +616,19 @@ class DatabaseSharding: def __init__( self, - base_db_path: str = "insightflow.db", - shard_db_dir: str = "./shards", - shards_count: int = 4, + base_db_path: str = "insightflow.db", + shard_db_dir: str = "./shards", + shards_count: int = 4, ) -> None: - self.base_db_path = base_db_path - self.shard_db_dir = shard_db_dir - self.shards_count = shards_count + self.base_db_path = base_db_path + self.shard_db_dir = shard_db_dir + self.shards_count = shards_count # 确保分片目录存在 - os.makedirs(shard_db_dir, exist_ok = True) + os.makedirs(shard_db_dir, exist_ok=True) # 分片映射 - self.shard_map: dict[str, ShardInfo] = {} + self.shard_map: dict[str, ShardInfo] = {} # 初始化分片 self._init_shards() @@ -636,24 +636,24 @@ class DatabaseSharding: def _init_shards(self) -> None: """初始化分片""" # 计算每个分片的 key 范围 - chars = "0123456789abcdef" - chars_per_shard = len(chars) // self.shards_count + chars = "0123456789abcdef" + chars_per_shard = len(chars) // self.shards_count for i in range(self.shards_count): - start_idx = i * chars_per_shard - end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars) + start_idx = i * chars_per_shard + end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars) - start_char = chars[start_idx] - end_char = chars[end_idx - 1] + start_char = chars[start_idx] + end_char = chars[end_idx - 1] - shard_id = f"shard_{i}" - db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db") + shard_id = f"shard_{i}" + db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db") - self.shard_map[shard_id] = ShardInfo( - shard_id = shard_id, - shard_key_range = (start_char, end_char), - db_path = db_path, - created_at = datetime.now().isoformat(), + self.shard_map[shard_id] = ShardInfo( + shard_id=shard_id, + shard_key_range=(start_char, end_char), + db_path=db_path, + created_at=datetime.now().isoformat(), ) # 确保分片数据库存在 @@ -662,7 +662,7 @@ class DatabaseSharding: def _create_shard_db(self, db_path: str) -> None: """创建分片数据库""" - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path) # 创建与主库相同的表结构(简化版) conn.executescript(""" @@ -702,10 +702,10 @@ class DatabaseSharding: if not project_id: return "shard_0" - first_char = project_id[0].lower() + first_char = project_id[0].lower() for shard_id, shard_info in self.shard_map.items(): - start, end = shard_info.shard_key_range + start, end = shard_info.shard_key_range if start <= first_char <= end: return shard_id @@ -713,14 +713,14 @@ class DatabaseSharding: def get_shard_connection(self, project_id: str) -> sqlite3.Connection: """获取项目对应的分片连接""" - shard_id = self._get_shard_id(project_id) - shard_info = self.shard_map[shard_id] + shard_id = self._get_shard_id(project_id) + shard_info = self.shard_map[shard_id] - conn = sqlite3.connect(shard_info.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(shard_info.db_path) + conn.row_factory = sqlite3.Row # 更新访问时间 - shard_info.last_accessed = datetime.now().isoformat() + shard_info.last_accessed = datetime.now().isoformat() return conn @@ -740,34 +740,34 @@ class DatabaseSharding: bool: 是否成功 """ # 获取源分片 - source_shard_id = self._get_shard_id(project_id) + source_shard_id = self._get_shard_id(project_id) if source_shard_id == target_shard_id: return True # 已经在目标分片 - source_info = self.shard_map.get(source_shard_id) - target_info = self.shard_map.get(target_shard_id) + source_info = self.shard_map.get(source_shard_id) + target_info = self.shard_map.get(target_shard_id) if not source_info or not target_info: return False try: # 从源分片读取数据 - source_conn = sqlite3.connect(source_info.db_path) - source_conn.row_factory = sqlite3.Row + source_conn = sqlite3.connect(source_info.db_path) + source_conn.row_factory = sqlite3.Row - entities = source_conn.execute( + entities = source_conn.execute( "SELECT * FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() - relations = source_conn.execute( + relations = source_conn.execute( "SELECT * FROM entity_relations WHERE project_id = ?", (project_id, ) ).fetchall() source_conn.close() # 写入目标分片 - target_conn = sqlite3.connect(target_info.db_path) + target_conn = sqlite3.connect(target_info.db_path) for entity in entities: target_conn.execute( @@ -793,7 +793,7 @@ class DatabaseSharding: target_conn.close() # 从源分片删除数据 - source_conn = sqlite3.connect(source_info.db_path) + source_conn = sqlite3.connect(source_info.db_path) source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id, )) source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id, )) source_conn.commit() @@ -811,15 +811,15 @@ class DatabaseSharding: def _update_shard_stats(self, shard_id: str) -> None: """更新分片统计""" - shard_info = self.shard_map.get(shard_id) + shard_info = self.shard_map.get(shard_id) if not shard_info: return - conn = sqlite3.connect(shard_info.db_path) + conn = sqlite3.connect(shard_info.db_path) - count = conn.execute("SELECT COUNT(DISTINCT project_id) FROM entities").fetchone()[0] + count = conn.execute("SELECT COUNT(DISTINCT project_id) FROM entities").fetchone()[0] - shard_info.entity_count = count + shard_info.entity_count = count conn.close() @@ -833,14 +833,14 @@ class DatabaseSharding: Returns: List[Dict]: 合并的查询结果 """ - results = [] + results = [] for shard_info in self.shard_map.values(): - conn = sqlite3.connect(shard_info.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(shard_info.db_path) + conn.row_factory = sqlite3.Row try: - shard_results = query_func(conn) + shard_results = query_func(conn) results.extend(shard_results) except Exception as e: print(f"分片 {shard_info.shard_id} 查询失败: {e}") @@ -851,7 +851,7 @@ class DatabaseSharding: def get_shard_stats(self) -> list[dict]: """获取所有分片的统计信息""" - stats = [] + stats = [] for shard_info in self.shard_map.values(): self._update_shard_stats(shard_info.shard_id) @@ -880,17 +880,17 @@ class DatabaseSharding: Dict: 重新平衡统计 """ # 获取各分片的负载 - stats = self.get_shard_stats() + stats = self.get_shard_stats() if not stats: return {"message": "No shards to rebalance"} # 计算平均负载 - avg_load = sum(s["entity_count"] for s in stats) / len(stats) + avg_load = sum(s["entity_count"] for s in stats) / len(stats) # 找出过载和欠载的分片 - overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5] - underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5] + overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5] + underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5] # 简化的重新平衡逻辑 # 实际生产环境需要更复杂的算法 @@ -917,16 +917,16 @@ class TaskQueue: - 任务状态追踪和重试机制 """ - def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db") -> None: - self.db_path = db_path - self.redis_url = redis_url - self.celery_app = None - self.use_celery = False + def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db") -> None: + self.db_path = db_path + self.redis_url = redis_url + self.celery_app = None + self.use_celery = False # 内存任务存储(非 Celery 模式) - self.tasks: dict[str, TaskInfo] = {} - self.task_handlers: dict[str, Callable] = {} - self.task_lock = threading.RLock() + self.tasks: dict[str, TaskInfo] = {} + self.task_handlers: dict[str, Callable] = {} + self.task_lock = threading.RLock() # 初始化任务队列表 self._init_task_tables() @@ -934,15 +934,15 @@ class TaskQueue: # 初始化 Celery if CELERY_AVAILABLE and redis_url: try: - self.celery_app = Celery("insightflow", broker = redis_url, backend = redis_url) - self.use_celery = True + self.celery_app = Celery("insightflow", broker=redis_url, backend=redis_url) + self.use_celery = True print("Celery 任务队列已初始化") except Exception as e: print(f"Celery 初始化失败,使用内存任务队列: {e}") def _init_task_tables(self) -> None: """初始化任务队列表""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute(""" CREATE TABLE IF NOT EXISTS task_queue ( @@ -972,9 +972,9 @@ class TaskQueue: def register_handler(self, task_type: str, handler: Callable) -> None: """注册任务处理器""" - self.task_handlers[task_type] = handler + self.task_handlers[task_type] = handler - def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str: + def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str: """ 提交任务 @@ -986,45 +986,45 @@ class TaskQueue: Returns: str: 任务ID """ - task_id = str(uuid.uuid4())[:16] + task_id = str(uuid.uuid4())[:16] - task = TaskInfo( - id = task_id, - task_type = task_type, - status = "pending", - payload = payload, - created_at = datetime.now().isoformat(), - max_retries = max_retries, + task = TaskInfo( + id=task_id, + task_type=task_type, + status="pending", + payload=payload, + created_at=datetime.now().isoformat(), + max_retries=max_retries, ) if self.use_celery: # 使用 Celery try: # 这里简化处理,实际应该定义具体的 Celery 任务 - result = self.celery_app.send_task( + result = self.celery_app.send_task( f"insightflow.tasks.{task_type}", - args = [payload], - task_id = task_id, - retry = True, - retry_policy = { + args=[payload], + task_id=task_id, + retry=True, + retry_policy={ "max_retries": max_retries, "interval_start": 10, "interval_step": 10, "interval_max": 60, }, ) - task.id = result.id + task.id = result.id except Exception as e: print(f"Celery 任务提交失败: {e}") # 回退到内存模式 - self.use_celery = False + self.use_celery = False if not self.use_celery: # 内存模式 with self.task_lock: - self.tasks[task_id] = task + self.tasks[task_id] = task # 异步执行 - threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start() + threading.Thread(target=self._execute_task, args=(task_id, ), daemon=True).start() # 保存到数据库 self._save_task(task) @@ -1034,49 +1034,49 @@ class TaskQueue: def _execute_task(self, task_id: str) -> None: """执行任务(内存模式)""" with self.task_lock: - task = self.tasks.get(task_id) + task = self.tasks.get(task_id) if not task: return - task.status = "running" - task.started_at = datetime.now().isoformat() + task.status = "running" + task.started_at = datetime.now().isoformat() self._update_task_status(task) # 获取处理器 - handler = self.task_handlers.get(task.task_type) + handler = self.task_handlers.get(task.task_type) if not handler: - task.status = "failed" - task.error_message = f"No handler for task type: {task.task_type}" + task.status = "failed" + task.error_message = f"No handler for task type: {task.task_type}" else: try: - result = handler(task.payload) - task.status = "success" - task.result = result + result = handler(task.payload) + task.status = "success" + task.result = result except Exception as e: task.retry_count += 1 if task.retry_count <= task.max_retries: - task.status = "retrying" + task.status = "retrying" # 延迟重试 threading.Timer( - 10 * task.retry_count, self._execute_task, args = (task_id, ) + 10 * task.retry_count, self._execute_task, args=(task_id, ) ).start() else: - task.status = "failed" - task.error_message = str(e) + task.status = "failed" + task.error_message = str(e) - task.completed_at = datetime.now().isoformat() + task.completed_at = datetime.now().isoformat() with self.task_lock: - self.tasks[task_id] = task + self.tasks[task_id] = task self._update_task_status(task) def _save_task(self, task: TaskInfo) -> None: """保存任务到数据库""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute( """ @@ -1089,8 +1089,8 @@ class TaskQueue: task.id, task.task_type, task.status, - json.dumps(task.payload, ensure_ascii = False), - json.dumps(task.result, ensure_ascii = False) if task.result else None, + json.dumps(task.payload, ensure_ascii=False), + json.dumps(task.result, ensure_ascii=False) if task.result else None, task.error_message, task.retry_count, task.max_retries, @@ -1105,7 +1105,7 @@ class TaskQueue: def _update_task_status(self, task: TaskInfo) -> None: """更新任务状态""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute( """ @@ -1120,7 +1120,7 @@ class TaskQueue: """, ( task.status, - json.dumps(task.result, ensure_ascii = False) if task.result else None, + json.dumps(task.result, ensure_ascii=False) if task.result else None, task.error_message, task.retry_count, task.started_at, @@ -1136,9 +1136,9 @@ class TaskQueue: """获取任务状态""" if self.use_celery: try: - result = AsyncResult(task_id, app = self.celery_app) + result = AsyncResult(task_id, app=self.celery_app) - status_map = { + status_map = { "PENDING": "pending", "STARTED": "running", "SUCCESS": "success", @@ -1147,13 +1147,13 @@ class TaskQueue: } return TaskInfo( - id = task_id, - task_type = "celery_task", - status = status_map.get(result.status, "unknown"), - payload = {}, - created_at = "", - result = result.result if result.successful() else None, - error_message = str(result.result) if result.failed() else None, + id=task_id, + task_type="celery_task", + status=status_map.get(result.status, "unknown"), + payload={}, + created_at="", + result=result.result if result.successful() else None, + error_message=str(result.result) if result.failed() else None, ) except Exception as e: print(f"获取 Celery 任务状态失败: {e}") @@ -1163,14 +1163,14 @@ class TaskQueue: return self.tasks.get(task_id) def list_tasks( - self, status: str | None = None, task_type: str | None = None, limit: int = 100 + self, status: str | None = None, task_type: str | None = None, limit: int = 100 ) -> list[TaskInfo]: """列出任务""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - where_clauses = [] - params = [] + where_clauses = [] + params = [] if status: where_clauses.append("status = ?") @@ -1180,9 +1180,9 @@ class TaskQueue: where_clauses.append("task_type = ?") params.append(task_type) - where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1" + where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1" - rows = conn.execute( + rows = conn.execute( f""" SELECT * FROM task_queue WHERE {where_str} @@ -1194,21 +1194,21 @@ class TaskQueue: conn.close() - tasks = [] + tasks = [] for row in rows: tasks.append( TaskInfo( - id = row["id"], - task_type = row["task_type"], - status = row["status"], - payload = json.loads(row["payload"]) if row["payload"] else {}, - created_at = row["created_at"], - started_at = row["started_at"], - completed_at = row["completed_at"], - result = json.loads(row["result"]) if row["result"] else None, - error_message = row["error_message"], - retry_count = row["retry_count"], - max_retries = row["max_retries"], + id=row["id"], + task_type=row["task_type"], + status=row["status"], + payload=json.loads(row["payload"]) if row["payload"] else {}, + created_at=row["created_at"], + started_at=row["started_at"], + completed_at=row["completed_at"], + result=json.loads(row["result"]) if row["result"] else None, + error_message=row["error_message"], + retry_count=row["retry_count"], + max_retries=row["max_retries"], ) ) @@ -1218,16 +1218,16 @@ class TaskQueue: """取消任务""" if self.use_celery: try: - self.celery_app.control.revoke(task_id, terminate = True) + self.celery_app.control.revoke(task_id, terminate=True) return True except Exception as e: print(f"取消 Celery 任务失败: {e}") with self.task_lock: - task = self.tasks.get(task_id) + task = self.tasks.get(task_id) if task and task.status in ["pending", "running"]: - task.status = "cancelled" - task.completed_at = datetime.now().isoformat() + task.status = "cancelled" + task.completed_at = datetime.now().isoformat() self._update_task_status(task) return True @@ -1235,44 +1235,44 @@ class TaskQueue: def retry(self, task_id: str) -> bool: """重试失败的任务""" - task = self.get_status(task_id) + task = self.get_status(task_id) if not task or task.status != "failed": return False - task.status = "pending" - task.retry_count = 0 - task.error_message = None - task.completed_at = None + task.status = "pending" + task.retry_count = 0 + task.error_message = None + task.completed_at = None if not self.use_celery: with self.task_lock: - self.tasks[task_id] = task - threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start() + self.tasks[task_id] = task + threading.Thread(target=self._execute_task, args=(task_id, ), daemon=True).start() self._update_task_status(task) return True def get_stats(self) -> dict: """获取任务队列统计""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) # 各状态任务数量 - status_counts = conn.execute(""" + status_counts = conn.execute(""" SELECT status, COUNT(*) as count FROM task_queue GROUP BY status """).fetchall() # 各类型任务数量 - type_counts = conn.execute(""" + type_counts = conn.execute(""" SELECT task_type, COUNT(*) as count FROM task_queue GROUP BY task_type """).fetchall() # 最近24小时任务数 - recent_count = conn.execute(""" + recent_count = conn.execute(""" SELECT COUNT(*) as count FROM task_queue WHERE created_at > datetime('now', '-1 day') @@ -1304,28 +1304,28 @@ class PerformanceMonitor: def __init__( self, - db_path: str = "insightflow.db", - slow_query_threshold: int = 1000, - alert_threshold: int = 5000, # 毫秒 + db_path: str = "insightflow.db", + slow_query_threshold: int = 1000, + alert_threshold: int = 5000, # 毫秒 ) -> None: # 毫秒 - self.db_path = db_path - self.slow_query_threshold = slow_query_threshold - self.alert_threshold = alert_threshold + self.db_path = db_path + self.slow_query_threshold = slow_query_threshold + self.alert_threshold = alert_threshold # 内存中的指标缓存 - self.metrics_buffer: list[PerformanceMetric] = [] - self.buffer_lock = threading.RLock() - self.buffer_size = 100 + self.metrics_buffer: list[PerformanceMetric] = [] + self.buffer_lock = threading.RLock() + self.buffer_size = 100 # 告警回调 - self.alert_handlers: list[Callable] = [] + self.alert_handlers: list[Callable] = [] def record_metric( self, metric_type: str, duration_ms: float, - endpoint: str | None = None, - metadata: dict | None = None, + endpoint: str | None = None, + metadata: dict | None = None, ) -> None: """ 记录性能指标 @@ -1336,13 +1336,13 @@ class PerformanceMonitor: endpoint: 端点/查询标识 metadata: 额外元数据 """ - metric = PerformanceMetric( - id = str(uuid.uuid4())[:16], - metric_type = metric_type, - endpoint = endpoint, - duration_ms = duration_ms, - timestamp = datetime.now().isoformat(), - metadata = metadata or {}, + metric = PerformanceMetric( + id=str(uuid.uuid4())[:16], + metric_type=metric_type, + endpoint=endpoint, + duration_ms=duration_ms, + timestamp=datetime.now().isoformat(), + metadata=metadata or {}, ) # 添加到缓冲区 @@ -1364,7 +1364,7 @@ class PerformanceMonitor: if not self.metrics_buffer: return - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) for metric in self.metrics_buffer: conn.execute( @@ -1379,14 +1379,14 @@ class PerformanceMonitor: metric.endpoint, metric.duration_ms, metric.timestamp, - json.dumps(metric.metadata, ensure_ascii = False), + json.dumps(metric.metadata, ensure_ascii=False), ), ) conn.commit() conn.close() - self.metrics_buffer = [] + self.metrics_buffer = [] def _record_slow_query(self, metric: PerformanceMetric) -> None: """记录慢查询""" @@ -1395,7 +1395,7 @@ class PerformanceMonitor: def _trigger_alert(self, metric: PerformanceMetric) -> None: """触发告警""" - alert_data = { + alert_data = { "type": "performance_alert", "metric": metric.to_dict(), "threshold": self.alert_threshold, @@ -1412,7 +1412,7 @@ class PerformanceMonitor: """注册告警处理器""" self.alert_handlers.append(handler) - def get_stats(self, hours: int = 24) -> dict: + def get_stats(self, hours: int = 24) -> dict: """ 获取性能统计 @@ -1425,11 +1425,11 @@ class PerformanceMonitor: # 先刷新缓冲区 self._flush_metrics() - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 总体统计 - overall = conn.execute( + overall = conn.execute( """ SELECT COUNT(*) as total, @@ -1443,7 +1443,7 @@ class PerformanceMonitor: ).fetchone() # 按类型统计 - by_type = conn.execute( + by_type = conn.execute( """ SELECT metric_type, @@ -1458,7 +1458,7 @@ class PerformanceMonitor: ).fetchall() # 按端点统计(API) - by_endpoint = conn.execute( + by_endpoint = conn.execute( """ SELECT endpoint, @@ -1476,7 +1476,7 @@ class PerformanceMonitor: ).fetchall() # 慢查询统计 - slow_queries = conn.execute( + slow_queries = conn.execute( """ SELECT metric_type, @@ -1531,22 +1531,22 @@ class PerformanceMonitor: ], } - def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict: + def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict: """获取 API 性能详情""" self._flush_metrics() - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - where_clause = "metric_type = 'api_response'" - params = [f"-{hours} hours"] + where_clause = "metric_type = 'api_response'" + params = [f"-{hours} hours"] if endpoint: where_clause += " AND endpoint = ?" params.append(endpoint) # 百分位数统计 - percentiles = conn.execute( + percentiles = conn.execute( f""" SELECT endpoint, @@ -1580,7 +1580,7 @@ class PerformanceMonitor: ], } - def cleanup_old_metrics(self, days: int = 30) -> int: + def cleanup_old_metrics(self, days: int = 30) -> int: """ 清理旧的性能指标数据 @@ -1590,9 +1590,9 @@ class PerformanceMonitor: Returns: int: 删除的记录数 """ - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) - cursor = conn.execute( + cursor = conn.execute( """ DELETE FROM performance_metrics WHERE timestamp < datetime('now', ?) @@ -1600,7 +1600,7 @@ class PerformanceMonitor: (f"-{days} days", ), ) - deleted = cursor.rowcount + deleted = cursor.rowcount conn.commit() conn.close() @@ -1613,9 +1613,9 @@ class PerformanceMonitor: def cached( cache_manager: CacheManager, - key_prefix: str = "", - ttl: int = 3600, - key_func: Callable | None = None, + key_prefix: str = "", + ttl: int = 3600, + key_func: Callable | None = None, ) -> None: """ 缓存装饰器 @@ -1632,19 +1632,19 @@ def cached( def wrapper(*args, **kwargs) -> None: # 生成缓存键 if key_func: - cache_key = key_func(*args, **kwargs) + cache_key = key_func(*args, **kwargs) else: # 默认使用函数名和参数哈希 - key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}" - cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}" + key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}" + cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}" # 尝试从缓存获取 - cached_value = cache_manager.get(cache_key) + cached_value = cache_manager.get(cache_key) if cached_value is not None: return cached_value # 执行函数 - result = func(*args, **kwargs) + result = func(*args, **kwargs) # 写入缓存 cache_manager.set(cache_key, result, ttl) @@ -1656,7 +1656,7 @@ def cached( return decorator -def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None: +def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None: """ 性能监控装饰器 @@ -1669,14 +1669,14 @@ def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | Non def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> None: - start_time = time.time() + start_time = time.time() try: - result = func(*args, **kwargs) + result = func(*args, **kwargs) return result finally: - duration_ms = (time.time() - start_time) * 1000 - ep = endpoint or func.__name__ + duration_ms = (time.time() - start_time) * 1000 + ep = endpoint or func.__name__ monitor.record_metric(metric_type, duration_ms, ep) return wrapper @@ -1696,20 +1696,20 @@ class PerformanceManager: def __init__( self, - db_path: str = "insightflow.db", - redis_url: str | None = None, - enable_sharding: bool = False, + db_path: str = "insightflow.db", + redis_url: str | None = None, + enable_sharding: bool = False, ) -> None: - self.db_path = db_path + self.db_path = db_path # 初始化各模块 - self.cache = CacheManager(redis_url = redis_url, db_path = db_path) + self.cache = CacheManager(redis_url=redis_url, db_path=db_path) - self.sharding = DatabaseSharding(base_db_path = db_path) if enable_sharding else None + self.sharding = DatabaseSharding(base_db_path=db_path) if enable_sharding else None - self.task_queue = TaskQueue(redis_url = redis_url, db_path = db_path) + self.task_queue = TaskQueue(redis_url=redis_url, db_path=db_path) - self.monitor = PerformanceMonitor(db_path = db_path) + self.monitor = PerformanceMonitor(db_path=db_path) def get_health_status(self) -> dict: """获取系统健康状态""" @@ -1737,29 +1737,29 @@ class PerformanceManager: def get_full_stats(self) -> dict: """获取完整统计信息""" - stats = { + stats = { "cache": self.cache.get_stats(), "task_queue": self.task_queue.get_stats(), "performance": self.monitor.get_stats(), } if self.sharding: - stats["sharding"] = self.sharding.get_shard_stats() + stats["sharding"] = self.sharding.get_shard_stats() return stats # 单例模式 -_performance_manager = None +_performance_manager = None def get_performance_manager( - db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False + db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False ) -> PerformanceManager: """获取性能管理器单例""" global _performance_manager if _performance_manager is None: - _performance_manager = PerformanceManager( - db_path = db_path, redis_url = redis_url, enable_sharding = enable_sharding + _performance_manager = PerformanceManager( + db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding ) return _performance_manager diff --git a/backend/plugin_manager.py b/backend/plugin_manager.py index 64e8375..c6da345 100644 --- a/backend/plugin_manager.py +++ b/backend/plugin_manager.py @@ -18,40 +18,39 @@ from enum import Enum from typing import Any import httpx -from plugin_manager import PluginManager import urllib.parse # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # WebDAV 支持 try: import webdav4.client as webdav_client - WEBDAV_AVAILABLE = True + WEBDAV_AVAILABLE = True except ImportError: - WEBDAV_AVAILABLE = False + WEBDAV_AVAILABLE = False class PluginType(Enum): """插件类型""" - CHROME_EXTENSION = "chrome_extension" - FEISHU_BOT = "feishu_bot" - DINGTALK_BOT = "dingtalk_bot" - ZAPIER = "zapier" - MAKE = "make" - WEBDAV = "webdav" - CUSTOM = "custom" + CHROME_EXTENSION = "chrome_extension" + FEISHU_BOT = "feishu_bot" + DINGTALK_BOT = "dingtalk_bot" + ZAPIER = "zapier" + MAKE = "make" + WEBDAV = "webdav" + CUSTOM = "custom" class PluginStatus(Enum): """插件状态""" - ACTIVE = "active" - INACTIVE = "inactive" - ERROR = "error" - PENDING = "pending" + ACTIVE = "active" + INACTIVE = "inactive" + ERROR = "error" + PENDING = "pending" @dataclass @@ -62,12 +61,12 @@ class Plugin: name: str plugin_type: str project_id: str - status: str = "active" - config: dict = field(default_factory = dict) - created_at: str = "" - updated_at: str = "" - last_used_at: str | None = None - use_count: int = 0 + status: str = "active" + config: dict = field(default_factory=dict) + created_at: str = "" + updated_at: str = "" + last_used_at: str | None = None + use_count: int = 0 @dataclass @@ -78,9 +77,9 @@ class PluginConfig: plugin_id: str config_key: str config_value: str - is_encrypted: bool = False - created_at: str = "" - updated_at: str = "" + is_encrypted: bool = False + created_at: str = "" + updated_at: str = "" @dataclass @@ -91,14 +90,14 @@ class BotSession: bot_type: str # feishu, dingtalk session_id: str # 群ID或会话ID session_name: str - project_id: str | None = None - webhook_url: str = "" - secret: str = "" - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_message_at: str | None = None - message_count: int = 0 + project_id: str | None = None + webhook_url: str = "" + secret: str = "" + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_message_at: str | None = None + message_count: int = 0 @dataclass @@ -109,15 +108,15 @@ class WebhookEndpoint: name: str endpoint_type: str # zapier, make, custom endpoint_url: str - project_id: str | None = None - auth_type: str = "none" # none, api_key, oauth, custom - auth_config: dict = field(default_factory = dict) - trigger_events: list[str] = field(default_factory = list) - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_triggered_at: str | None = None - trigger_count: int = 0 + project_id: str | None = None + auth_type: str = "none" # none, api_key, oauth, custom + auth_config: dict = field(default_factory=dict) + trigger_events: list[str] = field(default_factory=list) + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_triggered_at: str | None = None + trigger_count: int = 0 @dataclass @@ -129,17 +128,17 @@ class WebDAVSync: project_id: str server_url: str username: str - password: str = "" # 加密存储 - remote_path: str = "/insightflow" - sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only - sync_interval: int = 3600 # 秒 - last_sync_at: str | None = None - last_sync_status: str = "pending" # pending, success, failed - last_sync_error: str = "" - is_active: bool = True - created_at: str = "" - updated_at: str = "" - sync_count: int = 0 + password: str = "" # 加密存储 + remote_path: str = "/insightflow" + sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only + sync_interval: int = 3600 # 秒 + last_sync_at: str | None = None + last_sync_status: str = "pending" # pending, success, failed + last_sync_error: str = "" + is_active: bool = True + created_at: str = "" + updated_at: str = "" + sync_count: int = 0 @dataclass @@ -148,33 +147,33 @@ class ChromeExtensionToken: id: str token: str - user_id: str | None = None - project_id: str | None = None - name: str = "" - permissions: list[str] = field(default_factory = lambda: ["read", "write"]) - expires_at: str | None = None - created_at: str = "" - last_used_at: str | None = None - use_count: int = 0 - is_revoked: bool = False + user_id: str | None = None + project_id: str | None = None + name: str = "" + permissions: list[str] = field(default_factory=lambda: ["read", "write"]) + expires_at: str | None = None + created_at: str = "" + last_used_at: str | None = None + use_count: int = 0 + is_revoked: bool = False class PluginManager: """插件管理主类""" - def __init__(self, db_manager = None) -> None: - self.db = db_manager - self._handlers = {} + def __init__(self, db_manager=None) -> None: + self.db = db_manager + self._handlers = {} self._register_default_handlers() def _register_default_handlers(self) -> None: """注册默认处理器""" - self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self) - self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu") - self._handlers[PluginType.DINGTALK_BOT] = BotHandler(self, "dingtalk") - self._handlers[PluginType.ZAPIER] = WebhookIntegration(self, "zapier") - self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make") - self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self) + self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self) + self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu") + self._handlers[PluginType.DINGTALK_BOT] = BotHandler(self, "dingtalk") + self._handlers[PluginType.ZAPIER] = WebhookIntegration(self, "zapier") + self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make") + self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self) def get_handler(self, plugin_type: PluginType) -> Any | None: """获取插件处理器""" @@ -184,8 +183,8 @@ class PluginManager: def create_plugin(self, plugin: Plugin) -> Plugin: """创建插件""" - conn = self.db.get_conn() - now = datetime.now().isoformat() + conn = self.db.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO plugins @@ -206,14 +205,14 @@ class PluginManager: conn.commit() conn.close() - plugin.created_at = now - plugin.updated_at = now + plugin.created_at = now + plugin.updated_at = now return plugin def get_plugin(self, plugin_id: str) -> Plugin | None: """获取插件""" - conn = self.db.get_conn() - row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id, )).fetchone() + conn = self.db.get_conn() + row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id, )).fetchone() conn.close() if row: @@ -221,13 +220,13 @@ class PluginManager: return None def list_plugins( - self, project_id: str = None, plugin_type: str = None, status: str = None + self, project_id: str = None, plugin_type: str = None, status: str = None ) -> list[Plugin]: """列出插件""" - conn = self.db.get_conn() + conn = self.db.get_conn() - conditions = [] - params = [] + conditions = [] + params = [] if project_id: conditions.append("project_id = ?") @@ -239,9 +238,9 @@ class PluginManager: conditions.append("status = ?") params.append(status) - where_clause = " AND ".join(conditions) if conditions else "1 = 1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params ).fetchall() conn.close() @@ -250,11 +249,11 @@ class PluginManager: def update_plugin(self, plugin_id: str, **kwargs) -> Plugin | None: """更新插件""" - conn = self.db.get_conn() + conn = self.db.get_conn() - allowed_fields = ["name", "status", "config"] - updates = [] - values = [] + allowed_fields = ["name", "status", "config"] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -272,7 +271,7 @@ class PluginManager: values.append(datetime.now().isoformat()) values.append(plugin_id) - query = f"UPDATE plugins SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE plugins SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -281,13 +280,13 @@ class PluginManager: def delete_plugin(self, plugin_id: str) -> bool: """删除插件""" - conn = self.db.get_conn() + conn = self.db.get_conn() # 删除关联的配置 conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id, )) # 删除插件 - cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id, )) + cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id, )) conn.commit() conn.close() @@ -296,29 +295,29 @@ class PluginManager: def _row_to_plugin(self, row: sqlite3.Row) -> Plugin: """将数据库行转换为 Plugin 对象""" return Plugin( - id = row["id"], - name = row["name"], - plugin_type = row["plugin_type"], - project_id = row["project_id"], - status = row["status"], - config = json.loads(row["config"]) if row["config"] else {}, - created_at = row["created_at"], - updated_at = row["updated_at"], - last_used_at = row["last_used_at"], - use_count = row["use_count"], + id=row["id"], + name=row["name"], + plugin_type=row["plugin_type"], + project_id=row["project_id"], + status=row["status"], + config=json.loads(row["config"]) if row["config"] else {}, + created_at=row["created_at"], + updated_at=row["updated_at"], + last_used_at=row["last_used_at"], + use_count=row["use_count"], ) # ==================== Plugin Config ==================== def set_plugin_config( - self, plugin_id: str, key: str, value: str, is_encrypted: bool = False + self, plugin_id: str, key: str, value: str, is_encrypted: bool = False ) -> PluginConfig: """设置插件配置""" - conn = self.db.get_conn() - now = datetime.now().isoformat() + conn = self.db.get_conn() + now = datetime.now().isoformat() # 检查是否已存在 - existing = conn.execute( + existing = conn.execute( "SELECT id FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) ).fetchone() @@ -329,9 +328,9 @@ class PluginManager: WHERE id = ?""", (value, is_encrypted, now, existing["id"]), ) - config_id = existing["id"] + config_id = existing["id"] else: - config_id = str(uuid.uuid4())[:UUID_LENGTH] + config_id = str(uuid.uuid4())[:UUID_LENGTH] conn.execute( """INSERT INTO plugin_configs (id, plugin_id, config_key, config_value, is_encrypted, created_at, updated_at) @@ -343,19 +342,19 @@ class PluginManager: conn.close() return PluginConfig( - id = config_id, - plugin_id = plugin_id, - config_key = key, - config_value = value, - is_encrypted = is_encrypted, - created_at = now, - updated_at = now, + id=config_id, + plugin_id=plugin_id, + config_key=key, + config_value=value, + is_encrypted=is_encrypted, + created_at=now, + updated_at=now, ) def get_plugin_config(self, plugin_id: str, key: str) -> str | None: """获取插件配置""" - conn = self.db.get_conn() - row = conn.execute( + conn = self.db.get_conn() + row = conn.execute( "SELECT config_value FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key), ).fetchone() @@ -365,8 +364,8 @@ class PluginManager: def get_all_plugin_configs(self, plugin_id: str) -> dict[str, str]: """获取插件所有配置""" - conn = self.db.get_conn() - rows = conn.execute( + conn = self.db.get_conn() + rows = conn.execute( "SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id, ) ).fetchall() conn.close() @@ -375,8 +374,8 @@ class PluginManager: def delete_plugin_config(self, plugin_id: str, key: str) -> bool: """删除插件配置""" - conn = self.db.get_conn() - cursor = conn.execute( + conn = self.db.get_conn() + cursor = conn.execute( "DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) ) conn.commit() @@ -386,8 +385,8 @@ class PluginManager: def record_plugin_usage(self, plugin_id: str) -> None: """记录插件使用""" - conn = self.db.get_conn() - now = datetime.now().isoformat() + conn = self.db.get_conn() + now = datetime.now().isoformat() conn.execute( """UPDATE plugins @@ -403,33 +402,33 @@ class ChromeExtensionHandler: """Chrome 扩展处理器""" def __init__(self, plugin_manager: PluginManager) -> None: - self.pm = plugin_manager + self.pm = plugin_manager def create_token( self, name: str, - user_id: str = None, - project_id: str = None, - permissions: list[str] = None, - expires_days: int = None, + user_id: str = None, + project_id: str = None, + permissions: list[str] = None, + expires_days: int = None, ) -> ChromeExtensionToken: """创建 Chrome 扩展令牌""" - token_id = str(uuid.uuid4())[:UUID_LENGTH] + token_id = str(uuid.uuid4())[:UUID_LENGTH] # 生成随机令牌 - raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip(' = ')}" + raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip(' = ')}" # 哈希存储 - token_hash = hashlib.sha256(raw_token.encode()).hexdigest() + token_hash = hashlib.sha256(raw_token.encode()).hexdigest() - now = datetime.now().isoformat() - expires_at = None + now = datetime.now().isoformat() + expires_at = None if expires_days: from datetime import timedelta - expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat() + expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO chrome_extension_tokens (id, token_hash, user_id, project_id, name, permissions, expires_at, @@ -452,22 +451,22 @@ class ChromeExtensionHandler: conn.close() return ChromeExtensionToken( - id = token_id, - token = raw_token, # 仅返回一次 - user_id = user_id, - project_id = project_id, - name = name, - permissions = permissions or ["read"], - expires_at = expires_at, - created_at = now, + id=token_id, + token=raw_token, # 仅返回一次 + user_id=user_id, + project_id=project_id, + name=name, + permissions=permissions or ["read"], + expires_at=expires_at, + created_at=now, ) def validate_token(self, token: str) -> ChromeExtensionToken | None: """验证 Chrome 扩展令牌""" - token_hash = hashlib.sha256(token.encode()).hexdigest() + token_hash = hashlib.sha256(token.encode()).hexdigest() - conn = self.pm.db.get_conn() - row = conn.execute( + conn = self.pm.db.get_conn() + row = conn.execute( """SELECT * FROM chrome_extension_tokens WHERE token_hash = ? AND is_revoked = 0""", (token_hash, ), @@ -482,8 +481,8 @@ class ChromeExtensionHandler: return None # 更新使用记录 - now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + now = datetime.now().isoformat() + conn = self.pm.db.get_conn() conn.execute( """UPDATE chrome_extension_tokens SET use_count = use_count + 1, last_used_at = ? @@ -494,22 +493,22 @@ class ChromeExtensionHandler: conn.close() return ChromeExtensionToken( - id = row["id"], - token = "", # 不返回实际令牌 - user_id = row["user_id"], - project_id = row["project_id"], - name = row["name"], - permissions = json.loads(row["permissions"]), - expires_at = row["expires_at"], - created_at = row["created_at"], - last_used_at = now, - use_count = row["use_count"] + 1, + id=row["id"], + token="", # 不返回实际令牌 + user_id=row["user_id"], + project_id=row["project_id"], + name=row["name"], + permissions=json.loads(row["permissions"]), + expires_at=row["expires_at"], + created_at=row["created_at"], + last_used_at=now, + use_count=row["use_count"] + 1, ) def revoke_token(self, token_id: str) -> bool: """撤销令牌""" - conn = self.pm.db.get_conn() - cursor = conn.execute( + conn = self.pm.db.get_conn() + cursor = conn.execute( "UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id, ) ) conn.commit() @@ -518,13 +517,13 @@ class ChromeExtensionHandler: return cursor.rowcount > 0 def list_tokens( - self, user_id: str = None, project_id: str = None + self, user_id: str = None, project_id: str = None ) -> list[ChromeExtensionToken]: """列出令牌""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - conditions = ["is_revoked = 0"] - params = [] + conditions = ["is_revoked = 0"] + params = [] if user_id: conditions.append("user_id = ?") @@ -533,29 +532,29 @@ class ChromeExtensionHandler: conditions.append("project_id = ?") params.append(project_id) - where_clause = " AND ".join(conditions) + where_clause = " AND ".join(conditions) - rows = conn.execute( + rows = conn.execute( f"SELECT * FROM chrome_extension_tokens WHERE {where_clause} ORDER BY created_at DESC", params, ).fetchall() conn.close() - tokens = [] + tokens = [] for row in rows: tokens.append( ChromeExtensionToken( - id = row["id"], - token = "", # 不返回实际令牌 - user_id = row["user_id"], - project_id = row["project_id"], - name = row["name"], - permissions = json.loads(row["permissions"]), - expires_at = row["expires_at"], - created_at = row["created_at"], - last_used_at = row["last_used_at"], - use_count = row["use_count"], - is_revoked = bool(row["is_revoked"]), + id=row["id"], + token="", # 不返回实际令牌 + user_id=row["user_id"], + project_id=row["project_id"], + name=row["name"], + permissions=json.loads(row["permissions"]), + expires_at=row["expires_at"], + created_at=row["created_at"], + last_used_at=row["last_used_at"], + use_count=row["use_count"], + is_revoked=bool(row["is_revoked"]), ) ) @@ -567,7 +566,7 @@ class ChromeExtensionHandler: url: str, title: str, content: str, - html_content: str = None, + html_content: str = None, ) -> dict: """导入网页内容""" if not token.project_id: @@ -577,13 +576,13 @@ class ChromeExtensionHandler: return {"success": False, "error": "Insufficient permissions"} # 创建转录记录(将网页作为文档处理) - transcript_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + transcript_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() # 构建完整文本 - full_text = f"# {title}\n\nURL: {url}\n\n{content}" + full_text = f"# {title}\n\nURL: {url}\n\n{content}" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO transcripts (id, project_id, filename, full_text, type, created_at) @@ -607,22 +606,22 @@ class BotHandler: """飞书/钉钉机器人处理器""" def __init__(self, plugin_manager: PluginManager, bot_type: str) -> None: - self.pm = plugin_manager - self.bot_type = bot_type + self.pm = plugin_manager + self.bot_type = bot_type def create_session( self, session_id: str, session_name: str, - project_id: str = None, - webhook_url: str = "", - secret: str = "", + project_id: str = None, + webhook_url: str = "", + secret: str = "", ) -> BotSession: """创建机器人会话""" - bot_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + bot_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO bot_sessions (id, bot_type, session_id, session_name, project_id, webhook_url, secret, @@ -646,22 +645,22 @@ class BotHandler: conn.close() return BotSession( - id = bot_id, - bot_type = self.bot_type, - session_id = session_id, - session_name = session_name, - project_id = project_id, - webhook_url = webhook_url, - secret = secret, - is_active = True, - created_at = now, - updated_at = now, + id=bot_id, + bot_type=self.bot_type, + session_id=session_id, + session_name=session_name, + project_id=project_id, + webhook_url=webhook_url, + secret=secret, + is_active=True, + created_at=now, + updated_at=now, ) def get_session(self, session_id: str) -> BotSession | None: """获取会话""" - conn = self.pm.db.get_conn() - row = conn.execute( + conn = self.pm.db.get_conn() + row = conn.execute( """SELECT * FROM bot_sessions WHERE session_id = ? AND bot_type = ?""", (session_id, self.bot_type), @@ -672,18 +671,18 @@ class BotHandler: return self._row_to_session(row) return None - def list_sessions(self, project_id: str = None) -> list[BotSession]: + def list_sessions(self, project_id: str = None) -> list[BotSession]: """列出会话""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() if project_id: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM bot_sessions WHERE bot_type = ? AND project_id = ? ORDER BY created_at DESC""", (self.bot_type, project_id), ).fetchall() else: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM bot_sessions WHERE bot_type = ? ORDER BY created_at DESC""", (self.bot_type, ), @@ -695,11 +694,11 @@ class BotHandler: def update_session(self, session_id: str, **kwargs) -> BotSession | None: """更新会话""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - allowed_fields = ["session_name", "project_id", "webhook_url", "secret", "is_active"] - updates = [] - values = [] + allowed_fields = ["session_name", "project_id", "webhook_url", "secret", "is_active"] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -715,7 +714,7 @@ class BotHandler: values.append(session_id) values.append(self.bot_type) - query = ( + query = ( f"UPDATE bot_sessions SET {', '.join(updates)} WHERE session_id = ? AND bot_type = ?" ) conn.execute(query, values) @@ -726,8 +725,8 @@ class BotHandler: def delete_session(self, session_id: str) -> bool: """删除会话""" - conn = self.pm.db.get_conn() - cursor = conn.execute( + conn = self.pm.db.get_conn() + cursor = conn.execute( "DELETE FROM bot_sessions WHERE session_id = ? AND bot_type = ?", (session_id, self.bot_type), ) @@ -739,26 +738,26 @@ class BotHandler: def _row_to_session(self, row: sqlite3.Row) -> BotSession: """将数据库行转换为 BotSession 对象""" return BotSession( - id = row["id"], - bot_type = row["bot_type"], - session_id = row["session_id"], - session_name = row["session_name"], - project_id = row["project_id"], - webhook_url = row["webhook_url"], - secret = row["secret"], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], - last_message_at = row["last_message_at"], - message_count = row["message_count"], + id=row["id"], + bot_type=row["bot_type"], + session_id=row["session_id"], + session_name=row["session_name"], + project_id=row["project_id"], + webhook_url=row["webhook_url"], + secret=row["secret"], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_message_at=row["last_message_at"], + message_count=row["message_count"], ) async def handle_message(self, session: BotSession, message: dict) -> dict: """处理收到的消息""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() # 更新消息统计 - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """UPDATE bot_sessions SET message_count = message_count + 1, last_message_at = ? @@ -769,11 +768,11 @@ class BotHandler: conn.close() # 处理消息 - msg_type = message.get("msg_type", "text") - content = message.get("content", {}) + msg_type = message.get("msg_type", "text") + content = message.get("content", {}) if msg_type == "text": - text = content.get("text", "") + text = content.get("text", "") return await self._handle_text_message(session, text, message) elif msg_type == "audio": # 处理音频消息 @@ -802,8 +801,8 @@ class BotHandler: return {"success": True, "response": "⚠️ 当前会话未绑定项目"} # 获取项目状态 - summary = self.pm.db.get_project_summary(session.project_id) - stats = summary.get("statistics", {}) + summary = self.pm.db.get_project_summary(session.project_id) + stats = summary.get("statistics", {}) return { "success": True, @@ -825,17 +824,17 @@ class BotHandler: return {"success": False, "error": "Session not bound to any project"} # 下载音频文件 - audio_url = message.get("content", {}).get("download_url") + audio_url = message.get("content", {}).get("download_url") if not audio_url: return {"success": False, "error": "No audio URL provided"} try: async with httpx.AsyncClient() as client: - response = await client.get(audio_url) - audio_data = response.content + response = await client.get(audio_url) + audio_data = response.content # 保存音频文件 - filename = f"bot_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" + filename = f"bot_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" # 这里应该调用 ASR 服务进行转录 # 简化处理,返回提示 @@ -853,7 +852,7 @@ class BotHandler: """处理文件消息""" return {"success": True, "response": "📎 收到文件,正在处理中..."} - async def send_message(self, session: BotSession, message: str, msg_type: str = "text") -> bool: + async def send_message(self, session: BotSession, message: str, msg_type: str = "text") -> bool: """发送消息到群聊""" if not session.webhook_url: return False @@ -872,21 +871,21 @@ class BotHandler: async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool: """发送飞书消息""" - timestamp = str(int(time.time())) + timestamp = str(int(time.time())) # 生成签名 if session.secret: - string_to_sign = f"{timestamp}\n{session.secret}" - hmac_code = hmac.new( + string_to_sign = f"{timestamp}\n{session.secret}" + hmac_code = hmac.new( session.secret.encode("utf-8"), string_to_sign.encode("utf-8"), - digestmod = hashlib.sha256, + digestmod=hashlib.sha256, ).digest() - sign = base64.b64encode(hmac_code).decode("utf-8") + sign = base64.b64encode(hmac_code).decode("utf-8") else: - sign = "" + sign = "" - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "text", @@ -894,8 +893,8 @@ class BotHandler: } async with httpx.AsyncClient() as client: - response = await client.post( - session.webhook_url, json = payload, headers = {"Content-Type": "application/json"} + response = await client.post( + session.webhook_url, json=payload, headers={"Content-Type": "application/json"} ) return response.status_code == 200 @@ -903,30 +902,30 @@ class BotHandler: self, session: BotSession, message: str, msg_type: str ) -> bool: """发送钉钉消息""" - timestamp = str(round(time.time() * 1000)) + timestamp = str(round(time.time() * 1000)) # 生成签名 if session.secret: - string_to_sign = f"{timestamp}\n{session.secret}" - hmac_code = hmac.new( + string_to_sign = f"{timestamp}\n{session.secret}" + hmac_code = hmac.new( session.secret.encode("utf-8"), string_to_sign.encode("utf-8"), - digestmod = hashlib.sha256, + digestmod=hashlib.sha256, ).digest() - sign = base64.b64encode(hmac_code).decode("utf-8") - sign = urllib.parse.quote(sign) + sign = base64.b64encode(hmac_code).decode("utf-8") + sign = urllib.parse.quote(sign) else: - sign = "" + sign = "" - payload = {"msgtype": "text", "text": {"content": message}} + payload = {"msgtype": "text", "text": {"content": message}} - url = session.webhook_url + url = session.webhook_url if sign: - url = f"{url}×tamp = {timestamp}&sign = {sign}" + url = f"{url}×tamp = {timestamp}&sign = {sign}" async with httpx.AsyncClient() as client: - response = await client.post( - url, json = payload, headers = {"Content-Type": "application/json"} + response = await client.post( + url, json=payload, headers={"Content-Type": "application/json"} ) return response.status_code == 200 @@ -935,23 +934,23 @@ class WebhookIntegration: """Zapier/Make Webhook 集成""" def __init__(self, plugin_manager: PluginManager, endpoint_type: str) -> None: - self.pm = plugin_manager - self.endpoint_type = endpoint_type + self.pm = plugin_manager + self.endpoint_type = endpoint_type def create_endpoint( self, name: str, endpoint_url: str, - project_id: str = None, - auth_type: str = "none", - auth_config: dict = None, - trigger_events: list[str] = None, + project_id: str = None, + auth_type: str = "none", + auth_config: dict = None, + trigger_events: list[str] = None, ) -> WebhookEndpoint: """创建 Webhook 端点""" - endpoint_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + endpoint_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO webhook_endpoints (id, name, endpoint_type, endpoint_url, project_id, auth_type, auth_config, @@ -976,23 +975,23 @@ class WebhookIntegration: conn.close() return WebhookEndpoint( - id = endpoint_id, - name = name, - endpoint_type = self.endpoint_type, - endpoint_url = endpoint_url, - project_id = project_id, - auth_type = auth_type, - auth_config = auth_config or {}, - trigger_events = trigger_events or [], - is_active = True, - created_at = now, - updated_at = now, + id=endpoint_id, + name=name, + endpoint_type=self.endpoint_type, + endpoint_url=endpoint_url, + project_id=project_id, + auth_type=auth_type, + auth_config=auth_config or {}, + trigger_events=trigger_events or [], + is_active=True, + created_at=now, + updated_at=now, ) def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None: """获取端点""" - conn = self.pm.db.get_conn() - row = conn.execute( + conn = self.pm.db.get_conn() + row = conn.execute( "SELECT * FROM webhook_endpoints WHERE id = ? AND endpoint_type = ?", (endpoint_id, self.endpoint_type), ).fetchone() @@ -1002,18 +1001,18 @@ class WebhookIntegration: return self._row_to_endpoint(row) return None - def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]: + def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]: """列出端点""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() if project_id: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM webhook_endpoints WHERE endpoint_type = ? AND project_id = ? ORDER BY created_at DESC""", (self.endpoint_type, project_id), ).fetchall() else: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM webhook_endpoints WHERE endpoint_type = ? ORDER BY created_at DESC""", (self.endpoint_type, ), @@ -1025,9 +1024,9 @@ class WebhookIntegration: def update_endpoint(self, endpoint_id: str, **kwargs) -> WebhookEndpoint | None: """更新端点""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - allowed_fields = [ + allowed_fields = [ "name", "endpoint_url", "project_id", @@ -1036,8 +1035,8 @@ class WebhookIntegration: "trigger_events", "is_active", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -1055,7 +1054,7 @@ class WebhookIntegration: values.append(datetime.now().isoformat()) values.append(endpoint_id) - query = f"UPDATE webhook_endpoints SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE webhook_endpoints SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -1064,8 +1063,8 @@ class WebhookIntegration: def delete_endpoint(self, endpoint_id: str) -> bool: """删除端点""" - conn = self.pm.db.get_conn() - cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id, )) + conn = self.pm.db.get_conn() + cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id, )) conn.commit() conn.close() @@ -1074,19 +1073,19 @@ class WebhookIntegration: def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint: """将数据库行转换为 WebhookEndpoint 对象""" return WebhookEndpoint( - id = row["id"], - name = row["name"], - endpoint_type = row["endpoint_type"], - endpoint_url = row["endpoint_url"], - project_id = row["project_id"], - auth_type = row["auth_type"], - auth_config = json.loads(row["auth_config"]) if row["auth_config"] else {}, - trigger_events = json.loads(row["trigger_events"]) if row["trigger_events"] else [], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], - last_triggered_at = row["last_triggered_at"], - trigger_count = row["trigger_count"], + id=row["id"], + name=row["name"], + endpoint_type=row["endpoint_type"], + endpoint_url=row["endpoint_url"], + project_id=row["project_id"], + auth_type=row["auth_type"], + auth_config=json.loads(row["auth_config"]) if row["auth_config"] else {}, + trigger_events=json.loads(row["trigger_events"]) if row["trigger_events"] else [], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_triggered_at=row["last_triggered_at"], + trigger_count=row["trigger_count"], ) async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool: @@ -1098,29 +1097,29 @@ class WebhookIntegration: return False try: - headers = {"Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} # 添加认证头 if endpoint.auth_type == "api_key": - api_key = endpoint.auth_config.get("api_key", "") - header_name = endpoint.auth_config.get("header_name", "X-API-Key") - headers[header_name] = api_key + api_key = endpoint.auth_config.get("api_key", "") + header_name = endpoint.auth_config.get("header_name", "X-API-Key") + headers[header_name] = api_key elif endpoint.auth_type == "bearer": - token = endpoint.auth_config.get("token", "") - headers["Authorization"] = f"Bearer {token}" + token = endpoint.auth_config.get("token", "") + headers["Authorization"] = f"Bearer {token}" - payload = {"event": event_type, "timestamp": datetime.now().isoformat(), "data": data} + payload = {"event": event_type, "timestamp": datetime.now().isoformat(), "data": data} async with httpx.AsyncClient() as client: - response = await client.post( - endpoint.endpoint_url, json = payload, headers = headers, timeout = 30.0 + response = await client.post( + endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0 ) - success = response.status_code in [200, 201, 202] + success = response.status_code in [200, 201, 202] # 更新触发统计 - now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + now = datetime.now().isoformat() + conn = self.pm.db.get_conn() conn.execute( """UPDATE webhook_endpoints SET trigger_count = trigger_count + 1, last_triggered_at = ? @@ -1138,13 +1137,13 @@ class WebhookIntegration: async def test_endpoint(self, endpoint: WebhookEndpoint) -> dict: """测试端点""" - test_data = { + test_data = { "message": "This is a test event from InsightFlow", "test": True, "timestamp": datetime.now().isoformat(), } - success = await self.trigger(endpoint, "test", test_data) + success = await self.trigger(endpoint, "test", test_data) return { "success": success, @@ -1158,7 +1157,7 @@ class WebDAVSyncManager: """WebDAV 同步管理""" def __init__(self, plugin_manager: PluginManager) -> None: - self.pm = plugin_manager + self.pm = plugin_manager def create_sync( self, @@ -1167,15 +1166,15 @@ class WebDAVSyncManager: server_url: str, username: str, password: str, - remote_path: str = "/insightflow", - sync_mode: str = "bidirectional", - sync_interval: int = 3600, + remote_path: str = "/insightflow", + sync_mode: str = "bidirectional", + sync_interval: int = 3600, ) -> WebDAVSync: """创建 WebDAV 同步配置""" - sync_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + sync_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO webdav_syncs (id, name, project_id, server_url, username, password, remote_path, @@ -1202,42 +1201,42 @@ class WebDAVSyncManager: conn.close() return WebDAVSync( - id = sync_id, - name = name, - project_id = project_id, - server_url = server_url, - username = username, - password = password, - remote_path = remote_path, - sync_mode = sync_mode, - sync_interval = sync_interval, - last_sync_status = "pending", - is_active = True, - created_at = now, - updated_at = now, + id=sync_id, + name=name, + project_id=project_id, + server_url=server_url, + username=username, + password=password, + remote_path=remote_path, + sync_mode=sync_mode, + sync_interval=sync_interval, + last_sync_status="pending", + is_active=True, + created_at=now, + updated_at=now, ) def get_sync(self, sync_id: str) -> WebDAVSync | None: """获取同步配置""" - conn = self.pm.db.get_conn() - row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id, )).fetchone() + conn = self.pm.db.get_conn() + row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id, )).fetchone() conn.close() if row: return self._row_to_sync(row) return None - def list_syncs(self, project_id: str = None) -> list[WebDAVSync]: + def list_syncs(self, project_id: str = None) -> list[WebDAVSync]: """列出同步配置""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() if project_id: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC", (project_id, ), ).fetchall() else: - rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall() + rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall() conn.close() @@ -1245,9 +1244,9 @@ class WebDAVSyncManager: def update_sync(self, sync_id: str, **kwargs) -> WebDAVSync | None: """更新同步配置""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - allowed_fields = [ + allowed_fields = [ "name", "server_url", "username", @@ -1257,8 +1256,8 @@ class WebDAVSyncManager: "sync_interval", "is_active", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -1273,7 +1272,7 @@ class WebDAVSyncManager: values.append(datetime.now().isoformat()) values.append(sync_id) - query = f"UPDATE webdav_syncs SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE webdav_syncs SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -1282,8 +1281,8 @@ class WebDAVSyncManager: def delete_sync(self, sync_id: str) -> bool: """删除同步配置""" - conn = self.pm.db.get_conn() - cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id, )) + conn = self.pm.db.get_conn() + cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id, )) conn.commit() conn.close() @@ -1292,22 +1291,22 @@ class WebDAVSyncManager: def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync: """将数据库行转换为 WebDAVSync 对象""" return WebDAVSync( - id = row["id"], - name = row["name"], - project_id = row["project_id"], - server_url = row["server_url"], - username = row["username"], - password = row["password"], - remote_path = row["remote_path"], - sync_mode = row["sync_mode"], - sync_interval = row["sync_interval"], - last_sync_at = row["last_sync_at"], - last_sync_status = row["last_sync_status"], - last_sync_error = row["last_sync_error"] or "", - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], - sync_count = row["sync_count"], + id=row["id"], + name=row["name"], + project_id=row["project_id"], + server_url=row["server_url"], + username=row["username"], + password=row["password"], + remote_path=row["remote_path"], + sync_mode=row["sync_mode"], + sync_interval=row["sync_interval"], + last_sync_at=row["last_sync_at"], + last_sync_status=row["last_sync_status"], + last_sync_error=row["last_sync_error"] or "", + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + sync_count=row["sync_count"], ) async def test_connection(self, sync: WebDAVSync) -> dict: @@ -1316,7 +1315,7 @@ class WebDAVSyncManager: return {"success": False, "error": "WebDAV library not available"} try: - client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password)) + client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password)) # 尝试列出根目录 client.list("/") @@ -1335,26 +1334,26 @@ class WebDAVSyncManager: return {"success": False, "error": "Sync is not active"} try: - client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password)) + client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password)) # 确保远程目录存在 - remote_project_path = f"{sync.remote_path}/{sync.project_id}" + remote_project_path = f"{sync.remote_path}/{sync.project_id}" try: client.mkdir(remote_project_path) except (OSError, IOError): pass # 目录可能已存在 # 获取项目数据 - project = self.pm.db.get_project(sync.project_id) + project = self.pm.db.get_project(sync.project_id) if not project: return {"success": False, "error": "Project not found"} # 导出项目数据为 JSON - entities = self.pm.db.list_project_entities(sync.project_id) - relations = self.pm.db.list_project_relations(sync.project_id) - transcripts = self.pm.db.list_project_transcripts(sync.project_id) + entities = self.pm.db.list_project_entities(sync.project_id) + relations = self.pm.db.list_project_relations(sync.project_id) + transcripts = self.pm.db.list_project_transcripts(sync.project_id) - export_data = { + export_data = { "project": { "id": project.id, "name": project.name, @@ -1367,22 +1366,22 @@ class WebDAVSyncManager: } # 上传 JSON 文件 - json_content = json.dumps(export_data, ensure_ascii = False, indent = 2) - json_path = f"{remote_project_path}/project_export.json" + json_content = json.dumps(export_data, ensure_ascii=False, indent=2) + json_path = f"{remote_project_path}/project_export.json" # 使用临时文件上传 import tempfile - with tempfile.NamedTemporaryFile(mode = "w", suffix = ".json", delete = False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write(json_content) - temp_path = f.name + temp_path = f.name client.upload_file(temp_path, json_path) os.unlink(temp_path) # 更新同步状态 - now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + now = datetime.now().isoformat() + conn = self.pm.db.get_conn() conn.execute( """UPDATE webdav_syncs SET last_sync_at = ?, last_sync_status = ?, sync_count = sync_count + 1 @@ -1402,7 +1401,7 @@ class WebDAVSyncManager: except Exception as e: # 更新失败状态 - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """UPDATE webdav_syncs SET last_sync_status = ?, last_sync_error = ? @@ -1416,12 +1415,12 @@ class WebDAVSyncManager: # Singleton instance -_plugin_manager = None +_plugin_manager = None -def get_plugin_manager(db_manager = None) -> None: +def get_plugin_manager(db_manager=None) -> None: """获取 PluginManager 单例""" global _plugin_manager if _plugin_manager is None: - _plugin_manager = PluginManager(db_manager) + _plugin_manager = PluginManager(db_manager) return _plugin_manager diff --git a/backend/rate_limiter.py b/backend/rate_limiter.py index ea1ea8e..c32c69e 100644 --- a/backend/rate_limiter.py +++ b/backend/rate_limiter.py @@ -17,9 +17,9 @@ from functools import wraps class RateLimitConfig: """限流配置""" - requests_per_minute: int = 60 - burst_size: int = 10 # 突发请求数 - window_size: int = 60 # 窗口大小(秒) + requests_per_minute: int = 60 + burst_size: int = 10 # 突发请求数 + window_size: int = 60 # 窗口大小(秒) @dataclass @@ -35,16 +35,16 @@ class RateLimitInfo: class SlidingWindowCounter: """滑动窗口计数器""" - def __init__(self, window_size: int = 60) -> None: - self.window_size = window_size - self.requests: dict[int, int] = defaultdict(int) # 秒级计数 - self._lock = asyncio.Lock() - self._cleanup_lock = asyncio.Lock() + def __init__(self, window_size: int = 60) -> None: + self.window_size = window_size + self.requests: dict[int, int] = defaultdict(int) # 秒级计数 + self._lock = asyncio.Lock() + self._cleanup_lock = asyncio.Lock() async def add_request(self) -> int: """添加请求,返回当前窗口内的请求数""" async with self._lock: - now = int(time.time()) + now = int(time.time()) self.requests[now] += 1 self._cleanup_old(now) return sum(self.requests.values()) @@ -52,14 +52,14 @@ class SlidingWindowCounter: async def get_count(self) -> int: """获取当前窗口内的请求数""" async with self._lock: - now = int(time.time()) + now = int(time.time()) self._cleanup_old(now) return sum(self.requests.values()) def _cleanup_old(self, now: int) -> None: """清理过期的请求记录 - 使用独立锁避免竞态条件""" - cutoff = now - self.window_size - old_keys = [k for k in list(self.requests.keys()) if k < cutoff] + cutoff = now - self.window_size + old_keys = [k for k in list(self.requests.keys()) if k < cutoff] for k in old_keys: self.requests.pop(k, None) @@ -69,13 +69,13 @@ class RateLimiter: def __init__(self) -> None: # key -> SlidingWindowCounter - self.counters: dict[str, SlidingWindowCounter] = {} + self.counters: dict[str, SlidingWindowCounter] = {} # key -> RateLimitConfig - self.configs: dict[str, RateLimitConfig] = {} - self._lock = asyncio.Lock() - self._cleanup_lock = asyncio.Lock() + self.configs: dict[str, RateLimitConfig] = {} + self._lock = asyncio.Lock() + self._cleanup_lock = asyncio.Lock() - async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo: + async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo: """ 检查是否允许请求 @@ -87,70 +87,70 @@ class RateLimiter: RateLimitInfo """ if config is None: - config = RateLimitConfig() + config = RateLimitConfig() async with self._lock: if key not in self.counters: - self.counters[key] = SlidingWindowCounter(config.window_size) - self.configs[key] = config + self.counters[key] = SlidingWindowCounter(config.window_size) + self.configs[key] = config - counter = self.counters[key] - stored_config = self.configs.get(key, config) + counter = self.counters[key] + stored_config = self.configs.get(key, config) # 获取当前计数 - current_count = await counter.get_count() + current_count = await counter.get_count() # 计算剩余配额 - remaining = max(0, stored_config.requests_per_minute - current_count) + remaining = max(0, stored_config.requests_per_minute - current_count) # 计算重置时间 - now = int(time.time()) - reset_time = now + stored_config.window_size + now = int(time.time()) + reset_time = now + stored_config.window_size # 检查是否超过限制 if current_count >= stored_config.requests_per_minute: return RateLimitInfo( - allowed = False, - remaining = 0, - reset_time = reset_time, - retry_after = stored_config.window_size, + allowed=False, + remaining=0, + reset_time=reset_time, + retry_after=stored_config.window_size, ) # 允许请求,增加计数 await counter.add_request() return RateLimitInfo( - allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0 + allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0 ) async def get_limit_info(self, key: str) -> RateLimitInfo: """获取限流信息(不增加计数)""" if key not in self.counters: - config = RateLimitConfig() + config = RateLimitConfig() return RateLimitInfo( - allowed = True, - remaining = config.requests_per_minute, - reset_time = int(time.time()) + config.window_size, - retry_after = 0, + allowed=True, + remaining=config.requests_per_minute, + reset_time=int(time.time()) + config.window_size, + retry_after=0, ) - counter = self.counters[key] - config = self.configs.get(key, RateLimitConfig()) + counter = self.counters[key] + config = self.configs.get(key, RateLimitConfig()) - current_count = await counter.get_count() - remaining = max(0, config.requests_per_minute - current_count) - reset_time = int(time.time()) + config.window_size + current_count = await counter.get_count() + remaining = max(0, config.requests_per_minute - current_count) + reset_time = int(time.time()) + config.window_size return RateLimitInfo( - allowed = current_count < config.requests_per_minute, - remaining = remaining, - reset_time = reset_time, - retry_after = max(0, config.window_size) + allowed=current_count < config.requests_per_minute, + remaining=remaining, + reset_time=reset_time, + retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0, ) - def reset(self, key: str | None = None) -> None: + def reset(self, key: str | None = None) -> None: """重置限流计数器""" if key: self.counters.pop(key, None) @@ -161,21 +161,21 @@ class RateLimiter: # 全局限流器实例 -_rate_limiter: RateLimiter | None = None +_rate_limiter: RateLimiter | None = None def get_rate_limiter() -> RateLimiter: """获取限流器实例""" global _rate_limiter if _rate_limiter is None: - _rate_limiter = RateLimiter() + _rate_limiter = RateLimiter() return _rate_limiter # 限流装饰器(用于函数级别限流) -def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: +def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: """ 限流装饰器 @@ -185,13 +185,13 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None """ def decorator(func) -> None: - limiter = get_rate_limiter() - config = RateLimitConfig(requests_per_minute = requests_per_minute) + limiter = get_rate_limiter() + config = RateLimitConfig(requests_per_minute=requests_per_minute) @wraps(func) async def async_wrapper(*args, **kwargs) -> None: - key = key_func(*args, **kwargs) if key_func else func.__name__ - info = await limiter.is_allowed(key, config) + key = key_func(*args, **kwargs) if key_func else func.__name__ + info = await limiter.is_allowed(key, config) if not info.allowed: raise RateLimitExceeded( @@ -202,9 +202,9 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None @wraps(func) def sync_wrapper(*args, **kwargs) -> None: - key = key_func(*args, **kwargs) if key_func else func.__name__ + key = key_func(*args, **kwargs) if key_func else func.__name__ # 同步版本使用 asyncio.run - info = asyncio.run(limiter.is_allowed(key, config)) + info = asyncio.run(limiter.is_allowed(key, config)) if not info.allowed: raise RateLimitExceeded( diff --git a/backend/search_manager.py b/backend/search_manager.py index 641bb6c..026b64d 100644 --- a/backend/search_manager.py +++ b/backend/search_manager.py @@ -23,9 +23,9 @@ from enum import Enum class SearchOperator(Enum): """搜索操作符""" - AND = "AND" - OR = "OR" - NOT = "NOT" + AND = "AND" + OR = "OR" + NOT = "NOT" # 尝试导入 sentence-transformers 用于语义搜索 @@ -33,9 +33,9 @@ try: from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity - SENTENCE_TRANSFORMERS_AVAILABLE = True + SENTENCE_TRANSFORMERS_AVAILABLE = True except ImportError: - SENTENCE_TRANSFORMERS_AVAILABLE = False + SENTENCE_TRANSFORMERS_AVAILABLE = False # ==================== 数据模型 ==================== @@ -49,8 +49,8 @@ class SearchResult: content_type: str # transcript, entity, relation project_id: str score: float - highlights: list[tuple[int, int]] = field(default_factory = list) # 高亮位置 - metadata: dict = field(default_factory = dict) + highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置 + metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: return { @@ -73,11 +73,11 @@ class SemanticSearchResult: content_type: str project_id: str similarity: float - embedding: list[float] | None = None - metadata: dict = field(default_factory = dict) + embedding: list[float] | None = None + metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: - result = { + result = { "id": self.id, "content": self.content[:500] + "..." if len(self.content) > 500 else self.content, "content_type": self.content_type, @@ -86,7 +86,7 @@ class SemanticSearchResult: "metadata": self.metadata, } if self.embedding: - result["embedding_dim"] = len(self.embedding) + result["embedding_dim"] = len(self.embedding) return result @@ -132,7 +132,7 @@ class KnowledgeGap: severity: str # high, medium, low suggestions: list[str] related_entities: list[str] - metadata: dict = field(default_factory = dict) + metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: return { @@ -189,19 +189,19 @@ class FullTextSearch: - 支持布尔搜索(AND/OR/NOT) """ - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_search_tables() def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_search_tables(self) -> None: """初始化搜索相关表""" - conn = self._get_conn() + conn = self._get_conn() # 搜索索引表 conn.execute(""" @@ -251,25 +251,25 @@ class FullTextSearch: 实际生产环境可以使用 jieba 等分词工具 """ # 清理文本 - text = text.lower() + text = text.lower() # 提取中文字符、英文单词和数字 - tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text) + tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text) return tokens def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]: """提取每个词在文本中的位置""" - positions = defaultdict(list) - text_lower = text.lower() + positions = defaultdict(list) + text_lower = text.lower() for token in tokens: # 查找所有出现位置 - start = 0 + start = 0 while True: - pos = text_lower.find(token, start) + pos = text_lower.find(token, start) if pos == -1: break positions[token].append(pos) - start = pos + 1 + start = pos + 1 return dict(positions) @@ -287,24 +287,24 @@ class FullTextSearch: bool: 是否成功 """ try: - conn = self._get_conn() + conn = self._get_conn() # 分词 - tokens = self._tokenize(text) + tokens = self._tokenize(text) if not tokens: conn.close() return False # 提取位置信息 - token_positions = self._extract_positions(text, tokens) + token_positions = self._extract_positions(text, tokens) # 计算词频 - token_freq = defaultdict(int) + token_freq = defaultdict(int) for token in tokens: token_freq[token] += 1 - index_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] - now = datetime.now().isoformat() + index_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] + now = datetime.now().isoformat() # 保存索引 conn.execute( @@ -318,8 +318,8 @@ class FullTextSearch: content_id, content_type, project_id, - json.dumps(tokens, ensure_ascii = False), - json.dumps(token_positions, ensure_ascii = False), + json.dumps(tokens, ensure_ascii=False), + json.dumps(token_positions, ensure_ascii=False), now, now, ), @@ -327,7 +327,7 @@ class FullTextSearch: # 保存词频统计 for token, freq in token_freq.items(): - positions = token_positions.get(token, []) + positions = token_positions.get(token, []) conn.execute( """ INSERT OR REPLACE INTO search_term_freq @@ -340,7 +340,7 @@ class FullTextSearch: content_type, project_id, freq, - json.dumps(positions, ensure_ascii = False), + json.dumps(positions, ensure_ascii=False), ), ) @@ -355,10 +355,10 @@ class FullTextSearch: def search( self, query: str, - project_id: str | None = None, - content_types: list[str] | None = None, - limit: int = 20, - offset: int = 0, + project_id: str | None = None, + content_types: list[str] | None = None, + limit: int = 20, + offset: int = 0, ) -> list[SearchResult]: """ 全文搜索 @@ -374,18 +374,18 @@ class FullTextSearch: List[SearchResult]: 搜索结果列表 """ # 解析布尔查询 - parsed_query = self._parse_boolean_query(query) + parsed_query = self._parse_boolean_query(query) # 执行搜索 - results = self._execute_boolean_search(parsed_query, project_id, content_types) + results = self._execute_boolean_search(parsed_query, project_id, content_types) # 计算相关性分数 - scored_results = self._score_results(results, parsed_query) + scored_results = self._score_results(results, parsed_query) # 排序和分页 - scored_results.sort(key = lambda x: x.score, reverse = True) + scored_results.sort(key=lambda x: x.score, reverse=True) - return scored_results[offset : offset + limit] + return scored_results[offset: offset + limit] def _parse_boolean_query(self, query: str) -> dict: """ @@ -397,65 +397,65 @@ class FullTextSearch: - NOT: NOT 词1 或 词1 -词2 - 短语: "精确短语" """ - query = query.strip() + query = query.strip() # 提取短语(引号内的内容) - phrases = re.findall(r'"([^"]+)"', query) - query_without_phrases = re.sub(r'"[^"]+"', "", query) + phrases = re.findall(r'"([^"]+)"', query) + query_without_phrases = re.sub(r'"[^"]+"', "", query) # 解析布尔操作 - and_terms = [] - or_terms = [] - not_terms = [] + and_terms = [] + or_terms = [] + not_terms = [] # 处理 NOT - not_pattern = r"(?:NOT\s+|\-)(\w+)" - not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) + not_pattern = r"(?:NOT\s+|\-)(\w+)" + not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) not_terms.extend(not_matches) - query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags = re.IGNORECASE) + query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE) # 处理 OR - or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags = re.IGNORECASE) + or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE) if len(or_parts) > 1: - or_terms = [p.strip() for p in or_parts[1:] if p.strip()] - query_without_phrases = or_parts[0] + or_terms = [p.strip() for p in or_parts[1:] if p.strip()] + query_without_phrases = or_parts[0] # 剩余的作为 AND 条件 - and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()] + and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()] return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases} def _execute_boolean_search( self, parsed_query: dict, - project_id: str | None = None, - content_types: list[str] | None = None, + project_id: str | None = None, + content_types: list[str] | None = None, ) -> list[dict]: """执行布尔搜索""" - conn = self._get_conn() + conn = self._get_conn() # 构建基础查询 - base_where = [] - params = [] + base_where = [] + params = [] if project_id: base_where.append("project_id = ?") params.append(project_id) if content_types: - placeholders = ", ".join(["?" for _ in content_types]) + placeholders = ", ".join(["?" for _ in content_types]) base_where.append(f"content_type IN ({placeholders})") params.extend(content_types) - base_where_str = " AND ".join(base_where) if base_where else "1 = 1" + base_where_str = " AND ".join(base_where) if base_where else "1 = 1" # 获取候选结果 - candidates = set() + candidates = set() # 处理 AND 条件 if parsed_query["and"]: for term in parsed_query["and"]: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq @@ -464,17 +464,17 @@ class FullTextSearch: [term] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} if not candidates: - candidates = term_contents + candidates = term_contents else: candidates &= term_contents # 交集 # 处理 OR 条件 if parsed_query["or"]: for term in parsed_query["or"]: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq @@ -483,17 +483,17 @@ class FullTextSearch: [term] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} candidates |= term_contents # 并集 # 如果没有 AND 和 OR,但有 phrases,使用 phrases if not candidates and parsed_query["phrases"]: for phrase in parsed_query["phrases"]: - phrase_tokens = self._tokenize(phrase) + phrase_tokens = self._tokenize(phrase) if phrase_tokens: # 查找包含所有短语的文档 for token in phrase_tokens: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq @@ -502,17 +502,17 @@ class FullTextSearch: [token] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} if not candidates: - candidates = term_contents + candidates = term_contents else: candidates &= term_contents # 处理 NOT 条件(排除) if parsed_query["not"]: for term in parsed_query["not"]: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type FROM search_term_freq @@ -521,14 +521,14 @@ class FullTextSearch: [term] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} candidates -= term_contents # 差集 # 获取完整内容 - results = [] + results = [] for content_id, content_type in candidates: # 获取原始内容 - content = self._get_content_by_id(conn, content_id, content_type) + content = self._get_content_by_id(conn, content_id, content_type) if content: results.append( { @@ -550,13 +550,13 @@ class FullTextSearch: """根据ID获取内容""" try: if content_type == "transcript": - row = conn.execute( + row = conn.execute( "SELECT full_text FROM transcripts WHERE id = ?", (content_id, ) ).fetchone() return row["full_text"] if row else None elif content_type == "entity": - row = conn.execute( + row = conn.execute( "SELECT name, definition FROM entities WHERE id = ?", (content_id, ) ).fetchone() if row: @@ -564,7 +564,7 @@ class FullTextSearch: return None elif content_type == "relation": - row = conn.execute( + row = conn.execute( """SELECT r.relation_type, r.evidence, e1.name as source_name, e2.name as target_name FROM entity_relations r @@ -588,15 +588,15 @@ class FullTextSearch: """获取内容所属的项目ID""" try: if content_type == "transcript": - row = conn.execute( + row = conn.execute( "SELECT project_id FROM transcripts WHERE id = ?", (content_id, ) ).fetchone() elif content_type == "entity": - row = conn.execute( + row = conn.execute( "SELECT project_id FROM entities WHERE id = ?", (content_id, ) ).fetchone() elif content_type == "relation": - row = conn.execute( + row = conn.execute( "SELECT project_id FROM entity_relations WHERE id = ?", (content_id, ) ).fetchone() else: @@ -608,39 +608,39 @@ class FullTextSearch: def _score_results(self, results: list[dict], parsed_query: dict) -> list[SearchResult]: """计算搜索结果的相关性分数""" - scored = [] - all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"] + scored = [] + all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"] for result in results: - content = result["content"].lower() + content = result["content"].lower() # 基础分数 - score = 0.0 - highlights = [] + score = 0.0 + highlights = [] # 计算每个词的匹配分数 for term in all_terms: - term_lower = term.lower() - count = content.count(term_lower) + term_lower = term.lower() + count = content.count(term_lower) if count > 0: # TF 分数(词频) - tf_score = math.log(1 + count) + tf_score = math.log(1 + count) # 位置加分(标题/开头匹配分数更高) - position_bonus = 0 - first_pos = content.find(term_lower) + position_bonus = 0 + first_pos = content.find(term_lower) if first_pos != -1: if first_pos < 50: # 开头50个字符 - position_bonus = 2.0 + position_bonus = 2.0 elif first_pos < 200: # 开头200个字符 - position_bonus = 1.0 + position_bonus = 1.0 # 记录高亮位置 - start = first_pos + start = first_pos while start != -1: highlights.append((start, start + len(term))) - start = content.find(term_lower, start + 1) + start = content.find(term_lower, start + 1) score += tf_score + position_bonus @@ -650,23 +650,23 @@ class FullTextSearch: score *= 1.5 # 短语匹配加权 # 归一化分数 - score = min(score / max(len(all_terms), 1), 10.0) + score = min(score / max(len(all_terms), 1), 10.0) scored.append( SearchResult( - id = result["id"], - content = result["content"], - content_type = result["content_type"], - project_id = result["project_id"], - score = round(score, 4), - highlights = highlights[:10], # 限制高亮数量 - metadata = {}, + id=result["id"], + content=result["content"], + content_type=result["content_type"], + project_id=result["project_id"], + score=round(score, 4), + highlights=highlights[:10], # 限制高亮数量 + metadata={}, ) ) return scored - def highlight_text(self, text: str, query: str, max_length: int = 300) -> str: + def highlight_text(self, text: str, query: str, max_length: int = 300) -> str: """ 高亮文本中的关键词 @@ -678,37 +678,37 @@ class FullTextSearch: Returns: str: 带高亮标记的文本 """ - parsed = self._parse_boolean_query(query) - all_terms = parsed["and"] + parsed["or"] + parsed["phrases"] + parsed = self._parse_boolean_query(query) + all_terms = parsed["and"] + parsed["or"] + parsed["phrases"] # 找到第一个匹配位置 - first_match = len(text) + first_match = len(text) for term in all_terms: - pos = text.lower().find(term.lower()) + pos = text.lower().find(term.lower()) if pos != -1 and pos < first_match: - first_match = pos + first_match = pos # 截取上下文 - start = max(0, first_match - 100) - end = min(len(text), start + max_length) - snippet = text[start:end] + start = max(0, first_match - 100) + end = min(len(text), start + max_length) + snippet = text[start:end] if start > 0: - snippet = "..." + snippet + snippet = "..." + snippet if end < len(text): - snippet = snippet + "..." + snippet = snippet + "..." # 添加高亮标记 - for term in sorted(all_terms, key = len, reverse = True): # 长的先替换 - pattern = re.compile(re.escape(term), re.IGNORECASE) - snippet = pattern.sub(f"**{term}**", snippet) + for term in sorted(all_terms, key=len, reverse=True): # 长的先替换 + pattern = re.compile(re.escape(term), re.IGNORECASE) + snippet = pattern.sub(f"**{term}**", snippet) return snippet def delete_index(self, content_id: str, content_type: str) -> bool: """删除内容的搜索索引""" try: - conn = self._get_conn() + conn = self._get_conn() # 删除索引 conn.execute( @@ -731,12 +731,12 @@ class FullTextSearch: def reindex_project(self, project_id: str) -> dict: """重新索引整个项目""" - conn = self._get_conn() - stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0} + conn = self._get_conn() + stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0} try: # 索引转录文本 - transcripts = conn.execute( + transcripts = conn.execute( "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id, ), ).fetchall() @@ -749,20 +749,20 @@ class FullTextSearch: stats["errors"] += 1 # 索引实体 - entities = conn.execute( + entities = conn.execute( "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id, ), ).fetchall() for e in entities: - text = f"{e['name']} {e['definition'] or ''}" + text = f"{e['name']} {e['definition'] or ''}" if self.index_content(e["id"], "entity", e["project_id"], text): stats["entities"] += 1 else: stats["errors"] += 1 # 索引关系 - relations = conn.execute( + relations = conn.execute( """SELECT r.id, r.project_id, r.relation_type, r.evidence, e1.name as source_name, e2.name as target_name FROM entity_relations r @@ -773,7 +773,7 @@ class FullTextSearch: ).fetchall() for r in relations: - text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}" + text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}" if self.index_content(r["id"], "relation", r["project_id"], text): stats["relations"] += 1 else: @@ -803,31 +803,31 @@ class SemanticSearch: def __init__( self, - db_path: str = "insightflow.db", - model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", + db_path: str = "insightflow.db", + model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", ) -> None: - self.db_path = db_path - self.model_name = model_name - self.model = None + self.db_path = db_path + self.model_name = model_name + self.model = None self._init_embedding_tables() # 延迟加载模型 if SENTENCE_TRANSFORMERS_AVAILABLE: try: - self.model = SentenceTransformer(model_name) + self.model = SentenceTransformer(model_name) print(f"语义搜索模型加载成功: {model_name}") except Exception as e: print(f"模型加载失败: {e}") def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_embedding_tables(self) -> None: """初始化 embedding 相关表""" - conn = self._get_conn() + conn = self._get_conn() conn.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( @@ -869,11 +869,11 @@ class SemanticSearch: try: # 截断长文本 - max_chars = 5000 + max_chars = 5000 if len(text) > max_chars: - text = text[:max_chars] + text = text[:max_chars] - embedding = self.model.encode(text, convert_to_list = True) + embedding = self.model.encode(text, convert_to_list=True) return embedding except Exception as e: print(f"生成 embedding 失败: {e}") @@ -898,13 +898,13 @@ class SemanticSearch: return False try: - embedding = self.generate_embedding(text) + embedding = self.generate_embedding(text) if not embedding: return False - conn = self._get_conn() + conn = self._get_conn() - embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] + embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] conn.execute( """ @@ -934,10 +934,10 @@ class SemanticSearch: def search( self, query: str, - project_id: str | None = None, - content_types: list[str] | None = None, - top_k: int = 10, - threshold: float = 0.5, + project_id: str | None = None, + content_types: list[str] | None = None, + top_k: int = 10, + threshold: float = 0.5, ) -> list[SemanticSearchResult]: """ 语义搜索 @@ -956,28 +956,28 @@ class SemanticSearch: return [] # 生成查询的 embedding - query_embedding = self.generate_embedding(query) + query_embedding = self.generate_embedding(query) if not query_embedding: return [] # 获取候选 embedding - conn = self._get_conn() + conn = self._get_conn() - where_clauses = [] - params = [] + where_clauses = [] + params = [] if project_id: where_clauses.append("project_id = ?") params.append(project_id) if content_types: - placeholders = ", ".join(["?" for _ in content_types]) + placeholders = ", ".join(["?" for _ in content_types]) where_clauses.append(f"content_type IN ({placeholders})") params.extend(content_types) - where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1" + where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1" - rows = conn.execute( + rows = conn.execute( f""" SELECT content_id, content_type, project_id, embedding FROM embeddings @@ -989,29 +989,29 @@ class SemanticSearch: conn.close() # 计算相似度 - results = [] - query_vec = [query_embedding] + results = [] + query_vec = [query_embedding] for row in rows: try: - content_embedding = json.loads(row["embedding"]) + content_embedding = json.loads(row["embedding"]) # 计算余弦相似度 - similarity = cosine_similarity(query_vec, [content_embedding])[0][0] + similarity = cosine_similarity(query_vec, [content_embedding])[0][0] if similarity >= threshold: # 获取原始内容 - content = self._get_content_text(row["content_id"], row["content_type"]) + content = self._get_content_text(row["content_id"], row["content_type"]) results.append( SemanticSearchResult( - id = row["content_id"], - content = content or "", - content_type = row["content_type"], - project_id = row["project_id"], - similarity = float(similarity), - embedding = None, # 不返回 embedding 以节省带宽 - metadata = {}, + id=row["content_id"], + content=content or "", + content_type=row["content_type"], + project_id=row["project_id"], + similarity=float(similarity), + embedding=None, # 不返回 embedding 以节省带宽 + metadata={}, ) ) except Exception as e: @@ -1019,28 +1019,28 @@ class SemanticSearch: continue # 排序并返回 top_k - results.sort(key = lambda x: x.similarity, reverse = True) + results.sort(key=lambda x: x.similarity, reverse=True) return results[:top_k] def _get_content_text(self, content_id: str, content_type: str) -> str | None: """获取内容文本""" - conn = self._get_conn() + conn = self._get_conn() try: if content_type == "transcript": - row = conn.execute( + row = conn.execute( "SELECT full_text FROM transcripts WHERE id = ?", (content_id, ) ).fetchone() - result = row["full_text"] if row else None + result = row["full_text"] if row else None elif content_type == "entity": - row = conn.execute( + row = conn.execute( "SELECT name, definition FROM entities WHERE id = ?", (content_id, ) ).fetchone() - result = f"{row['name']}: {row['definition']}" if row else None + result = f"{row['name']}: {row['definition']}" if row else None elif content_type == "relation": - row = conn.execute( + row = conn.execute( """SELECT r.relation_type, r.evidence, e1.name as source_name, e2.name as target_name FROM entity_relations r @@ -1049,14 +1049,14 @@ class SemanticSearch: WHERE r.id = ?""", (content_id, ), ).fetchone() - result = ( + result = ( f"{row['source_name']} {row['relation_type']} {row['target_name']}" if row else None ) else: - result = None + result = None conn.close() return result @@ -1067,7 +1067,7 @@ class SemanticSearch: return None def find_similar_content( - self, content_id: str, content_type: str, top_k: int = 5 + self, content_id: str, content_type: str, top_k: int = 5 ) -> list[SemanticSearchResult]: """ 查找与指定内容相似的内容 @@ -1084,9 +1084,9 @@ class SemanticSearch: return [] # 获取源内容的 embedding - conn = self._get_conn() + conn = self._get_conn() - row = conn.execute( + row = conn.execute( "SELECT embedding, project_id FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type), ).fetchone() @@ -1095,11 +1095,11 @@ class SemanticSearch: conn.close() return [] - source_embedding = json.loads(row["embedding"]) - project_id = row["project_id"] + source_embedding = json.loads(row["embedding"]) + project_id = row["project_id"] # 获取其他内容的 embedding - rows = conn.execute( + rows = conn.execute( """SELECT content_id, content_type, project_id, embedding FROM embeddings WHERE project_id = ? AND (content_id != ? OR content_type != ?)""", @@ -1109,36 +1109,36 @@ class SemanticSearch: conn.close() # 计算相似度 - results = [] - source_vec = [source_embedding] + results = [] + source_vec = [source_embedding] for row in rows: try: - content_embedding = json.loads(row["embedding"]) - similarity = cosine_similarity(source_vec, [content_embedding])[0][0] + content_embedding = json.loads(row["embedding"]) + similarity = cosine_similarity(source_vec, [content_embedding])[0][0] - content = self._get_content_text(row["content_id"], row["content_type"]) + content = self._get_content_text(row["content_id"], row["content_type"]) results.append( SemanticSearchResult( - id = row["content_id"], - content = content or "", - content_type = row["content_type"], - project_id = row["project_id"], - similarity = float(similarity), - metadata = {}, + id=row["content_id"], + content=content or "", + content_type=row["content_type"], + project_id=row["project_id"], + similarity=float(similarity), + metadata={}, ) ) except (KeyError, ValueError): continue - results.sort(key = lambda x: x.similarity, reverse = True) + results.sort(key=lambda x: x.similarity, reverse=True) return results[:top_k] def delete_embedding(self, content_id: str, content_type: str) -> bool: """删除内容的 embedding""" try: - conn = self._get_conn() + conn = self._get_conn() conn.execute( "DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type), @@ -1165,17 +1165,17 @@ class EntityPathDiscovery: - 路径可视化数据生成 """ - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def find_shortest_path( - self, source_entity_id: str, target_entity_id: str, max_depth: int = 5 + self, source_entity_id: str, target_entity_id: str, max_depth: int = 5 ) -> EntityPath | None: """ 查找两个实体之间的最短路径(BFS算法) @@ -1188,10 +1188,10 @@ class EntityPathDiscovery: Returns: Optional[EntityPath]: 最短路径 """ - conn = self._get_conn() + conn = self._get_conn() # 获取项目ID - row = conn.execute( + row = conn.execute( "SELECT project_id FROM entities WHERE id = ?", (source_entity_id, ) ).fetchone() @@ -1199,10 +1199,10 @@ class EntityPathDiscovery: conn.close() return None - project_id = row["project_id"] + project_id = row["project_id"] # 验证目标实体也在同一项目 - row = conn.execute( + row = conn.execute( "SELECT 1 FROM entities WHERE id = ? AND project_id = ?", (target_entity_id, project_id) ).fetchone() @@ -1211,11 +1211,11 @@ class EntityPathDiscovery: return None # BFS - visited = {source_entity_id} - queue = [(source_entity_id, [source_entity_id])] + visited = {source_entity_id} + queue = [(source_entity_id, [source_entity_id])] while queue: - current_id, path = queue.pop(0) + current_id, path = queue.pop(0) if len(path) > max_depth + 1: continue @@ -1226,7 +1226,7 @@ class EntityPathDiscovery: return self._build_path_object(path, project_id) # 获取邻居 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT target_entity_id as neighbor_id, relation_type, evidence FROM entity_relations @@ -1240,7 +1240,7 @@ class EntityPathDiscovery: ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited.add(neighbor_id) queue.append((neighbor_id, path + [neighbor_id])) @@ -1249,7 +1249,7 @@ class EntityPathDiscovery: return None def find_all_paths( - self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10 + self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10 ) -> list[EntityPath]: """ 查找两个实体之间的所有路径(限制数量和深度) @@ -1263,10 +1263,10 @@ class EntityPathDiscovery: Returns: List[EntityPath]: 路径列表 """ - conn = self._get_conn() + conn = self._get_conn() # 获取项目ID - row = conn.execute( + row = conn.execute( "SELECT project_id FROM entities WHERE id = ?", (source_entity_id, ) ).fetchone() @@ -1274,9 +1274,9 @@ class EntityPathDiscovery: conn.close() return [] - project_id = row["project_id"] + project_id = row["project_id"] - paths = [] + paths = [] def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int) -> None: if depth > max_depth: @@ -1287,7 +1287,7 @@ class EntityPathDiscovery: return # 获取邻居 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT target_entity_id as neighbor_id FROM entity_relations @@ -1301,7 +1301,7 @@ class EntityPathDiscovery: ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited and len(paths) < max_paths: visited.add(neighbor_id) path.append(neighbor_id) @@ -1309,7 +1309,7 @@ class EntityPathDiscovery: path.pop() visited.remove(neighbor_id) - visited = {source_entity_id} + visited = {source_entity_id} dfs(source_entity_id, target_entity_id, [source_entity_id], visited, 0) conn.close() @@ -1319,24 +1319,24 @@ class EntityPathDiscovery: def _build_path_object(self, entity_ids: list[str], project_id: str) -> EntityPath: """构建路径对象""" - conn = self._get_conn() + conn = self._get_conn() # 获取实体信息 - nodes = [] + nodes = [] for entity_id in entity_ids: - row = conn.execute( + row = conn.execute( "SELECT id, name, type FROM entities WHERE id = ?", (entity_id, ) ).fetchone() if row: nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]}) # 获取边信息 - edges = [] + edges = [] for i in range(len(entity_ids) - 1): - source_id = entity_ids[i] - target_id = entity_ids[i + 1] + source_id = entity_ids[i] + target_id = entity_ids[i + 1] - row = conn.execute( + row = conn.execute( """ SELECT id, relation_type, evidence FROM entity_relations @@ -1361,26 +1361,26 @@ class EntityPathDiscovery: conn.close() # 生成路径描述 - node_names = [n["name"] for n in nodes] - path_desc = " → ".join(node_names) + node_names = [n["name"] for n in nodes] + path_desc = " → ".join(node_names) # 计算置信度(基于路径长度和关系数量) - confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0 + confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0 return EntityPath( - path_id = f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", - source_entity_id = entity_ids[0], - source_entity_name = nodes[0]["name"] if nodes else "", - target_entity_id = entity_ids[-1], - target_entity_name = nodes[-1]["name"] if nodes else "", - path_length = len(entity_ids) - 1, - nodes = nodes, - edges = edges, - confidence = round(confidence, 4), - path_description = path_desc, + path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", + source_entity_id=entity_ids[0], + source_entity_name=nodes[0]["name"] if nodes else "", + target_entity_id=entity_ids[-1], + target_entity_name=nodes[-1]["name"] if nodes else "", + path_length=len(entity_ids) - 1, + nodes=nodes, + edges=edges, + confidence=round(confidence, 4), + path_description=path_desc, ) - def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: + def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: """ 查找实体的多跳关系 @@ -1391,10 +1391,10 @@ class EntityPathDiscovery: Returns: List[Dict]: 多跳关系列表 """ - conn = self._get_conn() + conn = self._get_conn() # 获取项目ID - row = conn.execute( + row = conn.execute( "SELECT project_id, name FROM entities WHERE id = ?", (entity_id, ) ).fetchone() @@ -1402,22 +1402,22 @@ class EntityPathDiscovery: conn.close() return [] - project_id = row["project_id"] + project_id = row["project_id"] row["name"] # BFS 收集多跳关系 - visited = {entity_id: 0} - queue = [(entity_id, 0)] - relations = [] + visited = {entity_id: 0} + queue = [(entity_id, 0)] + relations = [] while queue: - current_id, depth = queue.pop(0) + current_id, depth = queue.pop(0) if depth >= max_hops: continue # 获取邻居 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT CASE @@ -1434,14 +1434,14 @@ class EntityPathDiscovery: ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: - visited[neighbor_id] = depth + 1 + visited[neighbor_id] = depth + 1 queue.append((neighbor_id, depth + 1)) # 获取邻居信息 - neighbor_info = conn.execute( + neighbor_info = conn.execute( "SELECT name, type FROM entities WHERE id = ?", (neighbor_id, ) ).fetchone() @@ -1463,7 +1463,7 @@ class EntityPathDiscovery: conn.close() # 按跳数排序 - relations.sort(key = lambda x: x["hops"]) + relations.sort(key=lambda x: x["hops"]) return relations def _get_path_to_entity( @@ -1471,11 +1471,11 @@ class EntityPathDiscovery: ) -> list[str]: """获取从源实体到目标实体的路径(简化版)""" # BFS 找路径 - visited = {source_id} - queue = [(source_id, [source_id])] + visited = {source_id} + queue = [(source_id, [source_id])] while queue: - current, path = queue.pop(0) + current, path = queue.pop(0) if current == target_id: return path @@ -1483,7 +1483,7 @@ class EntityPathDiscovery: if len(path) > 5: # 限制路径长度 continue - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT CASE @@ -1498,7 +1498,7 @@ class EntityPathDiscovery: ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited.add(neighbor_id) queue.append((neighbor_id, path + [neighbor_id])) @@ -1516,7 +1516,7 @@ class EntityPathDiscovery: Dict: D3.js 可视化数据格式 """ # 节点数据 - nodes = [] + nodes = [] for node in path.nodes: nodes.append( { @@ -1529,7 +1529,7 @@ class EntityPathDiscovery: ) # 边数据 - links = [] + links = [] for edge in path.edges: links.append( { @@ -1558,21 +1558,21 @@ class EntityPathDiscovery: Returns: List[Dict]: 中心性分析结果 """ - conn = self._get_conn() + conn = self._get_conn() # 获取所有实体 - entities = conn.execute( + entities = conn.execute( "SELECT id, name FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() # 计算每个实体作为桥梁的次数 - bridge_scores = [] + bridge_scores = [] for entity in entities: - entity_id = entity["id"] + entity_id = entity["id"] # 计算该实体连接的不同群组数量 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT CASE @@ -1586,11 +1586,11 @@ class EntityPathDiscovery: (entity_id, entity_id, entity_id, project_id), ).fetchall() - neighbor_ids = {n["neighbor_id"] for n in neighbors} + neighbor_ids = {n["neighbor_id"] for n in neighbors} # 计算邻居之间的连接数(用于评估桥接程度) if len(neighbor_ids) > 1: - connections = conn.execute( + connections = conn.execute( f""" SELECT COUNT(*) as count FROM entity_relations @@ -1604,9 +1604,9 @@ class EntityPathDiscovery: ).fetchone() # 桥接分数 = 邻居数量 / (邻居间连接数 + 1) - bridge_score = len(neighbor_ids) / (connections["count"] + 1) + bridge_score = len(neighbor_ids) / (connections["count"] + 1) else: - bridge_score = 0 + bridge_score = 0 bridge_scores.append( { @@ -1620,7 +1620,7 @@ class EntityPathDiscovery: conn.close() # 按桥接分数排序 - bridge_scores.sort(key = lambda x: x["bridge_score"], reverse = True) + bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True) return bridge_scores[:20] # 返回前20 @@ -1638,13 +1638,13 @@ class KnowledgeGapDetection: - 生成知识补全建议 """ - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def analyze_project(self, project_id: str) -> list[KnowledgeGap]: @@ -1657,7 +1657,7 @@ class KnowledgeGapDetection: Returns: List[KnowledgeGap]: 知识缺口列表 """ - gaps = [] + gaps = [] # 1. 检查实体属性完整性 gaps.extend(self._check_entity_attribute_completeness(project_id)) @@ -1675,18 +1675,18 @@ class KnowledgeGapDetection: gaps.extend(self._check_missing_key_entities(project_id)) # 按严重程度排序 - severity_order = {"high": 0, "medium": 1, "low": 2} - gaps.sort(key = lambda x: severity_order.get(x.severity, 3)) + severity_order = {"high": 0, "medium": 1, "low": 2} + gaps.sort(key=lambda x: severity_order.get(x.severity, 3)) return gaps def _check_entity_attribute_completeness(self, project_id: str) -> list[KnowledgeGap]: """检查实体属性完整性""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 获取项目的属性模板 - templates = conn.execute( + templates = conn.execute( "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", (project_id, ), ).fetchall() @@ -1695,34 +1695,34 @@ class KnowledgeGapDetection: conn.close() return [] - required_template_ids = {t["id"] for t in templates if t["is_required"]} + required_template_ids = {t["id"] for t in templates if t["is_required"]} if not required_template_ids: conn.close() return [] # 检查每个实体的属性完整性 - entities = conn.execute( + entities = conn.execute( "SELECT id, name FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() for entity in entities: - entity_id = entity["id"] + entity_id = entity["id"] # 获取实体已有的属性 - existing_attrs = conn.execute( + existing_attrs = conn.execute( "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id, ) ).fetchall() - existing_template_ids = {a["template_id"] for a in existing_attrs} + existing_template_ids = {a["template_id"] for a in existing_attrs} # 找出缺失的必需属性 - missing_templates = required_template_ids - existing_template_ids + missing_templates = required_template_ids - existing_template_ids if missing_templates: - missing_names = [] + missing_names = [] for template_id in missing_templates: - template = conn.execute( + template = conn.execute( "SELECT name FROM attribute_templates WHERE id = ?", (template_id, ) ).fetchone() if template: @@ -1731,18 +1731,18 @@ class KnowledgeGapDetection: if missing_names: gaps.append( KnowledgeGap( - gap_id = f"gap_attr_{entity_id}", - gap_type = "missing_attribute", - entity_id = entity_id, - entity_name = entity["name"], - description = f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", - severity = "medium", - suggestions = [ + gap_id=f"gap_attr_{entity_id}", + gap_type="missing_attribute", + entity_id=entity_id, + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", + severity="medium", + suggestions=[ f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}", "检查属性模板定义是否合理", ], - related_entities = [], - metadata = {"missing_attributes": missing_names}, + related_entities=[], + metadata={"missing_attributes": missing_names}, ) ) @@ -1751,19 +1751,19 @@ class KnowledgeGapDetection: def _check_relation_sparsity(self, project_id: str) -> list[KnowledgeGap]: """检查关系稀疏度""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 获取所有实体及其关系数量 - entities = conn.execute( + entities = conn.execute( "SELECT id, name, type FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() for entity in entities: - entity_id = entity["id"] + entity_id = entity["id"] # 计算关系数量 - relation_count = conn.execute( + relation_count = conn.execute( """ SELECT COUNT(*) as count FROM entity_relations @@ -1774,11 +1774,11 @@ class KnowledgeGapDetection: ).fetchone()["count"] # 根据实体类型判断阈值 - threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0 + threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0 if relation_count <= threshold: # 查找潜在的相关实体 - potential_related = conn.execute( + potential_related = conn.execute( """ SELECT e.id, e.name FROM entities e @@ -1793,19 +1793,19 @@ class KnowledgeGapDetection: gaps.append( KnowledgeGap( - gap_id = f"gap_sparse_{entity_id}", - gap_type = "sparse_relation", - entity_id = entity_id, - entity_name = entity["name"], - description = f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", - severity = "medium" if relation_count == 0 else "low", - suggestions = [ + gap_id=f"gap_sparse_{entity_id}", + gap_type="sparse_relation", + entity_id=entity_id, + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", + severity="medium" if relation_count == 0 else "low", + suggestions=[ f"检查转录文本中提及 '{entity['name']}' 的其他实体", f"手动添加 '{entity['name']}' 与其他实体的关系", "使用实体对齐功能合并相似实体", ], - related_entities = [r["id"] for r in potential_related], - metadata = { + related_entities=[r["id"] for r in potential_related], + metadata={ "relation_count": relation_count, "potential_related": [r["name"] for r in potential_related], }, @@ -1817,11 +1817,11 @@ class KnowledgeGapDetection: def _check_isolated_entities(self, project_id: str) -> list[KnowledgeGap]: """检查孤立实体(没有任何关系)""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 查找没有关系的实体 - isolated = conn.execute( + isolated = conn.execute( """ SELECT e.id, e.name, e.type FROM entities e @@ -1837,19 +1837,19 @@ class KnowledgeGapDetection: for entity in isolated: gaps.append( KnowledgeGap( - gap_id = f"gap_iso_{entity['id']}", - gap_type = "isolated_entity", - entity_id = entity["id"], - entity_name = entity["name"], - description = f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", - severity = "high", - suggestions = [ + gap_id=f"gap_iso_{entity['id']}", + gap_type="isolated_entity", + entity_id=entity["id"], + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", + severity="high", + suggestions=[ f"检查 '{entity['name']}' 是否应该与其他实体建立关系", f"考虑删除不相关的实体 '{entity['name']}'", "运行关系发现算法自动识别潜在关系", ], - related_entities = [], - metadata = {"entity_type": entity["type"]}, + related_entities=[], + metadata={"entity_type": entity["type"]}, ) ) @@ -1858,11 +1858,11 @@ class KnowledgeGapDetection: def _check_incomplete_entities(self, project_id: str) -> list[KnowledgeGap]: """检查不完整实体(缺少名称、类型或定义)""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 查找缺少定义的实体 - incomplete = conn.execute( + incomplete = conn.execute( """ SELECT id, name, type, definition FROM entities @@ -1875,15 +1875,15 @@ class KnowledgeGapDetection: for entity in incomplete: gaps.append( KnowledgeGap( - gap_id = f"gap_inc_{entity['id']}", - gap_type = "incomplete_entity", - entity_id = entity["id"], - entity_name = entity["name"], - description = f"实体 '{entity['name']}' 缺少定义", - severity = "low", - suggestions = [f"为 '{entity['name']}' 添加定义", "从转录文本中提取定义信息"], - related_entities = [], - metadata = {"entity_type": entity["type"]}, + gap_id=f"gap_inc_{entity['id']}", + gap_type="incomplete_entity", + entity_id=entity["id"], + entity_name=entity["name"], + description=f"实体 '{entity['name']}' 缺少定义", + severity="low", + suggestions=[f"为 '{entity['name']}' 添加定义", "从转录文本中提取定义信息"], + related_entities=[], + metadata={"entity_type": entity["type"]}, ) ) @@ -1892,30 +1892,30 @@ class KnowledgeGapDetection: def _check_missing_key_entities(self, project_id: str) -> list[KnowledgeGap]: """检查可能缺失的关键实体""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 分析转录文本中频繁提及但未提取为实体的词 - transcripts = conn.execute( + transcripts = conn.execute( "SELECT full_text FROM transcripts WHERE project_id = ?", (project_id, ) ).fetchall() # 合并所有文本 - all_text = " ".join([t["full_text"] or "" for t in transcripts]) + all_text = " ".join([t["full_text"] or "" for t in transcripts]) # 获取现有实体名称 - existing_entities = conn.execute( + existing_entities = conn.execute( "SELECT name FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() - existing_names = {e["name"].lower() for e in existing_entities} + existing_names = {e["name"].lower() for e in existing_entities} # 简单的关键词提取(实际可以使用更复杂的 NLP 方法) # 查找大写的词组(可能是专有名词) - potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text) + potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text) # 统计频率 - freq = defaultdict(int) + freq = defaultdict(int) for entity in potential_entities: if len(entity) > 3 and entity.lower() not in existing_names: freq[entity] += 1 @@ -1925,18 +1925,18 @@ class KnowledgeGapDetection: if count >= 3: # 出现3次以上 gaps.append( KnowledgeGap( - gap_id = f"gap_missing_{hash(entity) % 10000}", - gap_type = "missing_key_entity", - entity_id = None, - entity_name = None, - description = f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", - severity = "low", - suggestions = [ + gap_id=f"gap_missing_{hash(entity) % 10000}", + gap_type="missing_key_entity", + entity_id=None, + entity_name=None, + description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", + severity="low", + suggestions=[ f"考虑将 '{entity}' 添加为实体", "检查实体提取算法是否需要优化", ], - related_entities = [], - metadata = {"mention_count": count}, + related_entities=[], + metadata={"mention_count": count}, ) ) @@ -1953,10 +1953,10 @@ class KnowledgeGapDetection: Returns: Dict: 完整性报告 """ - conn = self._get_conn() + conn = self._get_conn() # 基础统计 - stats = conn.execute( + stats = conn.execute( """ SELECT (SELECT COUNT(*) FROM entities WHERE project_id = ?) as entity_count, @@ -1967,22 +1967,22 @@ class KnowledgeGapDetection: ).fetchone() # 计算完整性分数 - gaps = self.analyze_project(project_id) + gaps = self.analyze_project(project_id) # 按类型统计 - gap_by_type = defaultdict(int) - severity_count = {"high": 0, "medium": 0, "low": 0} + gap_by_type = defaultdict(int) + severity_count = {"high": 0, "medium": 0, "low": 0} for gap in gaps: gap_by_type[gap.gap_type] += 1 severity_count[gap.severity] += 1 # 计算完整性分数(100 - 扣分) - score = 100 + score = 100 score -= severity_count["high"] * 10 score -= severity_count["medium"] * 5 score -= severity_count["low"] * 2 - score = max(0, score) + score = max(0, score) conn.close() @@ -2005,9 +2005,9 @@ class KnowledgeGapDetection: def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]: """生成改进建议""" - recommendations = [] + recommendations = [] - gap_types = {g.gap_type for g in gaps} + gap_types = {g.gap_type for g in gaps} if "isolated_entity" in gap_types: recommendations.append("优先处理孤立实体,建立实体间的关系连接") @@ -2040,14 +2040,14 @@ class SearchManager: 整合全文搜索、语义搜索、实体路径发现和知识缺口识别功能 """ - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path - self.fulltext_search = FullTextSearch(db_path) - self.semantic_search = SemanticSearch(db_path) - self.path_discovery = EntityPathDiscovery(db_path) - self.gap_detection = KnowledgeGapDetection(db_path) + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path + self.fulltext_search = FullTextSearch(db_path) + self.semantic_search = SemanticSearch(db_path) + self.path_discovery = EntityPathDiscovery(db_path) + self.gap_detection = KnowledgeGapDetection(db_path) - def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict: + def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict: """ 混合搜索(全文 + 语义) @@ -2060,20 +2060,20 @@ class SearchManager: Dict: 混合搜索结果 """ # 全文搜索 - fulltext_results = self.fulltext_search.search(query, project_id, limit = limit) + fulltext_results = self.fulltext_search.search(query, project_id, limit=limit) # 语义搜索 - semantic_results = [] + semantic_results = [] if self.semantic_search.is_available(): - semantic_results = self.semantic_search.search(query, project_id, top_k = limit) + semantic_results = self.semantic_search.search(query, project_id, top_k=limit) # 合并结果(去重并加权) - combined = {} + combined = {} # 添加全文搜索结果 for r in fulltext_results: - key = (r.id, r.content_type) - combined[key] = { + key = (r.id, r.content_type) + combined[key] = { "id": r.id, "content": r.content, "content_type": r.content_type, @@ -2086,12 +2086,12 @@ class SearchManager: # 添加语义搜索结果 for r in semantic_results: - key = (r.id, r.content_type) + key = (r.id, r.content_type) if key in combined: - combined[key]["semantic_score"] = r.similarity + combined[key]["semantic_score"] = r.similarity combined[key]["combined_score"] += r.similarity * 0.4 # 语义权重 40% else: - combined[key] = { + combined[key] = { "id": r.id, "content": r.content, "content_type": r.content_type, @@ -2103,8 +2103,8 @@ class SearchManager: } # 排序 - results = list(combined.values()) - results.sort(key = lambda x: x["combined_score"], reverse = True) + results = list(combined.values()) + results.sort(key=lambda x: x["combined_score"], reverse=True) return { "query": query, @@ -2126,17 +2126,17 @@ class SearchManager: Dict: 索引统计 """ # 全文索引 - fulltext_stats = self.fulltext_search.reindex_project(project_id) + fulltext_stats = self.fulltext_search.reindex_project(project_id) # 语义索引 - semantic_stats = {"indexed": 0, "errors": 0} + semantic_stats = {"indexed": 0, "errors": 0} if self.semantic_search.is_available(): - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 索引转录文本 - transcripts = conn.execute( + transcripts = conn.execute( "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id, ), ).fetchall() @@ -2150,13 +2150,13 @@ class SearchManager: semantic_stats["errors"] += 1 # 索引实体 - entities = conn.execute( + entities = conn.execute( "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id, ), ).fetchall() for e in entities: - text = f"{e['name']} {e['definition'] or ''}" + text = f"{e['name']} {e['definition'] or ''}" if self.semantic_search.index_embedding(e["id"], "entity", e["project_id"], text): semantic_stats["indexed"] += 1 else: @@ -2166,34 +2166,34 @@ class SearchManager: return {"project_id": project_id, "fulltext": fulltext_stats, "semantic": semantic_stats} - def get_search_stats(self, project_id: str | None = None) -> dict: + def get_search_stats(self, project_id: str | None = None) -> dict: """获取搜索统计信息""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - where_clause = "WHERE project_id = ?" if project_id else "" - params = [project_id] if project_id else [] + where_clause = "WHERE project_id = ?" if project_id else "" + params = [project_id] if project_id else [] # 全文索引统计 - fulltext_count = conn.execute( + fulltext_count = conn.execute( f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params ).fetchone()["count"] # 语义索引统计 - semantic_count = conn.execute( + semantic_count = conn.execute( f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params ).fetchone()["count"] # 按类型统计 - type_stats = {} + type_stats = {} if project_id: - rows = conn.execute( + rows = conn.execute( """SELECT content_type, COUNT(*) as count FROM search_indexes WHERE project_id = ? GROUP BY content_type""", (project_id, ), ).fetchall() - type_stats = {r["content_type"]: r["count"] for r in rows} + type_stats = {r["content_type"]: r["count"] for r in rows} conn.close() @@ -2207,14 +2207,14 @@ class SearchManager: # 单例模式 -_search_manager = None +_search_manager = None -def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: +def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: """获取搜索管理器单例""" global _search_manager if _search_manager is None: - _search_manager = SearchManager(db_path) + _search_manager = SearchManager(db_path) return _search_manager @@ -2222,28 +2222,28 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: def fulltext_search( - query: str, project_id: str | None = None, limit: int = 20 + query: str, project_id: str | None = None, limit: int = 20 ) -> list[SearchResult]: """全文搜索便捷函数""" - manager = get_search_manager() - return manager.fulltext_search.search(query, project_id, limit = limit) + manager = get_search_manager() + return manager.fulltext_search.search(query, project_id, limit=limit) def semantic_search( - query: str, project_id: str | None = None, top_k: int = 10 + query: str, project_id: str | None = None, top_k: int = 10 ) -> list[SemanticSearchResult]: """语义搜索便捷函数""" - manager = get_search_manager() - return manager.semantic_search.search(query, project_id, top_k = top_k) + manager = get_search_manager() + return manager.semantic_search.search(query, project_id, top_k=top_k) -def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: +def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: """查找实体路径便捷函数""" - manager = get_search_manager() + manager = get_search_manager() return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth) def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]: """知识缺口检测便捷函数""" - manager = get_search_manager() + manager = get_search_manager() return manager.gap_detection.analyze_project(project_id) diff --git a/backend/security_manager.py b/backend/security_manager.py index 3f7161b..9ff1299 100644 --- a/backend/security_manager.py +++ b/backend/security_manager.py @@ -20,54 +20,54 @@ try: from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC - CRYPTO_AVAILABLE = True + CRYPTO_AVAILABLE = True except ImportError: - CRYPTO_AVAILABLE = False + CRYPTO_AVAILABLE = False print("Warning: cryptography not available, encryption features disabled") class AuditActionType(Enum): """审计动作类型""" - CREATE = "create" - READ = "read" - UPDATE = "update" - DELETE = "delete" - LOGIN = "login" - LOGOUT = "logout" - EXPORT = "export" - IMPORT = "import" - SHARE = "share" - PERMISSION_CHANGE = "permission_change" - ENCRYPTION_ENABLE = "encryption_enable" - ENCRYPTION_DISABLE = "encryption_disable" - DATA_MASKING = "data_masking" - API_KEY_CREATE = "api_key_create" - API_KEY_REVOKE = "api_key_revoke" - WORKFLOW_TRIGGER = "workflow_trigger" - WEBHOOK_SEND = "webhook_send" - BOT_MESSAGE = "bot_message" + CREATE = "create" + READ = "read" + UPDATE = "update" + DELETE = "delete" + LOGIN = "login" + LOGOUT = "logout" + EXPORT = "export" + IMPORT = "import" + SHARE = "share" + PERMISSION_CHANGE = "permission_change" + ENCRYPTION_ENABLE = "encryption_enable" + ENCRYPTION_DISABLE = "encryption_disable" + DATA_MASKING = "data_masking" + API_KEY_CREATE = "api_key_create" + API_KEY_REVOKE = "api_key_revoke" + WORKFLOW_TRIGGER = "workflow_trigger" + WEBHOOK_SEND = "webhook_send" + BOT_MESSAGE = "bot_message" class DataSensitivityLevel(Enum): """数据敏感度级别""" - PUBLIC = "public" # 公开 - INTERNAL = "internal" # 内部 - CONFIDENTIAL = "confidential" # 机密 - SECRET = "secret" # 绝密 + PUBLIC = "public" # 公开 + INTERNAL = "internal" # 内部 + CONFIDENTIAL = "confidential" # 机密 + SECRET = "secret" # 绝密 class MaskingRuleType(Enum): """脱敏规则类型""" - PHONE = "phone" # 手机号 - EMAIL = "email" # 邮箱 - ID_CARD = "id_card" # 身份证号 - BANK_CARD = "bank_card" # 银行卡号 - NAME = "name" # 姓名 - ADDRESS = "address" # 地址 - CUSTOM = "custom" # 自定义 + PHONE = "phone" # 手机号 + EMAIL = "email" # 邮箱 + ID_CARD = "id_card" # 身份证号 + BANK_CARD = "bank_card" # 银行卡号 + NAME = "name" # 姓名 + ADDRESS = "address" # 地址 + CUSTOM = "custom" # 自定义 @dataclass @@ -76,17 +76,17 @@ class AuditLog: id: str action_type: str - user_id: str | None = None - user_ip: str | None = None - user_agent: str | None = None - resource_type: str | None = None # project, entity, transcript, etc. - resource_id: str | None = None - action_details: str | None = None # JSON string - before_value: str | None = None - after_value: str | None = None - success: bool = True - error_message: str | None = None - created_at: str = field(default_factory = lambda: datetime.now().isoformat()) + user_id: str | None = None + user_ip: str | None = None + user_agent: str | None = None + resource_type: str | None = None # project, entity, transcript, etc. + resource_id: str | None = None + action_details: str | None = None # JSON string + before_value: str | None = None + after_value: str | None = None + success: bool = True + error_message: str | None = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -98,13 +98,13 @@ class EncryptionConfig: id: str project_id: str - is_enabled: bool = False - encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305 - key_derivation: str = "pbkdf2" # pbkdf2, argon2 - master_key_hash: str | None = None # 主密钥哈希(用于验证) - salt: str | None = None - created_at: str = field(default_factory = lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) + is_enabled: bool = False + encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305 + key_derivation: str = "pbkdf2" # pbkdf2, argon2 + master_key_hash: str | None = None # 主密钥哈希(用于验证) + salt: str | None = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -120,11 +120,11 @@ class MaskingRule: rule_type: str # phone, email, id_card, bank_card, name, address, custom pattern: str # 正则表达式 replacement: str # 替换模板,如 "****" - is_active: bool = True - priority: int = 0 - description: str | None = None - created_at: str = field(default_factory = lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) + is_active: bool = True + priority: int = 0 + description: str | None = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -137,16 +137,16 @@ class DataAccessPolicy: id: str project_id: str name: str - description: str | None = None - allowed_users: str | None = None # JSON array of user IDs - allowed_roles: str | None = None # JSON array of roles - allowed_ips: str | None = None # JSON array of IP patterns - time_restrictions: str | None = None # JSON: {"start_time": "09:00", "end_time": "18:00"} - max_access_count: int | None = None # 最大访问次数 - require_approval: bool = False - is_active: bool = True - created_at: str = field(default_factory = lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) + description: str | None = None + allowed_users: str | None = None # JSON array of user IDs + allowed_roles: str | None = None # JSON array of roles + allowed_ips: str | None = None # JSON array of IP patterns + time_restrictions: str | None = None # JSON: {"start_time": "09:00", "end_time": "18:00"} + max_access_count: int | None = None # 最大访问次数 + require_approval: bool = False + is_active: bool = True + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -159,12 +159,12 @@ class AccessRequest: id: str policy_id: str user_id: str - request_reason: str | None = None - status: str = "pending" # pending, approved, rejected, expired - approved_by: str | None = None - approved_at: str | None = None - expires_at: str | None = None - created_at: str = field(default_factory = lambda: datetime.now().isoformat()) + request_reason: str | None = None + status: str = "pending" # pending, approved, rejected, expired + approved_by: str | None = None + approved_at: str | None = None + expires_at: str | None = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -174,7 +174,7 @@ class SecurityManager: """安全管理器""" # 预定义脱敏规则 - DEFAULT_MASKING_RULES = { + DEFAULT_MASKING_RULES = { MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"}, MaskingRuleType.EMAIL: {"pattern": r"(\w{1, 3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"}, MaskingRuleType.ID_CARD: { @@ -195,17 +195,17 @@ class SecurityManager: }, } - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path # 预编译正则缓存 - self._compiled_patterns: dict[str, re.Pattern] = {} - self._local = {} + self._compiled_patterns: dict[str, re.Pattern] = {} + self._local = {} self._init_db() def _init_db(self) -> None: """初始化数据库表""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() # 审计日志表 cursor.execute(""" @@ -332,35 +332,35 @@ class SecurityManager: def log_audit( self, action_type: AuditActionType, - user_id: str | None = None, - user_ip: str | None = None, - user_agent: str | None = None, - resource_type: str | None = None, - resource_id: str | None = None, - action_details: dict | None = None, - before_value: str | None = None, - after_value: str | None = None, - success: bool = True, - error_message: str | None = None, + user_id: str | None = None, + user_ip: str | None = None, + user_agent: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + action_details: dict | None = None, + before_value: str | None = None, + after_value: str | None = None, + success: bool = True, + error_message: str | None = None, ) -> AuditLog: """记录审计日志""" - log = AuditLog( - id = self._generate_id(), - action_type = action_type.value, - user_id = user_id, - user_ip = user_ip, - user_agent = user_agent, - resource_type = resource_type, - resource_id = resource_id, - action_details = json.dumps(action_details) if action_details else None, - before_value = before_value, - after_value = after_value, - success = success, - error_message = error_message, + log = AuditLog( + id=self._generate_id(), + action_type=action_type.value, + user_id=user_id, + user_ip=user_ip, + user_agent=user_agent, + resource_type=resource_type, + resource_id=resource_id, + action_details=json.dumps(action_details) if action_details else None, + before_value=before_value, + after_value=after_value, + success=success, + error_message=error_message, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ INSERT INTO audit_logs @@ -391,22 +391,22 @@ class SecurityManager: def get_audit_logs( self, - user_id: str | None = None, - resource_type: str | None = None, - resource_id: str | None = None, - action_type: str | None = None, - start_time: str | None = None, - end_time: str | None = None, - success: bool | None = None, - limit: int = 100, - offset: int = 0, + user_id: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + action_type: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + success: bool | None = None, + limit: int = 100, + offset: int = 0, ) -> list[AuditLog]: """查询审计日志""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT * FROM audit_logs WHERE 1 = 1" - params = [] + query = "SELECT * FROM audit_logs WHERE 1 = 1" + params = [] if user_id: query += " AND user_id = ?" @@ -434,29 +434,29 @@ class SecurityManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() conn.close() - logs = [] - col_names = [desc[0] for desc in cursor.description] if cursor.description else [] + logs = [] + col_names = [desc[0] for desc in cursor.description] if cursor.description else [] if not col_names: return logs for row in rows: - log = AuditLog( - id = row[0], - action_type = row[1], - user_id = row[2], - user_ip = row[3], - user_agent = row[4], - resource_type = row[5], - resource_id = row[6], - action_details = row[7], - before_value = row[8], - after_value = row[9], - success = bool(row[10]), - error_message = row[11], - created_at = row[12], + log = AuditLog( + id=row[0], + action_type=row[1], + user_id=row[2], + user_ip=row[3], + user_agent=row[4], + resource_type=row[5], + resource_id=row[6], + action_details=row[7], + before_value=row[8], + after_value=row[9], + success=bool(row[10]), + error_message=row[11], + created_at=row[12], ) logs.append(log) @@ -464,14 +464,14 @@ class SecurityManager: return logs def get_audit_stats( - self, start_time: str | None = None, end_time: str | None = None + self, start_time: str | None = None, end_time: str | None = None ) -> dict[str, Any]: """获取审计统计""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1 = 1" - params = [] + query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1 = 1" + params = [] if start_time: query += " AND created_at >= ?" @@ -483,9 +483,9 @@ class SecurityManager: query += " GROUP BY action_type, success" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() - stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}} + stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}} for action_type, success, count in rows: stats["total_actions"] += count @@ -495,7 +495,7 @@ class SecurityManager: stats["failure_count"] += count if action_type not in stats["action_breakdown"]: - stats["action_breakdown"][action_type] = {"success": 0, "failure": 0} + stats["action_breakdown"][action_type] = {"success": 0, "failure": 0} if success: stats["action_breakdown"][action_type]["success"] += count @@ -512,11 +512,11 @@ class SecurityManager: if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") - kdf = PBKDF2HMAC( - algorithm = hashes.SHA256(), - length = 32, - salt = salt, - iterations = 100000, + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) @@ -526,28 +526,28 @@ class SecurityManager: raise RuntimeError("cryptography library not available") # 生成盐值 - salt = secrets.token_hex(16) + salt = secrets.token_hex(16) # 派生密钥并哈希(用于验证) - key = self._derive_key(master_password, salt.encode()) - key_hash = hashlib.sha256(key).hexdigest() + key = self._derive_key(master_password, salt.encode()) + key_hash = hashlib.sha256(key).hexdigest() - config = EncryptionConfig( - id = self._generate_id(), - project_id = project_id, - is_enabled = True, - encryption_type = "aes-256-gcm", - key_derivation = "pbkdf2", - master_key_hash = key_hash, - salt = salt, + config = EncryptionConfig( + id=self._generate_id(), + project_id=project_id, + is_enabled=True, + encryption_type="aes-256-gcm", + key_derivation="pbkdf2", + master_key_hash=key_hash, + salt=salt, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() # 检查是否已存在配置 cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id, )) - existing = cursor.fetchone() + existing = cursor.fetchone() if existing: cursor.execute( @@ -566,7 +566,7 @@ class SecurityManager: project_id, ), ) - config.id = existing[0] + config.id = existing[0] else: cursor.execute( """ @@ -593,10 +593,10 @@ class SecurityManager: # 记录审计日志 self.log_audit( - action_type = AuditActionType.ENCRYPTION_ENABLE, - resource_type = "project", - resource_id = project_id, - action_details = {"encryption_type": config.encryption_type}, + action_type=AuditActionType.ENCRYPTION_ENABLE, + resource_type="project", + resource_id=project_id, + action_details={"encryption_type": config.encryption_type}, ) return config @@ -607,8 +607,8 @@ class SecurityManager: if not self.verify_encryption_password(project_id, master_password): return False - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -624,9 +624,9 @@ class SecurityManager: # 记录审计日志 self.log_audit( - action_type = AuditActionType.ENCRYPTION_DISABLE, - resource_type = "project", - resource_id = project_id, + action_type=AuditActionType.ENCRYPTION_DISABLE, + resource_type="project", + resource_id=project_id, ) return True @@ -636,60 +636,60 @@ class SecurityManager: if not CRYPTO_AVAILABLE: return False - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( "SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return False - stored_hash, salt = row - key = self._derive_key(password, salt.encode()) - key_hash = hashlib.sha256(key).hexdigest() + stored_hash, salt = row + key = self._derive_key(password, salt.encode()) + key_hash = hashlib.sha256(key).hexdigest() return key_hash == stored_hash def get_encryption_config(self, project_id: str) -> EncryptionConfig | None: """获取加密配置""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id, )) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return None return EncryptionConfig( - id = row[0], - project_id = row[1], - is_enabled = bool(row[2]), - encryption_type = row[3], - key_derivation = row[4], - master_key_hash = row[5], - salt = row[6], - created_at = row[7], - updated_at = row[8], + id=row[0], + project_id=row[1], + is_enabled=bool(row[2]), + encryption_type=row[3], + key_derivation=row[4], + master_key_hash=row[5], + salt=row[6], + created_at=row[7], + updated_at=row[8], ) - def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: + def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: """加密数据""" if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") if salt is None: - salt = secrets.token_hex(16) + salt = secrets.token_hex(16) - key = self._derive_key(password, salt.encode()) - f = Fernet(key) - encrypted = f.encrypt(data.encode()) + key = self._derive_key(password, salt.encode()) + f = Fernet(key) + encrypted = f.encrypt(data.encode()) return base64.b64encode(encrypted).decode(), salt @@ -698,9 +698,9 @@ class SecurityManager: if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") - key = self._derive_key(password, salt.encode()) - f = Fernet(key) - decrypted = f.decrypt(base64.b64decode(encrypted_data)) + key = self._derive_key(password, salt.encode()) + f = Fernet(key) + decrypted = f.decrypt(base64.b64decode(encrypted_data)) return decrypted.decode() @@ -711,31 +711,31 @@ class SecurityManager: project_id: str, name: str, rule_type: MaskingRuleType, - pattern: str | None = None, - replacement: str | None = None, - description: str | None = None, - priority: int = 0, + pattern: str | None = None, + replacement: str | None = None, + description: str | None = None, + priority: int = 0, ) -> MaskingRule: """创建脱敏规则""" # 使用预定义规则或自定义规则 if rule_type in self.DEFAULT_MASKING_RULES and not pattern: - default = self.DEFAULT_MASKING_RULES[rule_type] - pattern = default["pattern"] - replacement = replacement or default["replacement"] + default = self.DEFAULT_MASKING_RULES[rule_type] + pattern = default["pattern"] + replacement = replacement or default["replacement"] - rule = MaskingRule( - id = self._generate_id(), - project_id = project_id, - name = name, - rule_type = rule_type.value, - pattern = pattern or "", - replacement = replacement or "****", - description = description, - priority = priority, + rule = MaskingRule( + id=self._generate_id(), + project_id=project_id, + name=name, + rule_type=rule_type.value, + pattern=pattern or "", + replacement=replacement or "****", + description=description, + priority=priority, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -764,21 +764,21 @@ class SecurityManager: # 记录审计日志 self.log_audit( - action_type = AuditActionType.DATA_MASKING, - resource_type = "project", - resource_id = project_id, - action_details = {"action": "create_rule", "rule_name": name}, + action_type=AuditActionType.DATA_MASKING, + resource_type="project", + resource_id=project_id, + action_details={"action": "create_rule", "rule_name": name}, ) return rule - def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]: + def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]: """获取脱敏规则""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT * FROM masking_rules WHERE project_id = ?" - params = [project_id] + query = "SELECT * FROM masking_rules WHERE project_id = ?" + params = [project_id] if active_only: query += " AND is_active = 1" @@ -786,24 +786,24 @@ class SecurityManager: query += " ORDER BY priority DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() conn.close() - rules = [] + rules = [] for row in rows: rules.append( MaskingRule( - id = row[0], - project_id = row[1], - name = row[2], - rule_type = row[3], - pattern = row[4], - replacement = row[5], - is_active = bool(row[6]), - priority = row[7], - description = row[8], - created_at = row[9], - updated_at = row[10], + id=row[0], + project_id=row[1], + name=row[2], + rule_type=row[3], + pattern=row[4], + replacement=row[5], + is_active=bool(row[6]), + priority=row[7], + description=row[8], + created_at=row[9], + updated_at=row[10], ) ) @@ -811,13 +811,13 @@ class SecurityManager: def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None: """更新脱敏规则""" - allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] + allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - set_clauses = [] - params = [] + set_clauses = [] + params = [] for key, value in kwargs.items(): if key in allowed_fields: @@ -845,52 +845,52 @@ class SecurityManager: conn.close() # 获取更新后的规则 - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id, )) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return None return MaskingRule( - id = row[0], - project_id = row[1], - name = row[2], - rule_type = row[3], - pattern = row[4], - replacement = row[5], - is_active = bool(row[6]), - priority = row[7], - description = row[8], - created_at = row[9], - updated_at = row[10], + id=row[0], + project_id=row[1], + name=row[2], + rule_type=row[3], + pattern=row[4], + replacement=row[5], + is_active=bool(row[6]), + priority=row[7], + description=row[8], + created_at=row[9], + updated_at=row[10], ) def delete_masking_rule(self, rule_id: str) -> bool: """删除脱敏规则""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id, )) - success = cursor.rowcount > 0 + success = cursor.rowcount > 0 conn.commit() conn.close() return success def apply_masking( - self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None + self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None ) -> str: """应用脱敏规则到文本""" - rules = self.get_masking_rules(project_id) + rules = self.get_masking_rules(project_id) if not rules: return text - masked_text = text + masked_text = text for rule in rules: # 如果指定了规则类型,只应用指定类型的规则 @@ -898,7 +898,7 @@ class SecurityManager: continue try: - masked_text = re.sub(rule.pattern, rule.replacement, masked_text) + masked_text = re.sub(rule.pattern, rule.replacement, masked_text) except re.error: # 忽略无效的正则表达式 continue @@ -909,14 +909,14 @@ class SecurityManager: self, entity_data: dict[str, Any], project_id: str ) -> dict[str, Any]: """对实体数据应用脱敏""" - masked_data = entity_data.copy() + masked_data = entity_data.copy() # 对可能包含敏感信息的字段进行脱敏 - sensitive_fields = ["name", "definition", "description", "value"] + sensitive_fields = ["name", "definition", "description", "value"] for f in sensitive_fields: if f in masked_data and isinstance(masked_data[f], str): - masked_data[f] = self.apply_masking(masked_data[f], project_id) + masked_data[f] = self.apply_masking(masked_data[f], project_id) return masked_data @@ -926,30 +926,30 @@ class SecurityManager: self, project_id: str, name: str, - description: str | None = None, - allowed_users: list[str] | None = None, - allowed_roles: list[str] | None = None, - allowed_ips: list[str] | None = None, - time_restrictions: dict | None = None, - max_access_count: int | None = None, - require_approval: bool = False, + description: str | None = None, + allowed_users: list[str] | None = None, + allowed_roles: list[str] | None = None, + allowed_ips: list[str] | None = None, + time_restrictions: dict | None = None, + max_access_count: int | None = None, + require_approval: bool = False, ) -> DataAccessPolicy: """创建数据访问策略""" - policy = DataAccessPolicy( - id = self._generate_id(), - project_id = project_id, - name = name, - description = description, - allowed_users = json.dumps(allowed_users) if allowed_users else None, - allowed_roles = json.dumps(allowed_roles) if allowed_roles else None, - allowed_ips = json.dumps(allowed_ips) if allowed_ips else None, - time_restrictions = json.dumps(time_restrictions) if time_restrictions else None, - max_access_count = max_access_count, - require_approval = require_approval, + policy = DataAccessPolicy( + id=self._generate_id(), + project_id=project_id, + name=name, + description=description, + allowed_users=json.dumps(allowed_users) if allowed_users else None, + allowed_roles=json.dumps(allowed_roles) if allowed_roles else None, + allowed_ips=json.dumps(allowed_ips) if allowed_ips else None, + time_restrictions=json.dumps(time_restrictions) if time_restrictions else None, + max_access_count=max_access_count, + require_approval=require_approval, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -982,100 +982,100 @@ class SecurityManager: return policy def get_access_policies( - self, project_id: str, active_only: bool = True + self, project_id: str, active_only: bool = True ) -> list[DataAccessPolicy]: """获取数据访问策略""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT * FROM data_access_policies WHERE project_id = ?" - params = [project_id] + query = "SELECT * FROM data_access_policies WHERE project_id = ?" + params = [project_id] if active_only: query += " AND is_active = 1" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() conn.close() - policies = [] + policies = [] for row in rows: policies.append( DataAccessPolicy( - id = row[0], - project_id = row[1], - name = row[2], - description = row[3], - allowed_users = row[4], - allowed_roles = row[5], - allowed_ips = row[6], - time_restrictions = row[7], - max_access_count = row[8], - require_approval = bool(row[9]), - is_active = bool(row[10]), - created_at = row[11], - updated_at = row[12], + id=row[0], + project_id=row[1], + name=row[2], + description=row[3], + allowed_users=row[4], + allowed_roles=row[5], + allowed_ips=row[6], + time_restrictions=row[7], + max_access_count=row[8], + require_approval=bool(row[9]), + is_active=bool(row[10]), + created_at=row[11], + updated_at=row[12], ) ) return policies def check_access_permission( - self, policy_id: str, user_id: str, user_ip: str | None = None + self, policy_id: str, user_id: str, user_ip: str | None = None ) -> tuple[bool, str | None]: """检查访问权限""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id, ) ) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return False, "Policy not found or inactive" - policy = DataAccessPolicy( - id = row[0], - project_id = row[1], - name = row[2], - description = row[3], - allowed_users = row[4], - allowed_roles = row[5], - allowed_ips = row[6], - time_restrictions = row[7], - max_access_count = row[8], - require_approval = bool(row[9]), - is_active = bool(row[10]), - created_at = row[11], - updated_at = row[12], + policy = DataAccessPolicy( + id=row[0], + project_id=row[1], + name=row[2], + description=row[3], + allowed_users=row[4], + allowed_roles=row[5], + allowed_ips=row[6], + time_restrictions=row[7], + max_access_count=row[8], + require_approval=bool(row[9]), + is_active=bool(row[10]), + created_at=row[11], + updated_at=row[12], ) # 检查用户白名单 if policy.allowed_users: - allowed = json.loads(policy.allowed_users) + allowed = json.loads(policy.allowed_users) if user_id not in allowed: return False, "User not in allowed list" # 检查IP白名单 if policy.allowed_ips and user_ip: - allowed_ips = json.loads(policy.allowed_ips) - ip_allowed = False + allowed_ips = json.loads(policy.allowed_ips) + ip_allowed = False for ip_pattern in allowed_ips: if self._match_ip_pattern(user_ip, ip_pattern): - ip_allowed = True + ip_allowed = True break if not ip_allowed: return False, "IP not in allowed list" # 检查时间限制 if policy.time_restrictions: - restrictions = json.loads(policy.time_restrictions) - now = datetime.now() + restrictions = json.loads(policy.time_restrictions) + now = datetime.now() if "start_time" in restrictions and "end_time" in restrictions: - current_time = now.strftime("%H:%M") + current_time = now.strftime("%H:%M") if not (restrictions["start_time"] <= current_time <= restrictions["end_time"]): return False, "Access not allowed at this time" @@ -1086,8 +1086,8 @@ class SecurityManager: # 检查是否需要审批 if policy.require_approval: # 检查是否有有效的访问请求 - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -1098,7 +1098,7 @@ class SecurityManager: (policy_id, user_id, datetime.now().isoformat()), ) - request = cursor.fetchone() + request = cursor.fetchone() conn.close() if not request: @@ -1113,7 +1113,7 @@ class SecurityManager: try: if "/" in pattern: # CIDR 表示法 - network = ipaddress.ip_network(pattern, strict = False) + network = ipaddress.ip_network(pattern, strict=False) return ipaddress.ip_address(ip) in network else: # 精确匹配 @@ -1125,20 +1125,20 @@ class SecurityManager: self, policy_id: str, user_id: str, - request_reason: str | None = None, - expires_hours: int = 24, + request_reason: str | None = None, + expires_hours: int = 24, ) -> AccessRequest: """创建访问请求""" - request = AccessRequest( - id = self._generate_id(), - policy_id = policy_id, - user_id = user_id, - request_reason = request_reason, - expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat(), + request = AccessRequest( + id=self._generate_id(), + policy_id=policy_id, + user_id=user_id, + request_reason=request_reason, + expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(), ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -1163,14 +1163,14 @@ class SecurityManager: return request def approve_access_request( - self, request_id: str, approved_by: str, expires_hours: int = 24 + self, request_id: str, approved_by: str, expires_hours: int = 24 ) -> AccessRequest | None: """批准访问请求""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat() - approved_at = datetime.now().isoformat() + expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat() + approved_at = datetime.now().isoformat() cursor.execute( """ @@ -1185,28 +1185,28 @@ class SecurityManager: # 获取更新后的请求 cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, )) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return None return AccessRequest( - id = row[0], - policy_id = row[1], - user_id = row[2], - request_reason = row[3], - status = row[4], - approved_by = row[5], - approved_at = row[6], - expires_at = row[7], - created_at = row[8], + id=row[0], + policy_id=row[1], + user_id=row[2], + request_reason=row[3], + status=row[4], + approved_by=row[5], + approved_at=row[6], + expires_at=row[7], + created_at=row[8], ) def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None: """拒绝访问请求""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -1220,32 +1220,32 @@ class SecurityManager: conn.commit() cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, )) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return None return AccessRequest( - id = row[0], - policy_id = row[1], - user_id = row[2], - request_reason = row[3], - status = row[4], - approved_by = row[5], - approved_at = row[6], - expires_at = row[7], - created_at = row[8], + id=row[0], + policy_id=row[1], + user_id=row[2], + request_reason=row[3], + status=row[4], + approved_by=row[5], + approved_at=row[6], + expires_at=row[7], + created_at=row[8], ) # 全局安全管理器实例 -_security_manager = None +_security_manager = None -def get_security_manager(db_path: str = "insightflow.db") -> SecurityManager: +def get_security_manager(db_path: str = "insightflow.db") -> SecurityManager: """获取安全管理器实例""" global _security_manager if _security_manager is None: - _security_manager = SecurityManager(db_path) + _security_manager = SecurityManager(db_path) return _security_manager diff --git a/backend/subscription_manager.py b/backend/subscription_manager.py index 87f0f89..29f074f 100644 --- a/backend/subscription_manager.py +++ b/backend/subscription_manager.py @@ -19,59 +19,59 @@ from datetime import datetime, timedelta from enum import StrEnum from typing import Any -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class SubscriptionStatus(StrEnum): """订阅状态""" - ACTIVE = "active" # 活跃 - CANCELLED = "cancelled" # 已取消 - EXPIRED = "expired" # 已过期 - PAST_DUE = "past_due" # 逾期 - TRIAL = "trial" # 试用中 - PENDING = "pending" # 待支付 + ACTIVE = "active" # 活跃 + CANCELLED = "cancelled" # 已取消 + EXPIRED = "expired" # 已过期 + PAST_DUE = "past_due" # 逾期 + TRIAL = "trial" # 试用中 + PENDING = "pending" # 待支付 class PaymentProvider(StrEnum): """支付提供商""" - STRIPE = "stripe" # Stripe - ALIPAY = "alipay" # 支付宝 - WECHAT = "wechat" # 微信支付 - BANK_TRANSFER = "bank_transfer" # 银行转账 + STRIPE = "stripe" # Stripe + ALIPAY = "alipay" # 支付宝 + WECHAT = "wechat" # 微信支付 + BANK_TRANSFER = "bank_transfer" # 银行转账 class PaymentStatus(StrEnum): """支付状态""" - PENDING = "pending" # 待支付 - PROCESSING = "processing" # 处理中 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 - REFUNDED = "refunded" # 已退款 - PARTIAL_REFUNDED = "partial_refunded" # 部分退款 + PENDING = "pending" # 待支付 + PROCESSING = "processing" # 处理中 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 + REFUNDED = "refunded" # 已退款 + PARTIAL_REFUNDED = "partial_refunded" # 部分退款 class InvoiceStatus(StrEnum): """发票状态""" - DRAFT = "draft" # 草稿 - ISSUED = "issued" # 已开具 - PAID = "paid" # 已支付 - OVERDUE = "overdue" # 逾期 - VOID = "void" # 作废 - CREDIT_NOTE = "credit_note" # 贷项通知单 + DRAFT = "draft" # 草稿 + ISSUED = "issued" # 已开具 + PAID = "paid" # 已支付 + OVERDUE = "overdue" # 逾期 + VOID = "void" # 作废 + CREDIT_NOTE = "credit_note" # 贷项通知单 class RefundStatus(StrEnum): """退款状态""" - PENDING = "pending" # 待处理 - APPROVED = "approved" # 已批准 - REJECTED = "rejected" # 已拒绝 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 + PENDING = "pending" # 待处理 + APPROVED = "approved" # 已批准 + REJECTED = "rejected" # 已拒绝 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 @dataclass @@ -218,7 +218,7 @@ class SubscriptionManager: """订阅与计费管理器""" # 默认订阅计划配置 - DEFAULT_PLANS = { + DEFAULT_PLANS = { "free": { "name": "Free", "tier": "free", @@ -298,7 +298,7 @@ class SubscriptionManager: } # 按量计费单价(CNY) - USAGE_PRICING = { + USAGE_PRICING = { "transcription": { "unit": "minute", "price": 0.5, @@ -313,22 +313,22 @@ class SubscriptionManager: "export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页(PDF导出) } - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_db() self._init_default_plans() def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_db(self) -> None: """初始化数据库表""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 订阅计划表 cursor.execute(""" @@ -528,9 +528,9 @@ class SubscriptionManager: def _init_default_plans(self) -> None: """初始化默认订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() for tier, plan_data in self.DEFAULT_PLANS.items(): cursor.execute( @@ -569,11 +569,11 @@ class SubscriptionManager: def get_plan(self, plan_id: str) -> SubscriptionPlan | None: """获取订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_plan(row) @@ -584,13 +584,13 @@ class SubscriptionManager: def get_plan_by_tier(self, tier: str) -> SubscriptionPlan | None: """通过层级获取订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( "SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier, ) ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_plan(row) @@ -599,11 +599,11 @@ class SubscriptionManager: finally: conn.close() - def list_plans(self, include_inactive: bool = False) -> list[SubscriptionPlan]: + def list_plans(self, include_inactive: bool = False) -> list[SubscriptionPlan]: """列出所有订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() if include_inactive: cursor.execute("SELECT * FROM subscription_plans ORDER BY price_monthly") @@ -612,7 +612,7 @@ class SubscriptionManager: "SELECT * FROM subscription_plans WHERE is_active = 1 ORDER BY price_monthly" ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_plan(row) for row in rows] finally: @@ -625,32 +625,32 @@ class SubscriptionManager: description: str, price_monthly: float, price_yearly: float, - currency: str = "CNY", - features: list[str] = None, - limits: dict[str, Any] = None, + currency: str = "CNY", + features: list[str] = None, + limits: dict[str, Any] = None, ) -> SubscriptionPlan: """创建新订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - plan_id = str(uuid.uuid4()) + plan_id = str(uuid.uuid4()) - plan = SubscriptionPlan( - id = plan_id, - name = name, - tier = tier, - description = description, - price_monthly = price_monthly, - price_yearly = price_yearly, - currency = currency, - features = features or [], - limits = limits or {}, - is_active = True, - created_at = datetime.now(), - updated_at = datetime.now(), - metadata = {}, + plan = SubscriptionPlan( + id=plan_id, + name=name, + tier=tier, + description=description, + price_monthly=price_monthly, + price_yearly=price_yearly, + currency=currency, + features=features or [], + limits=limits or {}, + is_active=True, + created_at=datetime.now(), + updated_at=datetime.now(), + metadata={}, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO subscription_plans @@ -688,16 +688,16 @@ class SubscriptionManager: def update_plan(self, plan_id: str, **kwargs) -> SubscriptionPlan | None: """更新订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - plan = self.get_plan(plan_id) + plan = self.get_plan(plan_id) if not plan: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "name", "description", "price_monthly", @@ -725,7 +725,7 @@ class SubscriptionManager: params.append(datetime.now()) params.append(plan_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE subscription_plans SET {", ".join(updates)} @@ -746,15 +746,15 @@ class SubscriptionManager: self, tenant_id: str, plan_id: str, - payment_provider: str | None = None, - trial_days: int = 0, - billing_cycle: str = "monthly", + payment_provider: str | None = None, + trial_days: int = 0, + billing_cycle: str = "monthly", ) -> Subscription: """创建新订阅""" - conn = self._get_connection() + conn = self._get_connection() try: # 检查是否已有活跃订阅 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM subscriptions @@ -763,50 +763,50 @@ class SubscriptionManager: (tenant_id, ), ) - existing = cursor.fetchone() + existing = cursor.fetchone() if existing: raise ValueError(f"Tenant {tenant_id} already has an active subscription") # 获取计划信息 - plan = self.get_plan(plan_id) + plan = self.get_plan(plan_id) if not plan: raise ValueError(f"Plan {plan_id} not found") - subscription_id = str(uuid.uuid4()) - now = datetime.now() + subscription_id = str(uuid.uuid4()) + now = datetime.now() # 计算周期 if billing_cycle == "yearly": - period_end = now + timedelta(days = 365) + period_end = now + timedelta(days=365) else: - period_end = now + timedelta(days = 30) + period_end = now + timedelta(days=30) # 试用处理 - trial_start = None - trial_end = None + trial_start = None + trial_end = None if trial_days > 0: - trial_start = now - trial_end = now + timedelta(days = trial_days) - status = SubscriptionStatus.TRIAL.value + trial_start = now + trial_end = now + timedelta(days=trial_days) + status = SubscriptionStatus.TRIAL.value else: - status = SubscriptionStatus.PENDING.value + status = SubscriptionStatus.PENDING.value - subscription = Subscription( - id = subscription_id, - tenant_id = tenant_id, - plan_id = plan_id, - status = status, - current_period_start = now, - current_period_end = period_end, - cancel_at_period_end = False, - canceled_at = None, - trial_start = trial_start, - trial_end = trial_end, - payment_provider = payment_provider, - provider_subscription_id = None, - created_at = now, - updated_at = now, - metadata = {"billing_cycle": billing_cycle}, + subscription = Subscription( + id=subscription_id, + tenant_id=tenant_id, + plan_id=plan_id, + status=status, + current_period_start=now, + current_period_end=period_end, + cancel_at_period_end=False, + canceled_at=None, + trial_start=trial_start, + trial_end=trial_end, + payment_provider=payment_provider, + provider_subscription_id=None, + created_at=now, + updated_at=now, + metadata={"billing_cycle": billing_cycle}, ) cursor.execute( @@ -837,7 +837,7 @@ class SubscriptionManager: ) # 创建发票 - amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly + amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly if amount > 0 and trial_days == 0: self._create_invoice_internal( conn, @@ -875,11 +875,11 @@ class SubscriptionManager: def get_subscription(self, subscription_id: str) -> Subscription | None: """获取订阅信息""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_subscription(row) @@ -890,9 +890,9 @@ class SubscriptionManager: def get_tenant_subscription(self, tenant_id: str) -> Subscription | None: """获取租户的当前订阅""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM subscriptions @@ -901,7 +901,7 @@ class SubscriptionManager: """, (tenant_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_subscription(row) @@ -912,16 +912,16 @@ class SubscriptionManager: def update_subscription(self, subscription_id: str, **kwargs) -> Subscription | None: """更新订阅""" - conn = self._get_connection() + conn = self._get_connection() try: - subscription = self.get_subscription(subscription_id) + subscription = self.get_subscription(subscription_id) if not subscription: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "status", "current_period_start", "current_period_end", @@ -947,7 +947,7 @@ class SubscriptionManager: params.append(datetime.now()) params.append(subscription_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE subscriptions SET {", ".join(updates)} @@ -963,20 +963,20 @@ class SubscriptionManager: conn.close() def cancel_subscription( - self, subscription_id: str, at_period_end: bool = True + self, subscription_id: str, at_period_end: bool = True ) -> Subscription | None: """取消订阅""" - conn = self._get_connection() + conn = self._get_connection() try: - subscription = self.get_subscription(subscription_id) + subscription = self.get_subscription(subscription_id) if not subscription: return None - now = datetime.now() + now = datetime.now() if at_period_end: # 在周期结束时取消 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE subscriptions @@ -987,7 +987,7 @@ class SubscriptionManager: ) else: # 立即取消 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE subscriptions @@ -1017,29 +1017,29 @@ class SubscriptionManager: conn.close() def change_plan( - self, subscription_id: str, new_plan_id: str, prorate: bool = True + self, subscription_id: str, new_plan_id: str, prorate: bool = True ) -> Subscription | None: """更改订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - subscription = self.get_subscription(subscription_id) + subscription = self.get_subscription(subscription_id) if not subscription: return None - old_plan = self.get_plan(subscription.plan_id) - new_plan = self.get_plan(new_plan_id) + old_plan = self.get_plan(subscription.plan_id) + new_plan = self.get_plan(new_plan_id) if not new_plan: raise ValueError(f"Plan {new_plan_id} not found") - now = datetime.now() + now = datetime.now() # 按比例计算差价(简化实现) if prorate and old_plan: # 这里应该实现实际的按比例计算逻辑 pass - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE subscriptions @@ -1076,29 +1076,29 @@ class SubscriptionManager: resource_type: str, quantity: float, unit: str, - description: str | None = None, - metadata: dict | None = None, + description: str | None = None, + metadata: dict | None = None, ) -> UsageRecord: """记录用量""" - conn = self._get_connection() + conn = self._get_connection() try: # 计算费用 - cost = self._calculate_usage_cost(resource_type, quantity) + cost = self._calculate_usage_cost(resource_type, quantity) - record_id = str(uuid.uuid4()) - record = UsageRecord( - id = record_id, - tenant_id = tenant_id, - resource_type = resource_type, - quantity = quantity, - unit = unit, - recorded_at = datetime.now(), - cost = cost, - description = description, - metadata = metadata or {}, + record_id = str(uuid.uuid4()) + record = UsageRecord( + id=record_id, + tenant_id=tenant_id, + resource_type=resource_type, + quantity=quantity, + unit=unit, + recorded_at=datetime.now(), + cost=cost, + description=description, + metadata=metadata or {}, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO usage_records @@ -1125,14 +1125,14 @@ class SubscriptionManager: conn.close() def get_usage_summary( - self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None + self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None ) -> dict[str, Any]: """获取用量汇总""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = """ + query = """ SELECT resource_type, SUM(quantity) as total_quantity, @@ -1141,7 +1141,7 @@ class SubscriptionManager: FROM usage_records WHERE tenant_id = ? """ - params = [tenant_id] + params = [tenant_id] if start_date: query += " AND recorded_at >= ?" @@ -1153,13 +1153,13 @@ class SubscriptionManager: query += " GROUP BY resource_type" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() - summary = {} - total_cost = 0 + summary = {} + total_cost = 0 for row in rows: - summary[row["resource_type"]] = { + summary[row["resource_type"]] = { "quantity": row["total_quantity"], "cost": row["total_cost"], "records": row["record_count"], @@ -1181,12 +1181,12 @@ class SubscriptionManager: def _calculate_usage_cost(self, resource_type: str, quantity: float) -> float: """计算用量费用""" - pricing = self.USAGE_PRICING.get(resource_type) + pricing = self.USAGE_PRICING.get(resource_type) if not pricing: return 0.0 # 扣除免费额度 - chargeable = max(0, quantity - pricing.get("free_quota", 0)) + chargeable = max(0, quantity - pricing.get("free_quota", 0)) # 计算费用 if pricing["unit"] == "1000_calls": @@ -1202,37 +1202,37 @@ class SubscriptionManager: amount: float, currency: str, provider: str, - subscription_id: str | None = None, - invoice_id: str | None = None, - payment_method: str | None = None, - payment_details: dict | None = None, + subscription_id: str | None = None, + invoice_id: str | None = None, + payment_method: str | None = None, + payment_details: dict | None = None, ) -> Payment: """创建支付记录""" - conn = self._get_connection() + conn = self._get_connection() try: - payment_id = str(uuid.uuid4()) - now = datetime.now() + payment_id = str(uuid.uuid4()) + now = datetime.now() - payment = Payment( - id = payment_id, - tenant_id = tenant_id, - subscription_id = subscription_id, - invoice_id = invoice_id, - amount = amount, - currency = currency, - provider = provider, - provider_payment_id = None, - status = PaymentStatus.PENDING.value, - payment_method = payment_method, - payment_details = payment_details or {}, - paid_at = None, - failed_at = None, - failure_reason = None, - created_at = now, - updated_at = now, + payment = Payment( + id=payment_id, + tenant_id=tenant_id, + subscription_id=subscription_id, + invoice_id=invoice_id, + amount=amount, + currency=currency, + provider=provider, + provider_payment_id=None, + status=PaymentStatus.PENDING.value, + payment_method=payment_method, + payment_details=payment_details or {}, + paid_at=None, + failed_at=None, + failure_reason=None, + created_at=now, + updated_at=now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO payments @@ -1268,18 +1268,18 @@ class SubscriptionManager: conn.close() def confirm_payment( - self, payment_id: str, provider_payment_id: str | None = None + self, payment_id: str, provider_payment_id: str | None = None ) -> Payment | None: """确认支付完成""" - conn = self._get_connection() + conn = self._get_connection() try: - payment = self._get_payment_internal(conn, payment_id) + payment = self._get_payment_internal(conn, payment_id) if not payment: return None - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE payments @@ -1332,11 +1332,11 @@ class SubscriptionManager: def fail_payment(self, payment_id: str, reason: str) -> Payment | None: """标记支付失败""" - conn = self._get_connection() + conn = self._get_connection() try: - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE payments @@ -1354,22 +1354,22 @@ class SubscriptionManager: def get_payment(self, payment_id: str) -> Payment | None: """获取支付记录""" - conn = self._get_connection() + conn = self._get_connection() try: return self._get_payment_internal(conn, payment_id) finally: conn.close() def list_payments( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Payment]: """列出支付记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM payments WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM payments WHERE tenant_id = ?" + params = [tenant_id] if status: query += " AND status = ?" @@ -1379,7 +1379,7 @@ class SubscriptionManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_payment(row) for row in rows] @@ -1388,9 +1388,9 @@ class SubscriptionManager: def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None: """内部方法:获取支付记录""" - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_payment(row) @@ -1408,36 +1408,36 @@ class SubscriptionManager: period_start: datetime, period_end: datetime, description: str, - line_items: list[dict] | None = None, + line_items: list[dict] | None = None, ) -> Invoice: """内部方法:创建发票""" - invoice_id = str(uuid.uuid4()) - invoice_number = self._generate_invoice_number() - now = datetime.now() - due_date = now + timedelta(days = 7) # 7天付款期限 + invoice_id = str(uuid.uuid4()) + invoice_number = self._generate_invoice_number() + now = datetime.now() + due_date = now + timedelta(days=7) # 7天付款期限 - invoice = Invoice( - id = invoice_id, - tenant_id = tenant_id, - subscription_id = subscription_id, - invoice_number = invoice_number, - status = InvoiceStatus.DRAFT.value, - amount_due = amount, - amount_paid = 0, - currency = currency, - period_start = period_start, - period_end = period_end, - description = description, - line_items = line_items or [{"description": description, "amount": amount}], - due_date = due_date, - paid_at = None, - voided_at = None, - void_reason = None, - created_at = now, - updated_at = now, + invoice = Invoice( + id=invoice_id, + tenant_id=tenant_id, + subscription_id=subscription_id, + invoice_number=invoice_number, + status=InvoiceStatus.DRAFT.value, + amount_due=amount, + amount_paid=0, + currency=currency, + period_start=period_start, + period_end=period_end, + description=description, + line_items=line_items or [{"description": description, "amount": amount}], + due_date=due_date, + paid_at=None, + voided_at=None, + void_reason=None, + created_at=now, + updated_at=now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO invoices @@ -1472,11 +1472,11 @@ class SubscriptionManager: def get_invoice(self, invoice_id: str) -> Invoice | None: """获取发票""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_invoice(row) @@ -1487,11 +1487,11 @@ class SubscriptionManager: def get_invoice_by_number(self, invoice_number: str) -> Invoice | None: """通过发票号获取发票""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_invoice(row) @@ -1501,15 +1501,15 @@ class SubscriptionManager: conn.close() def list_invoices( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Invoice]: """列出发票""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM invoices WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM invoices WHERE tenant_id = ?" + params = [tenant_id] if status: query += " AND status = ?" @@ -1519,7 +1519,7 @@ class SubscriptionManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_invoice(row) for row in rows] @@ -1528,18 +1528,18 @@ class SubscriptionManager: def void_invoice(self, invoice_id: str, reason: str) -> Invoice | None: """作废发票""" - conn = self._get_connection() + conn = self._get_connection() try: - invoice = self.get_invoice(invoice_id) + invoice = self.get_invoice(invoice_id) if not invoice: return None if invoice.status == InvoiceStatus.PAID.value: raise ValueError("Cannot void a paid invoice") - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE invoices @@ -1557,12 +1557,12 @@ class SubscriptionManager: def _generate_invoice_number(self) -> str: """生成发票号""" - now = datetime.now() - prefix = f"INV-{now.strftime('%Y%m')}" + now = datetime.now() + prefix = f"INV-{now.strftime('%Y%m')}" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT COUNT(*) as count FROM invoices @@ -1570,8 +1570,8 @@ class SubscriptionManager: """, (f"{prefix}%", ), ) - row = cursor.fetchone() - count = row["count"] + 1 + row = cursor.fetchone() + count = row["count"] + 1 return f"{prefix}-{count:06d}" @@ -1584,10 +1584,10 @@ class SubscriptionManager: self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str ) -> Refund: """申请退款""" - conn = self._get_connection() + conn = self._get_connection() try: # 验证支付记录 - payment = self._get_payment_internal(conn, payment_id) + payment = self._get_payment_internal(conn, payment_id) if not payment: raise ValueError(f"Payment {payment_id} not found") @@ -1600,30 +1600,30 @@ class SubscriptionManager: if amount > payment.amount: raise ValueError("Refund amount cannot exceed payment amount") - refund_id = str(uuid.uuid4()) - now = datetime.now() + refund_id = str(uuid.uuid4()) + now = datetime.now() - refund = Refund( - id = refund_id, - tenant_id = tenant_id, - payment_id = payment_id, - invoice_id = payment.invoice_id, - amount = amount, - currency = payment.currency, - reason = reason, - status = RefundStatus.PENDING.value, - requested_by = requested_by, - requested_at = now, - approved_by = None, - approved_at = None, - completed_at = None, - provider_refund_id = None, - metadata = {}, - created_at = now, - updated_at = now, + refund = Refund( + id=refund_id, + tenant_id=tenant_id, + payment_id=payment_id, + invoice_id=payment.invoice_id, + amount=amount, + currency=payment.currency, + reason=reason, + status=RefundStatus.PENDING.value, + requested_by=requested_by, + requested_at=now, + approved_by=None, + approved_at=None, + completed_at=None, + provider_refund_id=None, + metadata={}, + created_at=now, + updated_at=now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO refunds @@ -1662,18 +1662,18 @@ class SubscriptionManager: def approve_refund(self, refund_id: str, approved_by: str) -> Refund | None: """批准退款""" - conn = self._get_connection() + conn = self._get_connection() try: - refund = self._get_refund_internal(conn, refund_id) + refund = self._get_refund_internal(conn, refund_id) if not refund: return None if refund.status != RefundStatus.PENDING.value: raise ValueError("Can only approve pending refunds") - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE refunds @@ -1690,18 +1690,18 @@ class SubscriptionManager: conn.close() def complete_refund( - self, refund_id: str, provider_refund_id: str | None = None + self, refund_id: str, provider_refund_id: str | None = None ) -> Refund | None: """完成退款""" - conn = self._get_connection() + conn = self._get_connection() try: - refund = self._get_refund_internal(conn, refund_id) + refund = self._get_refund_internal(conn, refund_id) if not refund: return None - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE refunds @@ -1742,15 +1742,15 @@ class SubscriptionManager: def reject_refund(self, refund_id: str, reason: str) -> Refund | None: """拒绝退款""" - conn = self._get_connection() + conn = self._get_connection() try: - refund = self._get_refund_internal(conn, refund_id) + refund = self._get_refund_internal(conn, refund_id) if not refund: return None - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE refunds @@ -1768,22 +1768,22 @@ class SubscriptionManager: def get_refund(self, refund_id: str) -> Refund | None: """获取退款记录""" - conn = self._get_connection() + conn = self._get_connection() try: return self._get_refund_internal(conn, refund_id) finally: conn.close() def list_refunds( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Refund]: """列出退款记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM refunds WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM refunds WHERE tenant_id = ?" + params = [tenant_id] if status: query += " AND status = ?" @@ -1793,7 +1793,7 @@ class SubscriptionManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_refund(row) for row in rows] @@ -1802,9 +1802,9 @@ class SubscriptionManager: def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None: """内部方法:获取退款记录""" - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_refund(row) @@ -1824,9 +1824,9 @@ class SubscriptionManager: balance_after: float, ) -> None: """内部方法:添加账单历史""" - history_id = str(uuid.uuid4()) + history_id = str(uuid.uuid4()) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO billing_history @@ -1850,18 +1850,18 @@ class SubscriptionManager: def get_billing_history( self, tenant_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None, - limit: int = 100, - offset: int = 0, + start_date: datetime | None = None, + end_date: datetime | None = None, + limit: int = 100, + offset: int = 0, ) -> list[BillingHistory]: """获取账单历史""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM billing_history WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM billing_history WHERE tenant_id = ?" + params = [tenant_id] if start_date: query += " AND created_at >= ?" @@ -1874,7 +1874,7 @@ class SubscriptionManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_billing_history(row) for row in rows] @@ -1889,7 +1889,7 @@ class SubscriptionManager: plan_id: str, success_url: str, cancel_url: str, - billing_cycle: str = "monthly", + billing_cycle: str = "monthly", ) -> dict[str, Any]: """创建 Stripe Checkout 会话(占位实现)""" # 这里应该集成 Stripe SDK @@ -1902,12 +1902,12 @@ class SubscriptionManager: } def create_alipay_order( - self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" + self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" ) -> dict[str, Any]: """创建支付宝订单(占位实现)""" # 这里应该集成支付宝 SDK - plan = self.get_plan(plan_id) - amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly + plan = self.get_plan(plan_id) + amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly return { "order_id": f"ALI{datetime.now().strftime('%Y%m%d%H%M%S')}{uuid.uuid4().hex[:8].upper()}", @@ -1919,12 +1919,12 @@ class SubscriptionManager: } def create_wechat_order( - self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" + self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" ) -> dict[str, Any]: """创建微信支付订单(占位实现)""" # 这里应该集成微信支付 SDK - plan = self.get_plan(plan_id) - amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly + plan = self.get_plan(plan_id) + amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly return { "order_id": f"WX{datetime.now().strftime('%Y%m%d%H%M%S')}{uuid.uuid4().hex[:8].upper()}", @@ -1940,7 +1940,7 @@ class SubscriptionManager: # 这里应该实现实际的 Webhook 处理逻辑 logger.info(f"Received webhook from {provider}: {payload.get('event_type', 'unknown')}") - event_type = payload.get("event_type", "") + event_type = payload.get("event_type", "") if provider == "stripe": if event_type == "checkout.session.completed": @@ -1962,126 +1962,126 @@ class SubscriptionManager: def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan: """数据库行转换为 SubscriptionPlan 对象""" return SubscriptionPlan( - id = row["id"], - name = row["name"], - tier = row["tier"], - description = row["description"] or "", - price_monthly = row["price_monthly"], - price_yearly = row["price_yearly"], - currency = row["currency"], - features = json.loads(row["features"] or "[]"), - limits = json.loads(row["limits"] or "{}"), - is_active = bool(row["is_active"]), - created_at = ( + id=row["id"], + name=row["name"], + tier=row["tier"], + description=row["description"] or "", + price_monthly=row["price_monthly"], + price_yearly=row["price_yearly"], + currency=row["currency"], + features=json.loads(row["features"] or "[]"), + limits=json.loads(row["limits"] or "{}"), + is_active=bool(row["is_active"]), + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - metadata = json.loads(row["metadata"] or "{}"), + metadata=json.loads(row["metadata"] or "{}"), ) def _row_to_subscription(self, row: sqlite3.Row) -> Subscription: """数据库行转换为 Subscription 对象""" return Subscription( - id = row["id"], - tenant_id = row["tenant_id"], - plan_id = row["plan_id"], - status = row["status"], - current_period_start = ( + id=row["id"], + tenant_id=row["tenant_id"], + plan_id=row["plan_id"], + status=row["status"], + current_period_start=( datetime.fromisoformat(row["current_period_start"]) if row["current_period_start"] and isinstance(row["current_period_start"], str) else row["current_period_start"] ), - current_period_end = ( + current_period_end=( datetime.fromisoformat(row["current_period_end"]) if row["current_period_end"] and isinstance(row["current_period_end"], str) else row["current_period_end"] ), - cancel_at_period_end = bool(row["cancel_at_period_end"]), - canceled_at = ( + cancel_at_period_end=bool(row["cancel_at_period_end"]), + canceled_at=( datetime.fromisoformat(row["canceled_at"]) if row["canceled_at"] and isinstance(row["canceled_at"], str) else row["canceled_at"] ), - trial_start = ( + trial_start=( datetime.fromisoformat(row["trial_start"]) if row["trial_start"] and isinstance(row["trial_start"], str) else row["trial_start"] ), - trial_end = ( + trial_end=( datetime.fromisoformat(row["trial_end"]) if row["trial_end"] and isinstance(row["trial_end"], str) else row["trial_end"] ), - payment_provider = row["payment_provider"], - provider_subscription_id = row["provider_subscription_id"], - created_at = ( + payment_provider=row["payment_provider"], + provider_subscription_id=row["provider_subscription_id"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - metadata = json.loads(row["metadata"] or "{}"), + metadata=json.loads(row["metadata"] or "{}"), ) def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord: """数据库行转换为 UsageRecord 对象""" return UsageRecord( - id = row["id"], - tenant_id = row["tenant_id"], - resource_type = row["resource_type"], - quantity = row["quantity"], - unit = row["unit"], - recorded_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + resource_type=row["resource_type"], + quantity=row["quantity"], + unit=row["unit"], + recorded_at=( datetime.fromisoformat(row["recorded_at"]) if isinstance(row["recorded_at"], str) else row["recorded_at"] ), - cost = row["cost"], - description = row["description"], - metadata = json.loads(row["metadata"] or "{}"), + cost=row["cost"], + description=row["description"], + metadata=json.loads(row["metadata"] or "{}"), ) def _row_to_payment(self, row: sqlite3.Row) -> Payment: """数据库行转换为 Payment 对象""" return Payment( - id = row["id"], - tenant_id = row["tenant_id"], - subscription_id = row["subscription_id"], - invoice_id = row["invoice_id"], - amount = row["amount"], - currency = row["currency"], - provider = row["provider"], - provider_payment_id = row["provider_payment_id"], - status = row["status"], - payment_method = row["payment_method"], - payment_details = json.loads(row["payment_details"] or "{}"), - paid_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + subscription_id=row["subscription_id"], + invoice_id=row["invoice_id"], + amount=row["amount"], + currency=row["currency"], + provider=row["provider"], + provider_payment_id=row["provider_payment_id"], + status=row["status"], + payment_method=row["payment_method"], + payment_details=json.loads(row["payment_details"] or "{}"), + paid_at=( datetime.fromisoformat(row["paid_at"]) if row["paid_at"] and isinstance(row["paid_at"], str) else row["paid_at"] ), - failed_at = ( + failed_at=( datetime.fromisoformat(row["failed_at"]) if row["failed_at"] and isinstance(row["failed_at"], str) else row["failed_at"] ), - failure_reason = row["failure_reason"], - created_at = ( + failure_reason=row["failure_reason"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2091,48 +2091,48 @@ class SubscriptionManager: def _row_to_invoice(self, row: sqlite3.Row) -> Invoice: """数据库行转换为 Invoice 对象""" return Invoice( - id = row["id"], - tenant_id = row["tenant_id"], - subscription_id = row["subscription_id"], - invoice_number = row["invoice_number"], - status = row["status"], - amount_due = row["amount_due"], - amount_paid = row["amount_paid"], - currency = row["currency"], - period_start = ( + id=row["id"], + tenant_id=row["tenant_id"], + subscription_id=row["subscription_id"], + invoice_number=row["invoice_number"], + status=row["status"], + amount_due=row["amount_due"], + amount_paid=row["amount_paid"], + currency=row["currency"], + period_start=( datetime.fromisoformat(row["period_start"]) if row["period_start"] and isinstance(row["period_start"], str) else row["period_start"] ), - period_end = ( + period_end=( datetime.fromisoformat(row["period_end"]) if row["period_end"] and isinstance(row["period_end"], str) else row["period_end"] ), - description = row["description"], - line_items = json.loads(row["line_items"] or "[]"), - due_date = ( + description=row["description"], + line_items=json.loads(row["line_items"] or "[]"), + due_date=( datetime.fromisoformat(row["due_date"]) if row["due_date"] and isinstance(row["due_date"], str) else row["due_date"] ), - paid_at = ( + paid_at=( datetime.fromisoformat(row["paid_at"]) if row["paid_at"] and isinstance(row["paid_at"], str) else row["paid_at"] ), - voided_at = ( + voided_at=( datetime.fromisoformat(row["voided_at"]) if row["voided_at"] and isinstance(row["voided_at"], str) else row["voided_at"] ), - void_reason = row["void_reason"], - created_at = ( + void_reason=row["void_reason"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2142,39 +2142,39 @@ class SubscriptionManager: def _row_to_refund(self, row: sqlite3.Row) -> Refund: """数据库行转换为 Refund 对象""" return Refund( - id = row["id"], - tenant_id = row["tenant_id"], - payment_id = row["payment_id"], - invoice_id = row["invoice_id"], - amount = row["amount"], - currency = row["currency"], - reason = row["reason"], - status = row["status"], - requested_by = row["requested_by"], - requested_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + payment_id=row["payment_id"], + invoice_id=row["invoice_id"], + amount=row["amount"], + currency=row["currency"], + reason=row["reason"], + status=row["status"], + requested_by=row["requested_by"], + requested_at=( datetime.fromisoformat(row["requested_at"]) if isinstance(row["requested_at"], str) else row["requested_at"] ), - approved_by = row["approved_by"], - approved_at = ( + approved_by=row["approved_by"], + approved_at=( datetime.fromisoformat(row["approved_at"]) if row["approved_at"] and isinstance(row["approved_at"], str) else row["approved_at"] ), - completed_at = ( + completed_at=( datetime.fromisoformat(row["completed_at"]) if row["completed_at"] and isinstance(row["completed_at"], str) else row["completed_at"] ), - provider_refund_id = row["provider_refund_id"], - metadata = json.loads(row["metadata"] or "{}"), - created_at = ( + provider_refund_id=row["provider_refund_id"], + metadata=json.loads(row["metadata"] or "{}"), + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2184,30 +2184,30 @@ class SubscriptionManager: def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory: """数据库行转换为 BillingHistory 对象""" return BillingHistory( - id = row["id"], - tenant_id = row["tenant_id"], - type = row["type"], - amount = row["amount"], - currency = row["currency"], - description = row["description"], - reference_id = row["reference_id"], - balance_after = row["balance_after"], - created_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + type=row["type"], + amount=row["amount"], + currency=row["currency"], + description=row["description"], + reference_id=row["reference_id"], + balance_after=row["balance_after"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - metadata = json.loads(row["metadata"] or "{}"), + metadata=json.loads(row["metadata"] or "{}"), ) # 全局订阅管理器实例 -subscription_manager = None +subscription_manager = None -def get_subscription_manager(db_path: str = "insightflow.db") -> SubscriptionManager: +def get_subscription_manager(db_path: str = "insightflow.db") -> SubscriptionManager: """获取订阅管理器实例(单例模式)""" global subscription_manager if subscription_manager is None: - subscription_manager = SubscriptionManager(db_path) + subscription_manager = SubscriptionManager(db_path) return subscription_manager diff --git a/backend/tenant_manager.py b/backend/tenant_manager.py index 26c2034..66390b6 100644 --- a/backend/tenant_manager.py +++ b/backend/tenant_manager.py @@ -21,63 +21,63 @@ from datetime import datetime from enum import StrEnum from typing import Any -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class TenantLimits: """租户资源限制常量""" - FREE_MAX_PROJECTS = 3 - FREE_MAX_STORAGE_MB = 100 - FREE_MAX_TRANSCRIPTION_MINUTES = 60 - FREE_MAX_API_CALLS_PER_DAY = 100 - FREE_MAX_TEAM_MEMBERS = 2 - FREE_MAX_ENTITIES = 100 + FREE_MAX_PROJECTS = 3 + FREE_MAX_STORAGE_MB = 100 + FREE_MAX_TRANSCRIPTION_MINUTES = 60 + FREE_MAX_API_CALLS_PER_DAY = 100 + FREE_MAX_TEAM_MEMBERS = 2 + FREE_MAX_ENTITIES = 100 - PRO_MAX_PROJECTS = 20 - PRO_MAX_STORAGE_MB = 1000 - PRO_MAX_TRANSCRIPTION_MINUTES = 600 - PRO_MAX_API_CALLS_PER_DAY = 10000 - PRO_MAX_TEAM_MEMBERS = 10 - PRO_MAX_ENTITIES = 1000 + PRO_MAX_PROJECTS = 20 + PRO_MAX_STORAGE_MB = 1000 + PRO_MAX_TRANSCRIPTION_MINUTES = 600 + PRO_MAX_API_CALLS_PER_DAY = 10000 + PRO_MAX_TEAM_MEMBERS = 10 + PRO_MAX_ENTITIES = 1000 - UNLIMITED = -1 + UNLIMITED = -1 class TenantStatus(StrEnum): """租户状态""" - ACTIVE = "active" # 活跃 - SUSPENDED = "suspended" # 暂停 - TRIAL = "trial" # 试用 - EXPIRED = "expired" # 过期 - PENDING = "pending" # 待激活 + ACTIVE = "active" # 活跃 + SUSPENDED = "suspended" # 暂停 + TRIAL = "trial" # 试用 + EXPIRED = "expired" # 过期 + PENDING = "pending" # 待激活 class TenantTier(StrEnum): """租户订阅层级""" - FREE = "free" # 免费版 - PRO = "pro" # 专业版 - ENTERPRISE = "enterprise" # 企业版 + FREE = "free" # 免费版 + PRO = "pro" # 专业版 + ENTERPRISE = "enterprise" # 企业版 class TenantRole(StrEnum): """租户角色""" - OWNER = "owner" # 所有者 - ADMIN = "admin" # 管理员 - MEMBER = "member" # 成员 - VIEWER = "viewer" # 查看者 + OWNER = "owner" # 所有者 + ADMIN = "admin" # 管理员 + MEMBER = "member" # 成员 + VIEWER = "viewer" # 查看者 class DomainStatus(StrEnum): """域名状态""" - PENDING = "pending" # 待验证 - VERIFIED = "verified" # 已验证 - FAILED = "failed" # 验证失败 - EXPIRED = "expired" # 已过期 + PENDING = "pending" # 待验证 + VERIFIED = "verified" # 已验证 + FAILED = "failed" # 验证失败 + EXPIRED = "expired" # 已过期 @dataclass @@ -171,7 +171,7 @@ class TenantManager: """租户管理器 - 多租户 SaaS 架构核心""" # 默认资源限制配置 - 使用常量 - DEFAULT_LIMITS = { + DEFAULT_LIMITS = { TenantTier.FREE: { "max_projects": TenantLimits.FREE_MAX_PROJECTS, "max_storage_mb": TenantLimits.FREE_MAX_STORAGE_MB, @@ -209,7 +209,7 @@ class TenantManager: } # 角色权限映射 - ROLE_PERMISSIONS = { + ROLE_PERMISSIONS = { TenantRole.OWNER: [ "tenant:*", "project:*", @@ -240,7 +240,7 @@ class TenantManager: } # 权限名称映射 - PERMISSION_NAMES = { + PERMISSION_NAMES = { "tenant:*": "租户完全控制", "tenant:read": "查看租户信息", "project:*": "项目完全控制", @@ -257,21 +257,21 @@ class TenantManager: "export:basic": "基础导出", } - def __init__(self, db_path: str = "insightflow.db") -> None: - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_db() def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_db(self) -> None: """初始化数据库表""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 租户主表 cursor.execute(""" @@ -418,41 +418,41 @@ class TenantManager: self, name: str, owner_id: str, - tier: str = "free", - description: str | None = None, - settings: dict | None = None, + tier: str = "free", + description: str | None = None, + settings: dict | None = None, ) -> Tenant: """创建新租户""" - conn = self._get_connection() + conn = self._get_connection() try: - tenant_id = str(uuid.uuid4()) - slug = self._generate_slug(name) + tenant_id = str(uuid.uuid4()) + slug = self._generate_slug(name) # 获取对应层级的资源限制 - tier_enum = ( + tier_enum = ( TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE ) - resource_limits = self.DEFAULT_LIMITS.get( + resource_limits = self.DEFAULT_LIMITS.get( tier_enum, self.DEFAULT_LIMITS[TenantTier.FREE] ) - tenant = Tenant( - id = tenant_id, - name = name, - slug = slug, - description = description, - tier = tier, - status = TenantStatus.PENDING.value, - owner_id = owner_id, - created_at = datetime.now(), - updated_at = datetime.now(), - expires_at = None, - settings = settings or {}, - resource_limits = resource_limits, - metadata = {}, + tenant = Tenant( + id=tenant_id, + name=name, + slug=slug, + description=description, + tier=tier, + status=TenantStatus.PENDING.value, + owner_id=owner_id, + created_at=datetime.now(), + updated_at=datetime.now(), + expires_at=None, + settings=settings or {}, + resource_limits=resource_limits, + metadata={}, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO tenants (id, name, slug, description, tier, status, owner_id, @@ -492,11 +492,11 @@ class TenantManager: def get_tenant(self, tenant_id: str) -> Tenant | None: """获取租户信息""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_tenant(row) @@ -507,11 +507,11 @@ class TenantManager: def get_tenant_by_slug(self, slug: str) -> Tenant | None: """通过 slug 获取租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_tenant(row) @@ -522,9 +522,9 @@ class TenantManager: def get_tenant_by_domain(self, domain: str) -> Tenant | None: """通过自定义域名获取租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT t.* FROM tenants t @@ -533,7 +533,7 @@ class TenantManager: """, (domain, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_tenant(row) @@ -545,21 +545,21 @@ class TenantManager: def update_tenant( self, tenant_id: str, - name: str | None = None, - description: str | None = None, - tier: str | None = None, - status: str | None = None, - settings: dict | None = None, + name: str | None = None, + description: str | None = None, + tier: str | None = None, + status: str | None = None, + settings: dict | None = None, ) -> Tenant | None: """更新租户信息""" - conn = self._get_connection() + conn = self._get_connection() try: - tenant = self.get_tenant(tenant_id) + tenant = self.get_tenant(tenant_id) if not tenant: return None - updates = [] - params = [] + updates = [] + params = [] if name is not None: updates.append("name = ?") @@ -571,7 +571,7 @@ class TenantManager: updates.append("tier = ?") params.append(tier) # 更新资源限制 - tier_enum = TenantTier(tier) + tier_enum = TenantTier(tier) updates.append("resource_limits = ?") params.append(json.dumps(self.DEFAULT_LIMITS.get(tier_enum, {}))) if status is not None: @@ -585,7 +585,7 @@ class TenantManager: params.append(datetime.now()) params.append(tenant_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE tenants SET {", ".join(updates)} @@ -602,9 +602,9 @@ class TenantManager: def delete_tenant(self, tenant_id: str) -> bool: """删除租户(软删除或硬删除)""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id, )) conn.commit() return cursor.rowcount > 0 @@ -612,15 +612,15 @@ class TenantManager: conn.close() def list_tenants( - self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0 + self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Tenant]: """列出租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM tenants WHERE 1 = 1" - params = [] + query = "SELECT * FROM tenants WHERE 1 = 1" + params = [] if status: query += " AND status = ?" @@ -633,7 +633,7 @@ class TenantManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_tenant(row) for row in rows] @@ -646,36 +646,36 @@ class TenantManager: self, tenant_id: str, domain: str, - is_primary: bool = False, - verification_method: str = "dns", + is_primary: bool = False, + verification_method: str = "dns", ) -> TenantDomain: """为租户添加自定义域名""" - conn = self._get_connection() + conn = self._get_connection() try: # 验证域名格式 if not self._validate_domain(domain): raise ValueError(f"Invalid domain format: {domain}") # 生成验证令牌 - verification_token = self._generate_verification_token(tenant_id, domain) + verification_token = self._generate_verification_token(tenant_id, domain) - domain_id = str(uuid.uuid4()) - tenant_domain = TenantDomain( - id = domain_id, - tenant_id = tenant_id, - domain = domain.lower(), - status = DomainStatus.PENDING.value, - verification_token = verification_token, - verification_method = verification_method, - verified_at = None, - created_at = datetime.now(), - updated_at = datetime.now(), - is_primary = is_primary, - ssl_enabled = False, - ssl_expires_at = None, + domain_id = str(uuid.uuid4()) + tenant_domain = TenantDomain( + id=domain_id, + tenant_id=tenant_id, + domain=domain.lower(), + status=DomainStatus.PENDING.value, + verification_token=verification_token, + verification_method=verification_method, + verified_at=None, + created_at=datetime.now(), + updated_at=datetime.now(), + is_primary=is_primary, + ssl_enabled=False, + ssl_expires_at=None, ) - cursor = conn.cursor() + cursor = conn.cursor() # 如果设为主域名,取消其他主域名 if is_primary: @@ -723,9 +723,9 @@ class TenantManager: def verify_domain(self, tenant_id: str, domain_id: str) -> bool: """验证域名所有权""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 获取域名信息 cursor.execute( @@ -735,17 +735,17 @@ class TenantManager: """, (domain_id, tenant_id), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return False - domain = row["domain"] - token = row["verification_token"] - method = row["verification_method"] + domain = row["domain"] + token = row["verification_token"] + method = row["verification_method"] # 执行验证 - is_verified = self._check_domain_verification(domain, token, method) + is_verified = self._check_domain_verification(domain, token, method) if is_verified: cursor.execute( @@ -779,17 +779,17 @@ class TenantManager: def get_domain_verification_instructions(self, domain_id: str) -> dict[str, Any]: """获取域名验证指导""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return None - domain = row["domain"] - token = row["verification_token"] + domain = row["domain"] + token = row["verification_token"] return { "domain": domain, @@ -815,9 +815,9 @@ class TenantManager: def remove_domain(self, tenant_id: str, domain_id: str) -> bool: """移除域名绑定""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ DELETE FROM tenant_domains @@ -832,9 +832,9 @@ class TenantManager: def list_domains(self, tenant_id: str) -> list[TenantDomain]: """列出租户的所有域名""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM tenant_domains @@ -843,7 +843,7 @@ class TenantManager: """, (tenant_id, ), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_domain(row) for row in rows] @@ -854,11 +854,11 @@ class TenantManager: def get_branding(self, tenant_id: str) -> TenantBranding | None: """获取租户品牌配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id, )) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_branding(row) @@ -870,28 +870,28 @@ class TenantManager: def update_branding( self, tenant_id: str, - logo_url: str | None = None, - favicon_url: str | None = None, - primary_color: str | None = None, - secondary_color: str | None = None, - custom_css: str | None = None, - custom_js: str | None = None, - login_page_bg: str | None = None, - email_template: str | None = None, + logo_url: str | None = None, + favicon_url: str | None = None, + primary_color: str | None = None, + secondary_color: str | None = None, + custom_css: str | None = None, + custom_js: str | None = None, + login_page_bg: str | None = None, + email_template: str | None = None, ) -> TenantBranding: """更新租户品牌配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 检查是否已存在 cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id, )) - existing = cursor.fetchone() + existing = cursor.fetchone() if existing: # 更新 - updates = [] - params = [] + updates = [] + params = [] if logo_url is not None: updates.append("logo_url = ?") @@ -931,7 +931,7 @@ class TenantManager: ) else: # 创建 - branding_id = str(uuid.uuid4()) + branding_id = str(uuid.uuid4()) cursor.execute( """ INSERT INTO tenant_branding @@ -963,11 +963,11 @@ class TenantManager: def get_branding_css(self, tenant_id: str) -> str: """生成品牌 CSS""" - branding = self.get_branding(tenant_id) + branding = self.get_branding(tenant_id) if not branding: return "" - css = [] + css = [] if branding.primary_color: css.append(f""" @@ -1007,35 +1007,35 @@ class TenantManager: email: str, role: str, invited_by: str, - permissions: list[str] | None = None, + permissions: list[str] | None = None, ) -> TenantMember: """邀请成员加入租户""" - conn = self._get_connection() + conn = self._get_connection() try: - member_id = str(uuid.uuid4()) + member_id = str(uuid.uuid4()) # 使用角色默认权限 - role_enum = ( + role_enum = ( TenantRole(role) if role in [r.value for r in TenantRole] else TenantRole.MEMBER ) - default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) - final_permissions = permissions or default_permissions + default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) + final_permissions = permissions or default_permissions - member = TenantMember( - id = member_id, - tenant_id = tenant_id, - user_id = "pending", # 临时值,待用户接受邀请后更新 - email = email, - role = role, - permissions = final_permissions, - invited_by = invited_by, - invited_at = datetime.now(), - joined_at = None, - last_active_at = None, - status = "pending", + member = TenantMember( + id=member_id, + tenant_id=tenant_id, + user_id="pending", # 临时值,待用户接受邀请后更新 + email=email, + role=role, + permissions=final_permissions, + invited_by=invited_by, + invited_at=datetime.now(), + joined_at=None, + last_active_at=None, + status="pending", ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO tenant_members @@ -1067,9 +1067,9 @@ class TenantManager: def accept_invitation(self, invitation_id: str, user_id: str) -> bool: """接受邀请""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE tenant_members @@ -1087,9 +1087,9 @@ class TenantManager: def remove_member(self, tenant_id: str, member_id: str) -> bool: """移除成员""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ DELETE FROM tenant_members @@ -1103,16 +1103,16 @@ class TenantManager: conn.close() def update_member_role( - self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None + self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None ) -> bool: """更新成员角色""" - conn = self._get_connection() + conn = self._get_connection() try: - role_enum = TenantRole(role) - default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) - final_permissions = permissions or default_permissions + role_enum = TenantRole(role) + default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) + final_permissions = permissions or default_permissions - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE tenant_members @@ -1128,14 +1128,14 @@ class TenantManager: finally: conn.close() - def list_members(self, tenant_id: str, status: str | None = None) -> list[TenantMember]: + def list_members(self, tenant_id: str, status: str | None = None) -> list[TenantMember]: """列出租户成员""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM tenant_members WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM tenant_members WHERE tenant_id = ?" + params = [tenant_id] if status: query += " AND status = ?" @@ -1144,7 +1144,7 @@ class TenantManager: query += " ORDER BY invited_at DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_member(row) for row in rows] @@ -1153,9 +1153,9 @@ class TenantManager: def check_permission(self, tenant_id: str, user_id: str, resource: str, action: str) -> bool: """检查用户是否有特定权限""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT role, permissions FROM tenant_members @@ -1163,21 +1163,21 @@ class TenantManager: """, (tenant_id, user_id), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return False - role = row["role"] - permissions = json.loads(row["permissions"] or "[]") + role = row["role"] + permissions = json.loads(row["permissions"] or "[]") # 所有者拥有所有权限 if role == TenantRole.OWNER.value: return True # 检查具体权限 - required = f"{resource}:{action}" - wildcard = f"{resource}:*" + required = f"{resource}:{action}" + wildcard = f"{resource}:*" return required in permissions or wildcard in permissions or "*" in permissions @@ -1186,9 +1186,9 @@ class TenantManager: def get_user_tenants(self, user_id: str) -> list[dict[str, Any]]: """获取用户所属的所有租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT t.*, m.role, m.status as member_status @@ -1199,11 +1199,11 @@ class TenantManager: """, (user_id, ), ) - rows = cursor.fetchall() + rows = cursor.fetchall() - result = [] + result = [] for row in rows: - tenant = self._row_to_tenant(row) + tenant = self._row_to_tenant(row) result.append( { **asdict(tenant), @@ -1221,20 +1221,20 @@ class TenantManager: def record_usage( self, tenant_id: str, - storage_bytes: int = 0, - transcription_seconds: int = 0, - api_calls: int = 0, - projects_count: int = 0, - entities_count: int = 0, - members_count: int = 0, + storage_bytes: int = 0, + transcription_seconds: int = 0, + api_calls: int = 0, + projects_count: int = 0, + entities_count: int = 0, + members_count: int = 0, ) -> None: """记录资源使用""" - conn = self._get_connection() + conn = self._get_connection() try: - today = datetime.now().date() - usage_id = str(uuid.uuid4()) + today = datetime.now().date() + usage_id = str(uuid.uuid4()) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO tenant_usage @@ -1268,14 +1268,14 @@ class TenantManager: conn.close() def get_usage_stats( - self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None + self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None ) -> dict[str, Any]: """获取使用统计""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = """ + query = """ SELECT SUM(storage_bytes) as total_storage, SUM(transcription_seconds) as total_transcription, @@ -1286,7 +1286,7 @@ class TenantManager: FROM tenant_usage WHERE tenant_id = ? """ - params = [tenant_id] + params = [tenant_id] if start_date: query += " AND date >= ?" @@ -1296,11 +1296,11 @@ class TenantManager: params.append(end_date.date()) cursor.execute(query, params) - row = cursor.fetchone() + row = cursor.fetchone() # 获取租户限制 - tenant = self.get_tenant(tenant_id) - limits = tenant.resource_limits if tenant else {} + tenant = self.get_tenant(tenant_id) + limits = tenant.resource_limits if tenant else {} return { "storage_bytes": row["total_storage"] or 0, @@ -1344,14 +1344,14 @@ class TenantManager: Returns: (是否允许, 当前使用量, 限制值) """ - tenant = self.get_tenant(tenant_id) + tenant = self.get_tenant(tenant_id) if not tenant: return False, 0, 0 - limits = tenant.resource_limits - stats = self.get_usage_stats(tenant_id) + limits = tenant.resource_limits + stats = self.get_usage_stats(tenant_id) - resource_map = { + resource_map = { "storage": ("storage_mb", stats["storage_mb"]), "transcription": ("max_transcription_minutes", stats["transcription_minutes"]), "api_calls": ("max_api_calls_per_day", stats["api_calls"]), @@ -1363,8 +1363,8 @@ class TenantManager: if resource_type not in resource_map: return True, 0, -1 - limit_key, current = resource_map[resource_type] - limit = limits.get(limit_key, 0) + limit_key, current = resource_map[resource_type] + limit = limits.get(limit_key, 0) # -1 表示无限制 if limit == -1: @@ -1377,21 +1377,21 @@ class TenantManager: def _generate_slug(self, name: str) -> str: """生成 URL 友好的 slug""" # 转换为小写,替换空格为连字符 - slug = re.sub(r"[^\w\s-]", "", name.lower()) - slug = re.sub(r"[-\s]+", "-", slug) + slug = re.sub(r"[^\w\s-]", "", name.lower()) + slug = re.sub(r"[-\s]+", "-", slug) # 检查是否已存在 - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - base_slug = slug - counter = 1 + cursor = conn.cursor() + base_slug = slug + counter = 1 while True: cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug, )) if not cursor.fetchone(): break - slug = f"{base_slug}-{counter}" + slug = f"{base_slug}-{counter}" counter += 1 return slug @@ -1401,12 +1401,12 @@ class TenantManager: def _generate_verification_token(self, tenant_id: str, domain: str) -> str: """生成域名验证令牌""" - data = f"{tenant_id}:{domain}:{datetime.now().isoformat()}" + data = f"{tenant_id}:{domain}:{datetime.now().isoformat()}" return hashlib.sha256(data.encode()).hexdigest()[:32] def _validate_domain(self, domain: str) -> bool: """验证域名格式""" - pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$" + pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$" return bool(re.match(pattern, domain)) def _check_domain_verification(self, domain: str, token: str, method: str) -> bool: @@ -1442,14 +1442,14 @@ class TenantManager: def _darken_color(self, hex_color: str, percent: int) -> str: """加深颜色""" - hex_color = hex_color.lstrip("#") - r = int(hex_color[0:2], 16) - g = int(hex_color[2:4], 16) - b = int(hex_color[4:6], 16) + hex_color = hex_color.lstrip("#") + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) - r = int(r * (100 - percent) / 100) - g = int(g * (100 - percent) / 100) - b = int(b * (100 - percent) / 100) + r = int(r * (100 - percent) / 100) + g = int(g * (100 - percent) / 100) + b = int(b * (100 - percent) / 100) return f"#{r:02x}{g:02x}{b:02x}" @@ -1469,8 +1469,8 @@ class TenantManager: invited_by: str | None, ) -> None: """内部方法:添加成员""" - cursor = conn.cursor() - member_id = str(uuid.uuid4()) + cursor = conn.cursor() + member_id = str(uuid.uuid4()) cursor.execute( """ @@ -1497,60 +1497,60 @@ class TenantManager: def _row_to_tenant(self, row: sqlite3.Row) -> Tenant: """数据库行转换为 Tenant 对象""" return Tenant( - id = row["id"], - name = row["name"], - slug = row["slug"], - description = row["description"], - tier = row["tier"], - status = row["status"], - owner_id = row["owner_id"], - created_at = ( + id=row["id"], + name=row["name"], + slug=row["slug"], + description=row["description"], + tier=row["tier"], + status=row["status"], + owner_id=row["owner_id"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - expires_at = ( + expires_at=( datetime.fromisoformat(row["expires_at"]) if row["expires_at"] and isinstance(row["expires_at"], str) else row["expires_at"] ), - settings = json.loads(row["settings"] or "{}"), - resource_limits = json.loads(row["resource_limits"] or "{}"), - metadata = json.loads(row["metadata"] or "{}"), + settings=json.loads(row["settings"] or "{}"), + resource_limits=json.loads(row["resource_limits"] or "{}"), + metadata=json.loads(row["metadata"] or "{}"), ) def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain: """数据库行转换为 TenantDomain 对象""" return TenantDomain( - id = row["id"], - tenant_id = row["tenant_id"], - domain = row["domain"], - status = row["status"], - verification_token = row["verification_token"], - verification_method = row["verification_method"], - verified_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + domain=row["domain"], + status=row["status"], + verification_token=row["verification_token"], + verification_method=row["verification_method"], + verified_at=( datetime.fromisoformat(row["verified_at"]) if row["verified_at"] and isinstance(row["verified_at"], str) else row["verified_at"] ), - created_at = ( + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - is_primary = bool(row["is_primary"]), - ssl_enabled = bool(row["ssl_enabled"]), - ssl_expires_at = ( + is_primary=bool(row["is_primary"]), + ssl_enabled=bool(row["ssl_enabled"]), + ssl_expires_at=( datetime.fromisoformat(row["ssl_expires_at"]) if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str) else row["ssl_expires_at"] @@ -1560,22 +1560,22 @@ class TenantManager: def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding: """数据库行转换为 TenantBranding 对象""" return TenantBranding( - id = row["id"], - tenant_id = row["tenant_id"], - logo_url = row["logo_url"], - favicon_url = row["favicon_url"], - primary_color = row["primary_color"], - secondary_color = row["secondary_color"], - custom_css = row["custom_css"], - custom_js = row["custom_js"], - login_page_bg = row["login_page_bg"], - email_template = row["email_template"], - created_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + logo_url=row["logo_url"], + favicon_url=row["favicon_url"], + primary_color=row["primary_color"], + secondary_color=row["secondary_color"], + custom_css=row["custom_css"], + custom_js=row["custom_js"], + login_page_bg=row["login_page_bg"], + email_template=row["email_template"], + created_at=( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at = ( + updated_at=( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1585,29 +1585,29 @@ class TenantManager: def _row_to_member(self, row: sqlite3.Row) -> TenantMember: """数据库行转换为 TenantMember 对象""" return TenantMember( - id = row["id"], - tenant_id = row["tenant_id"], - user_id = row["user_id"], - email = row["email"], - role = row["role"], - permissions = json.loads(row["permissions"] or "[]"), - invited_by = row["invited_by"], - invited_at = ( + id=row["id"], + tenant_id=row["tenant_id"], + user_id=row["user_id"], + email=row["email"], + role=row["role"], + permissions=json.loads(row["permissions"] or "[]"), + invited_by=row["invited_by"], + invited_at=( datetime.fromisoformat(row["invited_at"]) if isinstance(row["invited_at"], str) else row["invited_at"] ), - joined_at = ( + joined_at=( datetime.fromisoformat(row["joined_at"]) if row["joined_at"] and isinstance(row["joined_at"], str) else row["joined_at"] ), - last_active_at = ( + last_active_at=( datetime.fromisoformat(row["last_active_at"]) if row["last_active_at"] and isinstance(row["last_active_at"], str) else row["last_active_at"] ), - status = row["status"], + status=row["status"], ) @@ -1617,13 +1617,13 @@ class TenantManager: class TenantContext: """租户上下文管理器 - 用于请求级别的租户隔离""" - _current_tenant_id: str | None = None - _current_user_id: str | None = None + _current_tenant_id: str | None = None + _current_user_id: str | None = None @classmethod def set_current_tenant(cls, tenant_id: str) -> None: """设置当前租户上下文""" - cls._current_tenant_id = tenant_id + cls._current_tenant_id = tenant_id @classmethod def get_current_tenant(cls) -> str | None: @@ -1633,7 +1633,7 @@ class TenantContext: @classmethod def set_current_user(cls, user_id: str) -> None: """设置当前用户""" - cls._current_user_id = user_id + cls._current_user_id = user_id @classmethod def get_current_user(cls) -> str | None: @@ -1643,17 +1643,17 @@ class TenantContext: @classmethod def clear(cls) -> None: """清除上下文""" - cls._current_tenant_id = None - cls._current_user_id = None + cls._current_tenant_id = None + cls._current_user_id = None # 全局租户管理器实例 -tenant_manager = None +tenant_manager = None -def get_tenant_manager(db_path: str = "insightflow.db") -> TenantManager: +def get_tenant_manager(db_path: str = "insightflow.db") -> TenantManager: """获取租户管理器实例(单例模式)""" global tenant_manager if tenant_manager is None: - tenant_manager = TenantManager(db_path) + tenant_manager = TenantManager(db_path) return tenant_manager diff --git a/backend/test_multimodal.py b/backend/test_multimodal.py index b7c26da..aac98cc 100644 --- a/backend/test_multimodal.py +++ b/backend/test_multimodal.py @@ -42,7 +42,7 @@ except ImportError as e: print("\n2. 测试模块初始化...") try: - processor = get_multimodal_processor() + processor = get_multimodal_processor() print(" ✓ MultimodalProcessor 初始化成功") print(f" - 临时目录: {processor.temp_dir}") print(f" - 帧提取间隔: {processor.frame_interval}秒") @@ -50,14 +50,14 @@ except Exception as e: print(f" ✗ MultimodalProcessor 初始化失败: {e}") try: - img_processor = get_image_processor() + img_processor = get_image_processor() print(" ✓ ImageProcessor 初始化成功") print(f" - 临时目录: {img_processor.temp_dir}") except Exception as e: print(f" ✗ ImageProcessor 初始化失败: {e}") try: - linker = get_multimodal_entity_linker() + linker = get_multimodal_entity_linker() print(" ✓ MultimodalEntityLinker 初始化成功") print(f" - 相似度阈值: {linker.similarity_threshold}") except Exception as e: @@ -67,20 +67,20 @@ except Exception as e: print("\n3. 测试实体关联功能...") try: - linker = get_multimodal_entity_linker() + linker = get_multimodal_entity_linker() # 测试字符串相似度 - sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha") + sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha") assert sim == 1.0, "完全匹配应该返回1.0" print(f" ✓ 字符串相似度计算正常 (完全匹配: {sim})") - sim = linker.calculate_string_similarity("K8s", "Kubernetes") + sim = linker.calculate_string_similarity("K8s", "Kubernetes") print(f" ✓ 字符串相似度计算正常 (不同字符串: {sim:.2f})") # 测试实体相似度 - entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"} - entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"} - sim, match_type = linker.calculate_entity_similarity(entity1, entity2) + entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"} + entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"} + sim, match_type = linker.calculate_entity_similarity(entity1, entity2) print(f" ✓ 实体相似度计算正常 (相似度: {sim:.2f}, 类型: {match_type})") except Exception as e: @@ -90,7 +90,7 @@ except Exception as e: print("\n4. 测试图片处理器功能...") try: - processor = get_image_processor() + processor = get_image_processor() # 测试图片类型检测(使用模拟数据) print(f" ✓ 支持的图片类型: {list(processor.IMAGE_TYPES.keys())}") @@ -103,7 +103,7 @@ except Exception as e: print("\n5. 测试视频处理器配置...") try: - processor = get_multimodal_processor() + processor = get_multimodal_processor() print(f" ✓ 视频目录: {processor.video_dir}") print(f" ✓ 帧目录: {processor.frames_dir}") @@ -129,11 +129,11 @@ print("\n6. 测试数据库多模态方法...") try: from db_manager import get_db_manager - db = get_db_manager() + db = get_db_manager() # 检查多模态表是否存在 - conn = db.get_conn() - tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"] + conn = db.get_conn() + tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"] for table in tables: try: diff --git a/backend/test_phase7_task6_8.py b/backend/test_phase7_task6_8.py index 40beaae..ff607d0 100644 --- a/backend/test_phase7_task6_8.py +++ b/backend/test_phase7_task6_8.py @@ -27,21 +27,21 @@ def test_fulltext_search() -> None: print("测试全文搜索 (FullTextSearch)") print(" = " * 60) - search = FullTextSearch() + search = FullTextSearch() # 测试索引创建 print("\n1. 测试索引创建...") - success = search.index_content( - content_id = "test_entity_1", - content_type = "entity", - project_id = "test_project", - text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。", + success = search.index_content( + content_id="test_entity_1", + content_type="entity", + project_id="test_project", + text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。", ) print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") # 测试搜索 print("\n2. 测试关键词搜索...") - results = search.search("测试", project_id = "test_project") + results = search.search("测试", project_id="test_project") print(f" 搜索结果数量: {len(results)}") if results: print(f" 第一个结果: {results[0].content[:50]}...") @@ -49,15 +49,15 @@ def test_fulltext_search() -> None: # 测试布尔搜索 print("\n3. 测试布尔搜索...") - results = search.search("测试 AND 全文", project_id = "test_project") + results = search.search("测试 AND 全文", project_id="test_project") print(f" AND 搜索结果: {len(results)}") - results = search.search("测试 OR 关键词", project_id = "test_project") + results = search.search("测试 OR 关键词", project_id="test_project") print(f" OR 搜索结果: {len(results)}") # 测试高亮 print("\n4. 测试文本高亮...") - highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文") + highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文") print(f" 高亮结果: {highlighted}") print("\n✓ 全文搜索测试完成") @@ -70,7 +70,7 @@ def test_semantic_search() -> None: print("测试语义搜索 (SemanticSearch)") print(" = " * 60) - semantic = SemanticSearch() + semantic = SemanticSearch() # 检查可用性 print(f"\n1. 语义搜索可用性: {'✓ 可用' if semantic.is_available() else '✗ 不可用'}") @@ -81,18 +81,18 @@ def test_semantic_search() -> None: # 测试 embedding 生成 print("\n2. 测试 embedding 生成...") - embedding = semantic.generate_embedding("这是一个测试句子") + embedding = semantic.generate_embedding("这是一个测试句子") if embedding: print(f" Embedding 维度: {len(embedding)}") print(f" 前5个值: {embedding[:5]}") # 测试索引 print("\n3. 测试语义索引...") - success = semantic.index_embedding( - content_id = "test_content_1", - content_type = "transcript", - project_id = "test_project", - text = "这是用于语义搜索测试的文本内容。", + success = semantic.index_embedding( + content_id="test_content_1", + content_type="transcript", + project_id="test_project", + text="这是用于语义搜索测试的文本内容。", ) print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") @@ -106,7 +106,7 @@ def test_entity_path_discovery() -> None: print("测试实体路径发现 (EntityPathDiscovery)") print(" = " * 60) - discovery = EntityPathDiscovery() + discovery = EntityPathDiscovery() print("\n1. 测试路径发现初始化...") print(f" 数据库路径: {discovery.db_path}") @@ -125,7 +125,7 @@ def test_knowledge_gap_detection() -> None: print("测试知识缺口识别 (KnowledgeGapDetection)") print(" = " * 60) - detection = KnowledgeGapDetection() + detection = KnowledgeGapDetection() print("\n1. 测试缺口检测初始化...") print(f" 数据库路径: {detection.db_path}") @@ -144,26 +144,26 @@ def test_cache_manager() -> None: print("测试缓存管理器 (CacheManager)") print(" = " * 60) - cache = CacheManager() + cache = CacheManager() print(f"\n1. 缓存后端: {'Redis' if cache.use_redis else '内存 LRU'}") print("\n2. 测试缓存操作...") # 设置缓存 - cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60) + cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60) print(" ✓ 设置缓存 test_key_1") # 获取缓存 - _ = cache.get("test_key_1") + _ = cache.get("test_key_1") print(" ✓ 获取缓存: {value}") # 批量操作 cache.set_many( - {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60 + {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60 ) print(" ✓ 批量设置缓存") - _ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"]) + _ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"]) print(" ✓ 批量获取缓存: {len(values)} 个") # 删除缓存 @@ -171,7 +171,7 @@ def test_cache_manager() -> None: print(" ✓ 删除缓存 test_key_1") # 获取统计 - stats = cache.get_stats() + stats = cache.get_stats() print("\n3. 缓存统计:") print(f" 总请求数: {stats['total_requests']}") print(f" 命中数: {stats['hits']}") @@ -192,7 +192,7 @@ def test_task_queue() -> None: print("测试任务队列 (TaskQueue)") print(" = " * 60) - queue = TaskQueue() + queue = TaskQueue() print(f"\n1. 任务队列可用性: {'✓ 可用' if queue.is_available() else '✗ 不可用'}") print(f" 后端: {'Celery' if queue.use_celery else '内存'}") @@ -207,18 +207,18 @@ def test_task_queue() -> None: queue.register_handler("test_task", test_task_handler) # 提交任务 - task_id = queue.submit( - task_type = "test_task", payload = {"test": "data", "timestamp": time.time()} + task_id = queue.submit( + task_type="test_task", payload={"test": "data", "timestamp": time.time()} ) print(" ✓ 提交任务: {task_id}") # 获取任务状态 - task_info = queue.get_status(task_id) + task_info = queue.get_status(task_id) if task_info: print(" ✓ 任务状态: {task_info.status}") # 获取统计 - stats = queue.get_stats() + stats = queue.get_stats() print("\n3. 任务队列统计:") print(f" 后端: {stats['backend']}") print(f" 按状态统计: {stats.get('by_status', {})}") @@ -233,32 +233,32 @@ def test_performance_monitor() -> None: print("测试性能监控 (PerformanceMonitor)") print(" = " * 60) - monitor = PerformanceMonitor() + monitor = PerformanceMonitor() print("\n1. 测试指标记录...") # 记录一些测试指标 for i in range(5): monitor.record_metric( - metric_type = "api_response", - duration_ms = 50 + i * 10, - endpoint = "/api/v1/test", - metadata = {"test": True}, + metric_type="api_response", + duration_ms=50 + i * 10, + endpoint="/api/v1/test", + metadata={"test": True}, ) for i in range(3): monitor.record_metric( - metric_type = "db_query", - duration_ms = 20 + i * 5, - endpoint = "SELECT test", - metadata = {"test": True}, + metric_type="db_query", + duration_ms=20 + i * 5, + endpoint="SELECT test", + metadata={"test": True}, ) print(" ✓ 记录了 8 个测试指标") # 获取统计 print("\n2. 获取性能统计...") - stats = monitor.get_stats(hours = 1) + stats = monitor.get_stats(hours=1) print(f" 总请求数: {stats['overall']['total_requests']}") print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms") print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms") @@ -280,13 +280,13 @@ def test_search_manager() -> None: print("测试搜索管理器 (SearchManager)") print(" = " * 60) - manager = get_search_manager() + manager = get_search_manager() print("\n1. 搜索管理器初始化...") print(" ✓ 搜索管理器已初始化") print("\n2. 获取搜索统计...") - stats = manager.get_search_stats() + stats = manager.get_search_stats() print(f" 全文索引数: {stats['fulltext_indexed']}") print(f" 语义索引数: {stats['semantic_indexed']}") print(f" 语义搜索可用: {stats['semantic_search_available']}") @@ -301,18 +301,18 @@ def test_performance_manager() -> None: print("测试性能管理器 (PerformanceManager)") print(" = " * 60) - manager = get_performance_manager() + manager = get_performance_manager() print("\n1. 性能管理器初始化...") print(" ✓ 性能管理器已初始化") print("\n2. 获取系统健康状态...") - health = manager.get_health_status() + health = manager.get_health_status() print(f" 缓存后端: {health['cache']['backend']}") print(f" 任务队列后端: {health['task_queue']['backend']}") print("\n3. 获取完整统计...") - stats = manager.get_full_stats() + stats = manager.get_full_stats() print(f" 缓存统计: {stats['cache']['total_requests']} 请求") print(f" 任务队列统计: {stats['task_queue']}") @@ -327,7 +327,7 @@ def run_all_tests() -> None: print("高级搜索与发现 + 性能优化与扩展") print(" = " * 60) - results = [] + results = [] # 搜索模块测试 try: @@ -390,11 +390,11 @@ def run_all_tests() -> None: print("测试汇总") print(" = " * 60) - passed = sum(1 for _, result in results if result) - total = len(results) + passed = sum(1 for _, result in results if result) + total = len(results) for name, result in results: - status = "✓ 通过" if result else "✗ 失败" + status = "✓ 通过" if result else "✗ 失败" print(f" {status} - {name}") print(f"\n总计: {passed}/{total} 测试通过") @@ -408,5 +408,5 @@ def run_all_tests() -> None: if __name__ == "__main__": - success = run_all_tests() + success = run_all_tests() sys.exit(0 if success else 1) diff --git a/backend/test_phase8_task1.py b/backend/test_phase8_task1.py index c7ddcd5..a5390cc 100644 --- a/backend/test_phase8_task1.py +++ b/backend/test_phase8_task1.py @@ -24,12 +24,12 @@ def test_tenant_management() -> None: print("测试 1: 租户管理") print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 创建租户 print("\n1.1 创建租户...") - tenant = manager.create_tenant( - name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant" + tenant = manager.create_tenant( + name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant" ) print(f"✅ 租户创建成功: {tenant.id}") print(f" - 名称: {tenant.name}") @@ -40,27 +40,27 @@ def test_tenant_management() -> None: # 2. 获取租户 print("\n1.2 获取租户信息...") - fetched = manager.get_tenant(tenant.id) + fetched = manager.get_tenant(tenant.id) assert fetched is not None, "获取租户失败" print(f"✅ 获取租户成功: {fetched.name}") # 3. 通过 slug 获取 print("\n1.3 通过 slug 获取租户...") - by_slug = manager.get_tenant_by_slug(tenant.slug) + by_slug = manager.get_tenant_by_slug(tenant.slug) assert by_slug is not None, "通过 slug 获取失败" print(f"✅ 通过 slug 获取成功: {by_slug.name}") # 4. 更新租户 print("\n1.4 更新租户信息...") - updated = manager.update_tenant( - tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise" + updated = manager.update_tenant( + tenant_id=tenant.id, name="Test Company Updated", tier="enterprise" ) assert updated is not None, "更新租户失败" print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}") # 5. 列出租户 print("\n1.5 列出租户...") - tenants = manager.list_tenants(limit = 10) + tenants = manager.list_tenants(limit=10) print(f"✅ 找到 {len(tenants)} 个租户") return tenant.id @@ -72,11 +72,11 @@ def test_domain_management(tenant_id: str) -> None: print("测试 2: 域名管理") print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 添加域名 print("\n2.1 添加自定义域名...") - domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True) + domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True) print(f"✅ 域名添加成功: {domain.domain}") print(f" - ID: {domain.id}") print(f" - 状态: {domain.status}") @@ -84,19 +84,19 @@ def test_domain_management(tenant_id: str) -> None: # 2. 获取验证指导 print("\n2.2 获取域名验证指导...") - instructions = manager.get_domain_verification_instructions(domain.id) + instructions = manager.get_domain_verification_instructions(domain.id) print("✅ 验证指导:") print(f" - DNS 记录: {instructions['dns_record']}") print(f" - 文件验证: {instructions['file_verification']}") # 3. 验证域名 print("\n2.3 验证域名...") - verified = manager.verify_domain(tenant_id, domain.id) + verified = manager.verify_domain(tenant_id, domain.id) print(f"✅ 域名验证结果: {verified}") # 4. 通过域名获取租户 print("\n2.4 通过域名获取租户...") - by_domain = manager.get_tenant_by_domain("test.example.com") + by_domain = manager.get_tenant_by_domain("test.example.com") if by_domain: print(f"✅ 通过域名获取租户成功: {by_domain.name}") else: @@ -104,7 +104,7 @@ def test_domain_management(tenant_id: str) -> None: # 5. 列出域名 print("\n2.5 列出所有域名...") - domains = manager.list_domains(tenant_id) + domains = manager.list_domains(tenant_id) print(f"✅ 找到 {len(domains)} 个域名") for d in domains: print(f" - {d.domain} ({d.status})") @@ -118,19 +118,19 @@ def test_branding_management(tenant_id: str) -> None: print("测试 3: 品牌白标") print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 更新品牌配置 print("\n3.1 更新品牌配置...") - branding = manager.update_branding( - tenant_id = tenant_id, - logo_url = "https://example.com/logo.png", - favicon_url = "https://example.com/favicon.ico", - primary_color = "#1890ff", - secondary_color = "#52c41a", - custom_css = ".header { background: #1890ff; }", - custom_js = "console.log('Custom JS loaded');", - login_page_bg = "https://example.com/bg.jpg", + branding = manager.update_branding( + tenant_id=tenant_id, + logo_url="https://example.com/logo.png", + favicon_url="https://example.com/favicon.ico", + primary_color="#1890ff", + secondary_color="#52c41a", + custom_css=".header { background: #1890ff; }", + custom_js="console.log('Custom JS loaded');", + login_page_bg="https://example.com/bg.jpg", ) print("✅ 品牌配置更新成功") print(f" - Logo: {branding.logo_url}") @@ -139,13 +139,13 @@ def test_branding_management(tenant_id: str) -> None: # 2. 获取品牌配置 print("\n3.2 获取品牌配置...") - fetched = manager.get_branding(tenant_id) + fetched = manager.get_branding(tenant_id) assert fetched is not None, "获取品牌配置失败" print("✅ 获取品牌配置成功") # 3. 生成品牌 CSS print("\n3.3 生成品牌 CSS...") - css = manager.get_branding_css(tenant_id) + css = manager.get_branding_css(tenant_id) print(f"✅ 生成 CSS 成功 ({len(css)} 字符)") print(f" CSS 预览:\n{css[:200]}...") @@ -158,48 +158,48 @@ def test_member_management(tenant_id: str) -> None: print("测试 4: 成员管理") print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 邀请成员 print("\n4.1 邀请成员...") - member1 = manager.invite_member( - tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001" + member1 = manager.invite_member( + tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001" ) print(f"✅ 成员邀请成功: {member1.email}") print(f" - ID: {member1.id}") print(f" - 角色: {member1.role}") print(f" - 权限: {member1.permissions}") - member2 = manager.invite_member( - tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001" + member2 = manager.invite_member( + tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001" ) print(f"✅ 成员邀请成功: {member2.email}") # 2. 接受邀请 print("\n4.2 接受邀请...") - accepted = manager.accept_invitation(member1.id, "user_002") + accepted = manager.accept_invitation(member1.id, "user_002") print(f"✅ 邀请接受结果: {accepted}") # 3. 列出成员 print("\n4.3 列出所有成员...") - members = manager.list_members(tenant_id) + members = manager.list_members(tenant_id) print(f"✅ 找到 {len(members)} 个成员") for m in members: print(f" - {m.email} ({m.role}) - {m.status}") # 4. 检查权限 print("\n4.4 检查权限...") - can_manage = manager.check_permission(tenant_id, "user_002", "project", "create") + can_manage = manager.check_permission(tenant_id, "user_002", "project", "create") print(f"✅ user_002 可以创建项目: {can_manage}") # 5. 更新成员角色 print("\n4.5 更新成员角色...") - updated = manager.update_member_role(tenant_id, member2.id, "viewer") + updated = manager.update_member_role(tenant_id, member2.id, "viewer") print(f"✅ 角色更新结果: {updated}") # 6. 获取用户所属租户 print("\n4.6 获取用户所属租户...") - user_tenants = manager.get_user_tenants("user_002") + user_tenants = manager.get_user_tenants("user_002") print(f"✅ user_002 属于 {len(user_tenants)} 个租户") for t in user_tenants: print(f" - {t['name']} ({t['member_role']})") @@ -213,24 +213,24 @@ def test_usage_tracking(tenant_id: str) -> None: print("测试 5: 资源使用统计") print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 记录使用 print("\n5.1 记录资源使用...") manager.record_usage( - tenant_id = tenant_id, - storage_bytes = 1024 * 1024 * 50, # 50MB - transcription_seconds = 600, # 10分钟 - api_calls = 100, - projects_count = 5, - entities_count = 50, - members_count = 3, + tenant_id=tenant_id, + storage_bytes=1024 * 1024 * 50, # 50MB + transcription_seconds=600, # 10分钟 + api_calls=100, + projects_count=5, + entities_count=50, + members_count=3, ) print("✅ 资源使用记录成功") # 2. 获取使用统计 print("\n5.2 获取使用统计...") - stats = manager.get_usage_stats(tenant_id) + stats = manager.get_usage_stats(tenant_id) print("✅ 使用统计:") print(f" - 存储: {stats['storage_mb']:.2f} MB") print(f" - 转录: {stats['transcription_minutes']:.2f} 分钟") @@ -243,7 +243,7 @@ def test_usage_tracking(tenant_id: str) -> None: # 3. 检查资源限制 print("\n5.3 检查资源限制...") for resource in ["storage", "transcription", "api_calls", "projects", "entities", "members"]: - allowed, current, limit = manager.check_resource_limit(tenant_id, resource) + allowed, current, limit = manager.check_resource_limit(tenant_id, resource) print(f" - {resource}: {current}/{limit} ({'✅' if allowed else '❌'})") return stats @@ -255,7 +255,7 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list) -> None: print("清理测试数据") print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 移除成员 for member_id in member_ids: @@ -279,17 +279,17 @@ def main() -> None: print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试") print(" = " * 60) - tenant_id = None - domain_id = None - member_ids = [] + tenant_id = None + domain_id = None + member_ids = [] try: # 运行所有测试 - tenant_id = test_tenant_management() - domain_id = test_domain_management(tenant_id) + tenant_id = test_tenant_management() + domain_id = test_domain_management(tenant_id) test_branding_management(tenant_id) - m1, m2 = test_member_management(tenant_id) - member_ids = [m1, m2] + m1, m2 = test_member_management(tenant_id) + member_ids = [m1, m2] test_usage_tracking(tenant_id) print("\n" + " = " * 60) diff --git a/backend/test_phase8_task2.py b/backend/test_phase8_task2.py index bc2589c..fa3af2e 100644 --- a/backend/test_phase8_task2.py +++ b/backend/test_phase8_task2.py @@ -19,24 +19,24 @@ def test_subscription_manager() -> None: print(" = " * 60) # 使用临时文件数据库进行测试 - db_path = tempfile.mktemp(suffix = ".db") + db_path = tempfile.mktemp(suffix=".db") try: - manager = SubscriptionManager(db_path = db_path) + manager = SubscriptionManager(db_path=db_path) print("\n1. 测试订阅计划管理") print("-" * 40) # 获取默认计划 - plans = manager.list_plans() + plans = manager.list_plans() print(f"✓ 默认计划数量: {len(plans)}") for plan in plans: print(f" - {plan.name} ({plan.tier}): ¥{plan.price_monthly}/月") # 通过 tier 获取计划 - free_plan = manager.get_plan_by_tier("free") - pro_plan = manager.get_plan_by_tier("pro") - enterprise_plan = manager.get_plan_by_tier("enterprise") + free_plan = manager.get_plan_by_tier("free") + pro_plan = manager.get_plan_by_tier("pro") + enterprise_plan = manager.get_plan_by_tier("enterprise") assert free_plan is not None, "Free 计划应该存在" assert pro_plan is not None, "Pro 计划应该存在" @@ -49,14 +49,14 @@ def test_subscription_manager() -> None: print("\n2. 测试订阅管理") print("-" * 40) - tenant_id = "test-tenant-001" + tenant_id = "test-tenant-001" # 创建订阅 - subscription = manager.create_subscription( - tenant_id = tenant_id, - plan_id = pro_plan.id, - payment_provider = PaymentProvider.STRIPE.value, - trial_days = 14, + subscription = manager.create_subscription( + tenant_id=tenant_id, + plan_id=pro_plan.id, + payment_provider=PaymentProvider.STRIPE.value, + trial_days=14, ) print(f"✓ 创建订阅: {subscription.id}") @@ -66,7 +66,7 @@ def test_subscription_manager() -> None: print(f" - 试用结束: {subscription.trial_end}") # 获取租户订阅 - tenant_sub = manager.get_tenant_subscription(tenant_id) + tenant_sub = manager.get_tenant_subscription(tenant_id) assert tenant_sub is not None, "应该能获取到租户订阅" print(f"✓ 获取租户订阅: {tenant_sub.id}") @@ -74,27 +74,27 @@ def test_subscription_manager() -> None: print("-" * 40) # 记录转录用量 - usage1 = manager.record_usage( - tenant_id = tenant_id, - resource_type = "transcription", - quantity = 120, - unit = "minute", - description = "会议转录", + usage1 = manager.record_usage( + tenant_id=tenant_id, + resource_type="transcription", + quantity=120, + unit="minute", + description="会议转录", ) print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}") # 记录存储用量 - usage2 = manager.record_usage( - tenant_id = tenant_id, - resource_type = "storage", - quantity = 2.5, - unit = "gb", - description = "文件存储", + usage2 = manager.record_usage( + tenant_id=tenant_id, + resource_type="storage", + quantity=2.5, + unit="gb", + description="文件存储", ) print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}") # 获取用量汇总 - summary = manager.get_usage_summary(tenant_id) + summary = manager.get_usage_summary(tenant_id) print("✓ 用量汇总:") print(f" - 总费用: ¥{summary['total_cost']:.2f}") for resource, data in summary["breakdown"].items(): @@ -104,12 +104,12 @@ def test_subscription_manager() -> None: print("-" * 40) # 创建支付 - payment = manager.create_payment( - tenant_id = tenant_id, - amount = 99.0, - currency = "CNY", - provider = PaymentProvider.ALIPAY.value, - payment_method = "qrcode", + payment = manager.create_payment( + tenant_id=tenant_id, + amount=99.0, + currency="CNY", + provider=PaymentProvider.ALIPAY.value, + payment_method="qrcode", ) print(f"✓ 创建支付: {payment.id}") print(f" - 金额: ¥{payment.amount}") @@ -117,22 +117,22 @@ def test_subscription_manager() -> None: print(f" - 状态: {payment.status}") # 确认支付 - confirmed = manager.confirm_payment(payment.id, "alipay_123456") + confirmed = manager.confirm_payment(payment.id, "alipay_123456") print(f"✓ 确认支付完成: {confirmed.status}") # 列出支付记录 - payments = manager.list_payments(tenant_id) + payments = manager.list_payments(tenant_id) print(f"✓ 支付记录数量: {len(payments)}") print("\n5. 测试发票管理") print("-" * 40) # 列出发票 - invoices = manager.list_invoices(tenant_id) + invoices = manager.list_invoices(tenant_id) print(f"✓ 发票数量: {len(invoices)}") if invoices: - invoice = invoices[0] + invoice = invoices[0] print(f" - 发票号: {invoice.invoice_number}") print(f" - 金额: ¥{invoice.amount_due}") print(f" - 状态: {invoice.status}") @@ -141,12 +141,12 @@ def test_subscription_manager() -> None: print("-" * 40) # 申请退款 - refund = manager.request_refund( - tenant_id = tenant_id, - payment_id = payment.id, - amount = 50.0, - reason = "服务不满意", - requested_by = "user_001", + refund = manager.request_refund( + tenant_id=tenant_id, + payment_id=payment.id, + amount=50.0, + reason="服务不满意", + requested_by="user_001", ) print(f"✓ 申请退款: {refund.id}") print(f" - 金额: ¥{refund.amount}") @@ -154,21 +154,21 @@ def test_subscription_manager() -> None: print(f" - 状态: {refund.status}") # 批准退款 - approved = manager.approve_refund(refund.id, "admin_001") + approved = manager.approve_refund(refund.id, "admin_001") print(f"✓ 批准退款: {approved.status}") # 完成退款 - completed = manager.complete_refund(refund.id, "refund_123456") + completed = manager.complete_refund(refund.id, "refund_123456") print(f"✓ 完成退款: {completed.status}") # 列出退款记录 - refunds = manager.list_refunds(tenant_id) + refunds = manager.list_refunds(tenant_id) print(f"✓ 退款记录数量: {len(refunds)}") print("\n7. 测试账单历史") print("-" * 40) - history = manager.get_billing_history(tenant_id) + history = manager.get_billing_history(tenant_id) print(f"✓ 账单历史记录数量: {len(history)}") for h in history: print(f" - [{h.type}] {h.description}: ¥{h.amount}") @@ -177,24 +177,24 @@ def test_subscription_manager() -> None: print("-" * 40) # Stripe Checkout - stripe_session = manager.create_stripe_checkout_session( - tenant_id = tenant_id, - plan_id = enterprise_plan.id, - success_url = "https://example.com/success", - cancel_url = "https://example.com/cancel", + stripe_session = manager.create_stripe_checkout_session( + tenant_id=tenant_id, + plan_id=enterprise_plan.id, + success_url="https://example.com/success", + cancel_url="https://example.com/cancel", ) print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}") # 支付宝订单 - alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id) + alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id) print(f"✓ 支付宝订单: {alipay_order['order_id']}") # 微信支付订单 - wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id) + wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id) print(f"✓ 微信支付订单: {wechat_order['order_id']}") # Webhook 处理 - webhook_result = manager.handle_webhook( + webhook_result = manager.handle_webhook( "stripe", {"event_type": "checkout.session.completed", "data": {"object": {"id": "cs_test"}}}, ) @@ -204,13 +204,13 @@ def test_subscription_manager() -> None: print("-" * 40) # 更改计划 - changed = manager.change_plan( - subscription_id = subscription.id, new_plan_id = enterprise_plan.id + changed = manager.change_plan( + subscription_id=subscription.id, new_plan_id=enterprise_plan.id ) print(f"✓ 更改计划: {changed.plan_id} (Enterprise)") # 取消订阅 - cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True) + cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True) print(f"✓ 取消订阅: {cancelled.status}") print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}") diff --git a/backend/test_phase8_task4.py b/backend/test_phase8_task4.py index 702363d..df4e187 100644 --- a/backend/test_phase8_task4.py +++ b/backend/test_phase8_task4.py @@ -18,27 +18,27 @@ def test_custom_model() -> None: """测试自定义模型功能""" print("\n=== 测试自定义模型 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 创建自定义模型 print("1. 创建自定义模型...") - model = manager.create_custom_model( - tenant_id = "tenant_001", - name = "领域实体识别模型", - description = "用于识别医疗领域实体的自定义模型", - model_type = ModelType.CUSTOM_NER, - training_data = { + model = manager.create_custom_model( + tenant_id="tenant_001", + name="领域实体识别模型", + description="用于识别医疗领域实体的自定义模型", + model_type=ModelType.CUSTOM_NER, + training_data={ "entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"], "domain": "medical", }, - hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32}, - created_by = "user_001", + hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32}, + created_by="user_001", ) print(f" 创建成功: {model.id}, 状态: {model.status.value}") # 2. 添加训练样本 print("2. 添加训练样本...") - samples = [ + samples = [ { "text": "患者张三患有高血压,正在服用降压药治疗。", "entities": [ @@ -66,22 +66,22 @@ def test_custom_model() -> None: ] for sample_data in samples: - sample = manager.add_training_sample( - model_id = model.id, - text = sample_data["text"], - entities = sample_data["entities"], - metadata = {"source": "manual"}, + sample = manager.add_training_sample( + model_id=model.id, + text=sample_data["text"], + entities=sample_data["entities"], + metadata={"source": "manual"}, ) print(f" 添加样本: {sample.id}") # 3. 获取训练样本 print("3. 获取训练样本...") - all_samples = manager.get_training_samples(model.id) + all_samples = manager.get_training_samples(model.id) print(f" 共有 {len(all_samples)} 个训练样本") # 4. 列出自定义模型 print("4. 列出自定义模型...") - models = manager.list_custom_models(tenant_id = "tenant_001") + models = manager.list_custom_models(tenant_id="tenant_001") print(f" 找到 {len(models)} 个模型") for m in models: print(f" - {m.name} ({m.model_type.value}): {m.status.value}") @@ -93,12 +93,12 @@ async def test_train_and_predict(model_id: str) -> None: """测试训练和预测""" print("\n=== 测试模型训练和预测 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 训练模型 print("1. 训练模型...") try: - trained_model = await manager.train_custom_model(model_id) + trained_model = await manager.train_custom_model(model_id) print(f" 训练完成: {trained_model.status.value}") print(f" 指标: {trained_model.metrics}") except Exception as e: @@ -107,9 +107,9 @@ async def test_train_and_predict(model_id: str) -> None: # 2. 使用模型预测 print("2. 使用模型预测...") - test_text = "赵六患有糖尿病,正在使用胰岛素治疗。" + test_text = "赵六患有糖尿病,正在使用胰岛素治疗。" try: - entities = await manager.predict_with_custom_model(model_id, test_text) + entities = await manager.predict_with_custom_model(model_id, test_text) print(f" 输入: {test_text}") print(f" 预测实体: {entities}") except Exception as e: @@ -120,37 +120,37 @@ def test_prediction_models() -> None: """测试预测模型""" print("\n=== 测试预测模型 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 创建趋势预测模型 print("1. 创建趋势预测模型...") - trend_model = manager.create_prediction_model( - tenant_id = "tenant_001", - project_id = "project_001", - name = "实体数量趋势预测", - prediction_type = PredictionType.TREND, - target_entity_type = "PERSON", - features = ["entity_count", "time_period", "document_count"], - model_config = {"algorithm": "linear_regression", "window_size": 7}, + trend_model = manager.create_prediction_model( + tenant_id="tenant_001", + project_id="project_001", + name="实体数量趋势预测", + prediction_type=PredictionType.TREND, + target_entity_type="PERSON", + features=["entity_count", "time_period", "document_count"], + model_config={"algorithm": "linear_regression", "window_size": 7}, ) print(f" 创建成功: {trend_model.id}") # 2. 创建异常检测模型 print("2. 创建异常检测模型...") - anomaly_model = manager.create_prediction_model( - tenant_id = "tenant_001", - project_id = "project_001", - name = "实体增长异常检测", - prediction_type = PredictionType.ANOMALY, - target_entity_type = None, - features = ["daily_growth", "weekly_growth"], - model_config = {"threshold": 2.5, "sensitivity": "medium"}, + anomaly_model = manager.create_prediction_model( + tenant_id="tenant_001", + project_id="project_001", + name="实体增长异常检测", + prediction_type=PredictionType.ANOMALY, + target_entity_type=None, + features=["daily_growth", "weekly_growth"], + model_config={"threshold": 2.5, "sensitivity": "medium"}, ) print(f" 创建成功: {anomaly_model.id}") # 3. 列出预测模型 print("3. 列出预测模型...") - models = manager.list_prediction_models(tenant_id = "tenant_001") + models = manager.list_prediction_models(tenant_id="tenant_001") print(f" 找到 {len(models)} 个预测模型") for m in models: print(f" - {m.name} ({m.prediction_type.value})") @@ -162,11 +162,11 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None: """测试预测功能""" print("\n=== 测试预测功能 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 训练趋势预测模型 print("1. 训练趋势预测模型...") - historical_data = [ + historical_data = [ {"date": "2024-01-01", "value": 10}, {"date": "2024-01-02", "value": 12}, {"date": "2024-01-03", "value": 15}, @@ -175,19 +175,19 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None: {"date": "2024-01-06", "value": 20}, {"date": "2024-01-07", "value": 22}, ] - trained = await manager.train_prediction_model(trend_model_id, historical_data) + trained = await manager.train_prediction_model(trend_model_id, historical_data) print(f" 训练完成,准确率: {trained.accuracy}") # 2. 趋势预测 print("2. 趋势预测...") - trend_result = await manager.predict( + trend_result = await manager.predict( trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]} ) print(f" 预测结果: {trend_result.prediction_data}") # 3. 异常检测 print("3. 异常检测...") - anomaly_result = await manager.predict( + anomaly_result = await manager.predict( anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]} ) print(f" 检测结果: {anomaly_result.prediction_data}") @@ -197,27 +197,27 @@ def test_kg_rag() -> None: """测试知识图谱 RAG""" print("\n=== 测试知识图谱 RAG ===") - manager = get_ai_manager() + manager = get_ai_manager() # 创建 RAG 配置 print("1. 创建知识图谱 RAG 配置...") - rag = manager.create_kg_rag( - tenant_id = "tenant_001", - project_id = "project_001", - name = "项目知识问答", - description = "基于项目知识图谱的智能问答", - kg_config = { + rag = manager.create_kg_rag( + tenant_id="tenant_001", + project_id="project_001", + name="项目知识问答", + description="基于项目知识图谱的智能问答", + kg_config={ "entity_types": ["PERSON", "ORG", "PROJECT", "TECH"], "relation_types": ["works_with", "belongs_to", "depends_on"], }, - retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True}, - generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True}, + retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True}, + generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True}, ) print(f" 创建成功: {rag.id}") # 列出 RAG 配置 print("2. 列出 RAG 配置...") - rags = manager.list_kg_rags(tenant_id = "tenant_001") + rags = manager.list_kg_rags(tenant_id="tenant_001") print(f" 找到 {len(rags)} 个配置") return rag.id @@ -227,10 +227,10 @@ async def test_kg_rag_query(rag_id: str) -> None: """测试 RAG 查询""" print("\n=== 测试知识图谱 RAG 查询 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 模拟项目实体和关系 - project_entities = [ + project_entities = [ {"id": "e1", "name": "张三", "type": "PERSON", "definition": "项目经理"}, {"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"}, {"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"}, @@ -238,7 +238,7 @@ async def test_kg_rag_query(rag_id: str) -> None: {"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"}, ] - project_relations = [ + project_relations = [ { "source_entity_id": "e1", "target_entity_id": "e3", @@ -275,14 +275,14 @@ async def test_kg_rag_query(rag_id: str) -> None: # 执行查询 print("1. 执行 RAG 查询...") - query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?" + query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?" try: - result = await manager.query_kg_rag( - rag_id = rag_id, - query = query_text, - project_entities = project_entities, - project_relations = project_relations, + result = await manager.query_kg_rag( + rag_id=rag_id, + query=query_text, + project_entities=project_entities, + project_relations=project_relations, ) print(f" 查询: {result.query}") @@ -298,10 +298,10 @@ async def test_smart_summary() -> None: """测试智能摘要""" print("\n=== 测试智能摘要 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 模拟转录文本 - transcript_text = """ + transcript_text = """ 今天的会议主要讨论了 Project Alpha 的进展情况。张三作为项目经理, 汇报了当前的项目进度,表示已经完成了 80% 的开发工作。李四提出了 一些关于 Kubernetes 部署的问题,建议我们采用新的部署策略。 @@ -309,7 +309,7 @@ async def test_smart_summary() -> None: 大家一致认为项目进展顺利,预计可以按时交付。 """ - content_data = { + content_data = { "text": transcript_text, "entities": [ {"name": "张三", "type": "PERSON"}, @@ -320,18 +320,18 @@ async def test_smart_summary() -> None: } # 生成不同类型的摘要 - summary_types = ["extractive", "abstractive", "key_points"] + summary_types = ["extractive", "abstractive", "key_points"] for summary_type in summary_types: print(f"1. 生成 {summary_type} 类型摘要...") try: - summary = await manager.generate_smart_summary( - tenant_id = "tenant_001", - project_id = "project_001", - source_type = "transcript", - source_id = "transcript_001", - summary_type = summary_type, - content_data = content_data, + summary = await manager.generate_smart_summary( + tenant_id="tenant_001", + project_id="project_001", + source_type="transcript", + source_id="transcript_001", + summary_type=summary_type, + content_data=content_data, ) print(f" 摘要类型: {summary.summary_type}") @@ -350,19 +350,19 @@ async def main() -> None: try: # 测试自定义模型 - model_id = test_custom_model() + model_id = test_custom_model() # 测试训练和预测 await test_train_and_predict(model_id) # 测试预测模型 - trend_model_id, anomaly_model_id = test_prediction_models() + trend_model_id, anomaly_model_id = test_prediction_models() # 测试预测功能 await test_predictions(trend_model_id, anomaly_model_id) # 测试知识图谱 RAG - rag_id = test_kg_rag() + rag_id = test_kg_rag() # 测试 RAG 查询 await test_kg_rag_query(rag_id) diff --git a/backend/test_phase8_task5.py b/backend/test_phase8_task5.py index ee10a8f..56cf44c 100644 --- a/backend/test_phase8_task5.py +++ b/backend/test_phase8_task5.py @@ -28,7 +28,7 @@ from growth_manager import ( ) # 添加 backend 目录到路径 -backend_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.dirname(os.path.abspath(__file__)) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) @@ -37,14 +37,14 @@ class TestGrowthManager: """测试 Growth Manager 功能""" def __init__(self) -> None: - self.manager = GrowthManager() - self.test_tenant_id = "test_tenant_001" - self.test_user_id = "test_user_001" - self.test_results = [] + self.manager = GrowthManager() + self.test_tenant_id = "test_tenant_001" + self.test_user_id = "test_user_001" + self.test_results = [] - def log(self, message: str, success: bool = True) -> None: + def log(self, message: str, success: bool = True) -> None: """记录测试结果""" - status = "✅" if success else "❌" + status = "✅" if success else "❌" print(f"{status} {message}") self.test_results.append((message, success)) @@ -55,16 +55,16 @@ class TestGrowthManager: print("\n📊 测试事件追踪...") try: - event = await self.manager.track_event( - tenant_id = self.test_tenant_id, - user_id = self.test_user_id, - event_type = EventType.PAGE_VIEW, - event_name = "dashboard_view", - properties = {"page": "/dashboard", "duration": 120}, - session_id = "session_001", - device_info = {"browser": "Chrome", "os": "MacOS"}, - referrer = "https://google.com", - utm_params = {"source": "google", "medium": "organic", "campaign": "summer"}, + event = await self.manager.track_event( + tenant_id=self.test_tenant_id, + user_id=self.test_user_id, + event_type=EventType.PAGE_VIEW, + event_name="dashboard_view", + properties={"page": "/dashboard", "duration": 120}, + session_id="session_001", + device_info={"browser": "Chrome", "os": "MacOS"}, + referrer="https://google.com", + utm_params={"source": "google", "medium": "organic", "campaign": "summer"}, ) assert event.id is not None @@ -74,7 +74,7 @@ class TestGrowthManager: self.log(f"事件追踪成功: {event.id}") return True except Exception as e: - self.log(f"事件追踪失败: {e}", success = False) + self.log(f"事件追踪失败: {e}", success=False) return False async def test_track_multiple_events(self) -> None: @@ -82,7 +82,7 @@ class TestGrowthManager: print("\n📊 测试追踪多个事件...") try: - events = [ + events = [ (EventType.FEATURE_USE, "entity_extraction", {"entity_count": 5}), (EventType.FEATURE_USE, "relation_discovery", {"relation_count": 3}), (EventType.CONVERSION, "upgrade_click", {"plan": "pro"}), @@ -91,17 +91,17 @@ class TestGrowthManager: for event_type, event_name, props in events: await self.manager.track_event( - tenant_id = self.test_tenant_id, - user_id = self.test_user_id, - event_type = event_type, - event_name = event_name, - properties = props, + tenant_id=self.test_tenant_id, + user_id=self.test_user_id, + event_type=event_type, + event_name=event_name, + properties=props, ) self.log(f"成功追踪 {len(events)} 个事件") return True except Exception as e: - self.log(f"批量事件追踪失败: {e}", success = False) + self.log(f"批量事件追踪失败: {e}", success=False) return False def test_get_user_profile(self) -> None: @@ -109,7 +109,7 @@ class TestGrowthManager: print("\n👤 测试用户画像...") try: - profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id) + profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id) if profile: assert profile.user_id == self.test_user_id @@ -120,7 +120,7 @@ class TestGrowthManager: return True except Exception as e: - self.log(f"获取用户画像失败: {e}", success = False) + self.log(f"获取用户画像失败: {e}", success=False) return False def test_get_analytics_summary(self) -> None: @@ -128,10 +128,10 @@ class TestGrowthManager: print("\n📈 测试分析汇总...") try: - summary = self.manager.get_user_analytics_summary( - tenant_id = self.test_tenant_id, - start_date = datetime.now() - timedelta(days = 7), - end_date = datetime.now(), + summary = self.manager.get_user_analytics_summary( + tenant_id=self.test_tenant_id, + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), ) assert "unique_users" in summary @@ -141,7 +141,7 @@ class TestGrowthManager: self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件") return True except Exception as e: - self.log(f"获取分析汇总失败: {e}", success = False) + self.log(f"获取分析汇总失败: {e}", success=False) return False def test_create_funnel(self) -> None: @@ -149,17 +149,17 @@ class TestGrowthManager: print("\n🎯 测试创建转化漏斗...") try: - funnel = self.manager.create_funnel( - tenant_id = self.test_tenant_id, - name = "用户注册转化漏斗", - description = "从访问到完成注册的转化流程", - steps = [ + funnel = self.manager.create_funnel( + tenant_id=self.test_tenant_id, + name="用户注册转化漏斗", + description="从访问到完成注册的转化流程", + steps=[ {"name": "访问首页", "event_name": "page_view_home"}, {"name": "点击注册", "event_name": "signup_click"}, {"name": "填写信息", "event_name": "signup_form_fill"}, {"name": "完成注册", "event_name": "signup_complete"}, ], - created_by = "test", + created_by="test", ) assert funnel.id is not None @@ -168,7 +168,7 @@ class TestGrowthManager: self.log(f"漏斗创建成功: {funnel.id}") return funnel.id except Exception as e: - self.log(f"创建漏斗失败: {e}", success = False) + self.log(f"创建漏斗失败: {e}", success=False) return None def test_analyze_funnel(self, funnel_id: str) -> None: @@ -180,10 +180,10 @@ class TestGrowthManager: return False try: - analysis = self.manager.analyze_funnel( - funnel_id = funnel_id, - period_start = datetime.now() - timedelta(days = 30), - period_end = datetime.now(), + analysis = self.manager.analyze_funnel( + funnel_id=funnel_id, + period_start=datetime.now() - timedelta(days=30), + period_end=datetime.now(), ) if analysis: @@ -194,7 +194,7 @@ class TestGrowthManager: self.log("漏斗分析返回空结果") return False except Exception as e: - self.log(f"漏斗分析失败: {e}", success = False) + self.log(f"漏斗分析失败: {e}", success=False) return False def test_calculate_retention(self) -> None: @@ -202,10 +202,10 @@ class TestGrowthManager: print("\n🔄 测试留存率计算...") try: - retention = self.manager.calculate_retention( - tenant_id = self.test_tenant_id, - cohort_date = datetime.now() - timedelta(days = 7), - periods = [1, 3, 7], + retention = self.manager.calculate_retention( + tenant_id=self.test_tenant_id, + cohort_date=datetime.now() - timedelta(days=7), + periods=[1, 3, 7], ) assert "cohort_date" in retention @@ -214,7 +214,7 @@ class TestGrowthManager: self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户") return True except Exception as e: - self.log(f"留存率计算失败: {e}", success = False) + self.log(f"留存率计算失败: {e}", success=False) return False # ==================== 测试 A/B 测试框架 ==================== @@ -224,24 +224,24 @@ class TestGrowthManager: print("\n🧪 测试创建 A/B 测试实验...") try: - experiment = self.manager.create_experiment( - tenant_id = self.test_tenant_id, - name = "首页按钮颜色测试", - description = "测试不同按钮颜色对转化率的影响", - hypothesis = "蓝色按钮比红色按钮有更高的点击率", - variants = [ + experiment = self.manager.create_experiment( + tenant_id=self.test_tenant_id, + name="首页按钮颜色测试", + description="测试不同按钮颜色对转化率的影响", + hypothesis="蓝色按钮比红色按钮有更高的点击率", + variants=[ {"id": "control", "name": "红色按钮", "is_control": True}, {"id": "variant_a", "name": "蓝色按钮", "is_control": False}, {"id": "variant_b", "name": "绿色按钮", "is_control": False}, ], - traffic_allocation = TrafficAllocationType.RANDOM, - traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, - target_audience = {"conditions": []}, - primary_metric = "button_click_rate", - secondary_metrics = ["conversion_rate", "bounce_rate"], - min_sample_size = 100, - confidence_level = 0.95, - created_by = "test", + traffic_allocation=TrafficAllocationType.RANDOM, + traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, + target_audience={"conditions": []}, + primary_metric="button_click_rate", + secondary_metrics=["conversion_rate", "bounce_rate"], + min_sample_size=100, + confidence_level=0.95, + created_by="test", ) assert experiment.id is not None @@ -250,7 +250,7 @@ class TestGrowthManager: self.log(f"实验创建成功: {experiment.id}") return experiment.id except Exception as e: - self.log(f"创建实验失败: {e}", success = False) + self.log(f"创建实验失败: {e}", success=False) return None def test_list_experiments(self) -> None: @@ -258,12 +258,12 @@ class TestGrowthManager: print("\n📋 测试列出实验...") try: - experiments = self.manager.list_experiments(self.test_tenant_id) + experiments = self.manager.list_experiments(self.test_tenant_id) self.log(f"列出 {len(experiments)} 个实验") return True except Exception as e: - self.log(f"列出实验失败: {e}", success = False) + self.log(f"列出实验失败: {e}", success=False) return False def test_assign_variant(self, experiment_id: str) -> None: @@ -279,23 +279,23 @@ class TestGrowthManager: self.manager.start_experiment(experiment_id) # 测试多个用户的变体分配 - test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"] - assignments = {} + test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"] + assignments = {} for user_id in test_users: - variant_id = self.manager.assign_variant( - experiment_id = experiment_id, - user_id = user_id, - user_attributes = {"user_id": user_id, "segment": "new"}, + variant_id = self.manager.assign_variant( + experiment_id=experiment_id, + user_id=user_id, + user_attributes={"user_id": user_id, "segment": "new"}, ) if variant_id: - assignments[user_id] = variant_id + assignments[user_id] = variant_id self.log(f"变体分配完成: {len(assignments)} 个用户") return True except Exception as e: - self.log(f"变体分配失败: {e}", success = False) + self.log(f"变体分配失败: {e}", success=False) return False def test_record_experiment_metric(self, experiment_id: str) -> None: @@ -308,7 +308,7 @@ class TestGrowthManager: try: # 模拟记录一些指标 - test_data = [ + test_data = [ ("user_001", "control", 1), ("user_002", "variant_a", 1), ("user_003", "variant_b", 0), @@ -318,17 +318,17 @@ class TestGrowthManager: for user_id, variant_id, value in test_data: self.manager.record_experiment_metric( - experiment_id = experiment_id, - variant_id = variant_id, - user_id = user_id, - metric_name = "button_click_rate", - metric_value = value, + experiment_id=experiment_id, + variant_id=variant_id, + user_id=user_id, + metric_name="button_click_rate", + metric_value=value, ) self.log(f"成功记录 {len(test_data)} 条指标") return True except Exception as e: - self.log(f"记录指标失败: {e}", success = False) + self.log(f"记录指标失败: {e}", success=False) return False def test_analyze_experiment(self, experiment_id: str) -> None: @@ -340,16 +340,16 @@ class TestGrowthManager: return False try: - result = self.manager.analyze_experiment(experiment_id) + result = self.manager.analyze_experiment(experiment_id) if "error" not in result: self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体") return True else: - self.log(f"实验分析返回错误: {result['error']}", success = False) + self.log(f"实验分析返回错误: {result['error']}", success=False) return False except Exception as e: - self.log(f"实验分析失败: {e}", success = False) + self.log(f"实验分析失败: {e}", success=False) return False # ==================== 测试邮件营销 ==================== @@ -359,12 +359,12 @@ class TestGrowthManager: print("\n📧 测试创建邮件模板...") try: - template = self.manager.create_email_template( - tenant_id = self.test_tenant_id, - name = "欢迎邮件", - template_type = EmailTemplateType.WELCOME, - subject = "欢迎加入 InsightFlow!", - html_content = """ + template = self.manager.create_email_template( + tenant_id=self.test_tenant_id, + name="欢迎邮件", + template_type=EmailTemplateType.WELCOME, + subject="欢迎加入 InsightFlow!", + html_content="""

欢迎,{{user_name}}!

感谢您注册 InsightFlow。我们很高兴您能加入我们!

您的账户已创建,可以开始使用以下功能:

@@ -375,8 +375,8 @@ class TestGrowthManager:

立即开始使用

""", - from_name = "InsightFlow 团队", - from_email = "welcome@insightflow.io", + from_name="InsightFlow 团队", + from_email="welcome@insightflow.io", ) assert template.id is not None @@ -385,7 +385,7 @@ class TestGrowthManager: self.log(f"邮件模板创建成功: {template.id}") return template.id except Exception as e: - self.log(f"创建邮件模板失败: {e}", success = False) + self.log(f"创建邮件模板失败: {e}", success=False) return None def test_list_email_templates(self) -> None: @@ -393,12 +393,12 @@ class TestGrowthManager: print("\n📧 测试列出邮件模板...") try: - templates = self.manager.list_email_templates(self.test_tenant_id) + templates = self.manager.list_email_templates(self.test_tenant_id) self.log(f"列出 {len(templates)} 个邮件模板") return True except Exception as e: - self.log(f"列出邮件模板失败: {e}", success = False) + self.log(f"列出邮件模板失败: {e}", success=False) return False def test_render_template(self, template_id: str) -> None: @@ -410,9 +410,9 @@ class TestGrowthManager: return False try: - rendered = self.manager.render_template( - template_id = template_id, - variables = { + rendered = self.manager.render_template( + template_id=template_id, + variables={ "user_name": "张三", "dashboard_url": "https://app.insightflow.io/dashboard", }, @@ -424,10 +424,10 @@ class TestGrowthManager: self.log(f"模板渲染成功: {rendered['subject']}") return True else: - self.log("模板渲染返回空结果", success = False) + self.log("模板渲染返回空结果", success=False) return False except Exception as e: - self.log(f"模板渲染失败: {e}", success = False) + self.log(f"模板渲染失败: {e}", success=False) return False def test_create_email_campaign(self, template_id: str) -> None: @@ -439,11 +439,11 @@ class TestGrowthManager: return None try: - campaign = self.manager.create_email_campaign( - tenant_id = self.test_tenant_id, - name = "新用户欢迎活动", - template_id = template_id, - recipient_list = [ + campaign = self.manager.create_email_campaign( + tenant_id=self.test_tenant_id, + name="新用户欢迎活动", + template_id=template_id, + recipient_list=[ {"user_id": "user_001", "email": "user1@example.com"}, {"user_id": "user_002", "email": "user2@example.com"}, {"user_id": "user_003", "email": "user3@example.com"}, @@ -456,7 +456,7 @@ class TestGrowthManager: self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人") return campaign.id except Exception as e: - self.log(f"创建营销活动失败: {e}", success = False) + self.log(f"创建营销活动失败: {e}", success=False) return None def test_create_automation_workflow(self) -> None: @@ -464,13 +464,13 @@ class TestGrowthManager: print("\n🤖 测试创建自动化工作流...") try: - workflow = self.manager.create_automation_workflow( - tenant_id = self.test_tenant_id, - name = "新用户欢迎序列", - description = "用户注册后自动发送欢迎邮件序列", - trigger_type = WorkflowTriggerType.USER_SIGNUP, - trigger_conditions = {"event": "user_signup"}, - actions = [ + workflow = self.manager.create_automation_workflow( + tenant_id=self.test_tenant_id, + name="新用户欢迎序列", + description="用户注册后自动发送欢迎邮件序列", + trigger_type=WorkflowTriggerType.USER_SIGNUP, + trigger_conditions={"event": "user_signup"}, + actions=[ {"type": "send_email", "template_type": "welcome", "delay_hours": 0}, {"type": "send_email", "template_type": "onboarding", "delay_hours": 24}, {"type": "send_email", "template_type": "feature_tips", "delay_hours": 72}, @@ -483,7 +483,7 @@ class TestGrowthManager: self.log(f"自动化工作流创建成功: {workflow.id}") return True except Exception as e: - self.log(f"创建工作流失败: {e}", success = False) + self.log(f"创建工作流失败: {e}", success=False) return False # ==================== 测试推荐系统 ==================== @@ -493,17 +493,17 @@ class TestGrowthManager: print("\n🎁 测试创建推荐计划...") try: - program = self.manager.create_referral_program( - tenant_id = self.test_tenant_id, - name = "邀请好友奖励计划", - description = "邀请好友注册,双方获得积分奖励", - referrer_reward_type = "credit", - referrer_reward_value = 100.0, - referee_reward_type = "credit", - referee_reward_value = 50.0, - max_referrals_per_user = 10, - referral_code_length = 8, - expiry_days = 30, + program = self.manager.create_referral_program( + tenant_id=self.test_tenant_id, + name="邀请好友奖励计划", + description="邀请好友注册,双方获得积分奖励", + referrer_reward_type="credit", + referrer_reward_value=100.0, + referee_reward_type="credit", + referee_reward_value=50.0, + max_referrals_per_user=10, + referral_code_length=8, + expiry_days=30, ) assert program.id is not None @@ -512,7 +512,7 @@ class TestGrowthManager: self.log(f"推荐计划创建成功: {program.id}") return program.id except Exception as e: - self.log(f"创建推荐计划失败: {e}", success = False) + self.log(f"创建推荐计划失败: {e}", success=False) return None def test_generate_referral_code(self, program_id: str) -> None: @@ -524,8 +524,8 @@ class TestGrowthManager: return None try: - referral = self.manager.generate_referral_code( - program_id = program_id, referrer_id = "referrer_user_001" + referral = self.manager.generate_referral_code( + program_id=program_id, referrer_id="referrer_user_001" ) if referral: @@ -535,10 +535,10 @@ class TestGrowthManager: self.log(f"推荐码生成成功: {referral.referral_code}") return referral.referral_code else: - self.log("生成推荐码返回空结果", success = False) + self.log("生成推荐码返回空结果", success=False) return None except Exception as e: - self.log(f"生成推荐码失败: {e}", success = False) + self.log(f"生成推荐码失败: {e}", success=False) return None def test_apply_referral_code(self, referral_code: str) -> None: @@ -550,18 +550,18 @@ class TestGrowthManager: return False try: - success = self.manager.apply_referral_code( - referral_code = referral_code, referee_id = "new_user_001" + success = self.manager.apply_referral_code( + referral_code=referral_code, referee_id="new_user_001" ) if success: self.log(f"推荐码应用成功: {referral_code}") return True else: - self.log("推荐码应用失败", success = False) + self.log("推荐码应用失败", success=False) return False except Exception as e: - self.log(f"应用推荐码失败: {e}", success = False) + self.log(f"应用推荐码失败: {e}", success=False) return False def test_get_referral_stats(self, program_id: str) -> None: @@ -573,7 +573,7 @@ class TestGrowthManager: return False try: - stats = self.manager.get_referral_stats(program_id) + stats = self.manager.get_referral_stats(program_id) assert "total_referrals" in stats assert "conversion_rate" in stats @@ -583,7 +583,7 @@ class TestGrowthManager: ) return True except Exception as e: - self.log(f"获取推荐统计失败: {e}", success = False) + self.log(f"获取推荐统计失败: {e}", success=False) return False def test_create_team_incentive(self) -> None: @@ -591,16 +591,16 @@ class TestGrowthManager: print("\n🏆 测试创建团队升级激励...") try: - incentive = self.manager.create_team_incentive( - tenant_id = self.test_tenant_id, - name = "团队升级奖励", - description = "团队规模达到5人升级到 Pro 计划可获得折扣", - target_tier = "pro", - min_team_size = 5, - incentive_type = "discount", - incentive_value = 20.0, # 20% 折扣 - valid_from = datetime.now(), - valid_until = datetime.now() + timedelta(days = 90), + incentive = self.manager.create_team_incentive( + tenant_id=self.test_tenant_id, + name="团队升级奖励", + description="团队规模达到5人升级到 Pro 计划可获得折扣", + target_tier="pro", + min_team_size=5, + incentive_type="discount", + incentive_value=20.0, # 20% 折扣 + valid_from=datetime.now(), + valid_until=datetime.now() + timedelta(days=90), ) assert incentive.id is not None @@ -609,7 +609,7 @@ class TestGrowthManager: self.log(f"团队激励创建成功: {incentive.id}") return True except Exception as e: - self.log(f"创建团队激励失败: {e}", success = False) + self.log(f"创建团队激励失败: {e}", success=False) return False def test_check_team_incentive_eligibility(self) -> None: @@ -617,14 +617,14 @@ class TestGrowthManager: print("\n🔍 测试检查团队激励资格...") try: - incentives = self.manager.check_team_incentive_eligibility( - tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5 + incentives = self.manager.check_team_incentive_eligibility( + tenant_id=self.test_tenant_id, current_tier="free", team_size=5 ) self.log(f"找到 {len(incentives)} 个符合条件的激励") return True except Exception as e: - self.log(f"检查激励资格失败: {e}", success = False) + self.log(f"检查激励资格失败: {e}", success=False) return False # ==================== 测试实时仪表板 ==================== @@ -634,19 +634,19 @@ class TestGrowthManager: print("\n📺 测试实时分析仪表板...") try: - dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id) + dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id) assert "today" in dashboard assert "recent_events" in dashboard assert "top_features" in dashboard - today = dashboard["today"] + today = dashboard["today"] self.log( f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件" ) return True except Exception as e: - self.log(f"获取实时仪表板失败: {e}", success = False) + self.log(f"获取实时仪表板失败: {e}", success=False) return False # ==================== 运行所有测试 ==================== @@ -666,7 +666,7 @@ class TestGrowthManager: await self.test_track_multiple_events() self.test_get_user_profile() self.test_get_analytics_summary() - funnel_id = self.test_create_funnel() + funnel_id = self.test_create_funnel() self.test_analyze_funnel(funnel_id) self.test_calculate_retention() @@ -675,7 +675,7 @@ class TestGrowthManager: print("🧪 模块 2: A/B 测试框架") print(" = " * 60) - experiment_id = self.test_create_experiment() + experiment_id = self.test_create_experiment() self.test_list_experiments() self.test_assign_variant(experiment_id) self.test_record_experiment_metric(experiment_id) @@ -686,7 +686,7 @@ class TestGrowthManager: print("📧 模块 3: 邮件营销自动化") print(" = " * 60) - template_id = self.test_create_email_template() + template_id = self.test_create_email_template() self.test_list_email_templates() self.test_render_template(template_id) self.test_create_email_campaign(template_id) @@ -697,8 +697,8 @@ class TestGrowthManager: print("🎁 模块 4: 推荐系统") print(" = " * 60) - program_id = self.test_create_referral_program() - referral_code = self.test_generate_referral_code(program_id) + program_id = self.test_create_referral_program() + referral_code = self.test_generate_referral_code(program_id) self.test_apply_referral_code(referral_code) self.test_get_referral_stats(program_id) self.test_create_team_incentive() @@ -716,9 +716,9 @@ class TestGrowthManager: print("📋 测试总结") print(" = " * 60) - total_tests = len(self.test_results) - passed_tests = sum(1 for _, success in self.test_results if success) - failed_tests = total_tests - passed_tests + total_tests = len(self.test_results) + passed_tests = sum(1 for _, success in self.test_results if success) + failed_tests = total_tests - passed_tests print(f"总测试数: {total_tests}") print(f"通过: {passed_tests} ✅") @@ -738,7 +738,7 @@ class TestGrowthManager: async def main() -> None: """主函数""" - tester = TestGrowthManager() + tester = TestGrowthManager() await tester.run_all_tests() diff --git a/backend/test_phase8_task6.py b/backend/test_phase8_task6.py index 6163551..b78c061 100644 --- a/backend/test_phase8_task6.py +++ b/backend/test_phase8_task6.py @@ -25,7 +25,7 @@ from developer_ecosystem_manager import ( ) # Add backend directory to path -backend_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.dirname(os.path.abspath(__file__)) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) @@ -34,9 +34,9 @@ class TestDeveloperEcosystem: """开发者生态系统测试类""" def __init__(self) -> None: - self.manager = DeveloperEcosystemManager() - self.test_results = [] - self.created_ids = { + self.manager = DeveloperEcosystemManager() + self.test_results = [] + self.created_ids = { "sdk": [], "template": [], "plugin": [], @@ -45,9 +45,9 @@ class TestDeveloperEcosystem: "portal_config": [], } - def log(self, message: str, success: bool = True) -> None: + def log(self, message: str, success: bool = True) -> None: """记录测试结果""" - status = "✅" if success else "❌" + status = "✅" if success else "❌" print(f"{status} {message}") self.test_results.append( {"message": message, "success": success, "timestamp": datetime.now().isoformat()} @@ -122,429 +122,429 @@ class TestDeveloperEcosystem: def test_sdk_create(self) -> None: """测试创建 SDK""" try: - sdk = self.manager.create_sdk_release( - name = "InsightFlow Python SDK", - language = SDKLanguage.PYTHON, - version = "1.0.0", - description = "Python SDK for InsightFlow API", - changelog = "Initial release", - download_url = "https://pypi.org/insightflow/1.0.0", - documentation_url = "https://docs.insightflow.io/python", - repository_url = "https://github.com/insightflow/python-sdk", - package_name = "insightflow", - min_platform_version = "1.0.0", - dependencies = [{"name": "requests", "version": ">= 2.0"}], - file_size = 1024000, - checksum = "abc123", - created_by = "test_user", + sdk = self.manager.create_sdk_release( + name="InsightFlow Python SDK", + language=SDKLanguage.PYTHON, + version="1.0.0", + description="Python SDK for InsightFlow API", + changelog="Initial release", + download_url="https://pypi.org/insightflow/1.0.0", + documentation_url="https://docs.insightflow.io/python", + repository_url="https://github.com/insightflow/python-sdk", + package_name="insightflow", + min_platform_version="1.0.0", + dependencies=[{"name": "requests", "version": ">= 2.0"}], + file_size=1024000, + checksum="abc123", + created_by="test_user", ) self.created_ids["sdk"].append(sdk.id) self.log(f"Created SDK: {sdk.name} ({sdk.id})") # Create JavaScript SDK - sdk_js = self.manager.create_sdk_release( - name = "InsightFlow JavaScript SDK", - language = SDKLanguage.JAVASCRIPT, - version = "1.0.0", - description = "JavaScript SDK for InsightFlow API", - changelog = "Initial release", - download_url = "https://npmjs.com/insightflow/1.0.0", - documentation_url = "https://docs.insightflow.io/js", - repository_url = "https://github.com/insightflow/js-sdk", - package_name = "@insightflow/sdk", - min_platform_version = "1.0.0", - dependencies = [{"name": "axios", "version": ">= 0.21"}], - file_size = 512000, - checksum = "def456", - created_by = "test_user", + sdk_js = self.manager.create_sdk_release( + name="InsightFlow JavaScript SDK", + language=SDKLanguage.JAVASCRIPT, + version="1.0.0", + description="JavaScript SDK for InsightFlow API", + changelog="Initial release", + download_url="https://npmjs.com/insightflow/1.0.0", + documentation_url="https://docs.insightflow.io/js", + repository_url="https://github.com/insightflow/js-sdk", + package_name="@insightflow/sdk", + min_platform_version="1.0.0", + dependencies=[{"name": "axios", "version": ">= 0.21"}], + file_size=512000, + checksum="def456", + created_by="test_user", ) self.created_ids["sdk"].append(sdk_js.id) self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})") except Exception as e: - self.log(f"Failed to create SDK: {str(e)}", success = False) + self.log(f"Failed to create SDK: {str(e)}", success=False) def test_sdk_list(self) -> None: """测试列出 SDK""" try: - sdks = self.manager.list_sdk_releases() + sdks = self.manager.list_sdk_releases() self.log(f"Listed {len(sdks)} SDKs") # Test filter by language - python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON) + python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON) self.log(f"Found {len(python_sdks)} Python SDKs") # Test search - search_results = self.manager.list_sdk_releases(search = "Python") + search_results = self.manager.list_sdk_releases(search="Python") self.log(f"Search found {len(search_results)} SDKs") except Exception as e: - self.log(f"Failed to list SDKs: {str(e)}", success = False) + self.log(f"Failed to list SDKs: {str(e)}", success=False) def test_sdk_get(self) -> None: """测试获取 SDK 详情""" try: if self.created_ids["sdk"]: - sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0]) + sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0]) if sdk: self.log(f"Retrieved SDK: {sdk.name}") else: - self.log("SDK not found", success = False) + self.log("SDK not found", success=False) except Exception as e: - self.log(f"Failed to get SDK: {str(e)}", success = False) + self.log(f"Failed to get SDK: {str(e)}", success=False) def test_sdk_update(self) -> None: """测试更新 SDK""" try: if self.created_ids["sdk"]: - sdk = self.manager.update_sdk_release( - self.created_ids["sdk"][0], description = "Updated description" + sdk = self.manager.update_sdk_release( + self.created_ids["sdk"][0], description="Updated description" ) if sdk: self.log(f"Updated SDK: {sdk.name}") except Exception as e: - self.log(f"Failed to update SDK: {str(e)}", success = False) + self.log(f"Failed to update SDK: {str(e)}", success=False) def test_sdk_publish(self) -> None: """测试发布 SDK""" try: if self.created_ids["sdk"]: - sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0]) + sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0]) if sdk: self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})") except Exception as e: - self.log(f"Failed to publish SDK: {str(e)}", success = False) + self.log(f"Failed to publish SDK: {str(e)}", success=False) def test_sdk_version_add(self) -> None: """测试添加 SDK 版本""" try: if self.created_ids["sdk"]: - version = self.manager.add_sdk_version( - sdk_id = self.created_ids["sdk"][0], - version = "1.1.0", - is_lts = True, - release_notes = "Bug fixes and improvements", - download_url = "https://pypi.org/insightflow/1.1.0", - checksum = "xyz789", - file_size = 1100000, + version = self.manager.add_sdk_version( + sdk_id=self.created_ids["sdk"][0], + version="1.1.0", + is_lts=True, + release_notes="Bug fixes and improvements", + download_url="https://pypi.org/insightflow/1.1.0", + checksum="xyz789", + file_size=1100000, ) self.log(f"Added SDK version: {version.version}") except Exception as e: - self.log(f"Failed to add SDK version: {str(e)}", success = False) + self.log(f"Failed to add SDK version: {str(e)}", success=False) def test_template_create(self) -> None: """测试创建模板""" try: - template = self.manager.create_template( - name = "医疗行业实体识别模板", - description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体", - category = TemplateCategory.MEDICAL, - subcategory = "entity_recognition", - tags = ["medical", "healthcare", "ner"], - author_id = "dev_001", - author_name = "Medical AI Lab", - price = 99.0, - currency = "CNY", - preview_image_url = "https://cdn.insightflow.io/templates/medical.png", - demo_url = "https://demo.insightflow.io/medical", - documentation_url = "https://docs.insightflow.io/templates/medical", - download_url = "https://cdn.insightflow.io/templates/medical.zip", - version = "1.0.0", - min_platform_version = "2.0.0", - file_size = 5242880, - checksum = "tpl123", + template = self.manager.create_template( + name="医疗行业实体识别模板", + description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体", + category=TemplateCategory.MEDICAL, + subcategory="entity_recognition", + tags=["medical", "healthcare", "ner"], + author_id="dev_001", + author_name="Medical AI Lab", + price=99.0, + currency="CNY", + preview_image_url="https://cdn.insightflow.io/templates/medical.png", + demo_url="https://demo.insightflow.io/medical", + documentation_url="https://docs.insightflow.io/templates/medical", + download_url="https://cdn.insightflow.io/templates/medical.zip", + version="1.0.0", + min_platform_version="2.0.0", + file_size=5242880, + checksum="tpl123", ) self.created_ids["template"].append(template.id) self.log(f"Created template: {template.name} ({template.id})") # Create free template - template_free = self.manager.create_template( - name = "通用实体识别模板", - description = "适用于一般场景的实体识别模板", - category = TemplateCategory.GENERAL, - subcategory = None, - tags = ["general", "ner", "basic"], - author_id = "dev_002", - author_name = "InsightFlow Team", - price = 0.0, - currency = "CNY", + template_free = self.manager.create_template( + name="通用实体识别模板", + description="适用于一般场景的实体识别模板", + category=TemplateCategory.GENERAL, + subcategory=None, + tags=["general", "ner", "basic"], + author_id="dev_002", + author_name="InsightFlow Team", + price=0.0, + currency="CNY", ) self.created_ids["template"].append(template_free.id) self.log(f"Created free template: {template_free.name}") except Exception as e: - self.log(f"Failed to create template: {str(e)}", success = False) + self.log(f"Failed to create template: {str(e)}", success=False) def test_template_list(self) -> None: """测试列出模板""" try: - templates = self.manager.list_templates() + templates = self.manager.list_templates() self.log(f"Listed {len(templates)} templates") # Filter by category - medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL) + medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL) self.log(f"Found {len(medical_templates)} medical templates") # Filter by price - free_templates = self.manager.list_templates(max_price = 0) + free_templates = self.manager.list_templates(max_price=0) self.log(f"Found {len(free_templates)} free templates") except Exception as e: - self.log(f"Failed to list templates: {str(e)}", success = False) + self.log(f"Failed to list templates: {str(e)}", success=False) def test_template_get(self) -> None: """测试获取模板详情""" try: if self.created_ids["template"]: - template = self.manager.get_template(self.created_ids["template"][0]) + template = self.manager.get_template(self.created_ids["template"][0]) if template: self.log(f"Retrieved template: {template.name}") except Exception as e: - self.log(f"Failed to get template: {str(e)}", success = False) + self.log(f"Failed to get template: {str(e)}", success=False) def test_template_approve(self) -> None: """测试审核通过模板""" try: if self.created_ids["template"]: - template = self.manager.approve_template( - self.created_ids["template"][0], reviewed_by = "admin_001" + template = self.manager.approve_template( + self.created_ids["template"][0], reviewed_by="admin_001" ) if template: self.log(f"Approved template: {template.name}") except Exception as e: - self.log(f"Failed to approve template: {str(e)}", success = False) + self.log(f"Failed to approve template: {str(e)}", success=False) def test_template_publish(self) -> None: """测试发布模板""" try: if self.created_ids["template"]: - template = self.manager.publish_template(self.created_ids["template"][0]) + template = self.manager.publish_template(self.created_ids["template"][0]) if template: self.log(f"Published template: {template.name}") except Exception as e: - self.log(f"Failed to publish template: {str(e)}", success = False) + self.log(f"Failed to publish template: {str(e)}", success=False) def test_template_review(self) -> None: """测试添加模板评价""" try: if self.created_ids["template"]: - review = self.manager.add_template_review( - template_id = self.created_ids["template"][0], - user_id = "user_001", - user_name = "Test User", - rating = 5, - comment = "Great template! Very accurate for medical entities.", - is_verified_purchase = True, + review = self.manager.add_template_review( + template_id=self.created_ids["template"][0], + user_id="user_001", + user_name="Test User", + rating=5, + comment="Great template! Very accurate for medical entities.", + is_verified_purchase=True, ) self.log(f"Added template review: {review.rating} stars") except Exception as e: - self.log(f"Failed to add template review: {str(e)}", success = False) + self.log(f"Failed to add template review: {str(e)}", success=False) def test_plugin_create(self) -> None: """测试创建插件""" try: - plugin = self.manager.create_plugin( - name = "飞书机器人集成插件", - description = "将 InsightFlow 与飞书机器人集成,实现自动通知", - category = PluginCategory.INTEGRATION, - tags = ["feishu", "bot", "integration", "notification"], - author_id = "dev_003", - author_name = "Integration Team", - price = 49.0, - currency = "CNY", - pricing_model = "paid", - preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png", - demo_url = "https://demo.insightflow.io/feishu", - documentation_url = "https://docs.insightflow.io/plugins/feishu", - repository_url = "https://github.com/insightflow/feishu-plugin", - download_url = "https://cdn.insightflow.io/plugins/feishu.zip", - webhook_url = "https://api.insightflow.io/webhooks/feishu", - permissions = ["read:projects", "write:notifications"], - version = "1.0.0", - min_platform_version = "2.0.0", - file_size = 1048576, - checksum = "plg123", + plugin = self.manager.create_plugin( + name="飞书机器人集成插件", + description="将 InsightFlow 与飞书机器人集成,实现自动通知", + category=PluginCategory.INTEGRATION, + tags=["feishu", "bot", "integration", "notification"], + author_id="dev_003", + author_name="Integration Team", + price=49.0, + currency="CNY", + pricing_model="paid", + preview_image_url="https://cdn.insightflow.io/plugins/feishu.png", + demo_url="https://demo.insightflow.io/feishu", + documentation_url="https://docs.insightflow.io/plugins/feishu", + repository_url="https://github.com/insightflow/feishu-plugin", + download_url="https://cdn.insightflow.io/plugins/feishu.zip", + webhook_url="https://api.insightflow.io/webhooks/feishu", + permissions=["read:projects", "write:notifications"], + version="1.0.0", + min_platform_version="2.0.0", + file_size=1048576, + checksum="plg123", ) self.created_ids["plugin"].append(plugin.id) self.log(f"Created plugin: {plugin.name} ({plugin.id})") # Create free plugin - plugin_free = self.manager.create_plugin( - name = "数据导出插件", - description = "支持多种格式的数据导出", - category = PluginCategory.ANALYSIS, - tags = ["export", "data", "csv", "json"], - author_id = "dev_004", - author_name = "Data Team", - price = 0.0, - currency = "CNY", - pricing_model = "free", + plugin_free = self.manager.create_plugin( + name="数据导出插件", + description="支持多种格式的数据导出", + category=PluginCategory.ANALYSIS, + tags=["export", "data", "csv", "json"], + author_id="dev_004", + author_name="Data Team", + price=0.0, + currency="CNY", + pricing_model="free", ) self.created_ids["plugin"].append(plugin_free.id) self.log(f"Created free plugin: {plugin_free.name}") except Exception as e: - self.log(f"Failed to create plugin: {str(e)}", success = False) + self.log(f"Failed to create plugin: {str(e)}", success=False) def test_plugin_list(self) -> None: """测试列出插件""" try: - plugins = self.manager.list_plugins() + plugins = self.manager.list_plugins() self.log(f"Listed {len(plugins)} plugins") # Filter by category - integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION) + integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION) self.log(f"Found {len(integration_plugins)} integration plugins") except Exception as e: - self.log(f"Failed to list plugins: {str(e)}", success = False) + self.log(f"Failed to list plugins: {str(e)}", success=False) def test_plugin_get(self) -> None: """测试获取插件详情""" try: if self.created_ids["plugin"]: - plugin = self.manager.get_plugin(self.created_ids["plugin"][0]) + plugin = self.manager.get_plugin(self.created_ids["plugin"][0]) if plugin: self.log(f"Retrieved plugin: {plugin.name}") except Exception as e: - self.log(f"Failed to get plugin: {str(e)}", success = False) + self.log(f"Failed to get plugin: {str(e)}", success=False) def test_plugin_review(self) -> None: """测试审核插件""" try: if self.created_ids["plugin"]: - plugin = self.manager.review_plugin( + plugin = self.manager.review_plugin( self.created_ids["plugin"][0], - reviewed_by = "admin_001", - status = PluginStatus.APPROVED, - notes = "Code review passed", + reviewed_by="admin_001", + status=PluginStatus.APPROVED, + notes="Code review passed", ) if plugin: self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})") except Exception as e: - self.log(f"Failed to review plugin: {str(e)}", success = False) + self.log(f"Failed to review plugin: {str(e)}", success=False) def test_plugin_publish(self) -> None: """测试发布插件""" try: if self.created_ids["plugin"]: - plugin = self.manager.publish_plugin(self.created_ids["plugin"][0]) + plugin = self.manager.publish_plugin(self.created_ids["plugin"][0]) if plugin: self.log(f"Published plugin: {plugin.name}") except Exception as e: - self.log(f"Failed to publish plugin: {str(e)}", success = False) + self.log(f"Failed to publish plugin: {str(e)}", success=False) def test_plugin_review_add(self) -> None: """测试添加插件评价""" try: if self.created_ids["plugin"]: - review = self.manager.add_plugin_review( - plugin_id = self.created_ids["plugin"][0], - user_id = "user_002", - user_name = "Plugin User", - rating = 4, - comment = "Works great with Feishu!", - is_verified_purchase = True, + review = self.manager.add_plugin_review( + plugin_id=self.created_ids["plugin"][0], + user_id="user_002", + user_name="Plugin User", + rating=4, + comment="Works great with Feishu!", + is_verified_purchase=True, ) self.log(f"Added plugin review: {review.rating} stars") except Exception as e: - self.log(f"Failed to add plugin review: {str(e)}", success = False) + self.log(f"Failed to add plugin review: {str(e)}", success=False) def test_developer_profile_create(self) -> None: """测试创建开发者档案""" try: # Generate unique user IDs - unique_id = uuid.uuid4().hex[:8] + unique_id = uuid.uuid4().hex[:8] - profile = self.manager.create_developer_profile( - user_id = f"user_dev_{unique_id}_001", - display_name = "张三", - email = f"zhangsan_{unique_id}@example.com", - bio = "专注于医疗AI和自然语言处理", - website = "https://zhangsan.dev", - github_url = "https://github.com/zhangsan", - avatar_url = "https://cdn.example.com/avatars/zhangsan.png", + profile = self.manager.create_developer_profile( + user_id=f"user_dev_{unique_id}_001", + display_name="张三", + email=f"zhangsan_{unique_id}@example.com", + bio="专注于医疗AI和自然语言处理", + website="https://zhangsan.dev", + github_url="https://github.com/zhangsan", + avatar_url="https://cdn.example.com/avatars/zhangsan.png", ) self.created_ids["developer"].append(profile.id) self.log(f"Created developer profile: {profile.display_name} ({profile.id})") # Create another developer - profile2 = self.manager.create_developer_profile( - user_id = f"user_dev_{unique_id}_002", - display_name = "李四", - email = f"lisi_{unique_id}@example.com", - bio = "全栈开发者,热爱开源", + profile2 = self.manager.create_developer_profile( + user_id=f"user_dev_{unique_id}_002", + display_name="李四", + email=f"lisi_{unique_id}@example.com", + bio="全栈开发者,热爱开源", ) self.created_ids["developer"].append(profile2.id) self.log(f"Created developer profile: {profile2.display_name}") except Exception as e: - self.log(f"Failed to create developer profile: {str(e)}", success = False) + self.log(f"Failed to create developer profile: {str(e)}", success=False) def test_developer_profile_get(self) -> None: """测试获取开发者档案""" try: if self.created_ids["developer"]: - profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) + profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) if profile: self.log(f"Retrieved developer profile: {profile.display_name}") except Exception as e: - self.log(f"Failed to get developer profile: {str(e)}", success = False) + self.log(f"Failed to get developer profile: {str(e)}", success=False) def test_developer_verify(self) -> None: """测试验证开发者""" try: if self.created_ids["developer"]: - profile = self.manager.verify_developer( + profile = self.manager.verify_developer( self.created_ids["developer"][0], DeveloperStatus.VERIFIED ) if profile: self.log(f"Verified developer: {profile.display_name} ({profile.status.value})") except Exception as e: - self.log(f"Failed to verify developer: {str(e)}", success = False) + self.log(f"Failed to verify developer: {str(e)}", success=False) def test_developer_stats_update(self) -> None: """测试更新开发者统计""" try: if self.created_ids["developer"]: self.manager.update_developer_stats(self.created_ids["developer"][0]) - profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) + profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) self.log( f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates" ) except Exception as e: - self.log(f"Failed to update developer stats: {str(e)}", success = False) + self.log(f"Failed to update developer stats: {str(e)}", success=False) def test_code_example_create(self) -> None: """测试创建代码示例""" try: - example = self.manager.create_code_example( - title = "使用 Python SDK 创建项目", - description = "演示如何使用 Python SDK 创建新项目", - language = "python", - category = "quickstart", - code = """from insightflow import Client + example = self.manager.create_code_example( + title="使用 Python SDK 创建项目", + description="演示如何使用 Python SDK 创建新项目", + language="python", + category="quickstart", + code="""from insightflow import Client client = Client(api_key = "your_api_key") project = client.projects.create(name = "My Project") print(f"Created project: {project.id}") """, - explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。", - tags = ["python", "quickstart", "projects"], - author_id = "dev_001", - author_name = "InsightFlow Team", - api_endpoints = ["/api/v1/projects"], + explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。", + tags=["python", "quickstart", "projects"], + author_id="dev_001", + author_name="InsightFlow Team", + api_endpoints=["/api/v1/projects"], ) self.created_ids["code_example"].append(example.id) self.log(f"Created code example: {example.title}") # Create JavaScript example - example_js = self.manager.create_code_example( - title = "使用 JavaScript SDK 上传文件", - description = "演示如何使用 JavaScript SDK 上传音频文件", - language = "javascript", - category = "upload", - code = """const { Client } = require('insightflow'); + example_js = self.manager.create_code_example( + title="使用 JavaScript SDK 上传文件", + description="演示如何使用 JavaScript SDK 上传音频文件", + language="javascript", + category="upload", + code="""const { Client } = require('insightflow'); const client = new Client({ apiKey: 'your_api_key' }); const result = await client.uploads.create({ @@ -553,104 +553,104 @@ const result = await client.uploads.create({ }); console.log('Upload complete:', result.id); """, - explanation = "使用 JavaScript SDK 上传文件到 InsightFlow", - tags = ["javascript", "upload", "audio"], - author_id = "dev_002", - author_name = "JS Team", + explanation="使用 JavaScript SDK 上传文件到 InsightFlow", + tags=["javascript", "upload", "audio"], + author_id="dev_002", + author_name="JS Team", ) self.created_ids["code_example"].append(example_js.id) self.log(f"Created code example: {example_js.title}") except Exception as e: - self.log(f"Failed to create code example: {str(e)}", success = False) + self.log(f"Failed to create code example: {str(e)}", success=False) def test_code_example_list(self) -> None: """测试列出代码示例""" try: - examples = self.manager.list_code_examples() + examples = self.manager.list_code_examples() self.log(f"Listed {len(examples)} code examples") # Filter by language - python_examples = self.manager.list_code_examples(language = "python") + python_examples = self.manager.list_code_examples(language="python") self.log(f"Found {len(python_examples)} Python examples") except Exception as e: - self.log(f"Failed to list code examples: {str(e)}", success = False) + self.log(f"Failed to list code examples: {str(e)}", success=False) def test_code_example_get(self) -> None: """测试获取代码示例详情""" try: if self.created_ids["code_example"]: - example = self.manager.get_code_example(self.created_ids["code_example"][0]) + example = self.manager.get_code_example(self.created_ids["code_example"][0]) if example: self.log( f"Retrieved code example: {example.title} (views: {example.view_count})" ) except Exception as e: - self.log(f"Failed to get code example: {str(e)}", success = False) + self.log(f"Failed to get code example: {str(e)}", success=False) def test_portal_config_create(self) -> None: """测试创建开发者门户配置""" try: - config = self.manager.create_portal_config( - name = "InsightFlow Developer Portal", - description = "开发者门户 - SDK、API 文档和示例代码", - theme = "default", - primary_color = "#1890ff", - secondary_color = "#52c41a", - support_email = "developers@insightflow.io", - support_url = "https://support.insightflow.io", - github_url = "https://github.com/insightflow", - discord_url = "https://discord.gg/insightflow", - api_base_url = "https://api.insightflow.io/v1", + config = self.manager.create_portal_config( + name="InsightFlow Developer Portal", + description="开发者门户 - SDK、API 文档和示例代码", + theme="default", + primary_color="#1890ff", + secondary_color="#52c41a", + support_email="developers@insightflow.io", + support_url="https://support.insightflow.io", + github_url="https://github.com/insightflow", + discord_url="https://discord.gg/insightflow", + api_base_url="https://api.insightflow.io/v1", ) self.created_ids["portal_config"].append(config.id) self.log(f"Created portal config: {config.name}") except Exception as e: - self.log(f"Failed to create portal config: {str(e)}", success = False) + self.log(f"Failed to create portal config: {str(e)}", success=False) def test_portal_config_get(self) -> None: """测试获取开发者门户配置""" try: if self.created_ids["portal_config"]: - config = self.manager.get_portal_config(self.created_ids["portal_config"][0]) + config = self.manager.get_portal_config(self.created_ids["portal_config"][0]) if config: self.log(f"Retrieved portal config: {config.name}") # Test active config - active_config = self.manager.get_active_portal_config() + active_config = self.manager.get_active_portal_config() if active_config: self.log(f"Active portal config: {active_config.name}") except Exception as e: - self.log(f"Failed to get portal config: {str(e)}", success = False) + self.log(f"Failed to get portal config: {str(e)}", success=False) def test_revenue_record(self) -> None: """测试记录开发者收益""" try: if self.created_ids["developer"] and self.created_ids["plugin"]: - revenue = self.manager.record_revenue( - developer_id = self.created_ids["developer"][0], - item_type = "plugin", - item_id = self.created_ids["plugin"][0], - item_name = "飞书机器人集成插件", - sale_amount = 49.0, - currency = "CNY", - buyer_id = "user_buyer_001", - transaction_id = "txn_123456", + revenue = self.manager.record_revenue( + developer_id=self.created_ids["developer"][0], + item_type="plugin", + item_id=self.created_ids["plugin"][0], + item_name="飞书机器人集成插件", + sale_amount=49.0, + currency="CNY", + buyer_id="user_buyer_001", + transaction_id="txn_123456", ) self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}") self.log(f" - Platform fee: {revenue.platform_fee}") self.log(f" - Developer earnings: {revenue.developer_earnings}") except Exception as e: - self.log(f"Failed to record revenue: {str(e)}", success = False) + self.log(f"Failed to record revenue: {str(e)}", success=False) def test_revenue_summary(self) -> None: """测试获取开发者收益汇总""" try: if self.created_ids["developer"]: - summary = self.manager.get_developer_revenue_summary( + summary = self.manager.get_developer_revenue_summary( self.created_ids["developer"][0] ) self.log("Revenue summary for developer:") @@ -659,7 +659,7 @@ console.log('Upload complete:', result.id); self.log(f" - Total earnings: {summary['total_earnings']}") self.log(f" - Transaction count: {summary['transaction_count']}") except Exception as e: - self.log(f"Failed to get revenue summary: {str(e)}", success = False) + self.log(f"Failed to get revenue summary: {str(e)}", success=False) def print_summary(self) -> None: """打印测试摘要""" @@ -667,9 +667,9 @@ console.log('Upload complete:', result.id); print("Test Summary") print(" = " * 60) - total = len(self.test_results) - passed = sum(1 for r in self.test_results if r["success"]) - failed = total - passed + total = len(self.test_results) + passed = sum(1 for r in self.test_results if r["success"]) + failed = total - passed print(f"Total tests: {total}") print(f"Passed: {passed} ✅") @@ -691,7 +691,7 @@ console.log('Upload complete:', result.id); def main() -> None: """主函数""" - test = TestDeveloperEcosystem() + test = TestDeveloperEcosystem() test.run_all_tests() diff --git a/backend/test_phase8_task8.py b/backend/test_phase8_task8.py index eef988a..a24307d 100644 --- a/backend/test_phase8_task8.py +++ b/backend/test_phase8_task8.py @@ -26,7 +26,7 @@ from ops_manager import ( ) # Add backend directory to path -backend_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.dirname(os.path.abspath(__file__)) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) @@ -35,13 +35,13 @@ class TestOpsManager: """测试运维与监控管理器""" def __init__(self) -> None: - self.manager = get_ops_manager() - self.tenant_id = "test_tenant_001" - self.test_results = [] + self.manager = get_ops_manager() + self.tenant_id = "test_tenant_001" + self.test_results = [] - def log(self, message: str, success: bool = True) -> None: + def log(self, message: str, success: bool = True) -> None: """记录测试结果""" - status = "✅" if success else "❌" + status = "✅" if success else "❌" print(f"{status} {message}") self.test_results.append((message, success)) @@ -79,57 +79,57 @@ class TestOpsManager: try: # 创建阈值告警规则 - rule1 = self.manager.create_alert_rule( - tenant_id = self.tenant_id, - name = "CPU 使用率告警", - description = "当 CPU 使用率超过 80% 时触发告警", - rule_type = AlertRuleType.THRESHOLD, - severity = AlertSeverity.P1, - metric = "cpu_usage_percent", - condition = ">", - threshold = 80.0, - duration = 300, - evaluation_interval = 60, - channels = [], - labels = {"service": "api", "team": "platform"}, - annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, - created_by = "test_user", + rule1 = self.manager.create_alert_rule( + tenant_id=self.tenant_id, + name="CPU 使用率告警", + description="当 CPU 使用率超过 80% 时触发告警", + rule_type=AlertRuleType.THRESHOLD, + severity=AlertSeverity.P1, + metric="cpu_usage_percent", + condition=">", + threshold=80.0, + duration=300, + evaluation_interval=60, + channels=[], + labels={"service": "api", "team": "platform"}, + annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, + created_by="test_user", ) self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})") # 创建异常检测告警规则 - rule2 = self.manager.create_alert_rule( - tenant_id = self.tenant_id, - name = "内存异常检测", - description = "检测内存使用异常", - rule_type = AlertRuleType.ANOMALY, - severity = AlertSeverity.P2, - metric = "memory_usage_percent", - condition = ">", - threshold = 0.0, - duration = 600, - evaluation_interval = 300, - channels = [], - labels = {"service": "database"}, - annotations = {}, - created_by = "test_user", + rule2 = self.manager.create_alert_rule( + tenant_id=self.tenant_id, + name="内存异常检测", + description="检测内存使用异常", + rule_type=AlertRuleType.ANOMALY, + severity=AlertSeverity.P2, + metric="memory_usage_percent", + condition=">", + threshold=0.0, + duration=600, + evaluation_interval=300, + channels=[], + labels={"service": "database"}, + annotations={}, + created_by="test_user", ) self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})") # 获取告警规则 - fetched_rule = self.manager.get_alert_rule(rule1.id) + fetched_rule = self.manager.get_alert_rule(rule1.id) assert fetched_rule is not None assert fetched_rule.name == rule1.name self.log(f"Fetched alert rule: {fetched_rule.name}") # 列出租户的所有告警规则 - rules = self.manager.list_alert_rules(self.tenant_id) + rules = self.manager.list_alert_rules(self.tenant_id) assert len(rules) >= 2 self.log(f"Listed {len(rules)} alert rules for tenant") # 更新告警规则 - updated_rule = self.manager.update_alert_rule( - rule1.id, threshold = 85.0, description = "更新后的描述" + updated_rule = self.manager.update_alert_rule( + rule1.id, threshold=85.0, description="更新后的描述" ) assert updated_rule.threshold == 85.0 self.log(f"Updated alert rule threshold to {updated_rule.threshold}") @@ -140,7 +140,7 @@ class TestOpsManager: self.log("Deleted test alert rules") except Exception as e: - self.log(f"Alert rules test failed: {e}", success = False) + self.log(f"Alert rules test failed: {e}", success=False) def test_alert_channels(self) -> None: """测试告警渠道管理""" @@ -148,49 +148,49 @@ class TestOpsManager: try: # 创建飞书告警渠道 - channel1 = self.manager.create_alert_channel( - tenant_id = self.tenant_id, - name = "飞书告警", - channel_type = AlertChannelType.FEISHU, - config = { + channel1 = self.manager.create_alert_channel( + tenant_id=self.tenant_id, + name="飞书告警", + channel_type=AlertChannelType.FEISHU, + config={ "webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test", "secret": "test_secret", }, - severity_filter = ["p0", "p1"], + severity_filter=["p0", "p1"], ) self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})") # 创建钉钉告警渠道 - channel2 = self.manager.create_alert_channel( - tenant_id = self.tenant_id, - name = "钉钉告警", - channel_type = AlertChannelType.DINGTALK, - config = { + channel2 = self.manager.create_alert_channel( + tenant_id=self.tenant_id, + name="钉钉告警", + channel_type=AlertChannelType.DINGTALK, + config={ "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test", "secret": "test_secret", }, - severity_filter = ["p0", "p1", "p2"], + severity_filter=["p0", "p1", "p2"], ) self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})") # 创建 Slack 告警渠道 - channel3 = self.manager.create_alert_channel( - tenant_id = self.tenant_id, - name = "Slack 告警", - channel_type = AlertChannelType.SLACK, - config = {"webhook_url": "https://hooks.slack.com/services/test"}, - severity_filter = ["p0", "p1", "p2", "p3"], + channel3 = self.manager.create_alert_channel( + tenant_id=self.tenant_id, + name="Slack 告警", + channel_type=AlertChannelType.SLACK, + config={"webhook_url": "https://hooks.slack.com/services/test"}, + severity_filter=["p0", "p1", "p2", "p3"], ) self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})") # 获取告警渠道 - fetched_channel = self.manager.get_alert_channel(channel1.id) + fetched_channel = self.manager.get_alert_channel(channel1.id) assert fetched_channel is not None assert fetched_channel.name == channel1.name self.log(f"Fetched alert channel: {fetched_channel.name}") # 列出租户的所有告警渠道 - channels = self.manager.list_alert_channels(self.tenant_id) + channels = self.manager.list_alert_channels(self.tenant_id) assert len(channels) >= 3 self.log(f"Listed {len(channels)} alert channels for tenant") @@ -203,7 +203,7 @@ class TestOpsManager: self.log("Deleted test alert channels") except Exception as e: - self.log(f"Alert channels test failed: {e}", success = False) + self.log(f"Alert channels test failed: {e}", success=False) def test_alerts(self) -> None: """测试告警管理""" @@ -211,61 +211,61 @@ class TestOpsManager: try: # 创建告警规则 - rule = self.manager.create_alert_rule( - tenant_id = self.tenant_id, - name = "测试告警规则", - description = "用于测试的告警规则", - rule_type = AlertRuleType.THRESHOLD, - severity = AlertSeverity.P1, - metric = "test_metric", - condition = ">", - threshold = 100.0, - duration = 60, - evaluation_interval = 60, - channels = [], - labels = {}, - annotations = {}, - created_by = "test_user", + rule = self.manager.create_alert_rule( + tenant_id=self.tenant_id, + name="测试告警规则", + description="用于测试的告警规则", + rule_type=AlertRuleType.THRESHOLD, + severity=AlertSeverity.P1, + metric="test_metric", + condition=">", + threshold=100.0, + duration=60, + evaluation_interval=60, + channels=[], + labels={}, + annotations={}, + created_by="test_user", ) # 记录资源指标 for i in range(10): self.manager.record_resource_metric( - tenant_id = self.tenant_id, - resource_type = ResourceType.CPU, - resource_id = "server-001", - metric_name = "test_metric", - metric_value = 110.0 + i, - unit = "percent", - metadata = {"region": "cn-north-1"}, + tenant_id=self.tenant_id, + resource_type=ResourceType.CPU, + resource_id="server-001", + metric_name="test_metric", + metric_value=110.0 + i, + unit="percent", + metadata={"region": "cn-north-1"}, ) self.log("Recorded 10 resource metrics") # 手动创建告警 from ops_manager import Alert - alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}" - now = datetime.now().isoformat() + alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}" + now = datetime.now().isoformat() - alert = Alert( - id = alert_id, - rule_id = rule.id, - tenant_id = self.tenant_id, - severity = AlertSeverity.P1, - status = AlertStatus.FIRING, - title = "测试告警", - description = "这是一条测试告警", - metric = "test_metric", - value = 120.0, - threshold = 100.0, - labels = {"test": "true"}, - annotations = {}, - started_at = now, - resolved_at = None, - acknowledged_by = None, - acknowledged_at = None, - notification_sent = {}, - suppression_count = 0, + alert = Alert( + id=alert_id, + rule_id=rule.id, + tenant_id=self.tenant_id, + severity=AlertSeverity.P1, + status=AlertStatus.FIRING, + title="测试告警", + description="这是一条测试告警", + metric="test_metric", + value=120.0, + threshold=100.0, + labels={"test": "true"}, + annotations={}, + started_at=now, + resolved_at=None, + acknowledged_by=None, + acknowledged_at=None, + notification_sent={}, + suppression_count=0, ) with self.manager._get_db() as conn: @@ -299,20 +299,20 @@ class TestOpsManager: self.log(f"Created test alert: {alert.id}") # 列出租户的告警 - alerts = self.manager.list_alerts(self.tenant_id) + alerts = self.manager.list_alerts(self.tenant_id) assert len(alerts) >= 1 self.log(f"Listed {len(alerts)} alerts for tenant") # 确认告警 self.manager.acknowledge_alert(alert_id, "test_user") - fetched_alert = self.manager.get_alert(alert_id) + fetched_alert = self.manager.get_alert(alert_id) assert fetched_alert.status == AlertStatus.ACKNOWLEDGED assert fetched_alert.acknowledged_by == "test_user" self.log(f"Acknowledged alert: {alert_id}") # 解决告警 self.manager.resolve_alert(alert_id) - fetched_alert = self.manager.get_alert(alert_id) + fetched_alert = self.manager.get_alert(alert_id) assert fetched_alert.status == AlertStatus.RESOLVED assert fetched_alert.resolved_at is not None self.log(f"Resolved alert: {alert_id}") @@ -326,7 +326,7 @@ class TestOpsManager: self.log("Cleaned up test data") except Exception as e: - self.log(f"Alerts test failed: {e}", success = False) + self.log(f"Alerts test failed: {e}", success=False) def test_capacity_planning(self) -> None: """测试容量规划""" @@ -334,9 +334,9 @@ class TestOpsManager: try: # 记录历史指标数据 - base_time = datetime.now() - timedelta(days = 30) + base_time = datetime.now() - timedelta(days=30) for i in range(30): - timestamp = (base_time + timedelta(days = i)).isoformat() + timestamp = (base_time + timedelta(days=i)).isoformat() with self.manager._get_db() as conn: conn.execute( """ @@ -360,13 +360,13 @@ class TestOpsManager: self.log("Recorded 30 days of historical metrics") # 创建容量规划 - prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d") - plan = self.manager.create_capacity_plan( - tenant_id = self.tenant_id, - resource_type = ResourceType.CPU, - current_capacity = 100.0, - prediction_date = prediction_date, - confidence = 0.85, + prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") + plan = self.manager.create_capacity_plan( + tenant_id=self.tenant_id, + resource_type=ResourceType.CPU, + current_capacity=100.0, + prediction_date=prediction_date, + confidence=0.85, ) self.log(f"Created capacity plan: {plan.id}") @@ -375,7 +375,7 @@ class TestOpsManager: self.log(f" Recommended action: {plan.recommended_action}") # 获取容量规划列表 - plans = self.manager.get_capacity_plans(self.tenant_id) + plans = self.manager.get_capacity_plans(self.tenant_id) assert len(plans) >= 1 self.log(f"Listed {len(plans)} capacity plans") @@ -387,7 +387,7 @@ class TestOpsManager: self.log("Cleaned up capacity planning test data") except Exception as e: - self.log(f"Capacity planning test failed: {e}", success = False) + self.log(f"Capacity planning test failed: {e}", success=False) def test_auto_scaling(self) -> None: """测试自动扩缩容""" @@ -395,18 +395,18 @@ class TestOpsManager: try: # 创建自动扩缩容策略 - policy = self.manager.create_auto_scaling_policy( - tenant_id = self.tenant_id, - name = "API 服务自动扩缩容", - resource_type = ResourceType.CPU, - min_instances = 2, - max_instances = 10, - target_utilization = 0.7, - scale_up_threshold = 0.8, - scale_down_threshold = 0.3, - scale_up_step = 2, - scale_down_step = 1, - cooldown_period = 300, + policy = self.manager.create_auto_scaling_policy( + tenant_id=self.tenant_id, + name="API 服务自动扩缩容", + resource_type=ResourceType.CPU, + min_instances=2, + max_instances=10, + target_utilization=0.7, + scale_up_threshold=0.8, + scale_down_threshold=0.3, + scale_up_step=2, + scale_down_step=1, + cooldown_period=300, ) self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})") @@ -415,13 +415,13 @@ class TestOpsManager: self.log(f" Target utilization: {policy.target_utilization}") # 获取策略列表 - policies = self.manager.list_auto_scaling_policies(self.tenant_id) + policies = self.manager.list_auto_scaling_policies(self.tenant_id) assert len(policies) >= 1 self.log(f"Listed {len(policies)} auto scaling policies") # 模拟扩缩容评估 - event = self.manager.evaluate_scaling_policy( - policy_id = policy.id, current_instances = 3, current_utilization = 0.85 + event = self.manager.evaluate_scaling_policy( + policy_id=policy.id, current_instances=3, current_utilization=0.85 ) if event: @@ -432,7 +432,7 @@ class TestOpsManager: self.log("No scaling action needed") # 获取扩缩容事件列表 - events = self.manager.list_scaling_events(self.tenant_id) + events = self.manager.list_scaling_events(self.tenant_id) self.log(f"Listed {len(events)} scaling events") # 清理 @@ -445,7 +445,7 @@ class TestOpsManager: self.log("Cleaned up auto scaling test data") except Exception as e: - self.log(f"Auto scaling test failed: {e}", success = False) + self.log(f"Auto scaling test failed: {e}", success=False) def test_health_checks(self) -> None: """测试健康检查""" @@ -453,41 +453,41 @@ class TestOpsManager: try: # 创建 HTTP 健康检查 - check1 = self.manager.create_health_check( - tenant_id = self.tenant_id, - name = "API 服务健康检查", - target_type = "service", - target_id = "api-service", - check_type = "http", - check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200}, - interval = 60, - timeout = 10, - retry_count = 3, + check1 = self.manager.create_health_check( + tenant_id=self.tenant_id, + name="API 服务健康检查", + target_type="service", + target_id="api-service", + check_type="http", + check_config={"url": "https://api.insightflow.io/health", "expected_status": 200}, + interval=60, + timeout=10, + retry_count=3, ) self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})") # 创建 TCP 健康检查 - check2 = self.manager.create_health_check( - tenant_id = self.tenant_id, - name = "数据库健康检查", - target_type = "database", - target_id = "postgres-001", - check_type = "tcp", - check_config = {"host": "db.insightflow.io", "port": 5432}, - interval = 30, - timeout = 5, - retry_count = 2, + check2 = self.manager.create_health_check( + tenant_id=self.tenant_id, + name="数据库健康检查", + target_type="database", + target_id="postgres-001", + check_type="tcp", + check_config={"host": "db.insightflow.io", "port": 5432}, + interval=30, + timeout=5, + retry_count=2, ) self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})") # 获取健康检查列表 - checks = self.manager.list_health_checks(self.tenant_id) + checks = self.manager.list_health_checks(self.tenant_id) assert len(checks) >= 2 self.log(f"Listed {len(checks)} health checks") # 执行健康检查(异步) async def run_health_check() -> None: - result = await self.manager.execute_health_check(check1.id) + result = await self.manager.execute_health_check(check1.id) return result # 由于健康检查需要网络,这里只验证方法存在 @@ -500,7 +500,7 @@ class TestOpsManager: self.log("Cleaned up health check test data") except Exception as e: - self.log(f"Health checks test failed: {e}", success = False) + self.log(f"Health checks test failed: {e}", success=False) def test_failover(self) -> None: """测试故障转移""" @@ -508,15 +508,15 @@ class TestOpsManager: try: # 创建故障转移配置 - config = self.manager.create_failover_config( - tenant_id = self.tenant_id, - name = "主备数据中心故障转移", - primary_region = "cn-north-1", - secondary_regions = ["cn-south-1", "cn-east-1"], - failover_trigger = "health_check_failed", - auto_failover = False, - failover_timeout = 300, - health_check_id = None, + config = self.manager.create_failover_config( + tenant_id=self.tenant_id, + name="主备数据中心故障转移", + primary_region="cn-north-1", + secondary_regions=["cn-south-1", "cn-east-1"], + failover_trigger="health_check_failed", + auto_failover=False, + failover_timeout=300, + health_check_id=None, ) self.log(f"Created failover config: {config.name} (ID: {config.id})") @@ -524,13 +524,13 @@ class TestOpsManager: self.log(f" Secondary regions: {config.secondary_regions}") # 获取故障转移配置列表 - configs = self.manager.list_failover_configs(self.tenant_id) + configs = self.manager.list_failover_configs(self.tenant_id) assert len(configs) >= 1 self.log(f"Listed {len(configs)} failover configs") # 发起故障转移 - event = self.manager.initiate_failover( - config_id = config.id, reason = "Primary region health check failed" + event = self.manager.initiate_failover( + config_id=config.id, reason="Primary region health check failed" ) if event: @@ -540,12 +540,12 @@ class TestOpsManager: # 更新故障转移状态 self.manager.update_failover_status(event.id, "completed") - updated_event = self.manager.get_failover_event(event.id) + updated_event = self.manager.get_failover_event(event.id) assert updated_event.status == "completed" self.log("Failover completed") # 获取故障转移事件列表 - events = self.manager.list_failover_events(self.tenant_id) + events = self.manager.list_failover_events(self.tenant_id) self.log(f"Listed {len(events)} failover events") # 清理 @@ -556,7 +556,7 @@ class TestOpsManager: self.log("Cleaned up failover test data") except Exception as e: - self.log(f"Failover test failed: {e}", success = False) + self.log(f"Failover test failed: {e}", success=False) def test_backup(self) -> None: """测试备份与恢复""" @@ -564,17 +564,17 @@ class TestOpsManager: try: # 创建备份任务 - job = self.manager.create_backup_job( - tenant_id = self.tenant_id, - name = "每日数据库备份", - backup_type = "full", - target_type = "database", - target_id = "postgres-main", - schedule = "0 2 * * *", # 每天凌晨2点 - retention_days = 30, - encryption_enabled = True, - compression_enabled = True, - storage_location = "s3://insightflow-backups/", + job = self.manager.create_backup_job( + tenant_id=self.tenant_id, + name="每日数据库备份", + backup_type="full", + target_type="database", + target_id="postgres-main", + schedule="0 2 * * *", # 每天凌晨2点 + retention_days=30, + encryption_enabled=True, + compression_enabled=True, + storage_location="s3://insightflow-backups/", ) self.log(f"Created backup job: {job.name} (ID: {job.id})") @@ -582,12 +582,12 @@ class TestOpsManager: self.log(f" Retention: {job.retention_days} days") # 获取备份任务列表 - jobs = self.manager.list_backup_jobs(self.tenant_id) + jobs = self.manager.list_backup_jobs(self.tenant_id) assert len(jobs) >= 1 self.log(f"Listed {len(jobs)} backup jobs") # 执行备份 - record = self.manager.execute_backup(job.id) + record = self.manager.execute_backup(job.id) if record: self.log(f"Executed backup: {record.id}") @@ -595,11 +595,11 @@ class TestOpsManager: self.log(f" Storage: {record.storage_path}") # 获取备份记录列表 - records = self.manager.list_backup_records(self.tenant_id) + records = self.manager.list_backup_records(self.tenant_id) self.log(f"Listed {len(records)} backup records") # 测试恢复(模拟) - restore_result = self.manager.restore_from_backup(record.id) + restore_result = self.manager.restore_from_backup(record.id) self.log(f"Restore test result: {restore_result}") # 清理 @@ -610,7 +610,7 @@ class TestOpsManager: self.log("Cleaned up backup test data") except Exception as e: - self.log(f"Backup test failed: {e}", success = False) + self.log(f"Backup test failed: {e}", success=False) def test_cost_optimization(self) -> None: """测试成本优化""" @@ -618,27 +618,27 @@ class TestOpsManager: try: # 记录资源利用率数据 - report_date = datetime.now().strftime("%Y-%m-%d") + report_date = datetime.now().strftime("%Y-%m-%d") for i in range(5): self.manager.record_resource_utilization( - tenant_id = self.tenant_id, - resource_type = ResourceType.CPU, - resource_id = f"server-{i:03d}", - utilization_rate = 0.05 + random.random() * 0.1, # 低利用率 - peak_utilization = 0.15, - avg_utilization = 0.08, - idle_time_percent = 0.85, - report_date = report_date, - recommendations = ["Consider downsizing this resource"], + tenant_id=self.tenant_id, + resource_type=ResourceType.CPU, + resource_id=f"server-{i:03d}", + utilization_rate=0.05 + random.random() * 0.1, # 低利用率 + peak_utilization=0.15, + avg_utilization=0.08, + idle_time_percent=0.85, + report_date=report_date, + recommendations=["Consider downsizing this resource"], ) self.log("Recorded 5 resource utilization records") # 生成成本报告 - now = datetime.now() - report = self.manager.generate_cost_report( - tenant_id = self.tenant_id, year = now.year, month = now.month + now = datetime.now() + report = self.manager.generate_cost_report( + tenant_id=self.tenant_id, year=now.year, month=now.month ) self.log(f"Generated cost report: {report.id}") @@ -647,11 +647,11 @@ class TestOpsManager: self.log(f" Anomalies detected: {len(report.anomalies)}") # 检测闲置资源 - idle_resources = self.manager.detect_idle_resources(self.tenant_id) + idle_resources = self.manager.detect_idle_resources(self.tenant_id) self.log(f"Detected {len(idle_resources)} idle resources") # 获取闲置资源列表 - idle_list = self.manager.get_idle_resources(self.tenant_id) + idle_list = self.manager.get_idle_resources(self.tenant_id) for resource in idle_list: self.log( f" Idle resource: {resource.resource_name} (est. cost: { @@ -660,7 +660,7 @@ class TestOpsManager: ) # 生成成本优化建议 - suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id) + suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id) self.log(f"Generated {len(suggestions)} cost optimization suggestions") for suggestion in suggestions: @@ -672,12 +672,12 @@ class TestOpsManager: self.log(f" Difficulty: {suggestion.difficulty}") # 获取优化建议列表 - all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id) + all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id) self.log(f"Listed {len(all_suggestions)} optimization suggestions") # 应用优化建议 if all_suggestions: - applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id) + applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id) if applied: self.log(f"Applied optimization suggestion: {applied.title}") assert applied.is_applied @@ -698,7 +698,7 @@ class TestOpsManager: self.log("Cleaned up cost optimization test data") except Exception as e: - self.log(f"Cost optimization test failed: {e}", success = False) + self.log(f"Cost optimization test failed: {e}", success=False) def print_summary(self) -> None: """打印测试总结""" @@ -706,9 +706,9 @@ class TestOpsManager: print("Test Summary") print(" = " * 60) - total = len(self.test_results) - passed = sum(1 for _, success in self.test_results if success) - failed = total - passed + total = len(self.test_results) + passed = sum(1 for _, success in self.test_results if success) + failed = total - passed print(f"Total tests: {total}") print(f"Passed: {passed} ✅") @@ -725,7 +725,7 @@ class TestOpsManager: def main() -> None: """主函数""" - test = TestOpsManager() + test = TestOpsManager() test.run_all_tests() diff --git a/backend/tingwu_client.py b/backend/tingwu_client.py index de330f5..4b35601 100644 --- a/backend/tingwu_client.py +++ b/backend/tingwu_client.py @@ -11,18 +11,18 @@ from typing import Any class TingwuClient: def __init__(self) -> None: - self.access_key = os.getenv("ALI_ACCESS_KEY", "") - self.secret_key = os.getenv("ALI_SECRET_KEY", "") - self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com" + self.access_key = os.getenv("ALI_ACCESS_KEY", "") + self.secret_key = os.getenv("ALI_SECRET_KEY", "") + self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com" if not self.access_key or not self.secret_key: raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required") def _sign_request( - self, method: str, uri: str, query: str = "", body: str = "" + self, method: str, uri: str, query: str = "", body: str = "" ) -> dict[str, str]: """阿里云签名 V3""" - timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") # 简化签名,实际生产需要完整实现 # 这里使用基础认证头 @@ -34,7 +34,7 @@ class TingwuClient: "Authorization": f"ACS3-HMAC-SHA256 Credential = {self.access_key}/acs/tingwu/cn-beijing", } - def create_task(self, audio_url: str, language: str = "zh") -> str: + def create_task(self, audio_url: str, language: str = "zh") -> str: """创建听悟任务""" try: # 导入移到文件顶部会导致循环导入,保持在这里 @@ -42,23 +42,23 @@ class TingwuClient: from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient - config = open_api_models.Config( - access_key_id = self.access_key, access_key_secret = self.secret_key + config = open_api_models.Config( + access_key_id=self.access_key, access_key_secret=self.secret_key ) - config.endpoint = "tingwu.cn-beijing.aliyuncs.com" - client = TingwuSDKClient(config) + config.endpoint = "tingwu.cn-beijing.aliyuncs.com" + client = TingwuSDKClient(config) - request = tingwu_models.CreateTaskRequest( - type = "offline", - input = tingwu_models.Input(source = "OSS", file_url = audio_url), - parameters = tingwu_models.Parameters( - transcription = tingwu_models.Transcription( - diarization_enabled = True, sentence_max_length = 20 + request = tingwu_models.CreateTaskRequest( + type="offline", + input=tingwu_models.Input(source="OSS", file_url=audio_url), + parameters=tingwu_models.Parameters( + transcription=tingwu_models.Transcription( + diarization_enabled=True, sentence_max_length=20 ) ), ) - response = client.create_task(request) + response = client.create_task(request) if response.body.code == "0": return response.body.data.task_id else: @@ -73,26 +73,29 @@ class TingwuClient: return f"mock_task_{int(time.time())}" def get_task_result( - self, task_id: str, max_retries: int = 60, interval: int = 5 + self, task_id: str, max_retries: int = 60, interval: int = 5 ) -> dict[str, Any]: """获取任务结果""" try: # 导入移到文件顶部会导致循环导入,保持在这里 + from alibabacloud_tingwu20230930 import models as tingwu_models + from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient + from alibabacloud_openapi_util import models as open_api_models - config = open_api_models.Config( - access_key_id = self.access_key, access_key_secret = self.secret_key + config = open_api_models.Config( + access_key_id=self.access_key, access_key_secret=self.secret_key ) - config.endpoint = "tingwu.cn-beijing.aliyuncs.com" - client = TingwuSDKClient(config) + config.endpoint = "tingwu.cn-beijing.aliyuncs.com" + client = TingwuSDKClient(config) for i in range(max_retries): - request = tingwu_models.GetTaskInfoRequest() - response = client.get_task_info(task_id, request) + request = tingwu_models.GetTaskInfoRequest() + response = client.get_task_info(task_id, request) if response.body.code != "0": raise RuntimeError(f"Query failed: {response.body.message}") - status = response.body.data.task_status + status = response.body.data.task_status if status == "SUCCESS": return self._parse_result(response.body.data) @@ -113,11 +116,11 @@ class TingwuClient: def _parse_result(self, data) -> dict[str, Any]: """解析结果""" - result = data.result - transcription = result.transcription + result = data.result + transcription = result.transcription - full_text = "" - segments = [] + full_text = "" + segments = [] if transcription.paragraphs: for para in transcription.paragraphs: @@ -150,8 +153,8 @@ class TingwuClient: ], } - def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]: + def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]: """一键转录""" - task_id = self.create_task(audio_url, language) + task_id = self.create_task(audio_url, language) print(f"Tingwu task: {task_id}") return self.get_task_result(task_id) diff --git a/backend/workflow_manager.py b/backend/workflow_manager.py index e6f29ce..033ff0a 100644 --- a/backend/workflow_manager.py +++ b/backend/workflow_manager.py @@ -27,56 +27,55 @@ from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.interval import IntervalTrigger -from workflow_manager import WorkflowManager import urllib.parse # Constants -UUID_LENGTH = 8 # UUID 截断长度 -DEFAULT_TIMEOUT = 300 # 默认超时时间(秒) -DEFAULT_RETRY_COUNT = 3 # 默认重试次数 -DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒) +UUID_LENGTH = 8 # UUID 截断长度 +DEFAULT_TIMEOUT = 300 # 默认超时时间(秒) +DEFAULT_RETRY_COUNT = 3 # 默认重试次数 +DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒) # Configure logging -logging.basicConfig(level = logging.INFO) -logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) class WorkflowStatus(Enum): """工作流状态""" - ACTIVE = "active" - PAUSED = "paused" - ERROR = "error" - COMPLETED = "completed" + ACTIVE = "active" + PAUSED = "paused" + ERROR = "error" + COMPLETED = "completed" class WorkflowType(Enum): """工作流类型""" - AUTO_ANALYZE = "auto_analyze" # 自动分析新文件 - AUTO_ALIGN = "auto_align" # 自动实体对齐 - AUTO_RELATION = "auto_relation" # 自动关系发现 - SCHEDULED_REPORT = "scheduled_report" # 定时报告 - CUSTOM = "custom" # 自定义工作流 + AUTO_ANALYZE = "auto_analyze" # 自动分析新文件 + AUTO_ALIGN = "auto_align" # 自动实体对齐 + AUTO_RELATION = "auto_relation" # 自动关系发现 + SCHEDULED_REPORT = "scheduled_report" # 定时报告 + CUSTOM = "custom" # 自定义工作流 class WebhookType(Enum): """Webhook 类型""" - FEISHU = "feishu" - DINGTALK = "dingtalk" - SLACK = "slack" - CUSTOM = "custom" + FEISHU = "feishu" + DINGTALK = "dingtalk" + SLACK = "slack" + CUSTOM = "custom" class TaskStatus(Enum): """任务执行状态""" - PENDING = "pending" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - CANCELLED = "cancelled" + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + CANCELLED = "cancelled" @dataclass @@ -87,20 +86,20 @@ class WorkflowTask: workflow_id: str name: str task_type: str # analyze, align, discover_relations, notify, custom - config: dict = field(default_factory = dict) - order: int = 0 - depends_on: list[str] = field(default_factory = list) - timeout_seconds: int = 300 - retry_count: int = 3 - retry_delay: int = 5 - created_at: str = "" - updated_at: str = "" + config: dict = field(default_factory=dict) + order: int = 0 + depends_on: list[str] = field(default_factory=list) + timeout_seconds: int = 300 + retry_count: int = 3 + retry_delay: int = 5 + created_at: str = "" + updated_at: str = "" def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() if not self.updated_at: - self.updated_at = self.created_at + self.updated_at = self.created_at @dataclass @@ -111,21 +110,21 @@ class WebhookConfig: name: str webhook_type: str # feishu, dingtalk, slack, custom url: str - secret: str = "" # 用于签名验证 - headers: dict = field(default_factory = dict) - template: str = "" # 消息模板 - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_used_at: str | None = None - success_count: int = 0 - fail_count: int = 0 + secret: str = "" # 用于签名验证 + headers: dict = field(default_factory=dict) + template: str = "" # 消息模板 + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_used_at: str | None = None + success_count: int = 0 + fail_count: int = 0 def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() if not self.updated_at: - self.updated_at = self.created_at + self.updated_at = self.created_at @dataclass @@ -137,25 +136,25 @@ class Workflow: description: str workflow_type: str project_id: str - status: str = "active" - schedule: str | None = None # cron expression or interval - schedule_type: str = "manual" # manual, cron, interval - config: dict = field(default_factory = dict) - webhook_ids: list[str] = field(default_factory = list) - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_run_at: str | None = None - next_run_at: str | None = None - run_count: int = 0 - success_count: int = 0 - fail_count: int = 0 + status: str = "active" + schedule: str | None = None # cron expression or interval + schedule_type: str = "manual" # manual, cron, interval + config: dict = field(default_factory=dict) + webhook_ids: list[str] = field(default_factory=list) + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_run_at: str | None = None + next_run_at: str | None = None + run_count: int = 0 + success_count: int = 0 + fail_count: int = 0 def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() if not self.updated_at: - self.updated_at = self.created_at + self.updated_at = self.created_at @dataclass @@ -164,31 +163,31 @@ class WorkflowLog: id: str workflow_id: str - task_id: str | None = None - status: str = "pending" # pending, running, success, failed, cancelled - start_time: str | None = None - end_time: str | None = None - duration_ms: int = 0 - input_data: dict = field(default_factory = dict) - output_data: dict = field(default_factory = dict) - error_message: str = "" - created_at: str = "" + task_id: str | None = None + status: str = "pending" # pending, running, success, failed, cancelled + start_time: str | None = None + end_time: str | None = None + duration_ms: int = 0 + input_data: dict = field(default_factory=dict) + output_data: dict = field(default_factory=dict) + error_message: str = "" + created_at: str = "" def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() class WebhookNotifier: """Webhook 通知器 - 支持飞书、钉钉、Slack""" def __init__(self) -> None: - self.http_client = httpx.AsyncClient(timeout = 30.0) + self.http_client = httpx.AsyncClient(timeout=30.0) async def send(self, config: WebhookConfig, message: dict) -> bool: """发送 Webhook 通知""" try: - webhook_type = WebhookType(config.webhook_type) + webhook_type = WebhookType(config.webhook_type) if webhook_type == WebhookType.FEISHU: return await self._send_feishu(config, message) @@ -205,20 +204,20 @@ class WebhookNotifier: async def _send_feishu(self, config: WebhookConfig, message: dict) -> bool: """发送飞书通知""" - timestamp = str(int(datetime.now().timestamp())) + timestamp = str(int(datetime.now().timestamp())) # 签名计算 if config.secret: - string_to_sign = f"{timestamp}\n{config.secret}" - hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod = hashlib.sha256).digest() - sign = base64.b64encode(hmac_code).decode("utf-8") + string_to_sign = f"{timestamp}\n{config.secret}" + hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() + sign = base64.b64encode(hmac_code).decode("utf-8") else: - sign = "" + sign = "" # 构建消息体 if "content" in message: # 文本消息 - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "text", @@ -226,7 +225,7 @@ class WebhookNotifier: } elif "title" in message: # 富文本消息 - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "post", @@ -241,47 +240,47 @@ class WebhookNotifier: } else: # 卡片消息 - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "interactive", "card": message.get("card", {}), } - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(config.url, json = payload, headers = headers) + response = await self.http_client.post(config.url, json=payload, headers=headers) response.raise_for_status() - result = response.json() + result = response.json() return result.get("code") == 0 async def _send_dingtalk(self, config: WebhookConfig, message: dict) -> bool: """发送钉钉通知""" - timestamp = str(round(datetime.now().timestamp() * 1000)) + timestamp = str(round(datetime.now().timestamp() * 1000)) # 签名计算 if config.secret: - secret_enc = config.secret.encode("utf-8") - string_to_sign = f"{timestamp}\n{config.secret}" - hmac_code = hmac.new( - secret_enc, string_to_sign.encode("utf-8"), digestmod = hashlib.sha256 + secret_enc = config.secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{config.secret}" + hmac_code = hmac.new( + secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256 ).digest() - sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - url = f"{config.url}×tamp = {timestamp}&sign = {sign}" + sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) + url = f"{config.url}×tamp = {timestamp}&sign = {sign}" else: - url = config.url + url = config.url # 构建消息体 if "content" in message: - payload = {"msgtype": "text", "text": {"content": message["content"]}} + payload = {"msgtype": "text", "text": {"content": message["content"]}} elif "title" in message: - payload = { + payload = { "msgtype": "markdown", "markdown": {"title": message["title"], "text": message.get("markdown", "")}, } elif "link" in message: - payload = { + payload = { "msgtype": "link", "link": { "text": message.get("text", ""), @@ -291,41 +290,41 @@ class WebhookNotifier: }, } else: - payload = {"msgtype": "action_card", "action_card": message.get("action_card", {})} + payload = {"msgtype": "action_card", "action_card": message.get("action_card", {})} - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(url, json = payload, headers = headers) + response = await self.http_client.post(url, json=payload, headers=headers) response.raise_for_status() - result = response.json() + result = response.json() return result.get("errcode") == 0 async def _send_slack(self, config: WebhookConfig, message: dict) -> bool: """发送 Slack 通知""" # Slack 直接支持标准 webhook 格式 - payload = { + payload = { "text": message.get("content", message.get("text", "")), } if "blocks" in message: - payload["blocks"] = message["blocks"] + payload["blocks"] = message["blocks"] if "attachments" in message: - payload["attachments"] = message["attachments"] + payload["attachments"] = message["attachments"] - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(config.url, json = payload, headers = headers) + response = await self.http_client.post(config.url, json=payload, headers=headers) response.raise_for_status() return response.text == "ok" async def _send_custom(self, config: WebhookConfig, message: dict) -> bool: """发送自定义 Webhook 通知""" - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(config.url, json = message, headers = headers) + response = await self.http_client.post(config.url, json=message, headers=headers) response.raise_for_status() return True @@ -339,16 +338,16 @@ class WorkflowManager: """工作流管理器 - 核心管理类""" # 默认配置常量 - DEFAULT_TIMEOUT: int = 300 - DEFAULT_RETRY_COUNT: int = 3 - DEFAULT_RETRY_DELAY: int = 5 + DEFAULT_TIMEOUT: int = 300 + DEFAULT_RETRY_COUNT: int = 3 + DEFAULT_RETRY_DELAY: int = 5 - def __init__(self, db_manager = None) -> None: - self.db = db_manager - self.scheduler = AsyncIOScheduler() - self.notifier = WebhookNotifier() - self._task_handlers: dict[str, Callable] = {} - self._running_tasks: dict[str, asyncio.Task] = {} + def __init__(self, db_manager=None) -> None: + self.db = db_manager + self.scheduler = AsyncIOScheduler() + self.notifier = WebhookNotifier() + self._task_handlers: dict[str, Callable] = {} + self._running_tasks: dict[str, asyncio.Task] = {} self._setup_default_handlers() # 添加调度器事件监听 @@ -356,7 +355,7 @@ class WorkflowManager: def _setup_default_handlers(self) -> None: """设置默认的任务处理器""" - self._task_handlers = { + self._task_handlers = { "analyze": self._handle_analyze_task, "align": self._handle_align_task, "discover_relations": self._handle_discover_relations_task, @@ -366,7 +365,7 @@ class WorkflowManager: def register_task_handler(self, task_type: str, handler: Callable) -> None: """注册自定义任务处理器""" - self._task_handlers[task_type] = handler + self._task_handlers[task_type] = handler def start(self) -> None: """启动工作流管理器""" @@ -381,13 +380,13 @@ class WorkflowManager: def stop(self) -> None: """停止工作流管理器""" if self.scheduler.running: - self.scheduler.shutdown(wait = True) + self.scheduler.shutdown(wait=True) logger.info("Workflow scheduler stopped") async def _load_and_schedule_workflows(self) -> None: """从数据库加载并调度所有活跃工作流""" try: - workflows = self.list_workflows(status = "active") + workflows = self.list_workflows(status="active") for workflow in workflows: if workflow.schedule and workflow.is_active: self._schedule_workflow(workflow) @@ -396,7 +395,7 @@ class WorkflowManager: def _schedule_workflow(self, workflow: Workflow) -> None: """调度工作流""" - job_id = f"workflow_{workflow.id}" + job_id = f"workflow_{workflow.id}" # 移除已存在的任务 if self.scheduler.get_job(job_id): @@ -404,22 +403,22 @@ class WorkflowManager: if workflow.schedule_type == "cron": # Cron 表达式调度 - trigger = CronTrigger.from_crontab(workflow.schedule) + trigger = CronTrigger.from_crontab(workflow.schedule) elif workflow.schedule_type == "interval": # 间隔调度 - interval_minutes = int(workflow.schedule) - trigger = IntervalTrigger(minutes = interval_minutes) + interval_minutes = int(workflow.schedule) + trigger = IntervalTrigger(minutes=interval_minutes) else: return self.scheduler.add_job( - func = self._execute_workflow_job, - trigger = trigger, - id = job_id, - args = [workflow.id], - replace_existing = True, - max_instances = 1, - coalesce = True, + func=self._execute_workflow_job, + trigger=trigger, + id=job_id, + args=[workflow.id], + replace_existing=True, + max_instances=1, + coalesce=True, ) logger.info( @@ -444,7 +443,7 @@ class WorkflowManager: def create_workflow(self, workflow: Workflow) -> Workflow: """创建工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO workflows @@ -486,9 +485,9 @@ class WorkflowManager: def get_workflow(self, workflow_id: str) -> Workflow | None: """获取工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id, )).fetchone() + row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id, )).fetchone() if not row: return None @@ -498,13 +497,13 @@ class WorkflowManager: conn.close() def list_workflows( - self, project_id: str = None, status: str = None, workflow_type: str = None + self, project_id: str = None, status: str = None, workflow_type: str = None ) -> list[Workflow]: """列出工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - conditions = [] - params = [] + conditions = [] + params = [] if project_id: conditions.append("project_id = ?") @@ -516,9 +515,9 @@ class WorkflowManager: conditions.append("workflow_type = ?") params.append(workflow_type) - where_clause = " AND ".join(conditions) if conditions else "1 = 1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params ).fetchall() @@ -528,9 +527,9 @@ class WorkflowManager: def update_workflow(self, workflow_id: str, **kwargs) -> Workflow | None: """更新工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = [ + allowed_fields = [ "name", "description", "status", @@ -540,8 +539,8 @@ class WorkflowManager: "config", "webhook_ids", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -558,16 +557,16 @@ class WorkflowManager: values.append(datetime.now().isoformat()) values.append(workflow_id) - query = f"UPDATE workflows SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE workflows SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() # 重新调度 - workflow = self.get_workflow(workflow_id) + workflow = self.get_workflow(workflow_id) if workflow and workflow.schedule and workflow.is_active: self._schedule_workflow(workflow) elif workflow and not workflow.is_active: - job_id = f"workflow_{workflow_id}" + job_id = f"workflow_{workflow_id}" if self.scheduler.get_job(job_id): self.scheduler.remove_job(job_id) @@ -577,10 +576,10 @@ class WorkflowManager: def delete_workflow(self, workflow_id: str) -> bool: """删除工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: # 移除调度 - job_id = f"workflow_{workflow_id}" + job_id = f"workflow_{workflow_id}" if self.scheduler.get_job(job_id): self.scheduler.remove_job(job_id) @@ -598,31 +597,31 @@ class WorkflowManager: def _row_to_workflow(self, row) -> Workflow: """将数据库行转换为 Workflow 对象""" return Workflow( - id = row["id"], - name = row["name"], - description = row["description"] or "", - workflow_type = row["workflow_type"], - project_id = row["project_id"], - status = row["status"], - schedule = row["schedule"], - schedule_type = row["schedule_type"], - config = json.loads(row["config"]) if row["config"] else {}, - webhook_ids = json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], - last_run_at = row["last_run_at"], - next_run_at = row["next_run_at"], - run_count = row["run_count"] or 0, - success_count = row["success_count"] or 0, - fail_count = row["fail_count"] or 0, + id=row["id"], + name=row["name"], + description=row["description"] or "", + workflow_type=row["workflow_type"], + project_id=row["project_id"], + status=row["status"], + schedule=row["schedule"], + schedule_type=row["schedule_type"], + config=json.loads(row["config"]) if row["config"] else {}, + webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_run_at=row["last_run_at"], + next_run_at=row["next_run_at"], + run_count=row["run_count"] or 0, + success_count=row["success_count"] or 0, + fail_count=row["fail_count"] or 0, ) # ==================== Workflow Task CRUD ==================== def create_task(self, task: WorkflowTask) -> WorkflowTask: """创建工作流任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO workflow_tasks @@ -652,9 +651,9 @@ class WorkflowManager: def get_task(self, task_id: str) -> WorkflowTask | None: """获取任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id, )).fetchone() + row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id, )).fetchone() if not row: return None @@ -665,9 +664,9 @@ class WorkflowManager: def list_tasks(self, workflow_id: str) -> list[WorkflowTask]: """列出工作流的所有任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - rows = conn.execute( + rows = conn.execute( "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", (workflow_id, ), ).fetchall() @@ -678,9 +677,9 @@ class WorkflowManager: def update_task(self, task_id: str, **kwargs) -> WorkflowTask | None: """更新任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = [ + allowed_fields = [ "name", "task_type", "config", @@ -690,8 +689,8 @@ class WorkflowManager: "retry_count", "retry_delay", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -708,7 +707,7 @@ class WorkflowManager: values.append(datetime.now().isoformat()) values.append(task_id) - query = f"UPDATE workflow_tasks SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE workflow_tasks SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -718,7 +717,7 @@ class WorkflowManager: def delete_task(self, task_id: str) -> bool: """删除任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id, )) conn.commit() @@ -729,25 +728,25 @@ class WorkflowManager: def _row_to_task(self, row) -> WorkflowTask: """将数据库行转换为 WorkflowTask 对象""" return WorkflowTask( - id = row["id"], - workflow_id = row["workflow_id"], - name = row["name"], - task_type = row["task_type"], - config = json.loads(row["config"]) if row["config"] else {}, - order = row["task_order"] or 0, - depends_on = json.loads(row["depends_on"]) if row["depends_on"] else [], - timeout_seconds = row["timeout_seconds"] or 300, - retry_count = row["retry_count"] or 3, - retry_delay = row["retry_delay"] or 5, - created_at = row["created_at"], - updated_at = row["updated_at"], + id=row["id"], + workflow_id=row["workflow_id"], + name=row["name"], + task_type=row["task_type"], + config=json.loads(row["config"]) if row["config"] else {}, + order=row["task_order"] or 0, + depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [], + timeout_seconds=row["timeout_seconds"] or 300, + retry_count=row["retry_count"] or 3, + retry_delay=row["retry_delay"] or 5, + created_at=row["created_at"], + updated_at=row["updated_at"], ) # ==================== Webhook Config CRUD ==================== def create_webhook(self, webhook: WebhookConfig) -> WebhookConfig: """创建 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO webhook_configs @@ -778,9 +777,9 @@ class WorkflowManager: def get_webhook(self, webhook_id: str) -> WebhookConfig | None: """获取 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute( + row = conn.execute( "SELECT * FROM webhook_configs WHERE id = ?", (webhook_id, ) ).fetchone() @@ -793,9 +792,9 @@ class WorkflowManager: def list_webhooks(self) -> list[WebhookConfig]: """列出所有 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - rows = conn.execute("SELECT * FROM webhook_configs ORDER BY created_at DESC").fetchall() + rows = conn.execute("SELECT * FROM webhook_configs ORDER BY created_at DESC").fetchall() return [self._row_to_webhook(row) for row in rows] finally: @@ -803,9 +802,9 @@ class WorkflowManager: def update_webhook(self, webhook_id: str, **kwargs) -> WebhookConfig | None: """更新 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = [ + allowed_fields = [ "name", "webhook_type", "url", @@ -814,8 +813,8 @@ class WorkflowManager: "template", "is_active", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -832,7 +831,7 @@ class WorkflowManager: values.append(datetime.now().isoformat()) values.append(webhook_id) - query = f"UPDATE webhook_configs SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE webhook_configs SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -842,7 +841,7 @@ class WorkflowManager: def delete_webhook(self, webhook_id: str) -> bool: """删除 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id, )) conn.commit() @@ -852,7 +851,7 @@ class WorkflowManager: def update_webhook_stats(self, webhook_id: str, success: bool) -> None: """更新 Webhook 统计""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: if success: conn.execute( @@ -875,26 +874,26 @@ class WorkflowManager: def _row_to_webhook(self, row) -> WebhookConfig: """将数据库行转换为 WebhookConfig 对象""" return WebhookConfig( - id = row["id"], - name = row["name"], - webhook_type = row["webhook_type"], - url = row["url"], - secret = row["secret"] or "", - headers = json.loads(row["headers"]) if row["headers"] else {}, - template = row["template"] or "", - is_active = bool(row["is_active"]), - created_at = row["created_at"], - updated_at = row["updated_at"], - last_used_at = row["last_used_at"], - success_count = row["success_count"] or 0, - fail_count = row["fail_count"] or 0, + id=row["id"], + name=row["name"], + webhook_type=row["webhook_type"], + url=row["url"], + secret=row["secret"] or "", + headers=json.loads(row["headers"]) if row["headers"] else {}, + template=row["template"] or "", + is_active=bool(row["is_active"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_used_at=row["last_used_at"], + success_count=row["success_count"] or 0, + fail_count=row["fail_count"] or 0, ) # ==================== Workflow Log ==================== def create_log(self, log: WorkflowLog) -> WorkflowLog: """创建工作流日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO workflow_logs @@ -922,11 +921,11 @@ class WorkflowManager: def update_log(self, log_id: str, **kwargs) -> WorkflowLog | None: """更新工作流日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = ["status", "end_time", "duration_ms", "output_data", "error_message"] - updates = [] - values = [] + allowed_fields = ["status", "end_time", "duration_ms", "output_data", "error_message"] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: @@ -940,7 +939,7 @@ class WorkflowManager: return None values.append(log_id) - query = f"UPDATE workflow_logs SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE workflow_logs SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -950,9 +949,9 @@ class WorkflowManager: def get_log(self, log_id: str) -> WorkflowLog | None: """获取日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id, )).fetchone() + row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id, )).fetchone() if not row: return None @@ -963,17 +962,17 @@ class WorkflowManager: def list_logs( self, - workflow_id: str = None, - task_id: str = None, - status: str = None, - limit: int = 100, - offset: int = 0, + workflow_id: str = None, + task_id: str = None, + status: str = None, + limit: int = 100, + offset: int = 0, ) -> list[WorkflowLog]: """列出工作流日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - conditions = [] - params = [] + conditions = [] + params = [] if workflow_id: conditions.append("workflow_id = ?") @@ -985,9 +984,9 @@ class WorkflowManager: conditions.append("status = ?") params.append(status) - where_clause = " AND ".join(conditions) if conditions else "1 = 1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"""SELECT * FROM workflow_logs WHERE {where_clause} ORDER BY created_at DESC @@ -999,32 +998,32 @@ class WorkflowManager: finally: conn.close() - def get_workflow_stats(self, workflow_id: str, days: int = 30) -> dict: + def get_workflow_stats(self, workflow_id: str, days: int = 30) -> dict: """获取工作流统计""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - since = (datetime.now() - timedelta(days = days)).isoformat() + since = (datetime.now() - timedelta(days=days)).isoformat() # 总执行次数 - total = conn.execute( + total = conn.execute( "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since), ).fetchone()[0] # 成功次数 - success = conn.execute( + success = conn.execute( "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'success' AND created_at > ?", (workflow_id, since), ).fetchone()[0] # 失败次数 - failed = conn.execute( + failed = conn.execute( "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'failed' AND created_at > ?", (workflow_id, since), ).fetchone()[0] # 平均执行时间 - avg_duration = ( + avg_duration = ( conn.execute( "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since), @@ -1033,7 +1032,7 @@ class WorkflowManager: ) # 每日统计 - daily = conn.execute( + daily = conn.execute( """SELECT DATE(created_at) as date, COUNT(*) as count, SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success @@ -1060,24 +1059,24 @@ class WorkflowManager: def _row_to_log(self, row) -> WorkflowLog: """将数据库行转换为 WorkflowLog 对象""" return WorkflowLog( - id = row["id"], - workflow_id = row["workflow_id"], - task_id = row["task_id"], - status = row["status"], - start_time = row["start_time"], - end_time = row["end_time"], - duration_ms = row["duration_ms"] or 0, - input_data = json.loads(row["input_data"]) if row["input_data"] else {}, - output_data = json.loads(row["output_data"]) if row["output_data"] else {}, - error_message = row["error_message"] or "", - created_at = row["created_at"], + id=row["id"], + workflow_id=row["workflow_id"], + task_id=row["task_id"], + status=row["status"], + start_time=row["start_time"], + end_time=row["end_time"], + duration_ms=row["duration_ms"] or 0, + input_data=json.loads(row["input_data"]) if row["input_data"] else {}, + output_data=json.loads(row["output_data"]) if row["output_data"] else {}, + error_message=row["error_message"] or "", + created_at=row["created_at"], ) # ==================== Workflow Execution ==================== - async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict: + async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict: """执行工作流""" - workflow = self.get_workflow(workflow_id) + workflow = self.get_workflow(workflow_id) if not workflow: raise ValueError(f"Workflow {workflow_id} not found") @@ -1085,49 +1084,49 @@ class WorkflowManager: raise ValueError(f"Workflow {workflow_id} is not active") # 更新最后运行时间 - now = datetime.now().isoformat() - self.update_workflow(workflow_id, last_run_at = now, run_count = workflow.run_count + 1) + now = datetime.now().isoformat() + self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1) # 创建工作流执行日志 - log = WorkflowLog( - id = str(uuid.uuid4())[:UUID_LENGTH], - workflow_id = workflow_id, - status = TaskStatus.RUNNING.value, - start_time = now, - input_data = input_data or {}, + log = WorkflowLog( + id=str(uuid.uuid4())[:UUID_LENGTH], + workflow_id=workflow_id, + status=TaskStatus.RUNNING.value, + start_time=now, + input_data=input_data or {}, ) self.create_log(log) - start_time = datetime.now() - results = {} + start_time = datetime.now() + results = {} try: # 获取所有任务 - tasks = self.list_tasks(workflow_id) + tasks = self.list_tasks(workflow_id) if not tasks: # 没有任务时执行默认行为 - results = await self._execute_default_workflow(workflow, input_data) + results = await self._execute_default_workflow(workflow, input_data) else: # 按依赖顺序执行任务 - results = await self._execute_tasks_with_deps(tasks, input_data, log.id) + results = await self._execute_tasks_with_deps(tasks, input_data, log.id) # 发送通知 - await self._send_workflow_notification(workflow, results, success = True) + await self._send_workflow_notification(workflow, results, success=True) # 更新日志为成功 - end_time = datetime.now() - duration = int((end_time - start_time).total_seconds() * 1000) + end_time = datetime.now() + duration = int((end_time - start_time).total_seconds() * 1000) self.update_log( log.id, - status = TaskStatus.SUCCESS.value, - end_time = end_time.isoformat(), - duration_ms = duration, - output_data = results, + status=TaskStatus.SUCCESS.value, + end_time=end_time.isoformat(), + duration_ms=duration, + output_data=results, ) # 更新成功计数 - self.update_workflow(workflow_id, success_count = workflow.success_count + 1) + self.update_workflow(workflow_id, success_count=workflow.success_count + 1) return { "success": True, @@ -1141,21 +1140,21 @@ class WorkflowManager: logger.error(f"Workflow {workflow_id} execution failed: {e}") # 更新日志为失败 - end_time = datetime.now() - duration = int((end_time - start_time).total_seconds() * 1000) + end_time = datetime.now() + duration = int((end_time - start_time).total_seconds() * 1000) self.update_log( log.id, - status = TaskStatus.FAILED.value, - end_time = end_time.isoformat(), - duration_ms = duration, - error_message = str(e), + status=TaskStatus.FAILED.value, + end_time=end_time.isoformat(), + duration_ms=duration, + error_message=str(e), ) # 更新失败计数 - self.update_workflow(workflow_id, fail_count = workflow.fail_count + 1) + self.update_workflow(workflow_id, fail_count=workflow.fail_count + 1) # 发送失败通知 - await self._send_workflow_notification(workflow, {"error": str(e)}, success = False) + await self._send_workflow_notification(workflow, {"error": str(e)}, success=False) raise @@ -1163,12 +1162,12 @@ class WorkflowManager: self, tasks: list[WorkflowTask], input_data: dict, log_id: str ) -> dict: """按依赖顺序执行任务""" - results = {} - completed_tasks = set() + results = {} + completed_tasks = set() while len(completed_tasks) < len(tasks): # 找到可以执行的任务(依赖已完成) - ready_tasks = [ + ready_tasks = [ t for t in tasks if t.id not in completed_tasks @@ -1180,12 +1179,12 @@ class WorkflowManager: raise ValueError("Circular dependency detected or tasks cannot be resolved") # 并行执行就绪的任务 - task_coros = [] + task_coros = [] for task in ready_tasks: - task_input = {**input_data, **results} + task_input = {**input_data, **results} task_coros.append(self._execute_single_task(task, task_input, log_id)) - task_results = await asyncio.gather(*task_coros, return_exceptions = True) + task_results = await asyncio.gather(*task_coros, return_exceptions=True) for task, result in zip(ready_tasks, task_results): if isinstance(result, Exception): @@ -1195,7 +1194,7 @@ class WorkflowManager: for attempt in range(task.retry_count): await asyncio.sleep(task.retry_delay) try: - result = await self._execute_single_task(task, task_input, log_id) + result = await self._execute_single_task(task, task_input, log_id) break except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Task {task.id} retry {attempt + 1} failed: {e}") @@ -1204,38 +1203,38 @@ class WorkflowManager: else: raise result - results[task.name] = result + results[task.name] = result completed_tasks.add(task.id) return results async def _execute_single_task(self, task: WorkflowTask, input_data: dict, log_id: str) -> Any: """执行单个任务""" - handler = self._task_handlers.get(task.task_type) + handler = self._task_handlers.get(task.task_type) if not handler: raise ValueError(f"No handler for task type: {task.task_type}") # 创建任务日志 - task_log = WorkflowLog( - id = str(uuid.uuid4())[:UUID_LENGTH], - workflow_id = task.workflow_id, - task_id = task.id, - status = TaskStatus.RUNNING.value, - start_time = datetime.now().isoformat(), - input_data = input_data, + task_log = WorkflowLog( + id=str(uuid.uuid4())[:UUID_LENGTH], + workflow_id=task.workflow_id, + task_id=task.id, + status=TaskStatus.RUNNING.value, + start_time=datetime.now().isoformat(), + input_data=input_data, ) self.create_log(task_log) try: # 设置超时 - result = await asyncio.wait_for(handler(task, input_data), timeout = task.timeout_seconds) + result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds) # 更新任务日志为成功 self.update_log( task_log.id, - status = TaskStatus.SUCCESS.value, - end_time = datetime.now().isoformat(), - output_data = {"result": result} if not isinstance(result, dict) else result, + status=TaskStatus.SUCCESS.value, + end_time=datetime.now().isoformat(), + output_data={"result": result} if not isinstance(result, dict) else result, ) return result @@ -1243,24 +1242,24 @@ class WorkflowManager: except TimeoutError: self.update_log( task_log.id, - status = TaskStatus.FAILED.value, - end_time = datetime.now().isoformat(), - error_message = "Task timeout", + status=TaskStatus.FAILED.value, + end_time=datetime.now().isoformat(), + error_message="Task timeout", ) raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s") except Exception as e: self.update_log( task_log.id, - status = TaskStatus.FAILED.value, - end_time = datetime.now().isoformat(), - error_message = str(e), + status=TaskStatus.FAILED.value, + end_time=datetime.now().isoformat(), + error_message=str(e), ) raise async def _execute_default_workflow(self, workflow: Workflow, input_data: dict) -> dict: """执行默认工作流(根据类型)""" - workflow_type = WorkflowType(workflow.workflow_type) + workflow_type = WorkflowType(workflow.workflow_type) if workflow_type == WorkflowType.AUTO_ANALYZE: return await self._auto_analyze_files(workflow, input_data) @@ -1277,8 +1276,8 @@ class WorkflowManager: async def _handle_analyze_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理分析任务""" - project_id = input_data.get("project_id") - file_ids = input_data.get("file_ids", []) + project_id = input_data.get("project_id") + file_ids = input_data.get("file_ids", []) if not project_id: raise ValueError("project_id required for analyze task") @@ -1294,8 +1293,8 @@ class WorkflowManager: async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理实体对齐任务""" - project_id = input_data.get("project_id") - threshold = task.config.get("threshold", 0.85) + project_id = input_data.get("project_id") + threshold = task.config.get("threshold", 0.85) if not project_id: raise ValueError("project_id required for align task") @@ -1311,7 +1310,7 @@ class WorkflowManager: async def _handle_discover_relations_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理关系发现任务""" - project_id = input_data.get("project_id") + project_id = input_data.get("project_id") if not project_id: raise ValueError("project_id required for discover_relations task") @@ -1326,24 +1325,24 @@ class WorkflowManager: async def _handle_notify_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理通知任务""" - webhook_id = task.config.get("webhook_id") - message = task.config.get("message", {}) + webhook_id = task.config.get("webhook_id") + message = task.config.get("message", {}) if not webhook_id: raise ValueError("webhook_id required for notify task") - webhook = self.get_webhook(webhook_id) + webhook = self.get_webhook(webhook_id) if not webhook: raise ValueError(f"Webhook {webhook_id} not found") # 替换模板变量 if webhook.template: try: - message = json.loads(webhook.template.format(**input_data)) + message = json.loads(webhook.template.format(**input_data)) except (json.JSONDecodeError, KeyError, ValueError): pass - success = await self.notifier.send(webhook, message) + success = await self.notifier.send(webhook, message) self.update_webhook_stats(webhook_id, success) return {"task": "notify", "webhook_id": webhook_id, "success": success} @@ -1362,7 +1361,7 @@ class WorkflowManager: async def _auto_analyze_files(self, workflow: Workflow, input_data: dict) -> dict: """自动分析新上传的文件""" - project_id = workflow.project_id + project_id = workflow.project_id # 获取未分析的文件(实际实现需要查询数据库) # 这里是一个示例实现 @@ -1377,8 +1376,8 @@ class WorkflowManager: async def _auto_align_entities(self, workflow: Workflow, input_data: dict) -> dict: """自动实体对齐""" - project_id = workflow.project_id - threshold = workflow.config.get("threshold", 0.85) + project_id = workflow.project_id + threshold = workflow.config.get("threshold", 0.85) return { "workflow_type": "auto_align", @@ -1390,7 +1389,7 @@ class WorkflowManager: async def _auto_discover_relations(self, workflow: Workflow, input_data: dict) -> dict: """自动关系发现""" - project_id = workflow.project_id + project_id = workflow.project_id return { "workflow_type": "auto_relation", @@ -1401,8 +1400,8 @@ class WorkflowManager: async def _generate_scheduled_report(self, workflow: Workflow, input_data: dict) -> dict: """生成定时报告""" - project_id = workflow.project_id - report_type = workflow.config.get("report_type", "summary") + project_id = workflow.project_id + report_type = workflow.config.get("report_type", "summary") return { "workflow_type": "scheduled_report", @@ -1414,26 +1413,26 @@ class WorkflowManager: # ==================== Notification ==================== async def _send_workflow_notification( - self, workflow: Workflow, results: dict, success: bool = True + self, workflow: Workflow, results: dict, success: bool = True ) -> None: """发送工作流执行通知""" if not workflow.webhook_ids: return for webhook_id in workflow.webhook_ids: - webhook = self.get_webhook(webhook_id) + webhook = self.get_webhook(webhook_id) if not webhook or not webhook.is_active: continue # 构建通知消息 if webhook.webhook_type == WebhookType.FEISHU.value: - message = self._build_feishu_message(workflow, results, success) + message = self._build_feishu_message(workflow, results, success) elif webhook.webhook_type == WebhookType.DINGTALK.value: - message = self._build_dingtalk_message(workflow, results, success) + message = self._build_dingtalk_message(workflow, results, success) elif webhook.webhook_type == WebhookType.SLACK.value: - message = self._build_slack_message(workflow, results, success) + message = self._build_slack_message(workflow, results, success) else: - message = { + message = { "workflow_id": workflow.id, "workflow_name": workflow.name, "status": "success" if success else "failed", @@ -1442,14 +1441,14 @@ class WorkflowManager: } try: - result = await self.notifier.send(webhook, message) + result = await self.notifier.send(webhook, message) self.update_webhook_stats(webhook_id, result) except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Failed to send notification to {webhook_id}: {e}") def _build_feishu_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建飞书消息""" - status_text = "✅ 成功" if success else "❌ 失败" + status_text = "✅ 成功" if success else "❌ 失败" return { "title": f"工作流执行通知: {workflow.name}", @@ -1462,7 +1461,7 @@ class WorkflowManager: def _build_dingtalk_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建钉钉消息""" - status_text = "✅ 成功" if success else "❌ 失败" + status_text = "✅ 成功" if success else "❌ 失败" return { "title": f"工作流执行通知: {workflow.name}", @@ -1476,15 +1475,15 @@ class WorkflowManager: **结果:** ```json -{json.dumps(results, ensure_ascii = False, indent = 2)} +{json.dumps(results, ensure_ascii=False, indent=2)} ``` """, } def _build_slack_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建 Slack 消息""" - color = "#36a64f" if success else "#ff0000" - status_text = "Success" if success else "Failed" + color = "#36a64f" if success else "#ff0000" + status_text = "Success" if success else "Failed" return { "attachments": [ @@ -1507,12 +1506,12 @@ class WorkflowManager: # Singleton instance -_workflow_manager = None +_workflow_manager = None -def get_workflow_manager(db_manager = None) -> WorkflowManager: +def get_workflow_manager(db_manager=None) -> WorkflowManager: """获取 WorkflowManager 单例""" global _workflow_manager if _workflow_manager is None: - _workflow_manager = WorkflowManager(db_manager) + _workflow_manager = WorkflowManager(db_manager) return _workflow_manager