diff --git a/AUTO_CODE_REVIEW_REPORT.md b/AUTO_CODE_REVIEW_REPORT.md index 6e68ecc..0e0ba5a 100644 --- a/AUTO_CODE_REVIEW_REPORT.md +++ b/AUTO_CODE_REVIEW_REPORT.md @@ -224,4 +224,8 @@ - 第 490 行: line_too_long - 第 541 行: line_too_long - 第 579 行: line_too_long -- ... 还有 2 个类似问题 \ No newline at end of file +- ... 还有 2 个类似问题 + +## Git 提交结果 + +✅ 提交并推送成功 diff --git a/__pycache__/auto_code_fixer.cpython-312.pyc b/__pycache__/auto_code_fixer.cpython-312.pyc new file mode 100644 index 0000000..1dd0708 Binary files /dev/null and b/__pycache__/auto_code_fixer.cpython-312.pyc differ diff --git a/__pycache__/auto_fix_code.cpython-312.pyc b/__pycache__/auto_fix_code.cpython-312.pyc new file mode 100644 index 0000000..5d36e6f Binary files /dev/null and b/__pycache__/auto_fix_code.cpython-312.pyc differ diff --git a/__pycache__/code_review_fixer.cpython-312.pyc b/__pycache__/code_review_fixer.cpython-312.pyc new file mode 100644 index 0000000..b1e60fb Binary files /dev/null and b/__pycache__/code_review_fixer.cpython-312.pyc differ diff --git a/__pycache__/code_reviewer.cpython-312.pyc b/__pycache__/code_reviewer.cpython-312.pyc new file mode 100644 index 0000000..c731960 Binary files /dev/null and b/__pycache__/code_reviewer.cpython-312.pyc differ diff --git a/auto_code_fixer.py b/auto_code_fixer.py index 858c810..1b62948 100644 --- a/auto_code_fixer.py +++ b/auto_code_fixer.py @@ -19,30 +19,30 @@ class CodeIssue: line_no: int, issue_type: str, message: str, - severity: str = "warning", - original_line: str = "", - ): - self.file_path = file_path - self.line_no = line_no - self.issue_type = issue_type - self.message = message - self.severity = severity - self.original_line = original_line - self.fixed = False + severity: str = "warning", + original_line: str = "", + ) -> None: + self.file_path = file_path + self.line_no = line_no + self.issue_type = issue_type + self.message = message + self.severity = severity + self.original_line = original_line + self.fixed = False - def __repr__(self): + def __repr__(self) -> None: return f"{self.file_path}:{self.line_no} [{self.severity}] {self.issue_type}: {self.message}" class CodeFixer: """代码自动修复器""" - def __init__(self, project_path: str): - self.project_path = Path(project_path) - self.issues: list[CodeIssue] = [] - self.fixed_issues: list[CodeIssue] = [] - self.manual_issues: list[CodeIssue] = [] - self.scanned_files: list[str] = [] + def __init__(self, project_path: str) -> None: + self.project_path = Path(project_path) + self.issues: list[CodeIssue] = [] + self.fixed_issues: list[CodeIssue] = [] + self.manual_issues: list[CodeIssue] = [] + self.scanned_files: list[str] = [] def scan_all_files(self) -> None: """扫描所有 Python 文件""" @@ -55,9 +55,9 @@ class CodeFixer: def _scan_file(self, file_path: Path) -> None: """扫描单个文件""" try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - lines = content.split("\n") + with open(file_path, "r", encoding = "utf-8") as f: + content = f.read() + lines = content.split("\n") except Exception as e: print(f"Error reading {file_path}: {e}") return @@ -85,7 +85,7 @@ class CodeFixer: ) -> None: """检查裸异常捕获""" for i, line in enumerate(lines, 1): - # 匹配 except: 但不匹配 except Exception: 或 except SpecificError: + # 匹配 except Exception: 但不匹配 except Exception: 或 except SpecificError: if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line): # 跳过注释说明的情况 if "# noqa" in line or "# intentional" in line.lower(): @@ -130,25 +130,25 @@ class CodeFixer: def _check_unused_imports(self, file_path: Path, content: str) -> None: """检查未使用的导入""" try: - tree = ast.parse(content) + tree = ast.parse(content) except SyntaxError: return - imports = {} + imports = {} for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: - name = alias.asname if alias.asname else alias.name - imports[name] = node.lineno + name = alias.asname if alias.asname else alias.name + imports[name] = node.lineno elif isinstance(node, ast.ImportFrom): for alias in node.names: - name = alias.asname if alias.asname else alias.name + name = alias.asname if alias.asname else alias.name if alias.name == "*": continue - imports[name] = node.lineno + imports[name] = node.lineno # 检查使用 - used_names = set() + used_names = set() for node in ast.walk(tree): if isinstance(node, ast.Name): used_names.add(node.id) @@ -216,15 +216,15 @@ class CodeFixer: ) -> None: """检查敏感信息泄露""" # 排除的文件 - excluded_files = ["auto_code_fixer.py", "code_reviewer.py"] + excluded_files = ["auto_code_fixer.py", "code_reviewer.py"] if any(excluded in str(file_path) for excluded in excluded_files): return - patterns = [ - (r'password\s*=\s*["\'][^"\']{8,}["\']', "硬编码密码"), - (r'secret_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"), - (r'api_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码 API Key"), - (r'token\s*=\s*["\'][^"\']{8,}["\']', "硬编码 Token"), + patterns = [ + (r'password\s* = \s*["\'][^"\']{8, }["\']', "硬编码密码"), + (r'secret_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码密钥"), + (r'api_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码 API Key"), + (r'token\s* = \s*["\'][^"\']{8, }["\']', "硬编码 Token"), ] for i, line in enumerate(lines, 1): @@ -241,7 +241,7 @@ class CodeFixer: if any(x in line.lower() for x in ["your_", "example", "placeholder", "test", "demo"]): continue # 排除 Enum 定义 - if re.search(r'^\s*[A-Z_]+\s*=', line.strip()): + if re.search(r'^\s*[A-Z_]+\s* = ', line.strip()): continue self.manual_issues.append( CodeIssue( @@ -256,17 +256,17 @@ class CodeFixer: def fix_auto_fixable(self) -> None: """自动修复可修复的问题""" - auto_fix_types = { + auto_fix_types = { "trailing_whitespace", "bare_exception", } # 按文件分组 - files_to_fix = {} + files_to_fix = {} for issue in self.issues: if issue.issue_type in auto_fix_types: if issue.file_path not in files_to_fix: - files_to_fix[issue.file_path] = [] + files_to_fix[issue.file_path] = [] files_to_fix[issue.file_path].append(issue) for file_path, file_issues in files_to_fix.items(): @@ -275,43 +275,43 @@ class CodeFixer: continue try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - lines = content.split("\n") + with open(file_path, "r", encoding = "utf-8") as f: + content = f.read() + lines = content.split("\n") except Exception: continue - original_lines = lines.copy() - fixed_lines = set() + original_lines = lines.copy() + fixed_lines = set() # 修复行尾空格 for issue in file_issues: if issue.issue_type == "trailing_whitespace": - line_idx = issue.line_no - 1 + line_idx = issue.line_no - 1 if 0 <= line_idx < len(lines) and line_idx not in fixed_lines: if lines[line_idx].rstrip() != lines[line_idx]: - lines[line_idx] = lines[line_idx].rstrip() + lines[line_idx] = lines[line_idx].rstrip() fixed_lines.add(line_idx) - issue.fixed = True + issue.fixed = True self.fixed_issues.append(issue) # 修复裸异常 for issue in file_issues: if issue.issue_type == "bare_exception": - line_idx = issue.line_no - 1 + line_idx = issue.line_no - 1 if 0 <= line_idx < len(lines) and line_idx not in fixed_lines: - line = lines[line_idx] - # 将 except: 改为 except Exception: + line = lines[line_idx] + # 将 except Exception: 改为 except Exception: if re.search(r"except\s*:\s*$", line.strip()): - lines[line_idx] = line.replace("except:", "except Exception:") + lines[line_idx] = line.replace("except Exception:", "except Exception:") fixed_lines.add(line_idx) - issue.fixed = True + issue.fixed = True self.fixed_issues.append(issue) # 如果文件有修改,写回 if lines != original_lines: try: - with open(file_path, "w", encoding="utf-8") as f: + with open(file_path, "w", encoding = "utf-8") as f: f.write("\n".join(lines)) print(f"Fixed issues in {file_path}") except Exception as e: @@ -319,7 +319,7 @@ class CodeFixer: def categorize_issues(self) -> dict[str, list[CodeIssue]]: """分类问题""" - categories = { + categories = { "critical": [], "error": [], "warning": [], @@ -334,7 +334,7 @@ class CodeFixer: def generate_report(self) -> str: """生成修复报告""" - report = [] + report = [] report.append("# InsightFlow 代码审查报告") report.append("") report.append(f"扫描时间: {os.popen('date').read().strip()}") @@ -349,9 +349,9 @@ class CodeFixer: report.append("") # 问题统计 - categories = self.categorize_issues() - manual_critical = [i for i in self.manual_issues if i.severity == "critical"] - manual_warning = [i for i in self.manual_issues if i.severity == "warning"] + categories = self.categorize_issues() + manual_critical = [i for i in self.manual_issues if i.severity == "critical"] + manual_warning = [i for i in self.manual_issues if i.severity == "warning"] report.append("## 问题分类统计") report.append("") @@ -393,17 +393,17 @@ class CodeFixer: # 其他问题 report.append("## 📋 其他发现的问题") report.append("") - other_issues = [ + other_issues = [ i for i in self.issues if i not in self.fixed_issues ] # 按类型分组 - by_type = {} + by_type = {} for issue in other_issues: if issue.issue_type not in by_type: - by_type[issue.issue_type] = [] + by_type[issue.issue_type] = [] by_type[issue.issue_type].append(issue) for issue_type, issues in sorted(by_type.items()): @@ -424,21 +424,21 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]: """Git 提交和推送""" try: # 检查是否有变更 - result = subprocess.run( + result = subprocess.run( ["git", "status", "--porcelain"], - cwd=project_path, - capture_output=True, - text=True, + cwd = project_path, + capture_output = True, + text = True, ) if not result.stdout.strip(): return True, "没有需要提交的变更" # 添加所有变更 - subprocess.run(["git", "add", "-A"], cwd=project_path, check=True) + subprocess.run(["git", "add", "-A"], cwd = project_path, check = True) # 提交 - commit_msg = """fix: auto-fix code issues (cron) + commit_msg = """fix: auto-fix code issues (cron) - 修复重复导入/字段 - 修复异常处理 @@ -446,11 +446,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]: - 添加类型注解""" subprocess.run( - ["git", "commit", "-m", commit_msg], cwd=project_path, check=True + ["git", "commit", "-m", commit_msg], cwd = project_path, check = True ) # 推送 - subprocess.run(["git", "push"], cwd=project_path, check=True) + subprocess.run(["git", "push"], cwd = project_path, check = True) return True, "提交并推送成功" except subprocess.CalledProcessError as e: @@ -459,11 +459,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]: return False, f"Git 操作异常: {e}" -def main(): - project_path = "/root/.openclaw/workspace/projects/insightflow" +def main() -> None: + project_path = "/root/.openclaw/workspace/projects/insightflow" print("🔍 开始扫描代码...") - fixer = CodeFixer(project_path) + fixer = CodeFixer(project_path) fixer.scan_all_files() print(f"📊 发现 {len(fixer.issues)} 个可自动修复问题") @@ -475,30 +475,30 @@ def main(): print(f"✅ 已修复 {len(fixer.fixed_issues)} 个问题") # 生成报告 - report = fixer.generate_report() + report = fixer.generate_report() # 保存报告 - report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md" - with open(report_path, "w", encoding="utf-8") as f: + report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md" + with open(report_path, "w", encoding = "utf-8") as f: f.write(report) print(f"📝 报告已保存到: {report_path}") # Git 提交 print("📤 提交变更到 Git...") - success, msg = git_commit_and_push(project_path) + success, msg = git_commit_and_push(project_path) print(f"{'✅' if success else '❌'} {msg}") # 添加 Git 结果到报告 report += f"\n\n## Git 提交结果\n\n{'✅' if success else '❌'} {msg}\n" # 重新保存完整报告 - with open(report_path, "w", encoding="utf-8") as f: + with open(report_path, "w", encoding = "utf-8") as f: f.write(report) - print("\n" + "=" * 60) + print("\n" + " = " * 60) print(report) - print("=" * 60) + print(" = " * 60) return report diff --git a/auto_fix_code.py b/auto_fix_code.py index b0e6043..caf66e1 100644 --- a/auto_fix_code.py +++ b/auto_fix_code.py @@ -14,11 +14,11 @@ from pathlib import Path def run_ruff_check(directory: str) -> list[dict]: """运行 ruff 检查并返回问题列表""" try: - result = subprocess.run( - ["ruff", "check", "--select=E,W,F,I", "--output-format=json", directory], - capture_output=True, - text=True, - check=False, + result = subprocess.run( + ["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory], + capture_output = True, + text = True, + check = False, ) if result.stdout: return json.loads(result.stdout) @@ -29,18 +29,18 @@ def run_ruff_check(directory: str) -> list[dict]: def fix_bare_except(content: str) -> str: - """修复裸异常捕获 - 将 bare except: 改为 except Exception:""" - pattern = r'except\s*:\s*\n' - replacement = 'except Exception:\n' + """修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:""" + pattern = r'except\s*:\s*\n' + replacement = 'except Exception:\n' return re.sub(pattern, replacement, content) def fix_undefined_names(content: str, filepath: str) -> str: """修复未定义的名称""" - lines = content.split('\n') - modified = False - - import_map = { + lines = content.split('\n') + modified = False + + import_map = { 'ExportEntity': 'from export_manager import ExportEntity', 'ExportRelation': 'from export_manager import ExportRelation', 'ExportTranscript': 'from export_manager import ExportTranscript', @@ -49,23 +49,23 @@ def fix_undefined_names(content: str, filepath: str) -> str: 'OpsManager': 'from ops_manager import OpsManager', 'urllib': 'import urllib.parse', } - - undefined_names = set() + + undefined_names = set() for name, import_stmt in import_map.items(): if name in content and import_stmt not in content: undefined_names.add((name, import_stmt)) - + if undefined_names: - import_idx = 0 + import_idx = 0 for i, line in enumerate(lines): if line.startswith('import ') or line.startswith('from '): - import_idx = i + 1 - + import_idx = i + 1 + for name, import_stmt in sorted(undefined_names): lines.insert(import_idx, import_stmt) import_idx += 1 - modified = True - + modified = True + if modified: return '\n'.join(lines) return content @@ -73,100 +73,100 @@ def fix_undefined_names(content: str, filepath: str) -> str: def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]: """修复单个文件的问题""" - with open(filepath, 'r', encoding='utf-8') as f: - original_content = f.read() - - content = original_content - fixed_issues = [] - manual_fix_needed = [] - + with open(filepath, 'r', encoding = 'utf-8') as f: + original_content = f.read() + + content = original_content + fixed_issues = [] + manual_fix_needed = [] + for issue in issues: - code = issue.get('code', '') - message = issue.get('message', '') - line_num = issue['location']['row'] - + code = issue.get('code', '') + message = issue.get('message', '') + line_num = issue['location']['row'] + if code == 'F821': - content = fix_undefined_names(content, filepath) + content = fix_undefined_names(content, filepath) if content != original_content: fixed_issues.append(f"F821 - {message} (line {line_num})") else: manual_fix_needed.append(f"F821 - {message} (line {line_num})") elif code == 'E501': manual_fix_needed.append(f"E501 (line {line_num})") - - content = fix_bare_except(content) - + + content = fix_bare_except(content) + if content != original_content: - with open(filepath, 'w', encoding='utf-8') as f: + with open(filepath, 'w', encoding = 'utf-8') as f: f.write(content) return True, fixed_issues, manual_fix_needed - + return False, fixed_issues, manual_fix_needed -def main(): - base_dir = Path("/root/.openclaw/workspace/projects/insightflow") - backend_dir = base_dir / "backend" - - print("=" * 60) +def main() -> None: + base_dir = Path("/root/.openclaw/workspace/projects/insightflow") + backend_dir = base_dir / "backend" + + print(" = " * 60) print("InsightFlow 代码自动修复") - print("=" * 60) - + print(" = " * 60) + print("\n1. 扫描代码问题...") - issues = run_ruff_check(str(backend_dir)) - - issues_by_file = {} + issues = run_ruff_check(str(backend_dir)) + + issues_by_file = {} for issue in issues: - filepath = issue.get('filename', '') + filepath = issue.get('filename', '') if filepath not in issues_by_file: - issues_by_file[filepath] = [] + issues_by_file[filepath] = [] issues_by_file[filepath].append(issue) - + print(f" 发现 {len(issues)} 个问题,分布在 {len(issues_by_file)} 个文件中") - - issue_types = {} + + issue_types = {} for issue in issues: - code = issue.get('code', 'UNKNOWN') - issue_types[code] = issue_types.get(code, 0) + 1 - + code = issue.get('code', 'UNKNOWN') + issue_types[code] = issue_types.get(code, 0) + 1 + print("\n2. 问题类型统计:") - for code, count in sorted(issue_types.items(), key=lambda x: -x[1]): + for code, count in sorted(issue_types.items(), key = lambda x: -x[1]): print(f" - {code}: {count} 个") - + print("\n3. 尝试自动修复...") - fixed_files = [] - all_fixed_issues = [] - all_manual_fixes = [] - + fixed_files = [] + all_fixed_issues = [] + all_manual_fixes = [] + for filepath, file_issues in issues_by_file.items(): if not os.path.exists(filepath): continue - - modified, fixed, manual = fix_file(filepath, file_issues) + + modified, fixed, manual = fix_file(filepath, file_issues) if modified: fixed_files.append(filepath) all_fixed_issues.extend(fixed) all_manual_fixes.extend([(filepath, m) for m in manual]) - + print(f" 直接修改了 {len(fixed_files)} 个文件") print(f" 自动修复了 {len(all_fixed_issues)} 个问题") - + print("\n4. 运行 ruff 自动格式化...") try: subprocess.run( ["ruff", "format", str(backend_dir)], - capture_output=True, - check=False, + capture_output = True, + check = False, ) print(" 格式化完成") except Exception as e: print(f" 格式化失败: {e}") - + print("\n5. 再次检查...") - remaining_issues = run_ruff_check(str(backend_dir)) + remaining_issues = run_ruff_check(str(backend_dir)) print(f" 剩余 {len(remaining_issues)} 个问题需要手动处理") - - report = { + + report = { 'total_issues': len(issues), 'fixed_files': len(fixed_files), 'fixed_issues': len(all_fixed_issues), @@ -174,15 +174,15 @@ def main(): 'issue_types': issue_types, 'manual_fix_needed': all_manual_fixes[:30], } - + return report if __name__ == "__main__": - report = main() - print("\n" + "=" * 60) + report = main() + print("\n" + " = " * 60) print("修复报告") - print("=" * 60) + print(" = " * 60) print(f"总问题数: {report['total_issues']}") print(f"修复文件数: {report['fixed_files']}") print(f"自动修复问题数: {report['fixed_issues']}") diff --git a/backend/__pycache__/ai_manager.cpython-312.pyc b/backend/__pycache__/ai_manager.cpython-312.pyc new file mode 100644 index 0000000..0e33da3 Binary files /dev/null and b/backend/__pycache__/ai_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/api_key_manager.cpython-312.pyc b/backend/__pycache__/api_key_manager.cpython-312.pyc new file mode 100644 index 0000000..8da9471 Binary files /dev/null and b/backend/__pycache__/api_key_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/collaboration_manager.cpython-312.pyc b/backend/__pycache__/collaboration_manager.cpython-312.pyc new file mode 100644 index 0000000..23551f1 Binary files /dev/null and b/backend/__pycache__/collaboration_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/db_manager.cpython-312.pyc b/backend/__pycache__/db_manager.cpython-312.pyc new file mode 100644 index 0000000..3b6828e Binary files /dev/null and b/backend/__pycache__/db_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/developer_ecosystem_manager.cpython-312.pyc b/backend/__pycache__/developer_ecosystem_manager.cpython-312.pyc new file mode 100644 index 0000000..2a56c60 Binary files /dev/null and b/backend/__pycache__/developer_ecosystem_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/document_processor.cpython-312.pyc b/backend/__pycache__/document_processor.cpython-312.pyc new file mode 100644 index 0000000..c95ae78 Binary files /dev/null and b/backend/__pycache__/document_processor.cpython-312.pyc differ diff --git a/backend/__pycache__/enterprise_manager.cpython-312.pyc b/backend/__pycache__/enterprise_manager.cpython-312.pyc new file mode 100644 index 0000000..b08c44d Binary files /dev/null and b/backend/__pycache__/enterprise_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/entity_aligner.cpython-312.pyc b/backend/__pycache__/entity_aligner.cpython-312.pyc new file mode 100644 index 0000000..7496bc3 Binary files /dev/null and b/backend/__pycache__/entity_aligner.cpython-312.pyc differ diff --git a/backend/__pycache__/export_manager.cpython-312.pyc b/backend/__pycache__/export_manager.cpython-312.pyc new file mode 100644 index 0000000..2ab372a Binary files /dev/null and b/backend/__pycache__/export_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/growth_manager.cpython-312.pyc b/backend/__pycache__/growth_manager.cpython-312.pyc new file mode 100644 index 0000000..9a408ba Binary files /dev/null and b/backend/__pycache__/growth_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/image_processor.cpython-312.pyc b/backend/__pycache__/image_processor.cpython-312.pyc new file mode 100644 index 0000000..325396c Binary files /dev/null and b/backend/__pycache__/image_processor.cpython-312.pyc differ diff --git a/backend/__pycache__/init_db.cpython-312.pyc b/backend/__pycache__/init_db.cpython-312.pyc new file mode 100644 index 0000000..54ce54f Binary files /dev/null and b/backend/__pycache__/init_db.cpython-312.pyc differ diff --git a/backend/__pycache__/knowledge_reasoner.cpython-312.pyc b/backend/__pycache__/knowledge_reasoner.cpython-312.pyc new file mode 100644 index 0000000..49659d3 Binary files /dev/null and b/backend/__pycache__/knowledge_reasoner.cpython-312.pyc differ diff --git a/backend/__pycache__/llm_client.cpython-312.pyc b/backend/__pycache__/llm_client.cpython-312.pyc new file mode 100644 index 0000000..96ebdb1 Binary files /dev/null and b/backend/__pycache__/llm_client.cpython-312.pyc differ diff --git a/backend/__pycache__/localization_manager.cpython-312.pyc b/backend/__pycache__/localization_manager.cpython-312.pyc new file mode 100644 index 0000000..f832bd2 Binary files /dev/null and b/backend/__pycache__/localization_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/main.cpython-312.pyc b/backend/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000..80420ec Binary files /dev/null and b/backend/__pycache__/main.cpython-312.pyc differ diff --git a/backend/__pycache__/multimodal_entity_linker.cpython-312.pyc b/backend/__pycache__/multimodal_entity_linker.cpython-312.pyc new file mode 100644 index 0000000..b0ebce0 Binary files /dev/null and b/backend/__pycache__/multimodal_entity_linker.cpython-312.pyc differ diff --git a/backend/__pycache__/multimodal_processor.cpython-312.pyc b/backend/__pycache__/multimodal_processor.cpython-312.pyc new file mode 100644 index 0000000..556c2e9 Binary files /dev/null and b/backend/__pycache__/multimodal_processor.cpython-312.pyc differ diff --git a/backend/__pycache__/neo4j_manager.cpython-312.pyc b/backend/__pycache__/neo4j_manager.cpython-312.pyc new file mode 100644 index 0000000..baed9bb Binary files /dev/null and b/backend/__pycache__/neo4j_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/ops_manager.cpython-312.pyc b/backend/__pycache__/ops_manager.cpython-312.pyc new file mode 100644 index 0000000..0e08d9c Binary files /dev/null and b/backend/__pycache__/ops_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/oss_uploader.cpython-312.pyc b/backend/__pycache__/oss_uploader.cpython-312.pyc new file mode 100644 index 0000000..e35b950 Binary files /dev/null and b/backend/__pycache__/oss_uploader.cpython-312.pyc differ diff --git a/backend/__pycache__/performance_manager.cpython-312.pyc b/backend/__pycache__/performance_manager.cpython-312.pyc new file mode 100644 index 0000000..0545cb7 Binary files /dev/null and b/backend/__pycache__/performance_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/plugin_manager.cpython-312.pyc b/backend/__pycache__/plugin_manager.cpython-312.pyc new file mode 100644 index 0000000..2e42c6e Binary files /dev/null and b/backend/__pycache__/plugin_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/rate_limiter.cpython-312.pyc b/backend/__pycache__/rate_limiter.cpython-312.pyc new file mode 100644 index 0000000..17535b5 Binary files /dev/null and b/backend/__pycache__/rate_limiter.cpython-312.pyc differ diff --git a/backend/__pycache__/search_manager.cpython-312.pyc b/backend/__pycache__/search_manager.cpython-312.pyc new file mode 100644 index 0000000..a8f6ece Binary files /dev/null and b/backend/__pycache__/search_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/security_manager.cpython-312.pyc b/backend/__pycache__/security_manager.cpython-312.pyc new file mode 100644 index 0000000..db594a8 Binary files /dev/null and b/backend/__pycache__/security_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/subscription_manager.cpython-312.pyc b/backend/__pycache__/subscription_manager.cpython-312.pyc new file mode 100644 index 0000000..48a8354 Binary files /dev/null and b/backend/__pycache__/subscription_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/tenant_manager.cpython-312.pyc b/backend/__pycache__/tenant_manager.cpython-312.pyc new file mode 100644 index 0000000..3363b77 Binary files /dev/null and b/backend/__pycache__/tenant_manager.cpython-312.pyc differ diff --git a/backend/__pycache__/test_multimodal.cpython-312.pyc b/backend/__pycache__/test_multimodal.cpython-312.pyc new file mode 100644 index 0000000..a75526f Binary files /dev/null and b/backend/__pycache__/test_multimodal.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase7_task6_8.cpython-312.pyc b/backend/__pycache__/test_phase7_task6_8.cpython-312.pyc new file mode 100644 index 0000000..f1a3764 Binary files /dev/null and b/backend/__pycache__/test_phase7_task6_8.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase8_task1.cpython-312.pyc b/backend/__pycache__/test_phase8_task1.cpython-312.pyc new file mode 100644 index 0000000..40d7b75 Binary files /dev/null and b/backend/__pycache__/test_phase8_task1.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase8_task2.cpython-312.pyc b/backend/__pycache__/test_phase8_task2.cpython-312.pyc new file mode 100644 index 0000000..0e823bc Binary files /dev/null and b/backend/__pycache__/test_phase8_task2.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase8_task4.cpython-312.pyc b/backend/__pycache__/test_phase8_task4.cpython-312.pyc new file mode 100644 index 0000000..f4a2651 Binary files /dev/null and b/backend/__pycache__/test_phase8_task4.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase8_task5.cpython-312.pyc b/backend/__pycache__/test_phase8_task5.cpython-312.pyc new file mode 100644 index 0000000..7ef10ed Binary files /dev/null and b/backend/__pycache__/test_phase8_task5.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase8_task6.cpython-312.pyc b/backend/__pycache__/test_phase8_task6.cpython-312.pyc new file mode 100644 index 0000000..63c112b Binary files /dev/null and b/backend/__pycache__/test_phase8_task6.cpython-312.pyc differ diff --git a/backend/__pycache__/test_phase8_task8.cpython-312.pyc b/backend/__pycache__/test_phase8_task8.cpython-312.pyc new file mode 100644 index 0000000..ba0f0cc Binary files /dev/null and b/backend/__pycache__/test_phase8_task8.cpython-312.pyc differ diff --git a/backend/__pycache__/tingwu_client.cpython-312.pyc b/backend/__pycache__/tingwu_client.cpython-312.pyc new file mode 100644 index 0000000..dc92539 Binary files /dev/null and b/backend/__pycache__/tingwu_client.cpython-312.pyc differ diff --git a/backend/__pycache__/workflow_manager.cpython-312.pyc b/backend/__pycache__/workflow_manager.cpython-312.pyc new file mode 100644 index 0000000..7c08712 Binary files /dev/null and b/backend/__pycache__/workflow_manager.cpython-312.pyc differ diff --git a/backend/ai_manager.py b/backend/ai_manager.py index b8e7a68..b1e5699 100644 --- a/backend/ai_manager.py +++ b/backend/ai_manager.py @@ -25,44 +25,44 @@ from enum import StrEnum import httpx # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class ModelType(StrEnum): """模型类型""" - CUSTOM_NER = "custom_ner" # 自定义实体识别 - MULTIMODAL = "multimodal" # 多模态 - SUMMARIZATION = "summarization" # 摘要 - PREDICTION = "prediction" # 预测 + CUSTOM_NER = "custom_ner" # 自定义实体识别 + MULTIMODAL = "multimodal" # 多模态 + SUMMARIZATION = "summarization" # 摘要 + PREDICTION = "prediction" # 预测 class ModelStatus(StrEnum): """模型状态""" - PENDING = "pending" - TRAINING = "training" - READY = "ready" - FAILED = "failed" - ARCHIVED = "archived" + PENDING = "pending" + TRAINING = "training" + READY = "ready" + FAILED = "failed" + ARCHIVED = "archived" class MultimodalProvider(StrEnum): """多模态模型提供商""" - GPT4V = "gpt-4-vision" - CLAUDE3 = "claude-3" - GEMINI = "gemini-pro-vision" - KIMI_VL = "kimi-vl" + GPT4V = "gpt-4-vision" + CLAUDE3 = "claude-3" + GEMINI = "gemini-pro-vision" + KIMI_VL = "kimi-vl" class PredictionType(StrEnum): """预测类型""" - TREND = "trend" # 趋势预测 - ANOMALY = "anomaly" # 异常检测 - ENTITY_GROWTH = "entity_growth" # 实体增长预测 - RELATION_EVOLUTION = "relation_evolution" # 关系演变预测 + TREND = "trend" # 趋势预测 + ANOMALY = "anomaly" # 异常检测 + ENTITY_GROWTH = "entity_growth" # 实体增长预测 + RELATION_EVOLUTION = "relation_evolution" # 关系演变预测 @dataclass @@ -204,17 +204,17 @@ class SmartSummary: class AIManager: """AI 能力管理主类""" - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path - self.kimi_api_key = os.getenv("KIMI_API_KEY", "") - self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") - self.openai_api_key = os.getenv("OPENAI_API_KEY", "") - self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self.kimi_api_key = os.getenv("KIMI_API_KEY", "") + self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") + self.openai_api_key = os.getenv("OPENAI_API_KEY", "") + self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") def _get_db(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn # ==================== 自定义模型训练 ==================== @@ -230,24 +230,24 @@ class AIManager: created_by: str, ) -> CustomModel: """创建自定义模型""" - model_id = f"cm_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + model_id = f"cm_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - model = CustomModel( - id=model_id, - tenant_id=tenant_id, - name=name, - description=description, - model_type=model_type, - status=ModelStatus.PENDING, - training_data=training_data, - hyperparameters=hyperparameters, - metrics={}, - model_path=None, - created_at=now, - updated_at=now, - trained_at=None, - created_by=created_by, + model = CustomModel( + id = model_id, + tenant_id = tenant_id, + name = name, + description = description, + model_type = model_type, + status = ModelStatus.PENDING, + training_data = training_data, + hyperparameters = hyperparameters, + metrics = {}, + model_path = None, + created_at = now, + updated_at = now, + trained_at = None, + created_by = created_by, ) with self._get_db() as conn: @@ -283,7 +283,7 @@ class AIManager: def get_custom_model(self, model_id: str) -> CustomModel | None: """获取自定义模型""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone() + row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id, )).fetchone() if not row: return None @@ -291,39 +291,39 @@ class AIManager: return self._row_to_custom_model(row) def list_custom_models( - self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None + self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None ) -> list[CustomModel]: """列出自定义模型""" - query = "SELECT * FROM custom_models WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM custom_models WHERE tenant_id = ?" + params = [tenant_id] if model_type: - query += " AND model_type = ?" + query += " AND model_type = ?" params.append(model_type.value) if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status.value) query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_custom_model(row) for row in rows] def add_training_sample( - self, model_id: str, text: str, entities: list[dict], metadata: dict = None + self, model_id: str, text: str, entities: list[dict], metadata: dict = None ) -> TrainingSample: """添加训练样本""" - sample_id = f"ts_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + sample_id = f"ts_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - sample = TrainingSample( - id=sample_id, - model_id=model_id, - text=text, - entities=entities, - metadata=metadata or {}, - created_at=now, + sample = TrainingSample( + id = sample_id, + model_id = model_id, + text = text, + entities = entities, + metadata = metadata or {}, + created_at = now, ) with self._get_db() as conn: @@ -349,29 +349,29 @@ class AIManager: def get_training_samples(self, model_id: str) -> list[TrainingSample]: """获取训练样本""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id,) + rows = conn.execute( + "SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id, ) ).fetchall() return [self._row_to_training_sample(row) for row in rows] async def train_custom_model(self, model_id: str) -> CustomModel: """训练自定义模型""" - model = self.get_custom_model(model_id) + model = self.get_custom_model(model_id) if not model: raise ValueError(f"Model {model_id} not found") # 更新状态为训练中 with self._get_db() as conn: conn.execute( - "UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?", + "UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?", (ModelStatus.TRAINING.value, datetime.now().isoformat(), model_id), ) conn.commit() try: # 获取训练样本 - samples = self.get_training_samples(model_id) + samples = self.get_training_samples(model_id) if len(samples) < 10: raise ValueError("至少需要 10 个训练样本") @@ -380,7 +380,7 @@ class AIManager: await asyncio.sleep(2) # 模拟训练时间 # 计算训练指标 - metrics = { + metrics = { "samples_count": len(samples), "epochs": model.hyperparameters.get("epochs", 10), "learning_rate": model.hyperparameters.get("learning_rate", 0.001), @@ -391,17 +391,17 @@ class AIManager: } # 保存模型(模拟) - model_path = f"models/{model_id}.bin" - os.makedirs("models", exist_ok=True) + model_path = f"models/{model_id}.bin" + os.makedirs("models", exist_ok = True) - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE custom_models - SET status = ?, metrics = ?, model_path = ?, trained_at = ?, updated_at = ? - WHERE id = ? + SET status = ?, metrics = ?, model_path = ?, trained_at = ?, updated_at = ? + WHERE id = ? """, (ModelStatus.READY.value, json.dumps(metrics), model_path, now, now, model_id), ) @@ -412,7 +412,7 @@ class AIManager: except Exception as e: with self._get_db() as conn: conn.execute( - "UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?", + "UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?", (ModelStatus.FAILED.value, datetime.now().isoformat(), model_id), ) conn.commit() @@ -420,49 +420,49 @@ class AIManager: async def predict_with_custom_model(self, model_id: str, text: str) -> list[dict]: """使用自定义模型进行预测""" - model = self.get_custom_model(model_id) + model = self.get_custom_model(model_id) if not model or model.status != ModelStatus.READY: raise ValueError(f"Model {model_id} not ready") # 模拟预测(实际项目中加载模型并进行推理) # 这里使用 LLM 模拟领域特定实体识别 - entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"]) + entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"]) - prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)} + prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)} 文本: {text} 以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}] 只返回 JSON 数组,不要其他内容。""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers=headers, - json=payload, - timeout=60.0, + headers = headers, + json = payload, + timeout = 60.0, ) response.raise_for_status() - result = response.json() - content = result["choices"][0]["message"]["content"] + result = response.json() + content = result["choices"][0]["message"]["content"] # 解析 JSON - json_match = re.search(r"\[.*?\]", content, re.DOTALL) + json_match = re.search(r"\[.*?\]", content, re.DOTALL) if json_match: try: - entities = json.loads(json_match.group()) + entities = json.loads(json_match.group()) return entities except (json.JSONDecodeError, ValueError): pass @@ -481,30 +481,30 @@ class AIManager: prompt: str, ) -> MultimodalAnalysis: """多模态分析""" - analysis_id = f"ma_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + analysis_id = f"ma_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 根据提供商调用不同的 API if provider == MultimodalProvider.GPT4V and self.openai_api_key: - result = await self._call_gpt4v(input_urls, prompt) + result = await self._call_gpt4v(input_urls, prompt) elif provider == MultimodalProvider.CLAUDE3 and self.anthropic_api_key: - result = await self._call_claude3(input_urls, prompt) + result = await self._call_claude3(input_urls, prompt) else: # 默认使用 Kimi - result = await self._call_kimi_multimodal(input_urls, prompt) + result = await self._call_kimi_multimodal(input_urls, prompt) - analysis = MultimodalAnalysis( - id=analysis_id, - tenant_id=tenant_id, - project_id=project_id, - provider=provider, - input_type=input_type, - input_urls=input_urls, - prompt=prompt, - result=result, - tokens_used=result.get("tokens_used", 0), - cost=result.get("cost", 0.0), - created_at=now, + analysis = MultimodalAnalysis( + id = analysis_id, + tenant_id = tenant_id, + project_id = project_id, + provider = provider, + input_type = input_type, + input_urls = input_urls, + prompt = prompt, + result = result, + tokens_used = result.get("tokens_used", 0), + cost = result.get("cost", 0.0), + created_at = now, ) with self._get_db() as conn: @@ -535,30 +535,30 @@ class AIManager: async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict: """调用 GPT-4V""" - headers = { + headers = { "Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json", } - content = [{"type": "text", "text": prompt}] + content = [{"type": "text", "text": prompt}] for url in image_urls: content.append({"type": "image_url", "image_url": {"url": url}}) - payload = { + payload = { "model": "gpt-4-vision-preview", "messages": [{"role": "user", "content": content}], "max_tokens": 2000, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( "https://api.openai.com/v1/chat/completions", - headers=headers, - json=payload, - timeout=120.0, + headers = headers, + json = payload, + timeout = 120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return { "content": result["choices"][0]["message"]["content"], @@ -568,32 +568,32 @@ class AIManager: async def _call_claude3(self, image_urls: list[str], prompt: str) -> dict: """调用 Claude 3""" - headers = { + headers = { "x-api-key": self.anthropic_api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", } - content = [] + content = [] for url in image_urls: content.append({"type": "image", "source": {"type": "url", "url": url}}) content.append({"type": "text", "text": prompt}) - payload = { + payload = { "model": "claude-3-opus-20240229", "max_tokens": 2000, "messages": [{"role": "user", "content": content}], } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( "https://api.anthropic.com/v1/messages", - headers=headers, - json=payload, - timeout=120.0, + headers = headers, + json = payload, + timeout = 120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return { "content": result["content"][0]["text"], @@ -604,7 +604,7 @@ class AIManager: async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict: """调用 Kimi 多模态模型""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } @@ -612,23 +612,23 @@ class AIManager: # Kimi 目前可能不支持真正的多模态,这里模拟返回 # 实际实现时需要根据 Kimi API 更新 - content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。" + content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。" - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": content}], "temperature": 0.3, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers=headers, - json=payload, - timeout=60.0, + headers = headers, + json = payload, + timeout = 60.0, ) response.raise_for_status() - result = response.json() + result = response.json() return { "content": result["choices"][0]["message"]["content"], @@ -637,20 +637,20 @@ class AIManager: } def get_multimodal_analyses( - self, tenant_id: str, project_id: str | None = None + self, tenant_id: str, project_id: str | None = None ) -> list[MultimodalAnalysis]: """获取多模态分析历史""" - query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" + params = [tenant_id] if project_id: - query += " AND project_id = ?" + query += " AND project_id = ?" params.append(project_id) query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_multimodal_analysis(row) for row in rows] # ==================== 智能摘要与问答(基于知识图谱的 RAG) ==================== @@ -666,21 +666,21 @@ class AIManager: generation_config: dict, ) -> KnowledgeGraphRAG: """创建知识图谱 RAG 配置""" - rag_id = f"kgr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + rag_id = f"kgr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - rag = KnowledgeGraphRAG( - id=rag_id, - tenant_id=tenant_id, - project_id=project_id, - name=name, - description=description, - kg_config=kg_config, - retrieval_config=retrieval_config, - generation_config=generation_config, - is_active=True, - created_at=now, - updated_at=now, + rag = KnowledgeGraphRAG( + id = rag_id, + tenant_id = tenant_id, + project_id = project_id, + name = name, + description = description, + kg_config = kg_config, + retrieval_config = retrieval_config, + generation_config = generation_config, + is_active = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -712,7 +712,7 @@ class AIManager: def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None: """获取知识图谱 RAG 配置""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone() + row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id, )).fetchone() if not row: return None @@ -720,43 +720,43 @@ class AIManager: return self._row_to_kg_rag(row) def list_kg_rags( - self, tenant_id: str, project_id: str | None = None + self, tenant_id: str, project_id: str | None = None ) -> list[KnowledgeGraphRAG]: """列出知识图谱 RAG 配置""" - query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" + params = [tenant_id] if project_id: - query += " AND project_id = ?" + query += " AND project_id = ?" params.append(project_id) query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_kg_rag(row) for row in rows] async def query_kg_rag( self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict] ) -> RAGQuery: """基于知识图谱的 RAG 查询""" - start_time = time.time() + start_time = time.time() - rag = self.get_kg_rag(rag_id) + rag = self.get_kg_rag(rag_id) if not rag: raise ValueError(f"RAG config {rag_id} not found") # 1. 检索相关实体和关系 - retrieval_config = rag.retrieval_config - top_k = retrieval_config.get("top_k", 5) + retrieval_config = rag.retrieval_config + top_k = retrieval_config.get("top_k", 5) # 简单的语义检索(基于实体名称匹配) - query_lower = query.lower() - relevant_entities = [] + query_lower = query.lower() + relevant_entities = [] for entity in project_entities: - score = 0 - name = entity.get("name", "").lower() - definition = entity.get("definition", "").lower() + score = 0 + name = entity.get("name", "").lower() + definition = entity.get("definition", "").lower() if name in query_lower or any(word in name for word in query_lower.split()): score += 0.5 @@ -766,12 +766,12 @@ class AIManager: if score > 0: relevant_entities.append({**entity, "relevance_score": score}) - relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True) - relevant_entities = relevant_entities[:top_k] + relevant_entities.sort(key = lambda x: x["relevance_score"], reverse = True) + relevant_entities = relevant_entities[:top_k] # 检索相关关系 - relevant_relations = [] - entity_ids = {e["id"] for e in relevant_entities} + relevant_relations = [] + entity_ids = {e["id"] for e in relevant_entities} for relation in project_relations: if ( relation.get("source_entity_id") in entity_ids @@ -780,16 +780,16 @@ class AIManager: relevant_relations.append(relation) # 2. 构建上下文 - context = {"entities": relevant_entities, "relations": relevant_relations[:10]} + context = {"entities": relevant_entities, "relations": relevant_relations[:10]} - context_text = self._build_kg_context(relevant_entities, relevant_relations) + context_text = self._build_kg_context(relevant_entities, relevant_relations) # 3. 生成回答 - generation_config = rag.generation_config - temperature = generation_config.get("temperature", 0.3) - max_tokens = generation_config.get("max_tokens", 1000) + generation_config = rag.generation_config + temperature = generation_config.get("temperature", 0.3) + max_tokens = generation_config.get("max_tokens", 1000) - prompt = f"""基于以下知识图谱信息回答问题: + prompt = f"""基于以下知识图谱信息回答问题: ## 知识图谱上下文 {context_text} @@ -803,12 +803,12 @@ class AIManager: 2. 如果涉及多个实体,说明它们之间的关联 3. 保持简洁专业""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature, @@ -816,44 +816,44 @@ class AIManager: } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers=headers, - json=payload, - timeout=60.0, + headers = headers, + json = payload, + timeout = 60.0, ) response.raise_for_status() - result = response.json() + result = response.json() - answer = result["choices"][0]["message"]["content"] - tokens_used = result["usage"]["total_tokens"] + answer = result["choices"][0]["message"]["content"] + tokens_used = result["usage"]["total_tokens"] - latency_ms = int((time.time() - start_time) * 1000) + latency_ms = int((time.time() - start_time) * 1000) # 4. 保存查询记录 - query_id = f"rq_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + query_id = f"rq_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - sources = [ + sources = [ {"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]} for e in relevant_entities ] - rag_query = RAGQuery( - id=query_id, - rag_id=rag_id, - query=query, - context=context, - answer=answer, - sources=sources, - confidence=( + rag_query = RAGQuery( + id = query_id, + rag_id = rag_id, + query = query, + context = context, + answer = answer, + sources = sources, + confidence = ( sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities) if relevant_entities else 0 ), - tokens_used=tokens_used, - latency_ms=latency_ms, - created_at=now, + tokens_used = tokens_used, + latency_ms = latency_ms, + created_at = now, ) with self._get_db() as conn: @@ -883,23 +883,23 @@ class AIManager: def _build_kg_context(self, entities: list[dict], relations: list[dict]) -> str: """构建知识图谱上下文文本""" - context = [] + context = [] if entities: context.append("### 相关实体") for entity in entities: - name = entity.get("name", "") - entity_type = entity.get("type", "") - definition = entity.get("definition", "") + name = entity.get("name", "") + entity_type = entity.get("type", "") + definition = entity.get("definition", "") context.append(f"- **{name}** ({entity_type}): {definition}") if relations: context.append("\n### 相关关系") for relation in relations: - source = relation.get("source_name", "") - target = relation.get("target_name", "") - rel_type = relation.get("relation_type", "") - evidence = relation.get("evidence", "") + source = relation.get("source_name", "") + target = relation.get("target_name", "") + rel_type = relation.get("relation_type", "") + evidence = relation.get("evidence", "") context.append(f"- {source} --[{rel_type}]--> {target}") if evidence: context.append(f" - 依据: {evidence[:100]}...") @@ -916,12 +916,12 @@ class AIManager: content_data: dict, ) -> SmartSummary: """生成智能摘要""" - summary_id = f"ss_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + summary_id = f"ss_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 根据摘要类型生成不同的提示 if summary_type == "extractive": - prompt = f"""从以下内容中提取关键句子作为摘要: + prompt = f"""从以下内容中提取关键句子作为摘要: {content_data.get("text", "")[:5000]} @@ -931,7 +931,7 @@ class AIManager: 3. 以 JSON 格式返回: {{"summary": "摘要内容", "key_points": ["要点1", "要点2"]}}""" elif summary_type == "abstractive": - prompt = f"""对以下内容生成简洁的摘要: + prompt = f"""对以下内容生成简洁的摘要: {content_data.get("text", "")[:5000]} @@ -941,7 +941,7 @@ class AIManager: 3. 包含关键实体和概念""" elif summary_type == "key_points": - prompt = f"""从以下内容中提取关键要点: + prompt = f"""从以下内容中提取关键要点: {content_data.get("text", "")[:5000]} @@ -951,7 +951,7 @@ class AIManager: 3. 以 JSON 格式返回: {{"key_points": ["要点1", "要点2", ...]}}""" else: # timeline - prompt = f"""基于以下内容生成时间线摘要: + prompt = f"""基于以下内容生成时间线摘要: {content_data.get("text", "")[:5000]} @@ -960,72 +960,72 @@ class AIManager: 2. 标注时间节点(如果有) 3. 突出里程碑事件""" - headers = { + headers = { "Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json", } - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", - headers=headers, - json=payload, - timeout=60.0, + headers = headers, + json = payload, + timeout = 60.0, ) response.raise_for_status() - result = response.json() + result = response.json() - content = result["choices"][0]["message"]["content"] - tokens_used = result["usage"]["total_tokens"] + content = result["choices"][0]["message"]["content"] + tokens_used = result["usage"]["total_tokens"] # 解析关键要点 - key_points = [] + key_points = [] # 尝试从 JSON 中提取 - json_match = re.search(r"\{.*?\}", content, re.DOTALL) + json_match = re.search(r"\{.*?\}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) - key_points = data.get("key_points", []) + data = json.loads(json_match.group()) + key_points = data.get("key_points", []) if "summary" in data: - content = data["summary"] + content = data["summary"] except (json.JSONDecodeError, ValueError): pass # 如果没有提取到关键要点,从文本中提取 if not key_points: - lines = content.split("\n") - key_points = [ + lines = content.split("\n") + key_points = [ line.strip("- ").strip() for line in lines if line.strip().startswith("-") or line.strip().startswith("•") ] if not key_points: - key_points = [content[:200] + "..."] if len(content) > 200 else [content] + key_points = [content[:200] + "..."] if len(content) > 200 else [content] # 提取提及的实体 - entities_mentioned = content_data.get("entities", []) - entity_names = [e.get("name", "") for e in entities_mentioned[:10]] + entities_mentioned = content_data.get("entities", []) + entity_names = [e.get("name", "") for e in entities_mentioned[:10]] - summary = SmartSummary( - id=summary_id, - tenant_id=tenant_id, - project_id=project_id, - source_type=source_type, - source_id=source_id, - summary_type=summary_type, - content=content, - key_points=key_points[:8], - entities_mentioned=entity_names, - confidence=0.85, - tokens_used=tokens_used, - created_at=now, + summary = SmartSummary( + id = summary_id, + tenant_id = tenant_id, + project_id = project_id, + source_type = source_type, + source_id = source_id, + summary_type = summary_type, + content = content, + key_points = key_points[:8], + entities_mentioned = entity_names, + confidence = 0.85, + tokens_used = tokens_used, + created_at = now, ) with self._get_db() as conn: @@ -1068,24 +1068,24 @@ class AIManager: model_config: dict, ) -> PredictionModel: """创建预测模型""" - model_id = f"pm_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + model_id = f"pm_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - model = PredictionModel( - id=model_id, - tenant_id=tenant_id, - project_id=project_id, - name=name, - prediction_type=prediction_type, - target_entity_type=target_entity_type, - features=features, - model_config=model_config, - accuracy=None, - last_trained_at=None, - prediction_count=0, - is_active=True, - created_at=now, - updated_at=now, + model = PredictionModel( + id = model_id, + tenant_id = tenant_id, + project_id = project_id, + name = name, + prediction_type = prediction_type, + target_entity_type = target_entity_type, + features = features, + model_config = model_config, + accuracy = None, + last_trained_at = None, + prediction_count = 0, + is_active = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1121,8 +1121,8 @@ class AIManager: def get_prediction_model(self, model_id: str) -> PredictionModel | None: """获取预测模型""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM prediction_models WHERE id = ?", (model_id,) + row = conn.execute( + "SELECT * FROM prediction_models WHERE id = ?", (model_id, ) ).fetchone() if not row: @@ -1131,27 +1131,27 @@ class AIManager: return self._row_to_prediction_model(row) def list_prediction_models( - self, tenant_id: str, project_id: str | None = None + self, tenant_id: str, project_id: str | None = None ) -> list[PredictionModel]: """列出预测模型""" - query = "SELECT * FROM prediction_models WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM prediction_models WHERE tenant_id = ?" + params = [tenant_id] if project_id: - query += " AND project_id = ?" + query += " AND project_id = ?" params.append(project_id) query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_prediction_model(row) for row in rows] async def train_prediction_model( self, model_id: str, historical_data: list[dict] ) -> PredictionModel: """训练预测模型""" - model = self.get_prediction_model(model_id) + model = self.get_prediction_model(model_id) if not model: raise ValueError(f"Prediction model {model_id} not found") @@ -1159,16 +1159,16 @@ class AIManager: await asyncio.sleep(1) # 计算准确率(模拟) - accuracy = round(0.75 + random.random() * 0.2, 4) + accuracy = round(0.75 + random.random() * 0.2, 4) - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE prediction_models - SET accuracy = ?, last_trained_at = ?, updated_at = ? - WHERE id = ? + SET accuracy = ?, last_trained_at = ?, updated_at = ? + WHERE id = ? """, (accuracy, now, now, model_id), ) @@ -1178,39 +1178,39 @@ class AIManager: async def predict(self, model_id: str, input_data: dict) -> PredictionResult: """进行预测""" - model = self.get_prediction_model(model_id) + model = self.get_prediction_model(model_id) if not model or not model.is_active: raise ValueError(f"Prediction model {model_id} not available") - prediction_id = f"pr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + prediction_id = f"pr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 根据预测类型进行不同的预测逻辑 if model.prediction_type == PredictionType.TREND: - prediction_data = self._predict_trend(input_data, model) + prediction_data = self._predict_trend(input_data, model) elif model.prediction_type == PredictionType.ANOMALY: - prediction_data = self._detect_anomaly(input_data, model) + prediction_data = self._detect_anomaly(input_data, model) elif model.prediction_type == PredictionType.ENTITY_GROWTH: - prediction_data = self._predict_entity_growth(input_data, model) + prediction_data = self._predict_entity_growth(input_data, model) elif model.prediction_type == PredictionType.RELATION_EVOLUTION: - prediction_data = self._predict_relation_evolution(input_data, model) + prediction_data = self._predict_relation_evolution(input_data, model) else: - prediction_data = {"value": "unknown", "confidence": 0} + prediction_data = {"value": "unknown", "confidence": 0} - confidence = prediction_data.get("confidence", 0.8) - explanation = prediction_data.get("explanation", "基于历史数据模式预测") + confidence = prediction_data.get("confidence", 0.8) + explanation = prediction_data.get("explanation", "基于历史数据模式预测") - result = PredictionResult( - id=prediction_id, - model_id=model_id, - prediction_type=model.prediction_type, - target_id=input_data.get("target_id"), - prediction_data=prediction_data, - confidence=confidence, - explanation=explanation, - actual_value=None, - is_correct=None, - created_at=now, + result = PredictionResult( + id = prediction_id, + model_id = model_id, + prediction_type = model.prediction_type, + target_id = input_data.get("target_id"), + prediction_data = prediction_data, + confidence = confidence, + explanation = explanation, + actual_value = None, + is_correct = None, + created_at = now, ) with self._get_db() as conn: @@ -1237,8 +1237,8 @@ class AIManager: # 更新预测计数 conn.execute( - "UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?", - (model_id,), + "UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?", + (model_id, ), ) conn.commit() @@ -1246,7 +1246,7 @@ class AIManager: def _predict_trend(self, input_data: dict, model: PredictionModel) -> dict: """趋势预测""" - historical_values = input_data.get("historical_values", []) + historical_values = input_data.get("historical_values", []) if len(historical_values) < 2: return { @@ -1257,23 +1257,23 @@ class AIManager: } # 简单线性趋势预测 - 使用最小二乘法计算斜率 - n = len(historical_values) - x = list(range(n)) - y = historical_values + n = len(historical_values) + x = list(range(n)) + y = historical_values # 计算均值 - mean_x = sum(x) / n - mean_y = sum(y) / n + mean_x = sum(x) / n + mean_y = sum(y) / n # 计算斜率 (最小二乘法) - numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n)) - denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) - slope = numerator / denominator if denominator != 0 else 0 + numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n)) + denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) + slope = numerator / denominator if denominator != 0 else 0 # 预测下一个值 - next_value = y[-1] + slope + next_value = y[-1] + slope - trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable" + trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable" return { "predicted_value": round(next_value, 2), @@ -1285,8 +1285,8 @@ class AIManager: def _detect_anomaly(self, input_data: dict, model: PredictionModel) -> dict: """异常检测""" - value = input_data.get("value") - historical_values = input_data.get("historical_values", []) + value = input_data.get("value") + historical_values = input_data.get("historical_values", []) if not historical_values or value is None: return { @@ -1297,15 +1297,15 @@ class AIManager: } # 计算均值和标准差 - mean = statistics.mean(historical_values) - std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0 + mean = statistics.mean(historical_values) + std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0 if std == 0: - is_anomaly = value != mean - z_score = 0 if value == mean else 3 + is_anomaly = value != mean + z_score = 0 if value == mean else 3 else: - z_score = abs(value - mean) / std - is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常 + z_score = abs(value - mean) / std + is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常 return { "is_anomaly": is_anomaly, @@ -1319,7 +1319,7 @@ class AIManager: def _predict_entity_growth(self, input_data: dict, model: PredictionModel) -> dict: """实体增长预测""" - entity_history = input_data.get("entity_history", []) + entity_history = input_data.get("entity_history", []) if len(entity_history) < 3: return { @@ -1330,14 +1330,14 @@ class AIManager: } # 计算增长率 - counts = [h.get("count", 0) for h in entity_history] - growth_rates = [ + counts = [h.get("count", 0) for h in entity_history] + growth_rates = [ (counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts)) ] - avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0 + avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0 # 预测下一个周期的实体数量 - predicted_count = counts[-1] * (1 + avg_growth_rate) + predicted_count = counts[-1] * (1 + avg_growth_rate) return { "predicted_count": round(predicted_count), @@ -1349,7 +1349,7 @@ class AIManager: def _predict_relation_evolution(self, input_data: dict, model: PredictionModel) -> dict: """关系演变预测""" - relation_history = input_data.get("relation_history", []) + relation_history = input_data.get("relation_history", []) if len(relation_history) < 2: return { @@ -1359,16 +1359,16 @@ class AIManager: } # 分析关系变化趋势 - relation_counts = defaultdict(int) + relation_counts = defaultdict(int) for snapshot in relation_history: for rel in snapshot.get("relations", []): relation_counts[rel.get("type", "unknown")] += 1 # 预测可能出现的新关系类型 - predicted_relations = [ + predicted_relations = [ {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} for rel_type, count in sorted( - relation_counts.items(), key=lambda x: x[1], reverse=True + relation_counts.items(), key = lambda x: x[1], reverse = True )[:5] ] @@ -1379,12 +1379,12 @@ class AIManager: "explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势", } - def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]: + def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]: """获取预测结果历史""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM prediction_results - WHERE model_id = ? + WHERE model_id = ? ORDER BY created_at DESC LIMIT ?""", (model_id, limit), @@ -1399,8 +1399,8 @@ class AIManager: with self._get_db() as conn: conn.execute( """UPDATE prediction_results - SET actual_value = ?, is_correct = ? - WHERE id = ?""", + SET actual_value = ?, is_correct = ? + WHERE id = ?""", (actual_value, is_correct, prediction_id), ) conn.commit() @@ -1410,106 +1410,106 @@ class AIManager: def _row_to_custom_model(self, row) -> CustomModel: """将数据库行转换为 CustomModel""" return CustomModel( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - model_type=ModelType(row["model_type"]), - status=ModelStatus(row["status"]), - training_data=json.loads(row["training_data"]), - hyperparameters=json.loads(row["hyperparameters"]), - metrics=json.loads(row["metrics"]), - model_path=row["model_path"], - created_at=row["created_at"], - updated_at=row["updated_at"], - trained_at=row["trained_at"], - created_by=row["created_by"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + model_type = ModelType(row["model_type"]), + status = ModelStatus(row["status"]), + training_data = json.loads(row["training_data"]), + hyperparameters = json.loads(row["hyperparameters"]), + metrics = json.loads(row["metrics"]), + model_path = row["model_path"], + created_at = row["created_at"], + updated_at = row["updated_at"], + trained_at = row["trained_at"], + created_by = row["created_by"], ) def _row_to_training_sample(self, row) -> TrainingSample: """将数据库行转换为 TrainingSample""" return TrainingSample( - id=row["id"], - model_id=row["model_id"], - text=row["text"], - entities=json.loads(row["entities"]), - metadata=json.loads(row["metadata"]), - created_at=row["created_at"], + id = row["id"], + model_id = row["model_id"], + text = row["text"], + entities = json.loads(row["entities"]), + metadata = json.loads(row["metadata"]), + created_at = row["created_at"], ) def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis: """将数据库行转换为 MultimodalAnalysis""" return MultimodalAnalysis( - id=row["id"], - tenant_id=row["tenant_id"], - project_id=row["project_id"], - provider=MultimodalProvider(row["provider"]), - input_type=row["input_type"], - input_urls=json.loads(row["input_urls"]), - prompt=row["prompt"], - result=json.loads(row["result"]), - tokens_used=row["tokens_used"], - cost=row["cost"], - created_at=row["created_at"], + id = row["id"], + tenant_id = row["tenant_id"], + project_id = row["project_id"], + provider = MultimodalProvider(row["provider"]), + input_type = row["input_type"], + input_urls = json.loads(row["input_urls"]), + prompt = row["prompt"], + result = json.loads(row["result"]), + tokens_used = row["tokens_used"], + cost = row["cost"], + created_at = row["created_at"], ) def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG: """将数据库行转换为 KnowledgeGraphRAG""" return KnowledgeGraphRAG( - id=row["id"], - tenant_id=row["tenant_id"], - project_id=row["project_id"], - name=row["name"], - description=row["description"], - kg_config=json.loads(row["kg_config"]), - retrieval_config=json.loads(row["retrieval_config"]), - generation_config=json.loads(row["generation_config"]), - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + project_id = row["project_id"], + name = row["name"], + description = row["description"], + kg_config = json.loads(row["kg_config"]), + retrieval_config = json.loads(row["retrieval_config"]), + generation_config = json.loads(row["generation_config"]), + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_prediction_model(self, row) -> PredictionModel: """将数据库行转换为 PredictionModel""" return PredictionModel( - id=row["id"], - tenant_id=row["tenant_id"], - project_id=row["project_id"], - name=row["name"], - prediction_type=PredictionType(row["prediction_type"]), - target_entity_type=row["target_entity_type"], - features=json.loads(row["features"]), - model_config=json.loads(row["model_config"]), - accuracy=row["accuracy"], - last_trained_at=row["last_trained_at"], - prediction_count=row["prediction_count"], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + project_id = row["project_id"], + name = row["name"], + prediction_type = PredictionType(row["prediction_type"]), + target_entity_type = row["target_entity_type"], + features = json.loads(row["features"]), + model_config = json.loads(row["model_config"]), + accuracy = row["accuracy"], + last_trained_at = row["last_trained_at"], + prediction_count = row["prediction_count"], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_prediction_result(self, row) -> PredictionResult: """将数据库行转换为 PredictionResult""" return PredictionResult( - id=row["id"], - model_id=row["model_id"], - prediction_type=PredictionType(row["prediction_type"]), - target_id=row["target_id"], - prediction_data=json.loads(row["prediction_data"]), - confidence=row["confidence"], - explanation=row["explanation"], - actual_value=row["actual_value"], - is_correct=row["is_correct"], - created_at=row["created_at"], + id = row["id"], + model_id = row["model_id"], + prediction_type = PredictionType(row["prediction_type"]), + target_id = row["target_id"], + prediction_data = json.loads(row["prediction_data"]), + confidence = row["confidence"], + explanation = row["explanation"], + actual_value = row["actual_value"], + is_correct = row["is_correct"], + created_at = row["created_at"], ) # Singleton instance -_ai_manager = None +_ai_manager = None def get_ai_manager() -> AIManager: global _ai_manager if _ai_manager is None: - _ai_manager = AIManager() + _ai_manager = AIManager() return _ai_manager diff --git a/backend/api_key_manager.py b/backend/api_key_manager.py index 04ee2cf..6a81461 100644 --- a/backend/api_key_manager.py +++ b/backend/api_key_manager.py @@ -13,13 +13,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum -DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") +DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") class ApiKeyStatus(Enum): - ACTIVE = "active" - REVOKED = "revoked" - EXPIRED = "expired" + ACTIVE = "active" + REVOKED = "revoked" + EXPIRED = "expired" @dataclass @@ -37,18 +37,18 @@ class ApiKey: last_used_at: str | None revoked_at: str | None revoked_reason: str | None - total_calls: int = 0 + total_calls: int = 0 class ApiKeyManager: """API Key 管理器""" # Key 前缀 - KEY_PREFIX = "ak_live_" - KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40) + KEY_PREFIX = "ak_live_" + KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40) - def __init__(self, db_path: str = DB_PATH) -> None: - self.db_path = db_path + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path self._init_db() def _init_db(self) -> None: @@ -117,7 +117,7 @@ class ApiKeyManager: def _generate_key(self) -> str: """生成新的 API Key""" # 生成 40 字符的随机字符串 - random_part = secrets.token_urlsafe(30)[:40] + random_part = secrets.token_urlsafe(30)[:40] return f"{self.KEY_PREFIX}{random_part}" def _hash_key(self, key: str) -> str: @@ -131,10 +131,10 @@ class ApiKeyManager: def create_key( self, name: str, - owner_id: str | None = None, - permissions: list[str] = None, - rate_limit: int = 60, - expires_days: int | None = None, + owner_id: str | None = None, + permissions: list[str] = None, + rate_limit: int = 60, + expires_days: int | None = None, ) -> tuple[str, ApiKey]: """ 创建新的 API Key @@ -143,32 +143,32 @@ class ApiKeyManager: tuple: (原始key(仅返回一次), ApiKey对象) """ if permissions is None: - permissions = ["read"] + permissions = ["read"] - key_id = secrets.token_hex(16) - raw_key = self._generate_key() - key_hash = self._hash_key(raw_key) - key_preview = self._get_preview(raw_key) + key_id = secrets.token_hex(16) + raw_key = self._generate_key() + key_hash = self._hash_key(raw_key) + key_preview = self._get_preview(raw_key) - expires_at = None + expires_at = None if expires_days: - expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() + expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat() - api_key = ApiKey( - id=key_id, - key_hash=key_hash, - key_preview=key_preview, - name=name, - owner_id=owner_id, - permissions=permissions, - rate_limit=rate_limit, - status=ApiKeyStatus.ACTIVE.value, - created_at=datetime.now().isoformat(), - expires_at=expires_at, - last_used_at=None, - revoked_at=None, - revoked_reason=None, - total_calls=0, + api_key = ApiKey( + id = key_id, + key_hash = key_hash, + key_preview = key_preview, + name = name, + owner_id = owner_id, + permissions = permissions, + rate_limit = rate_limit, + status = ApiKeyStatus.ACTIVE.value, + created_at = datetime.now().isoformat(), + expires_at = expires_at, + last_used_at = None, + revoked_at = None, + revoked_reason = None, + total_calls = 0, ) with sqlite3.connect(self.db_path) as conn: @@ -203,16 +203,16 @@ class ApiKeyManager: Returns: ApiKey if valid, None otherwise """ - key_hash = self._hash_key(key) + key_hash = self._hash_key(key) with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone() + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone() if not row: return None - api_key = self._row_to_api_key(row) + api_key = self._row_to_api_key(row) # 检查状态 if api_key.status != ApiKeyStatus.ACTIVE.value: @@ -220,11 +220,11 @@ class ApiKeyManager: # 检查是否过期 if api_key.expires_at: - expires = datetime.fromisoformat(api_key.expires_at) + expires = datetime.fromisoformat(api_key.expires_at) if datetime.now() > expires: # 更新状态为过期 conn.execute( - "UPDATE api_keys SET status = ? WHERE id = ?", + "UPDATE api_keys SET status = ? WHERE id = ?", (ApiKeyStatus.EXPIRED.value, api_key.id), ) conn.commit() @@ -232,22 +232,22 @@ class ApiKeyManager: return api_key - def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool: + def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool: """撤销 API Key""" with sqlite3.connect(self.db_path) as conn: # 验证所有权(如果提供了 owner_id) if owner_id: - row = conn.execute( - "SELECT owner_id FROM api_keys WHERE id = ?", (key_id,) + row = conn.execute( + "SELECT owner_id FROM api_keys WHERE id = ?", (key_id, ) ).fetchone() if not row or row[0] != owner_id: return False - cursor = conn.execute( + cursor = conn.execute( """ UPDATE api_keys - SET status = ?, revoked_at = ?, revoked_reason = ? - WHERE id = ? AND status = ? + SET status = ?, revoked_at = ?, revoked_reason = ? + WHERE id = ? AND status = ? """, ( ApiKeyStatus.REVOKED.value, @@ -260,17 +260,17 @@ class ApiKeyManager: conn.commit() return cursor.rowcount > 0 - def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None: + def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None: """通过 ID 获取 API Key(不包含敏感信息)""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row if owner_id: - row = conn.execute( - "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id) + row = conn.execute( + "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id) ).fetchone() else: - row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone() + row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone() if row: return self._row_to_api_key(row) @@ -278,54 +278,54 @@ class ApiKeyManager: def list_keys( self, - owner_id: str | None = None, - status: str | None = None, - limit: int = 100, - offset: int = 0, + owner_id: str | None = None, + status: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[ApiKey]: """列出 API Keys""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row - query = "SELECT * FROM api_keys WHERE 1=1" - params = [] + query = "SELECT * FROM api_keys WHERE 1 = 1" + params = [] if owner_id: - query += " AND owner_id = ?" + query += " AND owner_id = ?" params.append(owner_id) if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_api_key(row) for row in rows] def update_key( self, key_id: str, - name: str | None = None, - permissions: list[str] | None = None, - rate_limit: int | None = None, - owner_id: str | None = None, + name: str | None = None, + permissions: list[str] | None = None, + rate_limit: int | None = None, + owner_id: str | None = None, ) -> bool: """更新 API Key 信息""" - updates = [] - params = [] + updates = [] + params = [] if name is not None: - updates.append("name = ?") + updates.append("name = ?") params.append(name) if permissions is not None: - updates.append("permissions = ?") + updates.append("permissions = ?") params.append(json.dumps(permissions)) if rate_limit is not None: - updates.append("rate_limit = ?") + updates.append("rate_limit = ?") params.append(rate_limit) if not updates: @@ -336,14 +336,14 @@ class ApiKeyManager: with sqlite3.connect(self.db_path) as conn: # 验证所有权 if owner_id: - row = conn.execute( - "SELECT owner_id FROM api_keys WHERE id = ?", (key_id,) + row = conn.execute( + "SELECT owner_id FROM api_keys WHERE id = ?", (key_id, ) ).fetchone() if not row or row[0] != owner_id: return False - query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?" - cursor = conn.execute(query, params) + query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?" + cursor = conn.execute(query, params) conn.commit() return cursor.rowcount > 0 @@ -353,8 +353,8 @@ class ApiKeyManager: conn.execute( """ UPDATE api_keys - SET last_used_at = ?, total_calls = total_calls + 1 - WHERE id = ? + SET last_used_at = ?, total_calls = total_calls + 1 + WHERE id = ? """, (datetime.now().isoformat(), key_id), ) @@ -365,12 +365,12 @@ class ApiKeyManager: api_key_id: str, endpoint: str, method: str, - status_code: int = 200, - response_time_ms: int = 0, - ip_address: str = "", - user_agent: str = "", - error_message: str = "", - ): + status_code: int = 200, + response_time_ms: int = 0, + ip_address: str = "", + user_agent: str = "", + error_message: str = "", + ) -> None: """记录 API 调用日志""" with sqlite3.connect(self.db_path) as conn: conn.execute( @@ -395,21 +395,21 @@ class ApiKeyManager: def get_call_logs( self, - api_key_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - limit: int = 100, - offset: int = 0, + api_key_id: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[dict]: """获取 API 调用日志""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row - query = "SELECT * FROM api_call_logs WHERE 1=1" - params = [] + query = "SELECT * FROM api_call_logs WHERE 1 = 1" + params = [] if api_key_id: - query += " AND api_key_id = ?" + query += " AND api_key_id = ?" params.append(api_key_id) if start_date: @@ -423,16 +423,16 @@ class ApiKeyManager: query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [dict(row) for row in rows] - def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict: + def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict: """获取 API 调用统计""" with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row + conn.row_factory = sqlite3.Row # 总体统计 - query = f""" + query = f""" SELECT COUNT(*) as total_calls, COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls, @@ -444,15 +444,15 @@ class ApiKeyManager: WHERE created_at >= date('now', '-{days} days') """ - params = [] + params = [] if api_key_id: - query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") + query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") params.insert(0, api_key_id) - row = conn.execute(query, params).fetchone() + row = conn.execute(query, params).fetchone() # 按端点统计 - endpoint_query = f""" + endpoint_query = f""" SELECT endpoint, method, @@ -462,19 +462,19 @@ class ApiKeyManager: WHERE created_at >= date('now', '-{days} days') """ - endpoint_params = [] + endpoint_params = [] if api_key_id: - endpoint_query = endpoint_query.replace( - "WHERE created_at", "WHERE api_key_id = ? AND created_at" + endpoint_query = endpoint_query.replace( + "WHERE created_at", "WHERE api_key_id = ? AND created_at" ) endpoint_params.insert(0, api_key_id) endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC" - endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall() + endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall() # 按天统计 - daily_query = f""" + daily_query = f""" SELECT date(created_at) as date, COUNT(*) as calls, @@ -483,16 +483,16 @@ class ApiKeyManager: WHERE created_at >= date('now', '-{days} days') """ - daily_params = [] + daily_params = [] if api_key_id: - daily_query = daily_query.replace( - "WHERE created_at", "WHERE api_key_id = ? AND created_at" + daily_query = daily_query.replace( + "WHERE created_at", "WHERE api_key_id = ? AND created_at" ) daily_params.insert(0, api_key_id) daily_query += " GROUP BY date(created_at) ORDER BY date" - daily_rows = conn.execute(daily_query, daily_params).fetchall() + daily_rows = conn.execute(daily_query, daily_params).fetchall() return { "summary": { @@ -510,30 +510,30 @@ class ApiKeyManager: def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey: """将数据库行转换为 ApiKey 对象""" return ApiKey( - id=row["id"], - key_hash=row["key_hash"], - key_preview=row["key_preview"], - name=row["name"], - owner_id=row["owner_id"], - permissions=json.loads(row["permissions"]), - rate_limit=row["rate_limit"], - status=row["status"], - created_at=row["created_at"], - expires_at=row["expires_at"], - last_used_at=row["last_used_at"], - revoked_at=row["revoked_at"], - revoked_reason=row["revoked_reason"], - total_calls=row["total_calls"], + id = row["id"], + key_hash = row["key_hash"], + key_preview = row["key_preview"], + name = row["name"], + owner_id = row["owner_id"], + permissions = json.loads(row["permissions"]), + rate_limit = row["rate_limit"], + status = row["status"], + created_at = row["created_at"], + expires_at = row["expires_at"], + last_used_at = row["last_used_at"], + revoked_at = row["revoked_at"], + revoked_reason = row["revoked_reason"], + total_calls = row["total_calls"], ) # 全局实例 -_api_key_manager: ApiKeyManager | None = None +_api_key_manager: ApiKeyManager | None = None def get_api_key_manager() -> ApiKeyManager: """获取 API Key 管理器实例""" global _api_key_manager if _api_key_manager is None: - _api_key_manager = ApiKeyManager() + _api_key_manager = ApiKeyManager() return _api_key_manager diff --git a/backend/collaboration_manager.py b/backend/collaboration_manager.py index 40f99a4..0c33ad5 100644 --- a/backend/collaboration_manager.py +++ b/backend/collaboration_manager.py @@ -15,29 +15,29 @@ from typing import Any class SharePermission(Enum): """分享权限级别""" - READ_ONLY = "read_only" # 只读 - COMMENT = "comment" # 可评论 - EDIT = "edit" # 可编辑 - ADMIN = "admin" # 管理员 + READ_ONLY = "read_only" # 只读 + COMMENT = "comment" # 可评论 + EDIT = "edit" # 可编辑 + ADMIN = "admin" # 管理员 class CommentTargetType(Enum): """评论目标类型""" - ENTITY = "entity" # 实体评论 - RELATION = "relation" # 关系评论 - TRANSCRIPT = "transcript" # 转录文本评论 - PROJECT = "project" # 项目级评论 + ENTITY = "entity" # 实体评论 + RELATION = "relation" # 关系评论 + TRANSCRIPT = "transcript" # 转录文本评论 + PROJECT = "project" # 项目级评论 class ChangeType(Enum): """变更类型""" - CREATE = "create" # 创建 - UPDATE = "update" # 更新 - DELETE = "delete" # 删除 - MERGE = "merge" # 合并 - SPLIT = "split" # 拆分 + CREATE = "create" # 创建 + UPDATE = "update" # 更新 + DELETE = "delete" # 删除 + MERGE = "merge" # 合并 + SPLIT = "split" # 拆分 @dataclass @@ -136,10 +136,10 @@ class TeamSpace: class CollaborationManager: """协作管理主类""" - def __init__(self, db_manager=None): - self.db = db_manager - self._shares_cache: dict[str, ProjectShare] = {} - self._comments_cache: dict[str, list[Comment]] = {} + def __init__(self, db_manager = None) -> None: + self.db = db_manager + self._shares_cache: dict[str, ProjectShare] = {} + self._comments_cache: dict[str, list[Comment]] = {} # ============ 项目分享 ============ @@ -147,57 +147,57 @@ class CollaborationManager: self, project_id: str, created_by: str, - permission: str = "read_only", - expires_in_days: int | None = None, - max_uses: int | None = None, - password: str | None = None, - allow_download: bool = False, - allow_export: bool = False, + permission: str = "read_only", + expires_in_days: int | None = None, + max_uses: int | None = None, + password: str | None = None, + allow_download: bool = False, + allow_export: bool = False, ) -> ProjectShare: """创建项目分享链接""" - share_id = str(uuid.uuid4()) - token = self._generate_share_token(project_id) + share_id = str(uuid.uuid4()) + token = self._generate_share_token(project_id) - now = datetime.now().isoformat() - expires_at = None + now = datetime.now().isoformat() + expires_at = None if expires_in_days: - expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat() + expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat() - password_hash = None + password_hash = None if password: - password_hash = hashlib.sha256(password.encode()).hexdigest() + password_hash = hashlib.sha256(password.encode()).hexdigest() - share = ProjectShare( - id=share_id, - project_id=project_id, - token=token, - permission=permission, - created_by=created_by, - created_at=now, - expires_at=expires_at, - max_uses=max_uses, - use_count=0, - password_hash=password_hash, - is_active=True, - allow_download=allow_download, - allow_export=allow_export, + share = ProjectShare( + id = share_id, + project_id = project_id, + token = token, + permission = permission, + created_by = created_by, + created_at = now, + expires_at = expires_at, + max_uses = max_uses, + use_count = 0, + password_hash = password_hash, + is_active = True, + allow_download = allow_download, + allow_export = allow_export, ) # 保存到数据库 if self.db: self._save_share_to_db(share) - self._shares_cache[token] = share + self._shares_cache[token] = share return share def _generate_share_token(self, project_id: str) -> str: """生成分享令牌""" - data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}" + data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}" return hashlib.sha256(data.encode()).hexdigest()[:32] def _save_share_to_db(self, share: ProjectShare) -> None: """保存分享记录到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO project_shares @@ -224,12 +224,12 @@ class CollaborationManager: ) self.db.conn.commit() - def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None: + def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None: """验证分享令牌""" # 从缓存或数据库获取 - share = self._shares_cache.get(token) + share = self._shares_cache.get(token) if not share and self.db: - share = self._get_share_from_db(token) + share = self._get_share_from_db(token) if not share: return None @@ -250,7 +250,7 @@ class CollaborationManager: if share.password_hash: if not password: return None - password_hash = hashlib.sha256(password.encode()).hexdigest() + password_hash = hashlib.sha256(password.encode()).hexdigest() if password_hash != share.password_hash: return None @@ -258,63 +258,63 @@ class CollaborationManager: def _get_share_from_db(self, token: str) -> ProjectShare | None: """从数据库获取分享记录""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ - SELECT * FROM project_shares WHERE token = ? + SELECT * FROM project_shares WHERE token = ? """, - (token,), + (token, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return None return ProjectShare( - id=row[0], - project_id=row[1], - token=row[2], - permission=row[3], - created_by=row[4], - created_at=row[5], - expires_at=row[6], - max_uses=row[7], - use_count=row[8], - password_hash=row[9], - is_active=bool(row[10]), - allow_download=bool(row[11]), - allow_export=bool(row[12]), + id = row[0], + project_id = row[1], + token = row[2], + permission = row[3], + created_by = row[4], + created_at = row[5], + expires_at = row[6], + max_uses = row[7], + use_count = row[8], + password_hash = row[9], + is_active = bool(row[10]), + allow_download = bool(row[11]), + allow_export = bool(row[12]), ) def increment_share_usage(self, token: str) -> None: """增加分享链接使用次数""" - share = self._shares_cache.get(token) + share = self._shares_cache.get(token) if share: share.use_count += 1 if self.db: - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE project_shares - SET use_count = use_count + 1 - WHERE token = ? + SET use_count = use_count + 1 + WHERE token = ? """, - (token,), + (token, ), ) self.db.conn.commit() def revoke_share_link(self, share_id: str, revoked_by: str) -> bool: """撤销分享链接""" if self.db: - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE project_shares - SET is_active = 0 - WHERE id = ? + SET is_active = 0 + WHERE id = ? """, - (share_id,), + (share_id, ), ) self.db.conn.commit() return cursor.rowcount > 0 @@ -325,33 +325,33 @@ class CollaborationManager: if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM project_shares - WHERE project_id = ? + WHERE project_id = ? ORDER BY created_at DESC """, - (project_id,), + (project_id, ), ) - shares = [] + shares = [] for row in cursor.fetchall(): shares.append( ProjectShare( - id=row[0], - project_id=row[1], - token=row[2], - permission=row[3], - created_by=row[4], - created_at=row[5], - expires_at=row[6], - max_uses=row[7], - use_count=row[8], - password_hash=row[9], - is_active=bool(row[10]), - allow_download=bool(row[11]), - allow_export=bool(row[12]), + id = row[0], + project_id = row[1], + token = row[2], + permission = row[3], + created_by = row[4], + created_at = row[5], + expires_at = row[6], + max_uses = row[7], + use_count = row[8], + password_hash = row[9], + is_active = bool(row[10]), + allow_download = bool(row[11]), + allow_export = bool(row[12]), ) ) return shares @@ -366,46 +366,46 @@ class CollaborationManager: author: str, author_name: str, content: str, - parent_id: str | None = None, - mentions: list[str] | None = None, - attachments: list[dict] | None = None, + parent_id: str | None = None, + mentions: list[str] | None = None, + attachments: list[dict] | None = None, ) -> Comment: """添加评论""" - comment_id = str(uuid.uuid4()) - now = datetime.now().isoformat() + comment_id = str(uuid.uuid4()) + now = datetime.now().isoformat() - comment = Comment( - id=comment_id, - project_id=project_id, - target_type=target_type, - target_id=target_id, - parent_id=parent_id, - author=author, - author_name=author_name, - content=content, - created_at=now, - updated_at=now, - resolved=False, - resolved_by=None, - resolved_at=None, - mentions=mentions or [], - attachments=attachments or [], + comment = Comment( + id = comment_id, + project_id = project_id, + target_type = target_type, + target_id = target_id, + parent_id = parent_id, + author = author, + author_name = author_name, + content = content, + created_at = now, + updated_at = now, + resolved = False, + resolved_by = None, + resolved_at = None, + mentions = mentions or [], + attachments = attachments or [], ) if self.db: self._save_comment_to_db(comment) # 更新缓存 - key = f"{target_type}:{target_id}" + key = f"{target_type}:{target_id}" if key not in self._comments_cache: - self._comments_cache[key] = [] + self._comments_cache[key] = [] self._comments_cache[key].append(comment) return comment def _save_comment_to_db(self, comment: Comment) -> None: """保存评论到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO comments @@ -435,18 +435,18 @@ class CollaborationManager: self.db.conn.commit() def get_comments( - self, target_type: str, target_id: str, include_resolved: bool = True + self, target_type: str, target_id: str, include_resolved: bool = True ) -> list[Comment]: """获取评论列表""" if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() if include_resolved: cursor.execute( """ SELECT * FROM comments - WHERE target_type = ? AND target_id = ? + WHERE target_type = ? AND target_id = ? ORDER BY created_at ASC """, (target_type, target_id), @@ -455,13 +455,13 @@ class CollaborationManager: cursor.execute( """ SELECT * FROM comments - WHERE target_type = ? AND target_id = ? AND resolved = 0 + WHERE target_type = ? AND target_id = ? AND resolved = 0 ORDER BY created_at ASC """, (target_type, target_id), ) - comments = [] + comments = [] for row in cursor.fetchall(): comments.append(self._row_to_comment(row)) return comments @@ -469,21 +469,21 @@ class CollaborationManager: def _row_to_comment(self, row) -> Comment: """将数据库行转换为Comment对象""" return Comment( - id=row[0], - project_id=row[1], - target_type=row[2], - target_id=row[3], - parent_id=row[4], - author=row[5], - author_name=row[6], - content=row[7], - created_at=row[8], - updated_at=row[9], - resolved=bool(row[10]), - resolved_by=row[11], - resolved_at=row[12], - mentions=json.loads(row[13]) if row[13] else [], - attachments=json.loads(row[14]) if row[14] else [], + id = row[0], + project_id = row[1], + target_type = row[2], + target_id = row[3], + parent_id = row[4], + author = row[5], + author_name = row[6], + content = row[7], + created_at = row[8], + updated_at = row[9], + resolved = bool(row[10]), + resolved_by = row[11], + resolved_at = row[12], + mentions = json.loads(row[13]) if row[13] else [], + attachments = json.loads(row[14]) if row[14] else [], ) def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None: @@ -491,13 +491,13 @@ class CollaborationManager: if not self.db: return None - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE comments - SET content = ?, updated_at = ? - WHERE id = ? AND author = ? + SET content = ?, updated_at = ? + WHERE id = ? AND author = ? """, (content, now, comment_id, updated_by), ) @@ -509,9 +509,9 @@ class CollaborationManager: def _get_comment_by_id(self, comment_id: str) -> Comment | None: """根据ID获取评论""" - cursor = self.db.conn.cursor() - cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,)) - row = cursor.fetchone() + cursor = self.db.conn.cursor() + cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id, )) + row = cursor.fetchone() if row: return self._row_to_comment(row) return None @@ -521,13 +521,13 @@ class CollaborationManager: if not self.db: return False - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE comments - SET resolved = 1, resolved_by = ?, resolved_at = ? - WHERE id = ? + SET resolved = 1, resolved_by = ?, resolved_at = ? + WHERE id = ? """, (resolved_by, now, comment_id), ) @@ -539,13 +539,13 @@ class CollaborationManager: if not self.db: return False - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() # 只允许作者或管理员删除 cursor.execute( """ DELETE FROM comments - WHERE id = ? AND (author = ? OR ? IN ( - SELECT created_by FROM projects WHERE id = comments.project_id + WHERE id = ? AND (author = ? OR ? IN ( + SELECT created_by FROM projects WHERE id = comments.project_id )) """, (comment_id, deleted_by, deleted_by), @@ -554,24 +554,24 @@ class CollaborationManager: return cursor.rowcount > 0 def get_project_comments( - self, project_id: str, limit: int = 50, offset: int = 0 + self, project_id: str, limit: int = 50, offset: int = 0 ) -> list[Comment]: """获取项目下的所有评论""" if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM comments - WHERE project_id = ? + WHERE project_id = ? ORDER BY created_at DESC LIMIT ? OFFSET ? """, (project_id, limit, offset), ) - comments = [] + comments = [] for row in cursor.fetchall(): comments.append(self._row_to_comment(row)) return comments @@ -587,32 +587,32 @@ class CollaborationManager: entity_name: str, changed_by: str, changed_by_name: str, - old_value: dict | None = None, - new_value: dict | None = None, - description: str = "", - session_id: str | None = None, + old_value: dict | None = None, + new_value: dict | None = None, + description: str = "", + session_id: str | None = None, ) -> ChangeRecord: """记录变更""" - record_id = str(uuid.uuid4()) - now = datetime.now().isoformat() + record_id = str(uuid.uuid4()) + now = datetime.now().isoformat() - record = ChangeRecord( - id=record_id, - project_id=project_id, - change_type=change_type, - entity_type=entity_type, - entity_id=entity_id, - entity_name=entity_name, - changed_by=changed_by, - changed_by_name=changed_by_name, - changed_at=now, - old_value=old_value, - new_value=new_value, - description=description, - session_id=session_id, - reverted=False, - reverted_at=None, - reverted_by=None, + record = ChangeRecord( + id = record_id, + project_id = project_id, + change_type = change_type, + entity_type = entity_type, + entity_id = entity_id, + entity_name = entity_name, + changed_by = changed_by, + changed_by_name = changed_by_name, + changed_at = now, + old_value = old_value, + new_value = new_value, + description = description, + session_id = session_id, + reverted = False, + reverted_at = None, + reverted_by = None, ) if self.db: @@ -622,7 +622,7 @@ class CollaborationManager: def _save_change_to_db(self, record: ChangeRecord) -> None: """保存变更记录到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO change_history @@ -655,22 +655,22 @@ class CollaborationManager: def get_change_history( self, project_id: str, - entity_type: str | None = None, - entity_id: str | None = None, - limit: int = 50, - offset: int = 0, + entity_type: str | None = None, + entity_id: str | None = None, + limit: int = 50, + offset: int = 0, ) -> list[ChangeRecord]: """获取变更历史""" if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() if entity_type and entity_id: cursor.execute( """ SELECT * FROM change_history - WHERE project_id = ? AND entity_type = ? AND entity_id = ? + WHERE project_id = ? AND entity_type = ? AND entity_id = ? ORDER BY changed_at DESC LIMIT ? OFFSET ? """, @@ -680,7 +680,7 @@ class CollaborationManager: cursor.execute( """ SELECT * FROM change_history - WHERE project_id = ? AND entity_type = ? + WHERE project_id = ? AND entity_type = ? ORDER BY changed_at DESC LIMIT ? OFFSET ? """, @@ -690,14 +690,14 @@ class CollaborationManager: cursor.execute( """ SELECT * FROM change_history - WHERE project_id = ? + WHERE project_id = ? ORDER BY changed_at DESC LIMIT ? OFFSET ? """, (project_id, limit, offset), ) - records = [] + records = [] for row in cursor.fetchall(): records.append(self._row_to_change_record(row)) return records @@ -705,22 +705,22 @@ class CollaborationManager: def _row_to_change_record(self, row) -> ChangeRecord: """将数据库行转换为ChangeRecord对象""" return ChangeRecord( - id=row[0], - project_id=row[1], - change_type=row[2], - entity_type=row[3], - entity_id=row[4], - entity_name=row[5], - changed_by=row[6], - changed_by_name=row[7], - changed_at=row[8], - old_value=json.loads(row[9]) if row[9] else None, - new_value=json.loads(row[10]) if row[10] else None, - description=row[11], - session_id=row[12], - reverted=bool(row[13]), - reverted_at=row[14], - reverted_by=row[15], + id = row[0], + project_id = row[1], + change_type = row[2], + entity_type = row[3], + entity_id = row[4], + entity_name = row[5], + changed_by = row[6], + changed_by_name = row[7], + changed_at = row[8], + old_value = json.loads(row[9]) if row[9] else None, + new_value = json.loads(row[10]) if row[10] else None, + description = row[11], + session_id = row[12], + reverted = bool(row[13]), + reverted_at = row[14], + reverted_by = row[15], ) def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]: @@ -728,17 +728,17 @@ class CollaborationManager: if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT * FROM change_history - WHERE entity_type = ? AND entity_id = ? + WHERE entity_type = ? AND entity_id = ? ORDER BY changed_at ASC """, (entity_type, entity_id), ) - records = [] + records = [] for row in cursor.fetchall(): records.append(self._row_to_change_record(row)) return records @@ -748,13 +748,13 @@ class CollaborationManager: if not self.db: return False - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE change_history - SET reverted = 1, reverted_at = ?, reverted_by = ? - WHERE id = ? AND reverted = 0 + SET reverted = 1, reverted_at = ?, reverted_by = ? + WHERE id = ? AND reverted = 0 """, (now, reverted_by, record_id), ) @@ -766,49 +766,49 @@ class CollaborationManager: if not self.db: return {} - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() # 总变更数 cursor.execute( """ - SELECT COUNT(*) FROM change_history WHERE project_id = ? + SELECT COUNT(*) FROM change_history WHERE project_id = ? """, - (project_id,), + (project_id, ), ) - total_changes = cursor.fetchone()[0] + total_changes = cursor.fetchone()[0] # 按类型统计 cursor.execute( """ SELECT change_type, COUNT(*) FROM change_history - WHERE project_id = ? GROUP BY change_type + WHERE project_id = ? GROUP BY change_type """, - (project_id,), + (project_id, ), ) - type_counts = {row[0]: row[1] for row in cursor.fetchall()} + type_counts = {row[0]: row[1] for row in cursor.fetchall()} # 按实体类型统计 cursor.execute( """ SELECT entity_type, COUNT(*) FROM change_history - WHERE project_id = ? GROUP BY entity_type + WHERE project_id = ? GROUP BY entity_type """, - (project_id,), + (project_id, ), ) - entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()} + entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()} # 最近活跃的用户 cursor.execute( """ SELECT changed_by_name, COUNT(*) as count FROM change_history - WHERE project_id = ? + WHERE project_id = ? GROUP BY changed_by_name ORDER BY count DESC LIMIT 5 """, - (project_id,), + (project_id, ), ) - top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()] + top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()] return { "total_changes": total_changes, @@ -827,27 +827,27 @@ class CollaborationManager: user_email: str, role: str, invited_by: str, - permissions: list[str] | None = None, + permissions: list[str] | None = None, ) -> TeamMember: """添加团队成员""" - member_id = str(uuid.uuid4()) - now = datetime.now().isoformat() + member_id = str(uuid.uuid4()) + now = datetime.now().isoformat() # 根据角色设置默认权限 if permissions is None: - permissions = self._get_default_permissions(role) + permissions = self._get_default_permissions(role) - member = TeamMember( - id=member_id, - project_id=project_id, - user_id=user_id, - user_name=user_name, - user_email=user_email, - role=role, - joined_at=now, - invited_by=invited_by, - last_active_at=None, - permissions=permissions, + member = TeamMember( + id = member_id, + project_id = project_id, + user_id = user_id, + user_name = user_name, + user_email = user_email, + role = role, + joined_at = now, + invited_by = invited_by, + last_active_at = None, + permissions = permissions, ) if self.db: @@ -857,7 +857,7 @@ class CollaborationManager: def _get_default_permissions(self, role: str) -> list[str]: """获取角色的默认权限""" - permissions_map = { + permissions_map = { "owner": ["read", "write", "delete", "share", "admin", "export"], "admin": ["read", "write", "delete", "share", "export"], "editor": ["read", "write", "export"], @@ -868,7 +868,7 @@ class CollaborationManager: def _save_member_to_db(self, member: TeamMember) -> None: """保存成员到数据库""" - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ INSERT INTO team_members @@ -896,16 +896,16 @@ class CollaborationManager: if not self.db: return [] - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ - SELECT * FROM team_members WHERE project_id = ? + SELECT * FROM team_members WHERE project_id = ? ORDER BY joined_at ASC """, - (project_id,), + (project_id, ), ) - members = [] + members = [] for row in cursor.fetchall(): members.append(self._row_to_team_member(row)) return members @@ -913,16 +913,16 @@ class CollaborationManager: def _row_to_team_member(self, row) -> TeamMember: """将数据库行转换为TeamMember对象""" return TeamMember( - id=row[0], - project_id=row[1], - user_id=row[2], - user_name=row[3], - user_email=row[4], - role=row[5], - joined_at=row[6], - invited_by=row[7], - last_active_at=row[8], - permissions=json.loads(row[9]) if row[9] else [], + id = row[0], + project_id = row[1], + user_id = row[2], + user_name = row[3], + user_email = row[4], + role = row[5], + joined_at = row[6], + invited_by = row[7], + last_active_at = row[8], + permissions = json.loads(row[9]) if row[9] else [], ) def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool: @@ -930,13 +930,13 @@ class CollaborationManager: if not self.db: return False - permissions = self._get_default_permissions(new_role) - cursor = self.db.conn.cursor() + permissions = self._get_default_permissions(new_role) + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE team_members - SET role = ?, permissions = ? - WHERE id = ? + SET role = ?, permissions = ? + WHERE id = ? """, (new_role, json.dumps(permissions), member_id), ) @@ -948,8 +948,8 @@ class CollaborationManager: if not self.db: return False - cursor = self.db.conn.cursor() - cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id,)) + cursor = self.db.conn.cursor() + cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id, )) self.db.conn.commit() return cursor.rowcount > 0 @@ -958,20 +958,20 @@ class CollaborationManager: if not self.db: return False - cursor = self.db.conn.cursor() + cursor = self.db.conn.cursor() cursor.execute( """ SELECT permissions FROM team_members - WHERE project_id = ? AND user_id = ? + WHERE project_id = ? AND user_id = ? """, (project_id, user_id), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return False - permissions = json.loads(row[0]) if row[0] else [] + permissions = json.loads(row[0]) if row[0] else [] return permission in permissions or "admin" in permissions def update_last_active(self, project_id: str, user_id: str) -> None: @@ -979,13 +979,13 @@ class CollaborationManager: if not self.db: return - now = datetime.now().isoformat() - cursor = self.db.conn.cursor() + now = datetime.now().isoformat() + cursor = self.db.conn.cursor() cursor.execute( """ UPDATE team_members - SET last_active_at = ? - WHERE project_id = ? AND user_id = ? + SET last_active_at = ? + WHERE project_id = ? AND user_id = ? """, (now, project_id, user_id), ) @@ -993,12 +993,12 @@ class CollaborationManager: # 全局协作管理器实例 -_collaboration_manager = None +_collaboration_manager = None -def get_collaboration_manager(db_manager=None) -> None: +def get_collaboration_manager(db_manager = None) -> None: """获取协作管理器单例""" global _collaboration_manager if _collaboration_manager is None: - _collaboration_manager = CollaborationManager(db_manager) + _collaboration_manager = CollaborationManager(db_manager) return _collaboration_manager diff --git a/backend/db_manager.py b/backend/db_manager.py index 763bd14..37c1e0f 100644 --- a/backend/db_manager.py +++ b/backend/db_manager.py @@ -12,19 +12,19 @@ import uuid from dataclasses import dataclass from datetime import datetime -DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") +DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 @dataclass class Project: id: str name: str - description: str = "" - created_at: str = "" - updated_at: str = "" + description: str = "" + created_at: str = "" + updated_at: str = "" @dataclass @@ -33,19 +33,19 @@ class Entity: project_id: str name: str type: str - definition: str = "" - canonical_name: str = "" - aliases: list[str] = None - embedding: str = "" # Phase 3: 实体嵌入向量 - attributes: dict = None # Phase 5: 实体属性 - created_at: str = "" - updated_at: str = "" + definition: str = "" + canonical_name: str = "" + aliases: list[str] = None + embedding: str = "" # Phase 3: 实体嵌入向量 + attributes: dict = None # Phase 5: 实体属性 + created_at: str = "" + updated_at: str = "" - def __post_init__(self): + def __post_init__(self) -> None: if self.aliases is None: - self.aliases = [] + self.aliases = [] if self.attributes is None: - self.attributes = {} + self.attributes = {} @dataclass @@ -56,17 +56,17 @@ class AttributeTemplate: project_id: str name: str type: str # text, number, date, select, multiselect, boolean - options: list[str] = None # 用于 select/multiselect - default_value: str = "" - description: str = "" - is_required: bool = False - sort_order: int = 0 - created_at: str = "" - updated_at: str = "" + options: list[str] = None # 用于 select/multiselect + default_value: str = "" + description: str = "" + is_required: bool = False + sort_order: int = 0 + created_at: str = "" + updated_at: str = "" - def __post_init__(self): + def __post_init__(self) -> None: if self.options is None: - self.options = [] + self.options = [] @dataclass @@ -75,19 +75,19 @@ class EntityAttribute: id: str entity_id: str - template_id: str | None = None - name: str = "" # 属性名称 - type: str = "text" # 属性类型 - value: str = "" - options: list[str] = None # 选项列表 - template_name: str = "" # 关联查询时填充 - template_type: str = "" # 关联查询时填充 - created_at: str = "" - updated_at: str = "" + template_id: str | None = None + name: str = "" # 属性名称 + type: str = "text" # 属性类型 + value: str = "" + options: list[str] = None # 选项列表 + template_name: str = "" # 关联查询时填充 + template_type: str = "" # 关联查询时填充 + created_at: str = "" + updated_at: str = "" - def __post_init__(self): + def __post_init__(self) -> None: if self.options is None: - self.options = [] + self.options = [] @dataclass @@ -96,12 +96,12 @@ class AttributeHistory: id: str entity_id: str - attribute_name: str = "" # 属性名称 - old_value: str = "" - new_value: str = "" - changed_by: str = "" - changed_at: str = "" - change_reason: str = "" + attribute_name: str = "" # 属性名称 + old_value: str = "" + new_value: str = "" + changed_by: str = "" + changed_at: str = "" + change_reason: str = "" @dataclass @@ -112,35 +112,35 @@ class EntityMention: start_pos: int end_pos: int text_snippet: str - confidence: float = 1.0 + confidence: float = 1.0 class DatabaseManager: - def __init__(self, db_path: str = DB_PATH): - self.db_path = db_path - os.makedirs(os.path.dirname(db_path), exist_ok=True) + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + os.makedirs(os.path.dirname(db_path), exist_ok = True) self.init_db() - def get_conn(self): - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + def get_conn(self) -> None: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def init_db(self) -> None: """初始化数据库表""" with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f: - schema = f.read() + schema = f.read() - conn = self.get_conn() + conn = self.get_conn() conn.executescript(schema) conn.commit() conn.close() # ==================== Project Operations ==================== - def create_project(self, project_id: str, name: str, description: str = "") -> Project: - conn = self.get_conn() - now = datetime.now().isoformat() + def create_project(self, project_id: str, name: str, description: str = "") -> Project: + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO projects (id, name, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?)""", @@ -149,27 +149,27 @@ class DatabaseManager: conn.commit() conn.close() return Project( - id=project_id, name=name, description=description, created_at=now, updated_at=now + id = project_id, name = name, description = description, created_at = now, updated_at = now ) def get_project(self, project_id: str) -> Project | None: - conn = self.get_conn() - row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone() conn.close() if row: return Project(**dict(row)) return None def list_projects(self) -> list[Project]: - conn = self.get_conn() - rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() + conn = self.get_conn() + rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() conn.close() return [Project(**dict(r)) for r in rows] # ==================== Entity Operations ==================== def create_entity(self, entity: Entity) -> Entity: - conn = self.get_conn() + conn = self.get_conn() conn.execute( """INSERT INTO entities (id, project_id, name, canonical_name, type, definition, aliases, created_at, updated_at) @@ -192,122 +192,122 @@ class DatabaseManager: def get_entity_by_name(self, project_id: str, name: str) -> Entity | None: """通过名称查找实体(用于对齐)""" - conn = self.get_conn() - row = conn.execute( - """SELECT * FROM entities WHERE project_id = ? - AND (name = ? OR canonical_name = ? OR aliases LIKE ?)""", + conn = self.get_conn() + row = conn.execute( + """SELECT * FROM entities WHERE project_id = ? + AND (name = ? OR canonical_name = ? OR aliases LIKE ?)""", (project_id, name, name, f'%"{name}"%'), ).fetchone() conn.close() if row: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] return Entity(**data) return None def find_similar_entities( - self, project_id: str, name: str, threshold: float = 0.8 + self, project_id: str, name: str, threshold: float = 0.8 ) -> list[Entity]: """查找相似实体""" - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM entities WHERE project_id = ? AND name LIKE ?", (project_id, f"%{name}%") + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM entities WHERE project_id = ? AND name LIKE ?", (project_id, f"%{name}%") ).fetchall() conn.close() - entities = [] + entities = [] for row in rows: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] entities.append(Entity(**data)) return entities def merge_entities(self, target_id: str, source_id: str) -> Entity: """合并两个实体""" - conn = self.get_conn() + conn = self.get_conn() - target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id,)).fetchone() - source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id,)).fetchone() + target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id, )).fetchone() + source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id, )).fetchone() if not target or not source: conn.close() raise ValueError("Entity not found") - target_aliases = set(json.loads(target["aliases"]) if target["aliases"] else []) + target_aliases = set(json.loads(target["aliases"]) if target["aliases"] else []) target_aliases.add(source["name"]) target_aliases.update(json.loads(source["aliases"]) if source["aliases"] else []) conn.execute( - "UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?", + "UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?", (json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id), ) conn.execute( - "UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id) + "UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id) ) conn.execute( - "UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?", + "UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?", (target_id, source_id), ) conn.execute( - "UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?", + "UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?", (target_id, source_id), ) - conn.execute("DELETE FROM entities WHERE id = ?", (source_id,)) + conn.execute("DELETE FROM entities WHERE id = ?", (source_id, )) conn.commit() conn.close() return self.get_entity(target_id) def get_entity(self, entity_id: str) -> Entity | None: - conn = self.get_conn() - row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone() conn.close() if row: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] return Entity(**data) return None def list_project_entities(self, project_id: str) -> list[Entity]: - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,) + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id, ) ).fetchall() conn.close() - entities = [] + entities = [] for row in rows: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] entities.append(Entity(**data)) return entities def update_entity(self, entity_id: str, **kwargs) -> Entity: """更新实体信息""" - conn = self.get_conn() + conn = self.get_conn() - allowed_fields = ["name", "type", "definition", "canonical_name"] - updates = [] - values = [] + allowed_fields = ["name", "type", "definition", "canonical_name"] + updates = [] + values = [] for field in allowed_fields: if field in kwargs: - updates.append(f"{field} = ?") + updates.append(f"{field} = ?") values.append(kwargs[field]) if "aliases" in kwargs: - updates.append("aliases = ?") + updates.append("aliases = ?") values.append(json.dumps(kwargs["aliases"])) if not updates: conn.close() return self.get_entity(entity_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(entity_id) - query = f"UPDATE entities SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE entities SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -315,21 +315,21 @@ class DatabaseManager: def delete_entity(self, entity_id: str) -> None: """删除实体及其关联数据""" - conn = self.get_conn() - conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,)) + conn = self.get_conn() + conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id, )) conn.execute( - "DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?", + "DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?", (entity_id, entity_id), ) - conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id,)) - conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,)) + conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id, )) + conn.execute("DELETE FROM entities WHERE id = ?", (entity_id, )) conn.commit() conn.close() # ==================== Mention Operations ==================== def add_mention(self, mention: EntityMention) -> EntityMention: - conn = self.get_conn() + conn = self.get_conn() conn.execute( """INSERT INTO entity_mentions (id, entity_id, transcript_id, start_pos, end_pos, text_snippet, confidence) @@ -349,10 +349,10 @@ class DatabaseManager: return mention def get_entity_mentions(self, entity_id: str) -> list[EntityMention]: - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", - (entity_id,), + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", + (entity_id, ), ).fetchall() conn.close() return [EntityMention(**dict(r)) for r in rows] @@ -365,10 +365,10 @@ class DatabaseManager: project_id: str, filename: str, full_text: str, - transcript_type: str = "audio", - ): - conn = self.get_conn() - now = datetime.now().isoformat() + transcript_type: str = "audio", + ) -> None: + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO transcripts (id, project_id, filename, full_text, type, created_at) @@ -379,28 +379,28 @@ class DatabaseManager: conn.close() def get_transcript(self, transcript_id: str) -> dict | None: - conn = self.get_conn() - row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone() conn.close() return dict(row) if row else None def list_project_transcripts(self, project_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,) + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id, ) ).fetchall() conn.close() return [dict(r) for r in rows] def update_transcript(self, transcript_id: str, full_text: str) -> dict: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( - "UPDATE transcripts SET full_text = ?, updated_at = ? WHERE id = ?", + "UPDATE transcripts SET full_text = ?, updated_at = ? WHERE id = ?", (full_text, now, transcript_id), ) conn.commit() - row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() + row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone() conn.close() return dict(row) if row else None @@ -411,13 +411,13 @@ class DatabaseManager: project_id: str, source_entity_id: str, target_entity_id: str, - relation_type: str = "related", - evidence: str = "", - transcript_id: str = "", - ): - conn = self.get_conn() - relation_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + relation_type: str = "related", + evidence: str = "", + transcript_id: str = "", + ) -> None: + conn = self.get_conn() + relation_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() conn.execute( """INSERT INTO entity_relations (id, project_id, source_entity_id, target_entity_id, relation_type, @@ -439,10 +439,10 @@ class DatabaseManager: return relation_id def get_entity_relations(self, entity_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT * FROM entity_relations - WHERE source_entity_id = ? OR target_entity_id = ? + WHERE source_entity_id = ? OR target_entity_id = ? ORDER BY created_at DESC""", (entity_id, entity_id), ).fetchall() @@ -450,58 +450,58 @@ class DatabaseManager: return [dict(r) for r in rows] def list_project_relations(self, project_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", - (project_id,), + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", + (project_id, ), ).fetchall() conn.close() return [dict(r) for r in rows] def update_relation(self, relation_id: str, **kwargs) -> dict: - conn = self.get_conn() - allowed_fields = ["relation_type", "evidence"] - updates = [] - values = [] + conn = self.get_conn() + allowed_fields = ["relation_type", "evidence"] + updates = [] + values = [] for field in allowed_fields: if field in kwargs: - updates.append(f"{field} = ?") + updates.append(f"{field} = ?") values.append(kwargs[field]) if updates: - query = f"UPDATE entity_relations SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE entity_relations SET {', '.join(updates)} WHERE id = ?" values.append(relation_id) conn.execute(query, values) conn.commit() - row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id,)).fetchone() + row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id, )).fetchone() conn.close() return dict(row) if row else None - def delete_relation(self, relation_id: str): - conn = self.get_conn() - conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id,)) + def delete_relation(self, relation_id: str) -> None: + conn = self.get_conn() + conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id, )) conn.commit() conn.close() # ==================== Glossary Operations ==================== - def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str: - conn = self.get_conn() - existing = conn.execute( - "SELECT * FROM glossary WHERE project_id = ? AND term = ?", (project_id, term) + def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str: + conn = self.get_conn() + existing = conn.execute( + "SELECT * FROM glossary WHERE project_id = ? AND term = ?", (project_id, term) ).fetchone() if existing: conn.execute( - "UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],) + "UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"], ) ) conn.commit() conn.close() return existing["id"] - term_id = str(uuid.uuid4())[:UUID_LENGTH] + term_id = str(uuid.uuid4())[:UUID_LENGTH] conn.execute( """INSERT INTO glossary (id, project_id, term, pronunciation, frequency) @@ -513,118 +513,118 @@ class DatabaseManager: return term_id def list_glossary(self, project_id: str) -> list[dict]: - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,) + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id, ) ).fetchall() conn.close() return [dict(r) for r in rows] - def delete_glossary_term(self, term_id: str): - conn = self.get_conn() - conn.execute("DELETE FROM glossary WHERE id = ?", (term_id,)) + def delete_glossary_term(self, term_id: str) -> None: + conn = self.get_conn() + conn.execute("DELETE FROM glossary WHERE id = ?", (term_id, )) conn.commit() conn.close() # ==================== Phase 4: Agent & Provenance ==================== def get_relation_with_details(self, relation_id: str) -> dict | None: - conn = self.get_conn() - row = conn.execute( + conn = self.get_conn() + row = conn.execute( """SELECT r.*, s.name as source_name, t.name as target_name, tr.filename as transcript_filename, tr.full_text as transcript_text FROM entity_relations r - JOIN entities s ON r.source_entity_id = s.id - JOIN entities t ON r.target_entity_id = t.id - LEFT JOIN transcripts tr ON r.transcript_id = tr.id - WHERE r.id = ?""", - (relation_id,), + JOIN entities s ON r.source_entity_id = s.id + JOIN entities t ON r.target_entity_id = t.id + LEFT JOIN transcripts tr ON r.transcript_id = tr.id + WHERE r.id = ?""", + (relation_id, ), ).fetchone() conn.close() return dict(row) if row else None def get_entity_with_mentions(self, entity_id: str) -> dict | None: - conn = self.get_conn() - entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() + conn = self.get_conn() + entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone() if not entity_row: conn.close() return None - entity = dict(entity_row) - entity["aliases"] = json.loads(entity["aliases"]) if entity["aliases"] else [] + entity = dict(entity_row) + entity["aliases"] = json.loads(entity["aliases"]) if entity["aliases"] else [] - mentions = conn.execute( + mentions = conn.execute( """SELECT m.*, t.filename, t.created_at as transcript_date FROM entity_mentions m - JOIN transcripts t ON m.transcript_id = t.id - WHERE m.entity_id = ? ORDER BY t.created_at, m.start_pos""", - (entity_id,), + JOIN transcripts t ON m.transcript_id = t.id + WHERE m.entity_id = ? ORDER BY t.created_at, m.start_pos""", + (entity_id, ), ).fetchall() - entity["mentions"] = [dict(m) for m in mentions] - entity["mention_count"] = len(mentions) + entity["mentions"] = [dict(m) for m in mentions] + entity["mention_count"] = len(mentions) - relations = conn.execute( + relations = conn.execute( """SELECT r.*, s.name as source_name, t.name as target_name FROM entity_relations r - JOIN entities s ON r.source_entity_id = s.id - JOIN entities t ON r.target_entity_id = t.id - WHERE r.source_entity_id = ? OR r.target_entity_id = ? + JOIN entities s ON r.source_entity_id = s.id + JOIN entities t ON r.target_entity_id = t.id + WHERE r.source_entity_id = ? OR r.target_entity_id = ? ORDER BY r.created_at DESC""", (entity_id, entity_id), ).fetchall() - entity["relations"] = [dict(r) for r in relations] + entity["relations"] = [dict(r) for r in relations] conn.close() return entity def search_entities(self, project_id: str, query: str) -> list[Entity]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT * FROM entities - WHERE project_id = ? AND + WHERE project_id = ? AND (name LIKE ? OR definition LIKE ? OR aliases LIKE ?) ORDER BY name""", (project_id, f"%{query}%", f"%{query}%", f"%{query}%"), ).fetchall() conn.close() - entities = [] + entities = [] for row in rows: - data = dict(row) - data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] + data = dict(row) + data["aliases"] = json.loads(data["aliases"]) if data["aliases"] else [] entities.append(Entity(**data)) return entities def get_project_summary(self, project_id: str) -> dict: - conn = self.get_conn() - project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() + conn = self.get_conn() + project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone() - entity_count = conn.execute( - "SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id,) + entity_count = conn.execute( + "SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id, ) ).fetchone()["count"] - transcript_count = conn.execute( - "SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id,) + transcript_count = conn.execute( + "SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id, ) ).fetchone()["count"] - relation_count = conn.execute( - "SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id,) + relation_count = conn.execute( + "SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id, ) ).fetchone()["count"] - recent_transcripts = conn.execute( + recent_transcripts = conn.execute( """SELECT filename, full_text, created_at FROM transcripts - WHERE project_id = ? ORDER BY created_at DESC LIMIT 5""", - (project_id,), + WHERE project_id = ? ORDER BY created_at DESC LIMIT 5""", + (project_id, ), ).fetchall() - top_entities = conn.execute( + top_entities = conn.execute( """SELECT e.name, e.type, e.definition, COUNT(m.id) as mention_count FROM entities e - LEFT JOIN entity_mentions m ON e.id = m.entity_id - WHERE e.project_id = ? + LEFT JOIN entity_mentions m ON e.id = m.entity_id + WHERE e.project_id = ? GROUP BY e.id ORDER BY mention_count DESC LIMIT 10""", - (project_id,), + (project_id, ), ).fetchall() conn.close() @@ -641,32 +641,32 @@ class DatabaseManager: } def get_transcript_context( - self, transcript_id: str, position: int, context_chars: int = 200 + self, transcript_id: str, position: int, context_chars: int = 200 ) -> str: - conn = self.get_conn() - row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,) + conn = self.get_conn() + row = conn.execute( + "SELECT full_text FROM transcripts WHERE id = ?", (transcript_id, ) ).fetchone() conn.close() if not row: return "" - text = row["full_text"] - start = max(0, position - context_chars) - end = min(len(text), position + context_chars) + text = row["full_text"] + start = max(0, position - context_chars) + end = min(len(text), position + context_chars) return text[start:end] # ==================== Phase 5: Timeline Operations ==================== def get_project_timeline( - self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None + self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None ) -> list[dict]: - conn = self.get_conn() + conn = self.get_conn() - conditions = ["t.project_id = ?"] - params = [project_id] + conditions = ["t.project_id = ?"] + params = [project_id] if entity_id: - conditions.append("m.entity_id = ?") + conditions.append("m.entity_id = ?") params.append(entity_id) if start_date: conditions.append("t.created_at >= ?") @@ -675,19 +675,19 @@ class DatabaseManager: conditions.append("t.created_at <= ?") params.append(end_date) - where_clause = " AND ".join(conditions) + where_clause = " AND ".join(conditions) - mentions = conn.execute( + mentions = conn.execute( f"""SELECT m.*, e.name as entity_name, e.type as entity_type, e.definition, t.filename, t.created_at as event_date, t.type as source_type FROM entity_mentions m - JOIN entities e ON m.entity_id = e.id - JOIN transcripts t ON m.transcript_id = t.id + JOIN entities e ON m.entity_id = e.id + JOIN transcripts t ON m.transcript_id = t.id WHERE {where_clause} ORDER BY t.created_at, m.start_pos""", params, ).fetchall() - timeline_events = [] + timeline_events = [] for m in mentions: timeline_events.append( { @@ -708,30 +708,30 @@ class DatabaseManager: ) conn.close() - timeline_events.sort(key=lambda x: x["event_date"]) + timeline_events.sort(key = lambda x: x["event_date"]) return timeline_events def get_entity_timeline_summary(self, project_id: str) -> dict: - conn = self.get_conn() + conn = self.get_conn() - daily_stats = conn.execute( + daily_stats = conn.execute( """SELECT DATE(t.created_at) as date, COUNT(*) as count FROM entity_mentions m - JOIN transcripts t ON m.transcript_id = t.id - WHERE t.project_id = ? GROUP BY DATE(t.created_at) ORDER BY date""", - (project_id,), + JOIN transcripts t ON m.transcript_id = t.id + WHERE t.project_id = ? GROUP BY DATE(t.created_at) ORDER BY date""", + (project_id, ), ).fetchall() - entity_stats = conn.execute( + entity_stats = conn.execute( """SELECT e.name, e.type, COUNT(m.id) as mention_count, MIN(t.created_at) as first_mentioned, MAX(t.created_at) as last_mentioned FROM entities e - LEFT JOIN entity_mentions m ON e.id = m.entity_id - LEFT JOIN transcripts t ON m.transcript_id = t.id - WHERE e.project_id = ? + LEFT JOIN entity_mentions m ON e.id = m.entity_id + LEFT JOIN transcripts t ON m.transcript_id = t.id + WHERE e.project_id = ? GROUP BY e.id ORDER BY mention_count DESC LIMIT 20""", - (project_id,), + (project_id, ), ).fetchall() conn.close() @@ -744,8 +744,8 @@ class DatabaseManager: # ==================== Phase 5: Entity Attributes ==================== def create_attribute_template(self, template: AttributeTemplate) -> AttributeTemplate: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO attribute_templates (id, project_id, name, type, options, default_value, description, @@ -770,36 +770,36 @@ class DatabaseManager: return template def get_attribute_template(self, template_id: str) -> AttributeTemplate | None: - conn = self.get_conn() - row = conn.execute( - "SELECT * FROM attribute_templates WHERE id = ?", (template_id,) + conn = self.get_conn() + row = conn.execute( + "SELECT * FROM attribute_templates WHERE id = ?", (template_id, ) ).fetchone() conn.close() if row: - data = dict(row) - data["options"] = json.loads(data["options"]) if data["options"] else [] + data = dict(row) + data["options"] = json.loads(data["options"]) if data["options"] else [] return AttributeTemplate(**data) return None def list_attribute_templates(self, project_id: str) -> list[AttributeTemplate]: - conn = self.get_conn() - rows = conn.execute( - """SELECT * FROM attribute_templates WHERE project_id = ? + conn = self.get_conn() + rows = conn.execute( + """SELECT * FROM attribute_templates WHERE project_id = ? ORDER BY sort_order, created_at""", - (project_id,), + (project_id, ), ).fetchall() conn.close() - templates = [] + templates = [] for row in rows: - data = dict(row) - data["options"] = json.loads(data["options"]) if data["options"] else [] + data = dict(row) + data["options"] = json.loads(data["options"]) if data["options"] else [] templates.append(AttributeTemplate(**data)) return templates def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None: - conn = self.get_conn() - allowed_fields = [ + conn = self.get_conn() + allowed_fields = [ "name", "type", "options", @@ -808,45 +808,45 @@ class DatabaseManager: "is_required", "sort_order", ] - updates = [] - values = [] + updates = [] + values = [] for field in allowed_fields: if field in kwargs: - updates.append(f"{field} = ?") + updates.append(f"{field} = ?") if field == "options": values.append(json.dumps(kwargs[field]) if kwargs[field] else None) else: values.append(kwargs[field]) if updates: - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(template_id) - query = f"UPDATE attribute_templates SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE attribute_templates SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() return self.get_attribute_template(template_id) - def delete_attribute_template(self, template_id: str): - conn = self.get_conn() - conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id,)) + def delete_attribute_template(self, template_id: str) -> None: + conn = self.get_conn() + conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id, )) conn.commit() conn.close() def set_entity_attribute( - self, attr: EntityAttribute, changed_by: str = "system", change_reason: str = "" + self, attr: EntityAttribute, changed_by: str = "system", change_reason: str = "" ) -> EntityAttribute: - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() - old_row = conn.execute( - "SELECT value FROM entity_attributes WHERE entity_id = ? AND template_id = ?", + old_row = conn.execute( + "SELECT value FROM entity_attributes WHERE entity_id = ? AND template_id = ?", (attr.entity_id, attr.template_id), ).fetchone() - old_value = old_row["value"] if old_row else None + old_value = old_row["value"] if old_row else None if old_value != attr.value: conn.execute( @@ -872,12 +872,12 @@ class DatabaseManager: VALUES ( COALESCE( (SELECT id FROM entity_attributes - WHERE entity_id = ? AND template_id = ?), ? + WHERE entity_id = ? AND template_id = ?), ? ), ?, ?, ?, COALESCE( (SELECT created_at FROM entity_attributes - WHERE entity_id = ? AND template_id = ?), ? + WHERE entity_id = ? AND template_id = ?), ? ), ?)""", ( @@ -899,23 +899,23 @@ class DatabaseManager: return attr def get_entity_attributes(self, entity_id: str) -> list[EntityAttribute]: - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT ea.*, at.name as template_name, at.type as template_type FROM entity_attributes ea - LEFT JOIN attribute_templates at ON ea.template_id = at.id - WHERE ea.entity_id = ? ORDER BY ea.created_at""", - (entity_id,), + LEFT JOIN attribute_templates at ON ea.template_id = at.id + WHERE ea.entity_id = ? ORDER BY ea.created_at""", + (entity_id, ), ).fetchall() conn.close() return [EntityAttribute(**dict(r)) for r in rows] def get_entity_with_attributes(self, entity_id: str) -> Entity | None: - entity = self.get_entity(entity_id) + entity = self.get_entity(entity_id) if not entity: return None - attrs = self.get_entity_attributes(entity_id) - entity.attributes = { + attrs = self.get_entity_attributes(entity_id) + entity.attributes = { attr.template_name: { "value": attr.value, "type": attr.template_type, @@ -926,12 +926,12 @@ class DatabaseManager: return entity def delete_entity_attribute( - self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "" - ): - conn = self.get_conn() - old_row = conn.execute( + self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "" + ) -> None: + conn = self.get_conn() + old_row = conn.execute( """SELECT value FROM entity_attributes - WHERE entity_id = ? AND template_id = ?""", + WHERE entity_id = ? AND template_id = ?""", (entity_id, template_id), ).fetchone() @@ -953,29 +953,29 @@ class DatabaseManager: ), ) conn.execute( - "DELETE FROM entity_attributes WHERE entity_id = ? AND template_id = ?", + "DELETE FROM entity_attributes WHERE entity_id = ? AND template_id = ?", (entity_id, template_id), ) conn.commit() conn.close() def get_attribute_history( - self, entity_id: str = None, template_id: str = None, limit: int = 50 + self, entity_id: str = None, template_id: str = None, limit: int = 50 ) -> list[AttributeHistory]: - conn = self.get_conn() - conditions = [] - params = [] + conn = self.get_conn() + conditions = [] + params = [] if entity_id: - conditions.append("ah.entity_id = ?") + conditions.append("ah.entity_id = ?") params.append(entity_id) if template_id: - conditions.append("ah.template_id = ?") + conditions.append("ah.template_id = ?") params.append(template_id) - where_clause = " AND ".join(conditions) if conditions else "1=1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"""SELECT ah.* FROM attribute_history ah WHERE {where_clause} @@ -988,42 +988,42 @@ class DatabaseManager: def search_entities_by_attributes( self, project_id: str, attribute_filters: dict[str, str] ) -> list[Entity]: - entities = self.list_project_entities(project_id) + entities = self.list_project_entities(project_id) if not attribute_filters: return entities - entity_ids = [e.id for e in entities] + entity_ids = [e.id for e in entities] if not entity_ids: return [] - conn = self.get_conn() - placeholders = ",".join(["?" for _ in entity_ids]) - rows = conn.execute( + conn = self.get_conn() + placeholders = ", ".join(["?" for _ in entity_ids]) + rows = conn.execute( f"""SELECT ea.*, at.name as template_name FROM entity_attributes ea - JOIN attribute_templates at ON ea.template_id = at.id + JOIN attribute_templates at ON ea.template_id = at.id WHERE ea.entity_id IN ({placeholders})""", entity_ids, ).fetchall() conn.close() - entity_attrs = {} + entity_attrs = {} for row in rows: - eid = row["entity_id"] + eid = row["entity_id"] if eid not in entity_attrs: - entity_attrs[eid] = {} - entity_attrs[eid][row["template_name"]] = row["value"] + entity_attrs[eid] = {} + entity_attrs[eid][row["template_name"]] = row["value"] - filtered = [] + filtered = [] for entity in entities: - attrs = entity_attrs.get(entity.id, {}) - match = True + attrs = entity_attrs.get(entity.id, {}) + match = True for attr_name, attr_value in attribute_filters.items(): if attrs.get(attr_name) != attr_value: - match = False + match = False break if match: - entity.attributes = attrs + entity.attributes = attrs filtered.append(entity) return filtered @@ -1034,17 +1034,17 @@ class DatabaseManager: video_id: str, project_id: str, filename: str, - duration: float = 0, - fps: float = 0, - resolution: dict = None, - audio_transcript_id: str = None, - full_ocr_text: str = "", - extracted_entities: list[dict] = None, - extracted_relations: list[dict] = None, + duration: float = 0, + fps: float = 0, + resolution: dict = None, + audio_transcript_id: str = None, + full_ocr_text: str = "", + extracted_entities: list[dict] = None, + extracted_relations: list[dict] = None, ) -> str: """创建视频记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO videos @@ -1074,17 +1074,17 @@ class DatabaseManager: def get_video(self, video_id: str) -> dict | None: """获取视频信息""" - conn = self.get_conn() - row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id, )).fetchone() conn.close() if row: - data = dict(row) - data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None - data["extracted_entities"] = ( + data = dict(row) + data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) return data @@ -1092,20 +1092,20 @@ class DatabaseManager: def list_project_videos(self, project_id: str) -> list[dict]: """获取项目的所有视频""" - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id,) + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id, ) ).fetchall() conn.close() - videos = [] + videos = [] for row in rows: - data = dict(row) - data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None - data["extracted_entities"] = ( + data = dict(row) + data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) videos.append(data) @@ -1117,13 +1117,13 @@ class DatabaseManager: video_id: str, frame_number: int, timestamp: float, - image_url: str = None, - ocr_text: str = None, - extracted_entities: list[dict] = None, + image_url: str = None, + ocr_text: str = None, + extracted_entities: list[dict] = None, ) -> str: """创建视频帧记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO video_frames @@ -1147,16 +1147,16 @@ class DatabaseManager: def get_video_frames(self, video_id: str) -> list[dict]: """获取视频的所有帧""" - conn = self.get_conn() - rows = conn.execute( - """SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id,) + conn = self.get_conn() + rows = conn.execute( + """SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id, ) ).fetchall() conn.close() - frames = [] + frames = [] for row in rows: - data = dict(row) - data["extracted_entities"] = ( + data = dict(row) + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) frames.append(data) @@ -1167,14 +1167,14 @@ class DatabaseManager: image_id: str, project_id: str, filename: str, - ocr_text: str = "", - description: str = "", - extracted_entities: list[dict] = None, - extracted_relations: list[dict] = None, + ocr_text: str = "", + description: str = "", + extracted_entities: list[dict] = None, + extracted_relations: list[dict] = None, ) -> str: """创建图片记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO images @@ -1200,16 +1200,16 @@ class DatabaseManager: def get_image(self, image_id: str) -> dict | None: """获取图片信息""" - conn = self.get_conn() - row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone() + conn = self.get_conn() + row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id, )).fetchone() conn.close() if row: - data = dict(row) - data["extracted_entities"] = ( + data = dict(row) + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) return data @@ -1217,19 +1217,19 @@ class DatabaseManager: def list_project_images(self, project_id: str) -> list[dict]: """获取项目的所有图片""" - conn = self.get_conn() - rows = conn.execute( - "SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id,) + conn = self.get_conn() + rows = conn.execute( + "SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id, ) ).fetchall() conn.close() - images = [] + images = [] for row in rows: - data = dict(row) - data["extracted_entities"] = ( + data = dict(row) + data["extracted_entities"] = ( json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] ) - data["extracted_relations"] = ( + data["extracted_relations"] = ( json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] ) images.append(data) @@ -1243,12 +1243,12 @@ class DatabaseManager: modality: str, source_id: str, source_type: str, - text_snippet: str = "", - confidence: float = 1.0, + text_snippet: str = "", + confidence: float = 1.0, ) -> str: """创建多模态实体提及记录""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT OR REPLACE INTO multimodal_mentions @@ -1273,37 +1273,37 @@ class DatabaseManager: def get_entity_multimodal_mentions(self, entity_id: str) -> list[dict]: """获取实体的多模态提及""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT m.*, e.name as entity_name FROM multimodal_mentions m - JOIN entities e ON m.entity_id = e.id - WHERE m.entity_id = ? ORDER BY m.created_at DESC""", - (entity_id,), + JOIN entities e ON m.entity_id = e.id + WHERE m.entity_id = ? ORDER BY m.created_at DESC""", + (entity_id, ), ).fetchall() conn.close() return [dict(r) for r in rows] - def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]: + def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]: """获取项目的多模态提及""" - conn = self.get_conn() + conn = self.get_conn() if modality: - rows = conn.execute( + rows = conn.execute( """SELECT m.*, e.name as entity_name FROM multimodal_mentions m - JOIN entities e ON m.entity_id = e.id - WHERE m.project_id = ? AND m.modality = ? + JOIN entities e ON m.entity_id = e.id + WHERE m.project_id = ? AND m.modality = ? ORDER BY m.created_at DESC""", (project_id, modality), ).fetchall() else: - rows = conn.execute( + rows = conn.execute( """SELECT m.*, e.name as entity_name FROM multimodal_mentions m - JOIN entities e ON m.entity_id = e.id - WHERE m.project_id = ? ORDER BY m.created_at DESC""", - (project_id,), + JOIN entities e ON m.entity_id = e.id + WHERE m.project_id = ? ORDER BY m.created_at DESC""", + (project_id, ), ).fetchall() conn.close() @@ -1315,13 +1315,13 @@ class DatabaseManager: entity_id: str, linked_entity_id: str, link_type: str, - confidence: float = 1.0, - evidence: str = "", - modalities: list[str] = None, + confidence: float = 1.0, + evidence: str = "", + modalities: list[str] = None, ) -> str: """创建多模态实体关联""" - conn = self.get_conn() - now = datetime.now().isoformat() + conn = self.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT OR REPLACE INTO multimodal_entity_links @@ -1345,29 +1345,29 @@ class DatabaseManager: def get_entity_multimodal_links(self, entity_id: str) -> list[dict]: """获取实体的多模态关联""" - conn = self.get_conn() - rows = conn.execute( + conn = self.get_conn() + rows = conn.execute( """SELECT l.*, e1.name as entity_name, e2.name as linked_entity_name FROM multimodal_entity_links l - JOIN entities e1 ON l.entity_id = e1.id - JOIN entities e2 ON l.linked_entity_id = e2.id - WHERE l.entity_id = ? OR l.linked_entity_id = ?""", + JOIN entities e1 ON l.entity_id = e1.id + JOIN entities e2 ON l.linked_entity_id = e2.id + WHERE l.entity_id = ? OR l.linked_entity_id = ?""", (entity_id, entity_id), ).fetchall() conn.close() - links = [] + links = [] for row in rows: - data = dict(row) - data["modalities"] = json.loads(data["modalities"]) if data["modalities"] else [] + data = dict(row) + data["modalities"] = json.loads(data["modalities"]) if data["modalities"] else [] links.append(data) return links def get_project_multimodal_stats(self, project_id: str) -> dict: """获取项目多模态统计信息""" - conn = self.get_conn() + conn = self.get_conn() - stats = { + stats = { "video_count": 0, "image_count": 0, "multimodal_entity_count": 0, @@ -1376,52 +1376,52 @@ class DatabaseManager: } # 视频数量 - row = conn.execute( - "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,) + row = conn.execute( + "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id, ) ).fetchone() - stats["video_count"] = row["count"] + stats["video_count"] = row["count"] # 图片数量 - row = conn.execute( - "SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,) + row = conn.execute( + "SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id, ) ).fetchone() - stats["image_count"] = row["count"] + stats["image_count"] = row["count"] # 多模态实体数量 - row = conn.execute( + row = conn.execute( """SELECT COUNT(DISTINCT entity_id) as count - FROM multimodal_mentions WHERE project_id = ?""", - (project_id,), + FROM multimodal_mentions WHERE project_id = ?""", + (project_id, ), ).fetchone() - stats["multimodal_entity_count"] = row["count"] + stats["multimodal_entity_count"] = row["count"] # 跨模态关联数量 - row = conn.execute( + row = conn.execute( """SELECT COUNT(*) as count FROM multimodal_entity_links - WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""", - (project_id,), + WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""", + (project_id, ), ).fetchone() - stats["cross_modal_links"] = row["count"] + stats["cross_modal_links"] = row["count"] # 模态分布 for modality in ["audio", "video", "image", "document"]: - row = conn.execute( + row = conn.execute( """SELECT COUNT(*) as count FROM multimodal_mentions - WHERE project_id = ? AND modality = ?""", + WHERE project_id = ? AND modality = ?""", (project_id, modality), ).fetchone() - stats["modality_distribution"][modality] = row["count"] + stats["modality_distribution"][modality] = row["count"] conn.close() return stats # Singleton instance -_db_manager = None +_db_manager = None def get_db_manager() -> DatabaseManager: global _db_manager if _db_manager is None: - _db_manager = DatabaseManager() + _db_manager = DatabaseManager() return _db_manager diff --git a/backend/developer_ecosystem_manager.py b/backend/developer_ecosystem_manager.py index 55c31a7..fc48c95 100644 --- a/backend/developer_ecosystem_manager.py +++ b/backend/developer_ecosystem_manager.py @@ -19,81 +19,81 @@ from datetime import datetime from enum import StrEnum # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class SDKLanguage(StrEnum): """SDK 语言类型""" - PYTHON = "python" - JAVASCRIPT = "javascript" - TYPESCRIPT = "typescript" - GO = "go" - JAVA = "java" - RUST = "rust" + PYTHON = "python" + JAVASCRIPT = "javascript" + TYPESCRIPT = "typescript" + GO = "go" + JAVA = "java" + RUST = "rust" class SDKStatus(StrEnum): """SDK 状态""" - DRAFT = "draft" # 草稿 - BETA = "beta" # 测试版 - STABLE = "stable" # 稳定版 - DEPRECATED = "deprecated" # 已弃用 - ARCHIVED = "archived" # 已归档 + DRAFT = "draft" # 草稿 + BETA = "beta" # 测试版 + STABLE = "stable" # 稳定版 + DEPRECATED = "deprecated" # 已弃用 + ARCHIVED = "archived" # 已归档 class TemplateCategory(StrEnum): """模板分类""" - MEDICAL = "medical" # 医疗 - LEGAL = "legal" # 法律 - FINANCE = "finance" # 金融 - EDUCATION = "education" # 教育 - TECH = "tech" # 科技 - GENERAL = "general" # 通用 + MEDICAL = "medical" # 医疗 + LEGAL = "legal" # 法律 + FINANCE = "finance" # 金融 + EDUCATION = "education" # 教育 + TECH = "tech" # 科技 + GENERAL = "general" # 通用 class TemplateStatus(StrEnum): """模板状态""" - PENDING = "pending" # 待审核 - APPROVED = "approved" # 已通过 - REJECTED = "rejected" # 已拒绝 - PUBLISHED = "published" # 已发布 - UNLISTED = "unlisted" # 未列出 + PENDING = "pending" # 待审核 + APPROVED = "approved" # 已通过 + REJECTED = "rejected" # 已拒绝 + PUBLISHED = "published" # 已发布 + UNLISTED = "unlisted" # 未列出 class PluginStatus(StrEnum): """插件状态""" - PENDING = "pending" # 待审核 - REVIEWING = "reviewing" # 审核中 - APPROVED = "approved" # 已通过 - REJECTED = "rejected" # 已拒绝 - PUBLISHED = "published" # 已发布 - SUSPENDED = "suspended" # 已暂停 + PENDING = "pending" # 待审核 + REVIEWING = "reviewing" # 审核中 + APPROVED = "approved" # 已通过 + REJECTED = "rejected" # 已拒绝 + PUBLISHED = "published" # 已发布 + SUSPENDED = "suspended" # 已暂停 class PluginCategory(StrEnum): """插件分类""" - INTEGRATION = "integration" # 集成 - ANALYSIS = "analysis" # 分析 - VISUALIZATION = "visualization" # 可视化 - AUTOMATION = "automation" # 自动化 - SECURITY = "security" # 安全 - CUSTOM = "custom" # 自定义 + INTEGRATION = "integration" # 集成 + ANALYSIS = "analysis" # 分析 + VISUALIZATION = "visualization" # 可视化 + AUTOMATION = "automation" # 自动化 + SECURITY = "security" # 安全 + CUSTOM = "custom" # 自定义 class DeveloperStatus(StrEnum): """开发者认证状态""" - UNVERIFIED = "unverified" # 未认证 - PENDING = "pending" # 审核中 - VERIFIED = "verified" # 已认证 - CERTIFIED = "certified" # 已认证(高级) - SUSPENDED = "suspended" # 已暂停 + UNVERIFIED = "unverified" # 未认证 + PENDING = "pending" # 审核中 + VERIFIED = "verified" # 已认证 + CERTIFIED = "certified" # 已认证(高级) + SUSPENDED = "suspended" # 已暂停 @dataclass @@ -112,7 +112,7 @@ class SDKRelease: package_name: str # pip/npm/go module name status: SDKStatus min_platform_version: str - dependencies: list[dict] # [{"name": "requests", "version": ">=2.0"}] + dependencies: list[dict] # [{"name": "requests", "version": ">= 2.0"}] file_size: int checksum: str download_count: int @@ -152,7 +152,7 @@ class TemplateMarketItem: author_id: str author_name: str status: TemplateStatus - price: float # 0 = 免费 + price: float # 0 = 免费 currency: str preview_image_url: str | None demo_url: str | None @@ -348,14 +348,14 @@ class DeveloperPortalConfig: class DeveloperEcosystemManager: """开发者生态系统管理主类""" - def __init__(self, db_path: str = DB_PATH): - self.db_path = db_path - self.platform_fee_rate = 0.30 # 平台抽成比例 30% + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self.platform_fee_rate = 0.30 # 平台抽成比例 30% def _get_db(self) -> None: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn # ==================== SDK 发布与管理 ==================== @@ -378,30 +378,30 @@ class DeveloperEcosystemManager: created_by: str, ) -> SDKRelease: """创建 SDK 发布""" - sdk_id = f"sdk_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + sdk_id = f"sdk_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - sdk = SDKRelease( - id=sdk_id, - name=name, - language=language, - version=version, - description=description, - changelog=changelog, - download_url=download_url, - documentation_url=documentation_url, - repository_url=repository_url, - package_name=package_name, - status=SDKStatus.DRAFT, - min_platform_version=min_platform_version, - dependencies=dependencies, - file_size=file_size, - checksum=checksum, - download_count=0, - created_at=now, - updated_at=now, - published_at=None, - created_by=created_by, + sdk = SDKRelease( + id = sdk_id, + name = name, + language = language, + version = version, + description = description, + changelog = changelog, + download_url = download_url, + documentation_url = documentation_url, + repository_url = repository_url, + package_name = package_name, + status = SDKStatus.DRAFT, + min_platform_version = min_platform_version, + dependencies = dependencies, + file_size = file_size, + checksum = checksum, + download_count = 0, + created_at = now, + updated_at = now, + published_at = None, + created_by = created_by, ) with self._get_db() as conn: @@ -444,7 +444,7 @@ class DeveloperEcosystemManager: def get_sdk_release(self, sdk_id: str) -> SDKRelease | None: """获取 SDK 发布详情""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id,)).fetchone() + row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id, )).fetchone() if row: return self._row_to_sdk_release(row) @@ -452,19 +452,19 @@ class DeveloperEcosystemManager: def list_sdk_releases( self, - language: SDKLanguage | None = None, - status: SDKStatus | None = None, - search: str | None = None, + language: SDKLanguage | None = None, + status: SDKStatus | None = None, + search: str | None = None, ) -> list[SDKRelease]: """列出 SDK 发布""" - query = "SELECT * FROM sdk_releases WHERE 1=1" - params = [] + query = "SELECT * FROM sdk_releases WHERE 1 = 1" + params = [] if language: - query += " AND language = ?" + query += " AND language = ?" params.append(language.value) if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status.value) if search: query += " AND (name LIKE ? OR description LIKE ? OR package_name LIKE ?)" @@ -473,12 +473,12 @@ class DeveloperEcosystemManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_sdk_release(row) for row in rows] def update_sdk_release(self, sdk_id: str, **kwargs) -> SDKRelease | None: """更新 SDK 发布""" - allowed_fields = [ + allowed_fields = [ "name", "description", "changelog", @@ -488,16 +488,16 @@ class DeveloperEcosystemManager: "status", ] - updates = {k: v for k, v in kwargs.items() if k in allowed_fields} + updates = {k: v for k, v in kwargs.items() if k in allowed_fields} if not updates: return self.get_sdk_release(sdk_id) - updates["updated_at"] = datetime.now().isoformat() + updates["updated_at"] = datetime.now().isoformat() with self._get_db() as conn: - set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) conn.execute( - f"UPDATE sdk_releases SET {set_clause} WHERE id = ?", + f"UPDATE sdk_releases SET {set_clause} WHERE id = ?", list(updates.values()) + [sdk_id], ) conn.commit() @@ -506,14 +506,14 @@ class DeveloperEcosystemManager: def publish_sdk_release(self, sdk_id: str) -> SDKRelease | None: """发布 SDK""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE sdk_releases - SET status = ?, published_at = ?, updated_at = ? - WHERE id = ? + SET status = ?, published_at = ?, updated_at = ? + WHERE id = ? """, (SDKStatus.STABLE.value, now, now, sdk_id), ) @@ -527,18 +527,18 @@ class DeveloperEcosystemManager: conn.execute( """ UPDATE sdk_releases - SET download_count = download_count + 1 - WHERE id = ? + SET download_count = download_count + 1 + WHERE id = ? """, - (sdk_id,), + (sdk_id, ), ) conn.commit() def get_sdk_versions(self, sdk_id: str) -> list[SDKVersion]: """获取 SDK 版本历史""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC", (sdk_id,) + rows = conn.execute( + "SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC", (sdk_id, ) ).fetchall() return [self._row_to_sdk_version(row) for row in rows] @@ -553,13 +553,13 @@ class DeveloperEcosystemManager: file_size: int, ) -> SDKVersion: """添加 SDK 版本""" - version_id = f"sv_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + version_id = f"sv_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() with self._get_db() as conn: # 如果设置为最新版本,取消其他版本的最新标记 if True: # 默认新版本为最新 - conn.execute("UPDATE sdk_versions SET is_latest = 0 WHERE sdk_id = ?", (sdk_id,)) + conn.execute("UPDATE sdk_versions SET is_latest = 0 WHERE sdk_id = ?", (sdk_id, )) conn.execute( """ @@ -585,17 +585,17 @@ class DeveloperEcosystemManager: conn.commit() return SDKVersion( - id=version_id, - sdk_id=sdk_id, - version=version, - is_latest=True, - is_lts=is_lts, - release_notes=release_notes, - download_url=download_url, - checksum=checksum, - file_size=file_size, - download_count=0, - created_at=now, + id = version_id, + sdk_id = sdk_id, + version = version, + is_latest = True, + is_lts = is_lts, + release_notes = release_notes, + download_url = download_url, + checksum = checksum, + file_size = file_size, + download_count = 0, + created_at = now, ) # ==================== 模板市场 ==================== @@ -609,48 +609,48 @@ class DeveloperEcosystemManager: tags: list[str], author_id: str, author_name: str, - price: float = 0.0, - currency: str = "CNY", - preview_image_url: str | None = None, - demo_url: str | None = None, - documentation_url: str | None = None, - download_url: str | None = None, - version: str = "1.0.0", - min_platform_version: str = "1.0.0", - file_size: int = 0, - checksum: str = "", + price: float = 0.0, + currency: str = "CNY", + preview_image_url: str | None = None, + demo_url: str | None = None, + documentation_url: str | None = None, + download_url: str | None = None, + version: str = "1.0.0", + min_platform_version: str = "1.0.0", + file_size: int = 0, + checksum: str = "", ) -> TemplateMarketItem: """创建模板""" - template_id = f"tpl_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + template_id = f"tpl_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - template = TemplateMarketItem( - id=template_id, - name=name, - description=description, - category=category, - subcategory=subcategory, - tags=tags, - author_id=author_id, - author_name=author_name, - status=TemplateStatus.PENDING, - price=price, - currency=currency, - preview_image_url=preview_image_url, - demo_url=demo_url, - documentation_url=documentation_url, - download_url=download_url, - install_count=0, - rating=0.0, - rating_count=0, - review_count=0, - version=version, - min_platform_version=min_platform_version, - file_size=file_size, - checksum=checksum, - created_at=now, - updated_at=now, - published_at=None, + template = TemplateMarketItem( + id = template_id, + name = name, + description = description, + category = category, + subcategory = subcategory, + tags = tags, + author_id = author_id, + author_name = author_name, + status = TemplateStatus.PENDING, + price = price, + currency = currency, + preview_image_url = preview_image_url, + demo_url = demo_url, + documentation_url = documentation_url, + download_url = download_url, + install_count = 0, + rating = 0.0, + rating_count = 0, + review_count = 0, + version = version, + min_platform_version = min_platform_version, + file_size = file_size, + checksum = checksum, + created_at = now, + updated_at = now, + published_at = None, ) with self._get_db() as conn: @@ -699,8 +699,8 @@ class DeveloperEcosystemManager: def get_template(self, template_id: str) -> TemplateMarketItem | None: """获取模板详情""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM template_market WHERE id = ?", (template_id,) + row = conn.execute( + "SELECT * FROM template_market WHERE id = ?", (template_id, ) ).fetchone() if row: @@ -709,26 +709,26 @@ class DeveloperEcosystemManager: def list_templates( self, - category: TemplateCategory | None = None, - status: TemplateStatus | None = None, - search: str | None = None, - author_id: str | None = None, - min_price: float | None = None, - max_price: float | None = None, - sort_by: str = "created_at", + category: TemplateCategory | None = None, + status: TemplateStatus | None = None, + search: str | None = None, + author_id: str | None = None, + min_price: float | None = None, + max_price: float | None = None, + sort_by: str = "created_at", ) -> list[TemplateMarketItem]: """列出模板""" - query = "SELECT * FROM template_market WHERE 1=1" - params = [] + query = "SELECT * FROM template_market WHERE 1 = 1" + params = [] if category: - query += " AND category = ?" + query += " AND category = ?" params.append(category.value) if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status.value) if author_id: - query += " AND author_id = ?" + query += " AND author_id = ?" params.append(author_id) if search: query += " AND (name LIKE ? OR description LIKE ? OR tags LIKE ?)" @@ -741,7 +741,7 @@ class DeveloperEcosystemManager: params.append(max_price) # 排序 - sort_mapping = { + sort_mapping = { "created_at": "created_at DESC", "rating": "rating DESC", "install_count": "install_count DESC", @@ -751,19 +751,19 @@ class DeveloperEcosystemManager: query += f" ORDER BY {sort_mapping.get(sort_by, 'created_at DESC')}" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_template(row) for row in rows] def approve_template(self, template_id: str, reviewed_by: str) -> TemplateMarketItem | None: """审核通过模板""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE template_market - SET status = ?, updated_at = ? - WHERE id = ? + SET status = ?, updated_at = ? + WHERE id = ? """, (TemplateStatus.APPROVED.value, now, template_id), ) @@ -773,14 +773,14 @@ class DeveloperEcosystemManager: def publish_template(self, template_id: str) -> TemplateMarketItem | None: """发布模板""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE template_market - SET status = ?, published_at = ?, updated_at = ? - WHERE id = ? + SET status = ?, published_at = ?, updated_at = ? + WHERE id = ? """, (TemplateStatus.PUBLISHED.value, now, now, template_id), ) @@ -790,14 +790,14 @@ class DeveloperEcosystemManager: def reject_template(self, template_id: str, reason: str) -> TemplateMarketItem | None: """拒绝模板""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE template_market - SET status = ?, updated_at = ? - WHERE id = ? + SET status = ?, updated_at = ? + WHERE id = ? """, (TemplateStatus.REJECTED.value, now, template_id), ) @@ -811,10 +811,10 @@ class DeveloperEcosystemManager: conn.execute( """ UPDATE template_market - SET install_count = install_count + 1 - WHERE id = ? + SET install_count = install_count + 1 + WHERE id = ? """, - (template_id,), + (template_id, ), ) conn.commit() @@ -825,23 +825,23 @@ class DeveloperEcosystemManager: user_name: str, rating: int, comment: str, - is_verified_purchase: bool = False, + is_verified_purchase: bool = False, ) -> TemplateReview: """添加模板评价""" - review_id = f"tr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + review_id = f"tr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - review = TemplateReview( - id=review_id, - template_id=template_id, - user_id=user_id, - user_name=user_name, - rating=rating, - comment=comment, - is_verified_purchase=is_verified_purchase, - helpful_count=0, - created_at=now, - updated_at=now, + review = TemplateReview( + id = review_id, + template_id = template_id, + user_id = user_id, + user_name = user_name, + rating = rating, + comment = comment, + is_verified_purchase = is_verified_purchase, + helpful_count = 0, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -874,21 +874,21 @@ class DeveloperEcosystemManager: def _update_template_rating(self, conn, template_id: str) -> None: """更新模板评分""" - row = conn.execute( + row = conn.execute( """ SELECT AVG(rating) as avg_rating, COUNT(*) as count FROM template_reviews - WHERE template_id = ? + WHERE template_id = ? """, - (template_id,), + (template_id, ), ).fetchone() if row: conn.execute( """ UPDATE template_market - SET rating = ?, rating_count = ?, review_count = ? - WHERE id = ? + SET rating = ?, rating_count = ?, review_count = ? + WHERE id = ? """, ( round(row["avg_rating"], 2) if row["avg_rating"] else 0, @@ -898,12 +898,12 @@ class DeveloperEcosystemManager: ), ) - def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]: + def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]: """获取模板评价""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM template_reviews - WHERE template_id = ? + WHERE template_id = ? ORDER BY created_at DESC LIMIT ?""", (template_id, limit), @@ -920,59 +920,59 @@ class DeveloperEcosystemManager: tags: list[str], author_id: str, author_name: str, - price: float = 0.0, - currency: str = "CNY", - pricing_model: str = "free", - preview_image_url: str | None = None, - demo_url: str | None = None, - documentation_url: str | None = None, - repository_url: str | None = None, - download_url: str | None = None, - webhook_url: str | None = None, - permissions: list[str] = None, - version: str = "1.0.0", - min_platform_version: str = "1.0.0", - file_size: int = 0, - checksum: str = "", + price: float = 0.0, + currency: str = "CNY", + pricing_model: str = "free", + preview_image_url: str | None = None, + demo_url: str | None = None, + documentation_url: str | None = None, + repository_url: str | None = None, + download_url: str | None = None, + webhook_url: str | None = None, + permissions: list[str] = None, + version: str = "1.0.0", + min_platform_version: str = "1.0.0", + file_size: int = 0, + checksum: str = "", ) -> PluginMarketItem: """创建插件""" - plugin_id = f"plg_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + plugin_id = f"plg_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - plugin = PluginMarketItem( - id=plugin_id, - name=name, - description=description, - category=category, - tags=tags, - author_id=author_id, - author_name=author_name, - status=PluginStatus.PENDING, - price=price, - currency=currency, - pricing_model=pricing_model, - preview_image_url=preview_image_url, - demo_url=demo_url, - documentation_url=documentation_url, - repository_url=repository_url, - download_url=download_url, - webhook_url=webhook_url, - permissions=permissions or [], - install_count=0, - active_install_count=0, - rating=0.0, - rating_count=0, - review_count=0, - version=version, - min_platform_version=min_platform_version, - file_size=file_size, - checksum=checksum, - created_at=now, - updated_at=now, - published_at=None, - reviewed_by=None, - reviewed_at=None, - review_notes=None, + plugin = PluginMarketItem( + id = plugin_id, + name = name, + description = description, + category = category, + tags = tags, + author_id = author_id, + author_name = author_name, + status = PluginStatus.PENDING, + price = price, + currency = currency, + pricing_model = pricing_model, + preview_image_url = preview_image_url, + demo_url = demo_url, + documentation_url = documentation_url, + repository_url = repository_url, + download_url = download_url, + webhook_url = webhook_url, + permissions = permissions or [], + install_count = 0, + active_install_count = 0, + rating = 0.0, + rating_count = 0, + review_count = 0, + version = version, + min_platform_version = min_platform_version, + file_size = file_size, + checksum = checksum, + created_at = now, + updated_at = now, + published_at = None, + reviewed_by = None, + reviewed_at = None, + review_notes = None, ) with self._get_db() as conn: @@ -1032,7 +1032,7 @@ class DeveloperEcosystemManager: def get_plugin(self, plugin_id: str) -> PluginMarketItem | None: """获取插件详情""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id,)).fetchone() + row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id, )).fetchone() if row: return self._row_to_plugin(row) @@ -1040,30 +1040,30 @@ class DeveloperEcosystemManager: def list_plugins( self, - category: PluginCategory | None = None, - status: PluginStatus | None = None, - search: str | None = None, - author_id: str | None = None, - sort_by: str = "created_at", + category: PluginCategory | None = None, + status: PluginStatus | None = None, + search: str | None = None, + author_id: str | None = None, + sort_by: str = "created_at", ) -> list[PluginMarketItem]: """列出插件""" - query = "SELECT * FROM plugin_market WHERE 1=1" - params = [] + query = "SELECT * FROM plugin_market WHERE 1 = 1" + params = [] if category: - query += " AND category = ?" + query += " AND category = ?" params.append(category.value) if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status.value) if author_id: - query += " AND author_id = ?" + query += " AND author_id = ?" params.append(author_id) if search: query += " AND (name LIKE ? OR description LIKE ? OR tags LIKE ?)" params.extend([f"%{search}%", f"%{search}%", f"%{search}%"]) - sort_mapping = { + sort_mapping = { "created_at": "created_at DESC", "rating": "rating DESC", "install_count": "install_count DESC", @@ -1072,21 +1072,21 @@ class DeveloperEcosystemManager: query += f" ORDER BY {sort_mapping.get(sort_by, 'created_at DESC')}" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_plugin(row) for row in rows] def review_plugin( - self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "" + self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "" ) -> PluginMarketItem | None: """审核插件""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE plugin_market - SET status = ?, reviewed_by = ?, reviewed_at = ?, review_notes = ?, updated_at = ? - WHERE id = ? + SET status = ?, reviewed_by = ?, reviewed_at = ?, review_notes = ?, updated_at = ? + WHERE id = ? """, (status.value, reviewed_by, now, notes, now, plugin_id), ) @@ -1096,14 +1096,14 @@ class DeveloperEcosystemManager: def publish_plugin(self, plugin_id: str) -> PluginMarketItem | None: """发布插件""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE plugin_market - SET status = ?, published_at = ?, updated_at = ? - WHERE id = ? + SET status = ?, published_at = ?, updated_at = ? + WHERE id = ? """, (PluginStatus.PUBLISHED.value, now, now, plugin_id), ) @@ -1111,26 +1111,26 @@ class DeveloperEcosystemManager: return self.get_plugin(plugin_id) - def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None: + def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None: """增加插件安装计数""" with self._get_db() as conn: conn.execute( """ UPDATE plugin_market - SET install_count = install_count + 1 - WHERE id = ? + SET install_count = install_count + 1 + WHERE id = ? """, - (plugin_id,), + (plugin_id, ), ) if active: conn.execute( """ UPDATE plugin_market - SET active_install_count = active_install_count + 1 - WHERE id = ? + SET active_install_count = active_install_count + 1 + WHERE id = ? """, - (plugin_id,), + (plugin_id, ), ) conn.commit() @@ -1141,23 +1141,23 @@ class DeveloperEcosystemManager: user_name: str, rating: int, comment: str, - is_verified_purchase: bool = False, + is_verified_purchase: bool = False, ) -> PluginReview: """添加插件评价""" - review_id = f"pr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + review_id = f"pr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - review = PluginReview( - id=review_id, - plugin_id=plugin_id, - user_id=user_id, - user_name=user_name, - rating=rating, - comment=comment, - is_verified_purchase=is_verified_purchase, - helpful_count=0, - created_at=now, - updated_at=now, + review = PluginReview( + id = review_id, + plugin_id = plugin_id, + user_id = user_id, + user_name = user_name, + rating = rating, + comment = comment, + is_verified_purchase = is_verified_purchase, + helpful_count = 0, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1189,21 +1189,21 @@ class DeveloperEcosystemManager: def _update_plugin_rating(self, conn, plugin_id: str) -> None: """更新插件评分""" - row = conn.execute( + row = conn.execute( """ SELECT AVG(rating) as avg_rating, COUNT(*) as count FROM plugin_reviews - WHERE plugin_id = ? + WHERE plugin_id = ? """, - (plugin_id,), + (plugin_id, ), ).fetchone() if row: conn.execute( """ UPDATE plugin_market - SET rating = ?, rating_count = ?, review_count = ? - WHERE id = ? + SET rating = ?, rating_count = ?, review_count = ? + WHERE id = ? """, ( round(row["avg_rating"], 2) if row["avg_rating"] else 0, @@ -1213,12 +1213,12 @@ class DeveloperEcosystemManager: ), ) - def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]: + def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]: """获取插件评价""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM plugin_reviews - WHERE plugin_id = ? + WHERE plugin_id = ? ORDER BY created_at DESC LIMIT ?""", (plugin_id, limit), @@ -1239,25 +1239,25 @@ class DeveloperEcosystemManager: transaction_id: str, ) -> DeveloperRevenue: """记录开发者收益""" - revenue_id = f"rev_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + revenue_id = f"rev_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - platform_fee = sale_amount * self.platform_fee_rate - developer_earnings = sale_amount - platform_fee + platform_fee = sale_amount * self.platform_fee_rate + developer_earnings = sale_amount - platform_fee - revenue = DeveloperRevenue( - id=revenue_id, - developer_id=developer_id, - item_type=item_type, - item_id=item_id, - item_name=item_name, - sale_amount=sale_amount, - platform_fee=platform_fee, - developer_earnings=developer_earnings, - currency=currency, - buyer_id=buyer_id, - transaction_id=transaction_id, - created_at=now, + revenue = DeveloperRevenue( + id = revenue_id, + developer_id = developer_id, + item_type = item_type, + item_id = item_id, + item_name = item_name, + sale_amount = sale_amount, + platform_fee = platform_fee, + developer_earnings = developer_earnings, + currency = currency, + buyer_id = buyer_id, + transaction_id = transaction_id, + created_at = now, ) with self._get_db() as conn: @@ -1288,8 +1288,8 @@ class DeveloperEcosystemManager: conn.execute( """ UPDATE developer_profiles - SET total_sales = total_sales + ? - WHERE id = ? + SET total_sales = total_sales + ? + WHERE id = ? """, (sale_amount, developer_id), ) @@ -1301,12 +1301,12 @@ class DeveloperEcosystemManager: def get_developer_revenues( self, developer_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, ) -> list[DeveloperRevenue]: """获取开发者收益记录""" - query = "SELECT * FROM developer_revenues WHERE developer_id = ?" - params = [developer_id] + query = "SELECT * FROM developer_revenues WHERE developer_id = ?" + params = [developer_id] if start_date: query += " AND created_at >= ?" @@ -1318,13 +1318,13 @@ class DeveloperEcosystemManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_developer_revenue(row) for row in rows] def get_developer_revenue_summary(self, developer_id: str) -> dict: """获取开发者收益汇总""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """ SELECT SUM(sale_amount) as total_sales, @@ -1332,9 +1332,9 @@ class DeveloperEcosystemManager: SUM(developer_earnings) as total_earnings, COUNT(*) as transaction_count FROM developer_revenues - WHERE developer_id = ? + WHERE developer_id = ? """, - (developer_id,), + (developer_id, ), ).fetchone() return { @@ -1352,34 +1352,34 @@ class DeveloperEcosystemManager: user_id: str, display_name: str, email: str, - bio: str | None = None, - website: str | None = None, - github_url: str | None = None, - avatar_url: str | None = None, + bio: str | None = None, + website: str | None = None, + github_url: str | None = None, + avatar_url: str | None = None, ) -> DeveloperProfile: """创建开发者档案""" - profile_id = f"dev_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + profile_id = f"dev_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - profile = DeveloperProfile( - id=profile_id, - user_id=user_id, - display_name=display_name, - email=email, - bio=bio, - website=website, - github_url=github_url, - avatar_url=avatar_url, - status=DeveloperStatus.UNVERIFIED, - verification_documents={}, - total_sales=0.0, - total_downloads=0, - plugin_count=0, - template_count=0, - rating_average=0.0, - created_at=now, - updated_at=now, - verified_at=None, + profile = DeveloperProfile( + id = profile_id, + user_id = user_id, + display_name = display_name, + email = email, + bio = bio, + website = website, + github_url = github_url, + avatar_url = avatar_url, + status = DeveloperStatus.UNVERIFIED, + verification_documents = {}, + total_sales = 0.0, + total_downloads = 0, + plugin_count = 0, + template_count = 0, + rating_average = 0.0, + created_at = now, + updated_at = now, + verified_at = None, ) with self._get_db() as conn: @@ -1419,8 +1419,8 @@ class DeveloperEcosystemManager: def get_developer_profile(self, developer_id: str) -> DeveloperProfile | None: """获取开发者档案""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM developer_profiles WHERE id = ?", (developer_id,) + row = conn.execute( + "SELECT * FROM developer_profiles WHERE id = ?", (developer_id, ) ).fetchone() if row: @@ -1430,8 +1430,8 @@ class DeveloperEcosystemManager: def get_developer_profile_by_user(self, user_id: str) -> DeveloperProfile | None: """通过用户 ID 获取开发者档案""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,) + row = conn.execute( + "SELECT * FROM developer_profiles WHERE user_id = ?", (user_id, ) ).fetchone() if row: @@ -1442,14 +1442,14 @@ class DeveloperEcosystemManager: self, developer_id: str, status: DeveloperStatus ) -> DeveloperProfile | None: """验证开发者""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE developer_profiles - SET status = ?, verified_at = ?, updated_at = ? - WHERE id = ? + SET status = ?, verified_at = ?, updated_at = ? + WHERE id = ? """, ( status.value, @@ -1468,22 +1468,22 @@ class DeveloperEcosystemManager: """更新开发者统计信息""" with self._get_db() as conn: # 统计插件数量 - plugin_row = conn.execute( - "SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?", (developer_id,) + plugin_row = conn.execute( + "SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?", (developer_id, ) ).fetchone() # 统计模板数量 - template_row = conn.execute( - "SELECT COUNT(*) as count FROM template_market WHERE author_id = ?", (developer_id,) + template_row = conn.execute( + "SELECT COUNT(*) as count FROM template_market WHERE author_id = ?", (developer_id, ) ).fetchone() # 统计总下载量 - download_row = conn.execute( + download_row = conn.execute( """ SELECT SUM(install_count) as total FROM ( - SELECT install_count FROM plugin_market WHERE author_id = ? + SELECT install_count FROM plugin_market WHERE author_id = ? UNION ALL - SELECT install_count FROM template_market WHERE author_id = ? + SELECT install_count FROM template_market WHERE author_id = ? ) """, (developer_id, developer_id), @@ -1492,8 +1492,8 @@ class DeveloperEcosystemManager: conn.execute( """ UPDATE developer_profiles - SET plugin_count = ?, template_count = ?, total_downloads = ?, updated_at = ? - WHERE id = ? + SET plugin_count = ?, template_count = ?, total_downloads = ?, updated_at = ? + WHERE id = ? """, ( plugin_row["count"], @@ -1518,31 +1518,31 @@ class DeveloperEcosystemManager: tags: list[str], author_id: str, author_name: str, - sdk_id: str | None = None, - api_endpoints: list[str] = None, + sdk_id: str | None = None, + api_endpoints: list[str] = None, ) -> CodeExample: """创建代码示例""" - example_id = f"ex_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + example_id = f"ex_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - example = CodeExample( - id=example_id, - title=title, - description=description, - language=language, - category=category, - code=code, - explanation=explanation, - tags=tags, - author_id=author_id, - author_name=author_name, - sdk_id=sdk_id, - api_endpoints=api_endpoints or [], - view_count=0, - copy_count=0, - rating=0.0, - created_at=now, - updated_at=now, + example = CodeExample( + id = example_id, + title = title, + description = description, + language = language, + category = category, + code = code, + explanation = explanation, + tags = tags, + author_id = author_id, + author_name = author_name, + sdk_id = sdk_id, + api_endpoints = api_endpoints or [], + view_count = 0, + copy_count = 0, + rating = 0.0, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1581,7 +1581,7 @@ class DeveloperEcosystemManager: def get_code_example(self, example_id: str) -> CodeExample | None: """获取代码示例""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id,)).fetchone() + row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id, )).fetchone() if row: return self._row_to_code_example(row) @@ -1589,23 +1589,23 @@ class DeveloperEcosystemManager: def list_code_examples( self, - language: str | None = None, - category: str | None = None, - sdk_id: str | None = None, - search: str | None = None, + language: str | None = None, + category: str | None = None, + sdk_id: str | None = None, + search: str | None = None, ) -> list[CodeExample]: """列出代码示例""" - query = "SELECT * FROM code_examples WHERE 1=1" - params = [] + query = "SELECT * FROM code_examples WHERE 1 = 1" + params = [] if language: - query += " AND language = ?" + query += " AND language = ?" params.append(language) if category: - query += " AND category = ?" + query += " AND category = ?" params.append(category) if sdk_id: - query += " AND sdk_id = ?" + query += " AND sdk_id = ?" params.append(sdk_id) if search: query += " AND (title LIKE ? OR description LIKE ? OR tags LIKE ?)" @@ -1614,7 +1614,7 @@ class DeveloperEcosystemManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_code_example(row) for row in rows] def increment_example_view(self, example_id: str) -> None: @@ -1623,10 +1623,10 @@ class DeveloperEcosystemManager: conn.execute( """ UPDATE code_examples - SET view_count = view_count + 1 - WHERE id = ? + SET view_count = view_count + 1 + WHERE id = ? """, - (example_id,), + (example_id, ), ) conn.commit() @@ -1636,10 +1636,10 @@ class DeveloperEcosystemManager: conn.execute( """ UPDATE code_examples - SET copy_count = copy_count + 1 - WHERE id = ? + SET copy_count = copy_count + 1 + WHERE id = ? """, - (example_id,), + (example_id, ), ) conn.commit() @@ -1655,18 +1655,18 @@ class DeveloperEcosystemManager: generated_by: str, ) -> APIDocumentation: """创建 API 文档""" - doc_id = f"api_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + doc_id = f"api_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - doc = APIDocumentation( - id=doc_id, - version=version, - openapi_spec=openapi_spec, - markdown_content=markdown_content, - html_content=html_content, - changelog=changelog, - generated_at=now, - generated_by=generated_by, + doc = APIDocumentation( + id = doc_id, + version = version, + openapi_spec = openapi_spec, + markdown_content = markdown_content, + html_content = html_content, + changelog = changelog, + generated_at = now, + generated_by = generated_by, ) with self._get_db() as conn: @@ -1695,7 +1695,7 @@ class DeveloperEcosystemManager: def get_api_documentation(self, doc_id: str) -> APIDocumentation | None: """获取 API 文档""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id,)).fetchone() + row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id, )).fetchone() if row: return self._row_to_api_documentation(row) @@ -1704,7 +1704,7 @@ class DeveloperEcosystemManager: def get_latest_api_documentation(self) -> APIDocumentation | None: """获取最新 API 文档""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( "SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1" ).fetchone() @@ -1718,42 +1718,42 @@ class DeveloperEcosystemManager: self, name: str, description: str, - theme: str = "default", - custom_css: str | None = None, - custom_js: str | None = None, - logo_url: str | None = None, - favicon_url: str | None = None, - primary_color: str = "#1890ff", - secondary_color: str = "#52c41a", - support_email: str = "support@insightflow.io", - support_url: str | None = None, - github_url: str | None = None, - discord_url: str | None = None, - api_base_url: str = "https://api.insightflow.io", + theme: str = "default", + custom_css: str | None = None, + custom_js: str | None = None, + logo_url: str | None = None, + favicon_url: str | None = None, + primary_color: str = "#1890ff", + secondary_color: str = "#52c41a", + support_email: str = "support@insightflow.io", + support_url: str | None = None, + github_url: str | None = None, + discord_url: str | None = None, + api_base_url: str = "https://api.insightflow.io", ) -> DeveloperPortalConfig: """创建开发者门户配置""" - config_id = f"portal_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + config_id = f"portal_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - config = DeveloperPortalConfig( - id=config_id, - name=name, - description=description, - theme=theme, - custom_css=custom_css, - custom_js=custom_js, - logo_url=logo_url, - favicon_url=favicon_url, - primary_color=primary_color, - secondary_color=secondary_color, - support_email=support_email, - support_url=support_url, - github_url=github_url, - discord_url=discord_url, - api_base_url=api_base_url, - is_active=True, - created_at=now, - updated_at=now, + config = DeveloperPortalConfig( + id = config_id, + name = name, + description = description, + theme = theme, + custom_css = custom_css, + custom_js = custom_js, + logo_url = logo_url, + favicon_url = favicon_url, + primary_color = primary_color, + secondary_color = secondary_color, + support_email = support_email, + support_url = support_url, + github_url = github_url, + discord_url = discord_url, + api_base_url = api_base_url, + is_active = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1793,8 +1793,8 @@ class DeveloperEcosystemManager: def get_portal_config(self, config_id: str) -> DeveloperPortalConfig | None: """获取开发者门户配置""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,) + row = conn.execute( + "SELECT * FROM developer_portal_configs WHERE id = ?", (config_id, ) ).fetchone() if row: @@ -1804,8 +1804,8 @@ class DeveloperEcosystemManager: def get_active_portal_config(self) -> DeveloperPortalConfig | None: """获取活跃的开发者门户配置""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1" + row = conn.execute( + "SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1" ).fetchone() if row: @@ -1817,249 +1817,249 @@ class DeveloperEcosystemManager: def _row_to_sdk_release(self, row) -> SDKRelease: """将数据库行转换为 SDKRelease""" return SDKRelease( - id=row["id"], - name=row["name"], - language=SDKLanguage(row["language"]), - version=row["version"], - description=row["description"], - changelog=row["changelog"], - download_url=row["download_url"], - documentation_url=row["documentation_url"], - repository_url=row["repository_url"], - package_name=row["package_name"], - status=SDKStatus(row["status"]), - min_platform_version=row["min_platform_version"], - dependencies=json.loads(row["dependencies"]), - file_size=row["file_size"], - checksum=row["checksum"], - download_count=row["download_count"], - created_at=row["created_at"], - updated_at=row["updated_at"], - published_at=row["published_at"], - created_by=row["created_by"], + id = row["id"], + name = row["name"], + language = SDKLanguage(row["language"]), + version = row["version"], + description = row["description"], + changelog = row["changelog"], + download_url = row["download_url"], + documentation_url = row["documentation_url"], + repository_url = row["repository_url"], + package_name = row["package_name"], + status = SDKStatus(row["status"]), + min_platform_version = row["min_platform_version"], + dependencies = json.loads(row["dependencies"]), + file_size = row["file_size"], + checksum = row["checksum"], + download_count = row["download_count"], + created_at = row["created_at"], + updated_at = row["updated_at"], + published_at = row["published_at"], + created_by = row["created_by"], ) def _row_to_sdk_version(self, row) -> SDKVersion: """将数据库行转换为 SDKVersion""" return SDKVersion( - id=row["id"], - sdk_id=row["sdk_id"], - version=row["version"], - is_latest=bool(row["is_latest"]), - is_lts=bool(row["is_lts"]), - release_notes=row["release_notes"], - download_url=row["download_url"], - checksum=row["checksum"], - file_size=row["file_size"], - download_count=row["download_count"], - created_at=row["created_at"], + id = row["id"], + sdk_id = row["sdk_id"], + version = row["version"], + is_latest = bool(row["is_latest"]), + is_lts = bool(row["is_lts"]), + release_notes = row["release_notes"], + download_url = row["download_url"], + checksum = row["checksum"], + file_size = row["file_size"], + download_count = row["download_count"], + created_at = row["created_at"], ) def _row_to_template(self, row) -> TemplateMarketItem: """将数据库行转换为 TemplateMarketItem""" return TemplateMarketItem( - id=row["id"], - name=row["name"], - description=row["description"], - category=TemplateCategory(row["category"]), - subcategory=row["subcategory"], - tags=json.loads(row["tags"]), - author_id=row["author_id"], - author_name=row["author_name"], - status=TemplateStatus(row["status"]), - price=row["price"], - currency=row["currency"], - preview_image_url=row["preview_image_url"], - demo_url=row["demo_url"], - documentation_url=row["documentation_url"], - download_url=row["download_url"], - install_count=row["install_count"], - rating=row["rating"], - rating_count=row["rating_count"], - review_count=row["review_count"], - version=row["version"], - min_platform_version=row["min_platform_version"], - file_size=row["file_size"], - checksum=row["checksum"], - created_at=row["created_at"], - updated_at=row["updated_at"], - published_at=row["published_at"], + id = row["id"], + name = row["name"], + description = row["description"], + category = TemplateCategory(row["category"]), + subcategory = row["subcategory"], + tags = json.loads(row["tags"]), + author_id = row["author_id"], + author_name = row["author_name"], + status = TemplateStatus(row["status"]), + price = row["price"], + currency = row["currency"], + preview_image_url = row["preview_image_url"], + demo_url = row["demo_url"], + documentation_url = row["documentation_url"], + download_url = row["download_url"], + install_count = row["install_count"], + rating = row["rating"], + rating_count = row["rating_count"], + review_count = row["review_count"], + version = row["version"], + min_platform_version = row["min_platform_version"], + file_size = row["file_size"], + checksum = row["checksum"], + created_at = row["created_at"], + updated_at = row["updated_at"], + published_at = row["published_at"], ) def _row_to_template_review(self, row) -> TemplateReview: """将数据库行转换为 TemplateReview""" return TemplateReview( - id=row["id"], - template_id=row["template_id"], - user_id=row["user_id"], - user_name=row["user_name"], - rating=row["rating"], - comment=row["comment"], - is_verified_purchase=bool(row["is_verified_purchase"]), - helpful_count=row["helpful_count"], - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + template_id = row["template_id"], + user_id = row["user_id"], + user_name = row["user_name"], + rating = row["rating"], + comment = row["comment"], + is_verified_purchase = bool(row["is_verified_purchase"]), + helpful_count = row["helpful_count"], + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_plugin(self, row) -> PluginMarketItem: """将数据库行转换为 PluginMarketItem""" return PluginMarketItem( - id=row["id"], - name=row["name"], - description=row["description"], - category=PluginCategory(row["category"]), - tags=json.loads(row["tags"]), - author_id=row["author_id"], - author_name=row["author_name"], - status=PluginStatus(row["status"]), - price=row["price"], - currency=row["currency"], - pricing_model=row["pricing_model"], - preview_image_url=row["preview_image_url"], - demo_url=row["demo_url"], - documentation_url=row["documentation_url"], - repository_url=row["repository_url"], - download_url=row["download_url"], - webhook_url=row["webhook_url"], - permissions=json.loads(row["permissions"]), - install_count=row["install_count"], - active_install_count=row["active_install_count"], - rating=row["rating"], - rating_count=row["rating_count"], - review_count=row["review_count"], - version=row["version"], - min_platform_version=row["min_platform_version"], - file_size=row["file_size"], - checksum=row["checksum"], - created_at=row["created_at"], - updated_at=row["updated_at"], - published_at=row["published_at"], - reviewed_by=row["reviewed_by"], - reviewed_at=row["reviewed_at"], - review_notes=row["review_notes"], + id = row["id"], + name = row["name"], + description = row["description"], + category = PluginCategory(row["category"]), + tags = json.loads(row["tags"]), + author_id = row["author_id"], + author_name = row["author_name"], + status = PluginStatus(row["status"]), + price = row["price"], + currency = row["currency"], + pricing_model = row["pricing_model"], + preview_image_url = row["preview_image_url"], + demo_url = row["demo_url"], + documentation_url = row["documentation_url"], + repository_url = row["repository_url"], + download_url = row["download_url"], + webhook_url = row["webhook_url"], + permissions = json.loads(row["permissions"]), + install_count = row["install_count"], + active_install_count = row["active_install_count"], + rating = row["rating"], + rating_count = row["rating_count"], + review_count = row["review_count"], + version = row["version"], + min_platform_version = row["min_platform_version"], + file_size = row["file_size"], + checksum = row["checksum"], + created_at = row["created_at"], + updated_at = row["updated_at"], + published_at = row["published_at"], + reviewed_by = row["reviewed_by"], + reviewed_at = row["reviewed_at"], + review_notes = row["review_notes"], ) def _row_to_plugin_review(self, row) -> PluginReview: """将数据库行转换为 PluginReview""" return PluginReview( - id=row["id"], - plugin_id=row["plugin_id"], - user_id=row["user_id"], - user_name=row["user_name"], - rating=row["rating"], - comment=row["comment"], - is_verified_purchase=bool(row["is_verified_purchase"]), - helpful_count=row["helpful_count"], - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + plugin_id = row["plugin_id"], + user_id = row["user_id"], + user_name = row["user_name"], + rating = row["rating"], + comment = row["comment"], + is_verified_purchase = bool(row["is_verified_purchase"]), + helpful_count = row["helpful_count"], + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_developer_profile(self, row) -> DeveloperProfile: """将数据库行转换为 DeveloperProfile""" return DeveloperProfile( - id=row["id"], - user_id=row["user_id"], - display_name=row["display_name"], - email=row["email"], - bio=row["bio"], - website=row["website"], - github_url=row["github_url"], - avatar_url=row["avatar_url"], - status=DeveloperStatus(row["status"]), - verification_documents=json.loads(row["verification_documents"]), - total_sales=row["total_sales"], - total_downloads=row["total_downloads"], - plugin_count=row["plugin_count"], - template_count=row["template_count"], - rating_average=row["rating_average"], - created_at=row["created_at"], - updated_at=row["updated_at"], - verified_at=row["verified_at"], + id = row["id"], + user_id = row["user_id"], + display_name = row["display_name"], + email = row["email"], + bio = row["bio"], + website = row["website"], + github_url = row["github_url"], + avatar_url = row["avatar_url"], + status = DeveloperStatus(row["status"]), + verification_documents = json.loads(row["verification_documents"]), + total_sales = row["total_sales"], + total_downloads = row["total_downloads"], + plugin_count = row["plugin_count"], + template_count = row["template_count"], + rating_average = row["rating_average"], + created_at = row["created_at"], + updated_at = row["updated_at"], + verified_at = row["verified_at"], ) def _row_to_developer_revenue(self, row) -> DeveloperRevenue: """将数据库行转换为 DeveloperRevenue""" return DeveloperRevenue( - id=row["id"], - developer_id=row["developer_id"], - item_type=row["item_type"], - item_id=row["item_id"], - item_name=row["item_name"], - sale_amount=row["sale_amount"], - platform_fee=row["platform_fee"], - developer_earnings=row["developer_earnings"], - currency=row["currency"], - buyer_id=row["buyer_id"], - transaction_id=row["transaction_id"], - created_at=row["created_at"], + id = row["id"], + developer_id = row["developer_id"], + item_type = row["item_type"], + item_id = row["item_id"], + item_name = row["item_name"], + sale_amount = row["sale_amount"], + platform_fee = row["platform_fee"], + developer_earnings = row["developer_earnings"], + currency = row["currency"], + buyer_id = row["buyer_id"], + transaction_id = row["transaction_id"], + created_at = row["created_at"], ) def _row_to_code_example(self, row) -> CodeExample: """将数据库行转换为 CodeExample""" return CodeExample( - id=row["id"], - title=row["title"], - description=row["description"], - language=row["language"], - category=row["category"], - code=row["code"], - explanation=row["explanation"], - tags=json.loads(row["tags"]), - author_id=row["author_id"], - author_name=row["author_name"], - sdk_id=row["sdk_id"], - api_endpoints=json.loads(row["api_endpoints"]), - view_count=row["view_count"], - copy_count=row["copy_count"], - rating=row["rating"], - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + title = row["title"], + description = row["description"], + language = row["language"], + category = row["category"], + code = row["code"], + explanation = row["explanation"], + tags = json.loads(row["tags"]), + author_id = row["author_id"], + author_name = row["author_name"], + sdk_id = row["sdk_id"], + api_endpoints = json.loads(row["api_endpoints"]), + view_count = row["view_count"], + copy_count = row["copy_count"], + rating = row["rating"], + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_api_documentation(self, row) -> APIDocumentation: """将数据库行转换为 APIDocumentation""" return APIDocumentation( - id=row["id"], - version=row["version"], - openapi_spec=row["openapi_spec"], - markdown_content=row["markdown_content"], - html_content=row["html_content"], - changelog=row["changelog"], - generated_at=row["generated_at"], - generated_by=row["generated_by"], + id = row["id"], + version = row["version"], + openapi_spec = row["openapi_spec"], + markdown_content = row["markdown_content"], + html_content = row["html_content"], + changelog = row["changelog"], + generated_at = row["generated_at"], + generated_by = row["generated_by"], ) def _row_to_portal_config(self, row) -> DeveloperPortalConfig: """将数据库行转换为 DeveloperPortalConfig""" return DeveloperPortalConfig( - id=row["id"], - name=row["name"], - description=row["description"], - theme=row["theme"], - custom_css=row["custom_css"], - custom_js=row["custom_js"], - logo_url=row["logo_url"], - favicon_url=row["favicon_url"], - primary_color=row["primary_color"], - secondary_color=row["secondary_color"], - support_email=row["support_email"], - support_url=row["support_url"], - github_url=row["github_url"], - discord_url=row["discord_url"], - api_base_url=row["api_base_url"], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + name = row["name"], + description = row["description"], + theme = row["theme"], + custom_css = row["custom_css"], + custom_js = row["custom_js"], + logo_url = row["logo_url"], + favicon_url = row["favicon_url"], + primary_color = row["primary_color"], + secondary_color = row["secondary_color"], + support_email = row["support_email"], + support_url = row["support_url"], + github_url = row["github_url"], + discord_url = row["discord_url"], + api_base_url = row["api_base_url"], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) # Singleton instance -_developer_ecosystem_manager = None +_developer_ecosystem_manager = None def get_developer_ecosystem_manager() -> DeveloperEcosystemManager: """获取开发者生态系统管理器单例""" global _developer_ecosystem_manager if _developer_ecosystem_manager is None: - _developer_ecosystem_manager = DeveloperEcosystemManager() + _developer_ecosystem_manager = DeveloperEcosystemManager() return _developer_ecosystem_manager diff --git a/backend/document_processor.py b/backend/document_processor.py index 164a4ec..a73913b 100644 --- a/backend/document_processor.py +++ b/backend/document_processor.py @@ -11,8 +11,8 @@ import os class DocumentProcessor: """文档处理器 - 提取 PDF/DOCX 文本""" - def __init__(self): - self.supported_formats = { + def __init__(self) -> None: + self.supported_formats = { ".pdf": self._extract_pdf, ".docx": self._extract_docx, ".doc": self._extract_docx, @@ -31,18 +31,18 @@ class DocumentProcessor: Returns: {"text": "提取的文本内容", "format": "文件格式"} """ - ext = os.path.splitext(filename.lower())[1] + ext = os.path.splitext(filename.lower())[1] if ext not in self.supported_formats: raise ValueError( f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}" ) - extractor = self.supported_formats[ext] - text = extractor(content) + extractor = self.supported_formats[ext] + text = extractor(content) # 清理文本 - text = self._clean_text(text) + text = self._clean_text(text) return {"text": text, "format": ext, "filename": filename} @@ -51,12 +51,12 @@ class DocumentProcessor: try: import PyPDF2 - pdf_file = io.BytesIO(content) - reader = PyPDF2.PdfReader(pdf_file) + pdf_file = io.BytesIO(content) + reader = PyPDF2.PdfReader(pdf_file) - text_parts = [] + text_parts = [] for page in reader.pages: - page_text = page.extract_text() + page_text = page.extract_text() if page_text: text_parts.append(page_text) @@ -66,10 +66,10 @@ class DocumentProcessor: try: import pdfplumber - text_parts = [] + text_parts = [] with pdfplumber.open(io.BytesIO(content)) as pdf: for page in pdf.pages: - page_text = page.extract_text() + page_text = page.extract_text() if page_text: text_parts.append(page_text) return "\n\n".join(text_parts) @@ -85,10 +85,10 @@ class DocumentProcessor: try: import docx - doc_file = io.BytesIO(content) - doc = docx.Document(doc_file) + doc_file = io.BytesIO(content) + doc = docx.Document(doc_file) - text_parts = [] + text_parts = [] for para in doc.paragraphs: if para.text.strip(): text_parts.append(para.text) @@ -96,7 +96,7 @@ class DocumentProcessor: # 提取表格中的文本 for table in doc.tables: for row in table.rows: - row_text = [] + row_text = [] for cell in row.cells: if cell.text.strip(): row_text.append(cell.text.strip()) @@ -114,7 +114,7 @@ class DocumentProcessor: def _extract_txt(self, content: bytes) -> str: """提取纯文本""" # 尝试多种编码 - encodings = ["utf-8", "gbk", "gb2312", "latin-1"] + encodings = ["utf-8", "gbk", "gb2312", "latin-1"] for encoding in encodings: try: @@ -123,7 +123,7 @@ class DocumentProcessor: continue # 如果都失败了,使用 latin-1 并忽略错误 - return content.decode("latin-1", errors="ignore") + return content.decode("latin-1", errors = "ignore") def _clean_text(self, text: str) -> str: """清理提取的文本""" @@ -131,29 +131,29 @@ class DocumentProcessor: return "" # 移除多余的空白字符 - lines = text.split("\n") - cleaned_lines = [] + lines = text.split("\n") + cleaned_lines = [] for line in lines: - line = line.strip() + line = line.strip() # 移除空行,但保留段落分隔 if line: cleaned_lines.append(line) # 合并行,保留段落结构 - text = "\n\n".join(cleaned_lines) + text = "\n\n".join(cleaned_lines) # 移除多余的空格 - text = " ".join(text.split()) + text = " ".join(text.split()) # 移除控制字符 - text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t") + text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t") return text.strip() def is_supported(self, filename: str) -> bool: """检查文件格式是否支持""" - ext = os.path.splitext(filename.lower())[1] + ext = os.path.splitext(filename.lower())[1] return ext in self.supported_formats @@ -165,7 +165,7 @@ class SimpleTextExtractor: def extract(self, content: bytes, filename: str) -> str: """尝试提取文本""" - encodings = ["utf-8", "gbk", "latin-1"] + encodings = ["utf-8", "gbk", "latin-1"] for encoding in encodings: try: @@ -173,15 +173,15 @@ class SimpleTextExtractor: except UnicodeDecodeError: continue - return content.decode("latin-1", errors="ignore") + return content.decode("latin-1", errors = "ignore") if __name__ == "__main__": # 测试 - processor = DocumentProcessor() + processor = DocumentProcessor() # 测试文本提取 - test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs." - result = processor.process(test_text.encode("utf-8"), "test.txt") + test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs." + result = processor.process(test_text.encode("utf-8"), "test.txt") print(f"Text extraction test: {len(result['text'])} chars") print(result["text"][:100]) diff --git a/backend/enterprise_manager.py b/backend/enterprise_manager.py index fab08f3..44b61e8 100644 --- a/backend/enterprise_manager.py +++ b/backend/enterprise_manager.py @@ -19,64 +19,64 @@ from datetime import datetime, timedelta from enum import StrEnum from typing import Any -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class SSOProvider(StrEnum): """SSO 提供商类型""" - WECHAT_WORK = "wechat_work" # 企业微信 - DINGTALK = "dingtalk" # 钉钉 - FEISHU = "feishu" # 飞书 - OKTA = "okta" # Okta - AZURE_AD = "azure_ad" # Azure AD - GOOGLE = "google" # Google Workspace - CUSTOM_SAML = "custom_saml" # 自定义 SAML + WECHAT_WORK = "wechat_work" # 企业微信 + DINGTALK = "dingtalk" # 钉钉 + FEISHU = "feishu" # 飞书 + OKTA = "okta" # Okta + AZURE_AD = "azure_ad" # Azure AD + GOOGLE = "google" # Google Workspace + CUSTOM_SAML = "custom_saml" # 自定义 SAML class SSOStatus(StrEnum): """SSO 配置状态""" - DISABLED = "disabled" # 未启用 - PENDING = "pending" # 待配置 - ACTIVE = "active" # 已启用 - ERROR = "error" # 配置错误 + DISABLED = "disabled" # 未启用 + PENDING = "pending" # 待配置 + ACTIVE = "active" # 已启用 + ERROR = "error" # 配置错误 class SCIMSyncStatus(StrEnum): """SCIM 同步状态""" - IDLE = "idle" # 空闲 - SYNCING = "syncing" # 同步中 - SUCCESS = "success" # 同步成功 - FAILED = "failed" # 同步失败 + IDLE = "idle" # 空闲 + SYNCING = "syncing" # 同步中 + SUCCESS = "success" # 同步成功 + FAILED = "failed" # 同步失败 class AuditLogExportFormat(StrEnum): """审计日志导出格式""" - JSON = "json" - CSV = "csv" - PDF = "pdf" - XLSX = "xlsx" + JSON = "json" + CSV = "csv" + PDF = "pdf" + XLSX = "xlsx" class DataRetentionAction(StrEnum): """数据保留策略动作""" - ARCHIVE = "archive" # 归档 - DELETE = "delete" # 删除 - ANONYMIZE = "anonymize" # 匿名化 + ARCHIVE = "archive" # 归档 + DELETE = "delete" # 删除 + ANONYMIZE = "anonymize" # 匿名化 class ComplianceStandard(StrEnum): """合规标准""" - SOC2 = "soc2" - ISO27001 = "iso27001" - GDPR = "gdpr" - HIPAA = "hipaa" - PCI_DSS = "pci_dss" + SOC2 = "soc2" + ISO27001 = "iso27001" + GDPR = "gdpr" + HIPAA = "hipaa" + PCI_DSS = "pci_dss" @dataclass @@ -264,7 +264,7 @@ class EnterpriseManager: """企业级功能管理器""" # 默认属性映射 - DEFAULT_ATTRIBUTE_MAPPING = { + DEFAULT_ATTRIBUTE_MAPPING = { SSOProvider.WECHAT_WORK: { "email": "email", "name": "name", @@ -293,7 +293,7 @@ class EnterpriseManager: } # 合规标准字段映射 - COMPLIANCE_FIELDS = { + COMPLIANCE_FIELDS = { ComplianceStandard.SOC2: [ "timestamp", "user_id", @@ -329,21 +329,21 @@ class EnterpriseManager: ], } - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_db() def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_db(self) -> None: """初始化数据库表""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # SSO 配置表 cursor.execute(""" @@ -582,61 +582,61 @@ class EnterpriseManager: self, tenant_id: str, provider: str, - entity_id: str | None = None, - sso_url: str | None = None, - slo_url: str | None = None, - certificate: str | None = None, - metadata_url: str | None = None, - metadata_xml: str | None = None, - client_id: str | None = None, - client_secret: str | None = None, - authorization_url: str | None = None, - token_url: str | None = None, - userinfo_url: str | None = None, - scopes: list[str] | None = None, - attribute_mapping: dict[str, str] | None = None, - auto_provision: bool = True, - default_role: str = "member", - domain_restriction: list[str] | None = None, + entity_id: str | None = None, + sso_url: str | None = None, + slo_url: str | None = None, + certificate: str | None = None, + metadata_url: str | None = None, + metadata_xml: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + authorization_url: str | None = None, + token_url: str | None = None, + userinfo_url: str | None = None, + scopes: list[str] | None = None, + attribute_mapping: dict[str, str] | None = None, + auto_provision: bool = True, + default_role: str = "member", + domain_restriction: list[str] | None = None, ) -> SSOConfig: """创建 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config_id = str(uuid.uuid4()) - now = datetime.now() + config_id = str(uuid.uuid4()) + now = datetime.now() # 使用默认属性映射 if attribute_mapping is None and provider in self.DEFAULT_ATTRIBUTE_MAPPING: - attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)] + attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)] - config = SSOConfig( - id=config_id, - tenant_id=tenant_id, - provider=provider, - status=SSOStatus.PENDING.value, - entity_id=entity_id, - sso_url=sso_url, - slo_url=slo_url, - certificate=certificate, - metadata_url=metadata_url, - metadata_xml=metadata_xml, - client_id=client_id, - client_secret=client_secret, - authorization_url=authorization_url, - token_url=token_url, - userinfo_url=userinfo_url, - scopes=scopes or ["openid", "email", "profile"], - attribute_mapping=attribute_mapping or {}, - auto_provision=auto_provision, - default_role=default_role, - domain_restriction=domain_restriction or [], - created_at=now, - updated_at=now, - last_tested_at=None, - last_error=None, + config = SSOConfig( + id = config_id, + tenant_id = tenant_id, + provider = provider, + status = SSOStatus.PENDING.value, + entity_id = entity_id, + sso_url = sso_url, + slo_url = slo_url, + certificate = certificate, + metadata_url = metadata_url, + metadata_xml = metadata_xml, + client_id = client_id, + client_secret = client_secret, + authorization_url = authorization_url, + token_url = token_url, + userinfo_url = userinfo_url, + scopes = scopes or ["openid", "email", "profile"], + attribute_mapping = attribute_mapping or {}, + auto_provision = auto_provision, + default_role = default_role, + domain_restriction = domain_restriction or [], + created_at = now, + updated_at = now, + last_tested_at = None, + last_error = None, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO sso_configs @@ -685,11 +685,11 @@ class EnterpriseManager: def get_sso_config(self, config_id: str) -> SSOConfig | None: """获取 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id, )) + row = cursor.fetchone() if row: return self._row_to_sso_config(row) @@ -699,18 +699,18 @@ class EnterpriseManager: conn.close() def get_tenant_sso_config( - self, tenant_id: str, provider: str | None = None + self, tenant_id: str, provider: str | None = None ) -> SSOConfig | None: """获取租户的 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() if provider: cursor.execute( """ SELECT * FROM sso_configs - WHERE tenant_id = ? AND provider = ? + WHERE tenant_id = ? AND provider = ? ORDER BY created_at DESC LIMIT 1 """, (tenant_id, provider), @@ -719,13 +719,13 @@ class EnterpriseManager: cursor.execute( """ SELECT * FROM sso_configs - WHERE tenant_id = ? AND status = 'active' + WHERE tenant_id = ? AND status = 'active' ORDER BY created_at DESC LIMIT 1 """, - (tenant_id,), + (tenant_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_sso_config(row) @@ -736,16 +736,16 @@ class EnterpriseManager: def update_sso_config(self, config_id: str, **kwargs) -> SSOConfig | None: """更新 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config = self.get_sso_config(config_id) + config = self.get_sso_config(config_id) if not config: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "entity_id", "sso_url", "slo_url", @@ -767,7 +767,7 @@ class EnterpriseManager: for key, value in kwargs.items(): if key in allowed_fields: - updates.append(f"{key} = ?") + updates.append(f"{key} = ?") if key in ["scopes", "attribute_mapping", "domain_restriction"]: params.append(json.dumps(value) if value else "[]") elif key == "auto_provision": @@ -778,15 +778,15 @@ class EnterpriseManager: if not updates: return config - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(config_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE sso_configs SET {", ".join(updates)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -799,10 +799,10 @@ class EnterpriseManager: def delete_sso_config(self, config_id: str) -> bool: """删除 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id,)) + cursor = conn.cursor() + cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id, )) conn.commit() return cursor.rowcount > 0 finally: @@ -810,17 +810,17 @@ class EnterpriseManager: def list_sso_configs(self, tenant_id: str) -> list[SSOConfig]: """列出租户的所有 SSO 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ - SELECT * FROM sso_configs WHERE tenant_id = ? + SELECT * FROM sso_configs WHERE tenant_id = ? ORDER BY created_at DESC """, - (tenant_id,), + (tenant_id, ), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_sso_config(row) for row in rows] @@ -829,70 +829,70 @@ class EnterpriseManager: def generate_saml_metadata(self, config_id: str, base_url: str) -> str: """生成 SAML Service Provider 元数据""" - config = self.get_sso_config(config_id) + config = self.get_sso_config(config_id) if not config: raise ValueError(f"SSO config {config_id} not found") # 生成 SP 实体 ID - sp_entity_id = f"{base_url}/api/v1/sso/saml/{config.tenant_id}" - acs_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/acs" - slo_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/slo" + sp_entity_id = f"{base_url}/api/v1/sso/saml/{config.tenant_id}" + acs_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/acs" + slo_url = f"{base_url}/api/v1/sso/saml/{config.tenant_id}/slo" # 生成 X.509 证书(简化实现,实际应该生成真实的密钥对) - cert = config.certificate or self._generate_self_signed_cert() + cert = config.certificate or self._generate_self_signed_cert() - metadata = f""" - - - - + metadata = f""" + + + + {cert} - - + + - InsightFlow - InsightFlow - {base_url} + InsightFlow + InsightFlow + {base_url} """ return metadata def create_saml_auth_request( - self, tenant_id: str, config_id: str, relay_state: str | None = None + self, tenant_id: str, config_id: str, relay_state: str | None = None ) -> SAMLAuthRequest: """创建 SAML 认证请求""" - conn = self._get_connection() + conn = self._get_connection() try: - request_id = f"_{uuid.uuid4().hex}" - now = datetime.now() - expires = now + timedelta(minutes=10) + request_id = f"_{uuid.uuid4().hex}" + now = datetime.now() + expires = now + timedelta(minutes = 10) - auth_request = SAMLAuthRequest( - id=str(uuid.uuid4()), - tenant_id=tenant_id, - sso_config_id=config_id, - request_id=request_id, - relay_state=relay_state, - created_at=now, - expires_at=expires, - used=False, - used_at=None, + auth_request = SAMLAuthRequest( + id = str(uuid.uuid4()), + tenant_id = tenant_id, + sso_config_id = config_id, + request_id = request_id, + relay_state = relay_state, + created_at = now, + expires_at = expires, + used = False, + used_at = None, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO saml_auth_requests @@ -919,16 +919,16 @@ class EnterpriseManager: def get_saml_auth_request(self, request_id: str) -> SAMLAuthRequest | None: """获取 SAML 认证请求""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ - SELECT * FROM saml_auth_requests WHERE request_id = ? + SELECT * FROM saml_auth_requests WHERE request_id = ? """, - (request_id,), + (request_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_saml_request(row) @@ -942,27 +942,27 @@ class EnterpriseManager: # 这里应该实现实际的 SAML 响应解析 # 简化实现:假设响应已经验证并解析 - conn = self._get_connection() + conn = self._get_connection() try: # 解析 SAML Response(简化) # 实际应该使用 python-saml 或类似库 - attributes = self._parse_saml_response(saml_response) + attributes = self._parse_saml_response(saml_response) - auth_response = SAMLAuthResponse( - id=str(uuid.uuid4()), - request_id=request_id, - tenant_id="", # 从 request 获取 - user_id=None, - email=attributes.get("email"), - name=attributes.get("name"), - attributes=attributes, - session_index=attributes.get("session_index"), - processed=False, - processed_at=None, - created_at=datetime.now(), + auth_response = SAMLAuthResponse( + id = str(uuid.uuid4()), + request_id = request_id, + tenant_id = "", # 从 request 获取 + user_id = None, + email = attributes.get("email"), + name = attributes.get("name"), + attributes = attributes, + session_index = attributes.get("session_index"), + processed = False, + processed_at = None, + created_at = datetime.now(), ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO saml_auth_responses @@ -1017,35 +1017,35 @@ class EnterpriseManager: provider: str, scim_base_url: str, scim_token: str, - sync_interval_minutes: int = 60, - attribute_mapping: dict[str, str] | None = None, - sync_rules: dict[str, Any] | None = None, + sync_interval_minutes: int = 60, + attribute_mapping: dict[str, str] | None = None, + sync_rules: dict[str, Any] | None = None, ) -> SCIMConfig: """创建 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config_id = str(uuid.uuid4()) - now = datetime.now() + config_id = str(uuid.uuid4()) + now = datetime.now() - config = SCIMConfig( - id=config_id, - tenant_id=tenant_id, - provider=provider, - status="disabled", - scim_base_url=scim_base_url, - scim_token=scim_token, - sync_interval_minutes=sync_interval_minutes, - last_sync_at=None, - last_sync_status=None, - last_sync_error=None, - last_sync_users_count=0, - attribute_mapping=attribute_mapping or {}, - sync_rules=sync_rules or {}, - created_at=now, - updated_at=now, + config = SCIMConfig( + id = config_id, + tenant_id = tenant_id, + provider = provider, + status = "disabled", + scim_base_url = scim_base_url, + scim_token = scim_token, + sync_interval_minutes = sync_interval_minutes, + last_sync_at = None, + last_sync_status = None, + last_sync_error = None, + last_sync_users_count = 0, + attribute_mapping = attribute_mapping or {}, + sync_rules = sync_rules or {}, + created_at = now, + updated_at = now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO scim_configs @@ -1081,11 +1081,11 @@ class EnterpriseManager: def get_scim_config(self, config_id: str) -> SCIMConfig | None: """获取 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id, )) + row = cursor.fetchone() if row: return self._row_to_scim_config(row) @@ -1096,17 +1096,17 @@ class EnterpriseManager: def get_tenant_scim_config(self, tenant_id: str) -> SCIMConfig | None: """获取租户的 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ - SELECT * FROM scim_configs WHERE tenant_id = ? + SELECT * FROM scim_configs WHERE tenant_id = ? ORDER BY created_at DESC LIMIT 1 """, - (tenant_id,), + (tenant_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_scim_config(row) @@ -1117,16 +1117,16 @@ class EnterpriseManager: def update_scim_config(self, config_id: str, **kwargs) -> SCIMConfig | None: """更新 SCIM 配置""" - conn = self._get_connection() + conn = self._get_connection() try: - config = self.get_scim_config(config_id) + config = self.get_scim_config(config_id) if not config: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "scim_base_url", "scim_token", "sync_interval_minutes", @@ -1137,7 +1137,7 @@ class EnterpriseManager: for key, value in kwargs.items(): if key in allowed_fields: - updates.append(f"{key} = ?") + updates.append(f"{key} = ?") if key in ["attribute_mapping", "sync_rules"]: params.append(json.dumps(value) if value else "{}") else: @@ -1146,15 +1146,15 @@ class EnterpriseManager: if not updates: return config - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(config_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE scim_configs SET {", ".join(updates)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -1167,21 +1167,21 @@ class EnterpriseManager: def sync_scim_users(self, config_id: str) -> dict[str, Any]: """执行 SCIM 用户同步""" - config = self.get_scim_config(config_id) + config = self.get_scim_config(config_id) if not config: raise ValueError(f"SCIM config {config_id} not found") - conn = self._get_connection() + conn = self._get_connection() try: - now = datetime.now() + now = datetime.now() # 更新同步状态 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE scim_configs - SET status = 'syncing', last_sync_at = ? - WHERE id = ? + SET status = 'syncing', last_sync_at = ? + WHERE id = ? """, (now, config_id), ) @@ -1190,9 +1190,9 @@ class EnterpriseManager: try: # 模拟从 SCIM 服务端获取用户 # 实际应该使用 HTTP 请求获取 - users = self._fetch_scim_users(config) + users = self._fetch_scim_users(config) - synced_count = 0 + synced_count = 0 for user_data in users: self._upsert_scim_user(conn, config.tenant_id, user_data) synced_count += 1 @@ -1201,9 +1201,9 @@ class EnterpriseManager: cursor.execute( """ UPDATE scim_configs - SET status = 'active', last_sync_status = 'success', - last_sync_error = NULL, last_sync_users_count = ? - WHERE id = ? + SET status = 'active', last_sync_status = 'success', + last_sync_error = NULL, last_sync_users_count = ? + WHERE id = ? """, (synced_count, config_id), ) @@ -1215,9 +1215,9 @@ class EnterpriseManager: cursor.execute( """ UPDATE scim_configs - SET status = 'error', last_sync_status = 'failed', - last_sync_error = ? - WHERE id = ? + SET status = 'error', last_sync_status = 'failed', + last_sync_error = ? + WHERE id = ? """, (str(e), config_id), ) @@ -1238,17 +1238,17 @@ class EnterpriseManager: self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any] ) -> None: """插入或更新 SCIM 用户""" - cursor = conn.cursor() + cursor = conn.cursor() - external_id = user_data.get("id") - user_name = user_data.get("userName", "") - email = user_data.get("emails", [{}])[0].get("value", "") - display_name = user_data.get("displayName") - name = user_data.get("name", {}) - given_name = name.get("givenName") - family_name = name.get("familyName") - active = user_data.get("active", True) - groups = [g.get("value") for g in user_data.get("groups", [])] + external_id = user_data.get("id") + user_name = user_data.get("userName", "") + email = user_data.get("emails", [{}])[0].get("value", "") + display_name = user_data.get("displayName") + name = user_data.get("name", {}) + given_name = name.get("givenName") + family_name = name.get("familyName") + active = user_data.get("active", True) + groups = [g.get("value") for g in user_data.get("groups", [])] cursor.execute( """ @@ -1257,16 +1257,16 @@ class EnterpriseManager: given_name, family_name, active, groups, raw_data, synced_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(tenant_id, external_id) DO UPDATE SET - user_name = excluded.user_name, - email = excluded.email, - display_name = excluded.display_name, - given_name = excluded.given_name, - family_name = excluded.family_name, - active = excluded.active, - groups = excluded.groups, - raw_data = excluded.raw_data, - synced_at = excluded.synced_at, - updated_at = CURRENT_TIMESTAMP + user_name = excluded.user_name, + email = excluded.email, + display_name = excluded.display_name, + given_name = excluded.given_name, + family_name = excluded.family_name, + active = excluded.active, + groups = excluded.groups, + raw_data = excluded.raw_data, + synced_at = excluded.synced_at, + updated_at = CURRENT_TIMESTAMP """, ( str(uuid.uuid4()), @@ -1284,22 +1284,22 @@ class EnterpriseManager: ), ) - def list_scim_users(self, tenant_id: str, active_only: bool = True) -> list[SCIMUser]: + def list_scim_users(self, tenant_id: str, active_only: bool = True) -> list[SCIMUser]: """列出 SCIM 用户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM scim_users WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM scim_users WHERE tenant_id = ?" + params = [tenant_id] if active_only: - query += " AND active = 1" + query += " AND active = 1" query += " ORDER BY synced_at DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_scim_user(row) for row in rows] @@ -1315,41 +1315,41 @@ class EnterpriseManager: start_date: datetime, end_date: datetime, created_by: str, - filters: dict[str, Any] | None = None, - compliance_standard: str | None = None, + filters: dict[str, Any] | None = None, + compliance_standard: str | None = None, ) -> AuditLogExport: """创建审计日志导出任务""" - conn = self._get_connection() + conn = self._get_connection() try: - export_id = str(uuid.uuid4()) - now = datetime.now() + export_id = str(uuid.uuid4()) + now = datetime.now() # 默认7天后过期 - expires_at = now + timedelta(days=7) + expires_at = now + timedelta(days = 7) - export = AuditLogExport( - id=export_id, - tenant_id=tenant_id, - export_format=export_format, - start_date=start_date, - end_date=end_date, - filters=filters or {}, - compliance_standard=compliance_standard, - status="pending", - file_path=None, - file_size=None, - record_count=None, - checksum=None, - downloaded_by=None, - downloaded_at=None, - expires_at=expires_at, - created_by=created_by, - created_at=now, - completed_at=None, - error_message=None, + export = AuditLogExport( + id = export_id, + tenant_id = tenant_id, + export_format = export_format, + start_date = start_date, + end_date = end_date, + filters = filters or {}, + compliance_standard = compliance_standard, + status = "pending", + file_path = None, + file_size = None, + record_count = None, + checksum = None, + downloaded_by = None, + downloaded_at = None, + expires_at = expires_at, + created_by = created_by, + created_at = now, + completed_at = None, + error_message = None, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO audit_log_exports @@ -1383,49 +1383,49 @@ class EnterpriseManager: finally: conn.close() - def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None: + def process_audit_export(self, export_id: str, db_manager = None) -> AuditLogExport | None: """处理审计日志导出任务""" - export = self.get_audit_export(export_id) + export = self.get_audit_export(export_id) if not export: return None - conn = self._get_connection() + conn = self._get_connection() try: # 更新状态为处理中 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ - UPDATE audit_log_exports SET status = 'processing' - WHERE id = ? + UPDATE audit_log_exports SET status = 'processing' + WHERE id = ? """, - (export_id,), + (export_id, ), ) conn.commit() try: # 获取审计日志数据 - logs = self._fetch_audit_logs( + logs = self._fetch_audit_logs( export.tenant_id, export.start_date, export.end_date, export.filters, db_manager ) # 根据合规标准过滤字段 if export.compliance_standard: - logs = self._apply_compliance_filter(logs, export.compliance_standard) + logs = self._apply_compliance_filter(logs, export.compliance_standard) # 生成导出文件 - file_path, file_size, checksum = self._generate_export_file( + file_path, file_size, checksum = self._generate_export_file( export_id, logs, export.export_format ) - now = datetime.now() + now = datetime.now() # 更新导出记录 cursor.execute( """ UPDATE audit_log_exports - SET status = 'completed', file_path = ?, file_size = ?, - record_count = ?, checksum = ?, completed_at = ? - WHERE id = ? + SET status = 'completed', file_path = ?, file_size = ?, + record_count = ?, checksum = ?, completed_at = ? + WHERE id = ? """, (file_path, file_size, len(logs), checksum, now, export_id), ) @@ -1437,8 +1437,8 @@ class EnterpriseManager: cursor.execute( """ UPDATE audit_log_exports - SET status = 'failed', error_message = ? - WHERE id = ? + SET status = 'failed', error_message = ? + WHERE id = ? """, (str(e), export_id), ) @@ -1454,7 +1454,7 @@ class EnterpriseManager: start_date: datetime, end_date: datetime, filters: dict[str, Any], - db_manager=None, + db_manager = None, ) -> list[dict[str, Any]]: """获取审计日志数据""" if db_manager is None: @@ -1468,14 +1468,14 @@ class EnterpriseManager: self, logs: list[dict[str, Any]], standard: str ) -> list[dict[str, Any]]: """应用合规标准字段过滤""" - fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) + fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) if not fields: return logs - filtered_logs = [] + filtered_logs = [] for log in logs: - filtered_log = {k: v for k, v in log.items() if k in fields} + filtered_log = {k: v for k, v in log.items() if k in fields} filtered_logs.append(filtered_log) return filtered_logs @@ -1487,44 +1487,44 @@ class EnterpriseManager: import hashlib import os - export_dir = "/tmp/insightflow/exports" - os.makedirs(export_dir, exist_ok=True) + export_dir = "/tmp/insightflow/exports" + os.makedirs(export_dir, exist_ok = True) - file_path = f"{export_dir}/audit_export_{export_id}.{format}" + file_path = f"{export_dir}/audit_export_{export_id}.{format}" if format == "json": - content = json.dumps(logs, ensure_ascii=False, indent=2) - with open(file_path, "w", encoding="utf-8") as f: + content = json.dumps(logs, ensure_ascii = False, indent = 2) + with open(file_path, "w", encoding = "utf-8") as f: f.write(content) elif format == "csv": import csv if logs: - with open(file_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=logs[0].keys()) + with open(file_path, "w", newline = "", encoding = "utf-8") as f: + writer = csv.DictWriter(f, fieldnames = logs[0].keys()) writer.writeheader() writer.writerows(logs) else: # 其他格式暂不支持 - content = json.dumps(logs, ensure_ascii=False) - with open(file_path, "w", encoding="utf-8") as f: + content = json.dumps(logs, ensure_ascii = False) + with open(file_path, "w", encoding = "utf-8") as f: f.write(content) - file_size = os.path.getsize(file_path) + file_size = os.path.getsize(file_path) # 计算校验和 with open(file_path, "rb") as f: - checksum = hashlib.sha256(f.read()).hexdigest() + checksum = hashlib.sha256(f.read()).hexdigest() return file_path, file_size, checksum def get_audit_export(self, export_id: str) -> AuditLogExport | None: """获取审计日志导出记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id, )) + row = cursor.fetchone() if row: return self._row_to_audit_export(row) @@ -1533,21 +1533,21 @@ class EnterpriseManager: finally: conn.close() - def list_audit_exports(self, tenant_id: str, limit: int = 100) -> list[AuditLogExport]: + def list_audit_exports(self, tenant_id: str, limit: int = 100) -> list[AuditLogExport]: """列出审计日志导出记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM audit_log_exports - WHERE tenant_id = ? + WHERE tenant_id = ? ORDER BY created_at DESC LIMIT ? """, (tenant_id, limit), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_audit_export(row) for row in rows] @@ -1556,14 +1556,14 @@ class EnterpriseManager: def mark_export_downloaded(self, export_id: str, downloaded_by: str) -> bool: """标记导出文件已下载""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE audit_log_exports - SET downloaded_by = ?, downloaded_at = ? - WHERE id = ? + SET downloaded_by = ?, downloaded_at = ? + WHERE id = ? """, (downloaded_by, datetime.now(), export_id), ) @@ -1581,42 +1581,42 @@ class EnterpriseManager: resource_type: str, retention_days: int, action: str, - description: str | None = None, - conditions: dict[str, Any] | None = None, - auto_execute: bool = False, - execute_at: str | None = None, - notify_before_days: int = 7, - archive_location: str | None = None, - archive_encryption: bool = True, + description: str | None = None, + conditions: dict[str, Any] | None = None, + auto_execute: bool = False, + execute_at: str | None = None, + notify_before_days: int = 7, + archive_location: str | None = None, + archive_encryption: bool = True, ) -> DataRetentionPolicy: """创建数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - policy_id = str(uuid.uuid4()) - now = datetime.now() + policy_id = str(uuid.uuid4()) + now = datetime.now() - policy = DataRetentionPolicy( - id=policy_id, - tenant_id=tenant_id, - name=name, - description=description, - resource_type=resource_type, - retention_days=retention_days, - action=action, - conditions=conditions or {}, - auto_execute=auto_execute, - execute_at=execute_at, - notify_before_days=notify_before_days, - archive_location=archive_location, - archive_encryption=archive_encryption, - is_active=True, - last_executed_at=None, - last_execution_result=None, - created_at=now, - updated_at=now, + policy = DataRetentionPolicy( + id = policy_id, + tenant_id = tenant_id, + name = name, + description = description, + resource_type = resource_type, + retention_days = retention_days, + action = action, + conditions = conditions or {}, + auto_execute = auto_execute, + execute_at = execute_at, + notify_before_days = notify_before_days, + archive_location = archive_location, + archive_encryption = archive_encryption, + is_active = True, + last_executed_at = None, + last_execution_result = None, + created_at = now, + updated_at = now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO data_retention_policies @@ -1658,11 +1658,11 @@ class EnterpriseManager: def get_retention_policy(self, policy_id: str) -> DataRetentionPolicy | None: """获取数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id, )) + row = cursor.fetchone() if row: return self._row_to_retention_policy(row) @@ -1672,24 +1672,24 @@ class EnterpriseManager: conn.close() def list_retention_policies( - self, tenant_id: str, resource_type: str | None = None + self, tenant_id: str, resource_type: str | None = None ) -> list[DataRetentionPolicy]: """列出数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM data_retention_policies WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM data_retention_policies WHERE tenant_id = ?" + params = [tenant_id] if resource_type: - query += " AND resource_type = ?" + query += " AND resource_type = ?" params.append(resource_type) query += " ORDER BY created_at DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_retention_policy(row) for row in rows] @@ -1698,16 +1698,16 @@ class EnterpriseManager: def update_retention_policy(self, policy_id: str, **kwargs) -> DataRetentionPolicy | None: """更新数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - policy = self.get_retention_policy(policy_id) + policy = self.get_retention_policy(policy_id) if not policy: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "name", "description", "retention_days", @@ -1723,7 +1723,7 @@ class EnterpriseManager: for key, value in kwargs.items(): if key in allowed_fields: - updates.append(f"{key} = ?") + updates.append(f"{key} = ?") if key == "conditions": params.append(json.dumps(value) if value else "{}") elif key in ["auto_execute", "archive_encryption", "is_active"]: @@ -1734,15 +1734,15 @@ class EnterpriseManager: if not updates: return policy - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(policy_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE data_retention_policies SET {", ".join(updates)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -1755,10 +1755,10 @@ class EnterpriseManager: def delete_retention_policy(self, policy_id: str) -> bool: """删除数据保留策略""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id,)) + cursor = conn.cursor() + cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id, )) conn.commit() return cursor.rowcount > 0 finally: @@ -1766,31 +1766,31 @@ class EnterpriseManager: def execute_retention_policy(self, policy_id: str) -> DataRetentionJob: """执行数据保留策略""" - policy = self.get_retention_policy(policy_id) + policy = self.get_retention_policy(policy_id) if not policy: raise ValueError(f"Retention policy {policy_id} not found") - conn = self._get_connection() + conn = self._get_connection() try: - job_id = str(uuid.uuid4()) - now = datetime.now() + job_id = str(uuid.uuid4()) + now = datetime.now() - job = DataRetentionJob( - id=job_id, - policy_id=policy_id, - tenant_id=policy.tenant_id, - status="running", - started_at=now, - completed_at=None, - affected_records=0, - archived_records=0, - deleted_records=0, - error_count=0, - details={}, - created_at=now, + job = DataRetentionJob( + id = job_id, + policy_id = policy_id, + tenant_id = policy.tenant_id, + status = "running", + started_at = now, + completed_at = None, + affected_records = 0, + archived_records = 0, + deleted_records = 0, + error_count = 0, + details = {}, + created_at = now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO data_retention_jobs @@ -1804,26 +1804,26 @@ class EnterpriseManager: try: # 计算截止日期 - cutoff_date = now - timedelta(days=policy.retention_days) + cutoff_date = now - timedelta(days = policy.retention_days) # 根据资源类型执行不同的处理 if policy.resource_type == "audit_log": - result = self._retain_audit_logs(conn, policy, cutoff_date) + result = self._retain_audit_logs(conn, policy, cutoff_date) elif policy.resource_type == "project": - result = self._retain_projects(conn, policy, cutoff_date) + result = self._retain_projects(conn, policy, cutoff_date) elif policy.resource_type == "transcript": - result = self._retain_transcripts(conn, policy, cutoff_date) + result = self._retain_transcripts(conn, policy, cutoff_date) else: - result = {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} + result = {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} # 更新任务状态 cursor.execute( """ UPDATE data_retention_jobs - SET status = 'completed', completed_at = ?, - affected_records = ?, archived_records = ?, - deleted_records = ?, error_count = ?, details = ? - WHERE id = ? + SET status = 'completed', completed_at = ?, + affected_records = ?, archived_records = ?, + deleted_records = ?, error_count = ?, details = ? + WHERE id = ? """, ( datetime.now(), @@ -1840,8 +1840,8 @@ class EnterpriseManager: cursor.execute( """ UPDATE data_retention_policies - SET last_executed_at = ?, last_execution_result = 'success' - WHERE id = ? + SET last_executed_at = ?, last_execution_result = 'success' + WHERE id = ? """, (datetime.now(), policy_id), ) @@ -1852,8 +1852,8 @@ class EnterpriseManager: cursor.execute( """ UPDATE data_retention_jobs - SET status = 'failed', completed_at = ?, error_count = 1, details = ? - WHERE id = ? + SET status = 'failed', completed_at = ?, error_count = 1, details = ? + WHERE id = ? """, (datetime.now(), json.dumps({"error": str(e)}), job_id), ) @@ -1861,8 +1861,8 @@ class EnterpriseManager: cursor.execute( """ UPDATE data_retention_policies - SET last_executed_at = ?, last_execution_result = ? - WHERE id = ? + SET last_executed_at = ?, last_execution_result = ? + WHERE id = ? """, (datetime.now(), str(e), policy_id), ) @@ -1879,7 +1879,7 @@ class EnterpriseManager: self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime ) -> dict[str, int]: """保留审计日志""" - cursor = conn.cursor() + cursor = conn.cursor() # 获取符合条件的记录数 cursor.execute( @@ -1887,23 +1887,23 @@ class EnterpriseManager: SELECT COUNT(*) as count FROM audit_logs WHERE created_at < ? """, - (cutoff_date,), + (cutoff_date, ), ) - count = cursor.fetchone()["count"] + count = cursor.fetchone()["count"] if policy.action == DataRetentionAction.DELETE.value: cursor.execute( """ DELETE FROM audit_logs WHERE created_at < ? """, - (cutoff_date,), + (cutoff_date, ), ) - deleted = cursor.rowcount + deleted = cursor.rowcount return {"affected": count, "archived": 0, "deleted": deleted, "errors": 0} elif policy.action == DataRetentionAction.ARCHIVE.value: # 归档逻辑(简化实现) - archived = count + archived = count return {"affected": count, "archived": archived, "deleted": 0, "errors": 0} return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} @@ -1924,11 +1924,11 @@ class EnterpriseManager: def get_retention_job(self, job_id: str) -> DataRetentionJob | None: """获取数据保留任务""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id, )) + row = cursor.fetchone() if row: return self._row_to_retention_job(row) @@ -1937,21 +1937,21 @@ class EnterpriseManager: finally: conn.close() - def list_retention_jobs(self, policy_id: str, limit: int = 100) -> list[DataRetentionJob]: + def list_retention_jobs(self, policy_id: str, limit: int = 100) -> list[DataRetentionJob]: """列出数据保留任务""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM data_retention_jobs - WHERE policy_id = ? + WHERE policy_id = ? ORDER BY created_at DESC LIMIT ? """, (policy_id, limit), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_retention_job(row) for row in rows] @@ -1963,64 +1963,64 @@ class EnterpriseManager: def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig: """数据库行转换为 SSOConfig 对象""" return SSOConfig( - id=row["id"], - tenant_id=row["tenant_id"], - provider=row["provider"], - status=row["status"], - entity_id=row["entity_id"], - sso_url=row["sso_url"], - slo_url=row["slo_url"], - certificate=row["certificate"], - metadata_url=row["metadata_url"], - metadata_xml=row["metadata_xml"], - client_id=row["client_id"], - client_secret=row["client_secret"], - authorization_url=row["authorization_url"], - token_url=row["token_url"], - userinfo_url=row["userinfo_url"], - scopes=json.loads(row["scopes"] or '["openid", "email", "profile"]'), - attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), - auto_provision=bool(row["auto_provision"]), - default_role=row["default_role"], - domain_restriction=json.loads(row["domain_restriction"] or "[]"), - created_at=( + id = row["id"], + tenant_id = row["tenant_id"], + provider = row["provider"], + status = row["status"], + entity_id = row["entity_id"], + sso_url = row["sso_url"], + slo_url = row["slo_url"], + certificate = row["certificate"], + metadata_url = row["metadata_url"], + metadata_xml = row["metadata_xml"], + client_id = row["client_id"], + client_secret = row["client_secret"], + authorization_url = row["authorization_url"], + token_url = row["token_url"], + userinfo_url = row["userinfo_url"], + scopes = json.loads(row["scopes"] or '["openid", "email", "profile"]'), + attribute_mapping = json.loads(row["attribute_mapping"] or "{}"), + auto_provision = bool(row["auto_provision"]), + default_role = row["default_role"], + domain_restriction = json.loads(row["domain_restriction"] or "[]"), + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - last_tested_at=( + last_tested_at = ( datetime.fromisoformat(row["last_tested_at"]) if row["last_tested_at"] and isinstance(row["last_tested_at"], str) else row["last_tested_at"] ), - last_error=row["last_error"], + last_error = row["last_error"], ) def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest: """数据库行转换为 SAMLAuthRequest 对象""" return SAMLAuthRequest( - id=row["id"], - tenant_id=row["tenant_id"], - sso_config_id=row["sso_config_id"], - request_id=row["request_id"], - relay_state=row["relay_state"], - created_at=( + id = row["id"], + tenant_id = row["tenant_id"], + sso_config_id = row["sso_config_id"], + request_id = row["request_id"], + relay_state = row["relay_state"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - expires_at=( + expires_at = ( datetime.fromisoformat(row["expires_at"]) if isinstance(row["expires_at"], str) else row["expires_at"] ), - used=bool(row["used"]), - used_at=( + used = bool(row["used"]), + used_at = ( datetime.fromisoformat(row["used_at"]) if row["used_at"] and isinstance(row["used_at"], str) else row["used_at"] @@ -2030,29 +2030,29 @@ class EnterpriseManager: def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig: """数据库行转换为 SCIMConfig 对象""" return SCIMConfig( - id=row["id"], - tenant_id=row["tenant_id"], - provider=row["provider"], - status=row["status"], - scim_base_url=row["scim_base_url"], - scim_token=row["scim_token"], - sync_interval_minutes=row["sync_interval_minutes"], - last_sync_at=( + id = row["id"], + tenant_id = row["tenant_id"], + provider = row["provider"], + status = row["status"], + scim_base_url = row["scim_base_url"], + scim_token = row["scim_token"], + sync_interval_minutes = row["sync_interval_minutes"], + last_sync_at = ( datetime.fromisoformat(row["last_sync_at"]) if row["last_sync_at"] and isinstance(row["last_sync_at"], str) else row["last_sync_at"] ), - last_sync_status=row["last_sync_status"], - last_sync_error=row["last_sync_error"], - last_sync_users_count=row["last_sync_users_count"], - attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), - sync_rules=json.loads(row["sync_rules"] or "{}"), - created_at=( + last_sync_status = row["last_sync_status"], + last_sync_error = row["last_sync_error"], + last_sync_users_count = row["last_sync_users_count"], + attribute_mapping = json.loads(row["attribute_mapping"] or "{}"), + sync_rules = json.loads(row["sync_rules"] or "{}"), + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2062,28 +2062,28 @@ class EnterpriseManager: def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser: """数据库行转换为 SCIMUser 对象""" return SCIMUser( - id=row["id"], - tenant_id=row["tenant_id"], - external_id=row["external_id"], - user_name=row["user_name"], - email=row["email"], - display_name=row["display_name"], - given_name=row["given_name"], - family_name=row["family_name"], - active=bool(row["active"]), - groups=json.loads(row["groups"] or "[]"), - raw_data=json.loads(row["raw_data"] or "{}"), - synced_at=( + id = row["id"], + tenant_id = row["tenant_id"], + external_id = row["external_id"], + user_name = row["user_name"], + email = row["email"], + display_name = row["display_name"], + given_name = row["given_name"], + family_name = row["family_name"], + active = bool(row["active"]), + groups = json.loads(row["groups"] or "[]"), + raw_data = json.loads(row["raw_data"] or "{}"), + synced_at = ( datetime.fromisoformat(row["synced_at"]) if isinstance(row["synced_at"], str) else row["synced_at"] ), - created_at=( + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2093,78 +2093,78 @@ class EnterpriseManager: def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport: """数据库行转换为 AuditLogExport 对象""" return AuditLogExport( - id=row["id"], - tenant_id=row["tenant_id"], - export_format=row["export_format"], - start_date=( + id = row["id"], + tenant_id = row["tenant_id"], + export_format = row["export_format"], + start_date = ( datetime.fromisoformat(row["start_date"]) if isinstance(row["start_date"], str) else row["start_date"] ), - end_date=datetime.fromisoformat(row["end_date"]) + end_date = datetime.fromisoformat(row["end_date"]) if isinstance(row["end_date"], str) else row["end_date"], - filters=json.loads(row["filters"] or "{}"), - compliance_standard=row["compliance_standard"], - status=row["status"], - file_path=row["file_path"], - file_size=row["file_size"], - record_count=row["record_count"], - checksum=row["checksum"], - downloaded_by=row["downloaded_by"], - downloaded_at=( + filters = json.loads(row["filters"] or "{}"), + compliance_standard = row["compliance_standard"], + status = row["status"], + file_path = row["file_path"], + file_size = row["file_size"], + record_count = row["record_count"], + checksum = row["checksum"], + downloaded_by = row["downloaded_by"], + downloaded_at = ( datetime.fromisoformat(row["downloaded_at"]) if row["downloaded_at"] and isinstance(row["downloaded_at"], str) else row["downloaded_at"] ), - expires_at=( + expires_at = ( datetime.fromisoformat(row["expires_at"]) if isinstance(row["expires_at"], str) else row["expires_at"] ), - created_by=row["created_by"], - created_at=( + created_by = row["created_by"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - completed_at=( + completed_at = ( datetime.fromisoformat(row["completed_at"]) if row["completed_at"] and isinstance(row["completed_at"], str) else row["completed_at"] ), - error_message=row["error_message"], + error_message = row["error_message"], ) def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy: """数据库行转换为 DataRetentionPolicy 对象""" return DataRetentionPolicy( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - resource_type=row["resource_type"], - retention_days=row["retention_days"], - action=row["action"], - conditions=json.loads(row["conditions"] or "{}"), - auto_execute=bool(row["auto_execute"]), - execute_at=row["execute_at"], - notify_before_days=row["notify_before_days"], - archive_location=row["archive_location"], - archive_encryption=bool(row["archive_encryption"]), - is_active=bool(row["is_active"]), - last_executed_at=( + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + resource_type = row["resource_type"], + retention_days = row["retention_days"], + action = row["action"], + conditions = json.loads(row["conditions"] or "{}"), + auto_execute = bool(row["auto_execute"]), + execute_at = row["execute_at"], + notify_before_days = row["notify_before_days"], + archive_location = row["archive_location"], + archive_encryption = bool(row["archive_encryption"]), + is_active = bool(row["is_active"]), + last_executed_at = ( datetime.fromisoformat(row["last_executed_at"]) if row["last_executed_at"] and isinstance(row["last_executed_at"], str) else row["last_executed_at"] ), - last_execution_result=row["last_execution_result"], - created_at=( + last_execution_result = row["last_execution_result"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2174,26 +2174,26 @@ class EnterpriseManager: def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob: """数据库行转换为 DataRetentionJob 对象""" return DataRetentionJob( - id=row["id"], - policy_id=row["policy_id"], - tenant_id=row["tenant_id"], - status=row["status"], - started_at=( + id = row["id"], + policy_id = row["policy_id"], + tenant_id = row["tenant_id"], + status = row["status"], + started_at = ( datetime.fromisoformat(row["started_at"]) if row["started_at"] and isinstance(row["started_at"], str) else row["started_at"] ), - completed_at=( + completed_at = ( datetime.fromisoformat(row["completed_at"]) if row["completed_at"] and isinstance(row["completed_at"], str) else row["completed_at"] ), - affected_records=row["affected_records"], - archived_records=row["archived_records"], - deleted_records=row["deleted_records"], - error_count=row["error_count"], - details=json.loads(row["details"] or "{}"), - created_at=( + affected_records = row["affected_records"], + archived_records = row["archived_records"], + deleted_records = row["deleted_records"], + error_count = row["error_count"], + details = json.loads(row["details"] or "{}"), + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] @@ -2202,12 +2202,12 @@ class EnterpriseManager: # 全局实例 -_enterprise_manager = None +_enterprise_manager = None -def get_enterprise_manager(db_path: str = "insightflow.db") -> EnterpriseManager: +def get_enterprise_manager(db_path: str = "insightflow.db") -> EnterpriseManager: """获取 EnterpriseManager 单例""" global _enterprise_manager if _enterprise_manager is None: - _enterprise_manager = EnterpriseManager(db_path) + _enterprise_manager = EnterpriseManager(db_path) return _enterprise_manager diff --git a/backend/entity_aligner.py b/backend/entity_aligner.py index 3bad4cf..e1999d1 100644 --- a/backend/entity_aligner.py +++ b/backend/entity_aligner.py @@ -12,8 +12,8 @@ import httpx import numpy as np # API Keys -KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") -KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") +KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") +KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") @dataclass @@ -27,9 +27,9 @@ class EntityEmbedding: class EntityAligner: """实体对齐器 - 使用 embedding 进行相似度匹配""" - def __init__(self, similarity_threshold: float = 0.85): - self.similarity_threshold = similarity_threshold - self.embedding_cache: dict[str, list[float]] = {} + def __init__(self, similarity_threshold: float = 0.85) -> None: + self.similarity_threshold = similarity_threshold + self.embedding_cache: dict[str, list[float]] = {} def get_embedding(self, text: str) -> list[float] | None: """ @@ -45,25 +45,25 @@ class EntityAligner: return None # 检查缓存 - cache_key = hash(text) + cache_key = hash(text) if cache_key in self.embedding_cache: return self.embedding_cache[cache_key] try: - response = httpx.post( + response = httpx.post( f"{KIMI_BASE_URL}/v1/embeddings", - headers={ + headers = { "Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json", }, - json={"model": "k2p5", "input": text[:500]}, # 限制长度 - timeout=30.0, + json = {"model": "k2p5", "input": text[:500]}, # 限制长度 + timeout = 30.0, ) response.raise_for_status() - result = response.json() + result = response.json() - embedding = result["data"][0]["embedding"] - self.embedding_cache[cache_key] = embedding + embedding = result["data"][0]["embedding"] + self.embedding_cache[cache_key] = embedding return embedding except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: @@ -81,20 +81,20 @@ class EntityAligner: Returns: 相似度分数 (0-1) """ - vec1 = np.array(embedding1) - vec2 = np.array(embedding2) + vec1 = np.array(embedding1) + vec2 = np.array(embedding2) # 余弦相似度 - dot_product = np.dot(vec1, vec2) - norm1 = np.linalg.norm(vec1) - norm2 = np.linalg.norm(vec2) + dot_product = np.dot(vec1, vec2) + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) if norm1 == 0 or norm2 == 0: return 0.0 return float(dot_product / (norm1 * norm2)) - def get_entity_text(self, name: str, definition: str = "") -> str: + def get_entity_text(self, name: str, definition: str = "") -> str: """ 构建用于 embedding 的实体文本 @@ -113,9 +113,9 @@ class EntityAligner: self, project_id: str, name: str, - definition: str = "", - exclude_id: str | None = None, - threshold: float | None = None, + definition: str = "", + exclude_id: str | None = None, + threshold: float | None = None, ) -> object | None: """ 查找相似的实体 @@ -131,54 +131,54 @@ class EntityAligner: 相似的实体或 None """ if threshold is None: - threshold = self.similarity_threshold + threshold = self.similarity_threshold try: from db_manager import get_db_manager - db = get_db_manager() + db = get_db_manager() except ImportError: return None # 获取项目的所有实体 - entities = db.get_all_entities_for_embedding(project_id) + entities = db.get_all_entities_for_embedding(project_id) if not entities: return None # 获取查询实体的 embedding - query_text = self.get_entity_text(name, definition) - query_embedding = self.get_embedding(query_text) + query_text = self.get_entity_text(name, definition) + query_embedding = self.get_embedding(query_text) if query_embedding is None: # 如果 embedding API 失败,回退到简单匹配 return self._fallback_similarity_match(entities, name, exclude_id) - best_match = None - best_score = threshold + best_match = None + best_score = threshold for entity in entities: if exclude_id and entity.id == exclude_id: continue # 获取实体的 embedding - entity_text = self.get_entity_text(entity.name, entity.definition) - entity_embedding = self.get_embedding(entity_text) + entity_text = self.get_entity_text(entity.name, entity.definition) + entity_embedding = self.get_embedding(entity_text) if entity_embedding is None: continue # 计算相似度 - similarity = self.compute_similarity(query_embedding, entity_embedding) + similarity = self.compute_similarity(query_embedding, entity_embedding) if similarity > best_score: - best_score = similarity - best_match = entity + best_score = similarity + best_match = entity return best_match def _fallback_similarity_match( - self, entities: list[object], name: str, exclude_id: str | None = None + self, entities: list[object], name: str, exclude_id: str | None = None ) -> object | None: """ 回退到简单的相似度匹配(不使用 embedding) @@ -191,7 +191,7 @@ class EntityAligner: Returns: 最相似的实体或 None """ - name_lower = name.lower() + name_lower = name.lower() # 1. 精确匹配 for entity in entities: @@ -212,7 +212,7 @@ class EntityAligner: return None def batch_align_entities( - self, project_id: str, new_entities: list[dict], threshold: float | None = None + self, project_id: str, new_entities: list[dict], threshold: float | None = None ) -> list[dict]: """ 批量对齐实体 @@ -226,16 +226,16 @@ class EntityAligner: 对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}] """ if threshold is None: - threshold = self.similarity_threshold + threshold = self.similarity_threshold - results = [] + results = [] for new_ent in new_entities: - matched = self.find_similar_entity( - project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold + matched = self.find_similar_entity( + project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold ) - result = { + result = { "new_entity": new_ent, "matched_entity": None, "similarity": 0.0, @@ -244,28 +244,28 @@ class EntityAligner: if matched: # 计算相似度 - query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", "")) - matched_text = self.get_entity_text(matched.name, matched.definition) + query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", "")) + matched_text = self.get_entity_text(matched.name, matched.definition) - query_emb = self.get_embedding(query_text) - matched_emb = self.get_embedding(matched_text) + query_emb = self.get_embedding(query_text) + matched_emb = self.get_embedding(matched_text) if query_emb and matched_emb: - similarity = self.compute_similarity(query_emb, matched_emb) - result["matched_entity"] = { + similarity = self.compute_similarity(query_emb, matched_emb) + result["matched_entity"] = { "id": matched.id, "name": matched.name, "type": matched.type, "definition": matched.definition, } - result["similarity"] = similarity - result["should_merge"] = similarity >= threshold + result["similarity"] = similarity + result["should_merge"] = similarity >= threshold results.append(result) return results - def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]: + def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]: """ 使用 LLM 建议实体的别名 @@ -279,7 +279,7 @@ class EntityAligner: if not KIMI_API_KEY: return [] - prompt = f"""为以下实体生成可能的别名或简称: + prompt = f"""为以下实体生成可能的别名或简称: 实体名称:{entity_name} 定义:{entity_definition} @@ -290,28 +290,28 @@ class EntityAligner: 只返回 JSON,不要其他内容。""" try: - response = httpx.post( + response = httpx.post( f"{KIMI_BASE_URL}/v1/chat/completions", - headers={ + headers = { "Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json", }, - json={ + json = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3, }, - timeout=30.0, + timeout = 30.0, ) response.raise_for_status() - result = response.json() - content = result["choices"][0]["message"]["content"] + result = response.json() + content = result["choices"][0]["message"]["content"] import re - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return data.get("aliases", []) except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: print(f"Alias suggestion failed: {e}") @@ -340,8 +340,8 @@ def simple_similarity(str1: str, str2: str) -> float: return 0.0 # 转换为小写 - s1 = str1.lower() - s2 = str2.lower() + s1 = str1.lower() + s2 = str2.lower() # 包含关系 if s1 in s2 or s2 in s1: @@ -355,11 +355,11 @@ def simple_similarity(str1: str, str2: str) -> float: if __name__ == "__main__": # 测试 - aligner = EntityAligner() + aligner = EntityAligner() # 测试 embedding - test_text = "Kubernetes 容器编排平台" - embedding = aligner.get_embedding(test_text) + test_text = "Kubernetes 容器编排平台" + embedding = aligner.get_embedding(test_text) if embedding: print(f"Embedding dimension: {len(embedding)}") print(f"First 5 values: {embedding[:5]}") @@ -367,7 +367,7 @@ if __name__ == "__main__": print("Embedding API not available") # 测试相似度计算 - emb1 = [1.0, 0.0, 0.0] - emb2 = [0.9, 0.1, 0.0] - sim = aligner.compute_similarity(emb1, emb2) + emb1 = [1.0, 0.0, 0.0] + emb2 = [0.9, 0.1, 0.0] + sim = aligner.compute_similarity(emb1, emb2) print(f"Similarity: {sim:.4f}") diff --git a/backend/export_manager.py b/backend/export_manager.py index 35e7792..8d1547a 100644 --- a/backend/export_manager.py +++ b/backend/export_manager.py @@ -14,9 +14,9 @@ from typing import Any try: import pandas as pd - PANDAS_AVAILABLE = True + PANDAS_AVAILABLE = True except ImportError: - PANDAS_AVAILABLE = False + PANDAS_AVAILABLE = False try: from reportlab.lib import colors @@ -32,9 +32,9 @@ try: TableStyle, ) - REPORTLAB_AVAILABLE = True + REPORTLAB_AVAILABLE = True except ImportError: - REPORTLAB_AVAILABLE = False + REPORTLAB_AVAILABLE = False @dataclass @@ -71,8 +71,8 @@ class ExportTranscript: class ExportManager: """导出管理器 - 处理各种导出需求""" - def __init__(self, db_manager=None): - self.db = db_manager + def __init__(self, db_manager = None) -> None: + self.db = db_manager def export_knowledge_graph_svg( self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation] @@ -84,21 +84,21 @@ class ExportManager: SVG 字符串 """ # 计算布局参数 - width = 1200 - height = 800 - center_x = width / 2 - center_y = height / 2 - radius = 300 + width = 1200 + height = 800 + center_x = width / 2 + center_y = height / 2 + radius = 300 # 按类型分组实体 - entities_by_type = {} + entities_by_type = {} for e in entities: if e.type not in entities_by_type: - entities_by_type[e.type] = [] + entities_by_type[e.type] = [] entities_by_type[e.type].append(e) # 颜色映射 - type_colors = { + type_colors = { "PERSON": "#FF6B6B", "ORGANIZATION": "#4ECDC4", "LOCATION": "#45B7D1", @@ -110,110 +110,110 @@ class ExportManager: } # 计算实体位置 - entity_positions = {} - angle_step = 2 * 3.14159 / max(len(entities), 1) + entity_positions = {} + angle_step = 2 * 3.14159 / max(len(entities), 1) for i, entity in enumerate(entities): i * angle_step - x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50 - y = center_y + radius * 0.6 * ((i % 6) - 3) * 80 - entity_positions[entity.id] = (x, y) + x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50 + y = center_y + radius * 0.6 * ((i % 6) - 3) * 80 + entity_positions[entity.id] = (x, y) # 生成 SVG - svg_parts = [ - f'', + svg_parts = [ + f'', "", - ' ', - ' ', + ' ', + ' ', " ", "", - f'', - f'知识图谱 - {project_id}', + f'', + f'知识图谱 - {project_id}', ] # 绘制关系连线 for rel in relations: if rel.source in entity_positions and rel.target in entity_positions: - x1, y1 = entity_positions[rel.source] - x2, y2 = entity_positions[rel.target] + x1, y1 = entity_positions[rel.source] + x2, y2 = entity_positions[rel.target] # 计算箭头终点(避免覆盖节点) - dx = x2 - x1 - dy = y2 - y1 - dist = (dx**2 + dy**2) ** 0.5 + dx = x2 - x1 + dy = y2 - y1 + dist = (dx**2 + dy**2) ** 0.5 if dist > 0: - offset = 40 - x2 = x2 - dx * offset / dist - y2 = y2 - dy * offset / dist + offset = 40 + x2 = x2 - dx * offset / dist + y2 = y2 - dy * offset / dist svg_parts.append( - f'' + f'' ) # 关系标签 - mid_x = (x1 + x2) / 2 - mid_y = (y1 + y2) / 2 + mid_x = (x1 + x2) / 2 + mid_y = (y1 + y2) / 2 svg_parts.append( - f'' + f'' ) svg_parts.append( - f'{rel.relation_type}' + f'{rel.relation_type}' ) # 绘制实体节点 for entity in entities: if entity.id in entity_positions: - x, y = entity_positions[entity.id] - color = type_colors.get(entity.type, type_colors["default"]) + x, y = entity_positions[entity.id] + color = type_colors.get(entity.type, type_colors["default"]) # 节点圆圈 svg_parts.append( - f'' + f'' ) # 实体名称 svg_parts.append( - f'{entity.name[:8]}' + f'{entity.name[:8]}' ) # 实体类型 svg_parts.append( - f'{entity.type}' + f'{entity.type}' ) # 图例 - legend_x = width - 150 - legend_y = 80 - rect_x = legend_x - 10 - rect_y = legend_y - 20 - rect_height = len(type_colors) * 25 + 10 + legend_x = width - 150 + legend_y = 80 + rect_x = legend_x - 10 + rect_y = legend_y - 20 + rect_height = len(type_colors) * 25 + 10 svg_parts.append( - f'' + f'' ) svg_parts.append( - f'实体类型' + f'实体类型' ) for i, (etype, color) in enumerate(type_colors.items()): if etype != "default": - y_pos = legend_y + 25 + i * 20 + y_pos = legend_y + 25 + i * 20 svg_parts.append( - f'' + f'' ) - text_y = y_pos + 4 + text_y = y_pos + 4 svg_parts.append( - f'{etype}' + f'{etype}' ) svg_parts.append("") @@ -231,12 +231,12 @@ class ExportManager: try: import cairosvg - svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) - png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) + svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) + png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8")) return png_bytes except ImportError: # 如果没有 cairosvg,返回 SVG 的 base64 - svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) + svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) return base64.b64encode(svg_content.encode("utf-8")) def export_entities_excel(self, entities: list[ExportEntity]) -> bytes: @@ -250,9 +250,9 @@ class ExportManager: raise ImportError("pandas is required for Excel export") # 准备数据 - data = [] + data = [] for e in entities: - row = { + row = { "ID": e.id, "名称": e.name, "类型": e.type, @@ -262,29 +262,29 @@ class ExportManager: } # 添加属性 for attr_name, attr_value in e.attributes.items(): - row[f"属性:{attr_name}"] = attr_value + row[f"属性:{attr_name}"] = attr_value data.append(row) - df = pd.DataFrame(data) + df = pd.DataFrame(data) # 写入 Excel - output = io.BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, sheet_name="实体列表", index=False) + output = io.BytesIO() + with pd.ExcelWriter(output, engine = "openpyxl") as writer: + df.to_excel(writer, sheet_name = "实体列表", index = False) # 调整列宽 - worksheet = writer.sheets["实体列表"] + worksheet = writer.sheets["实体列表"] for column in worksheet.columns: - max_length = 0 - column_letter = column[0].column_letter + max_length = 0 + column_letter = column[0].column_letter for cell in column: try: if len(str(cell.value)) > max_length: - max_length = len(str(cell.value)) + max_length = len(str(cell.value)) except (AttributeError, TypeError, ValueError): pass - adjusted_width = min(max_length + 2, 50) - worksheet.column_dimensions[column_letter].width = adjusted_width + adjusted_width = min(max_length + 2, 50) + worksheet.column_dimensions[column_letter].width = adjusted_width return output.getvalue() @@ -295,24 +295,24 @@ class ExportManager: Returns: CSV 字符串 """ - output = io.StringIO() + output = io.StringIO() # 收集所有可能的属性列 - all_attrs = set() + all_attrs = set() for e in entities: all_attrs.update(e.attributes.keys()) # 表头 - headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [ + headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [ f"属性:{a}" for a in sorted(all_attrs) ] - writer = csv.writer(output) + writer = csv.writer(output) writer.writerow(headers) # 数据行 for e in entities: - row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count] + row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count] for attr in sorted(all_attrs): row.append(e.attributes.get(attr, "")) writer.writerow(row) @@ -327,8 +327,8 @@ class ExportManager: CSV 字符串 """ - output = io.StringIO() - writer = csv.writer(output) + output = io.StringIO() + writer = csv.writer(output) writer.writerow(["ID", "源实体", "目标实体", "关系类型", "置信度", "证据"]) for r in relations: @@ -345,7 +345,7 @@ class ExportManager: Returns: Markdown 字符串 """ - lines = [ + lines = [ f"# {transcript.name}", "", f"**类型**: {transcript.type}", @@ -369,10 +369,10 @@ class ExportManager: ] ) for seg in transcript.segments: - speaker = seg.get("speaker", "Unknown") - start = seg.get("start", 0) - end = seg.get("end", 0) - text = seg.get("text", "") + speaker = seg.get("speaker", "Unknown") + start = seg.get("start", 0) + end = seg.get("end", 0) + text = seg.get("text", "") lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}") lines.append("") @@ -387,12 +387,12 @@ class ExportManager: ] ) for mention in transcript.entity_mentions: - entity_id = mention.get("entity_id", "") - entity = entities_map.get(entity_id) - entity_name = entity.name if entity else mention.get("entity_name", "Unknown") - entity_type = entity.type if entity else "Unknown" - position = mention.get("position", "") - context = mention.get("context", "")[:50] + "..." if mention.get("context") else "" + entity_id = mention.get("entity_id", "") + entity = entities_map.get(entity_id) + entity_name = entity.name if entity else mention.get("entity_name", "Unknown") + entity_type = entity.type if entity else "Unknown" + position = mention.get("position", "") + context = mention.get("context", "")[:50] + "..." if mention.get("context") else "" lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |") return "\n".join(lines) @@ -404,7 +404,7 @@ class ExportManager: entities: list[ExportEntity], relations: list[ExportRelation], transcripts: list[ExportTranscript], - summary: str = "", + summary: str = "", ) -> bytes: """ 导出项目报告为 PDF 格式 @@ -415,29 +415,29 @@ class ExportManager: if not REPORTLAB_AVAILABLE: raise ImportError("reportlab is required for PDF export") - output = io.BytesIO() - doc = SimpleDocTemplate( - output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18 + output = io.BytesIO() + doc = SimpleDocTemplate( + output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18 ) # 样式 - styles = getSampleStyleSheet() - title_style = ParagraphStyle( + styles = getSampleStyleSheet() + title_style = ParagraphStyle( "CustomTitle", - parent=styles["Heading1"], - fontSize=24, - spaceAfter=30, - textColor=colors.HexColor("#2c3e50"), + parent = styles["Heading1"], + fontSize = 24, + spaceAfter = 30, + textColor = colors.HexColor("#2c3e50"), ) - heading_style = ParagraphStyle( + heading_style = ParagraphStyle( "CustomHeading", - parent=styles["Heading2"], - fontSize=16, - spaceAfter=12, - textColor=colors.HexColor("#34495e"), + parent = styles["Heading2"], + fontSize = 16, + spaceAfter = 12, + textColor = colors.HexColor("#34495e"), ) - story = [] + story = [] # 标题页 story.append(Paragraph("InsightFlow 项目报告", title_style)) @@ -452,7 +452,7 @@ class ExportManager: # 统计概览 story.append(Paragraph("项目概览", heading_style)) - stats_data = [ + stats_data = [ ["指标", "数值"], ["实体数量", str(len(entities))], ["关系数量", str(len(relations))], @@ -460,14 +460,14 @@ class ExportManager: ] # 按类型统计实体 - type_counts = {} + type_counts = {} for e in entities: - type_counts[e.type] = type_counts.get(e.type, 0) + 1 + type_counts[e.type] = type_counts.get(e.type, 0) + 1 for etype, count in sorted(type_counts.items()): stats_data.append([f"{etype} 实体", str(count)]) - stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch]) + stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch]) stats_table.setStyle( TableStyle( [ @@ -496,8 +496,8 @@ class ExportManager: story.append(PageBreak()) story.append(Paragraph("实体列表", heading_style)) - entity_data = [["名称", "类型", "提及次数", "定义"]] - for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[ + entity_data = [["名称", "类型", "提及次数", "定义"]] + for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[ :50 ]: # 限制前50个 entity_data.append( @@ -509,8 +509,8 @@ class ExportManager: ] ) - entity_table = Table( - entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch] + entity_table = Table( + entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch] ) entity_table.setStyle( TableStyle( @@ -534,12 +534,12 @@ class ExportManager: story.append(PageBreak()) story.append(Paragraph("关系列表", heading_style)) - relation_data = [["源实体", "关系", "目标实体", "置信度"]] + relation_data = [["源实体", "关系", "目标实体", "置信度"]] for r in relations[:100]: # 限制前100个 relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"]) - relation_table = Table( - relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch] + relation_table = Table( + relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch] ) relation_table.setStyle( TableStyle( @@ -574,7 +574,7 @@ class ExportManager: Returns: JSON 字符串 """ - data = { + data = { "project_id": project_id, "project_name": project_name, "export_time": datetime.now().isoformat(), @@ -613,16 +613,16 @@ class ExportManager: ], } - return json.dumps(data, ensure_ascii=False, indent=2) + return json.dumps(data, ensure_ascii = False, indent = 2) # 全局导出管理器实例 -_export_manager = None +_export_manager = None -def get_export_manager(db_manager=None) -> None: +def get_export_manager(db_manager = None) -> None: """获取导出管理器实例""" global _export_manager if _export_manager is None: - _export_manager = ExportManager(db_manager) + _export_manager = ExportManager(db_manager) return _export_manager diff --git a/backend/growth_manager.py b/backend/growth_manager.py index 0d71ab3..ffcae8f 100644 --- a/backend/growth_manager.py +++ b/backend/growth_manager.py @@ -26,88 +26,88 @@ from typing import Any import httpx # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class EventType(StrEnum): """事件类型""" - PAGE_VIEW = "page_view" # 页面浏览 - FEATURE_USE = "feature_use" # 功能使用 - CONVERSION = "conversion" # 转化 - SIGNUP = "signup" # 注册 - LOGIN = "login" # 登录 - UPGRADE = "upgrade" # 升级 - DOWNGRADE = "downgrade" # 降级 - CANCEL = "cancel" # 取消订阅 - INVITE_SENT = "invite_sent" # 发送邀请 - INVITE_ACCEPTED = "invite_accepted" # 接受邀请 - REFERRAL_REWARD = "referral_reward" # 推荐奖励 + PAGE_VIEW = "page_view" # 页面浏览 + FEATURE_USE = "feature_use" # 功能使用 + CONVERSION = "conversion" # 转化 + SIGNUP = "signup" # 注册 + LOGIN = "login" # 登录 + UPGRADE = "upgrade" # 升级 + DOWNGRADE = "downgrade" # 降级 + CANCEL = "cancel" # 取消订阅 + INVITE_SENT = "invite_sent" # 发送邀请 + INVITE_ACCEPTED = "invite_accepted" # 接受邀请 + REFERRAL_REWARD = "referral_reward" # 推荐奖励 class ExperimentStatus(StrEnum): """实验状态""" - DRAFT = "draft" # 草稿 - RUNNING = "running" # 运行中 - PAUSED = "paused" # 暂停 - COMPLETED = "completed" # 已完成 - ARCHIVED = "archived" # 已归档 + DRAFT = "draft" # 草稿 + RUNNING = "running" # 运行中 + PAUSED = "paused" # 暂停 + COMPLETED = "completed" # 已完成 + ARCHIVED = "archived" # 已归档 class TrafficAllocationType(StrEnum): """流量分配类型""" - RANDOM = "random" # 随机分配 - STRATIFIED = "stratified" # 分层分配 - TARGETED = "targeted" # 定向分配 + RANDOM = "random" # 随机分配 + STRATIFIED = "stratified" # 分层分配 + TARGETED = "targeted" # 定向分配 class EmailTemplateType(StrEnum): """邮件模板类型""" - WELCOME = "welcome" # 欢迎邮件 - ONBOARDING = "onboarding" # 引导邮件 - FEATURE_ANNOUNCEMENT = "feature_announcement" # 功能公告 - CHURN_RECOVERY = "churn_recovery" # 流失挽回 - UPGRADE_PROMPT = "upgrade_prompt" # 升级提示 - REFERRAL = "referral" # 推荐邀请 - NEWSLETTER = "newsletter" # 新闻通讯 + WELCOME = "welcome" # 欢迎邮件 + ONBOARDING = "onboarding" # 引导邮件 + FEATURE_ANNOUNCEMENT = "feature_announcement" # 功能公告 + CHURN_RECOVERY = "churn_recovery" # 流失挽回 + UPGRADE_PROMPT = "upgrade_prompt" # 升级提示 + REFERRAL = "referral" # 推荐邀请 + NEWSLETTER = "newsletter" # 新闻通讯 class EmailStatus(StrEnum): """邮件状态""" - DRAFT = "draft" # 草稿 - SCHEDULED = "scheduled" # 已计划 - SENDING = "sending" # 发送中 - SENT = "sent" # 已发送 - DELIVERED = "delivered" # 已送达 - OPENED = "opened" # 已打开 - CLICKED = "clicked" # 已点击 - BOUNCED = "bounced" # 退信 - FAILED = "failed" # 失败 + DRAFT = "draft" # 草稿 + SCHEDULED = "scheduled" # 已计划 + SENDING = "sending" # 发送中 + SENT = "sent" # 已发送 + DELIVERED = "delivered" # 已送达 + OPENED = "opened" # 已打开 + CLICKED = "clicked" # 已点击 + BOUNCED = "bounced" # 退信 + FAILED = "failed" # 失败 class WorkflowTriggerType(StrEnum): """工作流触发类型""" - USER_SIGNUP = "user_signup" # 用户注册 - USER_LOGIN = "user_login" # 用户登录 - SUBSCRIPTION_CREATED = "subscription_created" # 创建订阅 - SUBSCRIPTION_CANCELLED = "subscription_cancelled" # 取消订阅 - INACTIVITY = "inactivity" # 不活跃 - MILESTONE = "milestone" # 里程碑 - CUSTOM_EVENT = "custom_event" # 自定义事件 + USER_SIGNUP = "user_signup" # 用户注册 + USER_LOGIN = "user_login" # 用户登录 + SUBSCRIPTION_CREATED = "subscription_created" # 创建订阅 + SUBSCRIPTION_CANCELLED = "subscription_cancelled" # 取消订阅 + INACTIVITY = "inactivity" # 不活跃 + MILESTONE = "milestone" # 里程碑 + CUSTOM_EVENT = "custom_event" # 自定义事件 class ReferralStatus(StrEnum): """推荐状态""" - PENDING = "pending" # 待处理 - CONVERTED = "converted" # 已转化 - REWARDED = "rewarded" # 已奖励 - EXPIRED = "expired" # 已过期 + PENDING = "pending" # 待处理 + CONVERTED = "converted" # 已转化 + REWARDED = "rewarded" # 已奖励 + EXPIRED = "expired" # 已过期 @dataclass @@ -362,17 +362,17 @@ class TeamIncentive: class GrowthManager: """运营与增长管理主类""" - def __init__(self, db_path: str = DB_PATH): - self.db_path = db_path - self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "") - self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "") - self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "") - self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "") + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "") + self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "") + self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "") + self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "") def _get_db(self) -> None: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn # ==================== 用户行为分析 ==================== @@ -383,30 +383,30 @@ class GrowthManager: user_id: str, event_type: EventType, event_name: str, - properties: dict = None, - session_id: str = None, - device_info: dict = None, - referrer: str = None, - utm_params: dict = None, + properties: dict = None, + session_id: str = None, + device_info: dict = None, + referrer: str = None, + utm_params: dict = None, ) -> AnalyticsEvent: """追踪事件""" - event_id = f"evt_{uuid.uuid4().hex[:16]}" - now = datetime.now() + event_id = f"evt_{uuid.uuid4().hex[:16]}" + now = datetime.now() - event = AnalyticsEvent( - id=event_id, - tenant_id=tenant_id, - user_id=user_id, - event_type=event_type, - event_name=event_name, - properties=properties or {}, - timestamp=now, - session_id=session_id, - device_info=device_info or {}, - referrer=referrer, - utm_source=utm_params.get("source") if utm_params else None, - utm_medium=utm_params.get("medium") if utm_params else None, - utm_campaign=utm_params.get("campaign") if utm_params else None, + event = AnalyticsEvent( + id = event_id, + tenant_id = tenant_id, + user_id = user_id, + event_type = event_type, + event_name = event_name, + properties = properties or {}, + timestamp = now, + session_id = session_id, + device_info = device_info or {}, + referrer = referrer, + utm_source = utm_params.get("source") if utm_params else None, + utm_medium = utm_params.get("medium") if utm_params else None, + utm_campaign = utm_params.get("campaign") if utm_params else None, ) with self._get_db() as conn: @@ -443,9 +443,9 @@ class GrowthManager: return event - async def _send_to_analytics_platforms(self, event: AnalyticsEvent): + async def _send_to_analytics_platforms(self, event: AnalyticsEvent) -> None: """发送事件到第三方分析平台""" - tasks = [] + tasks = [] if self.mixpanel_token: tasks.append(self._send_to_mixpanel(event)) @@ -453,17 +453,17 @@ class GrowthManager: tasks.append(self._send_to_amplitude(event)) if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.gather(*tasks, return_exceptions = True) - async def _send_to_mixpanel(self, event: AnalyticsEvent): + async def _send_to_mixpanel(self, event: AnalyticsEvent) -> None: """发送事件到 Mixpanel""" try: - headers = { + headers = { "Content-Type": "application/json", "Authorization": f"Basic {self.mixpanel_token}", } - payload = { + payload = { "event": event.event_name, "properties": { "distinct_id": event.user_id, @@ -475,17 +475,17 @@ class GrowthManager: async with httpx.AsyncClient() as client: await client.post( - "https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0 + "https://api.mixpanel.com/track", headers = headers, json = [payload], timeout = 10.0 ) except (RuntimeError, ValueError, TypeError) as e: print(f"Failed to send to Mixpanel: {e}") - async def _send_to_amplitude(self, event: AnalyticsEvent): + async def _send_to_amplitude(self, event: AnalyticsEvent) -> None: """发送事件到 Amplitude""" try: - headers = {"Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} - payload = { + payload = { "api_key": self.amplitude_api_key, "events": [ { @@ -501,45 +501,45 @@ class GrowthManager: async with httpx.AsyncClient() as client: await client.post( "https://api.amplitude.com/2/httpapi", - headers=headers, - json=payload, - timeout=10.0, + headers = headers, + json = payload, + timeout = 10.0, ) except (RuntimeError, ValueError, TypeError) as e: print(f"Failed to send to Amplitude: {e}") async def _update_user_profile( self, tenant_id: str, user_id: str, event_type: EventType, event_name: str - ): + ) -> None: """更新用户画像""" with self._get_db() as conn: # 检查用户画像是否存在 - row = conn.execute( - "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", + row = conn.execute( + "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", (tenant_id, user_id), ).fetchone() - now = datetime.now().isoformat() + now = datetime.now().isoformat() if row: # 更新现有画像 - feature_usage = json.loads(row["feature_usage"]) + feature_usage = json.loads(row["feature_usage"]) if event_name not in feature_usage: - feature_usage[event_name] = 0 + feature_usage[event_name] = 0 feature_usage[event_name] += 1 conn.execute( """ UPDATE user_profiles - SET last_seen = ?, total_events = total_events + 1, - feature_usage = ?, updated_at = ? - WHERE id = ? + SET last_seen = ?, total_events = total_events + 1, + feature_usage = ?, updated_at = ? + WHERE id = ? """, (now, json.dumps(feature_usage), now, row["id"]), ) else: # 创建新画像 - profile_id = f"up_{uuid.uuid4().hex[:16]}" + profile_id = f"up_{uuid.uuid4().hex[:16]}" conn.execute( """ INSERT INTO user_profiles @@ -571,8 +571,8 @@ class GrowthManager: def get_user_profile(self, tenant_id: str, user_id: str) -> UserProfile | None: """获取用户画像""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", + row = conn.execute( + "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", (tenant_id, user_id), ).fetchone() @@ -581,20 +581,20 @@ class GrowthManager: return None def get_user_analytics_summary( - self, tenant_id: str, start_date: datetime = None, end_date: datetime = None + self, tenant_id: str, start_date: datetime = None, end_date: datetime = None ) -> dict: """获取用户分析汇总""" with self._get_db() as conn: - query = """ + query = """ SELECT COUNT(DISTINCT user_id) as unique_users, COUNT(*) as total_events, COUNT(DISTINCT session_id) as total_sessions, COUNT(DISTINCT date(timestamp)) as active_days FROM analytics_events - WHERE tenant_id = ? + WHERE tenant_id = ? """ - params = [tenant_id] + params = [tenant_id] if start_date: query += " AND timestamp >= ?" @@ -603,15 +603,15 @@ class GrowthManager: query += " AND timestamp <= ?" params.append(end_date.isoformat()) - row = conn.execute(query, params).fetchone() + row = conn.execute(query, params).fetchone() # 获取事件类型分布 - type_query = """ + type_query = """ SELECT event_type, COUNT(*) as count FROM analytics_events - WHERE tenant_id = ? + WHERE tenant_id = ? """ - type_params = [tenant_id] + type_params = [tenant_id] if start_date: type_query += " AND timestamp >= ?" @@ -622,7 +622,7 @@ class GrowthManager: type_query += " GROUP BY event_type" - type_rows = conn.execute(type_query, type_params).fetchall() + type_rows = conn.execute(type_query, type_params).fetchall() return { "unique_users": row["unique_users"], @@ -638,17 +638,17 @@ class GrowthManager: self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str ) -> Funnel: """创建转化漏斗""" - funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - funnel = Funnel( - id=funnel_id, - tenant_id=tenant_id, - name=name, - description=description, - steps=steps, - created_at=now, - updated_at=now, + funnel = Funnel( + id = funnel_id, + tenant_id = tenant_id, + name = name, + description = description, + steps = steps, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -673,46 +673,46 @@ class GrowthManager: return funnel def analyze_funnel( - self, funnel_id: str, period_start: datetime = None, period_end: datetime = None + self, funnel_id: str, period_start: datetime = None, period_end: datetime = None ) -> FunnelAnalysis | None: """分析漏斗转化率""" with self._get_db() as conn: - funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id,)).fetchone() + funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id, )).fetchone() if not funnel_row: return None - steps = json.loads(funnel_row["steps"]) + steps = json.loads(funnel_row["steps"]) if not period_start: - period_start = datetime.now() - timedelta(days=30) + period_start = datetime.now() - timedelta(days = 30) if not period_end: - period_end = datetime.now() + period_end = datetime.now() # 计算每步转化 - step_conversions = [] - previous_count = None + step_conversions = [] + previous_count = None for step in steps: - event_name = step.get("event_name") + event_name = step.get("event_name") - query = """ + query = """ SELECT COUNT(DISTINCT user_id) as user_count FROM analytics_events - WHERE event_name = ? AND timestamp >= ? AND timestamp <= ? + WHERE event_name = ? AND timestamp >= ? AND timestamp <= ? """ - row = conn.execute( + row = conn.execute( query, (event_name, period_start.isoformat(), period_end.isoformat()) ).fetchone() - user_count = row["user_count"] if row else 0 + user_count = row["user_count"] if row else 0 - conversion_rate = 0.0 - drop_off_rate = 0.0 + conversion_rate = 0.0 + drop_off_rate = 0.0 if previous_count and previous_count > 0: - conversion_rate = user_count / previous_count - drop_off_rate = 1 - conversion_rate + conversion_rate = user_count / previous_count + drop_off_rate = 1 - conversion_rate step_conversions.append( { @@ -724,79 +724,79 @@ class GrowthManager: } ) - previous_count = user_count + previous_count = user_count # 计算总体转化率 if steps and step_conversions: - first_step_count = step_conversions[0]["user_count"] - last_step_count = step_conversions[-1]["user_count"] - overall_conversion = last_step_count / max(first_step_count, 1) + first_step_count = step_conversions[0]["user_count"] + last_step_count = step_conversions[-1]["user_count"] + overall_conversion = last_step_count / max(first_step_count, 1) else: - overall_conversion = 0.0 + overall_conversion = 0.0 # 找出主要流失点 - drop_off_points = [ + drop_off_points = [ s for s in step_conversions if s["drop_off_rate"] > 0.2 and s != step_conversions[0] ] return FunnelAnalysis( - funnel_id=funnel_id, - period_start=period_start, - period_end=period_end, - total_users=step_conversions[0]["user_count"] if step_conversions else 0, - step_conversions=step_conversions, - overall_conversion=round(overall_conversion, 4), - drop_off_points=drop_off_points, + funnel_id = funnel_id, + period_start = period_start, + period_end = period_end, + total_users = step_conversions[0]["user_count"] if step_conversions else 0, + step_conversions = step_conversions, + overall_conversion = round(overall_conversion, 4), + drop_off_points = drop_off_points, ) def calculate_retention( - self, tenant_id: str, cohort_date: datetime, periods: list[int] = None + self, tenant_id: str, cohort_date: datetime, periods: list[int] = None ) -> dict: """计算留存率""" if periods is None: - periods = [1, 3, 7, 14, 30] + periods = [1, 3, 7, 14, 30] with self._get_db() as conn: # 获取同期群用户(在 cohort_date 当天首次活跃的用户) - cohort_query = """ + cohort_query = """ SELECT DISTINCT user_id FROM analytics_events - WHERE tenant_id = ? AND date(timestamp) = date(?) + WHERE tenant_id = ? AND date(timestamp) = date(?) AND user_id IN ( SELECT user_id FROM user_profiles - WHERE tenant_id = ? AND date(first_seen) = date(?) + WHERE tenant_id = ? AND date(first_seen) = date(?) ) """ - cohort_rows = conn.execute( + cohort_rows = conn.execute( cohort_query, (tenant_id, cohort_date.isoformat(), tenant_id, cohort_date.isoformat()), ).fetchall() - cohort_users = {r["user_id"] for r in cohort_rows} - cohort_size = len(cohort_users) + cohort_users = {r["user_id"] for r in cohort_rows} + cohort_size = len(cohort_users) if cohort_size == 0: return {"cohort_date": cohort_date.isoformat(), "cohort_size": 0, "retention": {}} - retention_rates = {} + retention_rates = {} for period in periods: - period_date = cohort_date + timedelta(days=period) + period_date = cohort_date + timedelta(days = period) - active_query = """ + active_query = """ SELECT COUNT(DISTINCT user_id) as active_count FROM analytics_events - WHERE tenant_id = ? AND date(timestamp) = date(?) + WHERE tenant_id = ? AND date(timestamp) = date(?) AND user_id IN ({}) - """.format(",".join(["?" for _ in cohort_users])) + """.format(", ".join(["?" for _ in cohort_users])) - params = [tenant_id, period_date.isoformat()] + list(cohort_users) - row = conn.execute(active_query, params).fetchone() + params = [tenant_id, period_date.isoformat()] + list(cohort_users) + row = conn.execute(active_query, params).fetchone() - active_count = row["active_count"] if row else 0 - retention_rate = active_count / cohort_size + active_count = row["active_count"] if row else 0 + retention_rate = active_count / cohort_size - retention_rates[f"day_{period}"] = { + retention_rates[f"day_{period}"] = { "active_users": active_count, "retention_rate": round(retention_rate, 4), } @@ -821,34 +821,34 @@ class GrowthManager: target_audience: dict, primary_metric: str, secondary_metrics: list[str], - min_sample_size: int = 100, - confidence_level: float = 0.95, - created_by: str = None, + min_sample_size: int = 100, + confidence_level: float = 0.95, + created_by: str = None, ) -> Experiment: """创建 A/B 测试实验""" - experiment_id = f"exp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + experiment_id = f"exp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - experiment = Experiment( - id=experiment_id, - tenant_id=tenant_id, - name=name, - description=description, - hypothesis=hypothesis, - status=ExperimentStatus.DRAFT, - variants=variants, - traffic_allocation=traffic_allocation, - traffic_split=traffic_split, - target_audience=target_audience, - primary_metric=primary_metric, - secondary_metrics=secondary_metrics, - start_date=None, - end_date=None, - min_sample_size=min_sample_size, - confidence_level=confidence_level, - created_at=now, - updated_at=now, - created_by=created_by or "system", + experiment = Experiment( + id = experiment_id, + tenant_id = tenant_id, + name = name, + description = description, + hypothesis = hypothesis, + status = ExperimentStatus.DRAFT, + variants = variants, + traffic_allocation = traffic_allocation, + traffic_split = traffic_split, + target_audience = target_audience, + primary_metric = primary_metric, + secondary_metrics = secondary_metrics, + start_date = None, + end_date = None, + min_sample_size = min_sample_size, + confidence_level = confidence_level, + created_at = now, + updated_at = now, + created_by = created_by or "system", ) with self._get_db() as conn: @@ -890,42 +890,42 @@ class GrowthManager: def get_experiment(self, experiment_id: str) -> Experiment | None: """获取实验详情""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM experiments WHERE id = ?", (experiment_id,) + row = conn.execute( + "SELECT * FROM experiments WHERE id = ?", (experiment_id, ) ).fetchone() if row: return self._row_to_experiment(row) return None - def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]: + def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]: """列出实验""" - query = "SELECT * FROM experiments WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM experiments WHERE tenant_id = ?" + params = [tenant_id] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status.value) query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_experiment(row) for row in rows] def assign_variant( - self, experiment_id: str, user_id: str, user_attributes: dict = None + self, experiment_id: str, user_id: str, user_attributes: dict = None ) -> str | None: """为用户分配实验变体""" - experiment = self.get_experiment(experiment_id) + experiment = self.get_experiment(experiment_id) if not experiment or experiment.status != ExperimentStatus.RUNNING: return None # 检查用户是否已分配 with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT variant_id FROM experiment_assignments - WHERE experiment_id = ? AND user_id = ?""", + WHERE experiment_id = ? AND user_id = ?""", (experiment_id, user_id), ).fetchone() @@ -934,18 +934,18 @@ class GrowthManager: # 根据分配策略选择变体 if experiment.traffic_allocation == TrafficAllocationType.RANDOM: - variant_id = self._random_allocation(experiment.variants, experiment.traffic_split) + variant_id = self._random_allocation(experiment.variants, experiment.traffic_split) elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED: - variant_id = self._stratified_allocation( + variant_id = self._stratified_allocation( experiment.variants, experiment.traffic_split, user_attributes ) else: # TARGETED - variant_id = self._targeted_allocation( + variant_id = self._targeted_allocation( experiment.variants, experiment.target_audience, user_attributes ) if variant_id: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ INSERT INTO experiment_assignments @@ -967,13 +967,13 @@ class GrowthManager: def _random_allocation(self, variants: list[dict], traffic_split: dict[str, float]) -> str: """随机分配""" - variant_ids = [v["id"] for v in variants] - weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids] + variant_ids = [v["id"] for v in variants] + weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids] - total = sum(weights) - normalized_weights = [w / total for w in weights] + total = sum(weights) + normalized_weights = [w / total for w in weights] - return random.choices(variant_ids, weights=normalized_weights, k=1)[0] + return random.choices(variant_ids, weights = normalized_weights, k = 1)[0] def _stratified_allocation( self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict @@ -981,9 +981,9 @@ class GrowthManager: """分层分配(基于用户属性)""" # 简化的分层分配:根据用户 ID 哈希值分配 if user_attributes and "user_id" in user_attributes: - hash_value = int(hashlib.md5(user_attributes["user_id"].encode()).hexdigest(), 16) - variant_ids = [v["id"] for v in variants] - index = hash_value % len(variant_ids) + hash_value = int(hashlib.md5(user_attributes["user_id"].encode()).hexdigest(), 16) + variant_ids = [v["id"] for v in variants] + index = hash_value % len(variant_ids) return variant_ids[index] return self._random_allocation(variants, traffic_split) @@ -993,29 +993,29 @@ class GrowthManager: ) -> str | None: """定向分配(基于目标受众条件)""" # 检查用户是否符合目标受众条件 - conditions = target_audience.get("conditions", []) + conditions = target_audience.get("conditions", []) - matches = True + matches = True for condition in conditions: - attr_name = condition.get("attribute") - operator = condition.get("operator") - value = condition.get("value") + attr_name = condition.get("attribute") + operator = condition.get("operator") + value = condition.get("value") - user_value = user_attributes.get(attr_name) if user_attributes else None + user_value = user_attributes.get(attr_name) if user_attributes else None if operator == "equals" and user_value != value: - matches = False + matches = False break elif operator == "not_equals" and user_value == value: - matches = False + matches = False break elif operator == "in" and user_value not in value: - matches = False + matches = False break if not matches: # 用户不符合条件,返回对照组 - control_variant = next((v for v in variants if v.get("is_control")), variants[0]) + control_variant = next((v for v in variants if v.get("is_control")), variants[0]) return control_variant["id"] if control_variant else None return self._random_allocation(variants, target_audience.get("traffic_split", {})) @@ -1027,7 +1027,7 @@ class GrowthManager: user_id: str, metric_name: str, metric_value: float, - ): + ) -> None: """记录实验指标""" with self._get_db() as conn: conn.execute( @@ -1050,46 +1050,46 @@ class GrowthManager: def analyze_experiment(self, experiment_id: str) -> dict: """分析实验结果""" - experiment = self.get_experiment(experiment_id) + experiment = self.get_experiment(experiment_id) if not experiment: return {"error": "Experiment not found"} with self._get_db() as conn: - results = {} + results = {} for variant in experiment.variants: - variant_id = variant["id"] + variant_id = variant["id"] # 获取样本量 - sample_row = conn.execute( + sample_row = conn.execute( """ SELECT COUNT(DISTINCT user_id) as sample_size FROM experiment_assignments - WHERE experiment_id = ? AND variant_id = ? + WHERE experiment_id = ? AND variant_id = ? """, (experiment_id, variant_id), ).fetchone() - sample_size = sample_row["sample_size"] if sample_row else 0 + sample_size = sample_row["sample_size"] if sample_row else 0 # 获取主要指标统计 - metric_row = conn.execute( + metric_row = conn.execute( """ SELECT AVG(metric_value) as mean_value, COUNT(*) as metric_count, SUM(metric_value) as total_value FROM experiment_metrics - WHERE experiment_id = ? AND variant_id = ? AND metric_name = ? + WHERE experiment_id = ? AND variant_id = ? AND metric_name = ? """, (experiment_id, variant_id, experiment.primary_metric), ).fetchone() - mean_value = ( + mean_value = ( metric_row["mean_value"] if metric_row and metric_row["mean_value"] else 0 ) - results[variant_id] = { + results[variant_id] = { "variant_name": variant.get("name", variant_id), "is_control": variant.get("is_control", False), "sample_size": sample_size, @@ -1098,27 +1098,27 @@ class GrowthManager: } # 计算统计显著性(简化版) - control_variant = next((v for v in experiment.variants if v.get("is_control")), None) + control_variant = next((v for v in experiment.variants if v.get("is_control")), None) if control_variant: - control_id = control_variant["id"] - control_result = results.get(control_id, {}) + control_id = control_variant["id"] + control_result = results.get(control_id, {}) for variant_id, result in results.items(): if variant_id != control_id: - control_mean = control_result.get("mean_value", 0) - variant_mean = result.get("mean_value", 0) + control_mean = control_result.get("mean_value", 0) + variant_mean = result.get("mean_value", 0) if control_mean > 0: - uplift = (variant_mean - control_mean) / control_mean + uplift = (variant_mean - control_mean) / control_mean else: - uplift = 0 + uplift = 0 # 简化的显著性判断 - is_significant = abs(uplift) > 0.05 and result["sample_size"] > 100 + is_significant = abs(uplift) > 0.05 and result["sample_size"] > 100 - result["uplift"] = round(uplift, 4) - result["is_significant"] = is_significant - result["p_value"] = 0.05 if is_significant else 0.5 + result["uplift"] = round(uplift, 4) + result["is_significant"] = is_significant + result["p_value"] = 0.05 if is_significant else 0.5 return { "experiment_id": experiment_id, @@ -1131,12 +1131,12 @@ class GrowthManager: def start_experiment(self, experiment_id: str) -> Experiment | None: """启动实验""" with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE experiments - SET status = ?, start_date = ?, updated_at = ? - WHERE id = ? AND status = ? + SET status = ?, start_date = ?, updated_at = ? + WHERE id = ? AND status = ? """, ( ExperimentStatus.RUNNING.value, @@ -1153,12 +1153,12 @@ class GrowthManager: def stop_experiment(self, experiment_id: str) -> Experiment | None: """停止实验""" with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE experiments - SET status = ?, end_date = ?, updated_at = ? - WHERE id = ? AND status = ? + SET status = ?, end_date = ?, updated_at = ? + WHERE id = ? AND status = ? """, ( ExperimentStatus.COMPLETED.value, @@ -1181,36 +1181,36 @@ class GrowthManager: template_type: EmailTemplateType, subject: str, html_content: str, - text_content: str = None, - variables: list[str] = None, - from_name: str = None, - from_email: str = None, - reply_to: str = None, + text_content: str = None, + variables: list[str] = None, + from_name: str = None, + from_email: str = None, + reply_to: str = None, ) -> EmailTemplate: """创建邮件模板""" - template_id = f"et_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + template_id = f"et_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 自动提取变量 if variables is None: - variables = re.findall(r"\{\{(\w+)\}\}", html_content) + variables = re.findall(r"\{\{(\w+)\}\}", html_content) - template = EmailTemplate( - id=template_id, - tenant_id=tenant_id, - name=name, - template_type=template_type, - subject=subject, - html_content=html_content, - text_content=text_content or re.sub(r"<[^>]+>", "", html_content), - variables=variables, - preview_text=None, - from_name=from_name or "InsightFlow", - from_email=from_email or "noreply@insightflow.io", - reply_to=reply_to, - is_active=True, - created_at=now, - updated_at=now, + template = EmailTemplate( + id = template_id, + tenant_id = tenant_id, + name = name, + template_type = template_type, + subject = subject, + html_content = html_content, + text_content = text_content or re.sub(r"<[^>]+>", "", html_content), + variables = variables, + preview_text = None, + from_name = from_name or "InsightFlow", + from_email = from_email or "noreply@insightflow.io", + reply_to = reply_to, + is_active = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1245,8 +1245,8 @@ class GrowthManager: def get_email_template(self, template_id: str) -> EmailTemplate | None: """获取邮件模板""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM email_templates WHERE id = ?", (template_id,) + row = conn.execute( + "SELECT * FROM email_templates WHERE id = ?", (template_id, ) ).fetchone() if row: @@ -1254,37 +1254,37 @@ class GrowthManager: return None def list_email_templates( - self, tenant_id: str, template_type: EmailTemplateType = None + self, tenant_id: str, template_type: EmailTemplateType = None ) -> list[EmailTemplate]: """列出邮件模板""" - query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" - params = [tenant_id] + query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" + params = [tenant_id] if template_type: - query += " AND template_type = ?" + query += " AND template_type = ?" params.append(template_type.value) query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_email_template(row) for row in rows] def render_template(self, template_id: str, variables: dict) -> dict[str, str]: """渲染邮件模板""" - template = self.get_email_template(template_id) + template = self.get_email_template(template_id) if not template: return None - subject = template.subject - html_content = template.html_content - text_content = template.text_content + subject = template.subject + html_content = template.html_content + text_content = template.text_content for key, value in variables.items(): - placeholder = f"{{{{{key}}}}}" - subject = subject.replace(placeholder, str(value)) - html_content = html_content.replace(placeholder, str(value)) - text_content = text_content.replace(placeholder, str(value)) + placeholder = f"{{{{{key}}}}}" + subject = subject.replace(placeholder, str(value)) + html_content = html_content.replace(placeholder, str(value)) + text_content = text_content.replace(placeholder, str(value)) return { "subject": subject, @@ -1301,29 +1301,29 @@ class GrowthManager: name: str, template_id: str, recipient_list: list[dict], - scheduled_at: datetime = None, + scheduled_at: datetime = None, ) -> EmailCampaign: """创建邮件营销活动""" - campaign_id = f"ec_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + campaign_id = f"ec_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - campaign = EmailCampaign( - id=campaign_id, - tenant_id=tenant_id, - name=name, - template_id=template_id, - status="draft", - recipient_count=len(recipient_list), - sent_count=0, - delivered_count=0, - opened_count=0, - clicked_count=0, - bounced_count=0, - failed_count=0, - scheduled_at=scheduled_at.isoformat() if scheduled_at else None, - started_at=None, - completed_at=None, - created_at=now, + campaign = EmailCampaign( + id = campaign_id, + tenant_id = tenant_id, + name = name, + template_id = template_id, + status = "draft", + recipient_count = len(recipient_list), + sent_count = 0, + delivered_count = 0, + opened_count = 0, + clicked_count = 0, + bounced_count = 0, + failed_count = 0, + scheduled_at = scheduled_at.isoformat() if scheduled_at else None, + started_at = None, + completed_at = None, + created_at = now, ) with self._get_db() as conn: @@ -1384,20 +1384,20 @@ class GrowthManager: self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict ) -> bool: """发送单封邮件""" - template = self.get_email_template(template_id) + template = self.get_email_template(template_id) if not template: return False - rendered = self.render_template(template_id, variables) + rendered = self.render_template(template_id, variables) # 更新状态为发送中 with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE email_logs - SET status = ?, sent_at = ?, subject = ? - WHERE campaign_id = ? AND user_id = ? + SET status = ?, sent_at = ?, subject = ? + WHERE campaign_id = ? AND user_id = ? """, (EmailStatus.SENDING.value, now, rendered["subject"], campaign_id, user_id), ) @@ -1408,17 +1408,17 @@ class GrowthManager: # 目前使用模拟发送 await asyncio.sleep(0.1) - success = True # 模拟成功 + success = True # 模拟成功 # 更新状态 with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() if success: conn.execute( """ UPDATE email_logs - SET status = ?, delivered_at = ? - WHERE campaign_id = ? AND user_id = ? + SET status = ?, delivered_at = ? + WHERE campaign_id = ? AND user_id = ? """, (EmailStatus.DELIVERED.value, now, campaign_id, user_id), ) @@ -1426,8 +1426,8 @@ class GrowthManager: conn.execute( """ UPDATE email_logs - SET status = ?, error_message = ? - WHERE campaign_id = ? AND user_id = ? + SET status = ?, error_message = ? + WHERE campaign_id = ? AND user_id = ? """, (EmailStatus.FAILED.value, "Send failed", campaign_id, user_id), ) @@ -1440,8 +1440,8 @@ class GrowthManager: conn.execute( """ UPDATE email_logs - SET status = ?, error_message = ? - WHERE campaign_id = ? AND user_id = ? + SET status = ?, error_message = ? + WHERE campaign_id = ? AND user_id = ? """, (EmailStatus.FAILED.value, str(e), campaign_id, user_id), ) @@ -1451,37 +1451,37 @@ class GrowthManager: async def send_campaign(self, campaign_id: str) -> dict: """发送整个营销活动""" with self._get_db() as conn: - campaign_row = conn.execute( - "SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,) + campaign_row = conn.execute( + "SELECT * FROM email_campaigns WHERE id = ?", (campaign_id, ) ).fetchone() if not campaign_row: return {"error": "Campaign not found"} # 获取待发送的邮件 - logs = conn.execute( + logs = conn.execute( """SELECT * FROM email_logs - WHERE campaign_id = ? AND status IN (?, ?)""", + WHERE campaign_id = ? AND status IN (?, ?)""", (campaign_id, EmailStatus.DRAFT.value, EmailStatus.SCHEDULED.value), ).fetchall() # 更新活动状态 - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( - "UPDATE email_campaigns SET status = ?, started_at = ? WHERE id = ?", + "UPDATE email_campaigns SET status = ?, started_at = ? WHERE id = ?", ("sending", now, campaign_id), ) conn.commit() # 批量发送 - success_count = 0 - failed_count = 0 + success_count = 0 + failed_count = 0 for log in logs: # 获取用户变量 - variables = self._get_user_variables(log["tenant_id"], log["user_id"]) + variables = self._get_user_variables(log["tenant_id"], log["user_id"]) - success = await self.send_email( + success = await self.send_email( campaign_id, log["user_id"], log["email"], log["template_id"], variables ) @@ -1492,12 +1492,12 @@ class GrowthManager: # 更新活动状态 with self._get_db() as conn: - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE email_campaigns - SET status = ?, completed_at = ?, sent_count = ? - WHERE id = ? + SET status = ?, completed_at = ?, sent_count = ? + WHERE id = ? """, ("completed", now, success_count, campaign_id), ) @@ -1526,21 +1526,21 @@ class GrowthManager: actions: list[dict], ) -> AutomationWorkflow: """创建自动化工作流""" - workflow_id = f"aw_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + workflow_id = f"aw_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - workflow = AutomationWorkflow( - id=workflow_id, - tenant_id=tenant_id, - name=name, - description=description, - trigger_type=trigger_type, - trigger_conditions=trigger_conditions, - actions=actions, - is_active=True, - execution_count=0, - created_at=now, - updated_at=now, + workflow = AutomationWorkflow( + id = workflow_id, + tenant_id = tenant_id, + name = name, + description = description, + trigger_type = trigger_type, + trigger_conditions = trigger_conditions, + actions = actions, + is_active = True, + execution_count = 0, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1569,17 +1569,17 @@ class GrowthManager: return workflow - async def trigger_workflow(self, workflow_id: str, event_data: dict): + async def trigger_workflow(self, workflow_id: str, event_data: dict) -> None: """触发自动化工作流""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id,) + row = conn.execute( + "SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id, ) ).fetchone() if not row: return False - workflow = self._row_to_automation_workflow(row) + workflow = self._row_to_automation_workflow(row) # 检查触发条件 if not self._check_trigger_conditions(workflow.trigger_conditions, event_data): @@ -1591,8 +1591,8 @@ class GrowthManager: # 更新执行计数 conn.execute( - "UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?", - (workflow_id,), + "UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?", + (workflow_id, ), ) conn.commit() @@ -1606,9 +1606,9 @@ class GrowthManager: return False return True - async def _execute_action(self, action: dict, event_data: dict): + async def _execute_action(self, action: dict, event_data: dict) -> None: """执行工作流动作""" - action_type = action.get("type") + action_type = action.get("type") if action_type == "send_email": action.get("template_id") @@ -1631,29 +1631,29 @@ class GrowthManager: referrer_reward_value: float, referee_reward_type: str, referee_reward_value: float, - max_referrals_per_user: int = 10, - referral_code_length: int = 8, - expiry_days: int = 30, + max_referrals_per_user: int = 10, + referral_code_length: int = 8, + expiry_days: int = 30, ) -> ReferralProgram: """创建推荐计划""" - program_id = f"rp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + program_id = f"rp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - program = ReferralProgram( - id=program_id, - tenant_id=tenant_id, - name=name, - description=description, - referrer_reward_type=referrer_reward_type, - referrer_reward_value=referrer_reward_value, - referee_reward_type=referee_reward_type, - referee_reward_value=referee_reward_value, - max_referrals_per_user=max_referrals_per_user, - referral_code_length=referral_code_length, - expiry_days=expiry_days, - is_active=True, - created_at=now, - updated_at=now, + program = ReferralProgram( + id = program_id, + tenant_id = tenant_id, + name = name, + description = description, + referrer_reward_type = referrer_reward_type, + referrer_reward_value = referrer_reward_value, + referee_reward_type = referee_reward_type, + referee_reward_value = referee_reward_value, + max_referrals_per_user = max_referrals_per_user, + referral_code_length = referral_code_length, + expiry_days = expiry_days, + is_active = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1688,15 +1688,15 @@ class GrowthManager: def generate_referral_code(self, program_id: str, referrer_id: str) -> Referral: """生成推荐码""" - program = self._get_referral_program(program_id) + program = self._get_referral_program(program_id) if not program: return None # 检查推荐次数限制 with self._get_db() as conn: - count_row = conn.execute( + count_row = conn.execute( """SELECT COUNT(*) as count FROM referrals - WHERE program_id = ? AND referrer_id = ? AND status != ?""", + WHERE program_id = ? AND referrer_id = ? AND status != ?""", (program_id, referrer_id, ReferralStatus.EXPIRED.value), ).fetchone() @@ -1704,28 +1704,28 @@ class GrowthManager: return None # 生成推荐码 - referral_code = self._generate_unique_code(program.referral_code_length) + referral_code = self._generate_unique_code(program.referral_code_length) - referral_id = f"ref_{uuid.uuid4().hex[:16]}" - now = datetime.now() - expires_at = now + timedelta(days=program.expiry_days) + referral_id = f"ref_{uuid.uuid4().hex[:16]}" + now = datetime.now() + expires_at = now + timedelta(days = program.expiry_days) - referral = Referral( - id=referral_id, - program_id=program_id, - tenant_id=program.tenant_id, - referrer_id=referrer_id, - referee_id=None, - referral_code=referral_code, - status=ReferralStatus.PENDING, - referrer_rewarded=False, - referee_rewarded=False, - referrer_reward_value=program.referrer_reward_value, - referee_reward_value=program.referee_reward_value, - converted_at=None, - rewarded_at=None, - expires_at=expires_at, - created_at=now, + referral = Referral( + id = referral_id, + program_id = program_id, + tenant_id = program.tenant_id, + referrer_id = referrer_id, + referee_id = None, + referral_code = referral_code, + status = ReferralStatus.PENDING, + referrer_rewarded = False, + referee_rewarded = False, + referrer_reward_value = program.referrer_reward_value, + referee_reward_value = program.referee_reward_value, + converted_at = None, + rewarded_at = None, + expires_at = expires_at, + created_at = now, ) conn.execute( @@ -1760,13 +1760,13 @@ class GrowthManager: def _generate_unique_code(self, length: int) -> str: """生成唯一推荐码""" - chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符 + chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符 while True: - code = "".join(random.choices(chars, k=length)) + code = "".join(random.choices(chars, k = length)) with self._get_db() as conn: - row = conn.execute( - "SELECT 1 FROM referrals WHERE referral_code = ?", (code,) + row = conn.execute( + "SELECT 1 FROM referrals WHERE referral_code = ?", (code, ) ).fetchone() if not row: @@ -1775,8 +1775,8 @@ class GrowthManager: def _get_referral_program(self, program_id: str) -> ReferralProgram | None: """获取推荐计划""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM referral_programs WHERE id = ?", (program_id,) + row = conn.execute( + "SELECT * FROM referral_programs WHERE id = ?", (program_id, ) ).fetchone() if row: @@ -1786,21 +1786,21 @@ class GrowthManager: def apply_referral_code(self, referral_code: str, referee_id: str) -> bool: """应用推荐码""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT * FROM referrals - WHERE referral_code = ? AND status = ? AND expires_at > ?""", + WHERE referral_code = ? AND status = ? AND expires_at > ?""", (referral_code, ReferralStatus.PENDING.value, datetime.now().isoformat()), ).fetchone() if not row: return False - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE referrals - SET referee_id = ?, status = ?, converted_at = ? - WHERE id = ? + SET referee_id = ?, status = ?, converted_at = ? + WHERE id = ? """, (referee_id, ReferralStatus.CONVERTED.value, now, row["id"]), ) @@ -1811,17 +1811,17 @@ class GrowthManager: def reward_referral(self, referral_id: str) -> bool: """发放推荐奖励""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id,)).fetchone() + row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id, )).fetchone() if not row or row["status"] != ReferralStatus.CONVERTED.value: return False - now = datetime.now().isoformat() + now = datetime.now().isoformat() conn.execute( """ UPDATE referrals - SET status = ?, referrer_rewarded = 1, referee_rewarded = 1, rewarded_at = ? - WHERE id = ? + SET status = ?, referrer_rewarded = 1, referee_rewarded = 1, rewarded_at = ? + WHERE id = ? """, (ReferralStatus.REWARDED.value, now, referral_id), ) @@ -1832,17 +1832,17 @@ class GrowthManager: def get_referral_stats(self, program_id: str) -> dict: """获取推荐统计""" with self._get_db() as conn: - stats = conn.execute( + stats = conn.execute( """ SELECT COUNT(*) as total_referrals, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as pending, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as converted, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as rewarded, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as expired, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as pending, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as converted, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as rewarded, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as expired, COUNT(DISTINCT referrer_id) as unique_referrers FROM referrals - WHERE program_id = ? + WHERE program_id = ? """, ( ReferralStatus.PENDING.value, @@ -1879,22 +1879,22 @@ class GrowthManager: valid_until: datetime, ) -> TeamIncentive: """创建团队升级激励""" - incentive_id = f"ti_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + incentive_id = f"ti_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - incentive = TeamIncentive( - id=incentive_id, - tenant_id=tenant_id, - name=name, - description=description, - target_tier=target_tier, - min_team_size=min_team_size, - incentive_type=incentive_type, - incentive_value=incentive_value, - valid_from=valid_from.isoformat(), - valid_until=valid_until.isoformat(), - is_active=True, - created_at=now, + incentive = TeamIncentive( + id = incentive_id, + tenant_id = tenant_id, + name = name, + description = description, + target_tier = target_tier, + min_team_size = min_team_size, + incentive_type = incentive_type, + incentive_value = incentive_value, + valid_from = valid_from.isoformat(), + valid_until = valid_until.isoformat(), + is_active = True, + created_at = now, ) with self._get_db() as conn: @@ -1929,12 +1929,12 @@ class GrowthManager: ) -> list[TeamIncentive]: """检查团队激励资格""" with self._get_db() as conn: - now = datetime.now().isoformat() - rows = conn.execute( + now = datetime.now().isoformat() + rows = conn.execute( """ SELECT * FROM team_incentives - WHERE tenant_id = ? AND is_active = 1 - AND target_tier = ? AND min_team_size <= ? + WHERE tenant_id = ? AND is_active = 1 + AND target_tier = ? AND min_team_size <= ? AND valid_from <= ? AND valid_until >= ? """, (tenant_id, current_tier, team_size, now, now), @@ -1946,41 +1946,41 @@ class GrowthManager: def get_realtime_dashboard(self, tenant_id: str) -> dict: """获取实时分析仪表板数据""" - now = datetime.now() - today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) + now = datetime.now() + today_start = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0) with self._get_db() as conn: # 今日统计 - today_stats = conn.execute( + today_stats = conn.execute( """ SELECT COUNT(DISTINCT user_id) as active_users, COUNT(*) as total_events, COUNT(DISTINCT session_id) as sessions FROM analytics_events - WHERE tenant_id = ? AND timestamp >= ? + WHERE tenant_id = ? AND timestamp >= ? """, (tenant_id, today_start.isoformat()), ).fetchone() # 最近事件 - recent_events = conn.execute( + recent_events = conn.execute( """ SELECT event_name, event_type, timestamp, user_id FROM analytics_events - WHERE tenant_id = ? + WHERE tenant_id = ? ORDER BY timestamp DESC LIMIT 20 """, - (tenant_id,), + (tenant_id, ), ).fetchall() # 热门功能 - top_features = conn.execute( + top_features = conn.execute( """ SELECT event_name, COUNT(*) as count FROM analytics_events - WHERE tenant_id = ? AND timestamp >= ? AND event_type = ? + WHERE tenant_id = ? AND timestamp >= ? AND event_type = ? GROUP BY event_name ORDER BY count DESC LIMIT 10 @@ -1989,16 +1989,16 @@ class GrowthManager: ).fetchall() # 活跃用户趋势(最近24小时,每小时) - hourly_trend = [] + hourly_trend = [] for i in range(24): - hour_start = now - timedelta(hours=i + 1) - hour_end = now - timedelta(hours=i) + hour_start = now - timedelta(hours = i + 1) + hour_end = now - timedelta(hours = i) - row = conn.execute( + row = conn.execute( """ SELECT COUNT(DISTINCT user_id) as count FROM analytics_events - WHERE tenant_id = ? AND timestamp >= ? AND timestamp < ? + WHERE tenant_id = ? AND timestamp >= ? AND timestamp < ? """, (tenant_id, hour_start.isoformat(), hour_end.isoformat()), ).fetchone() @@ -2035,125 +2035,125 @@ class GrowthManager: def _row_to_user_profile(self, row) -> UserProfile: """将数据库行转换为 UserProfile""" return UserProfile( - id=row["id"], - tenant_id=row["tenant_id"], - user_id=row["user_id"], - first_seen=datetime.fromisoformat(row["first_seen"]), - last_seen=datetime.fromisoformat(row["last_seen"]), - total_sessions=row["total_sessions"], - total_events=row["total_events"], - feature_usage=json.loads(row["feature_usage"]), - subscription_history=json.loads(row["subscription_history"]), - ltv=row["ltv"], - churn_risk_score=row["churn_risk_score"], - engagement_score=row["engagement_score"], - created_at=datetime.fromisoformat(row["created_at"]), - updated_at=datetime.fromisoformat(row["updated_at"]), + id = row["id"], + tenant_id = row["tenant_id"], + user_id = row["user_id"], + first_seen = datetime.fromisoformat(row["first_seen"]), + last_seen = datetime.fromisoformat(row["last_seen"]), + total_sessions = row["total_sessions"], + total_events = row["total_events"], + feature_usage = json.loads(row["feature_usage"]), + subscription_history = json.loads(row["subscription_history"]), + ltv = row["ltv"], + churn_risk_score = row["churn_risk_score"], + engagement_score = row["engagement_score"], + created_at = datetime.fromisoformat(row["created_at"]), + updated_at = datetime.fromisoformat(row["updated_at"]), ) def _row_to_experiment(self, row) -> Experiment: """将数据库行转换为 Experiment""" return Experiment( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - hypothesis=row["hypothesis"], - status=ExperimentStatus(row["status"]), - variants=json.loads(row["variants"]), - traffic_allocation=TrafficAllocationType(row["traffic_allocation"]), - traffic_split=json.loads(row["traffic_split"]), - target_audience=json.loads(row["target_audience"]), - primary_metric=row["primary_metric"], - secondary_metrics=json.loads(row["secondary_metrics"]), - start_date=datetime.fromisoformat(row["start_date"]) if row["start_date"] else None, - end_date=datetime.fromisoformat(row["end_date"]) if row["end_date"] else None, - min_sample_size=row["min_sample_size"], - confidence_level=row["confidence_level"], - created_at=row["created_at"], - updated_at=row["updated_at"], - created_by=row["created_by"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + hypothesis = row["hypothesis"], + status = ExperimentStatus(row["status"]), + variants = json.loads(row["variants"]), + traffic_allocation = TrafficAllocationType(row["traffic_allocation"]), + traffic_split = json.loads(row["traffic_split"]), + target_audience = json.loads(row["target_audience"]), + primary_metric = row["primary_metric"], + secondary_metrics = json.loads(row["secondary_metrics"]), + start_date = datetime.fromisoformat(row["start_date"]) if row["start_date"] else None, + end_date = datetime.fromisoformat(row["end_date"]) if row["end_date"] else None, + min_sample_size = row["min_sample_size"], + confidence_level = row["confidence_level"], + created_at = row["created_at"], + updated_at = row["updated_at"], + created_by = row["created_by"], ) def _row_to_email_template(self, row) -> EmailTemplate: """将数据库行转换为 EmailTemplate""" return EmailTemplate( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - template_type=EmailTemplateType(row["template_type"]), - subject=row["subject"], - html_content=row["html_content"], - text_content=row["text_content"], - variables=json.loads(row["variables"]), - preview_text=row["preview_text"], - from_name=row["from_name"], - from_email=row["from_email"], - reply_to=row["reply_to"], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + template_type = EmailTemplateType(row["template_type"]), + subject = row["subject"], + html_content = row["html_content"], + text_content = row["text_content"], + variables = json.loads(row["variables"]), + preview_text = row["preview_text"], + from_name = row["from_name"], + from_email = row["from_email"], + reply_to = row["reply_to"], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_automation_workflow(self, row) -> AutomationWorkflow: """将数据库行转换为 AutomationWorkflow""" return AutomationWorkflow( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - trigger_type=WorkflowTriggerType(row["trigger_type"]), - trigger_conditions=json.loads(row["trigger_conditions"]), - actions=json.loads(row["actions"]), - is_active=bool(row["is_active"]), - execution_count=row["execution_count"], - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + trigger_type = WorkflowTriggerType(row["trigger_type"]), + trigger_conditions = json.loads(row["trigger_conditions"]), + actions = json.loads(row["actions"]), + is_active = bool(row["is_active"]), + execution_count = row["execution_count"], + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_referral_program(self, row) -> ReferralProgram: """将数据库行转换为 ReferralProgram""" return ReferralProgram( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - referrer_reward_type=row["referrer_reward_type"], - referrer_reward_value=row["referrer_reward_value"], - referee_reward_type=row["referee_reward_type"], - referee_reward_value=row["referee_reward_value"], - max_referrals_per_user=row["max_referrals_per_user"], - referral_code_length=row["referral_code_length"], - expiry_days=row["expiry_days"], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + referrer_reward_type = row["referrer_reward_type"], + referrer_reward_value = row["referrer_reward_value"], + referee_reward_type = row["referee_reward_type"], + referee_reward_value = row["referee_reward_value"], + max_referrals_per_user = row["max_referrals_per_user"], + referral_code_length = row["referral_code_length"], + expiry_days = row["expiry_days"], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_team_incentive(self, row) -> TeamIncentive: """将数据库行转换为 TeamIncentive""" return TeamIncentive( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - target_tier=row["target_tier"], - min_team_size=row["min_team_size"], - incentive_type=row["incentive_type"], - incentive_value=row["incentive_value"], - valid_from=datetime.fromisoformat(row["valid_from"]), - valid_until=datetime.fromisoformat(row["valid_until"]), - is_active=bool(row["is_active"]), - created_at=row["created_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + target_tier = row["target_tier"], + min_team_size = row["min_team_size"], + incentive_type = row["incentive_type"], + incentive_value = row["incentive_value"], + valid_from = datetime.fromisoformat(row["valid_from"]), + valid_until = datetime.fromisoformat(row["valid_until"]), + is_active = bool(row["is_active"]), + created_at = row["created_at"], ) # Singleton instance -_growth_manager = None +_growth_manager = None def get_growth_manager() -> GrowthManager: global _growth_manager if _growth_manager is None: - _growth_manager = GrowthManager() + _growth_manager = GrowthManager() return _growth_manager diff --git a/backend/image_processor.py b/backend/image_processor.py index e34c59b..c14950f 100644 --- a/backend/image_processor.py +++ b/backend/image_processor.py @@ -11,30 +11,30 @@ import uuid from dataclasses import dataclass # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # 尝试导入图像处理库 try: from PIL import Image, ImageEnhance, ImageFilter - PIL_AVAILABLE = True + PIL_AVAILABLE = True except ImportError: - PIL_AVAILABLE = False + PIL_AVAILABLE = False try: import cv2 import numpy as np - CV2_AVAILABLE = True + CV2_AVAILABLE = True except ImportError: - CV2_AVAILABLE = False + CV2_AVAILABLE = False try: import pytesseract - PYTESSERACT_AVAILABLE = True + PYTESSERACT_AVAILABLE = True except ImportError: - PYTESSERACT_AVAILABLE = False + PYTESSERACT_AVAILABLE = False @dataclass @@ -44,7 +44,7 @@ class ImageEntity: name: str type: str confidence: float - bbox: tuple[int, int, int, int] | None = None # (x, y, width, height) + bbox: tuple[int, int, int, int] | None = None # (x, y, width, height) @dataclass @@ -70,7 +70,7 @@ class ImageProcessingResult: width: int height: int success: bool - error_message: str = "" + error_message: str = "" @dataclass @@ -87,7 +87,7 @@ class ImageProcessor: """图片处理器 - 处理各种类型图片""" # 图片类型定义 - IMAGE_TYPES = { + IMAGE_TYPES = { "whiteboard": "白板", "ppt": "PPT/演示文稿", "handwritten": "手写笔记", @@ -96,17 +96,17 @@ class ImageProcessor: "other": "其他", } - def __init__(self, temp_dir: str = None) -> None: + def __init__(self, temp_dir: str = None) -> None: """ 初始化图片处理器 Args: temp_dir: 临时文件目录 """ - self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") - os.makedirs(self.temp_dir, exist_ok=True) + self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") + os.makedirs(self.temp_dir, exist_ok = True) - def preprocess_image(self, image, image_type: str = None) -> None: + def preprocess_image(self, image, image_type: str = None) -> None: """ 预处理图片以提高OCR质量 @@ -123,25 +123,25 @@ class ImageProcessor: try: # 转换为RGB(如果是RGBA) if image.mode == "RGBA": - image = image.convert("RGB") + image = image.convert("RGB") # 根据图片类型进行针对性处理 if image_type == "whiteboard": # 白板:增强对比度,去除背景 - image = self._enhance_whiteboard(image) + image = self._enhance_whiteboard(image) elif image_type == "handwritten": # 手写笔记:降噪,增强对比度 - image = self._enhance_handwritten(image) + image = self._enhance_handwritten(image) elif image_type == "screenshot": # 截图:轻微锐化 - image = image.filter(ImageFilter.SHARPEN) + image = image.filter(ImageFilter.SHARPEN) # 通用处理:调整大小(如果太大) - max_size = 4096 + max_size = 4096 if max(image.size) > max_size: - ratio = max_size / max(image.size) - new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) - image = image.resize(new_size, Image.Resampling.LANCZOS) + ratio = max_size / max(image.size) + new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) + image = image.resize(new_size, Image.Resampling.LANCZOS) return image except Exception as e: @@ -151,33 +151,33 @@ class ImageProcessor: def _enhance_whiteboard(self, image) -> None: """增强白板图片""" # 转换为灰度 - gray = image.convert("L") + gray = image.convert("L") # 增强对比度 - enhancer = ImageEnhance.Contrast(gray) - enhanced = enhancer.enhance(2.0) + enhancer = ImageEnhance.Contrast(gray) + enhanced = enhancer.enhance(2.0) # 二值化 - threshold = 128 - binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1") + threshold = 128 + binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1") return binary.convert("L") def _enhance_handwritten(self, image) -> None: """增强手写笔记图片""" # 转换为灰度 - gray = image.convert("L") + gray = image.convert("L") # 轻微降噪 - blurred = gray.filter(ImageFilter.GaussianBlur(radius=1)) + blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1)) # 增强对比度 - enhancer = ImageEnhance.Contrast(blurred) - enhanced = enhancer.enhance(1.5) + enhancer = ImageEnhance.Contrast(blurred) + enhanced = enhancer.enhance(1.5) return enhanced - def detect_image_type(self, image, ocr_text: str = "") -> str: + def detect_image_type(self, image, ocr_text: str = "") -> str: """ 自动检测图片类型 @@ -193,8 +193,8 @@ class ImageProcessor: try: # 基于图片特征和OCR内容判断类型 - width, height = image.size - aspect_ratio = width / height + width, height = image.size + aspect_ratio = width / height # 检测是否为PPT(通常是16:9或4:3) if 1.3 <= aspect_ratio <= 1.8: @@ -204,12 +204,12 @@ class ImageProcessor: # 检测是否为白板(大量手写文字,可能有箭头、框等) if CV2_AVAILABLE: - img_array = np.array(image.convert("RGB")) - gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) + img_array = np.array(image.convert("RGB")) + gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) # 检测边缘(白板通常有很多线条) - edges = cv2.Canny(gray, 50, 150) - edge_ratio = np.sum(edges > 0) / edges.size + edges = cv2.Canny(gray, 50, 150) + edge_ratio = np.sum(edges > 0) / edges.size # 如果边缘比例高,可能是白板 if edge_ratio > 0.05 and len(ocr_text) > 50: @@ -236,7 +236,7 @@ class ImageProcessor: print(f"Image type detection error: {e}") return "other" - def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]: + def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]: """ 对图片进行OCR识别 @@ -252,15 +252,15 @@ class ImageProcessor: try: # 预处理图片 - processed_image = self.preprocess_image(image) + processed_image = self.preprocess_image(image) # 执行OCR - text = pytesseract.image_to_string(processed_image, lang=lang) + text = pytesseract.image_to_string(processed_image, lang = lang) # 获取置信度 - data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT) - confidences = [int(c) for c in data["conf"] if int(c) > 0] - avg_confidence = sum(confidences) / len(confidences) if confidences else 0 + data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT) + confidences = [int(c) for c in data["conf"] if int(c) > 0] + avg_confidence = sum(confidences) / len(confidences) if confidences else 0 return text.strip(), avg_confidence / 100.0 except Exception as e: @@ -277,26 +277,26 @@ class ImageProcessor: Returns: 实体列表 """ - entities = [] + entities = [] # 简单的实体提取规则(可以替换为LLM调用) # 提取大写字母开头的词组(可能是专有名词) import re # 项目名称(通常是大写或带引号) - project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)' + project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)' for match in re.finditer(project_pattern, text): - name = match.group(1) or match.group(2) + name = match.group(1) or match.group(2) if name and len(name) > 2: - entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7)) + entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7)) # 人名(中文) - name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)" + name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)" for match in re.finditer(name_pattern, text): - entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8)) + entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8)) # 技术术语 - tech_keywords = [ + tech_keywords = [ "K8s", "Kubernetes", "Docker", @@ -314,13 +314,13 @@ class ImageProcessor: ] for keyword in tech_keywords: if keyword in text: - entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9)) + entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9)) # 去重 - seen = set() - unique_entities = [] + seen = set() + unique_entities = [] for e in entities: - key = (e.name.lower(), e.type) + key = (e.name.lower(), e.type) if key not in seen: seen.add(key) unique_entities.append(e) @@ -341,19 +341,19 @@ class ImageProcessor: Returns: 图片描述 """ - type_name = self.IMAGE_TYPES.get(image_type, "图片") + type_name = self.IMAGE_TYPES.get(image_type, "图片") - description_parts = [f"这是一张{type_name}图片。"] + description_parts = [f"这是一张{type_name}图片。"] if ocr_text: # 提取前200字符作为摘要 - text_preview = ocr_text[:200].replace("\n", " ") + text_preview = ocr_text[:200].replace("\n", " ") if len(ocr_text) > 200: text_preview += "..." description_parts.append(f"内容摘要:{text_preview}") if entities: - entity_names = [e.name for e in entities[:5]] # 最多显示5个实体 + entity_names = [e.name for e in entities[:5]] # 最多显示5个实体 description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}") return " ".join(description_parts) @@ -361,9 +361,9 @@ class ImageProcessor: def process_image( self, image_data: bytes, - filename: str = None, - image_id: str = None, - detect_type: bool = True, + filename: str = None, + image_id: str = None, + detect_type: bool = True, ) -> ImageProcessingResult: """ 处理单张图片 @@ -377,73 +377,73 @@ class ImageProcessor: Returns: 图片处理结果 """ - image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH] + image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH] if not PIL_AVAILABLE: return ImageProcessingResult( - image_id=image_id, - image_type="other", - ocr_text="", - description="PIL not available", - entities=[], - relations=[], - width=0, - height=0, - success=False, - error_message="PIL library not available", + image_id = image_id, + image_type = "other", + ocr_text = "", + description = "PIL not available", + entities = [], + relations = [], + width = 0, + height = 0, + success = False, + error_message = "PIL library not available", ) try: # 加载图片 - image = Image.open(io.BytesIO(image_data)) - width, height = image.size + image = Image.open(io.BytesIO(image_data)) + width, height = image.size # 执行OCR - ocr_text, ocr_confidence = self.perform_ocr(image) + ocr_text, ocr_confidence = self.perform_ocr(image) # 检测图片类型 - image_type = "other" + image_type = "other" if detect_type: - image_type = self.detect_image_type(image, ocr_text) + image_type = self.detect_image_type(image, ocr_text) # 提取实体 - entities = self.extract_entities_from_text(ocr_text) + entities = self.extract_entities_from_text(ocr_text) # 生成描述 - description = self.generate_description(image_type, ocr_text, entities) + description = self.generate_description(image_type, ocr_text, entities) # 提取关系(基于实体共现) - relations = self._extract_relations(entities, ocr_text) + relations = self._extract_relations(entities, ocr_text) # 保存图片文件(可选) if filename: - save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}") + save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}") image.save(save_path) return ImageProcessingResult( - image_id=image_id, - image_type=image_type, - ocr_text=ocr_text, - description=description, - entities=entities, - relations=relations, - width=width, - height=height, - success=True, + image_id = image_id, + image_type = image_type, + ocr_text = ocr_text, + description = description, + entities = entities, + relations = relations, + width = width, + height = height, + success = True, ) except Exception as e: return ImageProcessingResult( - image_id=image_id, - image_type="other", - ocr_text="", - description="", - entities=[], - relations=[], - width=0, - height=0, - success=False, - error_message=str(e), + image_id = image_id, + image_type = "other", + ocr_text = "", + description = "", + entities = [], + relations = [], + width = 0, + height = 0, + success = False, + error_message = str(e), ) def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]: @@ -457,16 +457,16 @@ class ImageProcessor: Returns: 关系列表 """ - relations = [] + relations = [] if len(entities) < 2: return relations # 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关 - sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".") + sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".") for sentence in sentences: - sentence_entities = [] + sentence_entities = [] for entity in entities: if entity.name in sentence: sentence_entities.append(entity) @@ -477,17 +477,17 @@ class ImageProcessor: for j in range(i + 1, len(sentence_entities)): relations.append( ImageRelation( - source=sentence_entities[i].name, - target=sentence_entities[j].name, - relation_type="related", - confidence=0.5, + source = sentence_entities[i].name, + target = sentence_entities[j].name, + relation_type = "related", + confidence = 0.5, ) ) return relations def process_batch( - self, images_data: list[tuple[bytes, str]], project_id: str = None + self, images_data: list[tuple[bytes, str]], project_id: str = None ) -> BatchProcessingResult: """ 批量处理图片 @@ -499,12 +499,12 @@ class ImageProcessor: Returns: 批量处理结果 """ - results = [] - success_count = 0 - failed_count = 0 + results = [] + success_count = 0 + failed_count = 0 for image_data, filename in images_data: - result = self.process_image(image_data, filename) + result = self.process_image(image_data, filename) results.append(result) if result.success: @@ -513,10 +513,10 @@ class ImageProcessor: failed_count += 1 return BatchProcessingResult( - results=results, - total_count=len(results), - success_count=success_count, - failed_count=failed_count, + results = results, + total_count = len(results), + success_count = success_count, + failed_count = failed_count, ) def image_to_base64(self, image_data: bytes) -> str: @@ -531,7 +531,7 @@ class ImageProcessor: """ return base64.b64encode(image_data).decode("utf-8") - def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes: + def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes: """ 生成图片缩略图 @@ -546,11 +546,11 @@ class ImageProcessor: return image_data try: - image = Image.open(io.BytesIO(image_data)) + image = Image.open(io.BytesIO(image_data)) image.thumbnail(size, Image.Resampling.LANCZOS) - buffer = io.BytesIO() - image.save(buffer, format="JPEG") + buffer = io.BytesIO() + image.save(buffer, format = "JPEG") return buffer.getvalue() except Exception as e: print(f"Thumbnail generation error: {e}") @@ -558,12 +558,12 @@ class ImageProcessor: # Singleton instance -_image_processor = None +_image_processor = None -def get_image_processor(temp_dir: str = None) -> ImageProcessor: +def get_image_processor(temp_dir: str = None) -> ImageProcessor: """获取图片处理器单例""" global _image_processor if _image_processor is None: - _image_processor = ImageProcessor(temp_dir) + _image_processor = ImageProcessor(temp_dir) return _image_processor diff --git a/backend/init_db.py b/backend/init_db.py index 7cd7778..db80146 100644 --- a/backend/init_db.py +++ b/backend/init_db.py @@ -4,27 +4,27 @@ import os import sqlite3 -db_path = os.path.join(os.path.dirname(__file__), "insightflow.db") -schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") +db_path = os.path.join(os.path.dirname(__file__), "insightflow.db") +schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") print(f"Database path: {db_path}") print(f"Schema path: {schema_path}") # Read schema with open(schema_path) as f: - schema = f.read() + schema = f.read() # Execute schema -conn = sqlite3.connect(db_path) -cursor = conn.cursor() +conn = sqlite3.connect(db_path) +cursor = conn.cursor() # Split schema by semicolons and execute each statement -statements = schema.split(";") -success_count = 0 -error_count = 0 +statements = schema.split(";") +success_count = 0 +error_count = 0 for stmt in statements: - stmt = stmt.strip() + stmt = stmt.strip() if stmt: try: cursor.execute(stmt) diff --git a/backend/knowledge_reasoner.py b/backend/knowledge_reasoner.py index 7924d08..beb7e2c 100644 --- a/backend/knowledge_reasoner.py +++ b/backend/knowledge_reasoner.py @@ -12,18 +12,18 @@ from enum import Enum import httpx -KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") -KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") +KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") +KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") class ReasoningType(Enum): """推理类型""" - CAUSAL = "causal" # 因果推理 - ASSOCIATIVE = "associative" # 关联推理 - TEMPORAL = "temporal" # 时序推理 - COMPARATIVE = "comparative" # 对比推理 - SUMMARY = "summary" # 总结推理 + CAUSAL = "causal" # 因果推理 + ASSOCIATIVE = "associative" # 关联推理 + TEMPORAL = "temporal" # 时序推理 + COMPARATIVE = "comparative" # 对比推理 + SUMMARY = "summary" # 总结推理 @dataclass @@ -51,38 +51,38 @@ class InferencePath: class KnowledgeReasoner: """知识推理引擎""" - def __init__(self, api_key: str = None, base_url: str = None): - self.api_key = api_key or KIMI_API_KEY - self.base_url = base_url or KIMI_BASE_URL - self.headers = { + def __init__(self, api_key: str = None, base_url: str = None) -> None: + self.api_key = api_key or KIMI_API_KEY + self.base_url = base_url or KIMI_BASE_URL + self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str: + async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str: """调用 LLM""" if not self.api_key: raise ValueError("KIMI_API_KEY not set") - payload = { + payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature, } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.base_url}/v1/chat/completions", - headers=self.headers, - json=payload, - timeout=120.0, + headers = self.headers, + json = payload, + timeout = 120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return result["choices"][0]["message"]["content"] async def enhanced_qa( - self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium" + self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium" ) -> ReasoningResult: """ 增强问答 - 结合图谱推理的问答 @@ -94,7 +94,7 @@ class KnowledgeReasoner: reasoning_depth: 推理深度 (shallow/medium/deep) """ # 1. 分析问题类型 - analysis = await self._analyze_question(query) + analysis = await self._analyze_question(query) # 2. 根据问题类型选择推理策略 if analysis["type"] == "causal": @@ -108,7 +108,7 @@ class KnowledgeReasoner: async def _analyze_question(self, query: str) -> dict: """分析问题类型和意图""" - prompt = f"""分析以下问题的类型和意图: + prompt = f"""分析以下问题的类型和意图: 问题:{query} @@ -127,9 +127,9 @@ class KnowledgeReasoner: - factual: 事实类问题(是什么、有哪些) - opinion: 观点类问题(怎么看、态度、评价)""" - content = await self._call_llm(prompt, temperature=0.1) + content = await self._call_llm(prompt, temperature = 0.1) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: return json.loads(json_match.group()) @@ -144,10 +144,10 @@ class KnowledgeReasoner: """因果推理 - 分析原因和影响""" # 构建因果分析提示 - entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2) - relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2) + entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2) + relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2) - prompt = f"""基于以下知识图谱进行因果推理分析: + prompt = f"""基于以下知识图谱进行因果推理分析: ## 问题 {query} @@ -159,7 +159,7 @@ class KnowledgeReasoner: {relations_str[:2000]} ## 项目上下文 -{json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]} +{json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]} 请进行因果分析,返回 JSON 格式: {{ @@ -172,31 +172,31 @@ class KnowledgeReasoner: "knowledge_gaps": ["缺失信息1"] }}""" - content = await self._call_llm(prompt, temperature=0.3) + content = await self._call_llm(prompt, temperature = 0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer=data.get("answer", ""), - reasoning_type=ReasoningType.CAUSAL, - confidence=data.get("confidence", 0.7), - evidence=[{"text": e} for e in data.get("evidence", [])], - related_entities=[], - gaps=data.get("knowledge_gaps", []), + answer = data.get("answer", ""), + reasoning_type = ReasoningType.CAUSAL, + confidence = data.get("confidence", 0.7), + evidence = [{"text": e} for e in data.get("evidence", [])], + related_entities = [], + gaps = data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer=content, - reasoning_type=ReasoningType.CAUSAL, - confidence=0.5, - evidence=[], - related_entities=[], - gaps=["无法完成因果推理"], + answer = content, + reasoning_type = ReasoningType.CAUSAL, + confidence = 0.5, + evidence = [], + related_entities = [], + gaps = ["无法完成因果推理"], ) async def _comparative_reasoning( @@ -204,16 +204,16 @@ class KnowledgeReasoner: ) -> ReasoningResult: """对比推理 - 比较实体间的异同""" - prompt = f"""基于以下知识图谱进行对比分析: + prompt = f"""基于以下知识图谱进行对比分析: ## 问题 {query} ## 实体 -{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]} +{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]} ## 关系 -{json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]} +{json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]} 请进行对比分析,返回 JSON 格式: {{ @@ -226,31 +226,31 @@ class KnowledgeReasoner: "knowledge_gaps": [] }}""" - content = await self._call_llm(prompt, temperature=0.3) + content = await self._call_llm(prompt, temperature = 0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer=data.get("answer", ""), - reasoning_type=ReasoningType.COMPARATIVE, - confidence=data.get("confidence", 0.7), - evidence=[{"text": e} for e in data.get("evidence", [])], - related_entities=[], - gaps=data.get("knowledge_gaps", []), + answer = data.get("answer", ""), + reasoning_type = ReasoningType.COMPARATIVE, + confidence = data.get("confidence", 0.7), + evidence = [{"text": e} for e in data.get("evidence", [])], + related_entities = [], + gaps = data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer=content, - reasoning_type=ReasoningType.COMPARATIVE, - confidence=0.5, - evidence=[], - related_entities=[], - gaps=[], + answer = content, + reasoning_type = ReasoningType.COMPARATIVE, + confidence = 0.5, + evidence = [], + related_entities = [], + gaps = [], ) async def _temporal_reasoning( @@ -258,16 +258,16 @@ class KnowledgeReasoner: ) -> ReasoningResult: """时序推理 - 分析时间线和演变""" - prompt = f"""基于以下知识图谱进行时序分析: + prompt = f"""基于以下知识图谱进行时序分析: ## 问题 {query} ## 项目时间线 -{json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]} +{json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]} ## 实体提及历史 -{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]} +{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]} 请进行时序分析,返回 JSON 格式: {{ @@ -280,31 +280,31 @@ class KnowledgeReasoner: "knowledge_gaps": [] }}""" - content = await self._call_llm(prompt, temperature=0.3) + content = await self._call_llm(prompt, temperature = 0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer=data.get("answer", ""), - reasoning_type=ReasoningType.TEMPORAL, - confidence=data.get("confidence", 0.7), - evidence=[{"text": e} for e in data.get("evidence", [])], - related_entities=[], - gaps=data.get("knowledge_gaps", []), + answer = data.get("answer", ""), + reasoning_type = ReasoningType.TEMPORAL, + confidence = data.get("confidence", 0.7), + evidence = [{"text": e} for e in data.get("evidence", [])], + related_entities = [], + gaps = data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer=content, - reasoning_type=ReasoningType.TEMPORAL, - confidence=0.5, - evidence=[], - related_entities=[], - gaps=[], + answer = content, + reasoning_type = ReasoningType.TEMPORAL, + confidence = 0.5, + evidence = [], + related_entities = [], + gaps = [], ) async def _associative_reasoning( @@ -312,16 +312,16 @@ class KnowledgeReasoner: ) -> ReasoningResult: """关联推理 - 发现实体间的隐含关联""" - prompt = f"""基于以下知识图谱进行关联分析: + prompt = f"""基于以下知识图谱进行关联分析: ## 问题 {query} ## 实体 -{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)} +{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)} ## 关系 -{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)} +{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)} 请进行关联推理,发现隐含联系,返回 JSON 格式: {{ @@ -334,52 +334,52 @@ class KnowledgeReasoner: "knowledge_gaps": [] }}""" - content = await self._call_llm(prompt, temperature=0.4) + content = await self._call_llm(prompt, temperature = 0.4) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: - data = json.loads(json_match.group()) + data = json.loads(json_match.group()) return ReasoningResult( - answer=data.get("answer", ""), - reasoning_type=ReasoningType.ASSOCIATIVE, - confidence=data.get("confidence", 0.7), - evidence=[{"text": e} for e in data.get("evidence", [])], - related_entities=[], - gaps=data.get("knowledge_gaps", []), + answer = data.get("answer", ""), + reasoning_type = ReasoningType.ASSOCIATIVE, + confidence = data.get("confidence", 0.7), + evidence = [{"text": e} for e in data.get("evidence", [])], + related_entities = [], + gaps = data.get("knowledge_gaps", []), ) except (json.JSONDecodeError, KeyError): pass return ReasoningResult( - answer=content, - reasoning_type=ReasoningType.ASSOCIATIVE, - confidence=0.5, - evidence=[], - related_entities=[], - gaps=[], + answer = content, + reasoning_type = ReasoningType.ASSOCIATIVE, + confidence = 0.5, + evidence = [], + related_entities = [], + gaps = [], ) def find_inference_paths( - self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3 + self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3 ) -> list[InferencePath]: """ 发现两个实体之间的推理路径 使用 BFS 在关系图中搜索路径 """ - relations = graph_data.get("relations", []) + relations = graph_data.get("relations", []) # 构建邻接表 - adj = {} + adj = {} for r in relations: - src = r.get("source_id") or r.get("source") - tgt = r.get("target_id") or r.get("target") + src = r.get("source_id") or r.get("source") + tgt = r.get("target_id") or r.get("target") if src not in adj: - adj[src] = [] + adj[src] = [] if tgt not in adj: - adj[tgt] = [] + adj[tgt] = [] adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r}) # 无向图也添加反向 adj[tgt].append( @@ -389,21 +389,21 @@ class KnowledgeReasoner: # BFS 搜索路径 from collections import deque - paths = [] - queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])]) + paths = [] + queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])]) {start_entity} while queue and len(paths) < 5: - current, path = queue.popleft() + current, path = queue.popleft() if current == end_entity and len(path) > 1: # 找到一条路径 paths.append( InferencePath( - start_entity=start_entity, - end_entity=end_entity, - path=path, - strength=self._calculate_path_strength(path), + start_entity = start_entity, + end_entity = end_entity, + path = path, + strength = self._calculate_path_strength(path), ) ) continue @@ -412,9 +412,9 @@ class KnowledgeReasoner: continue for neighbor in adj.get(current, []): - next_entity = neighbor["target"] + next_entity = neighbor["target"] if next_entity not in [p["entity"] for p in path]: # 避免循环 - new_path = path + [ + new_path = path + [ { "entity": next_entity, "relation": neighbor["relation"], @@ -424,7 +424,7 @@ class KnowledgeReasoner: queue.append((next_entity, new_path)) # 按强度排序 - paths.sort(key=lambda p: p.strength, reverse=True) + paths.sort(key = lambda p: p.strength, reverse = True) return paths def _calculate_path_strength(self, path: list[dict]) -> float: @@ -433,23 +433,23 @@ class KnowledgeReasoner: return 0.0 # 路径越短越强 - length_factor = 1.0 / len(path) + length_factor = 1.0 / len(path) # 关系置信度 - confidence_sum = 0 - confidence_count = 0 + confidence_sum = 0 + confidence_count = 0 for node in path[1:]: # 跳过第一个节点 - rel_data = node.get("relation_data", {}) + rel_data = node.get("relation_data", {}) if "confidence" in rel_data: confidence_sum += rel_data["confidence"] confidence_count += 1 - confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5 + confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5 return length_factor * confidence_factor async def summarize_project( - self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive" + self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive" ) -> dict: """ 项目智能总结 @@ -457,17 +457,17 @@ class KnowledgeReasoner: Args: summary_type: comprehensive/executive/technical/risk """ - type_prompts = { + type_prompts = { "comprehensive": "全面总结项目的所有方面", "executive": "高管摘要,关注关键决策和风险", "technical": "技术总结,关注架构和技术栈", "risk": "风险分析,关注潜在问题和依赖", } - prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}: + prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}: ## 项目信息 -{json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]} +{json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]} ## 知识图谱 实体数: {len(graph_data.get("entities", []))} @@ -483,9 +483,9 @@ class KnowledgeReasoner: "confidence": 0.85 }}""" - content = await self._call_llm(prompt, temperature=0.3) + content = await self._call_llm(prompt, temperature = 0.3) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if json_match: try: @@ -504,11 +504,11 @@ class KnowledgeReasoner: # Singleton instance -_reasoner = None +_reasoner = None def get_knowledge_reasoner() -> KnowledgeReasoner: global _reasoner if _reasoner is None: - _reasoner = KnowledgeReasoner() + _reasoner = KnowledgeReasoner() return _reasoner diff --git a/backend/llm_client.py b/backend/llm_client.py index 82a2991..368ffed 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -12,8 +12,8 @@ from dataclasses import dataclass import httpx -KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") -KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") +KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") +KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") @dataclass @@ -41,22 +41,22 @@ class RelationExtractionResult: class LLMClient: """Kimi API 客户端""" - def __init__(self, api_key: str = None, base_url: str = None): - self.api_key = api_key or KIMI_API_KEY - self.base_url = base_url or KIMI_BASE_URL - self.headers = { + def __init__(self, api_key: str = None, base_url: str = None) -> None: + self.api_key = api_key or KIMI_API_KEY + self.base_url = base_url or KIMI_BASE_URL + self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } async def chat( - self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False + self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False ) -> str: """发送聊天请求""" if not self.api_key: raise ValueError("KIMI_API_KEY not set") - payload = { + payload = { "model": "k2p5", "messages": [{"role": m.role, "content": m.content} for m in messages], "temperature": temperature, @@ -64,24 +64,24 @@ class LLMClient: } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( f"{self.base_url}/v1/chat/completions", - headers=self.headers, - json=payload, - timeout=120.0, + headers = self.headers, + json = payload, + timeout = 120.0, ) response.raise_for_status() - result = response.json() + result = response.json() return result["choices"][0]["message"]["content"] async def chat_stream( - self, messages: list[ChatMessage], temperature: float = 0.3 + self, messages: list[ChatMessage], temperature: float = 0.3 ) -> AsyncGenerator[str, None]: """流式聊天请求""" if not self.api_key: raise ValueError("KIMI_API_KEY not set") - payload = { + payload = { "model": "k2p5", "messages": [{"role": m.role, "content": m.content} for m in messages], "temperature": temperature, @@ -92,19 +92,19 @@ class LLMClient: async with client.stream( "POST", f"{self.base_url}/v1/chat/completions", - headers=self.headers, - json=payload, - timeout=120.0, + headers = self.headers, + json = payload, + timeout = 120.0, ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): - data = line[6:] + data = line[6:] if data == "[DONE]": break try: - chunk = json.loads(data) - delta = chunk["choices"][0]["delta"] + chunk = json.loads(data) + delta = chunk["choices"][0]["delta"] if "content" in delta: yield delta["content"] except (json.JSONDecodeError, KeyError, IndexError): @@ -114,7 +114,7 @@ class LLMClient: self, text: str ) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]: """提取实体和关系,带置信度分数""" - prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回: + prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回: 文本:{text[:3000]} @@ -139,30 +139,30 @@ class LLMClient: ] }}""" - messages = [ChatMessage(role="user", content=prompt)] - content = await self.chat(messages, temperature=0.1) + messages = [ChatMessage(role = "user", content = prompt)] + content = await self.chat(messages, temperature = 0.1) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if not json_match: return [], [] try: - data = json.loads(json_match.group()) - entities = [ + data = json.loads(json_match.group()) + entities = [ EntityExtractionResult( - name=e["name"], - type=e.get("type", "OTHER"), - definition=e.get("definition", ""), - confidence=e.get("confidence", 0.8), + name = e["name"], + type = e.get("type", "OTHER"), + definition = e.get("definition", ""), + confidence = e.get("confidence", 0.8), ) for e in data.get("entities", []) ] - relations = [ + relations = [ RelationExtractionResult( - source=r["source"], - target=r["target"], - type=r.get("type", "related"), - confidence=r.get("confidence", 0.8), + source = r["source"], + target = r["target"], + type = r.get("type", "related"), + confidence = r.get("confidence", 0.8), ) for r in data.get("relations", []) ] @@ -173,10 +173,10 @@ class LLMClient: async def rag_query(self, query: str, context: str, project_context: dict) -> str: """RAG 问答 - 基于项目上下文回答问题""" - prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: + prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: ## 项目信息 -{json.dumps(project_context, ensure_ascii=False, indent=2)} +{json.dumps(project_context, ensure_ascii = False, indent = 2)} ## 相关上下文 {context[:4000]} @@ -186,21 +186,21 @@ class LLMClient: 请用中文回答,保持简洁专业。如果信息不足,请明确说明。""" - messages = [ + messages = [ ChatMessage( - role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。" + role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。" ), - ChatMessage(role="user", content=prompt), + ChatMessage(role = "user", content = prompt), ] - return await self.chat(messages, temperature=0.3) + return await self.chat(messages, temperature = 0.3) async def agent_command(self, command: str, project_context: dict) -> dict: """Agent 指令解析 - 将自然语言指令转换为结构化操作""" - prompt = f"""解析以下用户指令,转换为结构化操作: + prompt = f"""解析以下用户指令,转换为结构化操作: ## 项目信息 -{json.dumps(project_context, ensure_ascii=False, indent=2)} +{json.dumps(project_context, ensure_ascii = False, indent = 2)} ## 用户指令 {command} @@ -221,10 +221,10 @@ class LLMClient: - create_relation: 创建关系,params 包含 source(源实体), target(目标实体), relation_type(关系类型) """ - messages = [ChatMessage(role="user", content=prompt)] - content = await self.chat(messages, temperature=0.1) + messages = [ChatMessage(role = "user", content = prompt)] + content = await self.chat(messages, temperature = 0.1) - json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) + json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) if not json_match: return {"intent": "unknown", "explanation": "无法解析指令"} @@ -235,14 +235,14 @@ class LLMClient: async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str: """分析实体在项目中的演变/态度变化""" - mentions_text = "\n".join( + mentions_text = "\n".join( [ f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20] ] # 限制数量 ) - prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化: + prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化: ## 提及记录 {mentions_text} @@ -255,16 +255,16 @@ class LLMClient: 用中文回答,结构清晰。""" - messages = [ChatMessage(role="user", content=prompt)] - return await self.chat(messages, temperature=0.3) + messages = [ChatMessage(role = "user", content = prompt)] + return await self.chat(messages, temperature = 0.3) # Singleton instance -_llm_client = None +_llm_client = None def get_llm_client() -> LLMClient: global _llm_client if _llm_client is None: - _llm_client = LLMClient() + _llm_client = LLMClient() return _llm_client diff --git a/backend/localization_manager.py b/backend/localization_manager.py index 6325c31..344a95f 100644 --- a/backend/localization_manager.py +++ b/backend/localization_manager.py @@ -22,90 +22,90 @@ from typing import Any try: import pytz - PYTZ_AVAILABLE = True + PYTZ_AVAILABLE = True except ImportError: - PYTZ_AVAILABLE = False + PYTZ_AVAILABLE = False try: from babel import Locale, dates, numbers - BABEL_AVAILABLE = True + BABEL_AVAILABLE = True except ImportError: - BABEL_AVAILABLE = False + BABEL_AVAILABLE = False -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class LanguageCode(StrEnum): """支持的语言代码""" - EN = "en" - ZH_CN = "zh_CN" - ZH_TW = "zh_TW" - JA = "ja" - KO = "ko" - DE = "de" - FR = "fr" - ES = "es" - PT = "pt" - RU = "ru" - AR = "ar" - HI = "hi" + EN = "en" + ZH_CN = "zh_CN" + ZH_TW = "zh_TW" + JA = "ja" + KO = "ko" + DE = "de" + FR = "fr" + ES = "es" + PT = "pt" + RU = "ru" + AR = "ar" + HI = "hi" class RegionCode(StrEnum): """区域代码""" - GLOBAL = "global" - NORTH_AMERICA = "na" - EUROPE = "eu" - ASIA_PACIFIC = "apac" - CHINA = "cn" - LATIN_AMERICA = "latam" - MIDDLE_EAST = "me" + GLOBAL = "global" + NORTH_AMERICA = "na" + EUROPE = "eu" + ASIA_PACIFIC = "apac" + CHINA = "cn" + LATIN_AMERICA = "latam" + MIDDLE_EAST = "me" class DataCenterRegion(StrEnum): """数据中心区域""" - US_EAST = "us-east" - US_WEST = "us-west" - EU_WEST = "eu-west" - EU_CENTRAL = "eu-central" - AP_SOUTHEAST = "ap-southeast" - AP_NORTHEAST = "ap-northeast" - AP_SOUTH = "ap-south" - CN_NORTH = "cn-north" - CN_EAST = "cn-east" + US_EAST = "us-east" + US_WEST = "us-west" + EU_WEST = "eu-west" + EU_CENTRAL = "eu-central" + AP_SOUTHEAST = "ap-southeast" + AP_NORTHEAST = "ap-northeast" + AP_SOUTH = "ap-south" + CN_NORTH = "cn-north" + CN_EAST = "cn-east" class PaymentProvider(StrEnum): """支付提供商""" - STRIPE = "stripe" - ALIPAY = "alipay" - WECHAT_PAY = "wechat_pay" - PAYPAL = "paypal" - APPLE_PAY = "apple_pay" - GOOGLE_PAY = "google_pay" - KLARNA = "klarna" - IDEAL = "ideal" - BANCONTACT = "bancontact" - GIROPAY = "giropay" - SEPA = "sepa" - UNIONPAY = "unionpay" + STRIPE = "stripe" + ALIPAY = "alipay" + WECHAT_PAY = "wechat_pay" + PAYPAL = "paypal" + APPLE_PAY = "apple_pay" + GOOGLE_PAY = "google_pay" + KLARNA = "klarna" + IDEAL = "ideal" + BANCONTACT = "bancontact" + GIROPAY = "giropay" + SEPA = "sepa" + UNIONPAY = "unionpay" class CalendarType(StrEnum): """日历类型""" - GREGORIAN = "gregorian" - CHINESE_LUNAR = "chinese_lunar" - ISLAMIC = "islamic" - HEBREW = "hebrew" - INDIAN = "indian" - PERSIAN = "persian" - BUDDHIST = "buddhist" + GREGORIAN = "gregorian" + CHINESE_LUNAR = "chinese_lunar" + ISLAMIC = "islamic" + HEBREW = "hebrew" + INDIAN = "indian" + PERSIAN = "persian" + BUDDHIST = "buddhist" @dataclass @@ -252,7 +252,7 @@ class LocalizationSettings: class LocalizationManager: - DEFAULT_LANGUAGES = { + DEFAULT_LANGUAGES = { LanguageCode.EN: { "name": "English", "name_local": "English", @@ -260,8 +260,8 @@ class LocalizationManager: "date_format": "MM/dd/yyyy", "time_format": "h:mm a", "datetime_format": "MM/dd/yyyy h:mm a", - "number_format": "#,##0.##", - "currency_format": "$#,##0.00", + "number_format": "#, ##0.##", + "currency_format": "$#, ##0.00", "first_day_of_week": 0, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -272,8 +272,8 @@ class LocalizationManager: "date_format": "yyyy-MM-dd", "time_format": "HH:mm", "datetime_format": "yyyy-MM-dd HH:mm", - "number_format": "#,##0.##", - "currency_format": "¥#,##0.00", + "number_format": "#, ##0.##", + "currency_format": "¥#, ##0.00", "first_day_of_week": 1, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -284,8 +284,8 @@ class LocalizationManager: "date_format": "yyyy/MM/dd", "time_format": "HH:mm", "datetime_format": "yyyy/MM/dd HH:mm", - "number_format": "#,##0.##", - "currency_format": "NT$#,##0.00", + "number_format": "#, ##0.##", + "currency_format": "NT$#, ##0.00", "first_day_of_week": 0, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -296,8 +296,8 @@ class LocalizationManager: "date_format": "yyyy/MM/dd", "time_format": "HH:mm", "datetime_format": "yyyy/MM/dd HH:mm", - "number_format": "#,##0.##", - "currency_format": "¥#,##0", + "number_format": "#, ##0.##", + "currency_format": "¥#, ##0", "first_day_of_week": 0, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -308,8 +308,8 @@ class LocalizationManager: "date_format": "yyyy. MM. dd", "time_format": "HH:mm", "datetime_format": "yyyy. MM. dd HH:mm", - "number_format": "#,##0.##", - "currency_format": "₩#,##0", + "number_format": "#, ##0.##", + "currency_format": "₩#, ##0", "first_day_of_week": 0, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -320,8 +320,8 @@ class LocalizationManager: "date_format": "dd.MM.yyyy", "time_format": "HH:mm", "datetime_format": "dd.MM.yyyy HH:mm", - "number_format": "#,##0.##", - "currency_format": "#,##0.00 €", + "number_format": "#, ##0.##", + "currency_format": "#, ##0.00 €", "first_day_of_week": 1, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -332,8 +332,8 @@ class LocalizationManager: "date_format": "dd/MM/yyyy", "time_format": "HH:mm", "datetime_format": "dd/MM/yyyy HH:mm", - "number_format": "#,##0.##", - "currency_format": "#,##0.00 €", + "number_format": "#, ##0.##", + "currency_format": "#, ##0.00 €", "first_day_of_week": 1, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -344,8 +344,8 @@ class LocalizationManager: "date_format": "dd/MM/yyyy", "time_format": "HH:mm", "datetime_format": "dd/MM/yyyy HH:mm", - "number_format": "#,##0.##", - "currency_format": "#,##0.00 €", + "number_format": "#, ##0.##", + "currency_format": "#, ##0.00 €", "first_day_of_week": 1, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -356,8 +356,8 @@ class LocalizationManager: "date_format": "dd/MM/yyyy", "time_format": "HH:mm", "datetime_format": "dd/MM/yyyy HH:mm", - "number_format": "#,##0.##", - "currency_format": "R$#,##0.00", + "number_format": "#, ##0.##", + "currency_format": "R$#, ##0.00", "first_day_of_week": 0, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -368,8 +368,8 @@ class LocalizationManager: "date_format": "dd.MM.yyyy", "time_format": "HH:mm", "datetime_format": "dd.MM.yyyy HH:mm", - "number_format": "#,##0.##", - "currency_format": "#,##0.00 ₽", + "number_format": "#, ##0.##", + "currency_format": "#, ##0.00 ₽", "first_day_of_week": 1, "calendar_type": CalendarType.GREGORIAN.value, }, @@ -380,8 +380,8 @@ class LocalizationManager: "date_format": "dd/MM/yyyy", "time_format": "hh:mm a", "datetime_format": "dd/MM/yyyy hh:mm a", - "number_format": "#,##0.##", - "currency_format": "#,##0.00 ر.س", + "number_format": "#, ##0.##", + "currency_format": "#, ##0.00 ر.س", "first_day_of_week": 6, "calendar_type": CalendarType.ISLAMIC.value, }, @@ -392,14 +392,14 @@ class LocalizationManager: "date_format": "dd/MM/yyyy", "time_format": "hh:mm a", "datetime_format": "dd/MM/yyyy hh:mm a", - "number_format": "#,##0.##", - "currency_format": "₹#,##0.00", + "number_format": "#, ##0.##", + "currency_format": "₹#, ##0.00", "first_day_of_week": 0, "calendar_type": CalendarType.INDIAN.value, }, } - DEFAULT_DATA_CENTERS = { + DEFAULT_DATA_CENTERS = { DataCenterRegion.US_EAST: { "name": "US East (Virginia)", "location": "Virginia, USA", @@ -474,7 +474,7 @@ class LocalizationManager: }, } - DEFAULT_PAYMENT_METHODS = { + DEFAULT_PAYMENT_METHODS = { PaymentProvider.STRIPE: { "name": "Credit Card", "name_local": { @@ -572,7 +572,7 @@ class LocalizationManager: }, } - DEFAULT_COUNTRIES = { + DEFAULT_COUNTRIES = { "US": { "name": "United States", "name_local": {"en": "United States"}, @@ -719,31 +719,31 @@ class LocalizationManager: }, } - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path - self._is_memory_db = db_path == ":memory:" - self._conn = None + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path + self._is_memory_db = db_path == ":memory:" + self._conn = None self._init_db() self._init_default_data() def _get_connection(self) -> sqlite3.Connection: if self._is_memory_db: if self._conn is None: - self._conn = sqlite3.connect(self.db_path) - self._conn.row_factory = sqlite3.Row + self._conn = sqlite3.connect(self.db_path) + self._conn.row_factory = sqlite3.Row return self._conn - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn - def _close_if_file_db(self, conn): + def _close_if_file_db(self, conn) -> None: if not self._is_memory_db: conn.close() - def _init_db(self): - conn = self._get_connection() + def _init_db(self) -> None: + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS translations ( id TEXT PRIMARY KEY, key TEXT NOT NULL, language TEXT NOT NULL, @@ -813,7 +813,7 @@ class LocalizationManager: CREATE TABLE IF NOT EXISTS currency_configs ( code TEXT PRIMARY KEY, name TEXT NOT NULL, name_local TEXT DEFAULT '{}', symbol TEXT NOT NULL, decimal_places INTEGER DEFAULT 2, decimal_separator TEXT DEFAULT '.', - thousands_separator TEXT DEFAULT ',', is_active INTEGER DEFAULT 1 + thousands_separator TEXT DEFAULT ', ', is_active INTEGER DEFAULT 1 ) """) cursor.execute(""" @@ -863,10 +863,10 @@ class LocalizationManager: finally: self._close_if_file_db(conn) - def _init_default_data(self): - conn = self._get_connection() + def _init_default_data(self) -> None: + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() for code, config in self.DEFAULT_LANGUAGES.items(): cursor.execute( """ @@ -894,7 +894,7 @@ class LocalizationManager: ), ) for region_code, config in self.DEFAULT_DATA_CENTERS.items(): - dc_id = str(uuid.uuid4()) + dc_id = str(uuid.uuid4()) cursor.execute( """ INSERT OR IGNORE INTO data_centers @@ -913,7 +913,7 @@ class LocalizationManager: ), ) for provider, config in self.DEFAULT_PAYMENT_METHODS.items(): - pm_id = str(uuid.uuid4()) + pm_id = str(uuid.uuid4()) cursor.execute( """ INSERT OR IGNORE INTO localized_payment_methods @@ -963,20 +963,20 @@ class LocalizationManager: self._close_if_file_db(conn) def get_translation( - self, key: str, language: str, namespace: str = "common", fallback: bool = True + self, key: str, language: str, namespace: str = "common", fallback: bool = True ) -> str | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - "SELECT value FROM translations WHERE key = ? AND language = ? AND namespace = ?", + "SELECT value FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return row["value"] if fallback: - lang_config = self.get_language_config(language) + lang_config = self.get_language_config(language) if lang_config and lang_config.fallback_language: return self.get_translation( key, lang_config.fallback_language, namespace, False @@ -992,24 +992,24 @@ class LocalizationManager: key: str, language: str, value: str, - namespace: str = "common", - context: str | None = None, + namespace: str = "common", + context: str | None = None, ) -> Translation: - conn = self._get_connection() + conn = self._get_connection() try: - translation_id = str(uuid.uuid4()) - now = datetime.now() - cursor = conn.cursor() + translation_id = str(uuid.uuid4()) + now = datetime.now() + cursor = conn.cursor() cursor.execute( """ INSERT INTO translations (id, key, language, value, namespace, context, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(key, language, namespace) DO UPDATE SET - value = excluded.value, - context = excluded.context, - updated_at = excluded.updated_at, - is_reviewed = 0 + value = excluded.value, + context = excluded.context, + updated_at = excluded.updated_at, + is_reviewed = 0 """, (translation_id, key, language, value, namespace, context, now, now), ) @@ -1021,22 +1021,22 @@ class LocalizationManager: def _get_translation_internal( self, conn: sqlite3.Connection, key: str, language: str, namespace: str ) -> Translation | None: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - "SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", + "SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_translation(row) return None - def delete_translation(self, key: str, language: str, namespace: str = "common") -> bool: - conn = self._get_connection() + def delete_translation(self, key: str, language: str, namespace: str = "common") -> bool: + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - "DELETE FROM translations WHERE key = ? AND language = ? AND namespace = ?", + "DELETE FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace), ) conn.commit() @@ -1046,62 +1046,62 @@ class LocalizationManager: def list_translations( self, - language: str | None = None, - namespace: str | None = None, - limit: int = 1000, - offset: int = 0, + language: str | None = None, + namespace: str | None = None, + limit: int = 1000, + offset: int = 0, ) -> list[Translation]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM translations WHERE 1=1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM translations WHERE 1 = 1" + params = [] if language: - query += " AND language = ?" + query += " AND language = ?" params.append(language) if namespace: - query += " AND namespace = ?" + query += " AND namespace = ?" params.append(namespace) query += " ORDER BY namespace, key LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_translation(row) for row in rows] finally: self._close_if_file_db(conn) def get_language_config(self, code: str) -> LanguageConfig | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code, )) + row = cursor.fetchone() if row: return self._row_to_language_config(row) return None finally: self._close_if_file_db(conn) - def list_language_configs(self, active_only: bool = True) -> list[LanguageConfig]: - conn = self._get_connection() + def list_language_configs(self, active_only: bool = True) -> list[LanguageConfig]: + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM language_configs" + cursor = conn.cursor() + query = "SELECT * FROM language_configs" if active_only: - query += " WHERE is_active = 1" + query += " WHERE is_active = 1" query += " ORDER BY name" cursor.execute(query) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_language_config(row) for row in rows] finally: self._close_if_file_db(conn) def get_data_center(self, dc_id: str) -> DataCenter | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id, )) + row = cursor.fetchone() if row: return self._row_to_data_center(row) return None @@ -1109,11 +1109,11 @@ class LocalizationManager: self._close_if_file_db(conn) def get_data_center_by_region(self, region_code: str) -> DataCenter | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code, )) + row = cursor.fetchone() if row: return self._row_to_data_center(row) return None @@ -1121,34 +1121,34 @@ class LocalizationManager: self._close_if_file_db(conn) def list_data_centers( - self, status: str | None = None, region: str | None = None + self, status: str | None = None, region: str | None = None ) -> list[DataCenter]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM data_centers WHERE 1=1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM data_centers WHERE 1 = 1" + params = [] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) if region: query += " AND supported_regions LIKE ?" params.append(f'%"{region}"%') query += " ORDER BY priority" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_data_center(row) for row in rows] finally: self._close_if_file_db(conn) def get_tenant_data_center(self, tenant_id: str) -> TenantDataCenterMapping | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - "SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,) + "SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id, ) ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_tenant_dc_mapping(row) return None @@ -1156,38 +1156,38 @@ class LocalizationManager: self._close_if_file_db(conn) def set_tenant_data_center( - self, tenant_id: str, region_code: str, data_residency: str = "regional" + self, tenant_id: str, region_code: str, data_residency: str = "regional" ) -> TenantDataCenterMapping: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ - SELECT * FROM data_centers WHERE supported_regions LIKE ? AND status = 'active' + SELECT * FROM data_centers WHERE supported_regions LIKE ? AND status = 'active' ORDER BY priority LIMIT 1 """, - (f'%"{region_code}"%',), + (f'%"{region_code}"%', ), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: cursor.execute(""" - SELECT * FROM data_centers WHERE supported_regions LIKE '%"global"%' AND status = 'active' + SELECT * FROM data_centers WHERE supported_regions LIKE '%"global"%' AND status = 'active' ORDER BY priority LIMIT 1 """) - row = cursor.fetchone() + row = cursor.fetchone() if not row: raise ValueError(f"No data center available for region: {region_code}") - primary_dc_id = row["id"] + primary_dc_id = row["id"] cursor.execute( """ - SELECT * FROM data_centers WHERE id != ? AND status = 'active' ORDER BY priority LIMIT 1 + SELECT * FROM data_centers WHERE id != ? AND status = 'active' ORDER BY priority LIMIT 1 """, - (primary_dc_id,), + (primary_dc_id, ), ) - secondary_row = cursor.fetchone() - secondary_dc_id = secondary_row["id"] if secondary_row else None - mapping_id = str(uuid.uuid4()) - now = datetime.now() + secondary_row = cursor.fetchone() + secondary_dc_id = secondary_row["id"] if secondary_row else None + mapping_id = str(uuid.uuid4()) + now = datetime.now() cursor.execute( """ INSERT INTO tenant_data_center_mappings @@ -1195,11 +1195,11 @@ class LocalizationManager: data_residency, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(tenant_id) DO UPDATE SET - primary_dc_id = excluded.primary_dc_id, - secondary_dc_id = excluded.secondary_dc_id, - region_code = excluded.region_code, - data_residency = excluded.data_residency, - updated_at = excluded.updated_at + primary_dc_id = excluded.primary_dc_id, + secondary_dc_id = excluded.secondary_dc_id, + region_code = excluded.region_code, + data_residency = excluded.data_residency, + updated_at = excluded.updated_at """, ( mapping_id, @@ -1218,13 +1218,13 @@ class LocalizationManager: self._close_if_file_db(conn) def get_payment_method(self, provider: str) -> LocalizedPaymentMethod | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - "SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,) + "SELECT * FROM localized_payment_methods WHERE provider = ?", (provider, ) ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_payment_method(row) return None @@ -1232,15 +1232,15 @@ class LocalizationManager: self._close_if_file_db(conn) def list_payment_methods( - self, country_code: str | None = None, currency: str | None = None, active_only: bool = True + self, country_code: str | None = None, currency: str | None = None, active_only: bool = True ) -> list[LocalizedPaymentMethod]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM localized_payment_methods WHERE 1=1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM localized_payment_methods WHERE 1 = 1" + params = [] if active_only: - query += " AND is_active = 1" + query += " AND is_active = 1" if country_code: query += " AND (supported_countries LIKE ? OR supported_countries LIKE '%\"*\"%')" params.append(f'%"{country_code}"%') @@ -1249,18 +1249,18 @@ class LocalizationManager: params.append(f'%"{currency}"%') query += " ORDER BY display_order" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_payment_method(row) for row in rows] finally: self._close_if_file_db(conn) def get_localized_payment_methods( - self, country_code: str, language: str = "en" + self, country_code: str, language: str = "en" ) -> list[dict[str, Any]]: - methods = self.list_payment_methods(country_code=country_code) - result = [] + methods = self.list_payment_methods(country_code = country_code) + result = [] for method in methods: - name_local = method.name_local.get(language, method.name) + name_local = method.name_local.get(language, method.name) result.append( { "id": method.id, @@ -1275,11 +1275,11 @@ class LocalizationManager: return result def get_country_config(self, code: str) -> CountryConfig | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code, )) + row = cursor.fetchone() if row: return self._row_to_country_config(row) return None @@ -1287,21 +1287,21 @@ class LocalizationManager: self._close_if_file_db(conn) def list_country_configs( - self, region: str | None = None, active_only: bool = True + self, region: str | None = None, active_only: bool = True ) -> list[CountryConfig]: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - query = "SELECT * FROM country_configs WHERE 1=1" - params = [] + cursor = conn.cursor() + query = "SELECT * FROM country_configs WHERE 1 = 1" + params = [] if active_only: - query += " AND is_active = 1" + query += " AND is_active = 1" if region: - query += " AND region = ?" + query += " AND region = ?" params.append(region) query += " ORDER BY name" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_country_config(row) for row in rows] finally: self._close_if_file_db(conn) @@ -1309,34 +1309,34 @@ class LocalizationManager: def format_datetime( self, dt: datetime, - language: str = "en", - timezone: str | None = None, - format_type: str = "datetime", + language: str = "en", + timezone: str | None = None, + format_type: str = "datetime", ) -> str: try: if timezone and PYTZ_AVAILABLE: - tz = pytz.timezone(timezone) + tz = pytz.timezone(timezone) if dt.tzinfo is None: - dt = pytz.UTC.localize(dt) - dt = dt.astimezone(tz) - lang_config = self.get_language_config(language) + dt = pytz.UTC.localize(dt) + dt = dt.astimezone(tz) + lang_config = self.get_language_config(language) if not lang_config: - lang_config = self.get_language_config("en") + lang_config = self.get_language_config("en") if format_type == "date": - fmt = lang_config.date_format if lang_config else "%Y-%m-%d" + fmt = lang_config.date_format if lang_config else "%Y-%m-%d" elif format_type == "time": - fmt = lang_config.time_format if lang_config else "%H:%M" + fmt = lang_config.time_format if lang_config else "%H:%M" else: - fmt = lang_config.datetime_format if lang_config else "%Y-%m-%d %H:%M" + fmt = lang_config.datetime_format if lang_config else "%Y-%m-%d %H:%M" if BABEL_AVAILABLE: try: - locale = Locale.parse(language.replace("_", "-")) + locale = Locale.parse(language.replace("_", "-")) if format_type == "date": - return dates.format_date(dt, locale=locale) + return dates.format_date(dt, locale = locale) elif format_type == "time": - return dates.format_time(dt, locale=locale) + return dates.format_time(dt, locale = locale) else: - return dates.format_datetime(dt, locale=locale) + return dates.format_datetime(dt, locale = locale) except (ValueError, AttributeError): pass return dt.strftime(fmt) @@ -1345,33 +1345,33 @@ class LocalizationManager: return dt.strftime("%Y-%m-%d %H:%M") def format_number( - self, number: float, language: str = "en", decimal_places: int | None = None + self, number: float, language: str = "en", decimal_places: int | None = None ) -> str: try: if BABEL_AVAILABLE: try: - locale = Locale.parse(language.replace("_", "-")) + locale = Locale.parse(language.replace("_", "-")) return numbers.format_decimal( - number, locale=locale, decimal_quantization=(decimal_places is not None) + number, locale = locale, decimal_quantization = (decimal_places is not None) ) except (ValueError, AttributeError): pass if decimal_places is not None: - return f"{number:,.{decimal_places}f}" - return f"{number:,}" + return f"{number:, .{decimal_places}f}" + return f"{number:, }" except Exception as e: logger.error(f"Error formatting number: {e}") return str(number) - def format_currency(self, amount: float, currency: str, language: str = "en") -> str: + def format_currency(self, amount: float, currency: str, language: str = "en") -> str: try: if BABEL_AVAILABLE: try: - locale = Locale.parse(language.replace("_", "-")) - return numbers.format_currency(amount, currency, locale=locale) + locale = Locale.parse(language.replace("_", "-")) + return numbers.format_currency(amount, currency, locale = locale) except (ValueError, AttributeError): pass - return f"{currency} {amount:,.2f}" + return f"{currency} {amount:, .2f}" except Exception as e: logger.error(f"Error formatting currency: {e}") return f"{currency} {amount:.2f}" @@ -1379,10 +1379,10 @@ class LocalizationManager: def convert_timezone(self, dt: datetime, from_tz: str, to_tz: str) -> datetime: try: if PYTZ_AVAILABLE: - from_zone = pytz.timezone(from_tz) - to_zone = pytz.timezone(to_tz) + from_zone = pytz.timezone(from_tz) + to_zone = pytz.timezone(to_tz) if dt.tzinfo is None: - dt = from_zone.localize(dt) + dt = from_zone.localize(dt) return dt.astimezone(to_zone) return dt except Exception as e: @@ -1392,8 +1392,8 @@ class LocalizationManager: def get_calendar_info(self, calendar_type: str, year: int, month: int) -> dict[str, Any]: import calendar - cal = calendar.Calendar() - month_days = cal.monthdayscalendar(year, month) + cal = calendar.Calendar() + month_days = cal.monthdayscalendar(year, month) return { "calendar_type": calendar_type, "year": year, @@ -1405,11 +1405,11 @@ class LocalizationManager: } def get_localization_settings(self, tenant_id: str) -> LocalizationSettings | None: - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id, )) + row = cursor.fetchone() if row: return self._row_to_localization_settings(row) return None @@ -1419,22 +1419,22 @@ class LocalizationManager: def create_localization_settings( self, tenant_id: str, - default_language: str = "en", - supported_languages: list[str] | None = None, - default_currency: str = "USD", - supported_currencies: list[str] | None = None, - default_timezone: str = "UTC", - region_code: str = "global", - data_residency: str = "regional", + default_language: str = "en", + supported_languages: list[str] | None = None, + default_currency: str = "USD", + supported_currencies: list[str] | None = None, + default_timezone: str = "UTC", + region_code: str = "global", + data_residency: str = "regional", ) -> LocalizationSettings: - conn = self._get_connection() + conn = self._get_connection() try: - settings_id = str(uuid.uuid4()) - now = datetime.now() - supported_languages = supported_languages or [default_language] - supported_currencies = supported_currencies or [default_currency] - lang_config = self.get_language_config(default_language) - cursor = conn.cursor() + settings_id = str(uuid.uuid4()) + now = datetime.now() + supported_languages = supported_languages or [default_language] + supported_currencies = supported_currencies or [default_currency] + lang_config = self.get_language_config(default_language) + cursor = conn.cursor() cursor.execute( """ INSERT INTO localization_settings @@ -1453,7 +1453,7 @@ class LocalizationManager: default_timezone, lang_config.date_format if lang_config else "%Y-%m-%d", lang_config.time_format if lang_config else "%H:%M", - lang_config.number_format if lang_config else "#,##0.##", + lang_config.number_format if lang_config else "#, ##0.##", lang_config.calendar_type if lang_config else CalendarType.GREGORIAN.value, lang_config.first_day_of_week if lang_config else 1, region_code, @@ -1468,14 +1468,14 @@ class LocalizationManager: self._close_if_file_db(conn) def update_localization_settings(self, tenant_id: str, **kwargs) -> LocalizationSettings | None: - conn = self._get_connection() + conn = self._get_connection() try: - settings = self.get_localization_settings(tenant_id) + settings = self.get_localization_settings(tenant_id) if not settings: return None - updates = [] - params = [] - allowed_fields = [ + updates = [] + params = [] + allowed_fields = [ "default_language", "supported_languages", "default_currency", @@ -1491,7 +1491,7 @@ class LocalizationManager: ] for key, value in kwargs.items(): if key in allowed_fields: - updates.append(f"{key} = ?") + updates.append(f"{key} = ?") if key in ["supported_languages", "supported_currencies"]: params.append(json.dumps(value) if value else "[]") elif key == "first_day_of_week": @@ -1500,12 +1500,12 @@ class LocalizationManager: params.append(value) if not updates: return settings - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(tenant_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - f"UPDATE localization_settings SET {', '.join(updates)} WHERE tenant_id = ?", params + f"UPDATE localization_settings SET {', '.join(updates)} WHERE tenant_id = ?", params ) conn.commit() return self.get_localization_settings(tenant_id) @@ -1513,48 +1513,48 @@ class LocalizationManager: self._close_if_file_db(conn) def detect_user_preferences( - self, accept_language: str | None = None, ip_country: str | None = None + self, accept_language: str | None = None, ip_country: str | None = None ) -> dict[str, str]: - preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} + preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} if accept_language: - langs = accept_language.split(",") + langs = accept_language.split(", ") for lang in langs: - lang_code = lang.split(";")[0].strip().replace("-", "_") - lang_config = self.get_language_config(lang_code) + lang_code = lang.split(";")[0].strip().replace("-", "_") + lang_config = self.get_language_config(lang_code) if lang_config and lang_config.is_active: - preferences["language"] = lang_code + preferences["language"] = lang_code break if ip_country: - country = self.get_country_config(ip_country) + country = self.get_country_config(ip_country) if country: - preferences["country"] = ip_country - preferences["currency"] = country.default_currency - preferences["timezone"] = country.timezone + preferences["country"] = ip_country + preferences["currency"] = country.default_currency + preferences["timezone"] = country.timezone if country.default_language not in preferences["language"]: - preferences["language"] = country.default_language + preferences["language"] = country.default_language return preferences def _row_to_translation(self, row: sqlite3.Row) -> Translation: return Translation( - id=row["id"], - key=row["key"], - language=row["language"], - value=row["value"], - namespace=row["namespace"], - context=row["context"], - created_at=( + id = row["id"], + key = row["key"], + language = row["language"], + value = row["value"], + namespace = row["namespace"], + context = row["context"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - is_reviewed=bool(row["is_reviewed"]), - reviewed_by=row["reviewed_by"], - reviewed_at=( + is_reviewed = bool(row["is_reviewed"]), + reviewed_by = row["reviewed_by"], + reviewed_at = ( datetime.fromisoformat(row["reviewed_at"]) if row["reviewed_at"] and isinstance(row["reviewed_at"], str) else row["reviewed_at"] @@ -1563,39 +1563,39 @@ class LocalizationManager: def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig: return LanguageConfig( - code=row["code"], - name=row["name"], - name_local=row["name_local"], - is_rtl=bool(row["is_rtl"]), - is_active=bool(row["is_active"]), - is_default=bool(row["is_default"]), - fallback_language=row["fallback_language"], - date_format=row["date_format"], - time_format=row["time_format"], - datetime_format=row["datetime_format"], - number_format=row["number_format"], - currency_format=row["currency_format"], - first_day_of_week=row["first_day_of_week"], - calendar_type=row["calendar_type"], + code = row["code"], + name = row["name"], + name_local = row["name_local"], + is_rtl = bool(row["is_rtl"]), + is_active = bool(row["is_active"]), + is_default = bool(row["is_default"]), + fallback_language = row["fallback_language"], + date_format = row["date_format"], + time_format = row["time_format"], + datetime_format = row["datetime_format"], + number_format = row["number_format"], + currency_format = row["currency_format"], + first_day_of_week = row["first_day_of_week"], + calendar_type = row["calendar_type"], ) def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter: return DataCenter( - id=row["id"], - region_code=row["region_code"], - name=row["name"], - location=row["location"], - endpoint=row["endpoint"], - status=row["status"], - priority=row["priority"], - supported_regions=json.loads(row["supported_regions"] or "[]"), - capabilities=json.loads(row["capabilities"] or "{}"), - created_at=( + id = row["id"], + region_code = row["region_code"], + name = row["name"], + location = row["location"], + endpoint = row["endpoint"], + status = row["status"], + priority = row["priority"], + supported_regions = json.loads(row["supported_regions"] or "[]"), + capabilities = json.loads(row["capabilities"] or "{}"), + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1604,18 +1604,18 @@ class LocalizationManager: def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping: return TenantDataCenterMapping( - id=row["id"], - tenant_id=row["tenant_id"], - primary_dc_id=row["primary_dc_id"], - secondary_dc_id=row["secondary_dc_id"], - region_code=row["region_code"], - data_residency=row["data_residency"], - created_at=( + id = row["id"], + tenant_id = row["tenant_id"], + primary_dc_id = row["primary_dc_id"], + secondary_dc_id = row["secondary_dc_id"], + region_code = row["region_code"], + data_residency = row["data_residency"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1624,24 +1624,24 @@ class LocalizationManager: def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod: return LocalizedPaymentMethod( - id=row["id"], - provider=row["provider"], - name=row["name"], - name_local=json.loads(row["name_local"] or "{}"), - supported_countries=json.loads(row["supported_countries"] or "[]"), - supported_currencies=json.loads(row["supported_currencies"] or "[]"), - is_active=bool(row["is_active"]), - config=json.loads(row["config"] or "{}"), - icon_url=row["icon_url"], - display_order=row["display_order"], - min_amount=row["min_amount"], - max_amount=row["max_amount"], - created_at=( + id = row["id"], + provider = row["provider"], + name = row["name"], + name_local = json.loads(row["name_local"] or "{}"), + supported_countries = json.loads(row["supported_countries"] or "[]"), + supported_currencies = json.loads(row["supported_currencies"] or "[]"), + is_active = bool(row["is_active"]), + config = json.loads(row["config"] or "{}"), + icon_url = row["icon_url"], + display_order = row["display_order"], + min_amount = row["min_amount"], + max_amount = row["max_amount"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1650,48 +1650,48 @@ class LocalizationManager: def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig: return CountryConfig( - code=row["code"], - code3=row["code3"], - name=row["name"], - name_local=json.loads(row["name_local"] or "{}"), - region=row["region"], - default_language=row["default_language"], - supported_languages=json.loads(row["supported_languages"] or "[]"), - default_currency=row["default_currency"], - supported_currencies=json.loads(row["supported_currencies"] or "[]"), - timezone=row["timezone"], - calendar_type=row["calendar_type"], - date_format=row["date_format"], - time_format=row["time_format"], - number_format=row["number_format"], - address_format=row["address_format"], - phone_format=row["phone_format"], - vat_rate=row["vat_rate"], - is_active=bool(row["is_active"]), + code = row["code"], + code3 = row["code3"], + name = row["name"], + name_local = json.loads(row["name_local"] or "{}"), + region = row["region"], + default_language = row["default_language"], + supported_languages = json.loads(row["supported_languages"] or "[]"), + default_currency = row["default_currency"], + supported_currencies = json.loads(row["supported_currencies"] or "[]"), + timezone = row["timezone"], + calendar_type = row["calendar_type"], + date_format = row["date_format"], + time_format = row["time_format"], + number_format = row["number_format"], + address_format = row["address_format"], + phone_format = row["phone_format"], + vat_rate = row["vat_rate"], + is_active = bool(row["is_active"]), ) def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings: return LocalizationSettings( - id=row["id"], - tenant_id=row["tenant_id"], - default_language=row["default_language"], - supported_languages=json.loads(row["supported_languages"] or '["en"]'), - default_currency=row["default_currency"], - supported_currencies=json.loads(row["supported_currencies"] or '["USD"]'), - default_timezone=row["default_timezone"], - default_date_format=row["default_date_format"], - default_time_format=row["default_time_format"], - default_number_format=row["default_number_format"], - calendar_type=row["calendar_type"], - first_day_of_week=row["first_day_of_week"], - region_code=row["region_code"], - data_residency=row["data_residency"], - created_at=( + id = row["id"], + tenant_id = row["tenant_id"], + default_language = row["default_language"], + supported_languages = json.loads(row["supported_languages"] or '["en"]'), + default_currency = row["default_currency"], + supported_currencies = json.loads(row["supported_currencies"] or '["USD"]'), + default_timezone = row["default_timezone"], + default_date_format = row["default_date_format"], + default_time_format = row["default_time_format"], + default_number_format = row["default_number_format"], + calendar_type = row["calendar_type"], + first_day_of_week = row["first_day_of_week"], + region_code = row["region_code"], + data_residency = row["data_residency"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1699,11 +1699,11 @@ class LocalizationManager: ) -_localization_manager = None +_localization_manager = None -def get_localization_manager(db_path: str = "insightflow.db") -> LocalizationManager: +def get_localization_manager(db_path: str = "insightflow.db") -> LocalizationManager: global _localization_manager if _localization_manager is None: - _localization_manager = LocalizationManager(db_path) + _localization_manager = LocalizationManager(db_path) return _localization_manager diff --git a/backend/multimodal_entity_linker.py b/backend/multimodal_entity_linker.py index ca73030..803f566 100644 --- a/backend/multimodal_entity_linker.py +++ b/backend/multimodal_entity_linker.py @@ -9,13 +9,13 @@ from dataclasses import dataclass from difflib import SequenceMatcher # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # 尝试导入embedding库 try: - NUMPY_AVAILABLE = True + NUMPY_AVAILABLE = True except ImportError: - NUMPY_AVAILABLE = False + NUMPY_AVAILABLE = False @dataclass @@ -30,11 +30,11 @@ class MultimodalEntity: source_id: str mention_context: str confidence: float - modality_features: dict = None # 模态特定特征 + modality_features: dict = None # 模态特定特征 - def __post_init__(self): + def __post_init__(self) -> None: if self.modality_features is None: - self.modality_features = {} + self.modality_features = {} @dataclass @@ -78,7 +78,7 @@ class MultimodalEntityLinker: """多模态实体关联器 - 跨模态实体对齐和知识融合""" # 关联类型 - LINK_TYPES = { + LINK_TYPES = { "same_as": "同一实体", "related_to": "相关实体", "part_of": "组成部分", @@ -86,16 +86,16 @@ class MultimodalEntityLinker: } # 模态类型 - MODALITIES = ["audio", "video", "image", "document"] + MODALITIES = ["audio", "video", "image", "document"] - def __init__(self, similarity_threshold: float = 0.85) -> None: + def __init__(self, similarity_threshold: float = 0.85) -> None: """ 初始化多模态实体关联器 Args: similarity_threshold: 相似度阈值 """ - self.similarity_threshold = similarity_threshold + self.similarity_threshold = similarity_threshold def calculate_string_similarity(self, s1: str, s2: str) -> float: """ @@ -111,7 +111,7 @@ class MultimodalEntityLinker: if not s1 or not s2: return 0.0 - s1, s2 = s1.lower().strip(), s2.lower().strip() + s1, s2 = s1.lower().strip(), s2.lower().strip() # 完全匹配 if s1 == s2: @@ -136,7 +136,7 @@ class MultimodalEntityLinker: (相似度, 匹配类型) """ # 名称相似度 - name_sim = self.calculate_string_similarity( + name_sim = self.calculate_string_similarity( entity1.get("name", ""), entity2.get("name", "") ) @@ -145,8 +145,8 @@ class MultimodalEntityLinker: return 1.0, "exact" # 检查别名 - aliases1 = set(a.lower() for a in entity1.get("aliases", [])) - aliases2 = set(a.lower() for a in entity2.get("aliases", [])) + aliases1 = set(a.lower() for a in entity1.get("aliases", [])) + aliases2 = set(a.lower() for a in entity2.get("aliases", [])) if aliases1 & aliases2: # 有共同别名 return 0.95, "alias_match" @@ -157,12 +157,12 @@ class MultimodalEntityLinker: return 0.95, "alias_match" # 定义相似度 - def_sim = self.calculate_string_similarity( + def_sim = self.calculate_string_similarity( entity1.get("definition", ""), entity2.get("definition", "") ) # 综合相似度 - combined_sim = name_sim * 0.7 + def_sim * 0.3 + combined_sim = name_sim * 0.7 + def_sim * 0.3 if combined_sim >= self.similarity_threshold: return combined_sim, "fuzzy" @@ -170,7 +170,7 @@ class MultimodalEntityLinker: return combined_sim, "none" def find_matching_entity( - self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None + self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None ) -> AlignmentResult | None: """ 在候选实体中查找匹配的实体 @@ -183,28 +183,28 @@ class MultimodalEntityLinker: Returns: 对齐结果 """ - exclude_ids = exclude_ids or set() - best_match = None - best_similarity = 0.0 + exclude_ids = exclude_ids or set() + best_match = None + best_similarity = 0.0 for candidate in candidate_entities: if candidate.get("id") in exclude_ids: continue - similarity, match_type = self.calculate_entity_similarity(query_entity, candidate) + similarity, match_type = self.calculate_entity_similarity(query_entity, candidate) if similarity > best_similarity and similarity >= self.similarity_threshold: - best_similarity = similarity - best_match = candidate - best_match_type = match_type + best_similarity = similarity + best_match = candidate + best_match_type = match_type if best_match: return AlignmentResult( - entity_id=query_entity.get("id"), - matched_entity_id=best_match.get("id"), - similarity=best_similarity, - match_type=best_match_type, - confidence=best_similarity, + entity_id = query_entity.get("id"), + matched_entity_id = best_match.get("id"), + similarity = best_similarity, + match_type = best_match_type, + confidence = best_similarity, ) return None @@ -230,10 +230,10 @@ class MultimodalEntityLinker: Returns: 实体关联列表 """ - links = [] + links = [] # 合并所有实体 - all_entities = { + all_entities = { "audio": audio_entities, "video": video_entities, "image": image_entities, @@ -246,24 +246,24 @@ class MultimodalEntityLinker: if mod1 >= mod2: # 避免重复比较 continue - entities1 = all_entities.get(mod1, []) - entities2 = all_entities.get(mod2, []) + entities1 = all_entities.get(mod1, []) + entities2 = all_entities.get(mod2, []) for ent1 in entities1: # 在另一个模态中查找匹配 - result = self.find_matching_entity(ent1, entities2) + result = self.find_matching_entity(ent1, entities2) if result and result.matched_entity_id: - link = EntityLink( - id=str(uuid.uuid4())[:UUID_LENGTH], - project_id=project_id, - source_entity_id=ent1.get("id"), - target_entity_id=result.matched_entity_id, - link_type="same_as" if result.similarity > 0.95 else "related_to", - source_modality=mod1, - target_modality=mod2, - confidence=result.confidence, - evidence=f"Cross-modal alignment: {result.match_type}", + link = EntityLink( + id = str(uuid.uuid4())[:UUID_LENGTH], + project_id = project_id, + source_entity_id = ent1.get("id"), + target_entity_id = result.matched_entity_id, + link_type = "same_as" if result.similarity > 0.95 else "related_to", + source_modality = mod1, + target_modality = mod2, + confidence = result.confidence, + evidence = f"Cross-modal alignment: {result.match_type}", ) links.append(link) @@ -284,7 +284,7 @@ class MultimodalEntityLinker: 融合结果 """ # 收集所有属性 - fused_properties = { + fused_properties = { "names": set(), "definitions": [], "aliases": set(), @@ -293,7 +293,7 @@ class MultimodalEntityLinker: "contexts": [], } - merged_ids = [] + merged_ids = [] for entity in linked_entities: merged_ids.append(entity.get("id")) @@ -318,21 +318,21 @@ class MultimodalEntityLinker: fused_properties["contexts"].append(mention.get("mention_context")) # 选择最佳定义(最长的那个) - best_definition = ( - max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else "" + best_definition = ( + max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else "" ) # 选择最佳名称(最常见的那个) from collections import Counter - name_counts = Counter(fused_properties["names"]) - best_name = name_counts.most_common(1)[0][0] if name_counts else "" + name_counts = Counter(fused_properties["names"]) + best_name = name_counts.most_common(1)[0][0] if name_counts else "" # 构建融合结果 return FusionResult( - canonical_entity_id=entity_id, - merged_entity_ids=merged_ids, - fused_properties={ + canonical_entity_id = entity_id, + merged_entity_ids = merged_ids, + fused_properties = { "name": best_name, "definition": best_definition, "aliases": list(fused_properties["aliases"]), @@ -340,8 +340,8 @@ class MultimodalEntityLinker: "modalities": list(fused_properties["modalities"]), "contexts": fused_properties["contexts"][:10], # 最多10个上下文 }, - source_modalities=list(fused_properties["modalities"]), - confidence=min(1.0, len(linked_entities) * 0.2 + 0.5), + source_modalities = list(fused_properties["modalities"]), + confidence = min(1.0, len(linked_entities) * 0.2 + 0.5), ) def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]: @@ -354,30 +354,30 @@ class MultimodalEntityLinker: Returns: 冲突列表 """ - conflicts = [] + conflicts = [] # 按名称分组 - name_groups = {} + name_groups = {} for entity in entities: - name = entity.get("name", "").lower() + name = entity.get("name", "").lower() if name: if name not in name_groups: - name_groups[name] = [] + name_groups[name] = [] name_groups[name].append(entity) # 检测同名但定义不同的实体 for name, group in name_groups.items(): if len(group) > 1: # 检查定义是否相似 - definitions = [e.get("definition", "") for e in group if e.get("definition")] + definitions = [e.get("definition", "") for e in group if e.get("definition")] if len(definitions) > 1: # 计算定义之间的相似度 - sim_matrix = [] + sim_matrix = [] for i, d1 in enumerate(definitions): for j, d2 in enumerate(definitions): if i < j: - sim = self.calculate_string_similarity(d1, d2) + sim = self.calculate_string_similarity(d1, d2) sim_matrix.append(sim) # 如果定义相似度都很低,可能是冲突 @@ -394,7 +394,7 @@ class MultimodalEntityLinker: return conflicts def suggest_entity_merges( - self, entities: list[dict], existing_links: list[EntityLink] = None + self, entities: list[dict], existing_links: list[EntityLink] = None ) -> list[dict]: """ 建议实体合并 @@ -406,13 +406,13 @@ class MultimodalEntityLinker: Returns: 合并建议列表 """ - suggestions = [] - existing_pairs = set() + suggestions = [] + existing_pairs = set() # 记录已有的关联 if existing_links: for link in existing_links: - pair = tuple(sorted([link.source_entity_id, link.target_entity_id])) + pair = tuple(sorted([link.source_entity_id, link.target_entity_id])) existing_pairs.add(pair) # 检查所有实体对 @@ -422,12 +422,12 @@ class MultimodalEntityLinker: continue # 检查是否已有关联 - pair = tuple(sorted([ent1.get("id"), ent2.get("id")])) + pair = tuple(sorted([ent1.get("id"), ent2.get("id")])) if pair in existing_pairs: continue # 计算相似度 - similarity, match_type = self.calculate_entity_similarity(ent1, ent2) + similarity, match_type = self.calculate_entity_similarity(ent1, ent2) if similarity >= self.similarity_threshold: suggestions.append( @@ -441,7 +441,7 @@ class MultimodalEntityLinker: ) # 按相似度排序 - suggestions.sort(key=lambda x: x["similarity"], reverse=True) + suggestions.sort(key = lambda x: x["similarity"], reverse = True) return suggestions @@ -451,8 +451,8 @@ class MultimodalEntityLinker: entity_id: str, source_type: str, source_id: str, - mention_context: str = "", - confidence: float = 1.0, + mention_context: str = "", + confidence: float = 1.0, ) -> MultimodalEntity: """ 创建多模态实体记录 @@ -469,14 +469,14 @@ class MultimodalEntityLinker: 多模态实体记录 """ return MultimodalEntity( - id=str(uuid.uuid4())[:UUID_LENGTH], - entity_id=entity_id, - project_id=project_id, - name="", # 将在后续填充 - source_type=source_type, - source_id=source_id, - mention_context=mention_context, - confidence=confidence, + id = str(uuid.uuid4())[:UUID_LENGTH], + entity_id = entity_id, + project_id = project_id, + name = "", # 将在后续填充 + source_type = source_type, + source_id = source_id, + mention_context = mention_context, + confidence = confidence, ) def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict: @@ -489,7 +489,7 @@ class MultimodalEntityLinker: Returns: 模态分布统计 """ - distribution = {mod: 0 for mod in self.MODALITIES} + distribution = {mod: 0 for mod in self.MODALITIES} # 统计每个模态的实体数 for me in multimodal_entities: @@ -497,13 +497,13 @@ class MultimodalEntityLinker: distribution[me.source_type] += 1 # 统计跨模态实体 - entity_modalities = {} + entity_modalities = {} for me in multimodal_entities: if me.entity_id not in entity_modalities: - entity_modalities[me.entity_id] = set() + entity_modalities[me.entity_id] = set() entity_modalities[me.entity_id].add(me.source_type) - cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1) + cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1) return { "modality_distribution": distribution, @@ -517,12 +517,12 @@ class MultimodalEntityLinker: # Singleton instance -_multimodal_entity_linker = None +_multimodal_entity_linker = None -def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker: +def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker: """获取多模态实体关联器单例""" global _multimodal_entity_linker if _multimodal_entity_linker is None: - _multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold) + _multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold) return _multimodal_entity_linker diff --git a/backend/multimodal_processor.py b/backend/multimodal_processor.py index d450811..b13b4d9 100644 --- a/backend/multimodal_processor.py +++ b/backend/multimodal_processor.py @@ -13,30 +13,30 @@ from dataclasses import dataclass from pathlib import Path # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # 尝试导入OCR库 try: import pytesseract from PIL import Image - PYTESSERACT_AVAILABLE = True + PYTESSERACT_AVAILABLE = True except ImportError: - PYTESSERACT_AVAILABLE = False + PYTESSERACT_AVAILABLE = False try: import cv2 - CV2_AVAILABLE = True + CV2_AVAILABLE = True except ImportError: - CV2_AVAILABLE = False + CV2_AVAILABLE = False try: import ffmpeg - FFMPEG_AVAILABLE = True + FFMPEG_AVAILABLE = True except ImportError: - FFMPEG_AVAILABLE = False + FFMPEG_AVAILABLE = False @dataclass @@ -48,13 +48,13 @@ class VideoFrame: frame_number: int timestamp: float frame_path: str - ocr_text: str = "" - ocr_confidence: float = 0.0 - entities_detected: list[dict] = None + ocr_text: str = "" + ocr_confidence: float = 0.0 + entities_detected: list[dict] = None - def __post_init__(self): + def __post_init__(self) -> None: if self.entities_detected is None: - self.entities_detected = [] + self.entities_detected = [] @dataclass @@ -65,20 +65,20 @@ class VideoInfo: project_id: str filename: str file_path: str - duration: float = 0.0 - width: int = 0 - height: int = 0 - fps: float = 0.0 - audio_extracted: bool = False - audio_path: str = "" - transcript_id: str = "" - status: str = "pending" - error_message: str = "" - metadata: dict = None + duration: float = 0.0 + width: int = 0 + height: int = 0 + fps: float = 0.0 + audio_extracted: bool = False + audio_path: str = "" + transcript_id: str = "" + status: str = "pending" + error_message: str = "" + metadata: dict = None - def __post_init__(self): + def __post_init__(self) -> None: if self.metadata is None: - self.metadata = {} + self.metadata = {} @dataclass @@ -91,13 +91,13 @@ class VideoProcessingResult: ocr_results: list[dict] full_text: str # 整合的文本(音频转录 + OCR文本) success: bool - error_message: str = "" + error_message: str = "" class MultimodalProcessor: """多模态处理器 - 处理视频文件""" - def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None: + def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None: """ 初始化多模态处理器 @@ -105,16 +105,16 @@ class MultimodalProcessor: temp_dir: 临时文件目录 frame_interval: 关键帧提取间隔(秒) """ - self.temp_dir = temp_dir or tempfile.gettempdir() - self.frame_interval = frame_interval - self.video_dir = os.path.join(self.temp_dir, "videos") - self.frames_dir = os.path.join(self.temp_dir, "frames") - self.audio_dir = os.path.join(self.temp_dir, "audio") + self.temp_dir = temp_dir or tempfile.gettempdir() + self.frame_interval = frame_interval + self.video_dir = os.path.join(self.temp_dir, "videos") + self.frames_dir = os.path.join(self.temp_dir, "frames") + self.audio_dir = os.path.join(self.temp_dir, "audio") # 创建目录 - os.makedirs(self.video_dir, exist_ok=True) - os.makedirs(self.frames_dir, exist_ok=True) - os.makedirs(self.audio_dir, exist_ok=True) + os.makedirs(self.video_dir, exist_ok = True) + os.makedirs(self.frames_dir, exist_ok = True) + os.makedirs(self.audio_dir, exist_ok = True) def extract_video_info(self, video_path: str) -> dict: """ @@ -128,11 +128,11 @@ class MultimodalProcessor: """ try: if FFMPEG_AVAILABLE: - probe = ffmpeg.probe(video_path) - video_stream = next( + probe = ffmpeg.probe(video_path) + video_stream = next( (s for s in probe["streams"] if s["codec_type"] == "video"), None ) - audio_stream = next( + audio_stream = next( (s for s in probe["streams"] if s["codec_type"] == "audio"), None ) @@ -147,21 +147,21 @@ class MultimodalProcessor: } else: # 使用 ffprobe 命令行 - cmd = [ + cmd = [ "ffprobe", "-v", "error", "-show_entries", - "format=duration,bit_rate", + "format = duration, bit_rate", "-show_entries", - "stream=width,height,r_frame_rate", + "stream = width, height, r_frame_rate", "-of", "json", video_path, ] - result = subprocess.run(cmd, capture_output=True, text=True) + result = subprocess.run(cmd, capture_output = True, text = True) if result.returncode == 0: - data = json.loads(result.stdout) + data = json.loads(result.stdout) return { "duration": float(data["format"].get("duration", 0)), "width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0, @@ -177,7 +177,7 @@ class MultimodalProcessor: return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0} - def extract_audio(self, video_path: str, output_path: str = None) -> str: + def extract_audio(self, video_path: str, output_path: str = None) -> str: """ 从视频中提取音频 @@ -189,20 +189,20 @@ class MultimodalProcessor: 提取的音频文件路径 """ if output_path is None: - video_name = Path(video_path).stem - output_path = os.path.join(self.audio_dir, f"{video_name}.wav") + video_name = Path(video_path).stem + output_path = os.path.join(self.audio_dir, f"{video_name}.wav") try: if FFMPEG_AVAILABLE: ( ffmpeg.input(video_path) - .output(output_path, ac=1, ar=16000, vn=None) + .output(output_path, ac = 1, ar = 16000, vn = None) .overwrite_output() - .run(quiet=True) + .run(quiet = True) ) else: # 使用命令行 ffmpeg - cmd = [ + cmd = [ "ffmpeg", "-i", video_path, @@ -216,14 +216,14 @@ class MultimodalProcessor: "-y", output_path, ] - subprocess.run(cmd, check=True, capture_output=True) + subprocess.run(cmd, check = True, capture_output = True) return output_path except Exception as e: print(f"Error extracting audio: {e}") raise - def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]: + def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]: """ 从视频中提取关键帧 @@ -235,31 +235,31 @@ class MultimodalProcessor: Returns: 提取的帧文件路径列表 """ - interval = interval or self.frame_interval - frame_paths = [] + interval = interval or self.frame_interval + frame_paths = [] # 创建帧存储目录 - video_frames_dir = os.path.join(self.frames_dir, video_id) - os.makedirs(video_frames_dir, exist_ok=True) + video_frames_dir = os.path.join(self.frames_dir, video_id) + os.makedirs(video_frames_dir, exist_ok = True) try: if CV2_AVAILABLE: # 使用 OpenCV 提取帧 - cap = cv2.VideoCapture(video_path) - fps = cap.get(cv2.CAP_PROP_FPS) + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - frame_interval_frames = int(fps * interval) - frame_number = 0 + frame_interval_frames = int(fps * interval) + frame_number = 0 while True: - ret, frame = cap.read() + ret, frame = cap.read() if not ret: break if frame_number % frame_interval_frames == 0: - timestamp = frame_number / fps - frame_path = os.path.join( + timestamp = frame_number / fps + frame_path = os.path.join( video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg" ) cv2.imwrite(frame_path, frame) @@ -271,23 +271,23 @@ class MultimodalProcessor: else: # 使用 ffmpeg 命令行提取帧 Path(video_path).stem - output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg") + output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg") - cmd = [ + cmd = [ "ffmpeg", "-i", video_path, "-vf", - f"fps=1/{interval}", + f"fps = 1/{interval}", "-frame_pts", "1", "-y", output_pattern, ] - subprocess.run(cmd, check=True, capture_output=True) + subprocess.run(cmd, check = True, capture_output = True) # 获取生成的帧文件列表 - frame_paths = sorted( + frame_paths = sorted( [ os.path.join(video_frames_dir, f) for f in os.listdir(video_frames_dir) @@ -313,19 +313,19 @@ class MultimodalProcessor: return "", 0.0 try: - image = Image.open(image_path) + image = Image.open(image_path) # 预处理:转换为灰度图 if image.mode != "L": - image = image.convert("L") + image = image.convert("L") # 使用 pytesseract 进行 OCR - text = pytesseract.image_to_string(image, lang="chi_sim+eng") + text = pytesseract.image_to_string(image, lang = "chi_sim+eng") # 获取置信度数据 - data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) - confidences = [int(c) for c in data["conf"] if int(c) > 0] - avg_confidence = sum(confidences) / len(confidences) if confidences else 0 + data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT) + confidences = [int(c) for c in data["conf"] if int(c) > 0] + avg_confidence = sum(confidences) / len(confidences) if confidences else 0 return text.strip(), avg_confidence / 100.0 except Exception as e: @@ -333,7 +333,7 @@ class MultimodalProcessor: return "", 0.0 def process_video( - self, video_data: bytes, filename: str, project_id: str, video_id: str = None + self, video_data: bytes, filename: str, project_id: str, video_id: str = None ) -> VideoProcessingResult: """ 处理视频文件:提取音频、关键帧、OCR @@ -347,48 +347,48 @@ class MultimodalProcessor: Returns: 视频处理结果 """ - video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH] + video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH] try: # 保存视频文件 - video_path = os.path.join(self.video_dir, f"{video_id}_{filename}") + video_path = os.path.join(self.video_dir, f"{video_id}_{filename}") with open(video_path, "wb") as f: f.write(video_data) # 提取视频信息 - video_info = self.extract_video_info(video_path) + video_info = self.extract_video_info(video_path) # 提取音频 - audio_path = "" + audio_path = "" if video_info["has_audio"]: - audio_path = self.extract_audio(video_path) + audio_path = self.extract_audio(video_path) # 提取关键帧 - frame_paths = self.extract_keyframes(video_path, video_id) + frame_paths = self.extract_keyframes(video_path, video_id) # 对关键帧进行 OCR - frames = [] - ocr_results = [] - all_ocr_text = [] + frames = [] + ocr_results = [] + all_ocr_text = [] for i, frame_path in enumerate(frame_paths): # 解析帧信息 - frame_name = os.path.basename(frame_path) - parts = frame_name.replace(".jpg", "").split("_") - frame_number = int(parts[1]) if len(parts) > 1 else i - timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval + frame_name = os.path.basename(frame_path) + parts = frame_name.replace(".jpg", "").split("_") + frame_number = int(parts[1]) if len(parts) > 1 else i + timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval # OCR 识别 - ocr_text, confidence = self.perform_ocr(frame_path) + ocr_text, confidence = self.perform_ocr(frame_path) - frame = VideoFrame( - id=str(uuid.uuid4())[:UUID_LENGTH], - video_id=video_id, - frame_number=frame_number, - timestamp=timestamp, - frame_path=frame_path, - ocr_text=ocr_text, - ocr_confidence=confidence, + frame = VideoFrame( + id = str(uuid.uuid4())[:UUID_LENGTH], + video_id = video_id, + frame_number = frame_number, + timestamp = timestamp, + frame_path = frame_path, + ocr_text = ocr_text, + ocr_confidence = confidence, ) frames.append(frame) @@ -404,29 +404,29 @@ class MultimodalProcessor: all_ocr_text.append(ocr_text) # 整合所有 OCR 文本 - full_ocr_text = "\n\n".join(all_ocr_text) + full_ocr_text = "\n\n".join(all_ocr_text) return VideoProcessingResult( - video_id=video_id, - audio_path=audio_path, - frames=frames, - ocr_results=ocr_results, - full_text=full_ocr_text, - success=True, + video_id = video_id, + audio_path = audio_path, + frames = frames, + ocr_results = ocr_results, + full_text = full_ocr_text, + success = True, ) except Exception as e: return VideoProcessingResult( - video_id=video_id, - audio_path="", - frames=[], - ocr_results=[], - full_text="", - success=False, - error_message=str(e), + video_id = video_id, + audio_path = "", + frames = [], + ocr_results = [], + full_text = "", + success = False, + error_message = str(e), ) - def cleanup(self, video_id: str = None) -> None: + def cleanup(self, video_id: str = None) -> None: """ 清理临时文件 @@ -438,7 +438,7 @@ class MultimodalProcessor: if video_id: # 清理特定视频的文件 for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: - target_dir = ( + target_dir = ( os.path.join(dir_path, video_id) if dir_path == self.frames_dir else dir_path ) if os.path.exists(target_dir): @@ -450,16 +450,16 @@ class MultimodalProcessor: for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: if os.path.exists(dir_path): shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok=True) + os.makedirs(dir_path, exist_ok = True) # Singleton instance -_multimodal_processor = None +_multimodal_processor = None -def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor: +def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor: """获取多模态处理器单例""" global _multimodal_processor if _multimodal_processor is None: - _multimodal_processor = MultimodalProcessor(temp_dir, frame_interval) + _multimodal_processor = MultimodalProcessor(temp_dir, frame_interval) return _multimodal_processor diff --git a/backend/neo4j_manager.py b/backend/neo4j_manager.py index 45b8c57..229175d 100644 --- a/backend/neo4j_manager.py +++ b/backend/neo4j_manager.py @@ -10,20 +10,20 @@ import logging import os from dataclasses import dataclass -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # Neo4j 连接配置 -NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") -NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") -NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password") +NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") +NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") +NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password") # 延迟导入,避免未安装时出错 try: from neo4j import Driver, GraphDatabase - NEO4J_AVAILABLE = True + NEO4J_AVAILABLE = True except ImportError: - NEO4J_AVAILABLE = False + NEO4J_AVAILABLE = False logger.warning("Neo4j driver not installed. Neo4j features will be disabled.") @@ -35,15 +35,15 @@ class GraphEntity: project_id: str name: str type: str - definition: str = "" - aliases: list[str] = None - properties: dict = None + definition: str = "" + aliases: list[str] = None + properties: dict = None - def __post_init__(self): + def __post_init__(self) -> None: if self.aliases is None: - self.aliases = [] + self.aliases = [] if self.properties is None: - self.properties = {} + self.properties = {} @dataclass @@ -54,12 +54,12 @@ class GraphRelation: source_id: str target_id: str relation_type: str - evidence: str = "" - properties: dict = None + evidence: str = "" + properties: dict = None - def __post_init__(self): + def __post_init__(self) -> None: if self.properties is None: - self.properties = {} + self.properties = {} @dataclass @@ -69,7 +69,7 @@ class PathResult: nodes: list[dict] relationships: list[dict] length: int - total_weight: float = 0.0 + total_weight: float = 0.0 @dataclass @@ -79,7 +79,7 @@ class CommunityResult: community_id: int nodes: list[dict] size: int - density: float = 0.0 + density: float = 0.0 @dataclass @@ -89,17 +89,17 @@ class CentralityResult: entity_id: str entity_name: str score: float - rank: int = 0 + rank: int = 0 class Neo4jManager: """Neo4j 图数据库管理器""" - def __init__(self, uri: str = None, user: str = None, password: str = None): - self.uri = uri or NEO4J_URI - self.user = user or NEO4J_USER - self.password = password or NEO4J_PASSWORD - self._driver: Driver | None = None + def __init__(self, uri: str = None, user: str = None, password: str = None) -> None: + self.uri = uri or NEO4J_URI + self.user = user or NEO4J_USER + self.password = password or NEO4J_PASSWORD + self._driver: Driver | None = None if not NEO4J_AVAILABLE: logger.error("Neo4j driver not available. Please install: pip install neo4j") @@ -113,13 +113,13 @@ class Neo4jManager: return try: - self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) + self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password)) # 验证连接 self._driver.verify_connectivity() logger.info(f"Connected to Neo4j at {self.uri}") except (RuntimeError, ValueError, TypeError) as e: logger.error(f"Failed to connect to Neo4j: {e}") - self._driver = None + self._driver = None def close(self) -> None: """关闭连接""" @@ -179,7 +179,7 @@ class Neo4jManager: # ==================== 数据同步 ==================== def sync_project( - self, project_id: str, project_name: str, project_description: str = "" + self, project_id: str, project_name: str, project_description: str = "" ) -> None: """同步项目节点到 Neo4j""" if not self._driver: @@ -189,13 +189,13 @@ class Neo4jManager: session.run( """ MERGE (p:Project {id: $project_id}) - SET p.name = $name, - p.description = $description, - p.updated_at = datetime() + SET p.name = $name, + p.description = $description, + p.updated_at = datetime() """, - project_id=project_id, - name=project_name, - description=project_description, + project_id = project_id, + name = project_name, + description = project_description, ) def sync_entity(self, entity: GraphEntity) -> None: @@ -208,23 +208,23 @@ class Neo4jManager: session.run( """ MERGE (e:Entity {id: $id}) - SET e.name = $name, - e.type = $type, - e.definition = $definition, - e.aliases = $aliases, - e.properties = $properties, - e.updated_at = datetime() + SET e.name = $name, + e.type = $type, + e.definition = $definition, + e.aliases = $aliases, + e.properties = $properties, + e.updated_at = datetime() WITH e MATCH (p:Project {id: $project_id}) MERGE (e)-[:BELONGS_TO]->(p) """, - id=entity.id, - project_id=entity.project_id, - name=entity.name, - type=entity.type, - definition=entity.definition, - aliases=json.dumps(entity.aliases), - properties=json.dumps(entity.properties), + id = entity.id, + project_id = entity.project_id, + name = entity.name, + type = entity.type, + definition = entity.definition, + aliases = json.dumps(entity.aliases), + properties = json.dumps(entity.properties), ) def sync_entities_batch(self, entities: list[GraphEntity]) -> None: @@ -234,7 +234,7 @@ class Neo4jManager: with self._driver.session() as session: # 使用 UNWIND 批量处理 - entities_data = [ + entities_data = [ { "id": e.id, "project_id": e.project_id, @@ -251,17 +251,17 @@ class Neo4jManager: """ UNWIND $entities AS entity MERGE (e:Entity {id: entity.id}) - SET e.name = entity.name, - e.type = entity.type, - e.definition = entity.definition, - e.aliases = entity.aliases, - e.properties = entity.properties, - e.updated_at = datetime() + SET e.name = entity.name, + e.type = entity.type, + e.definition = entity.definition, + e.aliases = entity.aliases, + e.properties = entity.properties, + e.updated_at = datetime() WITH e, entity MATCH (p:Project {id: entity.project_id}) MERGE (e)-[:BELONGS_TO]->(p) """, - entities=entities_data, + entities = entities_data, ) def sync_relation(self, relation: GraphRelation) -> None: @@ -275,17 +275,17 @@ class Neo4jManager: MATCH (source:Entity {id: $source_id}) MATCH (target:Entity {id: $target_id}) MERGE (source)-[r:RELATES_TO {id: $id}]->(target) - SET r.relation_type = $relation_type, - r.evidence = $evidence, - r.properties = $properties, - r.updated_at = datetime() + SET r.relation_type = $relation_type, + r.evidence = $evidence, + r.properties = $properties, + r.updated_at = datetime() """, - id=relation.id, - source_id=relation.source_id, - target_id=relation.target_id, - relation_type=relation.relation_type, - evidence=relation.evidence, - properties=json.dumps(relation.properties), + id = relation.id, + source_id = relation.source_id, + target_id = relation.target_id, + relation_type = relation.relation_type, + evidence = relation.evidence, + properties = json.dumps(relation.properties), ) def sync_relations_batch(self, relations: list[GraphRelation]) -> None: @@ -294,7 +294,7 @@ class Neo4jManager: return with self._driver.session() as session: - relations_data = [ + relations_data = [ { "id": r.id, "source_id": r.source_id, @@ -312,12 +312,12 @@ class Neo4jManager: MATCH (source:Entity {id: rel.source_id}) MATCH (target:Entity {id: rel.target_id}) MERGE (source)-[r:RELATES_TO {id: rel.id}]->(target) - SET r.relation_type = rel.relation_type, - r.evidence = rel.evidence, - r.properties = rel.properties, - r.updated_at = datetime() + SET r.relation_type = rel.relation_type, + r.evidence = rel.evidence, + r.properties = rel.properties, + r.updated_at = datetime() """, - relations=relations_data, + relations = relations_data, ) def delete_entity(self, entity_id: str) -> None: @@ -331,7 +331,7 @@ class Neo4jManager: MATCH (e:Entity {id: $id}) DETACH DELETE e """, - id=entity_id, + id = entity_id, ) def delete_project(self, project_id: str) -> None: @@ -346,13 +346,13 @@ class Neo4jManager: OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p) DETACH DELETE e, p """, - id=project_id, + id = project_id, ) # ==================== 复杂图查询 ==================== def find_shortest_path( - self, source_id: str, target_id: str, max_depth: int = 10 + self, source_id: str, target_id: str, max_depth: int = 10 ) -> PathResult | None: """ 查找两个实体之间的最短路径 @@ -369,31 +369,31 @@ class Neo4jManager: return None with self._driver.session() as session: - result = session.run( + result = session.run( """ - MATCH path = shortestPath( + MATCH path = shortestPath( (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) ) RETURN path """, - source_id=source_id, - target_id=target_id, - max_depth=max_depth, + source_id = source_id, + target_id = target_id, + max_depth = max_depth, ) - record = result.single() + record = result.single() if not record: return None - path = record["path"] + path = record["path"] # 提取节点和关系 - nodes = [ + nodes = [ {"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes ] - relationships = [ + relationships = [ { "source": rel.start_node["id"], "target": rel.end_node["id"], @@ -404,11 +404,11 @@ class Neo4jManager: ] return PathResult( - nodes=nodes, relationships=relationships, length=len(path.relationships) + nodes = nodes, relationships = relationships, length = len(path.relationships) ) def find_all_paths( - self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10 + self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10 ) -> list[PathResult]: """ 查找两个实体之间的所有路径 @@ -426,29 +426,29 @@ class Neo4jManager: return [] with self._driver.session() as session: - result = session.run( + result = session.run( """ - MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) + MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) WHERE source <> target RETURN path LIMIT $limit """, - source_id=source_id, - target_id=target_id, - max_depth=max_depth, - limit=limit, + source_id = source_id, + target_id = target_id, + max_depth = max_depth, + limit = limit, ) - paths = [] + paths = [] for record in result: - path = record["path"] + path = record["path"] - nodes = [ + nodes = [ {"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes ] - relationships = [ + relationships = [ { "source": rel.start_node["id"], "target": rel.end_node["id"], @@ -460,14 +460,14 @@ class Neo4jManager: paths.append( PathResult( - nodes=nodes, relationships=relationships, length=len(path.relationships) + nodes = nodes, relationships = relationships, length = len(path.relationships) ) ) return paths def find_neighbors( - self, entity_id: str, relation_type: str = None, limit: int = 50 + self, entity_id: str, relation_type: str = None, limit: int = 50 ) -> list[dict]: """ 查找实体的邻居节点 @@ -485,30 +485,30 @@ class Neo4jManager: with self._driver.session() as session: if relation_type: - result = session.run( + result = session.run( """ MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]-(neighbor:Entity) RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence LIMIT $limit """, - entity_id=entity_id, - relation_type=relation_type, - limit=limit, + entity_id = entity_id, + relation_type = relation_type, + limit = limit, ) else: - result = session.run( + result = session.run( """ MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO]-(neighbor:Entity) RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence LIMIT $limit """, - entity_id=entity_id, - limit=limit, + entity_id = entity_id, + limit = limit, ) - neighbors = [] + neighbors = [] for record in result: - node = record["neighbor"] + node = record["neighbor"] neighbors.append( { "id": node["id"], @@ -536,13 +536,13 @@ class Neo4jManager: return [] with self._driver.session() as session: - result = session.run( + result = session.run( """ MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2}) RETURN DISTINCT common """, - id1=entity_id1, - id2=entity_id2, + id1 = entity_id1, + id2 = entity_id2, ) return [ @@ -556,7 +556,7 @@ class Neo4jManager: # ==================== 图算法分析 ==================== - def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: + def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: """ 计算 PageRank 中心性 @@ -571,7 +571,7 @@ class Neo4jManager: return [] with self._driver.session() as session: - result = session.run( + result = session.run( """ CALL gds.graph.exists('project-graph-$project_id') YIELD exists WITH exists @@ -581,7 +581,7 @@ class Neo4jManager: {} ) YIELD value RETURN value """, - project_id=project_id, + project_id = project_id, ) # 创建临时图 @@ -601,11 +601,11 @@ class Neo4jManager: } ) """, - project_id=project_id, + project_id = project_id, ) # 运行 PageRank - result = session.run( + result = session.run( """ CALL gds.pageRank.stream('project-graph-$project_id') YIELD nodeId, score @@ -615,19 +615,19 @@ class Neo4jManager: ORDER BY score DESC LIMIT $top_n """, - project_id=project_id, - top_n=top_n, + project_id = project_id, + top_n = top_n, ) - rankings = [] - rank = 1 + rankings = [] + rank = 1 for record in result: rankings.append( CentralityResult( - entity_id=record["entity_id"], - entity_name=record["entity_name"], - score=record["score"], - rank=rank, + entity_id = record["entity_id"], + entity_name = record["entity_name"], + score = record["score"], + rank = rank, ) ) rank += 1 @@ -637,12 +637,12 @@ class Neo4jManager: """ CALL gds.graph.drop('project-graph-$project_id') """, - project_id=project_id, + project_id = project_id, ) return rankings - def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: + def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: """ 计算 Betweenness 中心性(桥梁作用) @@ -658,7 +658,7 @@ class Neo4jManager: with self._driver.session() as session: # 使用 APOC 的 betweenness 计算(如果没有 GDS) - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) @@ -667,19 +667,19 @@ class Neo4jManager: LIMIT $top_n RETURN e.id as entity_id, e.name as entity_name, degree as score """, - project_id=project_id, - top_n=top_n, + project_id = project_id, + top_n = top_n, ) - rankings = [] - rank = 1 + rankings = [] + rank = 1 for record in result: rankings.append( CentralityResult( - entity_id=record["entity_id"], - entity_name=record["entity_name"], - score=float(record["score"]), - rank=rank, + entity_id = record["entity_id"], + entity_name = record["entity_name"], + score = float(record["score"]), + rank = rank, ) ) rank += 1 @@ -701,7 +701,7 @@ class Neo4jManager: with self._driver.session() as session: # 简单的社区检测:基于连通分量 - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)-[:BELONGS_TO]->(p) @@ -710,25 +710,25 @@ class Neo4jManager: connections, size(connections) as connection_count ORDER BY connection_count DESC """, - project_id=project_id, + project_id = project_id, ) # 手动分组(基于连通性) - communities = {} + communities = {} for record in result: - entity_id = record["entity_id"] - connections = record["connections"] + entity_id = record["entity_id"] + connections = record["connections"] # 找到所属的社区 - found_community = None + found_community = None for comm_id, comm_data in communities.items(): if any(conn in comm_data["member_ids"] for conn in connections): - found_community = comm_id + found_community = comm_id break if found_community is None: - found_community = len(communities) - communities[found_community] = {"member_ids": set(), "nodes": []} + found_community = len(communities) + communities[found_community] = {"member_ids": set(), "nodes": []} communities[found_community]["member_ids"].add(entity_id) communities[found_community]["nodes"].append( @@ -741,27 +741,27 @@ class Neo4jManager: ) # 构建结果 - results = [] + results = [] for comm_id, comm_data in communities.items(): - nodes = comm_data["nodes"] - size = len(nodes) + nodes = comm_data["nodes"] + size = len(nodes) # 计算密度(简化版) - max_edges = size * (size - 1) / 2 if size > 1 else 1 - actual_edges = sum(n["connections"] for n in nodes) / 2 - density = actual_edges / max_edges if max_edges > 0 else 0 + max_edges = size * (size - 1) / 2 if size > 1 else 1 + actual_edges = sum(n["connections"] for n in nodes) / 2 + density = actual_edges / max_edges if max_edges > 0 else 0 results.append( CommunityResult( - community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0) + community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0) ) ) # 按大小排序 - results.sort(key=lambda x: x.size, reverse=True) + results.sort(key = lambda x: x.size, reverse = True) return results def find_central_entities( - self, project_id: str, metric: str = "degree" + self, project_id: str, metric: str = "degree" ) -> list[CentralityResult]: """ 查找中心实体 @@ -778,7 +778,7 @@ class Neo4jManager: with self._driver.session() as session: if metric == "degree": - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) @@ -787,11 +787,11 @@ class Neo4jManager: ORDER BY degree DESC LIMIT 20 """, - project_id=project_id, + project_id = project_id, ) else: # 默认使用度中心性 - result = session.run( + result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) @@ -800,18 +800,18 @@ class Neo4jManager: ORDER BY degree DESC LIMIT 20 """, - project_id=project_id, + project_id = project_id, ) - rankings = [] - rank = 1 + rankings = [] + rank = 1 for record in result: rankings.append( CentralityResult( - entity_id=record["entity_id"], - entity_name=record["entity_name"], - score=float(record["score"]), - rank=rank, + entity_id = record["entity_id"], + entity_name = record["entity_name"], + score = float(record["score"]), + rank = rank, ) ) rank += 1 @@ -835,49 +835,49 @@ class Neo4jManager: with self._driver.session() as session: # 实体数量 - entity_count = session.run( + entity_count = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) RETURN count(e) as count """, - project_id=project_id, + project_id = project_id, ).single()["count"] # 关系数量 - relation_count = session.run( + relation_count = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e)-[r:RELATES_TO]-() RETURN count(r) as count """, - project_id=project_id, + project_id = project_id, ).single()["count"] # 实体类型分布 - type_distribution = session.run( + type_distribution = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) RETURN e.type as type, count(e) as count ORDER BY count DESC """, - project_id=project_id, + project_id = project_id, ) - types = {record["type"]: record["count"] for record in type_distribution} + types = {record["type"]: record["count"] for record in type_distribution} # 平均度 - avg_degree = session.run( + avg_degree = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other) WITH e, count(other) as degree RETURN avg(degree) as avg_degree """, - project_id=project_id, + project_id = project_id, ).single()["avg_degree"] # 关系类型分布 - rel_types = session.run( + rel_types = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e)-[r:RELATES_TO]-() @@ -885,10 +885,10 @@ class Neo4jManager: ORDER BY count DESC LIMIT 10 """, - project_id=project_id, + project_id = project_id, ) - relation_types = {record["type"]: record["count"] for record in rel_types} + relation_types = {record["type"]: record["count"] for record in rel_types} return { "entity_count": entity_count, @@ -901,7 +901,7 @@ class Neo4jManager: else 0, } - def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: + def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: """ 获取指定实体的子图 @@ -916,7 +916,7 @@ class Neo4jManager: return {"nodes": [], "relationships": []} with self._driver.session() as session: - result = session.run( + result = session.run( """ MATCH (e:Entity) WHERE e.id IN $entity_ids @@ -927,14 +927,14 @@ class Neo4jManager: }) YIELD node RETURN DISTINCT node """, - entity_ids=entity_ids, - depth=depth, + entity_ids = entity_ids, + depth = depth, ) - nodes = [] - node_ids = set() + nodes = [] + node_ids = set() for record in result: - node = record["node"] + node = record["node"] node_ids.add(node["id"]) nodes.append( { @@ -946,17 +946,17 @@ class Neo4jManager: ) # 获取这些节点之间的关系 - result = session.run( + result = session.run( """ MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity) WHERE source.id IN $node_ids AND target.id IN $node_ids RETURN source.id as source_id, target.id as target_id, r.relation_type as type, r.evidence as evidence """, - node_ids=list(node_ids), + node_ids = list(node_ids), ) - relationships = [ + relationships = [ { "source": record["source_id"], "target": record["target_id"], @@ -970,14 +970,14 @@ class Neo4jManager: # 全局单例 -_neo4j_manager = None +_neo4j_manager = None def get_neo4j_manager() -> Neo4jManager: """获取 Neo4j 管理器单例""" global _neo4j_manager if _neo4j_manager is None: - _neo4j_manager = Neo4jManager() + _neo4j_manager = Neo4jManager() return _neo4j_manager @@ -986,7 +986,7 @@ def close_neo4j_manager() -> None: global _neo4j_manager if _neo4j_manager: _neo4j_manager.close() - _neo4j_manager = None + _neo4j_manager = None # 便捷函数 @@ -1004,7 +1004,7 @@ def sync_project_to_neo4j( entities: 实体列表(字典格式) relations: 关系列表(字典格式) """ - manager = get_neo4j_manager() + manager = get_neo4j_manager() if not manager.is_connected(): logger.warning("Neo4j not connected, skipping sync") return @@ -1013,29 +1013,29 @@ def sync_project_to_neo4j( manager.sync_project(project_id, project_name) # 同步实体 - graph_entities = [ + graph_entities = [ GraphEntity( - id=e["id"], - project_id=project_id, - name=e["name"], - type=e.get("type", "unknown"), - definition=e.get("definition", ""), - aliases=e.get("aliases", []), - properties=e.get("properties", {}), + id = e["id"], + project_id = project_id, + name = e["name"], + type = e.get("type", "unknown"), + definition = e.get("definition", ""), + aliases = e.get("aliases", []), + properties = e.get("properties", {}), ) for e in entities ] manager.sync_entities_batch(graph_entities) # 同步关系 - graph_relations = [ + graph_relations = [ GraphRelation( - id=r["id"], - source_id=r["source_entity_id"], - target_id=r["target_entity_id"], - relation_type=r["relation_type"], - evidence=r.get("evidence", ""), - properties=r.get("properties", {}), + id = r["id"], + source_id = r["source_entity_id"], + target_id = r["target_entity_id"], + relation_type = r["relation_type"], + evidence = r.get("evidence", ""), + properties = r.get("properties", {}), ) for r in relations ] @@ -1048,9 +1048,9 @@ def sync_project_to_neo4j( if __name__ == "__main__": # 测试代码 - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level = logging.INFO) - manager = Neo4jManager() + manager = Neo4jManager() if manager.is_connected(): print("✅ Connected to Neo4j") @@ -1064,18 +1064,18 @@ if __name__ == "__main__": print("✅ Project synced") # 测试实体 - test_entity = GraphEntity( - id="test-entity-1", - project_id="test-project", - name="Test Entity", - type="Person", - definition="A test entity", + test_entity = GraphEntity( + id = "test-entity-1", + project_id = "test-project", + name = "Test Entity", + type = "Person", + definition = "A test entity", ) manager.sync_entity(test_entity) print("✅ Entity synced") # 获取统计 - stats = manager.get_graph_stats("test-project") + stats = manager.get_graph_stats("test-project") print(f"📊 Graph stats: {stats}") else: diff --git a/backend/ops_manager.py b/backend/ops_manager.py index 5a2ede9..ba80dfc 100644 --- a/backend/ops_manager.py +++ b/backend/ops_manager.py @@ -27,87 +27,87 @@ from enum import StrEnum import httpx # Database path -DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") +DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class AlertSeverity(StrEnum): """告警严重级别 P0-P3""" - P0 = "p0" # 紧急 - 系统不可用,需要立即处理 - P1 = "p1" # 严重 - 核心功能受损,需要1小时内处理 - P2 = "p2" # 一般 - 部分功能受影响,需要4小时内处理 - P3 = "p3" # 轻微 - 非核心功能问题,24小时内处理 + P0 = "p0" # 紧急 - 系统不可用,需要立即处理 + P1 = "p1" # 严重 - 核心功能受损,需要1小时内处理 + P2 = "p2" # 一般 - 部分功能受影响,需要4小时内处理 + P3 = "p3" # 轻微 - 非核心功能问题,24小时内处理 class AlertStatus(StrEnum): """告警状态""" - FIRING = "firing" # 正在告警 - RESOLVED = "resolved" # 已恢复 - ACKNOWLEDGED = "acknowledged" # 已确认 - SUPPRESSED = "suppressed" # 已抑制 + FIRING = "firing" # 正在告警 + RESOLVED = "resolved" # 已恢复 + ACKNOWLEDGED = "acknowledged" # 已确认 + SUPPRESSED = "suppressed" # 已抑制 class AlertChannelType(StrEnum): """告警渠道类型""" - PAGERDUTY = "pagerduty" - OPSGENIE = "opsgenie" - FEISHU = "feishu" - DINGTALK = "dingtalk" - SLACK = "slack" - EMAIL = "email" - SMS = "sms" - WEBHOOK = "webhook" + PAGERDUTY = "pagerduty" + OPSGENIE = "opsgenie" + FEISHU = "feishu" + DINGTALK = "dingtalk" + SLACK = "slack" + EMAIL = "email" + SMS = "sms" + WEBHOOK = "webhook" class AlertRuleType(StrEnum): """告警规则类型""" - THRESHOLD = "threshold" # 阈值告警 - ANOMALY = "anomaly" # 异常检测 - PREDICTIVE = "predictive" # 预测性告警 - COMPOSITE = "composite" # 复合告警 + THRESHOLD = "threshold" # 阈值告警 + ANOMALY = "anomaly" # 异常检测 + PREDICTIVE = "predictive" # 预测性告警 + COMPOSITE = "composite" # 复合告警 class ResourceType(StrEnum): """资源类型""" - CPU = "cpu" - MEMORY = "memory" - DISK = "disk" - NETWORK = "network" - GPU = "gpu" - DATABASE = "database" - CACHE = "cache" - QUEUE = "queue" + CPU = "cpu" + MEMORY = "memory" + DISK = "disk" + NETWORK = "network" + GPU = "gpu" + DATABASE = "database" + CACHE = "cache" + QUEUE = "queue" class ScalingAction(StrEnum): """扩缩容动作""" - SCALE_UP = "scale_up" # 扩容 - SCALE_DOWN = "scale_down" # 缩容 - MAINTAIN = "maintain" # 保持 + SCALE_UP = "scale_up" # 扩容 + SCALE_DOWN = "scale_down" # 缩容 + MAINTAIN = "maintain" # 保持 class HealthStatus(StrEnum): """健康状态""" - HEALTHY = "healthy" - DEGRADED = "degraded" - UNHEALTHY = "unhealthy" - UNKNOWN = "unknown" + HEALTHY = "healthy" + DEGRADED = "degraded" + UNHEALTHY = "unhealthy" + UNKNOWN = "unknown" class BackupStatus(StrEnum): """备份状态""" - PENDING = "pending" - IN_PROGRESS = "in_progress" - COMPLETED = "completed" - FAILED = "failed" - VERIFIED = "verified" + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + VERIFIED = "verified" @dataclass @@ -121,7 +121,7 @@ class AlertRule: rule_type: AlertRuleType severity: AlertSeverity metric: str # 监控指标 - condition: str # 条件: >, <, ==, >=, <=, != + condition: str # 条件: >, <, ==, >= , <= , != threshold: float duration: int # 持续时间(秒) evaluation_interval: int # 评估间隔(秒) @@ -449,24 +449,24 @@ class CostOptimizationSuggestion: class OpsManager: """运维与监控管理主类""" - def __init__(self, db_path: str = DB_PATH): - self.db_path = db_path - self._alert_evaluators: dict[str, Callable] = {} - self._running = False - self._evaluator_thread = None + def __init__(self, db_path: str = DB_PATH) -> None: + self.db_path = db_path + self._alert_evaluators: dict[str, Callable] = {} + self._running = False + self._evaluator_thread = None self._register_default_evaluators() def _get_db(self) -> None: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _register_default_evaluators(self) -> None: """注册默认的告警评估器""" - self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule - self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule - self._alert_evaluators[AlertRuleType.PREDICTIVE.value] = self._evaluate_predictive_rule + self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule + self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule + self._alert_evaluators[AlertRuleType.PREDICTIVE.value] = self._evaluate_predictive_rule # ==================== 告警规则管理 ==================== @@ -488,28 +488,28 @@ class OpsManager: created_by: str, ) -> AlertRule: """创建告警规则""" - rule_id = f"ar_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + rule_id = f"ar_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - rule = AlertRule( - id=rule_id, - tenant_id=tenant_id, - name=name, - description=description, - rule_type=rule_type, - severity=severity, - metric=metric, - condition=condition, - threshold=threshold, - duration=duration, - evaluation_interval=evaluation_interval, - channels=channels, - labels=labels or {}, - annotations=annotations or {}, - is_enabled=True, - created_at=now, - updated_at=now, - created_by=created_by, + rule = AlertRule( + id = rule_id, + tenant_id = tenant_id, + name = name, + description = description, + rule_type = rule_type, + severity = severity, + metric = metric, + condition = condition, + threshold = threshold, + duration = duration, + evaluation_interval = evaluation_interval, + channels = channels, + labels = labels or {}, + annotations = annotations or {}, + is_enabled = True, + created_at = now, + updated_at = now, + created_by = created_by, ) with self._get_db() as conn: @@ -549,16 +549,16 @@ class OpsManager: def get_alert_rule(self, rule_id: str) -> AlertRule | None: """获取告警规则""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id,)).fetchone() + row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id, )).fetchone() if row: return self._row_to_alert_rule(row) return None - def list_alert_rules(self, tenant_id: str, is_enabled: bool | None = None) -> list[AlertRule]: + def list_alert_rules(self, tenant_id: str, is_enabled: bool | None = None) -> list[AlertRule]: """列出租户的所有告警规则""" - query = "SELECT * FROM alert_rules WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM alert_rules WHERE tenant_id = ?" + params = [tenant_id] if is_enabled is not None: query += " AND is_enabled = ?" @@ -567,12 +567,12 @@ class OpsManager: query += " ORDER BY created_at DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_alert_rule(row) for row in rows] def update_alert_rule(self, rule_id: str, **kwargs) -> AlertRule | None: """更新告警规则""" - allowed_fields = [ + allowed_fields = [ "name", "description", "severity", @@ -587,26 +587,26 @@ class OpsManager: "is_enabled", ] - updates = {k: v for k, v in kwargs.items() if k in allowed_fields} + updates = {k: v for k, v in kwargs.items() if k in allowed_fields} if not updates: return self.get_alert_rule(rule_id) # 处理列表和字典字段 if "channels" in updates: - updates["channels"] = json.dumps(updates["channels"]) + updates["channels"] = json.dumps(updates["channels"]) if "labels" in updates: - updates["labels"] = json.dumps(updates["labels"]) + updates["labels"] = json.dumps(updates["labels"]) if "annotations" in updates: - updates["annotations"] = json.dumps(updates["annotations"]) + updates["annotations"] = json.dumps(updates["annotations"]) if "severity" in updates and isinstance(updates["severity"], AlertSeverity): - updates["severity"] = updates["severity"].value + updates["severity"] = updates["severity"].value - updates["updated_at"] = datetime.now().isoformat() + updates["updated_at"] = datetime.now().isoformat() with self._get_db() as conn: - set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) conn.execute( - f"UPDATE alert_rules SET {set_clause} WHERE id = ?", + f"UPDATE alert_rules SET {set_clause} WHERE id = ?", list(updates.values()) + [rule_id], ) conn.commit() @@ -616,7 +616,7 @@ class OpsManager: def delete_alert_rule(self, rule_id: str) -> bool: """删除告警规则""" with self._get_db() as conn: - conn.execute("DELETE FROM alert_rules WHERE id = ?", (rule_id,)) + conn.execute("DELETE FROM alert_rules WHERE id = ?", (rule_id, )) conn.commit() return conn.total_changes > 0 @@ -628,25 +628,25 @@ class OpsManager: name: str, channel_type: AlertChannelType, config: dict, - severity_filter: list[str] = None, + severity_filter: list[str] = None, ) -> AlertChannel: """创建告警渠道""" - channel_id = f"ac_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + channel_id = f"ac_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - channel = AlertChannel( - id=channel_id, - tenant_id=tenant_id, - name=name, - channel_type=channel_type, - config=config, - severity_filter=severity_filter or [s.value for s in AlertSeverity], - is_enabled=True, - success_count=0, - fail_count=0, - last_used_at=None, - created_at=now, - updated_at=now, + channel = AlertChannel( + id = channel_id, + tenant_id = tenant_id, + name = name, + channel_type = channel_type, + config = config, + severity_filter = severity_filter or [s.value for s in AlertSeverity], + is_enabled = True, + success_count = 0, + fail_count = 0, + last_used_at = None, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -679,8 +679,8 @@ class OpsManager: def get_alert_channel(self, channel_id: str) -> AlertChannel | None: """获取告警渠道""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM alert_channels WHERE id = ?", (channel_id,) + row = conn.execute( + "SELECT * FROM alert_channels WHERE id = ?", (channel_id, ) ).fetchone() if row: @@ -690,37 +690,37 @@ class OpsManager: def list_alert_channels(self, tenant_id: str) -> list[AlertChannel]: """列出租户的所有告警渠道""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM alert_channels WHERE tenant_id = ? ORDER BY created_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM alert_channels WHERE tenant_id = ? ORDER BY created_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_alert_channel(row) for row in rows] def test_alert_channel(self, channel_id: str) -> bool: """测试告警渠道""" - channel = self.get_alert_channel(channel_id) + channel = self.get_alert_channel(channel_id) if not channel: return False - test_alert = Alert( - id="test", - rule_id="test", - tenant_id=channel.tenant_id, - severity=AlertSeverity.P3, - status=AlertStatus.FIRING, - title="测试告警", - description="这是一条测试告警消息,用于验证告警渠道配置。", - metric="test_metric", - value=0.0, - threshold=0.0, - labels={"test": "true"}, - annotations={}, - started_at=datetime.now().isoformat(), - resolved_at=None, - acknowledged_by=None, - acknowledged_at=None, - notification_sent={}, - suppression_count=0, + test_alert = Alert( + id = "test", + rule_id = "test", + tenant_id = channel.tenant_id, + severity = AlertSeverity.P3, + status = AlertStatus.FIRING, + title = "测试告警", + description = "这是一条测试告警消息,用于验证告警渠道配置。", + metric = "test_metric", + value = 0.0, + threshold = 0.0, + labels = {"test": "true"}, + annotations = {}, + started_at = datetime.now().isoformat(), + resolved_at = None, + acknowledged_by = None, + acknowledged_at = None, + notification_sent = {}, + suppression_count = 0, ) return asyncio.run(self._send_alert_to_channel(test_alert, channel)) @@ -733,14 +733,14 @@ class OpsManager: return False # 获取最近 duration 秒内的指标 - cutoff_time = datetime.now() - timedelta(seconds=rule.duration) - recent_metrics = [m for m in metrics if datetime.fromisoformat(m.timestamp) > cutoff_time] + cutoff_time = datetime.now() - timedelta(seconds = rule.duration) + recent_metrics = [m for m in metrics if datetime.fromisoformat(m.timestamp) > cutoff_time] if not recent_metrics: return False # 计算平均值 - avg_value = statistics.mean([m.metric_value for m in recent_metrics]) + avg_value = statistics.mean([m.metric_value for m in recent_metrics]) # 评估条件 condition_map = { @@ -752,7 +752,7 @@ class OpsManager: "!=": lambda x, y: x != y, } - evaluator = condition_map.get(rule.condition) + evaluator = condition_map.get(rule.condition) if evaluator: return evaluator(avg_value, rule.threshold) @@ -763,16 +763,16 @@ class OpsManager: if len(metrics) < 10: return False - values = [m.metric_value for m in metrics] - mean = statistics.mean(values) - std = statistics.stdev(values) if len(values) > 1 else 0 + values = [m.metric_value for m in metrics] + mean = statistics.mean(values) + std = statistics.stdev(values) if len(values) > 1 else 0 if std == 0: return False # 最近值偏离均值超过3个标准差视为异常 - latest_value = values[-1] - z_score = abs(latest_value - mean) / std + latest_value = values[-1] + z_score = abs(latest_value - mean) / std return z_score > 3.0 @@ -782,15 +782,15 @@ class OpsManager: return False # 简单的线性趋势预测 - values = [m.metric_value for m in metrics[-10:]] # 最近10个点 - n = len(values) + values = [m.metric_value for m in metrics[-10:]] # 最近10个点 + n = len(values) if n < 2: return False - x = list(range(n)) - mean_x = sum(x) / n - mean_y = sum(values) / n + x = list(range(n)) + mean_x = sum(x) / n + mean_y = sum(values) / n # 计算斜率 numerator = sum((x[i] - mean_x) * (values[i] - mean_y) for i in range(n)) @@ -801,37 +801,37 @@ class OpsManager: predicted = values[-1] + slope # 如果预测值超过阈值,触发告警 - condition_map = { + condition_map = { ">": lambda x, y: x > y, "<": lambda x, y: x < y, } - evaluator = condition_map.get(rule.condition) + evaluator = condition_map.get(rule.condition) if evaluator: return evaluator(predicted, rule.threshold) return False - async def evaluate_alert_rules(self, tenant_id: str): + async def evaluate_alert_rules(self, tenant_id: str) -> None: """评估所有告警规则""" - rules = self.list_alert_rules(tenant_id, is_enabled=True) + rules = self.list_alert_rules(tenant_id, is_enabled = True) for rule in rules: # 获取相关指标 - metrics = self.get_recent_metrics( - tenant_id, rule.metric, seconds=rule.duration + rule.evaluation_interval + metrics = self.get_recent_metrics( + tenant_id, rule.metric, seconds = rule.duration + rule.evaluation_interval ) # 评估规则 - evaluator = self._alert_evaluators.get(rule.rule_type.value) + evaluator = self._alert_evaluators.get(rule.rule_type.value) if evaluator and evaluator(rule, metrics): # 触发告警 await self._trigger_alert(rule, metrics[-1] if metrics else None) - async def _trigger_alert(self, rule: AlertRule, metric: ResourceMetric | None): + async def _trigger_alert(self, rule: AlertRule, metric: ResourceMetric | None) -> None: """触发告警""" # 检查是否已有相同告警在触发中 - existing = self.get_active_alert_by_rule(rule.id) + existing = self.get_active_alert_by_rule(rule.id) if existing: # 更新抑制计数 self._increment_suppression_count(existing.id) @@ -841,28 +841,28 @@ class OpsManager: if self._is_alert_suppressed(rule): return - alert_id = f"al_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + alert_id = f"al_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - alert = Alert( - id=alert_id, - rule_id=rule.id, - tenant_id=rule.tenant_id, - severity=rule.severity, - status=AlertStatus.FIRING, - title=rule.annotations.get("summary", f"告警: {rule.name}"), - description=rule.annotations.get("description", rule.description), - metric=rule.metric, - value=metric.metric_value if metric else 0.0, - threshold=rule.threshold, - labels=rule.labels, - annotations=rule.annotations, - started_at=now, - resolved_at=None, - acknowledged_by=None, - acknowledged_at=None, - notification_sent={}, - suppression_count=0, + alert = Alert( + id = alert_id, + rule_id = rule.id, + tenant_id = rule.tenant_id, + severity = rule.severity, + status = AlertStatus.FIRING, + title = rule.annotations.get("summary", f"告警: {rule.name}"), + description = rule.annotations.get("description", rule.description), + metric = rule.metric, + value = metric.metric_value if metric else 0.0, + threshold = rule.threshold, + labels = rule.labels, + annotations = rule.annotations, + started_at = now, + resolved_at = None, + acknowledged_by = None, + acknowledged_at = None, + notification_sent = {}, + suppression_count = 0, ) # 保存告警 @@ -898,11 +898,11 @@ class OpsManager: # 发送告警通知 await self._send_alert_notifications(alert, rule) - async def _send_alert_notifications(self, alert: Alert, rule: AlertRule): + async def _send_alert_notifications(self, alert: Alert, rule: AlertRule) -> None: """发送告警通知到所有配置的渠道""" - channels = [] + channels = [] for channel_id in rule.channels: - channel = self.get_alert_channel(channel_id) + channel = self.get_alert_channel(channel_id) if channel and channel.is_enabled: channels.append(channel) @@ -911,10 +911,10 @@ class OpsManager: if alert.severity.value not in channel.severity_filter: continue - success = await self._send_alert_to_channel(alert, channel) + success = await self._send_alert_to_channel(alert, channel) # 更新发送状态 - alert.notification_sent[channel.id] = success + alert.notification_sent[channel.id] = success self._update_alert_notification_status(alert.id, channel.id, success) async def _send_alert_to_channel(self, alert: Alert, channel: AlertChannel) -> bool: @@ -942,22 +942,22 @@ class OpsManager: async def _send_feishu_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送飞书告警""" - config = channel.config - webhook_url = config.get("webhook_url") + config = channel.config + webhook_url = config.get("webhook_url") config.get("secret", "") if not webhook_url: return False # 构建飞书消息 - severity_colors = { + severity_colors = { AlertSeverity.P0.value: "red", AlertSeverity.P1.value: "orange", AlertSeverity.P2.value: "yellow", AlertSeverity.P3.value: "blue", } - message = { + message = { "msg_type": "interactive", "card": { "config": {"wide_screen_mode": True}, @@ -990,27 +990,27 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json=message, timeout=30.0) - success = response.status_code == 200 + response = await client.post(webhook_url, json = message, timeout = 30.0) + success = response.status_code == 200 if success: - self._update_channel_stats(channel.id, success=True) + self._update_channel_stats(channel.id, success = True) else: - self._update_channel_stats(channel.id, success=False) + self._update_channel_stats(channel.id, success = False) return success async def _send_dingtalk_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送钉钉告警""" - config = channel.config - webhook_url = config.get("webhook_url") + config = channel.config + webhook_url = config.get("webhook_url") config.get("secret", "") if not webhook_url: return False # 构建钉钉消息 - message = { + message = { "msgtype": "markdown", "markdown": { "title": f"[{alert.severity.value.upper()}] {alert.title}", @@ -1024,29 +1024,29 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json=message, timeout=30.0) - success = response.status_code == 200 + response = await client.post(webhook_url, json = message, timeout = 30.0) + success = response.status_code == 200 self._update_channel_stats(channel.id, success) return success async def _send_slack_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 Slack 告警""" - config = channel.config - webhook_url = config.get("webhook_url") + config = channel.config + webhook_url = config.get("webhook_url") if not webhook_url: return False - severity_emojis = { + severity_emojis = { AlertSeverity.P0.value: "🔴", AlertSeverity.P1.value: "🟠", AlertSeverity.P2.value: "🟡", AlertSeverity.P3.value: "🔵", } - emoji = severity_emojis.get(alert.severity.value, "⚪") + emoji = severity_emojis.get(alert.severity.value, "⚪") - message = { + message = { "text": f"{emoji} [{alert.severity.value.upper()}] {alert.title}", "blocks": [ { @@ -1073,20 +1073,20 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json=message, timeout=30.0) - success = response.status_code == 200 + response = await client.post(webhook_url, json = message, timeout = 30.0) + success = response.status_code == 200 self._update_channel_stats(channel.id, success) return success async def _send_email_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送邮件告警(模拟实现)""" # 实际实现需要集成邮件服务如 SendGrid、AWS SES 等 - config = channel.config - smtp_host = config.get("smtp_host") + config = channel.config + smtp_host = config.get("smtp_host") config.get("smtp_port", 587) - username = config.get("username") - password = config.get("password") - to_addresses = config.get("to_addresses", []) + username = config.get("username") + password = config.get("password") + to_addresses = config.get("to_addresses", []) if not all([smtp_host, username, password, to_addresses]): return False @@ -1097,20 +1097,20 @@ class OpsManager: async def _send_pagerduty_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 PagerDuty 告警""" - config = channel.config - integration_key = config.get("integration_key") + config = channel.config + integration_key = config.get("integration_key") if not integration_key: return False - severity_map = { + severity_map = { AlertSeverity.P0.value: "critical", AlertSeverity.P1.value: "error", AlertSeverity.P2.value: "warning", AlertSeverity.P3.value: "info", } - message = { + message = { "routing_key": integration_key, "event_action": "trigger", "dedup_key": alert.id, @@ -1128,29 +1128,29 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post( - "https://events.pagerduty.com/v2/enqueue", json=message, timeout=30.0 + response = await client.post( + "https://events.pagerduty.com/v2/enqueue", json = message, timeout = 30.0 ) - success = response.status_code == 202 + success = response.status_code == 202 self._update_channel_stats(channel.id, success) return success async def _send_opsgenie_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 Opsgenie 告警""" - config = channel.config - api_key = config.get("api_key") + config = channel.config + api_key = config.get("api_key") if not api_key: return False - priority_map = { + priority_map = { AlertSeverity.P0.value: "P1", AlertSeverity.P1.value: "P2", AlertSeverity.P2.value: "P3", AlertSeverity.P3.value: "P4", } - message = { + message = { "message": alert.title, "description": alert.description, "priority": priority_map.get(alert.severity.value, "P3"), @@ -1163,26 +1163,26 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.post( "https://api.opsgenie.com/v2/alerts", - json=message, - headers={"Authorization": f"GenieKey {api_key}"}, - timeout=30.0, + json = message, + headers = {"Authorization": f"GenieKey {api_key}"}, + timeout = 30.0, ) - success = response.status_code in [200, 201, 202] + success = response.status_code in [200, 201, 202] self._update_channel_stats(channel.id, success) return success async def _send_webhook_alert(self, alert: Alert, channel: AlertChannel) -> bool: """发送 Webhook 告警""" - config = channel.config - webhook_url = config.get("webhook_url") - headers = config.get("headers", {}) + config = channel.config + webhook_url = config.get("webhook_url") + headers = config.get("headers", {}) if not webhook_url: return False - message = { + message = { "alert_id": alert.id, "severity": alert.severity.value, "status": alert.status.value, @@ -1196,8 +1196,8 @@ class OpsManager: } async with httpx.AsyncClient() as client: - response = await client.post(webhook_url, json=message, headers=headers, timeout=30.0) - success = response.status_code in [200, 201, 202] + response = await client.post(webhook_url, json = message, headers = headers, timeout = 30.0) + success = response.status_code in [200, 201, 202] self._update_channel_stats(channel.id, success) return success @@ -1206,9 +1206,9 @@ class OpsManager: def get_active_alert_by_rule(self, rule_id: str) -> Alert | None: """获取规则对应的活跃告警""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT * FROM alerts - WHERE rule_id = ? AND status = ? + WHERE rule_id = ? AND status = ? ORDER BY started_at DESC LIMIT 1""", (rule_id, AlertStatus.FIRING.value), ).fetchone() @@ -1220,7 +1220,7 @@ class OpsManager: def get_alert(self, alert_id: str) -> Alert | None: """获取告警详情""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id,)).fetchone() + row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id, )).fetchone() if row: return self._row_to_alert(row) @@ -1229,38 +1229,38 @@ class OpsManager: def list_alerts( self, tenant_id: str, - status: AlertStatus | None = None, - severity: AlertSeverity | None = None, - limit: int = 100, + status: AlertStatus | None = None, + severity: AlertSeverity | None = None, + limit: int = 100, ) -> list[Alert]: """列出租户的告警""" - query = "SELECT * FROM alerts WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM alerts WHERE tenant_id = ?" + params = [tenant_id] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status.value) if severity: - query += " AND severity = ?" + query += " AND severity = ?" params.append(severity.value) query += " ORDER BY started_at DESC LIMIT ?" params.append(limit) with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_alert(row) for row in rows] def acknowledge_alert(self, alert_id: str, user_id: str) -> Alert | None: """确认告警""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE alerts - SET status = ?, acknowledged_by = ?, acknowledged_at = ? - WHERE id = ? + SET status = ?, acknowledged_by = ?, acknowledged_at = ? + WHERE id = ? """, (AlertStatus.ACKNOWLEDGED.value, user_id, now, alert_id), ) @@ -1270,14 +1270,14 @@ class OpsManager: def resolve_alert(self, alert_id: str) -> Alert | None: """解决告警""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE alerts - SET status = ?, resolved_at = ? - WHERE id = ? + SET status = ?, resolved_at = ? + WHERE id = ? """, (AlertStatus.RESOLVED.value, now, alert_id), ) @@ -1291,10 +1291,10 @@ class OpsManager: conn.execute( """ UPDATE alerts - SET suppression_count = suppression_count + 1 - WHERE id = ? + SET suppression_count = suppression_count + 1 + WHERE id = ? """, - (alert_id,), + (alert_id, ), ) conn.commit() @@ -1303,31 +1303,31 @@ class OpsManager: ) -> None: """更新告警通知状态""" with self._get_db() as conn: - row = conn.execute( - "SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,) + row = conn.execute( + "SELECT notification_sent FROM alerts WHERE id = ?", (alert_id, ) ).fetchone() if row: - notification_sent = json.loads(row["notification_sent"]) - notification_sent[channel_id] = success + notification_sent = json.loads(row["notification_sent"]) + notification_sent[channel_id] = success conn.execute( - "UPDATE alerts SET notification_sent = ? WHERE id = ?", + "UPDATE alerts SET notification_sent = ? WHERE id = ?", (json.dumps(notification_sent), alert_id), ) conn.commit() def _update_channel_stats(self, channel_id: str, success: bool) -> None: """更新渠道统计""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: if success: conn.execute( """ UPDATE alert_channels - SET success_count = success_count + 1, last_used_at = ? - WHERE id = ? + SET success_count = success_count + 1, last_used_at = ? + WHERE id = ? """, (now, channel_id), ) @@ -1335,8 +1335,8 @@ class OpsManager: conn.execute( """ UPDATE alert_channels - SET fail_count = fail_count + 1, last_used_at = ? - WHERE id = ? + SET fail_count = fail_count + 1, last_used_at = ? + WHERE id = ? """, (now, channel_id), ) @@ -1350,22 +1350,22 @@ class OpsManager: name: str, matchers: dict[str, str], duration: int, - is_regex: bool = False, - expires_at: str | None = None, + is_regex: bool = False, + expires_at: str | None = None, ) -> AlertSuppressionRule: """创建告警抑制规则""" - rule_id = f"sr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + rule_id = f"sr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - rule = AlertSuppressionRule( - id=rule_id, - tenant_id=tenant_id, - name=name, - matchers=matchers, - duration=duration, - is_regex=is_regex, - created_at=now, - expires_at=expires_at, + rule = AlertSuppressionRule( + id = rule_id, + tenant_id = tenant_id, + name = name, + matchers = matchers, + duration = duration, + is_regex = is_regex, + created_at = now, + expires_at = expires_at, ) with self._get_db() as conn: @@ -1393,12 +1393,12 @@ class OpsManager: def _is_alert_suppressed(self, rule: AlertRule) -> bool: """检查告警是否被抑制""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM alert_suppression_rules WHERE tenant_id = ?", (rule.tenant_id,) + rows = conn.execute( + "SELECT * FROM alert_suppression_rules WHERE tenant_id = ?", (rule.tenant_id, ) ).fetchall() for row in rows: - suppression_rule = self._row_to_suppression_rule(row) + suppression_rule = self._row_to_suppression_rule(row) # 检查是否过期 if suppression_rule.expires_at: @@ -1406,19 +1406,19 @@ class OpsManager: continue # 检查匹配 - matchers = suppression_rule.matchers - match = True + matchers = suppression_rule.matchers + match = True for key, pattern in matchers.items(): - value = rule.labels.get(key, "") + value = rule.labels.get(key, "") if suppression_rule.is_regex: if not re.match(pattern, value): - match = False + match = False break else: if value != pattern: - match = False + match = False break if match: @@ -1436,22 +1436,22 @@ class OpsManager: metric_name: str, metric_value: float, unit: str, - metadata: dict = None, + metadata: dict = None, ) -> ResourceMetric: """记录资源指标""" - metric_id = f"rm_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + metric_id = f"rm_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - metric = ResourceMetric( - id=metric_id, - tenant_id=tenant_id, - resource_type=resource_type, - resource_id=resource_id, - metric_name=metric_name, - metric_value=metric_value, - unit=unit, - timestamp=now, - metadata=metadata or {}, + metric = ResourceMetric( + id = metric_id, + tenant_id = tenant_id, + resource_type = resource_type, + resource_id = resource_id, + metric_name = metric_name, + metric_value = metric_value, + unit = unit, + timestamp = now, + metadata = metadata or {}, ) with self._get_db() as conn: @@ -1479,15 +1479,15 @@ class OpsManager: return metric def get_recent_metrics( - self, tenant_id: str, metric_name: str, seconds: int = 3600 + self, tenant_id: str, metric_name: str, seconds: int = 3600 ) -> list[ResourceMetric]: """获取最近的指标数据""" - cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat() + cutoff_time = (datetime.now() - timedelta(seconds = seconds)).isoformat() with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM resource_metrics - WHERE tenant_id = ? AND metric_name = ? AND timestamp > ? + WHERE tenant_id = ? AND metric_name = ? AND timestamp > ? ORDER BY timestamp DESC""", (tenant_id, metric_name, cutoff_time), ).fetchall() @@ -1505,10 +1505,10 @@ class OpsManager: ) -> list[ResourceMetric]: """获取指定资源的指标数据""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM resource_metrics - WHERE tenant_id = ? AND resource_type = ? AND resource_id = ? - AND metric_name = ? AND timestamp BETWEEN ? AND ? + WHERE tenant_id = ? AND resource_type = ? AND resource_id = ? + AND metric_name = ? AND timestamp BETWEEN ? AND ? ORDER BY timestamp ASC""", (tenant_id, resource_type.value, resource_id, metric_name, start_time, end_time), ).fetchall() @@ -1523,51 +1523,51 @@ class OpsManager: resource_type: ResourceType, current_capacity: float, prediction_date: str, - confidence: float = 0.8, + confidence: float = 0.8, ) -> CapacityPlan: """创建容量规划""" - plan_id = f"cp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + plan_id = f"cp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 基于历史数据预测 - metrics = self.get_recent_metrics( - tenant_id, f"{resource_type.value}_usage", seconds=30 * 24 * 3600 + metrics = self.get_recent_metrics( + tenant_id, f"{resource_type.value}_usage", seconds = 30 * 24 * 3600 ) if metrics: - values = [m.metric_value for m in metrics] - trend = self._calculate_trend(values) + values = [m.metric_value for m in metrics] + trend = self._calculate_trend(values) # 预测未来容量需求 - days_ahead = (datetime.fromisoformat(prediction_date) - datetime.now()).days - predicted_capacity = current_capacity * (1 + trend * days_ahead / 30) + days_ahead = (datetime.fromisoformat(prediction_date) - datetime.now()).days + predicted_capacity = current_capacity * (1 + trend * days_ahead / 30) # 推荐操作 if predicted_capacity > current_capacity * 1.2: - recommended_action = "scale_up" - estimated_cost = (predicted_capacity - current_capacity) * 10 # 简化计算 + recommended_action = "scale_up" + estimated_cost = (predicted_capacity - current_capacity) * 10 # 简化计算 elif predicted_capacity < current_capacity * 0.5: - recommended_action = "scale_down" - estimated_cost = 0 + recommended_action = "scale_down" + estimated_cost = 0 else: - recommended_action = "maintain" - estimated_cost = 0 + recommended_action = "maintain" + estimated_cost = 0 else: - predicted_capacity = current_capacity - recommended_action = "insufficient_data" - estimated_cost = 0 + predicted_capacity = current_capacity + recommended_action = "insufficient_data" + estimated_cost = 0 - plan = CapacityPlan( - id=plan_id, - tenant_id=tenant_id, - resource_type=resource_type, - current_capacity=current_capacity, - predicted_capacity=predicted_capacity, - prediction_date=prediction_date, - confidence=confidence, - recommended_action=recommended_action, - estimated_cost=estimated_cost, - created_at=now, + plan = CapacityPlan( + id = plan_id, + tenant_id = tenant_id, + resource_type = resource_type, + current_capacity = current_capacity, + predicted_capacity = predicted_capacity, + prediction_date = prediction_date, + confidence = confidence, + recommended_action = recommended_action, + estimated_cost = estimated_cost, + created_at = now, ) with self._get_db() as conn: @@ -1601,21 +1601,21 @@ class OpsManager: return 0.0 # 使用最近的数据计算趋势 - recent = values[-10:] if len(values) > 10 else values - n = len(recent) + recent = values[-10:] if len(values) > 10 else values + n = len(recent) if n < 2: return 0.0 # 简单线性回归计算斜率 - x = list(range(n)) - mean_x = sum(x) / n - mean_y = sum(recent) / n + x = list(range(n)) + mean_x = sum(x) / n + mean_y = sum(recent) / n - numerator = sum((x[i] - mean_x) * (recent[i] - mean_y) for i in range(n)) - denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) + numerator = sum((x[i] - mean_x) * (recent[i] - mean_y) for i in range(n)) + denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) - slope = numerator / denominator if denominator != 0 else 0 + slope = numerator / denominator if denominator != 0 else 0 # 归一化为增长率 if mean_y != 0: @@ -1625,9 +1625,9 @@ class OpsManager: def get_capacity_plans(self, tenant_id: str) -> list[CapacityPlan]: """获取容量规划列表""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM capacity_plans WHERE tenant_id = ? ORDER BY created_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM capacity_plans WHERE tenant_id = ? ORDER BY created_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_capacity_plan(row) for row in rows] @@ -1643,30 +1643,30 @@ class OpsManager: target_utilization: float, scale_up_threshold: float, scale_down_threshold: float, - scale_up_step: int = 1, - scale_down_step: int = 1, - cooldown_period: int = 300, + scale_up_step: int = 1, + scale_down_step: int = 1, + cooldown_period: int = 300, ) -> AutoScalingPolicy: """创建自动扩缩容策略""" - policy_id = f"asp_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + policy_id = f"asp_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - policy = AutoScalingPolicy( - id=policy_id, - tenant_id=tenant_id, - name=name, - resource_type=resource_type, - min_instances=min_instances, - max_instances=max_instances, - target_utilization=target_utilization, - scale_up_threshold=scale_up_threshold, - scale_down_threshold=scale_down_threshold, - scale_up_step=scale_up_step, - scale_down_step=scale_down_step, - cooldown_period=cooldown_period, - is_enabled=True, - created_at=now, - updated_at=now, + policy = AutoScalingPolicy( + id = policy_id, + tenant_id = tenant_id, + name = name, + resource_type = resource_type, + min_instances = min_instances, + max_instances = max_instances, + target_utilization = target_utilization, + scale_up_threshold = scale_up_threshold, + scale_down_threshold = scale_down_threshold, + scale_up_step = scale_up_step, + scale_down_step = scale_down_step, + cooldown_period = cooldown_period, + is_enabled = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1703,8 +1703,8 @@ class OpsManager: def get_auto_scaling_policy(self, policy_id: str) -> AutoScalingPolicy | None: """获取自动扩缩容策略""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,) + row = conn.execute( + "SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id, ) ).fetchone() if row: @@ -1714,9 +1714,9 @@ class OpsManager: def list_auto_scaling_policies(self, tenant_id: str) -> list[AutoScalingPolicy]: """列出租户的自动扩缩容策略""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM auto_scaling_policies WHERE tenant_id = ? ORDER BY created_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM auto_scaling_policies WHERE tenant_id = ? ORDER BY created_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_auto_scaling_policy(row) for row in rows] @@ -1724,36 +1724,36 @@ class OpsManager: self, policy_id: str, current_instances: int, current_utilization: float ) -> ScalingEvent | None: """评估扩缩容策略""" - policy = self.get_auto_scaling_policy(policy_id) + policy = self.get_auto_scaling_policy(policy_id) if not policy or not policy.is_enabled: return None # 检查是否在冷却期 - last_event = self.get_last_scaling_event(policy_id) + last_event = self.get_last_scaling_event(policy_id) if last_event: - last_time = datetime.fromisoformat(last_event.started_at) + last_time = datetime.fromisoformat(last_event.started_at) if (datetime.now() - last_time).total_seconds() < policy.cooldown_period: return None - action = None - reason = "" + action = None + reason = "" if current_utilization > policy.scale_up_threshold: if current_instances < policy.max_instances: - action = ScalingAction.SCALE_UP - reason = ( + action = ScalingAction.SCALE_UP + reason = ( f"利用率 {current_utilization:.1%} 超过扩容阈值 {policy.scale_up_threshold:.1%}" ) elif current_utilization < policy.scale_down_threshold: if current_instances > policy.min_instances: - action = ScalingAction.SCALE_DOWN - reason = f"利用率 {current_utilization:.1%} 低于缩容阈值 {policy.scale_down_threshold:.1%}" + action = ScalingAction.SCALE_DOWN + reason = f"利用率 {current_utilization:.1%} 低于缩容阈值 {policy.scale_down_threshold:.1%}" if action: if action == ScalingAction.SCALE_UP: - new_count = min(current_instances + policy.scale_up_step, policy.max_instances) + new_count = min(current_instances + policy.scale_up_step, policy.max_instances) else: - new_count = max(current_instances - policy.scale_down_step, policy.min_instances) + new_count = max(current_instances - policy.scale_down_step, policy.min_instances) return self._create_scaling_event(policy, action, current_instances, new_count, reason) @@ -1768,22 +1768,22 @@ class OpsManager: reason: str, ) -> ScalingEvent: """创建扩缩容事件""" - event_id = f"se_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + event_id = f"se_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - event = ScalingEvent( - id=event_id, - policy_id=policy.id, - tenant_id=policy.tenant_id, - action=action, - from_count=from_count, - to_count=to_count, - reason=reason, - triggered_by="auto", - status="pending", - started_at=now, - completed_at=None, - error_message=None, + event = ScalingEvent( + id = event_id, + policy_id = policy.id, + tenant_id = policy.tenant_id, + action = action, + from_count = from_count, + to_count = to_count, + reason = reason, + triggered_by = "auto", + status = "pending", + started_at = now, + completed_at = None, + error_message = None, ) with self._get_db() as conn: @@ -1814,11 +1814,11 @@ class OpsManager: def get_last_scaling_event(self, policy_id: str) -> ScalingEvent | None: """获取最近的扩缩容事件""" with self._get_db() as conn: - row = conn.execute( + row = conn.execute( """SELECT * FROM scaling_events - WHERE policy_id = ? + WHERE policy_id = ? ORDER BY started_at DESC LIMIT 1""", - (policy_id,), + (policy_id, ), ).fetchone() if row: @@ -1826,18 +1826,18 @@ class OpsManager: return None def update_scaling_event_status( - self, event_id: str, status: str, error_message: str = None + self, event_id: str, status: str, error_message: str = None ) -> ScalingEvent | None: """更新扩缩容事件状态""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: if status in ["completed", "failed"]: conn.execute( """ UPDATE scaling_events - SET status = ?, completed_at = ?, error_message = ? - WHERE id = ? + SET status = ?, completed_at = ?, error_message = ? + WHERE id = ? """, (status, now, error_message, event_id), ) @@ -1845,8 +1845,8 @@ class OpsManager: conn.execute( """ UPDATE scaling_events - SET status = ?, error_message = ? - WHERE id = ? + SET status = ?, error_message = ? + WHERE id = ? """, (status, error_message, event_id), ) @@ -1857,28 +1857,28 @@ class OpsManager: def get_scaling_event(self, event_id: str) -> ScalingEvent | None: """获取扩缩容事件""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id,)).fetchone() + row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id, )).fetchone() if row: return self._row_to_scaling_event(row) return None def list_scaling_events( - self, tenant_id: str, policy_id: str = None, limit: int = 100 + self, tenant_id: str, policy_id: str = None, limit: int = 100 ) -> list[ScalingEvent]: """列出租户的扩缩容事件""" - query = "SELECT * FROM scaling_events WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM scaling_events WHERE tenant_id = ?" + params = [tenant_id] if policy_id: - query += " AND policy_id = ?" + query += " AND policy_id = ?" params.append(policy_id) query += " ORDER BY started_at DESC LIMIT ?" params.append(limit) with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_scaling_event(row) for row in rows] # ==================== 健康检查与故障转移 ==================== @@ -1891,30 +1891,30 @@ class OpsManager: target_id: str, check_type: str, check_config: dict, - interval: int = 60, - timeout: int = 10, - retry_count: int = 3, + interval: int = 60, + timeout: int = 10, + retry_count: int = 3, ) -> HealthCheck: """创建健康检查""" - check_id = f"hc_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + check_id = f"hc_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - check = HealthCheck( - id=check_id, - tenant_id=tenant_id, - name=name, - target_type=target_type, - target_id=target_id, - check_type=check_type, - check_config=check_config, - interval=interval, - timeout=timeout, - retry_count=retry_count, - healthy_threshold=2, - unhealthy_threshold=3, - is_enabled=True, - created_at=now, - updated_at=now, + check = HealthCheck( + id = check_id, + tenant_id = tenant_id, + name = name, + target_type = target_type, + target_id = target_id, + check_type = check_type, + check_config = check_config, + interval = interval, + timeout = timeout, + retry_count = retry_count, + healthy_threshold = 2, + unhealthy_threshold = 3, + is_enabled = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -1951,7 +1951,7 @@ class OpsManager: def get_health_check(self, check_id: str) -> HealthCheck | None: """获取健康检查配置""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id,)).fetchone() + row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id, )).fetchone() if row: return self._row_to_health_check(row) @@ -1960,40 +1960,40 @@ class OpsManager: def list_health_checks(self, tenant_id: str) -> list[HealthCheck]: """列出租户的健康检查""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM health_checks WHERE tenant_id = ? ORDER BY created_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM health_checks WHERE tenant_id = ? ORDER BY created_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_health_check(row) for row in rows] async def execute_health_check(self, check_id: str) -> HealthCheckResult: """执行健康检查""" - check = self.get_health_check(check_id) + check = self.get_health_check(check_id) if not check: raise ValueError(f"Health check {check_id} not found") - result_id = f"hcr_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + result_id = f"hcr_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 模拟健康检查(实际实现需要根据 check_type 执行具体检查) if check.check_type == "http": - status, response_time, message = await self._check_http_health(check) + status, response_time, message = await self._check_http_health(check) elif check.check_type == "tcp": - status, response_time, message = await self._check_tcp_health(check) + status, response_time, message = await self._check_tcp_health(check) elif check.check_type == "ping": - status, response_time, message = await self._check_ping_health(check) + status, response_time, message = await self._check_ping_health(check) else: - status, response_time, message = HealthStatus.UNKNOWN, 0, "Unknown check type" + status, response_time, message = HealthStatus.UNKNOWN, 0, "Unknown check type" - result = HealthCheckResult( - id=result_id, - check_id=check_id, - tenant_id=check.tenant_id, - status=status, - response_time=response_time, - message=message, - details={}, - checked_at=now, + result = HealthCheckResult( + id = result_id, + check_id = check_id, + tenant_id = check.tenant_id, + status = status, + response_time = response_time, + message = message, + details = {}, + checked_at = now, ) with self._get_db() as conn: @@ -2020,18 +2020,18 @@ class OpsManager: async def _check_http_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]: """HTTP 健康检查""" - config = check.check_config - url = config.get("url") - expected_status = config.get("expected_status", 200) + config = check.check_config + url = config.get("url") + expected_status = config.get("expected_status", 200) if not url: return HealthStatus.UNHEALTHY, 0, "URL not configured" - start_time = time.time() + start_time = time.time() try: async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=check.timeout) - response_time = (time.time() - start_time) * 1000 + response = await client.get(url, timeout = check.timeout) + response_time = (time.time() - start_time) * 1000 if response.status_code == expected_status: return HealthStatus.HEALTHY, response_time, "OK" @@ -2046,19 +2046,19 @@ class OpsManager: async def _check_tcp_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]: """TCP 健康检查""" - config = check.check_config - host = config.get("host") - port = config.get("port") + config = check.check_config + host = config.get("host") + port = config.get("port") if not host or not port: return HealthStatus.UNHEALTHY, 0, "Host or port not configured" - start_time = time.time() + start_time = time.time() try: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(host, port), timeout=check.timeout + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout = check.timeout ) - response_time = (time.time() - start_time) * 1000 + response_time = (time.time() - start_time) * 1000 writer.close() await writer.wait_closed() return HealthStatus.HEALTHY, response_time, "TCP connection successful" @@ -2069,8 +2069,8 @@ class OpsManager: async def _check_ping_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]: """Ping 健康检查(模拟)""" - config = check.check_config - host = config.get("host") + config = check.check_config + host = config.get("host") if not host: return HealthStatus.UNHEALTHY, 0, "Host not configured" @@ -2079,12 +2079,12 @@ class OpsManager: # 这里模拟成功 return HealthStatus.HEALTHY, 10.0, "Ping successful" - def get_health_check_results(self, check_id: str, limit: int = 100) -> list[HealthCheckResult]: + def get_health_check_results(self, check_id: str, limit: int = 100) -> list[HealthCheckResult]: """获取健康检查历史结果""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM health_check_results - WHERE check_id = ? + WHERE check_id = ? ORDER BY checked_at DESC LIMIT ?""", (check_id, limit), ).fetchall() @@ -2099,27 +2099,27 @@ class OpsManager: primary_region: str, secondary_regions: list[str], failover_trigger: str, - auto_failover: bool = False, - failover_timeout: int = 300, - health_check_id: str = None, + auto_failover: bool = False, + failover_timeout: int = 300, + health_check_id: str = None, ) -> FailoverConfig: """创建故障转移配置""" - config_id = f"fc_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + config_id = f"fc_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - config = FailoverConfig( - id=config_id, - tenant_id=tenant_id, - name=name, - primary_region=primary_region, - secondary_regions=secondary_regions, - failover_trigger=failover_trigger, - auto_failover=auto_failover, - failover_timeout=failover_timeout, - health_check_id=health_check_id, - is_enabled=True, - created_at=now, - updated_at=now, + config = FailoverConfig( + id = config_id, + tenant_id = tenant_id, + name = name, + primary_region = primary_region, + secondary_regions = secondary_regions, + failover_trigger = failover_trigger, + auto_failover = auto_failover, + failover_timeout = failover_timeout, + health_check_id = health_check_id, + is_enabled = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -2152,8 +2152,8 @@ class OpsManager: def get_failover_config(self, config_id: str) -> FailoverConfig | None: """获取故障转移配置""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM failover_configs WHERE id = ?", (config_id,) + row = conn.execute( + "SELECT * FROM failover_configs WHERE id = ?", (config_id, ) ).fetchone() if row: @@ -2163,38 +2163,38 @@ class OpsManager: def list_failover_configs(self, tenant_id: str) -> list[FailoverConfig]: """列出租户的故障转移配置""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM failover_configs WHERE tenant_id = ? ORDER BY created_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM failover_configs WHERE tenant_id = ? ORDER BY created_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_failover_config(row) for row in rows] def initiate_failover(self, config_id: str, reason: str) -> FailoverEvent | None: """发起故障转移""" - config = self.get_failover_config(config_id) + config = self.get_failover_config(config_id) if not config or not config.is_enabled: return None - event_id = f"fe_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + event_id = f"fe_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() # 选择备用区域 - to_region = config.secondary_regions[0] if config.secondary_regions else None + to_region = config.secondary_regions[0] if config.secondary_regions else None if not to_region: return None - event = FailoverEvent( - id=event_id, - config_id=config_id, - tenant_id=config.tenant_id, - from_region=config.primary_region, - to_region=to_region, - reason=reason, - status="initiated", - started_at=now, - completed_at=None, - rolled_back_at=None, + event = FailoverEvent( + id = event_id, + config_id = config_id, + tenant_id = config.tenant_id, + from_region = config.primary_region, + to_region = to_region, + reason = reason, + status = "initiated", + started_at = now, + completed_at = None, + rolled_back_at = None, ) with self._get_db() as conn: @@ -2221,15 +2221,15 @@ class OpsManager: def update_failover_status(self, event_id: str, status: str) -> FailoverEvent | None: """更新故障转移状态""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: if status == "completed": conn.execute( """ UPDATE failover_events - SET status = ?, completed_at = ? - WHERE id = ? + SET status = ?, completed_at = ? + WHERE id = ? """, (status, now, event_id), ) @@ -2237,8 +2237,8 @@ class OpsManager: conn.execute( """ UPDATE failover_events - SET status = ?, rolled_back_at = ? - WHERE id = ? + SET status = ?, rolled_back_at = ? + WHERE id = ? """, (status, now, event_id), ) @@ -2246,8 +2246,8 @@ class OpsManager: conn.execute( """ UPDATE failover_events - SET status = ? - WHERE id = ? + SET status = ? + WHERE id = ? """, (status, event_id), ) @@ -2258,18 +2258,18 @@ class OpsManager: def get_failover_event(self, event_id: str) -> FailoverEvent | None: """获取故障转移事件""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id,)).fetchone() + row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id, )).fetchone() if row: return self._row_to_failover_event(row) return None - def list_failover_events(self, tenant_id: str, limit: int = 100) -> list[FailoverEvent]: + def list_failover_events(self, tenant_id: str, limit: int = 100) -> list[FailoverEvent]: """列出租户的故障转移事件""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM failover_events - WHERE tenant_id = ? + WHERE tenant_id = ? ORDER BY started_at DESC LIMIT ?""", (tenant_id, limit), ).fetchall() @@ -2285,30 +2285,30 @@ class OpsManager: target_type: str, target_id: str, schedule: str, - retention_days: int = 30, - encryption_enabled: bool = True, - compression_enabled: bool = True, - storage_location: str = None, + retention_days: int = 30, + encryption_enabled: bool = True, + compression_enabled: bool = True, + storage_location: str = None, ) -> BackupJob: """创建备份任务""" - job_id = f"bj_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + job_id = f"bj_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - job = BackupJob( - id=job_id, - tenant_id=tenant_id, - name=name, - backup_type=backup_type, - target_type=target_type, - target_id=target_id, - schedule=schedule, - retention_days=retention_days, - encryption_enabled=encryption_enabled, - compression_enabled=compression_enabled, - storage_location=storage_location or f"backups/{tenant_id}", - is_enabled=True, - created_at=now, - updated_at=now, + job = BackupJob( + id = job_id, + tenant_id = tenant_id, + name = name, + backup_type = backup_type, + target_type = target_type, + target_id = target_id, + schedule = schedule, + retention_days = retention_days, + encryption_enabled = encryption_enabled, + compression_enabled = compression_enabled, + storage_location = storage_location or f"backups/{tenant_id}", + is_enabled = True, + created_at = now, + updated_at = now, ) with self._get_db() as conn: @@ -2344,7 +2344,7 @@ class OpsManager: def get_backup_job(self, job_id: str) -> BackupJob | None: """获取备份任务""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id,)).fetchone() + row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id, )).fetchone() if row: return self._row_to_backup_job(row) @@ -2353,33 +2353,33 @@ class OpsManager: def list_backup_jobs(self, tenant_id: str) -> list[BackupJob]: """列出租户的备份任务""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM backup_jobs WHERE tenant_id = ? ORDER BY created_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM backup_jobs WHERE tenant_id = ? ORDER BY created_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_backup_job(row) for row in rows] def execute_backup(self, job_id: str) -> BackupRecord | None: """执行备份""" - job = self.get_backup_job(job_id) + job = self.get_backup_job(job_id) if not job or not job.is_enabled: return None - record_id = f"br_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + record_id = f"br_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - record = BackupRecord( - id=record_id, - job_id=job_id, - tenant_id=job.tenant_id, - status=BackupStatus.IN_PROGRESS, - size_bytes=0, - checksum="", - started_at=now, - completed_at=None, - verified_at=None, - error_message=None, - storage_path=f"{job.storage_location}/{record_id}", + record = BackupRecord( + id = record_id, + job_id = job_id, + tenant_id = job.tenant_id, + status = BackupStatus.IN_PROGRESS, + size_bytes = 0, + checksum = "", + started_at = now, + completed_at = None, + verified_at = None, + error_message = None, + storage_path = f"{job.storage_location}/{record_id}", ) with self._get_db() as conn: @@ -2404,21 +2404,21 @@ class OpsManager: # 异步执行备份(实际实现中应该启动后台任务) # 这里模拟备份完成 - self._complete_backup(record_id, size_bytes=1024 * 1024 * 100) # 模拟100MB + self._complete_backup(record_id, size_bytes = 1024 * 1024 * 100) # 模拟100MB return record - def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None: + def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None: """完成备份""" - now = datetime.now().isoformat() - checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] + now = datetime.now().isoformat() + checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] with self._get_db() as conn: conn.execute( """ UPDATE backup_records - SET status = ?, size_bytes = ?, checksum = ?, completed_at = ? - WHERE id = ? + SET status = ?, size_bytes = ?, checksum = ?, completed_at = ? + WHERE id = ? """, (BackupStatus.COMPLETED.value, size_bytes, checksum, now, record_id), ) @@ -2427,33 +2427,33 @@ class OpsManager: def get_backup_record(self, record_id: str) -> BackupRecord | None: """获取备份记录""" with self._get_db() as conn: - row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id,)).fetchone() + row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id, )).fetchone() if row: return self._row_to_backup_record(row) return None def list_backup_records( - self, tenant_id: str, job_id: str = None, limit: int = 100 + self, tenant_id: str, job_id: str = None, limit: int = 100 ) -> list[BackupRecord]: """列出租户的备份记录""" - query = "SELECT * FROM backup_records WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM backup_records WHERE tenant_id = ?" + params = [tenant_id] if job_id: - query += " AND job_id = ?" + query += " AND job_id = ?" params.append(job_id) query += " ORDER BY started_at DESC LIMIT ?" params.append(limit) with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_backup_record(row) for row in rows] def restore_from_backup(self, record_id: str) -> bool: """从备份恢复""" - record = self.get_backup_record(record_id) + record = self.get_backup_record(record_id) if not record or record.status != BackupStatus.COMPLETED: return False @@ -2465,42 +2465,42 @@ class OpsManager: def generate_cost_report(self, tenant_id: str, year: int, month: int) -> CostReport: """生成成本报告""" - report_id = f"cr_{uuid.uuid4().hex[:16]}" - report_period = f"{year:04d}-{month:02d}" - now = datetime.now().isoformat() + report_id = f"cr_{uuid.uuid4().hex[:16]}" + report_period = f"{year:04d}-{month:02d}" + now = datetime.now().isoformat() # 获取资源利用率数据 - utilizations = self.get_resource_utilizations(tenant_id, report_period) + utilizations = self.get_resource_utilizations(tenant_id, report_period) # 计算成本分解 - breakdown = {} - total_cost = 0.0 + breakdown = {} + total_cost = 0.0 for util in utilizations: # 简化计算:假设每单位资源每月成本 - unit_cost = 10.0 - resource_cost = unit_cost * util.utilization_rate - breakdown[util.resource_type.value] = ( + unit_cost = 10.0 + resource_cost = unit_cost * util.utilization_rate + breakdown[util.resource_type.value] = ( breakdown.get(util.resource_type.value, 0) + resource_cost ) total_cost += resource_cost # 检测异常 - anomalies = self._detect_cost_anomalies(utilizations) + anomalies = self._detect_cost_anomalies(utilizations) # 计算趋势 - trends = self._calculate_cost_trends(tenant_id, year, month) + trends = self._calculate_cost_trends(tenant_id, year, month) - report = CostReport( - id=report_id, - tenant_id=tenant_id, - report_period=report_period, - total_cost=total_cost, - currency="CNY", - breakdown=breakdown, - trends=trends, - anomalies=anomalies, - created_at=now, + report = CostReport( + id = report_id, + tenant_id = tenant_id, + report_period = report_period, + total_cost = total_cost, + currency = "CNY", + breakdown = breakdown, + trends = trends, + anomalies = anomalies, + created_at = now, ) with self._get_db() as conn: @@ -2528,7 +2528,7 @@ class OpsManager: def _detect_cost_anomalies(self, utilizations: list[ResourceUtilization]) -> list[dict]: """检测成本异常""" - anomalies = [] + anomalies = [] for util in utilizations: # 检测低利用率 @@ -2576,22 +2576,22 @@ class OpsManager: avg_utilization: float, idle_time_percent: float, report_date: str, - recommendations: list[str] = None, + recommendations: list[str] = None, ) -> ResourceUtilization: """记录资源利用率""" - util_id = f"ru_{uuid.uuid4().hex[:16]}" + util_id = f"ru_{uuid.uuid4().hex[:16]}" - util = ResourceUtilization( - id=util_id, - tenant_id=tenant_id, - resource_type=resource_type, - resource_id=resource_id, - utilization_rate=utilization_rate, - peak_utilization=peak_utilization, - avg_utilization=avg_utilization, - idle_time_percent=idle_time_percent, - report_date=report_date, - recommendations=recommendations or [], + util = ResourceUtilization( + id = util_id, + tenant_id = tenant_id, + resource_type = resource_type, + resource_id = resource_id, + utilization_rate = utilization_rate, + peak_utilization = peak_utilization, + avg_utilization = avg_utilization, + idle_time_percent = idle_time_percent, + report_date = report_date, + recommendations = recommendations or [], ) with self._get_db() as conn: @@ -2624,9 +2624,9 @@ class OpsManager: ) -> list[ResourceUtilization]: """获取资源利用率列表""" with self._get_db() as conn: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM resource_utilizations - WHERE tenant_id = ? AND report_date LIKE ? + WHERE tenant_id = ? AND report_date LIKE ? ORDER BY report_date DESC""", (tenant_id, f"{report_period}%"), ).fetchall() @@ -2634,37 +2634,37 @@ class OpsManager: def detect_idle_resources(self, tenant_id: str) -> list[IdleResource]: """检测闲置资源""" - idle_resources = [] + idle_resources = [] # 获取最近30天的利用率数据 with self._get_db() as conn: - thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() - rows = conn.execute( + thirty_days_ago = (datetime.now() - timedelta(days = 30)).isoformat() + rows = conn.execute( """SELECT resource_type, resource_id, AVG(utilization_rate) as avg_utilization, MAX(idle_time_percent) as max_idle_time FROM resource_utilizations - WHERE tenant_id = ? AND report_date > ? + WHERE tenant_id = ? AND report_date > ? GROUP BY resource_type, resource_id HAVING avg_utilization < 0.1 AND max_idle_time > 0.8""", (tenant_id, thirty_days_ago), ).fetchall() for row in rows: - idle_id = f"ir_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + idle_id = f"ir_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - idle_resource = IdleResource( - id=idle_id, - tenant_id=tenant_id, - resource_type=ResourceType(row["resource_type"]), - resource_id=row["resource_id"], - resource_name=f"{row['resource_type']}-{row['resource_id']}", - idle_since=thirty_days_ago, - estimated_monthly_cost=50.0, # 简化计算 - currency="CNY", - reason="Low utilization rate over 30 days", - recommendation="Consider downsizing or terminating this resource", - detected_at=now, + idle_resource = IdleResource( + id = idle_id, + tenant_id = tenant_id, + resource_type = ResourceType(row["resource_type"]), + resource_id = row["resource_id"], + resource_name = f"{row['resource_type']}-{row['resource_id']}", + idle_since = thirty_days_ago, + estimated_monthly_cost = 50.0, # 简化计算 + currency = "CNY", + reason = "Low utilization rate over 30 days", + recommendation = "Consider downsizing or terminating this resource", + detected_at = now, ) conn.execute( @@ -2698,9 +2698,9 @@ class OpsManager: def get_idle_resources(self, tenant_id: str) -> list[IdleResource]: """获取闲置资源列表""" with self._get_db() as conn: - rows = conn.execute( - "SELECT * FROM idle_resources WHERE tenant_id = ? ORDER BY detected_at DESC", - (tenant_id,), + rows = conn.execute( + "SELECT * FROM idle_resources WHERE tenant_id = ? ORDER BY detected_at DESC", + (tenant_id, ), ).fetchall() return [self._row_to_idle_resource(row) for row in rows] @@ -2708,36 +2708,36 @@ class OpsManager: self, tenant_id: str ) -> list[CostOptimizationSuggestion]: """生成成本优化建议""" - suggestions = [] + suggestions = [] # 基于闲置资源生成建议 - idle_resources = self.detect_idle_resources(tenant_id) + idle_resources = self.detect_idle_resources(tenant_id) - total_potential_savings = sum(r.estimated_monthly_cost for r in idle_resources) + total_potential_savings = sum(r.estimated_monthly_cost for r in idle_resources) if total_potential_savings > 0: - suggestion_id = f"cos_{uuid.uuid4().hex[:16]}" - now = datetime.now().isoformat() + suggestion_id = f"cos_{uuid.uuid4().hex[:16]}" + now = datetime.now().isoformat() - suggestion = CostOptimizationSuggestion( - id=suggestion_id, - tenant_id=tenant_id, - category="resource_rightsize", - title="清理闲置资源", - description=f"检测到 {len(idle_resources)} 个闲置资源,建议清理以节省成本。", - potential_savings=total_potential_savings, - currency="CNY", - confidence=0.85, - difficulty="easy", - implementation_steps=[ + suggestion = CostOptimizationSuggestion( + id = suggestion_id, + tenant_id = tenant_id, + category = "resource_rightsize", + title = "清理闲置资源", + description = f"检测到 {len(idle_resources)} 个闲置资源,建议清理以节省成本。", + potential_savings = total_potential_savings, + currency = "CNY", + confidence = 0.85, + difficulty = "easy", + implementation_steps = [ "Review the list of idle resources", "Confirm resources are no longer needed", "Terminate or downsize unused resources", ], - risk_level="low", - is_applied=False, - created_at=now, - applied_at=None, + risk_level = "low", + is_applied = False, + created_at = now, + applied_at = None, ) with self._get_db() as conn: @@ -2773,34 +2773,34 @@ class OpsManager: return suggestions def get_cost_optimization_suggestions( - self, tenant_id: str, is_applied: bool = None + self, tenant_id: str, is_applied: bool = None ) -> list[CostOptimizationSuggestion]: """获取成本优化建议""" - query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?" + params = [tenant_id] if is_applied is not None: - query += " AND is_applied = ?" + query += " AND is_applied = ?" params.append(1 if is_applied else 0) query += " ORDER BY potential_savings DESC" with self._get_db() as conn: - rows = conn.execute(query, params).fetchall() + rows = conn.execute(query, params).fetchall() return [self._row_to_cost_optimization_suggestion(row) for row in rows] def apply_cost_optimization_suggestion( self, suggestion_id: str ) -> CostOptimizationSuggestion | None: """应用成本优化建议""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE cost_optimization_suggestions - SET is_applied = ?, applied_at = ? - WHERE id = ? + SET is_applied = ?, applied_at = ? + WHERE id = ? """, (True, now, suggestion_id), ) @@ -2813,8 +2813,8 @@ class OpsManager: ) -> CostOptimizationSuggestion | None: """获取成本优化建议详情""" with self._get_db() as conn: - row = conn.execute( - "SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,) + row = conn.execute( + "SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id, ) ).fetchone() if row: @@ -2825,286 +2825,286 @@ class OpsManager: def _row_to_alert_rule(self, row) -> AlertRule: return AlertRule( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - description=row["description"], - rule_type=AlertRuleType(row["rule_type"]), - severity=AlertSeverity(row["severity"]), - metric=row["metric"], - condition=row["condition"], - threshold=row["threshold"], - duration=row["duration"], - evaluation_interval=row["evaluation_interval"], - channels=json.loads(row["channels"]), - labels=json.loads(row["labels"]), - annotations=json.loads(row["annotations"]), - is_enabled=bool(row["is_enabled"]), - created_at=row["created_at"], - updated_at=row["updated_at"], - created_by=row["created_by"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + description = row["description"], + rule_type = AlertRuleType(row["rule_type"]), + severity = AlertSeverity(row["severity"]), + metric = row["metric"], + condition = row["condition"], + threshold = row["threshold"], + duration = row["duration"], + evaluation_interval = row["evaluation_interval"], + channels = json.loads(row["channels"]), + labels = json.loads(row["labels"]), + annotations = json.loads(row["annotations"]), + is_enabled = bool(row["is_enabled"]), + created_at = row["created_at"], + updated_at = row["updated_at"], + created_by = row["created_by"], ) def _row_to_alert_channel(self, row) -> AlertChannel: return AlertChannel( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - channel_type=AlertChannelType(row["channel_type"]), - config=json.loads(row["config"]), - severity_filter=json.loads(row["severity_filter"]), - is_enabled=bool(row["is_enabled"]), - success_count=row["success_count"], - fail_count=row["fail_count"], - last_used_at=row["last_used_at"], - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + channel_type = AlertChannelType(row["channel_type"]), + config = json.loads(row["config"]), + severity_filter = json.loads(row["severity_filter"]), + is_enabled = bool(row["is_enabled"]), + success_count = row["success_count"], + fail_count = row["fail_count"], + last_used_at = row["last_used_at"], + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_alert(self, row) -> Alert: return Alert( - id=row["id"], - rule_id=row["rule_id"], - tenant_id=row["tenant_id"], - severity=AlertSeverity(row["severity"]), - status=AlertStatus(row["status"]), - title=row["title"], - description=row["description"], - metric=row["metric"], - value=row["value"], - threshold=row["threshold"], - labels=json.loads(row["labels"]), - annotations=json.loads(row["annotations"]), - started_at=row["started_at"], - resolved_at=row["resolved_at"], - acknowledged_by=row["acknowledged_by"], - acknowledged_at=row["acknowledged_at"], - notification_sent=json.loads(row["notification_sent"]), - suppression_count=row["suppression_count"], + id = row["id"], + rule_id = row["rule_id"], + tenant_id = row["tenant_id"], + severity = AlertSeverity(row["severity"]), + status = AlertStatus(row["status"]), + title = row["title"], + description = row["description"], + metric = row["metric"], + value = row["value"], + threshold = row["threshold"], + labels = json.loads(row["labels"]), + annotations = json.loads(row["annotations"]), + started_at = row["started_at"], + resolved_at = row["resolved_at"], + acknowledged_by = row["acknowledged_by"], + acknowledged_at = row["acknowledged_at"], + notification_sent = json.loads(row["notification_sent"]), + suppression_count = row["suppression_count"], ) def _row_to_suppression_rule(self, row) -> AlertSuppressionRule: return AlertSuppressionRule( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - matchers=json.loads(row["matchers"]), - duration=row["duration"], - is_regex=bool(row["is_regex"]), - created_at=row["created_at"], - expires_at=row["expires_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + matchers = json.loads(row["matchers"]), + duration = row["duration"], + is_regex = bool(row["is_regex"]), + created_at = row["created_at"], + expires_at = row["expires_at"], ) def _row_to_resource_metric(self, row) -> ResourceMetric: return ResourceMetric( - id=row["id"], - tenant_id=row["tenant_id"], - resource_type=ResourceType(row["resource_type"]), - resource_id=row["resource_id"], - metric_name=row["metric_name"], - metric_value=row["metric_value"], - unit=row["unit"], - timestamp=row["timestamp"], - metadata=json.loads(row["metadata"]), + id = row["id"], + tenant_id = row["tenant_id"], + resource_type = ResourceType(row["resource_type"]), + resource_id = row["resource_id"], + metric_name = row["metric_name"], + metric_value = row["metric_value"], + unit = row["unit"], + timestamp = row["timestamp"], + metadata = json.loads(row["metadata"]), ) def _row_to_capacity_plan(self, row) -> CapacityPlan: return CapacityPlan( - id=row["id"], - tenant_id=row["tenant_id"], - resource_type=ResourceType(row["resource_type"]), - current_capacity=row["current_capacity"], - predicted_capacity=row["predicted_capacity"], - prediction_date=row["prediction_date"], - confidence=row["confidence"], - recommended_action=row["recommended_action"], - estimated_cost=row["estimated_cost"], - created_at=row["created_at"], + id = row["id"], + tenant_id = row["tenant_id"], + resource_type = ResourceType(row["resource_type"]), + current_capacity = row["current_capacity"], + predicted_capacity = row["predicted_capacity"], + prediction_date = row["prediction_date"], + confidence = row["confidence"], + recommended_action = row["recommended_action"], + estimated_cost = row["estimated_cost"], + created_at = row["created_at"], ) def _row_to_auto_scaling_policy(self, row) -> AutoScalingPolicy: return AutoScalingPolicy( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - resource_type=ResourceType(row["resource_type"]), - min_instances=row["min_instances"], - max_instances=row["max_instances"], - target_utilization=row["target_utilization"], - scale_up_threshold=row["scale_up_threshold"], - scale_down_threshold=row["scale_down_threshold"], - scale_up_step=row["scale_up_step"], - scale_down_step=row["scale_down_step"], - cooldown_period=row["cooldown_period"], - is_enabled=bool(row["is_enabled"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + resource_type = ResourceType(row["resource_type"]), + min_instances = row["min_instances"], + max_instances = row["max_instances"], + target_utilization = row["target_utilization"], + scale_up_threshold = row["scale_up_threshold"], + scale_down_threshold = row["scale_down_threshold"], + scale_up_step = row["scale_up_step"], + scale_down_step = row["scale_down_step"], + cooldown_period = row["cooldown_period"], + is_enabled = bool(row["is_enabled"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_scaling_event(self, row) -> ScalingEvent: return ScalingEvent( - id=row["id"], - policy_id=row["policy_id"], - tenant_id=row["tenant_id"], - action=ScalingAction(row["action"]), - from_count=row["from_count"], - to_count=row["to_count"], - reason=row["reason"], - triggered_by=row["triggered_by"], - status=row["status"], - started_at=row["started_at"], - completed_at=row["completed_at"], - error_message=row["error_message"], + id = row["id"], + policy_id = row["policy_id"], + tenant_id = row["tenant_id"], + action = ScalingAction(row["action"]), + from_count = row["from_count"], + to_count = row["to_count"], + reason = row["reason"], + triggered_by = row["triggered_by"], + status = row["status"], + started_at = row["started_at"], + completed_at = row["completed_at"], + error_message = row["error_message"], ) def _row_to_health_check(self, row) -> HealthCheck: return HealthCheck( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - target_type=row["target_type"], - target_id=row["target_id"], - check_type=row["check_type"], - check_config=json.loads(row["check_config"]), - interval=row["interval"], - timeout=row["timeout"], - retry_count=row["retry_count"], - healthy_threshold=row["healthy_threshold"], - unhealthy_threshold=row["unhealthy_threshold"], - is_enabled=bool(row["is_enabled"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + target_type = row["target_type"], + target_id = row["target_id"], + check_type = row["check_type"], + check_config = json.loads(row["check_config"]), + interval = row["interval"], + timeout = row["timeout"], + retry_count = row["retry_count"], + healthy_threshold = row["healthy_threshold"], + unhealthy_threshold = row["unhealthy_threshold"], + is_enabled = bool(row["is_enabled"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_health_check_result(self, row) -> HealthCheckResult: return HealthCheckResult( - id=row["id"], - check_id=row["check_id"], - tenant_id=row["tenant_id"], - status=HealthStatus(row["status"]), - response_time=row["response_time"], - message=row["message"], - details=json.loads(row["details"]), - checked_at=row["checked_at"], + id = row["id"], + check_id = row["check_id"], + tenant_id = row["tenant_id"], + status = HealthStatus(row["status"]), + response_time = row["response_time"], + message = row["message"], + details = json.loads(row["details"]), + checked_at = row["checked_at"], ) def _row_to_failover_config(self, row) -> FailoverConfig: return FailoverConfig( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - primary_region=row["primary_region"], - secondary_regions=json.loads(row["secondary_regions"]), - failover_trigger=row["failover_trigger"], - auto_failover=bool(row["auto_failover"]), - failover_timeout=row["failover_timeout"], - health_check_id=row["health_check_id"], - is_enabled=bool(row["is_enabled"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + primary_region = row["primary_region"], + secondary_regions = json.loads(row["secondary_regions"]), + failover_trigger = row["failover_trigger"], + auto_failover = bool(row["auto_failover"]), + failover_timeout = row["failover_timeout"], + health_check_id = row["health_check_id"], + is_enabled = bool(row["is_enabled"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_failover_event(self, row) -> FailoverEvent: return FailoverEvent( - id=row["id"], - config_id=row["config_id"], - tenant_id=row["tenant_id"], - from_region=row["from_region"], - to_region=row["to_region"], - reason=row["reason"], - status=row["status"], - started_at=row["started_at"], - completed_at=row["completed_at"], - rolled_back_at=row["rolled_back_at"], + id = row["id"], + config_id = row["config_id"], + tenant_id = row["tenant_id"], + from_region = row["from_region"], + to_region = row["to_region"], + reason = row["reason"], + status = row["status"], + started_at = row["started_at"], + completed_at = row["completed_at"], + rolled_back_at = row["rolled_back_at"], ) def _row_to_backup_job(self, row) -> BackupJob: return BackupJob( - id=row["id"], - tenant_id=row["tenant_id"], - name=row["name"], - backup_type=row["backup_type"], - target_type=row["target_type"], - target_id=row["target_id"], - schedule=row["schedule"], - retention_days=row["retention_days"], - encryption_enabled=bool(row["encryption_enabled"]), - compression_enabled=bool(row["compression_enabled"]), - storage_location=row["storage_location"], - is_enabled=bool(row["is_enabled"]), - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + tenant_id = row["tenant_id"], + name = row["name"], + backup_type = row["backup_type"], + target_type = row["target_type"], + target_id = row["target_id"], + schedule = row["schedule"], + retention_days = row["retention_days"], + encryption_enabled = bool(row["encryption_enabled"]), + compression_enabled = bool(row["compression_enabled"]), + storage_location = row["storage_location"], + is_enabled = bool(row["is_enabled"]), + created_at = row["created_at"], + updated_at = row["updated_at"], ) def _row_to_backup_record(self, row) -> BackupRecord: return BackupRecord( - id=row["id"], - job_id=row["job_id"], - tenant_id=row["tenant_id"], - status=BackupStatus(row["status"]), - size_bytes=row["size_bytes"], - checksum=row["checksum"], - started_at=row["started_at"], - completed_at=row["completed_at"], - verified_at=row["verified_at"], - error_message=row["error_message"], - storage_path=row["storage_path"], + id = row["id"], + job_id = row["job_id"], + tenant_id = row["tenant_id"], + status = BackupStatus(row["status"]), + size_bytes = row["size_bytes"], + checksum = row["checksum"], + started_at = row["started_at"], + completed_at = row["completed_at"], + verified_at = row["verified_at"], + error_message = row["error_message"], + storage_path = row["storage_path"], ) def _row_to_resource_utilization(self, row) -> ResourceUtilization: return ResourceUtilization( - id=row["id"], - tenant_id=row["tenant_id"], - resource_type=ResourceType(row["resource_type"]), - resource_id=row["resource_id"], - utilization_rate=row["utilization_rate"], - peak_utilization=row["peak_utilization"], - avg_utilization=row["avg_utilization"], - idle_time_percent=row["idle_time_percent"], - report_date=row["report_date"], - recommendations=json.loads(row["recommendations"]), + id = row["id"], + tenant_id = row["tenant_id"], + resource_type = ResourceType(row["resource_type"]), + resource_id = row["resource_id"], + utilization_rate = row["utilization_rate"], + peak_utilization = row["peak_utilization"], + avg_utilization = row["avg_utilization"], + idle_time_percent = row["idle_time_percent"], + report_date = row["report_date"], + recommendations = json.loads(row["recommendations"]), ) def _row_to_idle_resource(self, row) -> IdleResource: return IdleResource( - id=row["id"], - tenant_id=row["tenant_id"], - resource_type=ResourceType(row["resource_type"]), - resource_id=row["resource_id"], - resource_name=row["resource_name"], - idle_since=row["idle_since"], - estimated_monthly_cost=row["estimated_monthly_cost"], - currency=row["currency"], - reason=row["reason"], - recommendation=row["recommendation"], - detected_at=row["detected_at"], + id = row["id"], + tenant_id = row["tenant_id"], + resource_type = ResourceType(row["resource_type"]), + resource_id = row["resource_id"], + resource_name = row["resource_name"], + idle_since = row["idle_since"], + estimated_monthly_cost = row["estimated_monthly_cost"], + currency = row["currency"], + reason = row["reason"], + recommendation = row["recommendation"], + detected_at = row["detected_at"], ) def _row_to_cost_optimization_suggestion(self, row) -> CostOptimizationSuggestion: return CostOptimizationSuggestion( - id=row["id"], - tenant_id=row["tenant_id"], - category=row["category"], - title=row["title"], - description=row["description"], - potential_savings=row["potential_savings"], - currency=row["currency"], - confidence=row["confidence"], - difficulty=row["difficulty"], - implementation_steps=json.loads(row["implementation_steps"]), - risk_level=row["risk_level"], - is_applied=bool(row["is_applied"]), - created_at=row["created_at"], - applied_at=row["applied_at"], + id = row["id"], + tenant_id = row["tenant_id"], + category = row["category"], + title = row["title"], + description = row["description"], + potential_savings = row["potential_savings"], + currency = row["currency"], + confidence = row["confidence"], + difficulty = row["difficulty"], + implementation_steps = json.loads(row["implementation_steps"]), + risk_level = row["risk_level"], + is_applied = bool(row["is_applied"]), + created_at = row["created_at"], + applied_at = row["applied_at"], ) # Singleton instance -_ops_manager = None +_ops_manager = None def get_ops_manager() -> OpsManager: global _ops_manager if _ops_manager is None: - _ops_manager = OpsManager() + _ops_manager = OpsManager() return _ops_manager diff --git a/backend/oss_uploader.py b/backend/oss_uploader.py index 83de463..edbbf7d 100644 --- a/backend/oss_uploader.py +++ b/backend/oss_uploader.py @@ -11,30 +11,30 @@ import oss2 class OSSUploader: - def __init__(self): - self.access_key = os.getenv("ALI_ACCESS_KEY") - self.secret_key = os.getenv("ALI_SECRET_KEY") - self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio") - self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com") - self.endpoint = f"https://{self.region}" + def __init__(self) -> None: + self.access_key = os.getenv("ALI_ACCESS_KEY") + self.secret_key = os.getenv("ALI_SECRET_KEY") + self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio") + self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com") + self.endpoint = f"https://{self.region}" if not self.access_key or not self.secret_key: raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY must be set") - self.auth = oss2.Auth(self.access_key, self.secret_key) - self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name) + self.auth = oss2.Auth(self.access_key, self.secret_key) + self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name) def upload_audio(self, audio_data: bytes, filename: str) -> tuple: """上传音频到 OSS,返回 (URL, object_name)""" # 生成唯一文件名 - ext = os.path.splitext(filename)[1] or ".wav" - object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}" + ext = os.path.splitext(filename)[1] or ".wav" + object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}" # 上传文件 self.bucket.put_object(object_name, audio_data) # 生成临时访问 URL (1小时有效) - url = self.bucket.sign_url("GET", object_name, 3600) + url = self.bucket.sign_url("GET", object_name, 3600) return url, object_name def delete_object(self, object_name: str) -> None: @@ -43,11 +43,11 @@ class OSSUploader: # 单例 -_oss_uploader = None +_oss_uploader = None def get_oss_uploader() -> OSSUploader: global _oss_uploader if _oss_uploader is None: - _oss_uploader = OSSUploader() + _oss_uploader = OSSUploader() return _oss_uploader diff --git a/backend/performance_manager.py b/backend/performance_manager.py index 0b98a8e..88a43d2 100644 --- a/backend/performance_manager.py +++ b/backend/performance_manager.py @@ -27,18 +27,18 @@ from typing import Any try: import redis - REDIS_AVAILABLE = True + REDIS_AVAILABLE = True except ImportError: - REDIS_AVAILABLE = False + REDIS_AVAILABLE = False # 尝试导入 Celery try: from celery import Celery from celery.result import AsyncResult - CELERY_AVAILABLE = True + CELERY_AVAILABLE = True except ImportError: - CELERY_AVAILABLE = False + CELERY_AVAILABLE = False # ==================== 数据模型 ==================== @@ -47,17 +47,17 @@ except ImportError: class CacheStats: """缓存统计数据模型""" - total_requests: int = 0 - hits: int = 0 - misses: int = 0 - evictions: int = 0 - expired: int = 0 - hit_rate: float = 0.0 + total_requests: int = 0 + hits: int = 0 + misses: int = 0 + evictions: int = 0 + expired: int = 0 + hit_rate: float = 0.0 def update_hit_rate(self) -> None: """更新命中率""" if self.total_requests > 0: - self.hit_rate = round(self.hits / self.total_requests, 4) + self.hit_rate = round(self.hits / self.total_requests, 4) @dataclass @@ -68,9 +68,9 @@ class CacheEntry: value: Any created_at: float expires_at: float | None - access_count: int = 0 - last_accessed: float = 0 - size_bytes: int = 0 + access_count: int = 0 + last_accessed: float = 0 + size_bytes: int = 0 @dataclass @@ -82,7 +82,7 @@ class PerformanceMetric: endpoint: str | None duration_ms: float timestamp: str - metadata: dict = field(default_factory=dict) + metadata: dict = field(default_factory = dict) def to_dict(self) -> dict: return { @@ -104,12 +104,12 @@ class TaskInfo: status: str # pending, running, success, failed, retrying payload: dict created_at: str - started_at: str | None = None - completed_at: str | None = None - result: Any | None = None - error_message: str | None = None - retry_count: int = 0 - max_retries: int = 3 + started_at: str | None = None + completed_at: str | None = None + result: Any | None = None + error_message: str | None = None + retry_count: int = 0 + max_retries: int = 3 def to_dict(self) -> dict: return { @@ -134,10 +134,10 @@ class ShardInfo: shard_id: str shard_key_range: tuple[str, str] # (start, end) db_path: str - entity_count: int = 0 - is_active: bool = True - created_at: str = "" - last_accessed: str = "" + entity_count: int = 0 + is_active: bool = True + created_at: str = "" + last_accessed: str = "" # ==================== Redis 缓存层 ==================== @@ -160,42 +160,42 @@ class CacheManager: def __init__( self, - redis_url: str | None = None, - max_memory_size: int = 100 * 1024 * 1024, # 100MB - default_ttl: int = 3600, # 1小时 - db_path: str = "insightflow.db", - ): - self.db_path = db_path - self.default_ttl = default_ttl - self.max_memory_size = max_memory_size - self.current_memory_size = 0 + redis_url: str | None = None, + max_memory_size: int = 100 * 1024 * 1024, # 100MB + default_ttl: int = 3600, # 1小时 + db_path: str = "insightflow.db", + ) -> None: + self.db_path = db_path + self.default_ttl = default_ttl + self.max_memory_size = max_memory_size + self.current_memory_size = 0 # Redis 客户端 - self.redis_client = None - self.use_redis = False + self.redis_client = None + self.use_redis = False if REDIS_AVAILABLE and redis_url: try: - self.redis_client = redis.from_url(redis_url, decode_responses=True) + self.redis_client = redis.from_url(redis_url, decode_responses = True) self.redis_client.ping() - self.use_redis = True + self.use_redis = True print(f"Redis 缓存已连接: {redis_url}") except Exception as e: print(f"Redis 连接失败,使用内存缓存: {e}") # 内存缓存(LRU) - self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict() - self.cache_lock = threading.RLock() + self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict() + self.cache_lock = threading.RLock() # 统计 - self.stats = CacheStats() + self.stats = CacheStats() # 初始化缓存统计表 self._init_cache_tables() def _init_cache_tables(self) -> None: """初始化缓存统计表""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute(""" CREATE TABLE IF NOT EXISTS cache_stats ( @@ -233,11 +233,11 @@ class CacheManager: def _get_entry_size(self, value: Any) -> int: """估算缓存条目大小""" try: - return len(json.dumps(value, ensure_ascii=False).encode("utf-8")) + return len(json.dumps(value, ensure_ascii = False).encode("utf-8")) except (TypeError, ValueError): return 1024 # 默认估算 - def _evict_lru(self, required_space: int = 0) -> None: + def _evict_lru(self, required_space: int = 0) -> None: """LRU 淘汰策略""" with self.cache_lock: while ( @@ -245,7 +245,7 @@ class CacheManager: and self.memory_cache ): # 移除最久未访问的 - oldest_key, oldest_entry = self.memory_cache.popitem(last=False) + oldest_key, oldest_entry = self.memory_cache.popitem(last = False) self.current_memory_size -= oldest_entry.size_bytes self.stats.evictions += 1 @@ -263,7 +263,7 @@ class CacheManager: if self.use_redis: try: - value = self.redis_client.get(key) + value = self.redis_client.get(key) if value: self.stats.hits += 1 return json.loads(value) @@ -276,7 +276,7 @@ class CacheManager: else: # 内存缓存 with self.cache_lock: - entry = self.memory_cache.get(key) + entry = self.memory_cache.get(key) if entry: # 检查是否过期 @@ -289,7 +289,7 @@ class CacheManager: # 更新访问信息 entry.access_count += 1 - entry.last_accessed = time.time() + entry.last_accessed = time.time() self.memory_cache.move_to_end(key) self.stats.hits += 1 @@ -298,7 +298,7 @@ class CacheManager: self.stats.misses += 1 return None - def set(self, key: str, value: Any, ttl: int | None = None) -> bool: + def set(self, key: str, value: Any, ttl: int | None = None) -> bool: """ 设置缓存值 @@ -310,11 +310,11 @@ class CacheManager: Returns: bool: 是否成功 """ - ttl = ttl or self.default_ttl + ttl = ttl or self.default_ttl if self.use_redis: try: - serialized = json.dumps(value, ensure_ascii=False) + serialized = json.dumps(value, ensure_ascii = False) self.redis_client.setex(key, ttl, serialized) return True except Exception as e: @@ -323,27 +323,27 @@ class CacheManager: else: # 内存缓存 with self.cache_lock: - size = self._get_entry_size(value) + size = self._get_entry_size(value) # 检查是否需要淘汰 if self.current_memory_size + size > self.max_memory_size: self._evict_lru(size) - now = time.time() - entry = CacheEntry( - key=key, - value=value, - created_at=now, - expires_at=now + ttl if ttl > 0 else None, - size_bytes=size, - last_accessed=now, + now = time.time() + entry = CacheEntry( + key = key, + value = value, + created_at = now, + expires_at = now + ttl if ttl > 0 else None, + size_bytes = size, + last_accessed = now, ) # 如果已存在,更新大小 if key in self.memory_cache: self.current_memory_size -= self.memory_cache[key].size_bytes - self.memory_cache[key] = entry + self.memory_cache[key] = entry self.memory_cache.move_to_end(key) self.current_memory_size += size @@ -360,7 +360,7 @@ class CacheManager: else: with self.cache_lock: if key in self.memory_cache: - entry = self.memory_cache.pop(key) + entry = self.memory_cache.pop(key) self.current_memory_size -= entry.size_bytes return True return False @@ -377,19 +377,19 @@ class CacheManager: else: with self.cache_lock: self.memory_cache.clear() - self.current_memory_size = 0 + self.current_memory_size = 0 return True def get_many(self, keys: list[str]) -> dict[str, Any]: """批量获取缓存""" - results = {} + results = {} if self.use_redis: try: - values = self.redis_client.mget(keys) + values = self.redis_client.mget(keys) for key, value in zip(keys, values): if value: - results[key] = json.loads(value) + results[key] = json.loads(value) self.stats.hits += 1 else: self.stats.misses += 1 @@ -398,21 +398,21 @@ class CacheManager: print(f"Redis mget 失败: {e}") else: for key in keys: - value = self.get(key) + value = self.get(key) if value is not None: - results[key] = value + results[key] = value return results - def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool: + def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool: """批量设置缓存""" - ttl = ttl or self.default_ttl + ttl = ttl or self.default_ttl if self.use_redis: try: - pipe = self.redis_client.pipeline() + pipe = self.redis_client.pipeline() for key, value in mapping.items(): - serialized = json.dumps(value, ensure_ascii=False) + serialized = json.dumps(value, ensure_ascii = False) pipe.setex(key, ttl, serialized) pipe.execute() return True @@ -428,7 +428,7 @@ class CacheManager: """获取缓存统计""" self.stats.update_hit_rate() - stats = { + stats = { "total_requests": self.stats.total_requests, "hits": self.stats.hits, "misses": self.stats.misses, @@ -454,7 +454,7 @@ class CacheManager: def save_stats(self) -> None: """保存缓存统计到数据库""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) self.stats.update_hit_rate() @@ -487,81 +487,81 @@ class CacheManager: Returns: Dict: 预热统计 """ - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - stats = {"entities": 0, "relations": 0, "transcripts": 0} + stats = {"entities": 0, "relations": 0, "transcripts": 0} # 预热实体数据 - entities = conn.execute( + entities = conn.execute( """SELECT e.*, - (SELECT COUNT(*) FROM entity_mentions m WHERE m.entity_id = e.id) as mention_count + (SELECT COUNT(*) FROM entity_mentions m WHERE m.entity_id = e.id) as mention_count FROM entities e - WHERE e.project_id = ? + WHERE e.project_id = ? ORDER BY mention_count DESC LIMIT 100""", - (project_id,), + (project_id, ), ).fetchall() for entity in entities: - key = f"entity:{entity['id']}" - self.set(key, dict(entity), ttl=7200) # 2小时 + key = f"entity:{entity['id']}" + self.set(key, dict(entity), ttl = 7200) # 2小时 stats["entities"] += 1 # 预热关系数据 - relations = conn.execute( + relations = conn.execute( """SELECT r.*, e1.name as source_name, e2.name as target_name FROM entity_relations r - JOIN entities e1 ON r.source_entity_id = e1.id - JOIN entities e2 ON r.target_entity_id = e2.id - WHERE r.project_id = ? + JOIN entities e1 ON r.source_entity_id = e1.id + JOIN entities e2 ON r.target_entity_id = e2.id + WHERE r.project_id = ? LIMIT 200""", - (project_id,), + (project_id, ), ).fetchall() for relation in relations: - key = f"relation:{relation['id']}" - self.set(key, dict(relation), ttl=3600) + key = f"relation:{relation['id']}" + self.set(key, dict(relation), ttl = 3600) stats["relations"] += 1 # 预热最近的转录 - transcripts = conn.execute( + transcripts = conn.execute( """SELECT * FROM transcripts - WHERE project_id = ? + WHERE project_id = ? ORDER BY created_at DESC LIMIT 10""", - (project_id,), + (project_id, ), ).fetchall() for transcript in transcripts: - key = f"transcript:{transcript['id']}" + key = f"transcript:{transcript['id']}" # 只缓存元数据,不缓存完整文本 - meta = { + meta = { "id": transcript["id"], "filename": transcript["filename"], "type": transcript.get("type", "audio"), "created_at": transcript["created_at"], } - self.set(key, meta, ttl=1800) # 30分钟 + self.set(key, meta, ttl = 1800) # 30分钟 stats["transcripts"] += 1 # 预热项目知识库摘要 - entity_count = conn.execute( - "SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,) + entity_count = conn.execute( + "SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id, ) ).fetchone()[0] - relation_count = conn.execute( - "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,) + relation_count = conn.execute( + "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id, ) ).fetchone()[0] - summary = { + summary = { "project_id": project_id, "entity_count": entity_count, "relation_count": relation_count, "cached_at": datetime.now().isoformat(), } - self.set(f"project_summary:{project_id}", summary, ttl=3600) + self.set(f"project_summary:{project_id}", summary, ttl = 3600) conn.close() @@ -577,13 +577,13 @@ class CacheManager: Returns: int: 清除的缓存数量 """ - count = 0 + count = 0 if self.use_redis: try: # 使用 Redis 的 scan 查找相关 key - pattern = f"*:{project_id}:*" - for key in self.redis_client.scan_iter(match=pattern): + pattern = f"*:{project_id}:*" + for key in self.redis_client.scan_iter(match = pattern): self.redis_client.delete(key) count += 1 except Exception as e: @@ -591,9 +591,9 @@ class CacheManager: else: # 内存缓存 - 查找并删除相关 key with self.cache_lock: - keys_to_delete = [key for key in self.memory_cache.keys() if project_id in key] + keys_to_delete = [key for key in self.memory_cache.keys() if project_id in key] for key in keys_to_delete: - entry = self.memory_cache.pop(key) + entry = self.memory_cache.pop(key) self.current_memory_size -= entry.size_bytes count += 1 @@ -616,19 +616,19 @@ class DatabaseSharding: def __init__( self, - base_db_path: str = "insightflow.db", - shard_db_dir: str = "./shards", - shards_count: int = 4, - ): - self.base_db_path = base_db_path - self.shard_db_dir = shard_db_dir - self.shards_count = shards_count + base_db_path: str = "insightflow.db", + shard_db_dir: str = "./shards", + shards_count: int = 4, + ) -> None: + self.base_db_path = base_db_path + self.shard_db_dir = shard_db_dir + self.shards_count = shards_count # 确保分片目录存在 - os.makedirs(shard_db_dir, exist_ok=True) + os.makedirs(shard_db_dir, exist_ok = True) # 分片映射 - self.shard_map: dict[str, ShardInfo] = {} + self.shard_map: dict[str, ShardInfo] = {} # 初始化分片 self._init_shards() @@ -636,24 +636,24 @@ class DatabaseSharding: def _init_shards(self) -> None: """初始化分片""" # 计算每个分片的 key 范围 - chars = "0123456789abcdef" - chars_per_shard = len(chars) // self.shards_count + chars = "0123456789abcdef" + chars_per_shard = len(chars) // self.shards_count for i in range(self.shards_count): - start_idx = i * chars_per_shard - end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars) + start_idx = i * chars_per_shard + end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars) - start_char = chars[start_idx] - end_char = chars[end_idx - 1] + start_char = chars[start_idx] + end_char = chars[end_idx - 1] - shard_id = f"shard_{i}" - db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db") + shard_id = f"shard_{i}" + db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db") - self.shard_map[shard_id] = ShardInfo( - shard_id=shard_id, - shard_key_range=(start_char, end_char), - db_path=db_path, - created_at=datetime.now().isoformat(), + self.shard_map[shard_id] = ShardInfo( + shard_id = shard_id, + shard_key_range = (start_char, end_char), + db_path = db_path, + created_at = datetime.now().isoformat(), ) # 确保分片数据库存在 @@ -662,7 +662,7 @@ class DatabaseSharding: def _create_shard_db(self, db_path: str) -> None: """创建分片数据库""" - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path) # 创建与主库相同的表结构(简化版) conn.executescript(""" @@ -702,10 +702,10 @@ class DatabaseSharding: if not project_id: return "shard_0" - first_char = project_id[0].lower() + first_char = project_id[0].lower() for shard_id, shard_info in self.shard_map.items(): - start, end = shard_info.shard_key_range + start, end = shard_info.shard_key_range if start <= first_char <= end: return shard_id @@ -713,14 +713,14 @@ class DatabaseSharding: def get_shard_connection(self, project_id: str) -> sqlite3.Connection: """获取项目对应的分片连接""" - shard_id = self._get_shard_id(project_id) - shard_info = self.shard_map[shard_id] + shard_id = self._get_shard_id(project_id) + shard_info = self.shard_map[shard_id] - conn = sqlite3.connect(shard_info.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(shard_info.db_path) + conn.row_factory = sqlite3.Row # 更新访问时间 - shard_info.last_accessed = datetime.now().isoformat() + shard_info.last_accessed = datetime.now().isoformat() return conn @@ -740,34 +740,34 @@ class DatabaseSharding: bool: 是否成功 """ # 获取源分片 - source_shard_id = self._get_shard_id(project_id) + source_shard_id = self._get_shard_id(project_id) if source_shard_id == target_shard_id: return True # 已经在目标分片 - source_info = self.shard_map.get(source_shard_id) - target_info = self.shard_map.get(target_shard_id) + source_info = self.shard_map.get(source_shard_id) + target_info = self.shard_map.get(target_shard_id) if not source_info or not target_info: return False try: # 从源分片读取数据 - source_conn = sqlite3.connect(source_info.db_path) - source_conn.row_factory = sqlite3.Row + source_conn = sqlite3.connect(source_info.db_path) + source_conn.row_factory = sqlite3.Row - entities = source_conn.execute( - "SELECT * FROM entities WHERE project_id = ?", (project_id,) + entities = source_conn.execute( + "SELECT * FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() - relations = source_conn.execute( - "SELECT * FROM entity_relations WHERE project_id = ?", (project_id,) + relations = source_conn.execute( + "SELECT * FROM entity_relations WHERE project_id = ?", (project_id, ) ).fetchall() source_conn.close() # 写入目标分片 - target_conn = sqlite3.connect(target_info.db_path) + target_conn = sqlite3.connect(target_info.db_path) for entity in entities: target_conn.execute( @@ -793,9 +793,9 @@ class DatabaseSharding: target_conn.close() # 从源分片删除数据 - source_conn = sqlite3.connect(source_info.db_path) - source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id,)) - source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id,)) + source_conn = sqlite3.connect(source_info.db_path) + source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id, )) + source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id, )) source_conn.commit() source_conn.close() @@ -811,15 +811,15 @@ class DatabaseSharding: def _update_shard_stats(self, shard_id: str) -> None: """更新分片统计""" - shard_info = self.shard_map.get(shard_id) + shard_info = self.shard_map.get(shard_id) if not shard_info: return - conn = sqlite3.connect(shard_info.db_path) + conn = sqlite3.connect(shard_info.db_path) - count = conn.execute("SELECT COUNT(DISTINCT project_id) FROM entities").fetchone()[0] + count = conn.execute("SELECT COUNT(DISTINCT project_id) FROM entities").fetchone()[0] - shard_info.entity_count = count + shard_info.entity_count = count conn.close() @@ -833,14 +833,14 @@ class DatabaseSharding: Returns: List[Dict]: 合并的查询结果 """ - results = [] + results = [] for shard_info in self.shard_map.values(): - conn = sqlite3.connect(shard_info.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(shard_info.db_path) + conn.row_factory = sqlite3.Row try: - shard_results = query_func(conn) + shard_results = query_func(conn) results.extend(shard_results) except Exception as e: print(f"分片 {shard_info.shard_id} 查询失败: {e}") @@ -851,7 +851,7 @@ class DatabaseSharding: def get_shard_stats(self) -> list[dict]: """获取所有分片的统计信息""" - stats = [] + stats = [] for shard_info in self.shard_map.values(): self._update_shard_stats(shard_info.shard_id) @@ -880,17 +880,17 @@ class DatabaseSharding: Dict: 重新平衡统计 """ # 获取各分片的负载 - stats = self.get_shard_stats() + stats = self.get_shard_stats() if not stats: return {"message": "No shards to rebalance"} # 计算平均负载 - avg_load = sum(s["entity_count"] for s in stats) / len(stats) + avg_load = sum(s["entity_count"] for s in stats) / len(stats) # 找出过载和欠载的分片 - overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5] - underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5] + overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5] + underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5] # 简化的重新平衡逻辑 # 实际生产环境需要更复杂的算法 @@ -917,16 +917,16 @@ class TaskQueue: - 任务状态追踪和重试机制 """ - def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db"): - self.db_path = db_path - self.redis_url = redis_url - self.celery_app = None - self.use_celery = False + def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db") -> None: + self.db_path = db_path + self.redis_url = redis_url + self.celery_app = None + self.use_celery = False # 内存任务存储(非 Celery 模式) - self.tasks: dict[str, TaskInfo] = {} - self.task_handlers: dict[str, Callable] = {} - self.task_lock = threading.RLock() + self.tasks: dict[str, TaskInfo] = {} + self.task_handlers: dict[str, Callable] = {} + self.task_lock = threading.RLock() # 初始化任务队列表 self._init_task_tables() @@ -934,15 +934,15 @@ class TaskQueue: # 初始化 Celery if CELERY_AVAILABLE and redis_url: try: - self.celery_app = Celery("insightflow", broker=redis_url, backend=redis_url) - self.use_celery = True + self.celery_app = Celery("insightflow", broker = redis_url, backend = redis_url) + self.use_celery = True print("Celery 任务队列已初始化") except Exception as e: print(f"Celery 初始化失败,使用内存任务队列: {e}") def _init_task_tables(self) -> None: """初始化任务队列表""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute(""" CREATE TABLE IF NOT EXISTS task_queue ( @@ -972,9 +972,9 @@ class TaskQueue: def register_handler(self, task_type: str, handler: Callable) -> None: """注册任务处理器""" - self.task_handlers[task_type] = handler + self.task_handlers[task_type] = handler - def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str: + def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str: """ 提交任务 @@ -986,45 +986,45 @@ class TaskQueue: Returns: str: 任务ID """ - task_id = str(uuid.uuid4())[:16] + task_id = str(uuid.uuid4())[:16] - task = TaskInfo( - id=task_id, - task_type=task_type, - status="pending", - payload=payload, - created_at=datetime.now().isoformat(), - max_retries=max_retries, + task = TaskInfo( + id = task_id, + task_type = task_type, + status = "pending", + payload = payload, + created_at = datetime.now().isoformat(), + max_retries = max_retries, ) if self.use_celery: # 使用 Celery try: # 这里简化处理,实际应该定义具体的 Celery 任务 - result = self.celery_app.send_task( + result = self.celery_app.send_task( f"insightflow.tasks.{task_type}", - args=[payload], - task_id=task_id, - retry=True, - retry_policy={ + args = [payload], + task_id = task_id, + retry = True, + retry_policy = { "max_retries": max_retries, "interval_start": 10, "interval_step": 10, "interval_max": 60, }, ) - task.id = result.id + task.id = result.id except Exception as e: print(f"Celery 任务提交失败: {e}") # 回退到内存模式 - self.use_celery = False + self.use_celery = False if not self.use_celery: # 内存模式 with self.task_lock: - self.tasks[task_id] = task + self.tasks[task_id] = task # 异步执行 - threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start() + threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start() # 保存到数据库 self._save_task(task) @@ -1034,49 +1034,49 @@ class TaskQueue: def _execute_task(self, task_id: str) -> None: """执行任务(内存模式)""" with self.task_lock: - task = self.tasks.get(task_id) + task = self.tasks.get(task_id) if not task: return - task.status = "running" - task.started_at = datetime.now().isoformat() + task.status = "running" + task.started_at = datetime.now().isoformat() self._update_task_status(task) # 获取处理器 - handler = self.task_handlers.get(task.task_type) + handler = self.task_handlers.get(task.task_type) if not handler: - task.status = "failed" - task.error_message = f"No handler for task type: {task.task_type}" + task.status = "failed" + task.error_message = f"No handler for task type: {task.task_type}" else: try: - result = handler(task.payload) - task.status = "success" - task.result = result + result = handler(task.payload) + task.status = "success" + task.result = result except Exception as e: task.retry_count += 1 if task.retry_count <= task.max_retries: - task.status = "retrying" + task.status = "retrying" # 延迟重试 threading.Timer( - 10 * task.retry_count, self._execute_task, args=(task_id,) + 10 * task.retry_count, self._execute_task, args = (task_id, ) ).start() else: - task.status = "failed" - task.error_message = str(e) + task.status = "failed" + task.error_message = str(e) - task.completed_at = datetime.now().isoformat() + task.completed_at = datetime.now().isoformat() with self.task_lock: - self.tasks[task_id] = task + self.tasks[task_id] = task self._update_task_status(task) def _save_task(self, task: TaskInfo) -> None: """保存任务到数据库""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute( """ @@ -1089,8 +1089,8 @@ class TaskQueue: task.id, task.task_type, task.status, - json.dumps(task.payload, ensure_ascii=False), - json.dumps(task.result, ensure_ascii=False) if task.result else None, + json.dumps(task.payload, ensure_ascii = False), + json.dumps(task.result, ensure_ascii = False) if task.result else None, task.error_message, task.retry_count, task.max_retries, @@ -1105,22 +1105,22 @@ class TaskQueue: def _update_task_status(self, task: TaskInfo) -> None: """更新任务状态""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) conn.execute( """ UPDATE task_queue SET - status = ?, - result = ?, - error_message = ?, - retry_count = ?, - started_at = ?, - completed_at = ? - WHERE id = ? + status = ?, + result = ?, + error_message = ?, + retry_count = ?, + started_at = ?, + completed_at = ? + WHERE id = ? """, ( task.status, - json.dumps(task.result, ensure_ascii=False) if task.result else None, + json.dumps(task.result, ensure_ascii = False) if task.result else None, task.error_message, task.retry_count, task.started_at, @@ -1136,9 +1136,9 @@ class TaskQueue: """获取任务状态""" if self.use_celery: try: - result = AsyncResult(task_id, app=self.celery_app) + result = AsyncResult(task_id, app = self.celery_app) - status_map = { + status_map = { "PENDING": "pending", "STARTED": "running", "SUCCESS": "success", @@ -1147,13 +1147,13 @@ class TaskQueue: } return TaskInfo( - id=task_id, - task_type="celery_task", - status=status_map.get(result.status, "unknown"), - payload={}, - created_at="", - result=result.result if result.successful() else None, - error_message=str(result.result) if result.failed() else None, + id = task_id, + task_type = "celery_task", + status = status_map.get(result.status, "unknown"), + payload = {}, + created_at = "", + result = result.result if result.successful() else None, + error_message = str(result.result) if result.failed() else None, ) except Exception as e: print(f"获取 Celery 任务状态失败: {e}") @@ -1163,26 +1163,26 @@ class TaskQueue: return self.tasks.get(task_id) def list_tasks( - self, status: str | None = None, task_type: str | None = None, limit: int = 100 + self, status: str | None = None, task_type: str | None = None, limit: int = 100 ) -> list[TaskInfo]: """列出任务""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - where_clauses = [] - params = [] + where_clauses = [] + params = [] if status: - where_clauses.append("status = ?") + where_clauses.append("status = ?") params.append(status) if task_type: - where_clauses.append("task_type = ?") + where_clauses.append("task_type = ?") params.append(task_type) - where_str = " AND ".join(where_clauses) if where_clauses else "1=1" + where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1" - rows = conn.execute( + rows = conn.execute( f""" SELECT * FROM task_queue WHERE {where_str} @@ -1194,21 +1194,21 @@ class TaskQueue: conn.close() - tasks = [] + tasks = [] for row in rows: tasks.append( TaskInfo( - id=row["id"], - task_type=row["task_type"], - status=row["status"], - payload=json.loads(row["payload"]) if row["payload"] else {}, - created_at=row["created_at"], - started_at=row["started_at"], - completed_at=row["completed_at"], - result=json.loads(row["result"]) if row["result"] else None, - error_message=row["error_message"], - retry_count=row["retry_count"], - max_retries=row["max_retries"], + id = row["id"], + task_type = row["task_type"], + status = row["status"], + payload = json.loads(row["payload"]) if row["payload"] else {}, + created_at = row["created_at"], + started_at = row["started_at"], + completed_at = row["completed_at"], + result = json.loads(row["result"]) if row["result"] else None, + error_message = row["error_message"], + retry_count = row["retry_count"], + max_retries = row["max_retries"], ) ) @@ -1218,16 +1218,16 @@ class TaskQueue: """取消任务""" if self.use_celery: try: - self.celery_app.control.revoke(task_id, terminate=True) + self.celery_app.control.revoke(task_id, terminate = True) return True except Exception as e: print(f"取消 Celery 任务失败: {e}") with self.task_lock: - task = self.tasks.get(task_id) + task = self.tasks.get(task_id) if task and task.status in ["pending", "running"]: - task.status = "cancelled" - task.completed_at = datetime.now().isoformat() + task.status = "cancelled" + task.completed_at = datetime.now().isoformat() self._update_task_status(task) return True @@ -1235,44 +1235,44 @@ class TaskQueue: def retry(self, task_id: str) -> bool: """重试失败的任务""" - task = self.get_status(task_id) + task = self.get_status(task_id) if not task or task.status != "failed": return False - task.status = "pending" - task.retry_count = 0 - task.error_message = None - task.completed_at = None + task.status = "pending" + task.retry_count = 0 + task.error_message = None + task.completed_at = None if not self.use_celery: with self.task_lock: - self.tasks[task_id] = task - threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start() + self.tasks[task_id] = task + threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start() self._update_task_status(task) return True def get_stats(self) -> dict: """获取任务队列统计""" - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) # 各状态任务数量 - status_counts = conn.execute(""" + status_counts = conn.execute(""" SELECT status, COUNT(*) as count FROM task_queue GROUP BY status """).fetchall() # 各类型任务数量 - type_counts = conn.execute(""" + type_counts = conn.execute(""" SELECT task_type, COUNT(*) as count FROM task_queue GROUP BY task_type """).fetchall() # 最近24小时任务数 - recent_count = conn.execute(""" + recent_count = conn.execute(""" SELECT COUNT(*) as count FROM task_queue WHERE created_at > datetime('now', '-1 day') @@ -1304,29 +1304,29 @@ class PerformanceMonitor: def __init__( self, - db_path: str = "insightflow.db", - slow_query_threshold: int = 1000, - alert_threshold: int = 5000, # 毫秒 - ): # 毫秒 - self.db_path = db_path - self.slow_query_threshold = slow_query_threshold - self.alert_threshold = alert_threshold + db_path: str = "insightflow.db", + slow_query_threshold: int = 1000, + alert_threshold: int = 5000, # 毫秒 + ) -> None: # 毫秒 + self.db_path = db_path + self.slow_query_threshold = slow_query_threshold + self.alert_threshold = alert_threshold # 内存中的指标缓存 - self.metrics_buffer: list[PerformanceMetric] = [] - self.buffer_lock = threading.RLock() - self.buffer_size = 100 + self.metrics_buffer: list[PerformanceMetric] = [] + self.buffer_lock = threading.RLock() + self.buffer_size = 100 # 告警回调 - self.alert_handlers: list[Callable] = [] + self.alert_handlers: list[Callable] = [] def record_metric( self, metric_type: str, duration_ms: float, - endpoint: str | None = None, - metadata: dict | None = None, - ): + endpoint: str | None = None, + metadata: dict | None = None, + ) -> None: """ 记录性能指标 @@ -1336,13 +1336,13 @@ class PerformanceMonitor: endpoint: 端点/查询标识 metadata: 额外元数据 """ - metric = PerformanceMetric( - id=str(uuid.uuid4())[:16], - metric_type=metric_type, - endpoint=endpoint, - duration_ms=duration_ms, - timestamp=datetime.now().isoformat(), - metadata=metadata or {}, + metric = PerformanceMetric( + id = str(uuid.uuid4())[:16], + metric_type = metric_type, + endpoint = endpoint, + duration_ms = duration_ms, + timestamp = datetime.now().isoformat(), + metadata = metadata or {}, ) # 添加到缓冲区 @@ -1364,7 +1364,7 @@ class PerformanceMonitor: if not self.metrics_buffer: return - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) for metric in self.metrics_buffer: conn.execute( @@ -1379,14 +1379,14 @@ class PerformanceMonitor: metric.endpoint, metric.duration_ms, metric.timestamp, - json.dumps(metric.metadata, ensure_ascii=False), + json.dumps(metric.metadata, ensure_ascii = False), ), ) conn.commit() conn.close() - self.metrics_buffer = [] + self.metrics_buffer = [] def _record_slow_query(self, metric: PerformanceMetric) -> None: """记录慢查询""" @@ -1395,7 +1395,7 @@ class PerformanceMonitor: def _trigger_alert(self, metric: PerformanceMetric) -> None: """触发告警""" - alert_data = { + alert_data = { "type": "performance_alert", "metric": metric.to_dict(), "threshold": self.alert_threshold, @@ -1412,7 +1412,7 @@ class PerformanceMonitor: """注册告警处理器""" self.alert_handlers.append(handler) - def get_stats(self, hours: int = 24) -> dict: + def get_stats(self, hours: int = 24) -> dict: """ 获取性能统计 @@ -1425,11 +1425,11 @@ class PerformanceMonitor: # 先刷新缓冲区 self._flush_metrics() - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 总体统计 - overall = conn.execute( + overall = conn.execute( """ SELECT COUNT(*) as total, @@ -1439,11 +1439,11 @@ class PerformanceMonitor: FROM performance_metrics WHERE timestamp > datetime('now', ?) """, - (f"-{hours} hours",), + (f"-{hours} hours", ), ).fetchone() # 按类型统计 - by_type = conn.execute( + by_type = conn.execute( """ SELECT metric_type, @@ -1454,11 +1454,11 @@ class PerformanceMonitor: WHERE timestamp > datetime('now', ?) GROUP BY metric_type """, - (f"-{hours} hours",), + (f"-{hours} hours", ), ).fetchall() # 按端点统计(API) - by_endpoint = conn.execute( + by_endpoint = conn.execute( """ SELECT endpoint, @@ -1467,16 +1467,16 @@ class PerformanceMonitor: MAX(duration_ms) as max_duration FROM performance_metrics WHERE timestamp > datetime('now', ?) - AND metric_type = 'api_response' + AND metric_type = 'api_response' GROUP BY endpoint ORDER BY avg_duration DESC LIMIT 20 """, - (f"-{hours} hours",), + (f"-{hours} hours", ), ).fetchall() # 慢查询统计 - slow_queries = conn.execute( + slow_queries = conn.execute( """ SELECT metric_type, @@ -1531,22 +1531,22 @@ class PerformanceMonitor: ], } - def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict: + def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict: """获取 API 性能详情""" self._flush_metrics() - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - where_clause = "metric_type = 'api_response'" - params = [f"-{hours} hours"] + where_clause = "metric_type = 'api_response'" + params = [f"-{hours} hours"] if endpoint: - where_clause += " AND endpoint = ?" + where_clause += " AND endpoint = ?" params.append(endpoint) # 百分位数统计 - percentiles = conn.execute( + percentiles = conn.execute( f""" SELECT endpoint, @@ -1580,7 +1580,7 @@ class PerformanceMonitor: ], } - def cleanup_old_metrics(self, days: int = 30) -> int: + def cleanup_old_metrics(self, days: int = 30) -> int: """ 清理旧的性能指标数据 @@ -1590,17 +1590,17 @@ class PerformanceMonitor: Returns: int: 删除的记录数 """ - conn = sqlite3.connect(self.db_path) + conn = sqlite3.connect(self.db_path) - cursor = conn.execute( + cursor = conn.execute( """ DELETE FROM performance_metrics WHERE timestamp < datetime('now', ?) """, - (f"-{days} days",), + (f"-{days} days", ), ) - deleted = cursor.rowcount + deleted = cursor.rowcount conn.commit() conn.close() @@ -1613,9 +1613,9 @@ class PerformanceMonitor: def cached( cache_manager: CacheManager, - key_prefix: str = "", - ttl: int = 3600, - key_func: Callable | None = None, + key_prefix: str = "", + ttl: int = 3600, + key_func: Callable | None = None, ) -> None: """ 缓存装饰器 @@ -1632,19 +1632,19 @@ def cached( def wrapper(*args, **kwargs) -> None: # 生成缓存键 if key_func: - cache_key = key_func(*args, **kwargs) + cache_key = key_func(*args, **kwargs) else: # 默认使用函数名和参数哈希 - key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}" - cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}" + key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}" + cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}" # 尝试从缓存获取 - cached_value = cache_manager.get(cache_key) + cached_value = cache_manager.get(cache_key) if cached_value is not None: return cached_value # 执行函数 - result = func(*args, **kwargs) + result = func(*args, **kwargs) # 写入缓存 cache_manager.set(cache_key, result, ttl) @@ -1656,7 +1656,7 @@ def cached( return decorator -def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None: +def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None: """ 性能监控装饰器 @@ -1668,15 +1668,15 @@ def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | Non def decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): - start_time = time.time() + def wrapper(*args, **kwargs) -> None: + start_time = time.time() try: - result = func(*args, **kwargs) + result = func(*args, **kwargs) return result finally: - duration_ms = (time.time() - start_time) * 1000 - ep = endpoint or func.__name__ + duration_ms = (time.time() - start_time) * 1000 + ep = endpoint or func.__name__ monitor.record_metric(metric_type, duration_ms, ep) return wrapper @@ -1696,20 +1696,20 @@ class PerformanceManager: def __init__( self, - db_path: str = "insightflow.db", - redis_url: str | None = None, - enable_sharding: bool = False, - ): - self.db_path = db_path + db_path: str = "insightflow.db", + redis_url: str | None = None, + enable_sharding: bool = False, + ) -> None: + self.db_path = db_path # 初始化各模块 - self.cache = CacheManager(redis_url=redis_url, db_path=db_path) + self.cache = CacheManager(redis_url = redis_url, db_path = db_path) - self.sharding = DatabaseSharding(base_db_path=db_path) if enable_sharding else None + self.sharding = DatabaseSharding(base_db_path = db_path) if enable_sharding else None - self.task_queue = TaskQueue(redis_url=redis_url, db_path=db_path) + self.task_queue = TaskQueue(redis_url = redis_url, db_path = db_path) - self.monitor = PerformanceMonitor(db_path=db_path) + self.monitor = PerformanceMonitor(db_path = db_path) def get_health_status(self) -> dict: """获取系统健康状态""" @@ -1737,29 +1737,29 @@ class PerformanceManager: def get_full_stats(self) -> dict: """获取完整统计信息""" - stats = { + stats = { "cache": self.cache.get_stats(), "task_queue": self.task_queue.get_stats(), "performance": self.monitor.get_stats(), } if self.sharding: - stats["sharding"] = self.sharding.get_shard_stats() + stats["sharding"] = self.sharding.get_shard_stats() return stats # 单例模式 -_performance_manager = None +_performance_manager = None def get_performance_manager( - db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False + db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False ) -> PerformanceManager: """获取性能管理器单例""" global _performance_manager if _performance_manager is None: - _performance_manager = PerformanceManager( - db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding + _performance_manager = PerformanceManager( + db_path = db_path, redis_url = redis_url, enable_sharding = enable_sharding ) return _performance_manager diff --git a/backend/plugin_manager.py b/backend/plugin_manager.py index 2f70244..64e8375 100644 --- a/backend/plugin_manager.py +++ b/backend/plugin_manager.py @@ -22,36 +22,36 @@ from plugin_manager import PluginManager import urllib.parse # Constants -UUID_LENGTH = 8 # UUID 截断长度 +UUID_LENGTH = 8 # UUID 截断长度 # WebDAV 支持 try: import webdav4.client as webdav_client - WEBDAV_AVAILABLE = True + WEBDAV_AVAILABLE = True except ImportError: - WEBDAV_AVAILABLE = False + WEBDAV_AVAILABLE = False class PluginType(Enum): """插件类型""" - CHROME_EXTENSION = "chrome_extension" - FEISHU_BOT = "feishu_bot" - DINGTALK_BOT = "dingtalk_bot" - ZAPIER = "zapier" - MAKE = "make" - WEBDAV = "webdav" - CUSTOM = "custom" + CHROME_EXTENSION = "chrome_extension" + FEISHU_BOT = "feishu_bot" + DINGTALK_BOT = "dingtalk_bot" + ZAPIER = "zapier" + MAKE = "make" + WEBDAV = "webdav" + CUSTOM = "custom" class PluginStatus(Enum): """插件状态""" - ACTIVE = "active" - INACTIVE = "inactive" - ERROR = "error" - PENDING = "pending" + ACTIVE = "active" + INACTIVE = "inactive" + ERROR = "error" + PENDING = "pending" @dataclass @@ -62,12 +62,12 @@ class Plugin: name: str plugin_type: str project_id: str - status: str = "active" - config: dict = field(default_factory=dict) - created_at: str = "" - updated_at: str = "" - last_used_at: str | None = None - use_count: int = 0 + status: str = "active" + config: dict = field(default_factory = dict) + created_at: str = "" + updated_at: str = "" + last_used_at: str | None = None + use_count: int = 0 @dataclass @@ -78,9 +78,9 @@ class PluginConfig: plugin_id: str config_key: str config_value: str - is_encrypted: bool = False - created_at: str = "" - updated_at: str = "" + is_encrypted: bool = False + created_at: str = "" + updated_at: str = "" @dataclass @@ -91,14 +91,14 @@ class BotSession: bot_type: str # feishu, dingtalk session_id: str # 群ID或会话ID session_name: str - project_id: str | None = None - webhook_url: str = "" - secret: str = "" - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_message_at: str | None = None - message_count: int = 0 + project_id: str | None = None + webhook_url: str = "" + secret: str = "" + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_message_at: str | None = None + message_count: int = 0 @dataclass @@ -109,15 +109,15 @@ class WebhookEndpoint: name: str endpoint_type: str # zapier, make, custom endpoint_url: str - project_id: str | None = None - auth_type: str = "none" # none, api_key, oauth, custom - auth_config: dict = field(default_factory=dict) - trigger_events: list[str] = field(default_factory=list) - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_triggered_at: str | None = None - trigger_count: int = 0 + project_id: str | None = None + auth_type: str = "none" # none, api_key, oauth, custom + auth_config: dict = field(default_factory = dict) + trigger_events: list[str] = field(default_factory = list) + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_triggered_at: str | None = None + trigger_count: int = 0 @dataclass @@ -129,17 +129,17 @@ class WebDAVSync: project_id: str server_url: str username: str - password: str = "" # 加密存储 - remote_path: str = "/insightflow" - sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only - sync_interval: int = 3600 # 秒 - last_sync_at: str | None = None - last_sync_status: str = "pending" # pending, success, failed - last_sync_error: str = "" - is_active: bool = True - created_at: str = "" - updated_at: str = "" - sync_count: int = 0 + password: str = "" # 加密存储 + remote_path: str = "/insightflow" + sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only + sync_interval: int = 3600 # 秒 + last_sync_at: str | None = None + last_sync_status: str = "pending" # pending, success, failed + last_sync_error: str = "" + is_active: bool = True + created_at: str = "" + updated_at: str = "" + sync_count: int = 0 @dataclass @@ -148,33 +148,33 @@ class ChromeExtensionToken: id: str token: str - user_id: str | None = None - project_id: str | None = None - name: str = "" - permissions: list[str] = field(default_factory=lambda: ["read", "write"]) - expires_at: str | None = None - created_at: str = "" - last_used_at: str | None = None - use_count: int = 0 - is_revoked: bool = False + user_id: str | None = None + project_id: str | None = None + name: str = "" + permissions: list[str] = field(default_factory = lambda: ["read", "write"]) + expires_at: str | None = None + created_at: str = "" + last_used_at: str | None = None + use_count: int = 0 + is_revoked: bool = False class PluginManager: """插件管理主类""" - def __init__(self, db_manager=None): - self.db = db_manager - self._handlers = {} + def __init__(self, db_manager = None) -> None: + self.db = db_manager + self._handlers = {} self._register_default_handlers() def _register_default_handlers(self) -> None: """注册默认处理器""" - self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self) - self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu") - self._handlers[PluginType.DINGTALK_BOT] = BotHandler(self, "dingtalk") - self._handlers[PluginType.ZAPIER] = WebhookIntegration(self, "zapier") - self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make") - self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self) + self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self) + self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu") + self._handlers[PluginType.DINGTALK_BOT] = BotHandler(self, "dingtalk") + self._handlers[PluginType.ZAPIER] = WebhookIntegration(self, "zapier") + self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make") + self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self) def get_handler(self, plugin_type: PluginType) -> Any | None: """获取插件处理器""" @@ -184,8 +184,8 @@ class PluginManager: def create_plugin(self, plugin: Plugin) -> Plugin: """创建插件""" - conn = self.db.get_conn() - now = datetime.now().isoformat() + conn = self.db.get_conn() + now = datetime.now().isoformat() conn.execute( """INSERT INTO plugins @@ -206,14 +206,14 @@ class PluginManager: conn.commit() conn.close() - plugin.created_at = now - plugin.updated_at = now + plugin.created_at = now + plugin.updated_at = now return plugin def get_plugin(self, plugin_id: str) -> Plugin | None: """获取插件""" - conn = self.db.get_conn() - row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id,)).fetchone() + conn = self.db.get_conn() + row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id, )).fetchone() conn.close() if row: @@ -221,27 +221,27 @@ class PluginManager: return None def list_plugins( - self, project_id: str = None, plugin_type: str = None, status: str = None + self, project_id: str = None, plugin_type: str = None, status: str = None ) -> list[Plugin]: """列出插件""" - conn = self.db.get_conn() + conn = self.db.get_conn() - conditions = [] - params = [] + conditions = [] + params = [] if project_id: - conditions.append("project_id = ?") + conditions.append("project_id = ?") params.append(project_id) if plugin_type: - conditions.append("plugin_type = ?") + conditions.append("plugin_type = ?") params.append(plugin_type) if status: - conditions.append("status = ?") + conditions.append("status = ?") params.append(status) - where_clause = " AND ".join(conditions) if conditions else "1=1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params ).fetchall() conn.close() @@ -250,15 +250,15 @@ class PluginManager: def update_plugin(self, plugin_id: str, **kwargs) -> Plugin | None: """更新插件""" - conn = self.db.get_conn() + conn = self.db.get_conn() - allowed_fields = ["name", "status", "config"] - updates = [] - values = [] + allowed_fields = ["name", "status", "config"] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") if f == "config": values.append(json.dumps(kwargs[f])) else: @@ -268,11 +268,11 @@ class PluginManager: conn.close() return self.get_plugin(plugin_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(plugin_id) - query = f"UPDATE plugins SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE plugins SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -281,13 +281,13 @@ class PluginManager: def delete_plugin(self, plugin_id: str) -> bool: """删除插件""" - conn = self.db.get_conn() + conn = self.db.get_conn() # 删除关联的配置 - conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id,)) + conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id, )) # 删除插件 - cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id,)) + cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id, )) conn.commit() conn.close() @@ -296,42 +296,42 @@ class PluginManager: def _row_to_plugin(self, row: sqlite3.Row) -> Plugin: """将数据库行转换为 Plugin 对象""" return Plugin( - id=row["id"], - name=row["name"], - plugin_type=row["plugin_type"], - project_id=row["project_id"], - status=row["status"], - config=json.loads(row["config"]) if row["config"] else {}, - created_at=row["created_at"], - updated_at=row["updated_at"], - last_used_at=row["last_used_at"], - use_count=row["use_count"], + id = row["id"], + name = row["name"], + plugin_type = row["plugin_type"], + project_id = row["project_id"], + status = row["status"], + config = json.loads(row["config"]) if row["config"] else {}, + created_at = row["created_at"], + updated_at = row["updated_at"], + last_used_at = row["last_used_at"], + use_count = row["use_count"], ) # ==================== Plugin Config ==================== def set_plugin_config( - self, plugin_id: str, key: str, value: str, is_encrypted: bool = False + self, plugin_id: str, key: str, value: str, is_encrypted: bool = False ) -> PluginConfig: """设置插件配置""" - conn = self.db.get_conn() - now = datetime.now().isoformat() + conn = self.db.get_conn() + now = datetime.now().isoformat() # 检查是否已存在 - existing = conn.execute( - "SELECT id FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) + existing = conn.execute( + "SELECT id FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) ).fetchone() if existing: conn.execute( """UPDATE plugin_configs - SET config_value = ?, is_encrypted = ?, updated_at = ? - WHERE id = ?""", + SET config_value = ?, is_encrypted = ?, updated_at = ? + WHERE id = ?""", (value, is_encrypted, now, existing["id"]), ) - config_id = existing["id"] + config_id = existing["id"] else: - config_id = str(uuid.uuid4())[:UUID_LENGTH] + config_id = str(uuid.uuid4())[:UUID_LENGTH] conn.execute( """INSERT INTO plugin_configs (id, plugin_id, config_key, config_value, is_encrypted, created_at, updated_at) @@ -343,20 +343,20 @@ class PluginManager: conn.close() return PluginConfig( - id=config_id, - plugin_id=plugin_id, - config_key=key, - config_value=value, - is_encrypted=is_encrypted, - created_at=now, - updated_at=now, + id = config_id, + plugin_id = plugin_id, + config_key = key, + config_value = value, + is_encrypted = is_encrypted, + created_at = now, + updated_at = now, ) def get_plugin_config(self, plugin_id: str, key: str) -> str | None: """获取插件配置""" - conn = self.db.get_conn() - row = conn.execute( - "SELECT config_value FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", + conn = self.db.get_conn() + row = conn.execute( + "SELECT config_value FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key), ).fetchone() conn.close() @@ -365,9 +365,9 @@ class PluginManager: def get_all_plugin_configs(self, plugin_id: str) -> dict[str, str]: """获取插件所有配置""" - conn = self.db.get_conn() - rows = conn.execute( - "SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id,) + conn = self.db.get_conn() + rows = conn.execute( + "SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id, ) ).fetchall() conn.close() @@ -375,9 +375,9 @@ class PluginManager: def delete_plugin_config(self, plugin_id: str, key: str) -> bool: """删除插件配置""" - conn = self.db.get_conn() - cursor = conn.execute( - "DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) + conn = self.db.get_conn() + cursor = conn.execute( + "DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) ) conn.commit() conn.close() @@ -386,13 +386,13 @@ class PluginManager: def record_plugin_usage(self, plugin_id: str) -> None: """记录插件使用""" - conn = self.db.get_conn() - now = datetime.now().isoformat() + conn = self.db.get_conn() + now = datetime.now().isoformat() conn.execute( """UPDATE plugins - SET use_count = use_count + 1, last_used_at = ? - WHERE id = ?""", + SET use_count = use_count + 1, last_used_at = ? + WHERE id = ?""", (now, plugin_id), ) conn.commit() @@ -402,34 +402,34 @@ class PluginManager: class ChromeExtensionHandler: """Chrome 扩展处理器""" - def __init__(self, plugin_manager: PluginManager): - self.pm = plugin_manager + def __init__(self, plugin_manager: PluginManager) -> None: + self.pm = plugin_manager def create_token( self, name: str, - user_id: str = None, - project_id: str = None, - permissions: list[str] = None, - expires_days: int = None, + user_id: str = None, + project_id: str = None, + permissions: list[str] = None, + expires_days: int = None, ) -> ChromeExtensionToken: """创建 Chrome 扩展令牌""" - token_id = str(uuid.uuid4())[:UUID_LENGTH] + token_id = str(uuid.uuid4())[:UUID_LENGTH] # 生成随机令牌 - raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip('=')}" + raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip(' = ')}" # 哈希存储 - token_hash = hashlib.sha256(raw_token.encode()).hexdigest() + token_hash = hashlib.sha256(raw_token.encode()).hexdigest() - now = datetime.now().isoformat() - expires_at = None + now = datetime.now().isoformat() + expires_at = None if expires_days: from datetime import timedelta - expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() + expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO chrome_extension_tokens (id, token_hash, user_id, project_id, name, permissions, expires_at, @@ -452,25 +452,25 @@ class ChromeExtensionHandler: conn.close() return ChromeExtensionToken( - id=token_id, - token=raw_token, # 仅返回一次 - user_id=user_id, - project_id=project_id, - name=name, - permissions=permissions or ["read"], - expires_at=expires_at, - created_at=now, + id = token_id, + token = raw_token, # 仅返回一次 + user_id = user_id, + project_id = project_id, + name = name, + permissions = permissions or ["read"], + expires_at = expires_at, + created_at = now, ) def validate_token(self, token: str) -> ChromeExtensionToken | None: """验证 Chrome 扩展令牌""" - token_hash = hashlib.sha256(token.encode()).hexdigest() + token_hash = hashlib.sha256(token.encode()).hexdigest() - conn = self.pm.db.get_conn() - row = conn.execute( + conn = self.pm.db.get_conn() + row = conn.execute( """SELECT * FROM chrome_extension_tokens - WHERE token_hash = ? AND is_revoked = 0""", - (token_hash,), + WHERE token_hash = ? AND is_revoked = 0""", + (token_hash, ), ).fetchone() conn.close() @@ -482,35 +482,35 @@ class ChromeExtensionHandler: return None # 更新使用记录 - now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + now = datetime.now().isoformat() + conn = self.pm.db.get_conn() conn.execute( """UPDATE chrome_extension_tokens - SET use_count = use_count + 1, last_used_at = ? - WHERE id = ?""", + SET use_count = use_count + 1, last_used_at = ? + WHERE id = ?""", (now, row["id"]), ) conn.commit() conn.close() return ChromeExtensionToken( - id=row["id"], - token="", # 不返回实际令牌 - user_id=row["user_id"], - project_id=row["project_id"], - name=row["name"], - permissions=json.loads(row["permissions"]), - expires_at=row["expires_at"], - created_at=row["created_at"], - last_used_at=now, - use_count=row["use_count"] + 1, + id = row["id"], + token = "", # 不返回实际令牌 + user_id = row["user_id"], + project_id = row["project_id"], + name = row["name"], + permissions = json.loads(row["permissions"]), + expires_at = row["expires_at"], + created_at = row["created_at"], + last_used_at = now, + use_count = row["use_count"] + 1, ) def revoke_token(self, token_id: str) -> bool: """撤销令牌""" - conn = self.pm.db.get_conn() - cursor = conn.execute( - "UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,) + conn = self.pm.db.get_conn() + cursor = conn.execute( + "UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id, ) ) conn.commit() conn.close() @@ -518,44 +518,44 @@ class ChromeExtensionHandler: return cursor.rowcount > 0 def list_tokens( - self, user_id: str = None, project_id: str = None + self, user_id: str = None, project_id: str = None ) -> list[ChromeExtensionToken]: """列出令牌""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - conditions = ["is_revoked = 0"] - params = [] + conditions = ["is_revoked = 0"] + params = [] if user_id: - conditions.append("user_id = ?") + conditions.append("user_id = ?") params.append(user_id) if project_id: - conditions.append("project_id = ?") + conditions.append("project_id = ?") params.append(project_id) - where_clause = " AND ".join(conditions) + where_clause = " AND ".join(conditions) - rows = conn.execute( + rows = conn.execute( f"SELECT * FROM chrome_extension_tokens WHERE {where_clause} ORDER BY created_at DESC", params, ).fetchall() conn.close() - tokens = [] + tokens = [] for row in rows: tokens.append( ChromeExtensionToken( - id=row["id"], - token="", # 不返回实际令牌 - user_id=row["user_id"], - project_id=row["project_id"], - name=row["name"], - permissions=json.loads(row["permissions"]), - expires_at=row["expires_at"], - created_at=row["created_at"], - last_used_at=row["last_used_at"], - use_count=row["use_count"], - is_revoked=bool(row["is_revoked"]), + id = row["id"], + token = "", # 不返回实际令牌 + user_id = row["user_id"], + project_id = row["project_id"], + name = row["name"], + permissions = json.loads(row["permissions"]), + expires_at = row["expires_at"], + created_at = row["created_at"], + last_used_at = row["last_used_at"], + use_count = row["use_count"], + is_revoked = bool(row["is_revoked"]), ) ) @@ -567,7 +567,7 @@ class ChromeExtensionHandler: url: str, title: str, content: str, - html_content: str = None, + html_content: str = None, ) -> dict: """导入网页内容""" if not token.project_id: @@ -577,13 +577,13 @@ class ChromeExtensionHandler: return {"success": False, "error": "Insufficient permissions"} # 创建转录记录(将网页作为文档处理) - transcript_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + transcript_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() # 构建完整文本 - full_text = f"# {title}\n\nURL: {url}\n\n{content}" + full_text = f"# {title}\n\nURL: {url}\n\n{content}" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO transcripts (id, project_id, filename, full_text, type, created_at) @@ -606,23 +606,23 @@ class ChromeExtensionHandler: class BotHandler: """飞书/钉钉机器人处理器""" - def __init__(self, plugin_manager: PluginManager, bot_type: str): - self.pm = plugin_manager - self.bot_type = bot_type + def __init__(self, plugin_manager: PluginManager, bot_type: str) -> None: + self.pm = plugin_manager + self.bot_type = bot_type def create_session( self, session_id: str, session_name: str, - project_id: str = None, - webhook_url: str = "", - secret: str = "", + project_id: str = None, + webhook_url: str = "", + secret: str = "", ) -> BotSession: """创建机器人会话""" - bot_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + bot_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO bot_sessions (id, bot_type, session_id, session_name, project_id, webhook_url, secret, @@ -646,24 +646,24 @@ class BotHandler: conn.close() return BotSession( - id=bot_id, - bot_type=self.bot_type, - session_id=session_id, - session_name=session_name, - project_id=project_id, - webhook_url=webhook_url, - secret=secret, - is_active=True, - created_at=now, - updated_at=now, + id = bot_id, + bot_type = self.bot_type, + session_id = session_id, + session_name = session_name, + project_id = project_id, + webhook_url = webhook_url, + secret = secret, + is_active = True, + created_at = now, + updated_at = now, ) def get_session(self, session_id: str) -> BotSession | None: """获取会话""" - conn = self.pm.db.get_conn() - row = conn.execute( + conn = self.pm.db.get_conn() + row = conn.execute( """SELECT * FROM bot_sessions - WHERE session_id = ? AND bot_type = ?""", + WHERE session_id = ? AND bot_type = ?""", (session_id, self.bot_type), ).fetchone() conn.close() @@ -672,21 +672,21 @@ class BotHandler: return self._row_to_session(row) return None - def list_sessions(self, project_id: str = None) -> list[BotSession]: + def list_sessions(self, project_id: str = None) -> list[BotSession]: """列出会话""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() if project_id: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM bot_sessions - WHERE bot_type = ? AND project_id = ? ORDER BY created_at DESC""", + WHERE bot_type = ? AND project_id = ? ORDER BY created_at DESC""", (self.bot_type, project_id), ).fetchall() else: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM bot_sessions - WHERE bot_type = ? ORDER BY created_at DESC""", - (self.bot_type,), + WHERE bot_type = ? ORDER BY created_at DESC""", + (self.bot_type, ), ).fetchall() conn.close() @@ -695,28 +695,28 @@ class BotHandler: def update_session(self, session_id: str, **kwargs) -> BotSession | None: """更新会话""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - allowed_fields = ["session_name", "project_id", "webhook_url", "secret", "is_active"] - updates = [] - values = [] + allowed_fields = ["session_name", "project_id", "webhook_url", "secret", "is_active"] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") values.append(kwargs[f]) if not updates: conn.close() return self.get_session(session_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(session_id) values.append(self.bot_type) - query = ( - f"UPDATE bot_sessions SET {', '.join(updates)} WHERE session_id = ? AND bot_type = ?" + query = ( + f"UPDATE bot_sessions SET {', '.join(updates)} WHERE session_id = ? AND bot_type = ?" ) conn.execute(query, values) conn.commit() @@ -726,9 +726,9 @@ class BotHandler: def delete_session(self, session_id: str) -> bool: """删除会话""" - conn = self.pm.db.get_conn() - cursor = conn.execute( - "DELETE FROM bot_sessions WHERE session_id = ? AND bot_type = ?", + conn = self.pm.db.get_conn() + cursor = conn.execute( + "DELETE FROM bot_sessions WHERE session_id = ? AND bot_type = ?", (session_id, self.bot_type), ) conn.commit() @@ -739,41 +739,41 @@ class BotHandler: def _row_to_session(self, row: sqlite3.Row) -> BotSession: """将数据库行转换为 BotSession 对象""" return BotSession( - id=row["id"], - bot_type=row["bot_type"], - session_id=row["session_id"], - session_name=row["session_name"], - project_id=row["project_id"], - webhook_url=row["webhook_url"], - secret=row["secret"], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], - last_message_at=row["last_message_at"], - message_count=row["message_count"], + id = row["id"], + bot_type = row["bot_type"], + session_id = row["session_id"], + session_name = row["session_name"], + project_id = row["project_id"], + webhook_url = row["webhook_url"], + secret = row["secret"], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], + last_message_at = row["last_message_at"], + message_count = row["message_count"], ) async def handle_message(self, session: BotSession, message: dict) -> dict: """处理收到的消息""" - now = datetime.now().isoformat() + now = datetime.now().isoformat() # 更新消息统计 - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """UPDATE bot_sessions - SET message_count = message_count + 1, last_message_at = ? - WHERE id = ?""", + SET message_count = message_count + 1, last_message_at = ? + WHERE id = ?""", (now, session.id), ) conn.commit() conn.close() # 处理消息 - msg_type = message.get("msg_type", "text") - content = message.get("content", {}) + msg_type = message.get("msg_type", "text") + content = message.get("content", {}) if msg_type == "text": - text = content.get("text", "") + text = content.get("text", "") return await self._handle_text_message(session, text, message) elif msg_type == "audio": # 处理音频消息 @@ -802,8 +802,8 @@ class BotHandler: return {"success": True, "response": "⚠️ 当前会话未绑定项目"} # 获取项目状态 - summary = self.pm.db.get_project_summary(session.project_id) - stats = summary.get("statistics", {}) + summary = self.pm.db.get_project_summary(session.project_id) + stats = summary.get("statistics", {}) return { "success": True, @@ -825,17 +825,17 @@ class BotHandler: return {"success": False, "error": "Session not bound to any project"} # 下载音频文件 - audio_url = message.get("content", {}).get("download_url") + audio_url = message.get("content", {}).get("download_url") if not audio_url: return {"success": False, "error": "No audio URL provided"} try: async with httpx.AsyncClient() as client: - response = await client.get(audio_url) - audio_data = response.content + response = await client.get(audio_url) + audio_data = response.content # 保存音频文件 - filename = f"bot_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" + filename = f"bot_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" # 这里应该调用 ASR 服务进行转录 # 简化处理,返回提示 @@ -853,7 +853,7 @@ class BotHandler: """处理文件消息""" return {"success": True, "response": "📎 收到文件,正在处理中..."} - async def send_message(self, session: BotSession, message: str, msg_type: str = "text") -> bool: + async def send_message(self, session: BotSession, message: str, msg_type: str = "text") -> bool: """发送消息到群聊""" if not session.webhook_url: return False @@ -872,21 +872,21 @@ class BotHandler: async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool: """发送飞书消息""" - timestamp = str(int(time.time())) + timestamp = str(int(time.time())) # 生成签名 if session.secret: - string_to_sign = f"{timestamp}\n{session.secret}" - hmac_code = hmac.new( + string_to_sign = f"{timestamp}\n{session.secret}" + hmac_code = hmac.new( session.secret.encode("utf-8"), string_to_sign.encode("utf-8"), - digestmod=hashlib.sha256, + digestmod = hashlib.sha256, ).digest() - sign = base64.b64encode(hmac_code).decode("utf-8") + sign = base64.b64encode(hmac_code).decode("utf-8") else: - sign = "" + sign = "" - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "text", @@ -894,8 +894,8 @@ class BotHandler: } async with httpx.AsyncClient() as client: - response = await client.post( - session.webhook_url, json=payload, headers={"Content-Type": "application/json"} + response = await client.post( + session.webhook_url, json = payload, headers = {"Content-Type": "application/json"} ) return response.status_code == 200 @@ -903,30 +903,30 @@ class BotHandler: self, session: BotSession, message: str, msg_type: str ) -> bool: """发送钉钉消息""" - timestamp = str(round(time.time() * 1000)) + timestamp = str(round(time.time() * 1000)) # 生成签名 if session.secret: - string_to_sign = f"{timestamp}\n{session.secret}" - hmac_code = hmac.new( + string_to_sign = f"{timestamp}\n{session.secret}" + hmac_code = hmac.new( session.secret.encode("utf-8"), string_to_sign.encode("utf-8"), - digestmod=hashlib.sha256, + digestmod = hashlib.sha256, ).digest() - sign = base64.b64encode(hmac_code).decode("utf-8") - sign = urllib.parse.quote(sign) + sign = base64.b64encode(hmac_code).decode("utf-8") + sign = urllib.parse.quote(sign) else: - sign = "" + sign = "" - payload = {"msgtype": "text", "text": {"content": message}} + payload = {"msgtype": "text", "text": {"content": message}} - url = session.webhook_url + url = session.webhook_url if sign: - url = f"{url}×tamp={timestamp}&sign={sign}" + url = f"{url}×tamp = {timestamp}&sign = {sign}" async with httpx.AsyncClient() as client: - response = await client.post( - url, json=payload, headers={"Content-Type": "application/json"} + response = await client.post( + url, json = payload, headers = {"Content-Type": "application/json"} ) return response.status_code == 200 @@ -934,24 +934,24 @@ class BotHandler: class WebhookIntegration: """Zapier/Make Webhook 集成""" - def __init__(self, plugin_manager: PluginManager, endpoint_type: str): - self.pm = plugin_manager - self.endpoint_type = endpoint_type + def __init__(self, plugin_manager: PluginManager, endpoint_type: str) -> None: + self.pm = plugin_manager + self.endpoint_type = endpoint_type def create_endpoint( self, name: str, endpoint_url: str, - project_id: str = None, - auth_type: str = "none", - auth_config: dict = None, - trigger_events: list[str] = None, + project_id: str = None, + auth_type: str = "none", + auth_config: dict = None, + trigger_events: list[str] = None, ) -> WebhookEndpoint: """创建 Webhook 端点""" - endpoint_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + endpoint_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO webhook_endpoints (id, name, endpoint_type, endpoint_url, project_id, auth_type, auth_config, @@ -976,24 +976,24 @@ class WebhookIntegration: conn.close() return WebhookEndpoint( - id=endpoint_id, - name=name, - endpoint_type=self.endpoint_type, - endpoint_url=endpoint_url, - project_id=project_id, - auth_type=auth_type, - auth_config=auth_config or {}, - trigger_events=trigger_events or [], - is_active=True, - created_at=now, - updated_at=now, + id = endpoint_id, + name = name, + endpoint_type = self.endpoint_type, + endpoint_url = endpoint_url, + project_id = project_id, + auth_type = auth_type, + auth_config = auth_config or {}, + trigger_events = trigger_events or [], + is_active = True, + created_at = now, + updated_at = now, ) def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None: """获取端点""" - conn = self.pm.db.get_conn() - row = conn.execute( - "SELECT * FROM webhook_endpoints WHERE id = ? AND endpoint_type = ?", + conn = self.pm.db.get_conn() + row = conn.execute( + "SELECT * FROM webhook_endpoints WHERE id = ? AND endpoint_type = ?", (endpoint_id, self.endpoint_type), ).fetchone() conn.close() @@ -1002,21 +1002,21 @@ class WebhookIntegration: return self._row_to_endpoint(row) return None - def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]: + def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]: """列出端点""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() if project_id: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM webhook_endpoints - WHERE endpoint_type = ? AND project_id = ? ORDER BY created_at DESC""", + WHERE endpoint_type = ? AND project_id = ? ORDER BY created_at DESC""", (self.endpoint_type, project_id), ).fetchall() else: - rows = conn.execute( + rows = conn.execute( """SELECT * FROM webhook_endpoints - WHERE endpoint_type = ? ORDER BY created_at DESC""", - (self.endpoint_type,), + WHERE endpoint_type = ? ORDER BY created_at DESC""", + (self.endpoint_type, ), ).fetchall() conn.close() @@ -1025,9 +1025,9 @@ class WebhookIntegration: def update_endpoint(self, endpoint_id: str, **kwargs) -> WebhookEndpoint | None: """更新端点""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - allowed_fields = [ + allowed_fields = [ "name", "endpoint_url", "project_id", @@ -1036,12 +1036,12 @@ class WebhookIntegration: "trigger_events", "is_active", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") if f in ["auth_config", "trigger_events"]: values.append(json.dumps(kwargs[f])) else: @@ -1051,11 +1051,11 @@ class WebhookIntegration: conn.close() return self.get_endpoint(endpoint_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(endpoint_id) - query = f"UPDATE webhook_endpoints SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE webhook_endpoints SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -1064,8 +1064,8 @@ class WebhookIntegration: def delete_endpoint(self, endpoint_id: str) -> bool: """删除端点""" - conn = self.pm.db.get_conn() - cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id,)) + conn = self.pm.db.get_conn() + cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id, )) conn.commit() conn.close() @@ -1074,19 +1074,19 @@ class WebhookIntegration: def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint: """将数据库行转换为 WebhookEndpoint 对象""" return WebhookEndpoint( - id=row["id"], - name=row["name"], - endpoint_type=row["endpoint_type"], - endpoint_url=row["endpoint_url"], - project_id=row["project_id"], - auth_type=row["auth_type"], - auth_config=json.loads(row["auth_config"]) if row["auth_config"] else {}, - trigger_events=json.loads(row["trigger_events"]) if row["trigger_events"] else [], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], - last_triggered_at=row["last_triggered_at"], - trigger_count=row["trigger_count"], + id = row["id"], + name = row["name"], + endpoint_type = row["endpoint_type"], + endpoint_url = row["endpoint_url"], + project_id = row["project_id"], + auth_type = row["auth_type"], + auth_config = json.loads(row["auth_config"]) if row["auth_config"] else {}, + trigger_events = json.loads(row["trigger_events"]) if row["trigger_events"] else [], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], + last_triggered_at = row["last_triggered_at"], + trigger_count = row["trigger_count"], ) async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool: @@ -1098,33 +1098,33 @@ class WebhookIntegration: return False try: - headers = {"Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} # 添加认证头 if endpoint.auth_type == "api_key": - api_key = endpoint.auth_config.get("api_key", "") - header_name = endpoint.auth_config.get("header_name", "X-API-Key") - headers[header_name] = api_key + api_key = endpoint.auth_config.get("api_key", "") + header_name = endpoint.auth_config.get("header_name", "X-API-Key") + headers[header_name] = api_key elif endpoint.auth_type == "bearer": - token = endpoint.auth_config.get("token", "") - headers["Authorization"] = f"Bearer {token}" + token = endpoint.auth_config.get("token", "") + headers["Authorization"] = f"Bearer {token}" - payload = {"event": event_type, "timestamp": datetime.now().isoformat(), "data": data} + payload = {"event": event_type, "timestamp": datetime.now().isoformat(), "data": data} async with httpx.AsyncClient() as client: - response = await client.post( - endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0 + response = await client.post( + endpoint.endpoint_url, json = payload, headers = headers, timeout = 30.0 ) - success = response.status_code in [200, 201, 202] + success = response.status_code in [200, 201, 202] # 更新触发统计 - now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + now = datetime.now().isoformat() + conn = self.pm.db.get_conn() conn.execute( """UPDATE webhook_endpoints - SET trigger_count = trigger_count + 1, last_triggered_at = ? - WHERE id = ?""", + SET trigger_count = trigger_count + 1, last_triggered_at = ? + WHERE id = ?""", (now, endpoint.id), ) conn.commit() @@ -1138,13 +1138,13 @@ class WebhookIntegration: async def test_endpoint(self, endpoint: WebhookEndpoint) -> dict: """测试端点""" - test_data = { + test_data = { "message": "This is a test event from InsightFlow", "test": True, "timestamp": datetime.now().isoformat(), } - success = await self.trigger(endpoint, "test", test_data) + success = await self.trigger(endpoint, "test", test_data) return { "success": success, @@ -1157,8 +1157,8 @@ class WebhookIntegration: class WebDAVSyncManager: """WebDAV 同步管理""" - def __init__(self, plugin_manager: PluginManager): - self.pm = plugin_manager + def __init__(self, plugin_manager: PluginManager) -> None: + self.pm = plugin_manager def create_sync( self, @@ -1167,15 +1167,15 @@ class WebDAVSyncManager: server_url: str, username: str, password: str, - remote_path: str = "/insightflow", - sync_mode: str = "bidirectional", - sync_interval: int = 3600, + remote_path: str = "/insightflow", + sync_mode: str = "bidirectional", + sync_interval: int = 3600, ) -> WebDAVSync: """创建 WebDAV 同步配置""" - sync_id = str(uuid.uuid4())[:UUID_LENGTH] - now = datetime.now().isoformat() + sync_id = str(uuid.uuid4())[:UUID_LENGTH] + now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """INSERT INTO webdav_syncs (id, name, project_id, server_url, username, password, remote_path, @@ -1202,42 +1202,42 @@ class WebDAVSyncManager: conn.close() return WebDAVSync( - id=sync_id, - name=name, - project_id=project_id, - server_url=server_url, - username=username, - password=password, - remote_path=remote_path, - sync_mode=sync_mode, - sync_interval=sync_interval, - last_sync_status="pending", - is_active=True, - created_at=now, - updated_at=now, + id = sync_id, + name = name, + project_id = project_id, + server_url = server_url, + username = username, + password = password, + remote_path = remote_path, + sync_mode = sync_mode, + sync_interval = sync_interval, + last_sync_status = "pending", + is_active = True, + created_at = now, + updated_at = now, ) def get_sync(self, sync_id: str) -> WebDAVSync | None: """获取同步配置""" - conn = self.pm.db.get_conn() - row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id,)).fetchone() + conn = self.pm.db.get_conn() + row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id, )).fetchone() conn.close() if row: return self._row_to_sync(row) return None - def list_syncs(self, project_id: str = None) -> list[WebDAVSync]: + def list_syncs(self, project_id: str = None) -> list[WebDAVSync]: """列出同步配置""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() if project_id: - rows = conn.execute( - "SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC", - (project_id,), + rows = conn.execute( + "SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC", + (project_id, ), ).fetchall() else: - rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall() + rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall() conn.close() @@ -1245,9 +1245,9 @@ class WebDAVSyncManager: def update_sync(self, sync_id: str, **kwargs) -> WebDAVSync | None: """更新同步配置""" - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() - allowed_fields = [ + allowed_fields = [ "name", "server_url", "username", @@ -1257,23 +1257,23 @@ class WebDAVSyncManager: "sync_interval", "is_active", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") values.append(kwargs[f]) if not updates: conn.close() return self.get_sync(sync_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(sync_id) - query = f"UPDATE webdav_syncs SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE webdav_syncs SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() conn.close() @@ -1282,8 +1282,8 @@ class WebDAVSyncManager: def delete_sync(self, sync_id: str) -> bool: """删除同步配置""" - conn = self.pm.db.get_conn() - cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id,)) + conn = self.pm.db.get_conn() + cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id, )) conn.commit() conn.close() @@ -1292,22 +1292,22 @@ class WebDAVSyncManager: def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync: """将数据库行转换为 WebDAVSync 对象""" return WebDAVSync( - id=row["id"], - name=row["name"], - project_id=row["project_id"], - server_url=row["server_url"], - username=row["username"], - password=row["password"], - remote_path=row["remote_path"], - sync_mode=row["sync_mode"], - sync_interval=row["sync_interval"], - last_sync_at=row["last_sync_at"], - last_sync_status=row["last_sync_status"], - last_sync_error=row["last_sync_error"] or "", - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], - sync_count=row["sync_count"], + id = row["id"], + name = row["name"], + project_id = row["project_id"], + server_url = row["server_url"], + username = row["username"], + password = row["password"], + remote_path = row["remote_path"], + sync_mode = row["sync_mode"], + sync_interval = row["sync_interval"], + last_sync_at = row["last_sync_at"], + last_sync_status = row["last_sync_status"], + last_sync_error = row["last_sync_error"] or "", + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], + sync_count = row["sync_count"], ) async def test_connection(self, sync: WebDAVSync) -> dict: @@ -1316,7 +1316,7 @@ class WebDAVSyncManager: return {"success": False, "error": "WebDAV library not available"} try: - client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password)) + client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password)) # 尝试列出根目录 client.list("/") @@ -1335,26 +1335,26 @@ class WebDAVSyncManager: return {"success": False, "error": "Sync is not active"} try: - client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password)) + client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password)) # 确保远程目录存在 - remote_project_path = f"{sync.remote_path}/{sync.project_id}" + remote_project_path = f"{sync.remote_path}/{sync.project_id}" try: client.mkdir(remote_project_path) except (OSError, IOError): pass # 目录可能已存在 # 获取项目数据 - project = self.pm.db.get_project(sync.project_id) + project = self.pm.db.get_project(sync.project_id) if not project: return {"success": False, "error": "Project not found"} # 导出项目数据为 JSON - entities = self.pm.db.list_project_entities(sync.project_id) - relations = self.pm.db.list_project_relations(sync.project_id) - transcripts = self.pm.db.list_project_transcripts(sync.project_id) + entities = self.pm.db.list_project_entities(sync.project_id) + relations = self.pm.db.list_project_relations(sync.project_id) + transcripts = self.pm.db.list_project_transcripts(sync.project_id) - export_data = { + export_data = { "project": { "id": project.id, "name": project.name, @@ -1367,26 +1367,26 @@ class WebDAVSyncManager: } # 上传 JSON 文件 - json_content = json.dumps(export_data, ensure_ascii=False, indent=2) - json_path = f"{remote_project_path}/project_export.json" + json_content = json.dumps(export_data, ensure_ascii = False, indent = 2) + json_path = f"{remote_project_path}/project_export.json" # 使用临时文件上传 import tempfile - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + with tempfile.NamedTemporaryFile(mode = "w", suffix = ".json", delete = False) as f: f.write(json_content) - temp_path = f.name + temp_path = f.name client.upload_file(temp_path, json_path) os.unlink(temp_path) # 更新同步状态 - now = datetime.now().isoformat() - conn = self.pm.db.get_conn() + now = datetime.now().isoformat() + conn = self.pm.db.get_conn() conn.execute( """UPDATE webdav_syncs - SET last_sync_at = ?, last_sync_status = ?, sync_count = sync_count + 1 - WHERE id = ?""", + SET last_sync_at = ?, last_sync_status = ?, sync_count = sync_count + 1 + WHERE id = ?""", (now, "success", sync.id), ) conn.commit() @@ -1402,11 +1402,11 @@ class WebDAVSyncManager: except Exception as e: # 更新失败状态 - conn = self.pm.db.get_conn() + conn = self.pm.db.get_conn() conn.execute( """UPDATE webdav_syncs - SET last_sync_status = ?, last_sync_error = ? - WHERE id = ?""", + SET last_sync_status = ?, last_sync_error = ? + WHERE id = ?""", ("failed", str(e), sync.id), ) conn.commit() @@ -1416,12 +1416,12 @@ class WebDAVSyncManager: # Singleton instance -_plugin_manager = None +_plugin_manager = None -def get_plugin_manager(db_manager=None) -> None: +def get_plugin_manager(db_manager = None) -> None: """获取 PluginManager 单例""" global _plugin_manager if _plugin_manager is None: - _plugin_manager = PluginManager(db_manager) + _plugin_manager = PluginManager(db_manager) return _plugin_manager diff --git a/backend/rate_limiter.py b/backend/rate_limiter.py index 58ce2ec..ea1ea8e 100644 --- a/backend/rate_limiter.py +++ b/backend/rate_limiter.py @@ -17,9 +17,9 @@ from functools import wraps class RateLimitConfig: """限流配置""" - requests_per_minute: int = 60 - burst_size: int = 10 # 突发请求数 - window_size: int = 60 # 窗口大小(秒) + requests_per_minute: int = 60 + burst_size: int = 10 # 突发请求数 + window_size: int = 60 # 窗口大小(秒) @dataclass @@ -35,16 +35,16 @@ class RateLimitInfo: class SlidingWindowCounter: """滑动窗口计数器""" - def __init__(self, window_size: int = 60): - self.window_size = window_size - self.requests: dict[int, int] = defaultdict(int) # 秒级计数 - self._lock = asyncio.Lock() - self._cleanup_lock = asyncio.Lock() + def __init__(self, window_size: int = 60) -> None: + self.window_size = window_size + self.requests: dict[int, int] = defaultdict(int) # 秒级计数 + self._lock = asyncio.Lock() + self._cleanup_lock = asyncio.Lock() async def add_request(self) -> int: """添加请求,返回当前窗口内的请求数""" async with self._lock: - now = int(time.time()) + now = int(time.time()) self.requests[now] += 1 self._cleanup_old(now) return sum(self.requests.values()) @@ -52,14 +52,14 @@ class SlidingWindowCounter: async def get_count(self) -> int: """获取当前窗口内的请求数""" async with self._lock: - now = int(time.time()) + now = int(time.time()) self._cleanup_old(now) return sum(self.requests.values()) def _cleanup_old(self, now: int) -> None: """清理过期的请求记录 - 使用独立锁避免竞态条件""" - cutoff = now - self.window_size - old_keys = [k for k in list(self.requests.keys()) if k < cutoff] + cutoff = now - self.window_size + old_keys = [k for k in list(self.requests.keys()) if k < cutoff] for k in old_keys: self.requests.pop(k, None) @@ -69,13 +69,13 @@ class RateLimiter: def __init__(self) -> None: # key -> SlidingWindowCounter - self.counters: dict[str, SlidingWindowCounter] = {} + self.counters: dict[str, SlidingWindowCounter] = {} # key -> RateLimitConfig - self.configs: dict[str, RateLimitConfig] = {} - self._lock = asyncio.Lock() - self._cleanup_lock = asyncio.Lock() + self.configs: dict[str, RateLimitConfig] = {} + self._lock = asyncio.Lock() + self._cleanup_lock = asyncio.Lock() - async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo: + async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo: """ 检查是否允许请求 @@ -87,70 +87,70 @@ class RateLimiter: RateLimitInfo """ if config is None: - config = RateLimitConfig() + config = RateLimitConfig() async with self._lock: if key not in self.counters: - self.counters[key] = SlidingWindowCounter(config.window_size) - self.configs[key] = config + self.counters[key] = SlidingWindowCounter(config.window_size) + self.configs[key] = config - counter = self.counters[key] - stored_config = self.configs.get(key, config) + counter = self.counters[key] + stored_config = self.configs.get(key, config) # 获取当前计数 - current_count = await counter.get_count() + current_count = await counter.get_count() # 计算剩余配额 - remaining = max(0, stored_config.requests_per_minute - current_count) + remaining = max(0, stored_config.requests_per_minute - current_count) # 计算重置时间 - now = int(time.time()) - reset_time = now + stored_config.window_size + now = int(time.time()) + reset_time = now + stored_config.window_size # 检查是否超过限制 if current_count >= stored_config.requests_per_minute: return RateLimitInfo( - allowed=False, - remaining=0, - reset_time=reset_time, - retry_after=stored_config.window_size, + allowed = False, + remaining = 0, + reset_time = reset_time, + retry_after = stored_config.window_size, ) # 允许请求,增加计数 await counter.add_request() return RateLimitInfo( - allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0 + allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0 ) async def get_limit_info(self, key: str) -> RateLimitInfo: """获取限流信息(不增加计数)""" if key not in self.counters: - config = RateLimitConfig() + config = RateLimitConfig() return RateLimitInfo( - allowed=True, - remaining=config.requests_per_minute, - reset_time=int(time.time()) + config.window_size, - retry_after=0, + allowed = True, + remaining = config.requests_per_minute, + reset_time = int(time.time()) + config.window_size, + retry_after = 0, ) - counter = self.counters[key] - config = self.configs.get(key, RateLimitConfig()) + counter = self.counters[key] + config = self.configs.get(key, RateLimitConfig()) - current_count = await counter.get_count() - remaining = max(0, config.requests_per_minute - current_count) - reset_time = int(time.time()) + config.window_size + current_count = await counter.get_count() + remaining = max(0, config.requests_per_minute - current_count) + reset_time = int(time.time()) + config.window_size return RateLimitInfo( - allowed=current_count < config.requests_per_minute, - remaining=remaining, - reset_time=reset_time, - retry_after=max(0, config.window_size) + allowed = current_count < config.requests_per_minute, + remaining = remaining, + reset_time = reset_time, + retry_after = max(0, config.window_size) if current_count >= config.requests_per_minute else 0, ) - def reset(self, key: str | None = None) -> None: + def reset(self, key: str | None = None) -> None: """重置限流计数器""" if key: self.counters.pop(key, None) @@ -161,21 +161,21 @@ class RateLimiter: # 全局限流器实例 -_rate_limiter: RateLimiter | None = None +_rate_limiter: RateLimiter | None = None def get_rate_limiter() -> RateLimiter: """获取限流器实例""" global _rate_limiter if _rate_limiter is None: - _rate_limiter = RateLimiter() + _rate_limiter = RateLimiter() return _rate_limiter # 限流装饰器(用于函数级别限流) -def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: +def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: """ 限流装饰器 @@ -184,14 +184,14 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) key_func: 生成限流键的函数,默认为 None(使用函数名) """ - def decorator(func): - limiter = get_rate_limiter() - config = RateLimitConfig(requests_per_minute=requests_per_minute) + def decorator(func) -> None: + limiter = get_rate_limiter() + config = RateLimitConfig(requests_per_minute = requests_per_minute) @wraps(func) - async def async_wrapper(*args, **kwargs): - key = key_func(*args, **kwargs) if key_func else func.__name__ - info = await limiter.is_allowed(key, config) + async def async_wrapper(*args, **kwargs) -> None: + key = key_func(*args, **kwargs) if key_func else func.__name__ + info = await limiter.is_allowed(key, config) if not info.allowed: raise RateLimitExceeded( @@ -201,10 +201,10 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) return await func(*args, **kwargs) @wraps(func) - def sync_wrapper(*args, **kwargs): - key = key_func(*args, **kwargs) if key_func else func.__name__ + def sync_wrapper(*args, **kwargs) -> None: + key = key_func(*args, **kwargs) if key_func else func.__name__ # 同步版本使用 asyncio.run - info = asyncio.run(limiter.is_allowed(key, config)) + info = asyncio.run(limiter.is_allowed(key, config)) if not info.allowed: raise RateLimitExceeded( diff --git a/backend/search_manager.py b/backend/search_manager.py index 45e2343..641bb6c 100644 --- a/backend/search_manager.py +++ b/backend/search_manager.py @@ -23,9 +23,9 @@ from enum import Enum class SearchOperator(Enum): """搜索操作符""" - AND = "AND" - OR = "OR" - NOT = "NOT" + AND = "AND" + OR = "OR" + NOT = "NOT" # 尝试导入 sentence-transformers 用于语义搜索 @@ -33,9 +33,9 @@ try: from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity - SENTENCE_TRANSFORMERS_AVAILABLE = True + SENTENCE_TRANSFORMERS_AVAILABLE = True except ImportError: - SENTENCE_TRANSFORMERS_AVAILABLE = False + SENTENCE_TRANSFORMERS_AVAILABLE = False # ==================== 数据模型 ==================== @@ -49,8 +49,8 @@ class SearchResult: content_type: str # transcript, entity, relation project_id: str score: float - highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置 - metadata: dict = field(default_factory=dict) + highlights: list[tuple[int, int]] = field(default_factory = list) # 高亮位置 + metadata: dict = field(default_factory = dict) def to_dict(self) -> dict: return { @@ -73,11 +73,11 @@ class SemanticSearchResult: content_type: str project_id: str similarity: float - embedding: list[float] | None = None - metadata: dict = field(default_factory=dict) + embedding: list[float] | None = None + metadata: dict = field(default_factory = dict) def to_dict(self) -> dict: - result = { + result = { "id": self.id, "content": self.content[:500] + "..." if len(self.content) > 500 else self.content, "content_type": self.content_type, @@ -86,7 +86,7 @@ class SemanticSearchResult: "metadata": self.metadata, } if self.embedding: - result["embedding_dim"] = len(self.embedding) + result["embedding_dim"] = len(self.embedding) return result @@ -132,7 +132,7 @@ class KnowledgeGap: severity: str # high, medium, low suggestions: list[str] related_entities: list[str] - metadata: dict = field(default_factory=dict) + metadata: dict = field(default_factory = dict) def to_dict(self) -> dict: return { @@ -189,19 +189,19 @@ class FullTextSearch: - 支持布尔搜索(AND/OR/NOT) """ - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_search_tables() def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_search_tables(self) -> None: """初始化搜索相关表""" - conn = self._get_conn() + conn = self._get_conn() # 搜索索引表 conn.execute(""" @@ -251,25 +251,25 @@ class FullTextSearch: 实际生产环境可以使用 jieba 等分词工具 """ # 清理文本 - text = text.lower() + text = text.lower() # 提取中文字符、英文单词和数字 - tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text) + tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text) return tokens def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]: """提取每个词在文本中的位置""" - positions = defaultdict(list) - text_lower = text.lower() + positions = defaultdict(list) + text_lower = text.lower() for token in tokens: # 查找所有出现位置 - start = 0 + start = 0 while True: - pos = text_lower.find(token, start) + pos = text_lower.find(token, start) if pos == -1: break positions[token].append(pos) - start = pos + 1 + start = pos + 1 return dict(positions) @@ -287,24 +287,24 @@ class FullTextSearch: bool: 是否成功 """ try: - conn = self._get_conn() + conn = self._get_conn() # 分词 - tokens = self._tokenize(text) + tokens = self._tokenize(text) if not tokens: conn.close() return False # 提取位置信息 - token_positions = self._extract_positions(text, tokens) + token_positions = self._extract_positions(text, tokens) # 计算词频 - token_freq = defaultdict(int) + token_freq = defaultdict(int) for token in tokens: token_freq[token] += 1 - index_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] - now = datetime.now().isoformat() + index_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] + now = datetime.now().isoformat() # 保存索引 conn.execute( @@ -318,8 +318,8 @@ class FullTextSearch: content_id, content_type, project_id, - json.dumps(tokens, ensure_ascii=False), - json.dumps(token_positions, ensure_ascii=False), + json.dumps(tokens, ensure_ascii = False), + json.dumps(token_positions, ensure_ascii = False), now, now, ), @@ -327,7 +327,7 @@ class FullTextSearch: # 保存词频统计 for token, freq in token_freq.items(): - positions = token_positions.get(token, []) + positions = token_positions.get(token, []) conn.execute( """ INSERT OR REPLACE INTO search_term_freq @@ -340,7 +340,7 @@ class FullTextSearch: content_type, project_id, freq, - json.dumps(positions, ensure_ascii=False), + json.dumps(positions, ensure_ascii = False), ), ) @@ -355,10 +355,10 @@ class FullTextSearch: def search( self, query: str, - project_id: str | None = None, - content_types: list[str] | None = None, - limit: int = 20, - offset: int = 0, + project_id: str | None = None, + content_types: list[str] | None = None, + limit: int = 20, + offset: int = 0, ) -> list[SearchResult]: """ 全文搜索 @@ -374,16 +374,16 @@ class FullTextSearch: List[SearchResult]: 搜索结果列表 """ # 解析布尔查询 - parsed_query = self._parse_boolean_query(query) + parsed_query = self._parse_boolean_query(query) # 执行搜索 - results = self._execute_boolean_search(parsed_query, project_id, content_types) + results = self._execute_boolean_search(parsed_query, project_id, content_types) # 计算相关性分数 - scored_results = self._score_results(results, parsed_query) + scored_results = self._score_results(results, parsed_query) # 排序和分页 - scored_results.sort(key=lambda x: x.score, reverse=True) + scored_results.sort(key = lambda x: x.score, reverse = True) return scored_results[offset : offset + limit] @@ -397,138 +397,138 @@ class FullTextSearch: - NOT: NOT 词1 或 词1 -词2 - 短语: "精确短语" """ - query = query.strip() + query = query.strip() # 提取短语(引号内的内容) - phrases = re.findall(r'"([^"]+)"', query) - query_without_phrases = re.sub(r'"[^"]+"', "", query) + phrases = re.findall(r'"([^"]+)"', query) + query_without_phrases = re.sub(r'"[^"]+"', "", query) # 解析布尔操作 - and_terms = [] - or_terms = [] - not_terms = [] + and_terms = [] + or_terms = [] + not_terms = [] # 处理 NOT - not_pattern = r"(?:NOT\s+|\-)(\w+)" - not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) + not_pattern = r"(?:NOT\s+|\-)(\w+)" + not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) not_terms.extend(not_matches) - query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE) + query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags = re.IGNORECASE) # 处理 OR - or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE) + or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags = re.IGNORECASE) if len(or_parts) > 1: - or_terms = [p.strip() for p in or_parts[1:] if p.strip()] - query_without_phrases = or_parts[0] + or_terms = [p.strip() for p in or_parts[1:] if p.strip()] + query_without_phrases = or_parts[0] # 剩余的作为 AND 条件 - and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()] + and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()] return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases} def _execute_boolean_search( self, parsed_query: dict, - project_id: str | None = None, - content_types: list[str] | None = None, + project_id: str | None = None, + content_types: list[str] | None = None, ) -> list[dict]: """执行布尔搜索""" - conn = self._get_conn() + conn = self._get_conn() # 构建基础查询 - base_where = [] - params = [] + base_where = [] + params = [] if project_id: - base_where.append("project_id = ?") + base_where.append("project_id = ?") params.append(project_id) if content_types: - placeholders = ",".join(["?" for _ in content_types]) + placeholders = ", ".join(["?" for _ in content_types]) base_where.append(f"content_type IN ({placeholders})") params.extend(content_types) - base_where_str = " AND ".join(base_where) if base_where else "1=1" + base_where_str = " AND ".join(base_where) if base_where else "1 = 1" # 获取候选结果 - candidates = set() + candidates = set() # 处理 AND 条件 if parsed_query["and"]: for term in parsed_query["and"]: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq - WHERE term = ? AND {base_where_str} + WHERE term = ? AND {base_where_str} """, [term] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} if not candidates: - candidates = term_contents + candidates = term_contents else: candidates &= term_contents # 交集 # 处理 OR 条件 if parsed_query["or"]: for term in parsed_query["or"]: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq - WHERE term = ? AND {base_where_str} + WHERE term = ? AND {base_where_str} """, [term] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} candidates |= term_contents # 并集 # 如果没有 AND 和 OR,但有 phrases,使用 phrases if not candidates and parsed_query["phrases"]: for phrase in parsed_query["phrases"]: - phrase_tokens = self._tokenize(phrase) + phrase_tokens = self._tokenize(phrase) if phrase_tokens: # 查找包含所有短语的文档 for token in phrase_tokens: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type, project_id, frequency, positions FROM search_term_freq - WHERE term = ? AND {base_where_str} + WHERE term = ? AND {base_where_str} """, [token] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} if not candidates: - candidates = term_contents + candidates = term_contents else: candidates &= term_contents # 处理 NOT 条件(排除) if parsed_query["not"]: for term in parsed_query["not"]: - term_results = conn.execute( + term_results = conn.execute( f""" SELECT content_id, content_type FROM search_term_freq - WHERE term = ? AND {base_where_str} + WHERE term = ? AND {base_where_str} """, [term] + params, ).fetchall() - term_contents = {(r["content_id"], r["content_type"]) for r in term_results} + term_contents = {(r["content_id"], r["content_type"]) for r in term_results} candidates -= term_contents # 差集 # 获取完整内容 - results = [] + results = [] for content_id, content_type in candidates: # 获取原始内容 - content = self._get_content_by_id(conn, content_id, content_type) + content = self._get_content_by_id(conn, content_id, content_type) if content: results.append( { @@ -550,28 +550,28 @@ class FullTextSearch: """根据ID获取内容""" try: if content_type == "transcript": - row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT full_text FROM transcripts WHERE id = ?", (content_id, ) ).fetchone() return row["full_text"] if row else None elif content_type == "entity": - row = conn.execute( - "SELECT name, definition FROM entities WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT name, definition FROM entities WHERE id = ?", (content_id, ) ).fetchone() if row: return f"{row['name']} {row['definition'] or ''}" return None elif content_type == "relation": - row = conn.execute( + row = conn.execute( """SELECT r.relation_type, r.evidence, e1.name as source_name, e2.name as target_name FROM entity_relations r - JOIN entities e1 ON r.source_entity_id = e1.id - JOIN entities e2 ON r.target_entity_id = e2.id - WHERE r.id = ?""", - (content_id,), + JOIN entities e1 ON r.source_entity_id = e1.id + JOIN entities e2 ON r.target_entity_id = e2.id + WHERE r.id = ?""", + (content_id, ), ).fetchone() if row: return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}" @@ -588,16 +588,16 @@ class FullTextSearch: """获取内容所属的项目ID""" try: if content_type == "transcript": - row = conn.execute( - "SELECT project_id FROM transcripts WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT project_id FROM transcripts WHERE id = ?", (content_id, ) ).fetchone() elif content_type == "entity": - row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT project_id FROM entities WHERE id = ?", (content_id, ) ).fetchone() elif content_type == "relation": - row = conn.execute( - "SELECT project_id FROM entity_relations WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT project_id FROM entity_relations WHERE id = ?", (content_id, ) ).fetchone() else: return None @@ -608,39 +608,39 @@ class FullTextSearch: def _score_results(self, results: list[dict], parsed_query: dict) -> list[SearchResult]: """计算搜索结果的相关性分数""" - scored = [] - all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"] + scored = [] + all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"] for result in results: - content = result["content"].lower() + content = result["content"].lower() # 基础分数 - score = 0.0 - highlights = [] + score = 0.0 + highlights = [] # 计算每个词的匹配分数 for term in all_terms: - term_lower = term.lower() - count = content.count(term_lower) + term_lower = term.lower() + count = content.count(term_lower) if count > 0: # TF 分数(词频) - tf_score = math.log(1 + count) + tf_score = math.log(1 + count) # 位置加分(标题/开头匹配分数更高) - position_bonus = 0 - first_pos = content.find(term_lower) + position_bonus = 0 + first_pos = content.find(term_lower) if first_pos != -1: if first_pos < 50: # 开头50个字符 - position_bonus = 2.0 + position_bonus = 2.0 elif first_pos < 200: # 开头200个字符 - position_bonus = 1.0 + position_bonus = 1.0 # 记录高亮位置 - start = first_pos + start = first_pos while start != -1: highlights.append((start, start + len(term))) - start = content.find(term_lower, start + 1) + start = content.find(term_lower, start + 1) score += tf_score + position_bonus @@ -650,23 +650,23 @@ class FullTextSearch: score *= 1.5 # 短语匹配加权 # 归一化分数 - score = min(score / max(len(all_terms), 1), 10.0) + score = min(score / max(len(all_terms), 1), 10.0) scored.append( SearchResult( - id=result["id"], - content=result["content"], - content_type=result["content_type"], - project_id=result["project_id"], - score=round(score, 4), - highlights=highlights[:10], # 限制高亮数量 - metadata={}, + id = result["id"], + content = result["content"], + content_type = result["content_type"], + project_id = result["project_id"], + score = round(score, 4), + highlights = highlights[:10], # 限制高亮数量 + metadata = {}, ) ) return scored - def highlight_text(self, text: str, query: str, max_length: int = 300) -> str: + def highlight_text(self, text: str, query: str, max_length: int = 300) -> str: """ 高亮文本中的关键词 @@ -678,47 +678,47 @@ class FullTextSearch: Returns: str: 带高亮标记的文本 """ - parsed = self._parse_boolean_query(query) - all_terms = parsed["and"] + parsed["or"] + parsed["phrases"] + parsed = self._parse_boolean_query(query) + all_terms = parsed["and"] + parsed["or"] + parsed["phrases"] # 找到第一个匹配位置 - first_match = len(text) + first_match = len(text) for term in all_terms: - pos = text.lower().find(term.lower()) + pos = text.lower().find(term.lower()) if pos != -1 and pos < first_match: - first_match = pos + first_match = pos # 截取上下文 - start = max(0, first_match - 100) - end = min(len(text), start + max_length) - snippet = text[start:end] + start = max(0, first_match - 100) + end = min(len(text), start + max_length) + snippet = text[start:end] if start > 0: - snippet = "..." + snippet + snippet = "..." + snippet if end < len(text): - snippet = snippet + "..." + snippet = snippet + "..." # 添加高亮标记 - for term in sorted(all_terms, key=len, reverse=True): # 长的先替换 - pattern = re.compile(re.escape(term), re.IGNORECASE) - snippet = pattern.sub(f"**{term}**", snippet) + for term in sorted(all_terms, key = len, reverse = True): # 长的先替换 + pattern = re.compile(re.escape(term), re.IGNORECASE) + snippet = pattern.sub(f"**{term}**", snippet) return snippet def delete_index(self, content_id: str, content_type: str) -> bool: """删除内容的搜索索引""" try: - conn = self._get_conn() + conn = self._get_conn() # 删除索引 conn.execute( - "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?", + "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?", (content_id, content_type), ) # 删除词频统计 conn.execute( - "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?", + "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?", (content_id, content_type), ) @@ -731,14 +731,14 @@ class FullTextSearch: def reindex_project(self, project_id: str) -> dict: """重新索引整个项目""" - conn = self._get_conn() - stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0} + conn = self._get_conn() + stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0} try: # 索引转录文本 - transcripts = conn.execute( - "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", - (project_id,), + transcripts = conn.execute( + "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", + (project_id, ), ).fetchall() for t in transcripts: @@ -749,31 +749,31 @@ class FullTextSearch: stats["errors"] += 1 # 索引实体 - entities = conn.execute( - "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", - (project_id,), + entities = conn.execute( + "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", + (project_id, ), ).fetchall() for e in entities: - text = f"{e['name']} {e['definition'] or ''}" + text = f"{e['name']} {e['definition'] or ''}" if self.index_content(e["id"], "entity", e["project_id"], text): stats["entities"] += 1 else: stats["errors"] += 1 # 索引关系 - relations = conn.execute( + relations = conn.execute( """SELECT r.id, r.project_id, r.relation_type, r.evidence, e1.name as source_name, e2.name as target_name FROM entity_relations r - JOIN entities e1 ON r.source_entity_id = e1.id - JOIN entities e2 ON r.target_entity_id = e2.id - WHERE r.project_id = ?""", - (project_id,), + JOIN entities e1 ON r.source_entity_id = e1.id + JOIN entities e2 ON r.target_entity_id = e2.id + WHERE r.project_id = ?""", + (project_id, ), ).fetchall() for r in relations: - text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}" + text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}" if self.index_content(r["id"], "relation", r["project_id"], text): stats["relations"] += 1 else: @@ -803,31 +803,31 @@ class SemanticSearch: def __init__( self, - db_path: str = "insightflow.db", - model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", - ): - self.db_path = db_path - self.model_name = model_name - self.model = None + db_path: str = "insightflow.db", + model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", + ) -> None: + self.db_path = db_path + self.model_name = model_name + self.model = None self._init_embedding_tables() # 延迟加载模型 if SENTENCE_TRANSFORMERS_AVAILABLE: try: - self.model = SentenceTransformer(model_name) + self.model = SentenceTransformer(model_name) print(f"语义搜索模型加载成功: {model_name}") except Exception as e: print(f"模型加载失败: {e}") def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_embedding_tables(self) -> None: """初始化 embedding 相关表""" - conn = self._get_conn() + conn = self._get_conn() conn.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( @@ -869,11 +869,11 @@ class SemanticSearch: try: # 截断长文本 - max_chars = 5000 + max_chars = 5000 if len(text) > max_chars: - text = text[:max_chars] + text = text[:max_chars] - embedding = self.model.encode(text, convert_to_list=True) + embedding = self.model.encode(text, convert_to_list = True) return embedding except Exception as e: print(f"生成 embedding 失败: {e}") @@ -898,13 +898,13 @@ class SemanticSearch: return False try: - embedding = self.generate_embedding(text) + embedding = self.generate_embedding(text) if not embedding: return False - conn = self._get_conn() + conn = self._get_conn() - embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] + embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16] conn.execute( """ @@ -934,10 +934,10 @@ class SemanticSearch: def search( self, query: str, - project_id: str | None = None, - content_types: list[str] | None = None, - top_k: int = 10, - threshold: float = 0.5, + project_id: str | None = None, + content_types: list[str] | None = None, + top_k: int = 10, + threshold: float = 0.5, ) -> list[SemanticSearchResult]: """ 语义搜索 @@ -956,28 +956,28 @@ class SemanticSearch: return [] # 生成查询的 embedding - query_embedding = self.generate_embedding(query) + query_embedding = self.generate_embedding(query) if not query_embedding: return [] # 获取候选 embedding - conn = self._get_conn() + conn = self._get_conn() - where_clauses = [] - params = [] + where_clauses = [] + params = [] if project_id: - where_clauses.append("project_id = ?") + where_clauses.append("project_id = ?") params.append(project_id) if content_types: - placeholders = ",".join(["?" for _ in content_types]) + placeholders = ", ".join(["?" for _ in content_types]) where_clauses.append(f"content_type IN ({placeholders})") params.extend(content_types) - where_str = " AND ".join(where_clauses) if where_clauses else "1=1" + where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1" - rows = conn.execute( + rows = conn.execute( f""" SELECT content_id, content_type, project_id, embedding FROM embeddings @@ -989,29 +989,29 @@ class SemanticSearch: conn.close() # 计算相似度 - results = [] - query_vec = [query_embedding] + results = [] + query_vec = [query_embedding] for row in rows: try: - content_embedding = json.loads(row["embedding"]) + content_embedding = json.loads(row["embedding"]) # 计算余弦相似度 - similarity = cosine_similarity(query_vec, [content_embedding])[0][0] + similarity = cosine_similarity(query_vec, [content_embedding])[0][0] if similarity >= threshold: # 获取原始内容 - content = self._get_content_text(row["content_id"], row["content_type"]) + content = self._get_content_text(row["content_id"], row["content_type"]) results.append( SemanticSearchResult( - id=row["content_id"], - content=content or "", - content_type=row["content_type"], - project_id=row["project_id"], - similarity=float(similarity), - embedding=None, # 不返回 embedding 以节省带宽 - metadata={}, + id = row["content_id"], + content = content or "", + content_type = row["content_type"], + project_id = row["project_id"], + similarity = float(similarity), + embedding = None, # 不返回 embedding 以节省带宽 + metadata = {}, ) ) except Exception as e: @@ -1019,44 +1019,44 @@ class SemanticSearch: continue # 排序并返回 top_k - results.sort(key=lambda x: x.similarity, reverse=True) + results.sort(key = lambda x: x.similarity, reverse = True) return results[:top_k] def _get_content_text(self, content_id: str, content_type: str) -> str | None: """获取内容文本""" - conn = self._get_conn() + conn = self._get_conn() try: if content_type == "transcript": - row = conn.execute( - "SELECT full_text FROM transcripts WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT full_text FROM transcripts WHERE id = ?", (content_id, ) ).fetchone() - result = row["full_text"] if row else None + result = row["full_text"] if row else None elif content_type == "entity": - row = conn.execute( - "SELECT name, definition FROM entities WHERE id = ?", (content_id,) + row = conn.execute( + "SELECT name, definition FROM entities WHERE id = ?", (content_id, ) ).fetchone() - result = f"{row['name']}: {row['definition']}" if row else None + result = f"{row['name']}: {row['definition']}" if row else None elif content_type == "relation": - row = conn.execute( + row = conn.execute( """SELECT r.relation_type, r.evidence, e1.name as source_name, e2.name as target_name FROM entity_relations r - JOIN entities e1 ON r.source_entity_id = e1.id - JOIN entities e2 ON r.target_entity_id = e2.id - WHERE r.id = ?""", - (content_id,), + JOIN entities e1 ON r.source_entity_id = e1.id + JOIN entities e2 ON r.target_entity_id = e2.id + WHERE r.id = ?""", + (content_id, ), ).fetchone() - result = ( + result = ( f"{row['source_name']} {row['relation_type']} {row['target_name']}" if row else None ) else: - result = None + result = None conn.close() return result @@ -1067,7 +1067,7 @@ class SemanticSearch: return None def find_similar_content( - self, content_id: str, content_type: str, top_k: int = 5 + self, content_id: str, content_type: str, top_k: int = 5 ) -> list[SemanticSearchResult]: """ 查找与指定内容相似的内容 @@ -1084,10 +1084,10 @@ class SemanticSearch: return [] # 获取源内容的 embedding - conn = self._get_conn() + conn = self._get_conn() - row = conn.execute( - "SELECT embedding, project_id FROM embeddings WHERE content_id = ? AND content_type = ?", + row = conn.execute( + "SELECT embedding, project_id FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type), ).fetchone() @@ -1095,52 +1095,52 @@ class SemanticSearch: conn.close() return [] - source_embedding = json.loads(row["embedding"]) - project_id = row["project_id"] + source_embedding = json.loads(row["embedding"]) + project_id = row["project_id"] # 获取其他内容的 embedding - rows = conn.execute( + rows = conn.execute( """SELECT content_id, content_type, project_id, embedding FROM embeddings - WHERE project_id = ? AND (content_id != ? OR content_type != ?)""", + WHERE project_id = ? AND (content_id != ? OR content_type != ?)""", (project_id, content_id, content_type), ).fetchall() conn.close() # 计算相似度 - results = [] - source_vec = [source_embedding] + results = [] + source_vec = [source_embedding] for row in rows: try: - content_embedding = json.loads(row["embedding"]) - similarity = cosine_similarity(source_vec, [content_embedding])[0][0] + content_embedding = json.loads(row["embedding"]) + similarity = cosine_similarity(source_vec, [content_embedding])[0][0] - content = self._get_content_text(row["content_id"], row["content_type"]) + content = self._get_content_text(row["content_id"], row["content_type"]) results.append( SemanticSearchResult( - id=row["content_id"], - content=content or "", - content_type=row["content_type"], - project_id=row["project_id"], - similarity=float(similarity), - metadata={}, + id = row["content_id"], + content = content or "", + content_type = row["content_type"], + project_id = row["project_id"], + similarity = float(similarity), + metadata = {}, ) ) except (KeyError, ValueError): continue - results.sort(key=lambda x: x.similarity, reverse=True) + results.sort(key = lambda x: x.similarity, reverse = True) return results[:top_k] def delete_embedding(self, content_id: str, content_type: str) -> bool: """删除内容的 embedding""" try: - conn = self._get_conn() + conn = self._get_conn() conn.execute( - "DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", + "DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type), ) conn.commit() @@ -1165,17 +1165,17 @@ class EntityPathDiscovery: - 路径可视化数据生成 """ - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def find_shortest_path( - self, source_entity_id: str, target_entity_id: str, max_depth: int = 5 + self, source_entity_id: str, target_entity_id: str, max_depth: int = 5 ) -> EntityPath | None: """ 查找两个实体之间的最短路径(BFS算法) @@ -1188,22 +1188,22 @@ class EntityPathDiscovery: Returns: Optional[EntityPath]: 最短路径 """ - conn = self._get_conn() + conn = self._get_conn() # 获取项目ID - row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", (source_entity_id,) + row = conn.execute( + "SELECT project_id FROM entities WHERE id = ?", (source_entity_id, ) ).fetchone() if not row: conn.close() return None - project_id = row["project_id"] + project_id = row["project_id"] # 验证目标实体也在同一项目 - row = conn.execute( - "SELECT 1 FROM entities WHERE id = ? AND project_id = ?", (target_entity_id, project_id) + row = conn.execute( + "SELECT 1 FROM entities WHERE id = ? AND project_id = ?", (target_entity_id, project_id) ).fetchone() if not row: @@ -1211,11 +1211,11 @@ class EntityPathDiscovery: return None # BFS - visited = {source_entity_id} - queue = [(source_entity_id, [source_entity_id])] + visited = {source_entity_id} + queue = [(source_entity_id, [source_entity_id])] while queue: - current_id, path = queue.pop(0) + current_id, path = queue.pop(0) if len(path) > max_depth + 1: continue @@ -1226,21 +1226,21 @@ class EntityPathDiscovery: return self._build_path_object(path, project_id) # 获取邻居 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT target_entity_id as neighbor_id, relation_type, evidence FROM entity_relations - WHERE source_entity_id = ? AND project_id = ? + WHERE source_entity_id = ? AND project_id = ? UNION SELECT source_entity_id as neighbor_id, relation_type, evidence FROM entity_relations - WHERE target_entity_id = ? AND project_id = ? + WHERE target_entity_id = ? AND project_id = ? """, (current_id, project_id, current_id, project_id), ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited.add(neighbor_id) queue.append((neighbor_id, path + [neighbor_id])) @@ -1249,7 +1249,7 @@ class EntityPathDiscovery: return None def find_all_paths( - self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10 + self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10 ) -> list[EntityPath]: """ 查找两个实体之间的所有路径(限制数量和深度) @@ -1263,22 +1263,22 @@ class EntityPathDiscovery: Returns: List[EntityPath]: 路径列表 """ - conn = self._get_conn() + conn = self._get_conn() # 获取项目ID - row = conn.execute( - "SELECT project_id FROM entities WHERE id = ?", (source_entity_id,) + row = conn.execute( + "SELECT project_id FROM entities WHERE id = ?", (source_entity_id, ) ).fetchone() if not row: conn.close() return [] - project_id = row["project_id"] + project_id = row["project_id"] - paths = [] + paths = [] - def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int): + def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int) -> None: if depth > max_depth: return @@ -1287,21 +1287,21 @@ class EntityPathDiscovery: return # 获取邻居 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT target_entity_id as neighbor_id FROM entity_relations - WHERE source_entity_id = ? AND project_id = ? + WHERE source_entity_id = ? AND project_id = ? UNION SELECT source_entity_id as neighbor_id FROM entity_relations - WHERE target_entity_id = ? AND project_id = ? + WHERE target_entity_id = ? AND project_id = ? """, (current_id, project_id, current_id, project_id), ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited and len(paths) < max_paths: visited.add(neighbor_id) path.append(neighbor_id) @@ -1309,7 +1309,7 @@ class EntityPathDiscovery: path.pop() visited.remove(neighbor_id) - visited = {source_entity_id} + visited = {source_entity_id} dfs(source_entity_id, target_entity_id, [source_entity_id], visited, 0) conn.close() @@ -1319,30 +1319,30 @@ class EntityPathDiscovery: def _build_path_object(self, entity_ids: list[str], project_id: str) -> EntityPath: """构建路径对象""" - conn = self._get_conn() + conn = self._get_conn() # 获取实体信息 - nodes = [] + nodes = [] for entity_id in entity_ids: - row = conn.execute( - "SELECT id, name, type FROM entities WHERE id = ?", (entity_id,) + row = conn.execute( + "SELECT id, name, type FROM entities WHERE id = ?", (entity_id, ) ).fetchone() if row: nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]}) # 获取边信息 - edges = [] + edges = [] for i in range(len(entity_ids) - 1): - source_id = entity_ids[i] - target_id = entity_ids[i + 1] + source_id = entity_ids[i] + target_id = entity_ids[i + 1] - row = conn.execute( + row = conn.execute( """ SELECT id, relation_type, evidence FROM entity_relations - WHERE ((source_entity_id = ? AND target_entity_id = ?) - OR (source_entity_id = ? AND target_entity_id = ?)) - AND project_id = ? + WHERE ((source_entity_id = ? AND target_entity_id = ?) + OR (source_entity_id = ? AND target_entity_id = ?)) + AND project_id = ? """, (source_id, target_id, target_id, source_id, project_id), ).fetchone() @@ -1361,26 +1361,26 @@ class EntityPathDiscovery: conn.close() # 生成路径描述 - node_names = [n["name"] for n in nodes] - path_desc = " → ".join(node_names) + node_names = [n["name"] for n in nodes] + path_desc = " → ".join(node_names) # 计算置信度(基于路径长度和关系数量) - confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0 + confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0 return EntityPath( - path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", - source_entity_id=entity_ids[0], - source_entity_name=nodes[0]["name"] if nodes else "", - target_entity_id=entity_ids[-1], - target_entity_name=nodes[-1]["name"] if nodes else "", - path_length=len(entity_ids) - 1, - nodes=nodes, - edges=edges, - confidence=round(confidence, 4), - path_description=path_desc, + path_id = f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", + source_entity_id = entity_ids[0], + source_entity_name = nodes[0]["name"] if nodes else "", + target_entity_id = entity_ids[-1], + target_entity_name = nodes[-1]["name"] if nodes else "", + path_length = len(entity_ids) - 1, + nodes = nodes, + edges = edges, + confidence = round(confidence, 4), + path_description = path_desc, ) - def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: + def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: """ 查找实体的多跳关系 @@ -1391,58 +1391,58 @@ class EntityPathDiscovery: Returns: List[Dict]: 多跳关系列表 """ - conn = self._get_conn() + conn = self._get_conn() # 获取项目ID - row = conn.execute( - "SELECT project_id, name FROM entities WHERE id = ?", (entity_id,) + row = conn.execute( + "SELECT project_id, name FROM entities WHERE id = ?", (entity_id, ) ).fetchone() if not row: conn.close() return [] - project_id = row["project_id"] + project_id = row["project_id"] row["name"] # BFS 收集多跳关系 - visited = {entity_id: 0} - queue = [(entity_id, 0)] - relations = [] + visited = {entity_id: 0} + queue = [(entity_id, 0)] + relations = [] while queue: - current_id, depth = queue.pop(0) + current_id, depth = queue.pop(0) if depth >= max_hops: continue # 获取邻居 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT CASE - WHEN source_entity_id = ? THEN target_entity_id + WHEN source_entity_id = ? THEN target_entity_id ELSE source_entity_id END as neighbor_id, relation_type, evidence FROM entity_relations - WHERE (source_entity_id = ? OR target_entity_id = ?) - AND project_id = ? + WHERE (source_entity_id = ? OR target_entity_id = ?) + AND project_id = ? """, (current_id, current_id, current_id, project_id), ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: - visited[neighbor_id] = depth + 1 + visited[neighbor_id] = depth + 1 queue.append((neighbor_id, depth + 1)) # 获取邻居信息 - neighbor_info = conn.execute( - "SELECT name, type FROM entities WHERE id = ?", (neighbor_id,) + neighbor_info = conn.execute( + "SELECT name, type FROM entities WHERE id = ?", (neighbor_id, ) ).fetchone() if neighbor_info: @@ -1463,7 +1463,7 @@ class EntityPathDiscovery: conn.close() # 按跳数排序 - relations.sort(key=lambda x: x["hops"]) + relations.sort(key = lambda x: x["hops"]) return relations def _get_path_to_entity( @@ -1471,11 +1471,11 @@ class EntityPathDiscovery: ) -> list[str]: """获取从源实体到目标实体的路径(简化版)""" # BFS 找路径 - visited = {source_id} - queue = [(source_id, [source_id])] + visited = {source_id} + queue = [(source_id, [source_id])] while queue: - current, path = queue.pop(0) + current, path = queue.pop(0) if current == target_id: return path @@ -1483,22 +1483,22 @@ class EntityPathDiscovery: if len(path) > 5: # 限制路径长度 continue - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT CASE - WHEN source_entity_id = ? THEN target_entity_id + WHEN source_entity_id = ? THEN target_entity_id ELSE source_entity_id END as neighbor_id FROM entity_relations - WHERE (source_entity_id = ? OR target_entity_id = ?) - AND project_id = ? + WHERE (source_entity_id = ? OR target_entity_id = ?) + AND project_id = ? """, (current, current, current, project_id), ).fetchall() for neighbor in neighbors: - neighbor_id = neighbor["neighbor_id"] + neighbor_id = neighbor["neighbor_id"] if neighbor_id not in visited: visited.add(neighbor_id) queue.append((neighbor_id, path + [neighbor_id])) @@ -1516,7 +1516,7 @@ class EntityPathDiscovery: Dict: D3.js 可视化数据格式 """ # 节点数据 - nodes = [] + nodes = [] for node in path.nodes: nodes.append( { @@ -1529,7 +1529,7 @@ class EntityPathDiscovery: ) # 边数据 - links = [] + links = [] for edge in path.edges: links.append( { @@ -1558,55 +1558,55 @@ class EntityPathDiscovery: Returns: List[Dict]: 中心性分析结果 """ - conn = self._get_conn() + conn = self._get_conn() # 获取所有实体 - entities = conn.execute( - "SELECT id, name FROM entities WHERE project_id = ?", (project_id,) + entities = conn.execute( + "SELECT id, name FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() # 计算每个实体作为桥梁的次数 - bridge_scores = [] + bridge_scores = [] for entity in entities: - entity_id = entity["id"] + entity_id = entity["id"] # 计算该实体连接的不同群组数量 - neighbors = conn.execute( + neighbors = conn.execute( """ SELECT CASE - WHEN source_entity_id = ? THEN target_entity_id + WHEN source_entity_id = ? THEN target_entity_id ELSE source_entity_id END as neighbor_id FROM entity_relations - WHERE (source_entity_id = ? OR target_entity_id = ?) - AND project_id = ? + WHERE (source_entity_id = ? OR target_entity_id = ?) + AND project_id = ? """, (entity_id, entity_id, entity_id, project_id), ).fetchall() - neighbor_ids = {n["neighbor_id"] for n in neighbors} + neighbor_ids = {n["neighbor_id"] for n in neighbors} # 计算邻居之间的连接数(用于评估桥接程度) if len(neighbor_ids) > 1: - connections = conn.execute( + connections = conn.execute( f""" SELECT COUNT(*) as count FROM entity_relations - WHERE ((source_entity_id IN ({",".join(["?" for _ in neighbor_ids])}) - AND target_entity_id IN ({",".join(["?" for _ in neighbor_ids])})) - OR (target_entity_id IN ({",".join(["?" for _ in neighbor_ids])}) - AND source_entity_id IN ({",".join(["?" for _ in neighbor_ids])}))) - AND project_id = ? + WHERE ((source_entity_id IN ({", ".join(["?" for _ in neighbor_ids])}) + AND target_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})) + OR (target_entity_id IN ({", ".join(["?" for _ in neighbor_ids])}) + AND source_entity_id IN ({", ".join(["?" for _ in neighbor_ids])}))) + AND project_id = ? """, list(neighbor_ids) * 4 + [project_id], ).fetchone() - # 桥接分数 = 邻居数量 / (邻居间连接数 + 1) - bridge_score = len(neighbor_ids) / (connections["count"] + 1) + # 桥接分数 = 邻居数量 / (邻居间连接数 + 1) + bridge_score = len(neighbor_ids) / (connections["count"] + 1) else: - bridge_score = 0 + bridge_score = 0 bridge_scores.append( { @@ -1620,7 +1620,7 @@ class EntityPathDiscovery: conn.close() # 按桥接分数排序 - bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True) + bridge_scores.sort(key = lambda x: x["bridge_score"], reverse = True) return bridge_scores[:20] # 返回前20 @@ -1638,13 +1638,13 @@ class KnowledgeGapDetection: - 生成知识补全建议 """ - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path def _get_conn(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def analyze_project(self, project_id: str) -> list[KnowledgeGap]: @@ -1657,7 +1657,7 @@ class KnowledgeGapDetection: Returns: List[KnowledgeGap]: 知识缺口列表 """ - gaps = [] + gaps = [] # 1. 检查实体属性完整性 gaps.extend(self._check_entity_attribute_completeness(project_id)) @@ -1675,55 +1675,55 @@ class KnowledgeGapDetection: gaps.extend(self._check_missing_key_entities(project_id)) # 按严重程度排序 - severity_order = {"high": 0, "medium": 1, "low": 2} - gaps.sort(key=lambda x: severity_order.get(x.severity, 3)) + severity_order = {"high": 0, "medium": 1, "low": 2} + gaps.sort(key = lambda x: severity_order.get(x.severity, 3)) return gaps def _check_entity_attribute_completeness(self, project_id: str) -> list[KnowledgeGap]: """检查实体属性完整性""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 获取项目的属性模板 - templates = conn.execute( - "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", - (project_id,), + templates = conn.execute( + "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", + (project_id, ), ).fetchall() if not templates: conn.close() return [] - required_template_ids = {t["id"] for t in templates if t["is_required"]} + required_template_ids = {t["id"] for t in templates if t["is_required"]} if not required_template_ids: conn.close() return [] # 检查每个实体的属性完整性 - entities = conn.execute( - "SELECT id, name FROM entities WHERE project_id = ?", (project_id,) + entities = conn.execute( + "SELECT id, name FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() for entity in entities: - entity_id = entity["id"] + entity_id = entity["id"] # 获取实体已有的属性 - existing_attrs = conn.execute( - "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,) + existing_attrs = conn.execute( + "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id, ) ).fetchall() - existing_template_ids = {a["template_id"] for a in existing_attrs} + existing_template_ids = {a["template_id"] for a in existing_attrs} # 找出缺失的必需属性 - missing_templates = required_template_ids - existing_template_ids + missing_templates = required_template_ids - existing_template_ids if missing_templates: - missing_names = [] + missing_names = [] for template_id in missing_templates: - template = conn.execute( - "SELECT name FROM attribute_templates WHERE id = ?", (template_id,) + template = conn.execute( + "SELECT name FROM attribute_templates WHERE id = ?", (template_id, ) ).fetchone() if template: missing_names.append(template["name"]) @@ -1731,18 +1731,18 @@ class KnowledgeGapDetection: if missing_names: gaps.append( KnowledgeGap( - gap_id=f"gap_attr_{entity_id}", - gap_type="missing_attribute", - entity_id=entity_id, - entity_name=entity["name"], - description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", - severity="medium", - suggestions=[ + gap_id = f"gap_attr_{entity_id}", + gap_type = "missing_attribute", + entity_id = entity_id, + entity_name = entity["name"], + description = f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", + severity = "medium", + suggestions = [ f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}", "检查属性模板定义是否合理", ], - related_entities=[], - metadata={"missing_attributes": missing_names}, + related_entities = [], + metadata = {"missing_attributes": missing_names}, ) ) @@ -1751,39 +1751,39 @@ class KnowledgeGapDetection: def _check_relation_sparsity(self, project_id: str) -> list[KnowledgeGap]: """检查关系稀疏度""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 获取所有实体及其关系数量 - entities = conn.execute( - "SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,) + entities = conn.execute( + "SELECT id, name, type FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() for entity in entities: - entity_id = entity["id"] + entity_id = entity["id"] # 计算关系数量 - relation_count = conn.execute( + relation_count = conn.execute( """ SELECT COUNT(*) as count FROM entity_relations - WHERE (source_entity_id = ? OR target_entity_id = ?) - AND project_id = ? + WHERE (source_entity_id = ? OR target_entity_id = ?) + AND project_id = ? """, (entity_id, entity_id, project_id), ).fetchone()["count"] # 根据实体类型判断阈值 - threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0 + threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0 if relation_count <= threshold: # 查找潜在的相关实体 - potential_related = conn.execute( + potential_related = conn.execute( """ SELECT e.id, e.name FROM entities e - JOIN transcripts t ON t.project_id = e.project_id - WHERE e.project_id = ? + JOIN transcripts t ON t.project_id = e.project_id + WHERE e.project_id = ? AND e.id != ? AND t.full_text LIKE ? LIMIT 5 @@ -1793,19 +1793,19 @@ class KnowledgeGapDetection: gaps.append( KnowledgeGap( - gap_id=f"gap_sparse_{entity_id}", - gap_type="sparse_relation", - entity_id=entity_id, - entity_name=entity["name"], - description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", - severity="medium" if relation_count == 0 else "low", - suggestions=[ + gap_id = f"gap_sparse_{entity_id}", + gap_type = "sparse_relation", + entity_id = entity_id, + entity_name = entity["name"], + description = f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", + severity = "medium" if relation_count == 0 else "low", + suggestions = [ f"检查转录文本中提及 '{entity['name']}' 的其他实体", f"手动添加 '{entity['name']}' 与其他实体的关系", "使用实体对齐功能合并相似实体", ], - related_entities=[r["id"] for r in potential_related], - metadata={ + related_entities = [r["id"] for r in potential_related], + metadata = { "relation_count": relation_count, "potential_related": [r["name"] for r in potential_related], }, @@ -1817,39 +1817,39 @@ class KnowledgeGapDetection: def _check_isolated_entities(self, project_id: str) -> list[KnowledgeGap]: """检查孤立实体(没有任何关系)""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 查找没有关系的实体 - isolated = conn.execute( + isolated = conn.execute( """ SELECT e.id, e.name, e.type FROM entities e - LEFT JOIN entity_relations r1 ON e.id = r1.source_entity_id - LEFT JOIN entity_relations r2 ON e.id = r2.target_entity_id - WHERE e.project_id = ? + LEFT JOIN entity_relations r1 ON e.id = r1.source_entity_id + LEFT JOIN entity_relations r2 ON e.id = r2.target_entity_id + WHERE e.project_id = ? AND r1.id IS NULL AND r2.id IS NULL """, - (project_id,), + (project_id, ), ).fetchall() for entity in isolated: gaps.append( KnowledgeGap( - gap_id=f"gap_iso_{entity['id']}", - gap_type="isolated_entity", - entity_id=entity["id"], - entity_name=entity["name"], - description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", - severity="high", - suggestions=[ + gap_id = f"gap_iso_{entity['id']}", + gap_type = "isolated_entity", + entity_id = entity["id"], + entity_name = entity["name"], + description = f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", + severity = "high", + suggestions = [ f"检查 '{entity['name']}' 是否应该与其他实体建立关系", f"考虑删除不相关的实体 '{entity['name']}'", "运行关系发现算法自动识别潜在关系", ], - related_entities=[], - metadata={"entity_type": entity["type"]}, + related_entities = [], + metadata = {"entity_type": entity["type"]}, ) ) @@ -1858,32 +1858,32 @@ class KnowledgeGapDetection: def _check_incomplete_entities(self, project_id: str) -> list[KnowledgeGap]: """检查不完整实体(缺少名称、类型或定义)""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 查找缺少定义的实体 - incomplete = conn.execute( + incomplete = conn.execute( """ SELECT id, name, type, definition FROM entities - WHERE project_id = ? - AND (definition IS NULL OR definition = '') + WHERE project_id = ? + AND (definition IS NULL OR definition = '') """, - (project_id,), + (project_id, ), ).fetchall() for entity in incomplete: gaps.append( KnowledgeGap( - gap_id=f"gap_inc_{entity['id']}", - gap_type="incomplete_entity", - entity_id=entity["id"], - entity_name=entity["name"], - description=f"实体 '{entity['name']}' 缺少定义", - severity="low", - suggestions=[f"为 '{entity['name']}' 添加定义", "从转录文本中提取定义信息"], - related_entities=[], - metadata={"entity_type": entity["type"]}, + gap_id = f"gap_inc_{entity['id']}", + gap_type = "incomplete_entity", + entity_id = entity["id"], + entity_name = entity["name"], + description = f"实体 '{entity['name']}' 缺少定义", + severity = "low", + suggestions = [f"为 '{entity['name']}' 添加定义", "从转录文本中提取定义信息"], + related_entities = [], + metadata = {"entity_type": entity["type"]}, ) ) @@ -1892,30 +1892,30 @@ class KnowledgeGapDetection: def _check_missing_key_entities(self, project_id: str) -> list[KnowledgeGap]: """检查可能缺失的关键实体""" - conn = self._get_conn() - gaps = [] + conn = self._get_conn() + gaps = [] # 分析转录文本中频繁提及但未提取为实体的词 - transcripts = conn.execute( - "SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,) + transcripts = conn.execute( + "SELECT full_text FROM transcripts WHERE project_id = ?", (project_id, ) ).fetchall() # 合并所有文本 - all_text = " ".join([t["full_text"] or "" for t in transcripts]) + all_text = " ".join([t["full_text"] or "" for t in transcripts]) # 获取现有实体名称 - existing_entities = conn.execute( - "SELECT name FROM entities WHERE project_id = ?", (project_id,) + existing_entities = conn.execute( + "SELECT name FROM entities WHERE project_id = ?", (project_id, ) ).fetchall() - existing_names = {e["name"].lower() for e in existing_entities} + existing_names = {e["name"].lower() for e in existing_entities} # 简单的关键词提取(实际可以使用更复杂的 NLP 方法) # 查找大写的词组(可能是专有名词) - potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text) + potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text) # 统计频率 - freq = defaultdict(int) + freq = defaultdict(int) for entity in potential_entities: if len(entity) > 3 and entity.lower() not in existing_names: freq[entity] += 1 @@ -1925,18 +1925,18 @@ class KnowledgeGapDetection: if count >= 3: # 出现3次以上 gaps.append( KnowledgeGap( - gap_id=f"gap_missing_{hash(entity) % 10000}", - gap_type="missing_key_entity", - entity_id=None, - entity_name=None, - description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", - severity="low", - suggestions=[ + gap_id = f"gap_missing_{hash(entity) % 10000}", + gap_type = "missing_key_entity", + entity_id = None, + entity_name = None, + description = f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", + severity = "low", + suggestions = [ f"考虑将 '{entity}' 添加为实体", "检查实体提取算法是否需要优化", ], - related_entities=[], - metadata={"mention_count": count}, + related_entities = [], + metadata = {"mention_count": count}, ) ) @@ -1953,36 +1953,36 @@ class KnowledgeGapDetection: Returns: Dict: 完整性报告 """ - conn = self._get_conn() + conn = self._get_conn() # 基础统计 - stats = conn.execute( + stats = conn.execute( """ SELECT - (SELECT COUNT(*) FROM entities WHERE project_id = ?) as entity_count, - (SELECT COUNT(*) FROM entity_relations WHERE project_id = ?) as relation_count, - (SELECT COUNT(*) FROM transcripts WHERE project_id = ?) as transcript_count + (SELECT COUNT(*) FROM entities WHERE project_id = ?) as entity_count, + (SELECT COUNT(*) FROM entity_relations WHERE project_id = ?) as relation_count, + (SELECT COUNT(*) FROM transcripts WHERE project_id = ?) as transcript_count """, (project_id, project_id, project_id), ).fetchone() # 计算完整性分数 - gaps = self.analyze_project(project_id) + gaps = self.analyze_project(project_id) # 按类型统计 - gap_by_type = defaultdict(int) - severity_count = {"high": 0, "medium": 0, "low": 0} + gap_by_type = defaultdict(int) + severity_count = {"high": 0, "medium": 0, "low": 0} for gap in gaps: gap_by_type[gap.gap_type] += 1 severity_count[gap.severity] += 1 # 计算完整性分数(100 - 扣分) - score = 100 + score = 100 score -= severity_count["high"] * 10 score -= severity_count["medium"] * 5 score -= severity_count["low"] * 2 - score = max(0, score) + score = max(0, score) conn.close() @@ -2005,9 +2005,9 @@ class KnowledgeGapDetection: def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]: """生成改进建议""" - recommendations = [] + recommendations = [] - gap_types = {g.gap_type for g in gaps} + gap_types = {g.gap_type for g in gaps} if "isolated_entity" in gap_types: recommendations.append("优先处理孤立实体,建立实体间的关系连接") @@ -2040,14 +2040,14 @@ class SearchManager: 整合全文搜索、语义搜索、实体路径发现和知识缺口识别功能 """ - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path - self.fulltext_search = FullTextSearch(db_path) - self.semantic_search = SemanticSearch(db_path) - self.path_discovery = EntityPathDiscovery(db_path) - self.gap_detection = KnowledgeGapDetection(db_path) + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path + self.fulltext_search = FullTextSearch(db_path) + self.semantic_search = SemanticSearch(db_path) + self.path_discovery = EntityPathDiscovery(db_path) + self.gap_detection = KnowledgeGapDetection(db_path) - def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict: + def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict: """ 混合搜索(全文 + 语义) @@ -2060,20 +2060,20 @@ class SearchManager: Dict: 混合搜索结果 """ # 全文搜索 - fulltext_results = self.fulltext_search.search(query, project_id, limit=limit) + fulltext_results = self.fulltext_search.search(query, project_id, limit = limit) # 语义搜索 - semantic_results = [] + semantic_results = [] if self.semantic_search.is_available(): - semantic_results = self.semantic_search.search(query, project_id, top_k=limit) + semantic_results = self.semantic_search.search(query, project_id, top_k = limit) # 合并结果(去重并加权) - combined = {} + combined = {} # 添加全文搜索结果 for r in fulltext_results: - key = (r.id, r.content_type) - combined[key] = { + key = (r.id, r.content_type) + combined[key] = { "id": r.id, "content": r.content, "content_type": r.content_type, @@ -2086,12 +2086,12 @@ class SearchManager: # 添加语义搜索结果 for r in semantic_results: - key = (r.id, r.content_type) + key = (r.id, r.content_type) if key in combined: - combined[key]["semantic_score"] = r.similarity + combined[key]["semantic_score"] = r.similarity combined[key]["combined_score"] += r.similarity * 0.4 # 语义权重 40% else: - combined[key] = { + combined[key] = { "id": r.id, "content": r.content, "content_type": r.content_type, @@ -2103,8 +2103,8 @@ class SearchManager: } # 排序 - results = list(combined.values()) - results.sort(key=lambda x: x["combined_score"], reverse=True) + results = list(combined.values()) + results.sort(key = lambda x: x["combined_score"], reverse = True) return { "query": query, @@ -2126,19 +2126,19 @@ class SearchManager: Dict: 索引统计 """ # 全文索引 - fulltext_stats = self.fulltext_search.reindex_project(project_id) + fulltext_stats = self.fulltext_search.reindex_project(project_id) # 语义索引 - semantic_stats = {"indexed": 0, "errors": 0} + semantic_stats = {"indexed": 0, "errors": 0} if self.semantic_search.is_available(): - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 索引转录文本 - transcripts = conn.execute( - "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", - (project_id,), + transcripts = conn.execute( + "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", + (project_id, ), ).fetchall() for t in transcripts: @@ -2150,13 +2150,13 @@ class SearchManager: semantic_stats["errors"] += 1 # 索引实体 - entities = conn.execute( - "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", - (project_id,), + entities = conn.execute( + "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", + (project_id, ), ).fetchall() for e in entities: - text = f"{e['name']} {e['definition'] or ''}" + text = f"{e['name']} {e['definition'] or ''}" if self.semantic_search.index_embedding(e["id"], "entity", e["project_id"], text): semantic_stats["indexed"] += 1 else: @@ -2166,34 +2166,34 @@ class SearchManager: return {"project_id": project_id, "fulltext": fulltext_stats, "semantic": semantic_stats} - def get_search_stats(self, project_id: str | None = None) -> dict: + def get_search_stats(self, project_id: str | None = None) -> dict: """获取搜索统计信息""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row - where_clause = "WHERE project_id = ?" if project_id else "" - params = [project_id] if project_id else [] + where_clause = "WHERE project_id = ?" if project_id else "" + params = [project_id] if project_id else [] # 全文索引统计 - fulltext_count = conn.execute( + fulltext_count = conn.execute( f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params ).fetchone()["count"] # 语义索引统计 - semantic_count = conn.execute( + semantic_count = conn.execute( f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params ).fetchone()["count"] # 按类型统计 - type_stats = {} + type_stats = {} if project_id: - rows = conn.execute( + rows = conn.execute( """SELECT content_type, COUNT(*) as count - FROM search_indexes WHERE project_id = ? + FROM search_indexes WHERE project_id = ? GROUP BY content_type""", - (project_id,), + (project_id, ), ).fetchall() - type_stats = {r["content_type"]: r["count"] for r in rows} + type_stats = {r["content_type"]: r["count"] for r in rows} conn.close() @@ -2207,14 +2207,14 @@ class SearchManager: # 单例模式 -_search_manager = None +_search_manager = None -def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: +def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: """获取搜索管理器单例""" global _search_manager if _search_manager is None: - _search_manager = SearchManager(db_path) + _search_manager = SearchManager(db_path) return _search_manager @@ -2222,28 +2222,28 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: def fulltext_search( - query: str, project_id: str | None = None, limit: int = 20 + query: str, project_id: str | None = None, limit: int = 20 ) -> list[SearchResult]: """全文搜索便捷函数""" - manager = get_search_manager() - return manager.fulltext_search.search(query, project_id, limit=limit) + manager = get_search_manager() + return manager.fulltext_search.search(query, project_id, limit = limit) def semantic_search( - query: str, project_id: str | None = None, top_k: int = 10 + query: str, project_id: str | None = None, top_k: int = 10 ) -> list[SemanticSearchResult]: """语义搜索便捷函数""" - manager = get_search_manager() - return manager.semantic_search.search(query, project_id, top_k=top_k) + manager = get_search_manager() + return manager.semantic_search.search(query, project_id, top_k = top_k) -def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: +def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: """查找实体路径便捷函数""" - manager = get_search_manager() + manager = get_search_manager() return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth) def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]: """知识缺口检测便捷函数""" - manager = get_search_manager() + manager = get_search_manager() return manager.gap_detection.analyze_project(project_id) diff --git a/backend/security_manager.py b/backend/security_manager.py index aac14d5..3f7161b 100644 --- a/backend/security_manager.py +++ b/backend/security_manager.py @@ -20,54 +20,54 @@ try: from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC - CRYPTO_AVAILABLE = True + CRYPTO_AVAILABLE = True except ImportError: - CRYPTO_AVAILABLE = False + CRYPTO_AVAILABLE = False print("Warning: cryptography not available, encryption features disabled") class AuditActionType(Enum): """审计动作类型""" - CREATE = "create" - READ = "read" - UPDATE = "update" - DELETE = "delete" - LOGIN = "login" - LOGOUT = "logout" - EXPORT = "export" - IMPORT = "import" - SHARE = "share" - PERMISSION_CHANGE = "permission_change" - ENCRYPTION_ENABLE = "encryption_enable" - ENCRYPTION_DISABLE = "encryption_disable" - DATA_MASKING = "data_masking" - API_KEY_CREATE = "api_key_create" - API_KEY_REVOKE = "api_key_revoke" - WORKFLOW_TRIGGER = "workflow_trigger" - WEBHOOK_SEND = "webhook_send" - BOT_MESSAGE = "bot_message" + CREATE = "create" + READ = "read" + UPDATE = "update" + DELETE = "delete" + LOGIN = "login" + LOGOUT = "logout" + EXPORT = "export" + IMPORT = "import" + SHARE = "share" + PERMISSION_CHANGE = "permission_change" + ENCRYPTION_ENABLE = "encryption_enable" + ENCRYPTION_DISABLE = "encryption_disable" + DATA_MASKING = "data_masking" + API_KEY_CREATE = "api_key_create" + API_KEY_REVOKE = "api_key_revoke" + WORKFLOW_TRIGGER = "workflow_trigger" + WEBHOOK_SEND = "webhook_send" + BOT_MESSAGE = "bot_message" class DataSensitivityLevel(Enum): """数据敏感度级别""" - PUBLIC = "public" # 公开 - INTERNAL = "internal" # 内部 - CONFIDENTIAL = "confidential" # 机密 - SECRET = "secret" # 绝密 + PUBLIC = "public" # 公开 + INTERNAL = "internal" # 内部 + CONFIDENTIAL = "confidential" # 机密 + SECRET = "secret" # 绝密 class MaskingRuleType(Enum): """脱敏规则类型""" - PHONE = "phone" # 手机号 - EMAIL = "email" # 邮箱 - ID_CARD = "id_card" # 身份证号 - BANK_CARD = "bank_card" # 银行卡号 - NAME = "name" # 姓名 - ADDRESS = "address" # 地址 - CUSTOM = "custom" # 自定义 + PHONE = "phone" # 手机号 + EMAIL = "email" # 邮箱 + ID_CARD = "id_card" # 身份证号 + BANK_CARD = "bank_card" # 银行卡号 + NAME = "name" # 姓名 + ADDRESS = "address" # 地址 + CUSTOM = "custom" # 自定义 @dataclass @@ -76,17 +76,17 @@ class AuditLog: id: str action_type: str - user_id: str | None = None - user_ip: str | None = None - user_agent: str | None = None - resource_type: str | None = None # project, entity, transcript, etc. - resource_id: str | None = None - action_details: str | None = None # JSON string - before_value: str | None = None - after_value: str | None = None - success: bool = True - error_message: str | None = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + user_id: str | None = None + user_ip: str | None = None + user_agent: str | None = None + resource_type: str | None = None # project, entity, transcript, etc. + resource_id: str | None = None + action_details: str | None = None # JSON string + before_value: str | None = None + after_value: str | None = None + success: bool = True + error_message: str | None = None + created_at: str = field(default_factory = lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -98,13 +98,13 @@ class EncryptionConfig: id: str project_id: str - is_enabled: bool = False - encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305 - key_derivation: str = "pbkdf2" # pbkdf2, argon2 - master_key_hash: str | None = None # 主密钥哈希(用于验证) - salt: str | None = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + is_enabled: bool = False + encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305 + key_derivation: str = "pbkdf2" # pbkdf2, argon2 + master_key_hash: str | None = None # 主密钥哈希(用于验证) + salt: str | None = None + created_at: str = field(default_factory = lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -120,11 +120,11 @@ class MaskingRule: rule_type: str # phone, email, id_card, bank_card, name, address, custom pattern: str # 正则表达式 replacement: str # 替换模板,如 "****" - is_active: bool = True - priority: int = 0 - description: str | None = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + is_active: bool = True + priority: int = 0 + description: str | None = None + created_at: str = field(default_factory = lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -137,16 +137,16 @@ class DataAccessPolicy: id: str project_id: str name: str - description: str | None = None - allowed_users: str | None = None # JSON array of user IDs - allowed_roles: str | None = None # JSON array of roles - allowed_ips: str | None = None # JSON array of IP patterns - time_restrictions: str | None = None # JSON: {"start_time": "09:00", "end_time": "18:00"} - max_access_count: int | None = None # 最大访问次数 - require_approval: bool = False - is_active: bool = True - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + description: str | None = None + allowed_users: str | None = None # JSON array of user IDs + allowed_roles: str | None = None # JSON array of roles + allowed_ips: str | None = None # JSON array of IP patterns + time_restrictions: str | None = None # JSON: {"start_time": "09:00", "end_time": "18:00"} + max_access_count: int | None = None # 最大访问次数 + require_approval: bool = False + is_active: bool = True + created_at: str = field(default_factory = lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -159,12 +159,12 @@ class AccessRequest: id: str policy_id: str user_id: str - request_reason: str | None = None - status: str = "pending" # pending, approved, rejected, expired - approved_by: str | None = None - approved_at: str | None = None - expires_at: str | None = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + request_reason: str | None = None + status: str = "pending" # pending, approved, rejected, expired + approved_by: str | None = None + approved_at: str | None = None + expires_at: str | None = None + created_at: str = field(default_factory = lambda: datetime.now().isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -174,9 +174,9 @@ class SecurityManager: """安全管理器""" # 预定义脱敏规则 - DEFAULT_MASKING_RULES = { + DEFAULT_MASKING_RULES = { MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"}, - MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"}, + MaskingRuleType.EMAIL: {"pattern": r"(\w{1, 3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"}, MaskingRuleType.ID_CARD: { "pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2", @@ -190,22 +190,22 @@ class SecurityManager: "replacement": r"\1**", }, MaskingRuleType.ADDRESS: { - "pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)", + "pattern": r"([\u4e00-\u9fa5]{2, })([\u4e00-\u9fa5]+路|街|巷|号)(.+)", "replacement": r"\1\2***", }, } - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path # 预编译正则缓存 - self._compiled_patterns: dict[str, re.Pattern] = {} - self._local = {} + self._compiled_patterns: dict[str, re.Pattern] = {} + self._local = {} self._init_db() def _init_db(self) -> None: """初始化数据库表""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() # 审计日志表 cursor.execute(""" @@ -332,35 +332,35 @@ class SecurityManager: def log_audit( self, action_type: AuditActionType, - user_id: str | None = None, - user_ip: str | None = None, - user_agent: str | None = None, - resource_type: str | None = None, - resource_id: str | None = None, - action_details: dict | None = None, - before_value: str | None = None, - after_value: str | None = None, - success: bool = True, - error_message: str | None = None, + user_id: str | None = None, + user_ip: str | None = None, + user_agent: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + action_details: dict | None = None, + before_value: str | None = None, + after_value: str | None = None, + success: bool = True, + error_message: str | None = None, ) -> AuditLog: """记录审计日志""" - log = AuditLog( - id=self._generate_id(), - action_type=action_type.value, - user_id=user_id, - user_ip=user_ip, - user_agent=user_agent, - resource_type=resource_type, - resource_id=resource_id, - action_details=json.dumps(action_details) if action_details else None, - before_value=before_value, - after_value=after_value, - success=success, - error_message=error_message, + log = AuditLog( + id = self._generate_id(), + action_type = action_type.value, + user_id = user_id, + user_ip = user_ip, + user_agent = user_agent, + resource_type = resource_type, + resource_id = resource_id, + action_details = json.dumps(action_details) if action_details else None, + before_value = before_value, + after_value = after_value, + success = success, + error_message = error_message, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ INSERT INTO audit_logs @@ -391,34 +391,34 @@ class SecurityManager: def get_audit_logs( self, - user_id: str | None = None, - resource_type: str | None = None, - resource_id: str | None = None, - action_type: str | None = None, - start_time: str | None = None, - end_time: str | None = None, - success: bool | None = None, - limit: int = 100, - offset: int = 0, + user_id: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + action_type: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + success: bool | None = None, + limit: int = 100, + offset: int = 0, ) -> list[AuditLog]: """查询审计日志""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT * FROM audit_logs WHERE 1=1" - params = [] + query = "SELECT * FROM audit_logs WHERE 1 = 1" + params = [] if user_id: - query += " AND user_id = ?" + query += " AND user_id = ?" params.append(user_id) if resource_type: - query += " AND resource_type = ?" + query += " AND resource_type = ?" params.append(resource_type) if resource_id: - query += " AND resource_id = ?" + query += " AND resource_id = ?" params.append(resource_id) if action_type: - query += " AND action_type = ?" + query += " AND action_type = ?" params.append(action_type) if start_time: query += " AND created_at >= ?" @@ -427,36 +427,36 @@ class SecurityManager: query += " AND created_at <= ?" params.append(end_time) if success is not None: - query += " AND success = ?" + query += " AND success = ?" params.append(int(success)) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() conn.close() - logs = [] - col_names = [desc[0] for desc in cursor.description] if cursor.description else [] + logs = [] + col_names = [desc[0] for desc in cursor.description] if cursor.description else [] if not col_names: return logs for row in rows: - log = AuditLog( - id=row[0], - action_type=row[1], - user_id=row[2], - user_ip=row[3], - user_agent=row[4], - resource_type=row[5], - resource_id=row[6], - action_details=row[7], - before_value=row[8], - after_value=row[9], - success=bool(row[10]), - error_message=row[11], - created_at=row[12], + log = AuditLog( + id = row[0], + action_type = row[1], + user_id = row[2], + user_ip = row[3], + user_agent = row[4], + resource_type = row[5], + resource_id = row[6], + action_details = row[7], + before_value = row[8], + after_value = row[9], + success = bool(row[10]), + error_message = row[11], + created_at = row[12], ) logs.append(log) @@ -464,14 +464,14 @@ class SecurityManager: return logs def get_audit_stats( - self, start_time: str | None = None, end_time: str | None = None + self, start_time: str | None = None, end_time: str | None = None ) -> dict[str, Any]: """获取审计统计""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1=1" - params = [] + query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1 = 1" + params = [] if start_time: query += " AND created_at >= ?" @@ -483,9 +483,9 @@ class SecurityManager: query += " GROUP BY action_type, success" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() - stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}} + stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}} for action_type, success, count in rows: stats["total_actions"] += count @@ -495,7 +495,7 @@ class SecurityManager: stats["failure_count"] += count if action_type not in stats["action_breakdown"]: - stats["action_breakdown"][action_type] = {"success": 0, "failure": 0} + stats["action_breakdown"][action_type] = {"success": 0, "failure": 0} if success: stats["action_breakdown"][action_type]["success"] += count @@ -512,11 +512,11 @@ class SecurityManager: if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=salt, - iterations=100000, + kdf = PBKDF2HMAC( + algorithm = hashes.SHA256(), + length = 32, + salt = salt, + iterations = 100000, ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) @@ -526,36 +526,36 @@ class SecurityManager: raise RuntimeError("cryptography library not available") # 生成盐值 - salt = secrets.token_hex(16) + salt = secrets.token_hex(16) # 派生密钥并哈希(用于验证) - key = self._derive_key(master_password, salt.encode()) - key_hash = hashlib.sha256(key).hexdigest() + key = self._derive_key(master_password, salt.encode()) + key_hash = hashlib.sha256(key).hexdigest() - config = EncryptionConfig( - id=self._generate_id(), - project_id=project_id, - is_enabled=True, - encryption_type="aes-256-gcm", - key_derivation="pbkdf2", - master_key_hash=key_hash, - salt=salt, + config = EncryptionConfig( + id = self._generate_id(), + project_id = project_id, + is_enabled = True, + encryption_type = "aes-256-gcm", + key_derivation = "pbkdf2", + master_key_hash = key_hash, + salt = salt, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() # 检查是否已存在配置 - cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,)) - existing = cursor.fetchone() + cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id, )) + existing = cursor.fetchone() if existing: cursor.execute( """ UPDATE encryption_configs - SET is_enabled = 1, encryption_type = ?, key_derivation = ?, - master_key_hash = ?, salt = ?, updated_at = ? - WHERE project_id = ? + SET is_enabled = 1, encryption_type = ?, key_derivation = ?, + master_key_hash = ?, salt = ?, updated_at = ? + WHERE project_id = ? """, ( config.encryption_type, @@ -566,7 +566,7 @@ class SecurityManager: project_id, ), ) - config.id = existing[0] + config.id = existing[0] else: cursor.execute( """ @@ -593,10 +593,10 @@ class SecurityManager: # 记录审计日志 self.log_audit( - action_type=AuditActionType.ENCRYPTION_ENABLE, - resource_type="project", - resource_id=project_id, - action_details={"encryption_type": config.encryption_type}, + action_type = AuditActionType.ENCRYPTION_ENABLE, + resource_type = "project", + resource_id = project_id, + action_details = {"encryption_type": config.encryption_type}, ) return config @@ -607,14 +607,14 @@ class SecurityManager: if not self.verify_encryption_password(project_id, master_password): return False - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ UPDATE encryption_configs - SET is_enabled = 0, updated_at = ? - WHERE project_id = ? + SET is_enabled = 0, updated_at = ? + WHERE project_id = ? """, (datetime.now().isoformat(), project_id), ) @@ -624,9 +624,9 @@ class SecurityManager: # 记录审计日志 self.log_audit( - action_type=AuditActionType.ENCRYPTION_DISABLE, - resource_type="project", - resource_id=project_id, + action_type = AuditActionType.ENCRYPTION_DISABLE, + resource_type = "project", + resource_id = project_id, ) return True @@ -636,60 +636,60 @@ class SecurityManager: if not CRYPTO_AVAILABLE: return False - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( - "SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", - (project_id,), + "SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", + (project_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return False - stored_hash, salt = row - key = self._derive_key(password, salt.encode()) - key_hash = hashlib.sha256(key).hexdigest() + stored_hash, salt = row + key = self._derive_key(password, salt.encode()) + key_hash = hashlib.sha256(key).hexdigest() return key_hash == stored_hash def get_encryption_config(self, project_id: str) -> EncryptionConfig | None: """获取加密配置""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,)) - row = cursor.fetchone() + cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id, )) + row = cursor.fetchone() conn.close() if not row: return None return EncryptionConfig( - id=row[0], - project_id=row[1], - is_enabled=bool(row[2]), - encryption_type=row[3], - key_derivation=row[4], - master_key_hash=row[5], - salt=row[6], - created_at=row[7], - updated_at=row[8], + id = row[0], + project_id = row[1], + is_enabled = bool(row[2]), + encryption_type = row[3], + key_derivation = row[4], + master_key_hash = row[5], + salt = row[6], + created_at = row[7], + updated_at = row[8], ) - def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: + def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: """加密数据""" if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") if salt is None: - salt = secrets.token_hex(16) + salt = secrets.token_hex(16) - key = self._derive_key(password, salt.encode()) - f = Fernet(key) - encrypted = f.encrypt(data.encode()) + key = self._derive_key(password, salt.encode()) + f = Fernet(key) + encrypted = f.encrypt(data.encode()) return base64.b64encode(encrypted).decode(), salt @@ -698,9 +698,9 @@ class SecurityManager: if not CRYPTO_AVAILABLE: raise RuntimeError("cryptography library not available") - key = self._derive_key(password, salt.encode()) - f = Fernet(key) - decrypted = f.decrypt(base64.b64decode(encrypted_data)) + key = self._derive_key(password, salt.encode()) + f = Fernet(key) + decrypted = f.decrypt(base64.b64decode(encrypted_data)) return decrypted.decode() @@ -711,31 +711,31 @@ class SecurityManager: project_id: str, name: str, rule_type: MaskingRuleType, - pattern: str | None = None, - replacement: str | None = None, - description: str | None = None, - priority: int = 0, + pattern: str | None = None, + replacement: str | None = None, + description: str | None = None, + priority: int = 0, ) -> MaskingRule: """创建脱敏规则""" # 使用预定义规则或自定义规则 if rule_type in self.DEFAULT_MASKING_RULES and not pattern: - default = self.DEFAULT_MASKING_RULES[rule_type] - pattern = default["pattern"] - replacement = replacement or default["replacement"] + default = self.DEFAULT_MASKING_RULES[rule_type] + pattern = default["pattern"] + replacement = replacement or default["replacement"] - rule = MaskingRule( - id=self._generate_id(), - project_id=project_id, - name=name, - rule_type=rule_type.value, - pattern=pattern or "", - replacement=replacement or "****", - description=description, - priority=priority, + rule = MaskingRule( + id = self._generate_id(), + project_id = project_id, + name = name, + rule_type = rule_type.value, + pattern = pattern or "", + replacement = replacement or "****", + description = description, + priority = priority, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -764,46 +764,46 @@ class SecurityManager: # 记录审计日志 self.log_audit( - action_type=AuditActionType.DATA_MASKING, - resource_type="project", - resource_id=project_id, - action_details={"action": "create_rule", "rule_name": name}, + action_type = AuditActionType.DATA_MASKING, + resource_type = "project", + resource_id = project_id, + action_details = {"action": "create_rule", "rule_name": name}, ) return rule - def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]: + def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]: """获取脱敏规则""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT * FROM masking_rules WHERE project_id = ?" - params = [project_id] + query = "SELECT * FROM masking_rules WHERE project_id = ?" + params = [project_id] if active_only: - query += " AND is_active = 1" + query += " AND is_active = 1" query += " ORDER BY priority DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() conn.close() - rules = [] + rules = [] for row in rows: rules.append( MaskingRule( - id=row[0], - project_id=row[1], - name=row[2], - rule_type=row[3], - pattern=row[4], - replacement=row[5], - is_active=bool(row[6]), - priority=row[7], - description=row[8], - created_at=row[9], - updated_at=row[10], + id = row[0], + project_id = row[1], + name = row[2], + rule_type = row[3], + pattern = row[4], + replacement = row[5], + is_active = bool(row[6]), + priority = row[7], + description = row[8], + created_at = row[9], + updated_at = row[10], ) ) @@ -811,24 +811,24 @@ class SecurityManager: def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None: """更新脱敏规则""" - allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] + allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - set_clauses = [] - params = [] + set_clauses = [] + params = [] for key, value in kwargs.items(): if key in allowed_fields: - set_clauses.append(f"{key} = ?") + set_clauses.append(f"{key} = ?") params.append(int(value) if key == "is_active" else value) if not set_clauses: conn.close() return None - set_clauses.append("updated_at = ?") + set_clauses.append("updated_at = ?") params.append(datetime.now().isoformat()) params.append(rule_id) @@ -836,7 +836,7 @@ class SecurityManager: f""" UPDATE masking_rules SET {", ".join(set_clauses)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -845,52 +845,52 @@ class SecurityManager: conn.close() # 获取更新后的规则 - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id,)) - row = cursor.fetchone() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id, )) + row = cursor.fetchone() conn.close() if not row: return None return MaskingRule( - id=row[0], - project_id=row[1], - name=row[2], - rule_type=row[3], - pattern=row[4], - replacement=row[5], - is_active=bool(row[6]), - priority=row[7], - description=row[8], - created_at=row[9], - updated_at=row[10], + id = row[0], + project_id = row[1], + name = row[2], + rule_type = row[3], + pattern = row[4], + replacement = row[5], + is_active = bool(row[6]), + priority = row[7], + description = row[8], + created_at = row[9], + updated_at = row[10], ) def delete_masking_rule(self, rule_id: str) -> bool: """删除脱敏规则""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id,)) + cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id, )) - success = cursor.rowcount > 0 + success = cursor.rowcount > 0 conn.commit() conn.close() return success def apply_masking( - self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None + self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None ) -> str: """应用脱敏规则到文本""" - rules = self.get_masking_rules(project_id) + rules = self.get_masking_rules(project_id) if not rules: return text - masked_text = text + masked_text = text for rule in rules: # 如果指定了规则类型,只应用指定类型的规则 @@ -898,7 +898,7 @@ class SecurityManager: continue try: - masked_text = re.sub(rule.pattern, rule.replacement, masked_text) + masked_text = re.sub(rule.pattern, rule.replacement, masked_text) except re.error: # 忽略无效的正则表达式 continue @@ -909,14 +909,14 @@ class SecurityManager: self, entity_data: dict[str, Any], project_id: str ) -> dict[str, Any]: """对实体数据应用脱敏""" - masked_data = entity_data.copy() + masked_data = entity_data.copy() # 对可能包含敏感信息的字段进行脱敏 - sensitive_fields = ["name", "definition", "description", "value"] + sensitive_fields = ["name", "definition", "description", "value"] for f in sensitive_fields: if f in masked_data and isinstance(masked_data[f], str): - masked_data[f] = self.apply_masking(masked_data[f], project_id) + masked_data[f] = self.apply_masking(masked_data[f], project_id) return masked_data @@ -926,30 +926,30 @@ class SecurityManager: self, project_id: str, name: str, - description: str | None = None, - allowed_users: list[str] | None = None, - allowed_roles: list[str] | None = None, - allowed_ips: list[str] | None = None, - time_restrictions: dict | None = None, - max_access_count: int | None = None, - require_approval: bool = False, + description: str | None = None, + allowed_users: list[str] | None = None, + allowed_roles: list[str] | None = None, + allowed_ips: list[str] | None = None, + time_restrictions: dict | None = None, + max_access_count: int | None = None, + require_approval: bool = False, ) -> DataAccessPolicy: """创建数据访问策略""" - policy = DataAccessPolicy( - id=self._generate_id(), - project_id=project_id, - name=name, - description=description, - allowed_users=json.dumps(allowed_users) if allowed_users else None, - allowed_roles=json.dumps(allowed_roles) if allowed_roles else None, - allowed_ips=json.dumps(allowed_ips) if allowed_ips else None, - time_restrictions=json.dumps(time_restrictions) if time_restrictions else None, - max_access_count=max_access_count, - require_approval=require_approval, + policy = DataAccessPolicy( + id = self._generate_id(), + project_id = project_id, + name = name, + description = description, + allowed_users = json.dumps(allowed_users) if allowed_users else None, + allowed_roles = json.dumps(allowed_roles) if allowed_roles else None, + allowed_ips = json.dumps(allowed_ips) if allowed_ips else None, + time_restrictions = json.dumps(time_restrictions) if time_restrictions else None, + max_access_count = max_access_count, + require_approval = require_approval, ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -982,100 +982,100 @@ class SecurityManager: return policy def get_access_policies( - self, project_id: str, active_only: bool = True + self, project_id: str, active_only: bool = True ) -> list[DataAccessPolicy]: """获取数据访问策略""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - query = "SELECT * FROM data_access_policies WHERE project_id = ?" - params = [project_id] + query = "SELECT * FROM data_access_policies WHERE project_id = ?" + params = [project_id] if active_only: - query += " AND is_active = 1" + query += " AND is_active = 1" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() conn.close() - policies = [] + policies = [] for row in rows: policies.append( DataAccessPolicy( - id=row[0], - project_id=row[1], - name=row[2], - description=row[3], - allowed_users=row[4], - allowed_roles=row[5], - allowed_ips=row[6], - time_restrictions=row[7], - max_access_count=row[8], - require_approval=bool(row[9]), - is_active=bool(row[10]), - created_at=row[11], - updated_at=row[12], + id = row[0], + project_id = row[1], + name = row[2], + description = row[3], + allowed_users = row[4], + allowed_roles = row[5], + allowed_ips = row[6], + time_restrictions = row[7], + max_access_count = row[8], + require_approval = bool(row[9]), + is_active = bool(row[10]), + created_at = row[11], + updated_at = row[12], ) ) return policies def check_access_permission( - self, policy_id: str, user_id: str, user_ip: str | None = None + self, policy_id: str, user_id: str, user_ip: str | None = None ) -> tuple[bool, str | None]: """检查访问权限""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( - "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,) + "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id, ) ) - row = cursor.fetchone() + row = cursor.fetchone() conn.close() if not row: return False, "Policy not found or inactive" - policy = DataAccessPolicy( - id=row[0], - project_id=row[1], - name=row[2], - description=row[3], - allowed_users=row[4], - allowed_roles=row[5], - allowed_ips=row[6], - time_restrictions=row[7], - max_access_count=row[8], - require_approval=bool(row[9]), - is_active=bool(row[10]), - created_at=row[11], - updated_at=row[12], + policy = DataAccessPolicy( + id = row[0], + project_id = row[1], + name = row[2], + description = row[3], + allowed_users = row[4], + allowed_roles = row[5], + allowed_ips = row[6], + time_restrictions = row[7], + max_access_count = row[8], + require_approval = bool(row[9]), + is_active = bool(row[10]), + created_at = row[11], + updated_at = row[12], ) # 检查用户白名单 if policy.allowed_users: - allowed = json.loads(policy.allowed_users) + allowed = json.loads(policy.allowed_users) if user_id not in allowed: return False, "User not in allowed list" # 检查IP白名单 if policy.allowed_ips and user_ip: - allowed_ips = json.loads(policy.allowed_ips) - ip_allowed = False + allowed_ips = json.loads(policy.allowed_ips) + ip_allowed = False for ip_pattern in allowed_ips: if self._match_ip_pattern(user_ip, ip_pattern): - ip_allowed = True + ip_allowed = True break if not ip_allowed: return False, "IP not in allowed list" # 检查时间限制 if policy.time_restrictions: - restrictions = json.loads(policy.time_restrictions) - now = datetime.now() + restrictions = json.loads(policy.time_restrictions) + now = datetime.now() if "start_time" in restrictions and "end_time" in restrictions: - current_time = now.strftime("%H:%M") + current_time = now.strftime("%H:%M") if not (restrictions["start_time"] <= current_time <= restrictions["end_time"]): return False, "Access not allowed at this time" @@ -1086,19 +1086,19 @@ class SecurityManager: # 检查是否需要审批 if policy.require_approval: # 检查是否有有效的访问请求 - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ SELECT * FROM access_requests - WHERE policy_id = ? AND user_id = ? AND status = 'approved' + WHERE policy_id = ? AND user_id = ? AND status = 'approved' AND (expires_at IS NULL OR expires_at > ?) """, (policy_id, user_id, datetime.now().isoformat()), ) - request = cursor.fetchone() + request = cursor.fetchone() conn.close() if not request: @@ -1113,7 +1113,7 @@ class SecurityManager: try: if "/" in pattern: # CIDR 表示法 - network = ipaddress.ip_network(pattern, strict=False) + network = ipaddress.ip_network(pattern, strict = False) return ipaddress.ip_address(ip) in network else: # 精确匹配 @@ -1125,20 +1125,20 @@ class SecurityManager: self, policy_id: str, user_id: str, - request_reason: str | None = None, - expires_hours: int = 24, + request_reason: str | None = None, + expires_hours: int = 24, ) -> AccessRequest: """创建访问请求""" - request = AccessRequest( - id=self._generate_id(), - policy_id=policy_id, - user_id=user_id, - request_reason=request_reason, - expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(), + request = AccessRequest( + id = self._generate_id(), + policy_id = policy_id, + user_id = user_id, + request_reason = request_reason, + expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat(), ) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ @@ -1163,20 +1163,20 @@ class SecurityManager: return request def approve_access_request( - self, request_id: str, approved_by: str, expires_hours: int = 24 + self, request_id: str, approved_by: str, expires_hours: int = 24 ) -> AccessRequest | None: """批准访问请求""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() - expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat() - approved_at = datetime.now().isoformat() + expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat() + approved_at = datetime.now().isoformat() cursor.execute( """ UPDATE access_requests - SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ? - WHERE id = ? + SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ? + WHERE id = ? """, (approved_by, approved_at, expires_at, request_id), ) @@ -1184,68 +1184,68 @@ class SecurityManager: conn.commit() # 获取更新后的请求 - cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,)) - row = cursor.fetchone() + cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, )) + row = cursor.fetchone() conn.close() if not row: return None return AccessRequest( - id=row[0], - policy_id=row[1], - user_id=row[2], - request_reason=row[3], - status=row[4], - approved_by=row[5], - approved_at=row[6], - expires_at=row[7], - created_at=row[8], + id = row[0], + policy_id = row[1], + user_id = row[2], + request_reason = row[3], + status = row[4], + approved_by = row[5], + approved_at = row[6], + expires_at = row[7], + created_at = row[8], ) def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None: """拒绝访问请求""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() cursor.execute( """ UPDATE access_requests - SET status = 'rejected', approved_by = ? - WHERE id = ? + SET status = 'rejected', approved_by = ? + WHERE id = ? """, (rejected_by, request_id), ) conn.commit() - cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,)) - row = cursor.fetchone() + cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, )) + row = cursor.fetchone() conn.close() if not row: return None return AccessRequest( - id=row[0], - policy_id=row[1], - user_id=row[2], - request_reason=row[3], - status=row[4], - approved_by=row[5], - approved_at=row[6], - expires_at=row[7], - created_at=row[8], + id = row[0], + policy_id = row[1], + user_id = row[2], + request_reason = row[3], + status = row[4], + approved_by = row[5], + approved_at = row[6], + expires_at = row[7], + created_at = row[8], ) # 全局安全管理器实例 -_security_manager = None +_security_manager = None -def get_security_manager(db_path: str = "insightflow.db") -> SecurityManager: +def get_security_manager(db_path: str = "insightflow.db") -> SecurityManager: """获取安全管理器实例""" global _security_manager if _security_manager is None: - _security_manager = SecurityManager(db_path) + _security_manager = SecurityManager(db_path) return _security_manager diff --git a/backend/subscription_manager.py b/backend/subscription_manager.py index 166febf..87f0f89 100644 --- a/backend/subscription_manager.py +++ b/backend/subscription_manager.py @@ -19,59 +19,59 @@ from datetime import datetime, timedelta from enum import StrEnum from typing import Any -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class SubscriptionStatus(StrEnum): """订阅状态""" - ACTIVE = "active" # 活跃 - CANCELLED = "cancelled" # 已取消 - EXPIRED = "expired" # 已过期 - PAST_DUE = "past_due" # 逾期 - TRIAL = "trial" # 试用中 - PENDING = "pending" # 待支付 + ACTIVE = "active" # 活跃 + CANCELLED = "cancelled" # 已取消 + EXPIRED = "expired" # 已过期 + PAST_DUE = "past_due" # 逾期 + TRIAL = "trial" # 试用中 + PENDING = "pending" # 待支付 class PaymentProvider(StrEnum): """支付提供商""" - STRIPE = "stripe" # Stripe - ALIPAY = "alipay" # 支付宝 - WECHAT = "wechat" # 微信支付 - BANK_TRANSFER = "bank_transfer" # 银行转账 + STRIPE = "stripe" # Stripe + ALIPAY = "alipay" # 支付宝 + WECHAT = "wechat" # 微信支付 + BANK_TRANSFER = "bank_transfer" # 银行转账 class PaymentStatus(StrEnum): """支付状态""" - PENDING = "pending" # 待支付 - PROCESSING = "processing" # 处理中 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 - REFUNDED = "refunded" # 已退款 - PARTIAL_REFUNDED = "partial_refunded" # 部分退款 + PENDING = "pending" # 待支付 + PROCESSING = "processing" # 处理中 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 + REFUNDED = "refunded" # 已退款 + PARTIAL_REFUNDED = "partial_refunded" # 部分退款 class InvoiceStatus(StrEnum): """发票状态""" - DRAFT = "draft" # 草稿 - ISSUED = "issued" # 已开具 - PAID = "paid" # 已支付 - OVERDUE = "overdue" # 逾期 - VOID = "void" # 作废 - CREDIT_NOTE = "credit_note" # 贷项通知单 + DRAFT = "draft" # 草稿 + ISSUED = "issued" # 已开具 + PAID = "paid" # 已支付 + OVERDUE = "overdue" # 逾期 + VOID = "void" # 作废 + CREDIT_NOTE = "credit_note" # 贷项通知单 class RefundStatus(StrEnum): """退款状态""" - PENDING = "pending" # 待处理 - APPROVED = "approved" # 已批准 - REJECTED = "rejected" # 已拒绝 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 + PENDING = "pending" # 待处理 + APPROVED = "approved" # 已批准 + REJECTED = "rejected" # 已拒绝 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 @dataclass @@ -218,7 +218,7 @@ class SubscriptionManager: """订阅与计费管理器""" # 默认订阅计划配置 - DEFAULT_PLANS = { + DEFAULT_PLANS = { "free": { "name": "Free", "tier": "free", @@ -298,7 +298,7 @@ class SubscriptionManager: } # 按量计费单价(CNY) - USAGE_PRICING = { + USAGE_PRICING = { "transcription": { "unit": "minute", "price": 0.5, @@ -313,22 +313,22 @@ class SubscriptionManager: "export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页(PDF导出) } - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_db() self._init_default_plans() def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_db(self) -> None: """初始化数据库表""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 订阅计划表 cursor.execute(""" @@ -528,9 +528,9 @@ class SubscriptionManager: def _init_default_plans(self) -> None: """初始化默认订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() for tier, plan_data in self.DEFAULT_PLANS.items(): cursor.execute( @@ -569,11 +569,11 @@ class SubscriptionManager: def get_plan(self, plan_id: str) -> SubscriptionPlan | None: """获取订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id, )) + row = cursor.fetchone() if row: return self._row_to_plan(row) @@ -584,13 +584,13 @@ class SubscriptionManager: def get_plan_by_tier(self, tier: str) -> SubscriptionPlan | None: """通过层级获取订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( - "SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,) + "SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier, ) ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_plan(row) @@ -599,20 +599,20 @@ class SubscriptionManager: finally: conn.close() - def list_plans(self, include_inactive: bool = False) -> list[SubscriptionPlan]: + def list_plans(self, include_inactive: bool = False) -> list[SubscriptionPlan]: """列出所有订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() if include_inactive: cursor.execute("SELECT * FROM subscription_plans ORDER BY price_monthly") else: cursor.execute( - "SELECT * FROM subscription_plans WHERE is_active = 1 ORDER BY price_monthly" + "SELECT * FROM subscription_plans WHERE is_active = 1 ORDER BY price_monthly" ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_plan(row) for row in rows] finally: @@ -625,32 +625,32 @@ class SubscriptionManager: description: str, price_monthly: float, price_yearly: float, - currency: str = "CNY", - features: list[str] = None, - limits: dict[str, Any] = None, + currency: str = "CNY", + features: list[str] = None, + limits: dict[str, Any] = None, ) -> SubscriptionPlan: """创建新订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - plan_id = str(uuid.uuid4()) + plan_id = str(uuid.uuid4()) - plan = SubscriptionPlan( - id=plan_id, - name=name, - tier=tier, - description=description, - price_monthly=price_monthly, - price_yearly=price_yearly, - currency=currency, - features=features or [], - limits=limits or {}, - is_active=True, - created_at=datetime.now(), - updated_at=datetime.now(), - metadata={}, + plan = SubscriptionPlan( + id = plan_id, + name = name, + tier = tier, + description = description, + price_monthly = price_monthly, + price_yearly = price_yearly, + currency = currency, + features = features or [], + limits = limits or {}, + is_active = True, + created_at = datetime.now(), + updated_at = datetime.now(), + metadata = {}, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO subscription_plans @@ -688,16 +688,16 @@ class SubscriptionManager: def update_plan(self, plan_id: str, **kwargs) -> SubscriptionPlan | None: """更新订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - plan = self.get_plan(plan_id) + plan = self.get_plan(plan_id) if not plan: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "name", "description", "price_monthly", @@ -710,7 +710,7 @@ class SubscriptionManager: for key, value in kwargs.items(): if key in allowed_fields: - updates.append(f"{key} = ?") + updates.append(f"{key} = ?") if key in ["features", "limits"]: params.append(json.dumps(value) if value else "{}") elif key == "is_active": @@ -721,15 +721,15 @@ class SubscriptionManager: if not updates: return plan - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(plan_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE subscription_plans SET {", ".join(updates)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -746,67 +746,67 @@ class SubscriptionManager: self, tenant_id: str, plan_id: str, - payment_provider: str | None = None, - trial_days: int = 0, - billing_cycle: str = "monthly", + payment_provider: str | None = None, + trial_days: int = 0, + billing_cycle: str = "monthly", ) -> Subscription: """创建新订阅""" - conn = self._get_connection() + conn = self._get_connection() try: # 检查是否已有活跃订阅 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM subscriptions - WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending') + WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending') """, - (tenant_id,), + (tenant_id, ), ) - existing = cursor.fetchone() + existing = cursor.fetchone() if existing: raise ValueError(f"Tenant {tenant_id} already has an active subscription") # 获取计划信息 - plan = self.get_plan(plan_id) + plan = self.get_plan(plan_id) if not plan: raise ValueError(f"Plan {plan_id} not found") - subscription_id = str(uuid.uuid4()) - now = datetime.now() + subscription_id = str(uuid.uuid4()) + now = datetime.now() # 计算周期 if billing_cycle == "yearly": - period_end = now + timedelta(days=365) + period_end = now + timedelta(days = 365) else: - period_end = now + timedelta(days=30) + period_end = now + timedelta(days = 30) # 试用处理 - trial_start = None - trial_end = None + trial_start = None + trial_end = None if trial_days > 0: - trial_start = now - trial_end = now + timedelta(days=trial_days) - status = SubscriptionStatus.TRIAL.value + trial_start = now + trial_end = now + timedelta(days = trial_days) + status = SubscriptionStatus.TRIAL.value else: - status = SubscriptionStatus.PENDING.value + status = SubscriptionStatus.PENDING.value - subscription = Subscription( - id=subscription_id, - tenant_id=tenant_id, - plan_id=plan_id, - status=status, - current_period_start=now, - current_period_end=period_end, - cancel_at_period_end=False, - canceled_at=None, - trial_start=trial_start, - trial_end=trial_end, - payment_provider=payment_provider, - provider_subscription_id=None, - created_at=now, - updated_at=now, - metadata={"billing_cycle": billing_cycle}, + subscription = Subscription( + id = subscription_id, + tenant_id = tenant_id, + plan_id = plan_id, + status = status, + current_period_start = now, + current_period_end = period_end, + cancel_at_period_end = False, + canceled_at = None, + trial_start = trial_start, + trial_end = trial_end, + payment_provider = payment_provider, + provider_subscription_id = None, + created_at = now, + updated_at = now, + metadata = {"billing_cycle": billing_cycle}, ) cursor.execute( @@ -837,7 +837,7 @@ class SubscriptionManager: ) # 创建发票 - amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly + amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly if amount > 0 and trial_days == 0: self._create_invoice_internal( conn, @@ -875,11 +875,11 @@ class SubscriptionManager: def get_subscription(self, subscription_id: str) -> Subscription | None: """获取订阅信息""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id, )) + row = cursor.fetchone() if row: return self._row_to_subscription(row) @@ -890,18 +890,18 @@ class SubscriptionManager: def get_tenant_subscription(self, tenant_id: str) -> Subscription | None: """获取租户的当前订阅""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM subscriptions - WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending') + WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending') ORDER BY created_at DESC LIMIT 1 """, - (tenant_id,), + (tenant_id, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_subscription(row) @@ -912,16 +912,16 @@ class SubscriptionManager: def update_subscription(self, subscription_id: str, **kwargs) -> Subscription | None: """更新订阅""" - conn = self._get_connection() + conn = self._get_connection() try: - subscription = self.get_subscription(subscription_id) + subscription = self.get_subscription(subscription_id) if not subscription: return None - updates = [] - params = [] + updates = [] + params = [] - allowed_fields = [ + allowed_fields = [ "status", "current_period_start", "current_period_end", @@ -934,7 +934,7 @@ class SubscriptionManager: for key, value in kwargs.items(): if key in allowed_fields: - updates.append(f"{key} = ?") + updates.append(f"{key} = ?") if key == "cancel_at_period_end": params.append(int(value)) else: @@ -943,15 +943,15 @@ class SubscriptionManager: if not updates: return subscription - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(subscription_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE subscriptions SET {", ".join(updates)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -963,36 +963,36 @@ class SubscriptionManager: conn.close() def cancel_subscription( - self, subscription_id: str, at_period_end: bool = True + self, subscription_id: str, at_period_end: bool = True ) -> Subscription | None: """取消订阅""" - conn = self._get_connection() + conn = self._get_connection() try: - subscription = self.get_subscription(subscription_id) + subscription = self.get_subscription(subscription_id) if not subscription: return None - now = datetime.now() + now = datetime.now() if at_period_end: # 在周期结束时取消 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE subscriptions - SET cancel_at_period_end = 1, canceled_at = ?, updated_at = ? - WHERE id = ? + SET cancel_at_period_end = 1, canceled_at = ?, updated_at = ? + WHERE id = ? """, (now, now, subscription_id), ) else: # 立即取消 - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE subscriptions - SET status = 'cancelled', canceled_at = ?, updated_at = ? - WHERE id = ? + SET status = 'cancelled', canceled_at = ?, updated_at = ? + WHERE id = ? """, (now, now, subscription_id), ) @@ -1017,34 +1017,34 @@ class SubscriptionManager: conn.close() def change_plan( - self, subscription_id: str, new_plan_id: str, prorate: bool = True + self, subscription_id: str, new_plan_id: str, prorate: bool = True ) -> Subscription | None: """更改订阅计划""" - conn = self._get_connection() + conn = self._get_connection() try: - subscription = self.get_subscription(subscription_id) + subscription = self.get_subscription(subscription_id) if not subscription: return None - old_plan = self.get_plan(subscription.plan_id) - new_plan = self.get_plan(new_plan_id) + old_plan = self.get_plan(subscription.plan_id) + new_plan = self.get_plan(new_plan_id) if not new_plan: raise ValueError(f"Plan {new_plan_id} not found") - now = datetime.now() + now = datetime.now() # 按比例计算差价(简化实现) if prorate and old_plan: # 这里应该实现实际的按比例计算逻辑 pass - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE subscriptions - SET plan_id = ?, updated_at = ? - WHERE id = ? + SET plan_id = ?, updated_at = ? + WHERE id = ? """, (new_plan_id, now, subscription_id), ) @@ -1076,29 +1076,29 @@ class SubscriptionManager: resource_type: str, quantity: float, unit: str, - description: str | None = None, - metadata: dict | None = None, + description: str | None = None, + metadata: dict | None = None, ) -> UsageRecord: """记录用量""" - conn = self._get_connection() + conn = self._get_connection() try: # 计算费用 - cost = self._calculate_usage_cost(resource_type, quantity) + cost = self._calculate_usage_cost(resource_type, quantity) - record_id = str(uuid.uuid4()) - record = UsageRecord( - id=record_id, - tenant_id=tenant_id, - resource_type=resource_type, - quantity=quantity, - unit=unit, - recorded_at=datetime.now(), - cost=cost, - description=description, - metadata=metadata or {}, + record_id = str(uuid.uuid4()) + record = UsageRecord( + id = record_id, + tenant_id = tenant_id, + resource_type = resource_type, + quantity = quantity, + unit = unit, + recorded_at = datetime.now(), + cost = cost, + description = description, + metadata = metadata or {}, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO usage_records @@ -1125,23 +1125,23 @@ class SubscriptionManager: conn.close() def get_usage_summary( - self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None + self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None ) -> dict[str, Any]: """获取用量汇总""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = """ + query = """ SELECT resource_type, SUM(quantity) as total_quantity, SUM(cost) as total_cost, COUNT(*) as record_count FROM usage_records - WHERE tenant_id = ? + WHERE tenant_id = ? """ - params = [tenant_id] + params = [tenant_id] if start_date: query += " AND recorded_at >= ?" @@ -1153,13 +1153,13 @@ class SubscriptionManager: query += " GROUP BY resource_type" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() - summary = {} - total_cost = 0 + summary = {} + total_cost = 0 for row in rows: - summary[row["resource_type"]] = { + summary[row["resource_type"]] = { "quantity": row["total_quantity"], "cost": row["total_cost"], "records": row["record_count"], @@ -1181,12 +1181,12 @@ class SubscriptionManager: def _calculate_usage_cost(self, resource_type: str, quantity: float) -> float: """计算用量费用""" - pricing = self.USAGE_PRICING.get(resource_type) + pricing = self.USAGE_PRICING.get(resource_type) if not pricing: return 0.0 # 扣除免费额度 - chargeable = max(0, quantity - pricing.get("free_quota", 0)) + chargeable = max(0, quantity - pricing.get("free_quota", 0)) # 计算费用 if pricing["unit"] == "1000_calls": @@ -1202,37 +1202,37 @@ class SubscriptionManager: amount: float, currency: str, provider: str, - subscription_id: str | None = None, - invoice_id: str | None = None, - payment_method: str | None = None, - payment_details: dict | None = None, + subscription_id: str | None = None, + invoice_id: str | None = None, + payment_method: str | None = None, + payment_details: dict | None = None, ) -> Payment: """创建支付记录""" - conn = self._get_connection() + conn = self._get_connection() try: - payment_id = str(uuid.uuid4()) - now = datetime.now() + payment_id = str(uuid.uuid4()) + now = datetime.now() - payment = Payment( - id=payment_id, - tenant_id=tenant_id, - subscription_id=subscription_id, - invoice_id=invoice_id, - amount=amount, - currency=currency, - provider=provider, - provider_payment_id=None, - status=PaymentStatus.PENDING.value, - payment_method=payment_method, - payment_details=payment_details or {}, - paid_at=None, - failed_at=None, - failure_reason=None, - created_at=now, - updated_at=now, + payment = Payment( + id = payment_id, + tenant_id = tenant_id, + subscription_id = subscription_id, + invoice_id = invoice_id, + amount = amount, + currency = currency, + provider = provider, + provider_payment_id = None, + status = PaymentStatus.PENDING.value, + payment_method = payment_method, + payment_details = payment_details or {}, + paid_at = None, + failed_at = None, + failure_reason = None, + created_at = now, + updated_at = now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO payments @@ -1268,23 +1268,23 @@ class SubscriptionManager: conn.close() def confirm_payment( - self, payment_id: str, provider_payment_id: str | None = None + self, payment_id: str, provider_payment_id: str | None = None ) -> Payment | None: """确认支付完成""" - conn = self._get_connection() + conn = self._get_connection() try: - payment = self._get_payment_internal(conn, payment_id) + payment = self._get_payment_internal(conn, payment_id) if not payment: return None - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE payments - SET status = 'completed', provider_payment_id = ?, paid_at = ?, updated_at = ? - WHERE id = ? + SET status = 'completed', provider_payment_id = ?, paid_at = ?, updated_at = ? + WHERE id = ? """, (provider_payment_id, now, now, payment_id), ) @@ -1294,8 +1294,8 @@ class SubscriptionManager: cursor.execute( """ UPDATE invoices - SET status = 'paid', amount_paid = amount_due, paid_at = ? - WHERE id = ? + SET status = 'paid', amount_paid = amount_due, paid_at = ? + WHERE id = ? """, (now, payment.invoice_id), ) @@ -1305,8 +1305,8 @@ class SubscriptionManager: cursor.execute( """ UPDATE subscriptions - SET status = 'active', updated_at = ? - WHERE id = ? AND status = 'pending' + SET status = 'active', updated_at = ? + WHERE id = ? AND status = 'pending' """, (now, payment.subscription_id), ) @@ -1332,16 +1332,16 @@ class SubscriptionManager: def fail_payment(self, payment_id: str, reason: str) -> Payment | None: """标记支付失败""" - conn = self._get_connection() + conn = self._get_connection() try: - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE payments - SET status = 'failed', failure_reason = ?, failed_at = ?, updated_at = ? - WHERE id = ? + SET status = 'failed', failure_reason = ?, failed_at = ?, updated_at = ? + WHERE id = ? """, (reason, now, now, payment_id), ) @@ -1354,32 +1354,32 @@ class SubscriptionManager: def get_payment(self, payment_id: str) -> Payment | None: """获取支付记录""" - conn = self._get_connection() + conn = self._get_connection() try: return self._get_payment_internal(conn, payment_id) finally: conn.close() def list_payments( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Payment]: """列出支付记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM payments WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM payments WHERE tenant_id = ?" + params = [tenant_id] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_payment(row) for row in rows] @@ -1388,9 +1388,9 @@ class SubscriptionManager: def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None: """内部方法:获取支付记录""" - cursor = conn.cursor() - cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id, )) + row = cursor.fetchone() if row: return self._row_to_payment(row) @@ -1408,36 +1408,36 @@ class SubscriptionManager: period_start: datetime, period_end: datetime, description: str, - line_items: list[dict] | None = None, + line_items: list[dict] | None = None, ) -> Invoice: """内部方法:创建发票""" - invoice_id = str(uuid.uuid4()) - invoice_number = self._generate_invoice_number() - now = datetime.now() - due_date = now + timedelta(days=7) # 7天付款期限 + invoice_id = str(uuid.uuid4()) + invoice_number = self._generate_invoice_number() + now = datetime.now() + due_date = now + timedelta(days = 7) # 7天付款期限 - invoice = Invoice( - id=invoice_id, - tenant_id=tenant_id, - subscription_id=subscription_id, - invoice_number=invoice_number, - status=InvoiceStatus.DRAFT.value, - amount_due=amount, - amount_paid=0, - currency=currency, - period_start=period_start, - period_end=period_end, - description=description, - line_items=line_items or [{"description": description, "amount": amount}], - due_date=due_date, - paid_at=None, - voided_at=None, - void_reason=None, - created_at=now, - updated_at=now, + invoice = Invoice( + id = invoice_id, + tenant_id = tenant_id, + subscription_id = subscription_id, + invoice_number = invoice_number, + status = InvoiceStatus.DRAFT.value, + amount_due = amount, + amount_paid = 0, + currency = currency, + period_start = period_start, + period_end = period_end, + description = description, + line_items = line_items or [{"description": description, "amount": amount}], + due_date = due_date, + paid_at = None, + voided_at = None, + void_reason = None, + created_at = now, + updated_at = now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO invoices @@ -1472,11 +1472,11 @@ class SubscriptionManager: def get_invoice(self, invoice_id: str) -> Invoice | None: """获取发票""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id, )) + row = cursor.fetchone() if row: return self._row_to_invoice(row) @@ -1487,11 +1487,11 @@ class SubscriptionManager: def get_invoice_by_number(self, invoice_number: str) -> Invoice | None: """通过发票号获取发票""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number, )) + row = cursor.fetchone() if row: return self._row_to_invoice(row) @@ -1501,25 +1501,25 @@ class SubscriptionManager: conn.close() def list_invoices( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Invoice]: """列出发票""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM invoices WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM invoices WHERE tenant_id = ?" + params = [tenant_id] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_invoice(row) for row in rows] @@ -1528,23 +1528,23 @@ class SubscriptionManager: def void_invoice(self, invoice_id: str, reason: str) -> Invoice | None: """作废发票""" - conn = self._get_connection() + conn = self._get_connection() try: - invoice = self.get_invoice(invoice_id) + invoice = self.get_invoice(invoice_id) if not invoice: return None if invoice.status == InvoiceStatus.PAID.value: raise ValueError("Cannot void a paid invoice") - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE invoices - SET status = 'void', voided_at = ?, void_reason = ?, updated_at = ? - WHERE id = ? + SET status = 'void', voided_at = ?, void_reason = ?, updated_at = ? + WHERE id = ? """, (now, reason, now, invoice_id), ) @@ -1557,21 +1557,21 @@ class SubscriptionManager: def _generate_invoice_number(self) -> str: """生成发票号""" - now = datetime.now() - prefix = f"INV-{now.strftime('%Y%m')}" + now = datetime.now() + prefix = f"INV-{now.strftime('%Y%m')}" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT COUNT(*) as count FROM invoices WHERE invoice_number LIKE ? """, - (f"{prefix}%",), + (f"{prefix}%", ), ) - row = cursor.fetchone() - count = row["count"] + 1 + row = cursor.fetchone() + count = row["count"] + 1 return f"{prefix}-{count:06d}" @@ -1584,10 +1584,10 @@ class SubscriptionManager: self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str ) -> Refund: """申请退款""" - conn = self._get_connection() + conn = self._get_connection() try: # 验证支付记录 - payment = self._get_payment_internal(conn, payment_id) + payment = self._get_payment_internal(conn, payment_id) if not payment: raise ValueError(f"Payment {payment_id} not found") @@ -1600,30 +1600,30 @@ class SubscriptionManager: if amount > payment.amount: raise ValueError("Refund amount cannot exceed payment amount") - refund_id = str(uuid.uuid4()) - now = datetime.now() + refund_id = str(uuid.uuid4()) + now = datetime.now() - refund = Refund( - id=refund_id, - tenant_id=tenant_id, - payment_id=payment_id, - invoice_id=payment.invoice_id, - amount=amount, - currency=payment.currency, - reason=reason, - status=RefundStatus.PENDING.value, - requested_by=requested_by, - requested_at=now, - approved_by=None, - approved_at=None, - completed_at=None, - provider_refund_id=None, - metadata={}, - created_at=now, - updated_at=now, + refund = Refund( + id = refund_id, + tenant_id = tenant_id, + payment_id = payment_id, + invoice_id = payment.invoice_id, + amount = amount, + currency = payment.currency, + reason = reason, + status = RefundStatus.PENDING.value, + requested_by = requested_by, + requested_at = now, + approved_by = None, + approved_at = None, + completed_at = None, + provider_refund_id = None, + metadata = {}, + created_at = now, + updated_at = now, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO refunds @@ -1662,23 +1662,23 @@ class SubscriptionManager: def approve_refund(self, refund_id: str, approved_by: str) -> Refund | None: """批准退款""" - conn = self._get_connection() + conn = self._get_connection() try: - refund = self._get_refund_internal(conn, refund_id) + refund = self._get_refund_internal(conn, refund_id) if not refund: return None if refund.status != RefundStatus.PENDING.value: raise ValueError("Can only approve pending refunds") - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE refunds - SET status = 'approved', approved_by = ?, approved_at = ?, updated_at = ? - WHERE id = ? + SET status = 'approved', approved_by = ?, approved_at = ?, updated_at = ? + WHERE id = ? """, (approved_by, now, now, refund_id), ) @@ -1690,23 +1690,23 @@ class SubscriptionManager: conn.close() def complete_refund( - self, refund_id: str, provider_refund_id: str | None = None + self, refund_id: str, provider_refund_id: str | None = None ) -> Refund | None: """完成退款""" - conn = self._get_connection() + conn = self._get_connection() try: - refund = self._get_refund_internal(conn, refund_id) + refund = self._get_refund_internal(conn, refund_id) if not refund: return None - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE refunds - SET status = 'completed', provider_refund_id = ?, completed_at = ?, updated_at = ? - WHERE id = ? + SET status = 'completed', provider_refund_id = ?, completed_at = ?, updated_at = ? + WHERE id = ? """, (provider_refund_id, now, now, refund_id), ) @@ -1715,8 +1715,8 @@ class SubscriptionManager: cursor.execute( """ UPDATE payments - SET status = 'refunded', updated_at = ? - WHERE id = ? + SET status = 'refunded', updated_at = ? + WHERE id = ? """, (now, refund.payment_id), ) @@ -1742,20 +1742,20 @@ class SubscriptionManager: def reject_refund(self, refund_id: str, reason: str) -> Refund | None: """拒绝退款""" - conn = self._get_connection() + conn = self._get_connection() try: - refund = self._get_refund_internal(conn, refund_id) + refund = self._get_refund_internal(conn, refund_id) if not refund: return None - now = datetime.now() + now = datetime.now() - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE refunds - SET status = 'rejected', metadata = json_set(metadata, '$.rejection_reason', ?), updated_at = ? - WHERE id = ? + SET status = 'rejected', metadata = json_set(metadata, '$.rejection_reason', ?), updated_at = ? + WHERE id = ? """, (reason, now, refund_id), ) @@ -1768,32 +1768,32 @@ class SubscriptionManager: def get_refund(self, refund_id: str) -> Refund | None: """获取退款记录""" - conn = self._get_connection() + conn = self._get_connection() try: return self._get_refund_internal(conn, refund_id) finally: conn.close() def list_refunds( - self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 + self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Refund]: """列出退款记录""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM refunds WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM refunds WHERE tenant_id = ?" + params = [tenant_id] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_refund(row) for row in rows] @@ -1802,9 +1802,9 @@ class SubscriptionManager: def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None: """内部方法:获取退款记录""" - cursor = conn.cursor() - cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id, )) + row = cursor.fetchone() if row: return self._row_to_refund(row) @@ -1822,11 +1822,11 @@ class SubscriptionManager: description: str, reference_id: str, balance_after: float, - ): + ) -> None: """内部方法:添加账单历史""" - history_id = str(uuid.uuid4()) + history_id = str(uuid.uuid4()) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO billing_history @@ -1850,18 +1850,18 @@ class SubscriptionManager: def get_billing_history( self, tenant_id: str, - start_date: datetime | None = None, - end_date: datetime | None = None, - limit: int = 100, - offset: int = 0, + start_date: datetime | None = None, + end_date: datetime | None = None, + limit: int = 100, + offset: int = 0, ) -> list[BillingHistory]: """获取账单历史""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM billing_history WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM billing_history WHERE tenant_id = ?" + params = [tenant_id] if start_date: query += " AND created_at >= ?" @@ -1874,7 +1874,7 @@ class SubscriptionManager: params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_billing_history(row) for row in rows] @@ -1889,7 +1889,7 @@ class SubscriptionManager: plan_id: str, success_url: str, cancel_url: str, - billing_cycle: str = "monthly", + billing_cycle: str = "monthly", ) -> dict[str, Any]: """创建 Stripe Checkout 会话(占位实现)""" # 这里应该集成 Stripe SDK @@ -1902,12 +1902,12 @@ class SubscriptionManager: } def create_alipay_order( - self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" + self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" ) -> dict[str, Any]: """创建支付宝订单(占位实现)""" # 这里应该集成支付宝 SDK - plan = self.get_plan(plan_id) - amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly + plan = self.get_plan(plan_id) + amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly return { "order_id": f"ALI{datetime.now().strftime('%Y%m%d%H%M%S')}{uuid.uuid4().hex[:8].upper()}", @@ -1919,12 +1919,12 @@ class SubscriptionManager: } def create_wechat_order( - self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" + self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly" ) -> dict[str, Any]: """创建微信支付订单(占位实现)""" # 这里应该集成微信支付 SDK - plan = self.get_plan(plan_id) - amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly + plan = self.get_plan(plan_id) + amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly return { "order_id": f"WX{datetime.now().strftime('%Y%m%d%H%M%S')}{uuid.uuid4().hex[:8].upper()}", @@ -1940,7 +1940,7 @@ class SubscriptionManager: # 这里应该实现实际的 Webhook 处理逻辑 logger.info(f"Received webhook from {provider}: {payload.get('event_type', 'unknown')}") - event_type = payload.get("event_type", "") + event_type = payload.get("event_type", "") if provider == "stripe": if event_type == "checkout.session.completed": @@ -1962,126 +1962,126 @@ class SubscriptionManager: def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan: """数据库行转换为 SubscriptionPlan 对象""" return SubscriptionPlan( - id=row["id"], - name=row["name"], - tier=row["tier"], - description=row["description"] or "", - price_monthly=row["price_monthly"], - price_yearly=row["price_yearly"], - currency=row["currency"], - features=json.loads(row["features"] or "[]"), - limits=json.loads(row["limits"] or "{}"), - is_active=bool(row["is_active"]), - created_at=( + id = row["id"], + name = row["name"], + tier = row["tier"], + description = row["description"] or "", + price_monthly = row["price_monthly"], + price_yearly = row["price_yearly"], + currency = row["currency"], + features = json.loads(row["features"] or "[]"), + limits = json.loads(row["limits"] or "{}"), + is_active = bool(row["is_active"]), + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - metadata=json.loads(row["metadata"] or "{}"), + metadata = json.loads(row["metadata"] or "{}"), ) def _row_to_subscription(self, row: sqlite3.Row) -> Subscription: """数据库行转换为 Subscription 对象""" return Subscription( - id=row["id"], - tenant_id=row["tenant_id"], - plan_id=row["plan_id"], - status=row["status"], - current_period_start=( + id = row["id"], + tenant_id = row["tenant_id"], + plan_id = row["plan_id"], + status = row["status"], + current_period_start = ( datetime.fromisoformat(row["current_period_start"]) if row["current_period_start"] and isinstance(row["current_period_start"], str) else row["current_period_start"] ), - current_period_end=( + current_period_end = ( datetime.fromisoformat(row["current_period_end"]) if row["current_period_end"] and isinstance(row["current_period_end"], str) else row["current_period_end"] ), - cancel_at_period_end=bool(row["cancel_at_period_end"]), - canceled_at=( + cancel_at_period_end = bool(row["cancel_at_period_end"]), + canceled_at = ( datetime.fromisoformat(row["canceled_at"]) if row["canceled_at"] and isinstance(row["canceled_at"], str) else row["canceled_at"] ), - trial_start=( + trial_start = ( datetime.fromisoformat(row["trial_start"]) if row["trial_start"] and isinstance(row["trial_start"], str) else row["trial_start"] ), - trial_end=( + trial_end = ( datetime.fromisoformat(row["trial_end"]) if row["trial_end"] and isinstance(row["trial_end"], str) else row["trial_end"] ), - payment_provider=row["payment_provider"], - provider_subscription_id=row["provider_subscription_id"], - created_at=( + payment_provider = row["payment_provider"], + provider_subscription_id = row["provider_subscription_id"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - metadata=json.loads(row["metadata"] or "{}"), + metadata = json.loads(row["metadata"] or "{}"), ) def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord: """数据库行转换为 UsageRecord 对象""" return UsageRecord( - id=row["id"], - tenant_id=row["tenant_id"], - resource_type=row["resource_type"], - quantity=row["quantity"], - unit=row["unit"], - recorded_at=( + id = row["id"], + tenant_id = row["tenant_id"], + resource_type = row["resource_type"], + quantity = row["quantity"], + unit = row["unit"], + recorded_at = ( datetime.fromisoformat(row["recorded_at"]) if isinstance(row["recorded_at"], str) else row["recorded_at"] ), - cost=row["cost"], - description=row["description"], - metadata=json.loads(row["metadata"] or "{}"), + cost = row["cost"], + description = row["description"], + metadata = json.loads(row["metadata"] or "{}"), ) def _row_to_payment(self, row: sqlite3.Row) -> Payment: """数据库行转换为 Payment 对象""" return Payment( - id=row["id"], - tenant_id=row["tenant_id"], - subscription_id=row["subscription_id"], - invoice_id=row["invoice_id"], - amount=row["amount"], - currency=row["currency"], - provider=row["provider"], - provider_payment_id=row["provider_payment_id"], - status=row["status"], - payment_method=row["payment_method"], - payment_details=json.loads(row["payment_details"] or "{}"), - paid_at=( + id = row["id"], + tenant_id = row["tenant_id"], + subscription_id = row["subscription_id"], + invoice_id = row["invoice_id"], + amount = row["amount"], + currency = row["currency"], + provider = row["provider"], + provider_payment_id = row["provider_payment_id"], + status = row["status"], + payment_method = row["payment_method"], + payment_details = json.loads(row["payment_details"] or "{}"), + paid_at = ( datetime.fromisoformat(row["paid_at"]) if row["paid_at"] and isinstance(row["paid_at"], str) else row["paid_at"] ), - failed_at=( + failed_at = ( datetime.fromisoformat(row["failed_at"]) if row["failed_at"] and isinstance(row["failed_at"], str) else row["failed_at"] ), - failure_reason=row["failure_reason"], - created_at=( + failure_reason = row["failure_reason"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2091,48 +2091,48 @@ class SubscriptionManager: def _row_to_invoice(self, row: sqlite3.Row) -> Invoice: """数据库行转换为 Invoice 对象""" return Invoice( - id=row["id"], - tenant_id=row["tenant_id"], - subscription_id=row["subscription_id"], - invoice_number=row["invoice_number"], - status=row["status"], - amount_due=row["amount_due"], - amount_paid=row["amount_paid"], - currency=row["currency"], - period_start=( + id = row["id"], + tenant_id = row["tenant_id"], + subscription_id = row["subscription_id"], + invoice_number = row["invoice_number"], + status = row["status"], + amount_due = row["amount_due"], + amount_paid = row["amount_paid"], + currency = row["currency"], + period_start = ( datetime.fromisoformat(row["period_start"]) if row["period_start"] and isinstance(row["period_start"], str) else row["period_start"] ), - period_end=( + period_end = ( datetime.fromisoformat(row["period_end"]) if row["period_end"] and isinstance(row["period_end"], str) else row["period_end"] ), - description=row["description"], - line_items=json.loads(row["line_items"] or "[]"), - due_date=( + description = row["description"], + line_items = json.loads(row["line_items"] or "[]"), + due_date = ( datetime.fromisoformat(row["due_date"]) if row["due_date"] and isinstance(row["due_date"], str) else row["due_date"] ), - paid_at=( + paid_at = ( datetime.fromisoformat(row["paid_at"]) if row["paid_at"] and isinstance(row["paid_at"], str) else row["paid_at"] ), - voided_at=( + voided_at = ( datetime.fromisoformat(row["voided_at"]) if row["voided_at"] and isinstance(row["voided_at"], str) else row["voided_at"] ), - void_reason=row["void_reason"], - created_at=( + void_reason = row["void_reason"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2142,39 +2142,39 @@ class SubscriptionManager: def _row_to_refund(self, row: sqlite3.Row) -> Refund: """数据库行转换为 Refund 对象""" return Refund( - id=row["id"], - tenant_id=row["tenant_id"], - payment_id=row["payment_id"], - invoice_id=row["invoice_id"], - amount=row["amount"], - currency=row["currency"], - reason=row["reason"], - status=row["status"], - requested_by=row["requested_by"], - requested_at=( + id = row["id"], + tenant_id = row["tenant_id"], + payment_id = row["payment_id"], + invoice_id = row["invoice_id"], + amount = row["amount"], + currency = row["currency"], + reason = row["reason"], + status = row["status"], + requested_by = row["requested_by"], + requested_at = ( datetime.fromisoformat(row["requested_at"]) if isinstance(row["requested_at"], str) else row["requested_at"] ), - approved_by=row["approved_by"], - approved_at=( + approved_by = row["approved_by"], + approved_at = ( datetime.fromisoformat(row["approved_at"]) if row["approved_at"] and isinstance(row["approved_at"], str) else row["approved_at"] ), - completed_at=( + completed_at = ( datetime.fromisoformat(row["completed_at"]) if row["completed_at"] and isinstance(row["completed_at"], str) else row["completed_at"] ), - provider_refund_id=row["provider_refund_id"], - metadata=json.loads(row["metadata"] or "{}"), - created_at=( + provider_refund_id = row["provider_refund_id"], + metadata = json.loads(row["metadata"] or "{}"), + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -2184,30 +2184,30 @@ class SubscriptionManager: def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory: """数据库行转换为 BillingHistory 对象""" return BillingHistory( - id=row["id"], - tenant_id=row["tenant_id"], - type=row["type"], - amount=row["amount"], - currency=row["currency"], - description=row["description"], - reference_id=row["reference_id"], - balance_after=row["balance_after"], - created_at=( + id = row["id"], + tenant_id = row["tenant_id"], + type = row["type"], + amount = row["amount"], + currency = row["currency"], + description = row["description"], + reference_id = row["reference_id"], + balance_after = row["balance_after"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - metadata=json.loads(row["metadata"] or "{}"), + metadata = json.loads(row["metadata"] or "{}"), ) # 全局订阅管理器实例 -subscription_manager = None +subscription_manager = None -def get_subscription_manager(db_path: str = "insightflow.db") -> SubscriptionManager: +def get_subscription_manager(db_path: str = "insightflow.db") -> SubscriptionManager: """获取订阅管理器实例(单例模式)""" global subscription_manager if subscription_manager is None: - subscription_manager = SubscriptionManager(db_path) + subscription_manager = SubscriptionManager(db_path) return subscription_manager diff --git a/backend/tenant_manager.py b/backend/tenant_manager.py index 4611248..26c2034 100644 --- a/backend/tenant_manager.py +++ b/backend/tenant_manager.py @@ -21,63 +21,63 @@ from datetime import datetime from enum import StrEnum from typing import Any -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class TenantLimits: """租户资源限制常量""" - FREE_MAX_PROJECTS = 3 - FREE_MAX_STORAGE_MB = 100 - FREE_MAX_TRANSCRIPTION_MINUTES = 60 - FREE_MAX_API_CALLS_PER_DAY = 100 - FREE_MAX_TEAM_MEMBERS = 2 - FREE_MAX_ENTITIES = 100 + FREE_MAX_PROJECTS = 3 + FREE_MAX_STORAGE_MB = 100 + FREE_MAX_TRANSCRIPTION_MINUTES = 60 + FREE_MAX_API_CALLS_PER_DAY = 100 + FREE_MAX_TEAM_MEMBERS = 2 + FREE_MAX_ENTITIES = 100 - PRO_MAX_PROJECTS = 20 - PRO_MAX_STORAGE_MB = 1000 - PRO_MAX_TRANSCRIPTION_MINUTES = 600 - PRO_MAX_API_CALLS_PER_DAY = 10000 - PRO_MAX_TEAM_MEMBERS = 10 - PRO_MAX_ENTITIES = 1000 + PRO_MAX_PROJECTS = 20 + PRO_MAX_STORAGE_MB = 1000 + PRO_MAX_TRANSCRIPTION_MINUTES = 600 + PRO_MAX_API_CALLS_PER_DAY = 10000 + PRO_MAX_TEAM_MEMBERS = 10 + PRO_MAX_ENTITIES = 1000 - UNLIMITED = -1 + UNLIMITED = -1 class TenantStatus(StrEnum): """租户状态""" - ACTIVE = "active" # 活跃 - SUSPENDED = "suspended" # 暂停 - TRIAL = "trial" # 试用 - EXPIRED = "expired" # 过期 - PENDING = "pending" # 待激活 + ACTIVE = "active" # 活跃 + SUSPENDED = "suspended" # 暂停 + TRIAL = "trial" # 试用 + EXPIRED = "expired" # 过期 + PENDING = "pending" # 待激活 class TenantTier(StrEnum): """租户订阅层级""" - FREE = "free" # 免费版 - PRO = "pro" # 专业版 - ENTERPRISE = "enterprise" # 企业版 + FREE = "free" # 免费版 + PRO = "pro" # 专业版 + ENTERPRISE = "enterprise" # 企业版 class TenantRole(StrEnum): """租户角色""" - OWNER = "owner" # 所有者 - ADMIN = "admin" # 管理员 - MEMBER = "member" # 成员 - VIEWER = "viewer" # 查看者 + OWNER = "owner" # 所有者 + ADMIN = "admin" # 管理员 + MEMBER = "member" # 成员 + VIEWER = "viewer" # 查看者 class DomainStatus(StrEnum): """域名状态""" - PENDING = "pending" # 待验证 - VERIFIED = "verified" # 已验证 - FAILED = "failed" # 验证失败 - EXPIRED = "expired" # 已过期 + PENDING = "pending" # 待验证 + VERIFIED = "verified" # 已验证 + FAILED = "failed" # 验证失败 + EXPIRED = "expired" # 已过期 @dataclass @@ -171,7 +171,7 @@ class TenantManager: """租户管理器 - 多租户 SaaS 架构核心""" # 默认资源限制配置 - 使用常量 - DEFAULT_LIMITS = { + DEFAULT_LIMITS = { TenantTier.FREE: { "max_projects": TenantLimits.FREE_MAX_PROJECTS, "max_storage_mb": TenantLimits.FREE_MAX_STORAGE_MB, @@ -209,7 +209,7 @@ class TenantManager: } # 角色权限映射 - ROLE_PERMISSIONS = { + ROLE_PERMISSIONS = { TenantRole.OWNER: [ "tenant:*", "project:*", @@ -240,7 +240,7 @@ class TenantManager: } # 权限名称映射 - PERMISSION_NAMES = { + PERMISSION_NAMES = { "tenant:*": "租户完全控制", "tenant:read": "查看租户信息", "project:*": "项目完全控制", @@ -257,21 +257,21 @@ class TenantManager: "export:basic": "基础导出", } - def __init__(self, db_path: str = "insightflow.db"): - self.db_path = db_path + def __init__(self, db_path: str = "insightflow.db") -> None: + self.db_path = db_path self._init_db() def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row return conn def _init_db(self) -> None: """初始化数据库表""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 租户主表 cursor.execute(""" @@ -418,41 +418,41 @@ class TenantManager: self, name: str, owner_id: str, - tier: str = "free", - description: str | None = None, - settings: dict | None = None, + tier: str = "free", + description: str | None = None, + settings: dict | None = None, ) -> Tenant: """创建新租户""" - conn = self._get_connection() + conn = self._get_connection() try: - tenant_id = str(uuid.uuid4()) - slug = self._generate_slug(name) + tenant_id = str(uuid.uuid4()) + slug = self._generate_slug(name) # 获取对应层级的资源限制 - tier_enum = ( + tier_enum = ( TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE ) - resource_limits = self.DEFAULT_LIMITS.get( + resource_limits = self.DEFAULT_LIMITS.get( tier_enum, self.DEFAULT_LIMITS[TenantTier.FREE] ) - tenant = Tenant( - id=tenant_id, - name=name, - slug=slug, - description=description, - tier=tier, - status=TenantStatus.PENDING.value, - owner_id=owner_id, - created_at=datetime.now(), - updated_at=datetime.now(), - expires_at=None, - settings=settings or {}, - resource_limits=resource_limits, - metadata={}, + tenant = Tenant( + id = tenant_id, + name = name, + slug = slug, + description = description, + tier = tier, + status = TenantStatus.PENDING.value, + owner_id = owner_id, + created_at = datetime.now(), + updated_at = datetime.now(), + expires_at = None, + settings = settings or {}, + resource_limits = resource_limits, + metadata = {}, ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO tenants (id, name, slug, description, tier, status, owner_id, @@ -492,11 +492,11 @@ class TenantManager: def get_tenant(self, tenant_id: str) -> Tenant | None: """获取租户信息""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id, )) + row = cursor.fetchone() if row: return self._row_to_tenant(row) @@ -507,11 +507,11 @@ class TenantManager: def get_tenant_by_slug(self, slug: str) -> Tenant | None: """通过 slug 获取租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug, )) + row = cursor.fetchone() if row: return self._row_to_tenant(row) @@ -522,18 +522,18 @@ class TenantManager: def get_tenant_by_domain(self, domain: str) -> Tenant | None: """通过自定义域名获取租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT t.* FROM tenants t - JOIN tenant_domains d ON t.id = d.tenant_id - WHERE d.domain = ? AND d.status = 'verified' + JOIN tenant_domains d ON t.id = d.tenant_id + WHERE d.domain = ? AND d.status = 'verified' """, - (domain,), + (domain, ), ) - row = cursor.fetchone() + row = cursor.fetchone() if row: return self._row_to_tenant(row) @@ -545,51 +545,51 @@ class TenantManager: def update_tenant( self, tenant_id: str, - name: str | None = None, - description: str | None = None, - tier: str | None = None, - status: str | None = None, - settings: dict | None = None, + name: str | None = None, + description: str | None = None, + tier: str | None = None, + status: str | None = None, + settings: dict | None = None, ) -> Tenant | None: """更新租户信息""" - conn = self._get_connection() + conn = self._get_connection() try: - tenant = self.get_tenant(tenant_id) + tenant = self.get_tenant(tenant_id) if not tenant: return None - updates = [] - params = [] + updates = [] + params = [] if name is not None: - updates.append("name = ?") + updates.append("name = ?") params.append(name) if description is not None: - updates.append("description = ?") + updates.append("description = ?") params.append(description) if tier is not None: - updates.append("tier = ?") + updates.append("tier = ?") params.append(tier) # 更新资源限制 - tier_enum = TenantTier(tier) - updates.append("resource_limits = ?") + tier_enum = TenantTier(tier) + updates.append("resource_limits = ?") params.append(json.dumps(self.DEFAULT_LIMITS.get(tier_enum, {}))) if status is not None: - updates.append("status = ?") + updates.append("status = ?") params.append(status) if settings is not None: - updates.append("settings = ?") + updates.append("settings = ?") params.append(json.dumps(settings)) - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(tenant_id) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( f""" UPDATE tenants SET {", ".join(updates)} - WHERE id = ? + WHERE id = ? """, params, ) @@ -602,38 +602,38 @@ class TenantManager: def delete_tenant(self, tenant_id: str) -> bool: """删除租户(软删除或硬删除)""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id,)) + cursor = conn.cursor() + cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id, )) conn.commit() return cursor.rowcount > 0 finally: conn.close() def list_tenants( - self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0 + self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0 ) -> list[Tenant]: """列出租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM tenants WHERE 1=1" - params = [] + query = "SELECT * FROM tenants WHERE 1 = 1" + params = [] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) if tier: - query += " AND tier = ?" + query += " AND tier = ?" params.append(tier) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_tenant(row) for row in rows] @@ -646,45 +646,45 @@ class TenantManager: self, tenant_id: str, domain: str, - is_primary: bool = False, - verification_method: str = "dns", + is_primary: bool = False, + verification_method: str = "dns", ) -> TenantDomain: """为租户添加自定义域名""" - conn = self._get_connection() + conn = self._get_connection() try: # 验证域名格式 if not self._validate_domain(domain): raise ValueError(f"Invalid domain format: {domain}") # 生成验证令牌 - verification_token = self._generate_verification_token(tenant_id, domain) + verification_token = self._generate_verification_token(tenant_id, domain) - domain_id = str(uuid.uuid4()) - tenant_domain = TenantDomain( - id=domain_id, - tenant_id=tenant_id, - domain=domain.lower(), - status=DomainStatus.PENDING.value, - verification_token=verification_token, - verification_method=verification_method, - verified_at=None, - created_at=datetime.now(), - updated_at=datetime.now(), - is_primary=is_primary, - ssl_enabled=False, - ssl_expires_at=None, + domain_id = str(uuid.uuid4()) + tenant_domain = TenantDomain( + id = domain_id, + tenant_id = tenant_id, + domain = domain.lower(), + status = DomainStatus.PENDING.value, + verification_token = verification_token, + verification_method = verification_method, + verified_at = None, + created_at = datetime.now(), + updated_at = datetime.now(), + is_primary = is_primary, + ssl_enabled = False, + ssl_expires_at = None, ) - cursor = conn.cursor() + cursor = conn.cursor() # 如果设为主域名,取消其他主域名 if is_primary: cursor.execute( """ - UPDATE tenant_domains SET is_primary = 0 - WHERE tenant_id = ? + UPDATE tenant_domains SET is_primary = 0 + WHERE tenant_id = ? """, - (tenant_id,), + (tenant_id, ), ) cursor.execute( @@ -723,36 +723,36 @@ class TenantManager: def verify_domain(self, tenant_id: str, domain_id: str) -> bool: """验证域名所有权""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 获取域名信息 cursor.execute( """ SELECT * FROM tenant_domains - WHERE id = ? AND tenant_id = ? + WHERE id = ? AND tenant_id = ? """, (domain_id, tenant_id), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return False - domain = row["domain"] - token = row["verification_token"] - method = row["verification_method"] + domain = row["domain"] + token = row["verification_token"] + method = row["verification_method"] # 执行验证 - is_verified = self._check_domain_verification(domain, token, method) + is_verified = self._check_domain_verification(domain, token, method) if is_verified: cursor.execute( """ UPDATE tenant_domains - SET status = 'verified', verified_at = ?, updated_at = ? - WHERE id = ? + SET status = 'verified', verified_at = ?, updated_at = ? + WHERE id = ? """, (datetime.now(), datetime.now(), domain_id), ) @@ -762,8 +762,8 @@ class TenantManager: cursor.execute( """ UPDATE tenant_domains - SET status = 'failed', updated_at = ? - WHERE id = ? + SET status = 'failed', updated_at = ? + WHERE id = ? """, (datetime.now(), domain_id), ) @@ -779,17 +779,17 @@ class TenantManager: def get_domain_verification_instructions(self, domain_id: str) -> dict[str, Any]: """获取域名验证指导""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id, )) + row = cursor.fetchone() if not row: return None - domain = row["domain"] - token = row["verification_token"] + domain = row["domain"] + token = row["verification_token"] return { "domain": domain, @@ -797,7 +797,7 @@ class TenantManager: "dns_record": { "type": "TXT", "name": "_insightflow", - "value": f"insightflow-verify={token}", + "value": f"insightflow-verify = {token}", "ttl": 3600, }, "file_verification": { @@ -805,7 +805,7 @@ class TenantManager: "content": token, }, "instructions": [ - f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}", + f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify = {token}", f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt,内容为 {token}", ], } @@ -815,13 +815,13 @@ class TenantManager: def remove_domain(self, tenant_id: str, domain_id: str) -> bool: """移除域名绑定""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ DELETE FROM tenant_domains - WHERE id = ? AND tenant_id = ? + WHERE id = ? AND tenant_id = ? """, (domain_id, tenant_id), ) @@ -832,18 +832,18 @@ class TenantManager: def list_domains(self, tenant_id: str) -> list[TenantDomain]: """列出租户的所有域名""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT * FROM tenant_domains - WHERE tenant_id = ? + WHERE tenant_id = ? ORDER BY is_primary DESC, created_at DESC """, - (tenant_id,), + (tenant_id, ), ) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_domain(row) for row in rows] @@ -854,11 +854,11 @@ class TenantManager: def get_branding(self, tenant_id: str) -> TenantBranding | None: """获取租户品牌配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id,)) - row = cursor.fetchone() + cursor = conn.cursor() + cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id, )) + row = cursor.fetchone() if row: return self._row_to_branding(row) @@ -870,68 +870,68 @@ class TenantManager: def update_branding( self, tenant_id: str, - logo_url: str | None = None, - favicon_url: str | None = None, - primary_color: str | None = None, - secondary_color: str | None = None, - custom_css: str | None = None, - custom_js: str | None = None, - login_page_bg: str | None = None, - email_template: str | None = None, + logo_url: str | None = None, + favicon_url: str | None = None, + primary_color: str | None = None, + secondary_color: str | None = None, + custom_css: str | None = None, + custom_js: str | None = None, + login_page_bg: str | None = None, + email_template: str | None = None, ) -> TenantBranding: """更新租户品牌配置""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() # 检查是否已存在 - cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id,)) - existing = cursor.fetchone() + cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id, )) + existing = cursor.fetchone() if existing: # 更新 - updates = [] - params = [] + updates = [] + params = [] if logo_url is not None: - updates.append("logo_url = ?") + updates.append("logo_url = ?") params.append(logo_url) if favicon_url is not None: - updates.append("favicon_url = ?") + updates.append("favicon_url = ?") params.append(favicon_url) if primary_color is not None: - updates.append("primary_color = ?") + updates.append("primary_color = ?") params.append(primary_color) if secondary_color is not None: - updates.append("secondary_color = ?") + updates.append("secondary_color = ?") params.append(secondary_color) if custom_css is not None: - updates.append("custom_css = ?") + updates.append("custom_css = ?") params.append(custom_css) if custom_js is not None: - updates.append("custom_js = ?") + updates.append("custom_js = ?") params.append(custom_js) if login_page_bg is not None: - updates.append("login_page_bg = ?") + updates.append("login_page_bg = ?") params.append(login_page_bg) if email_template is not None: - updates.append("email_template = ?") + updates.append("email_template = ?") params.append(email_template) - updates.append("updated_at = ?") + updates.append("updated_at = ?") params.append(datetime.now()) params.append(tenant_id) cursor.execute( f""" UPDATE tenant_branding SET {", ".join(updates)} - WHERE tenant_id = ? + WHERE tenant_id = ? """, params, ) else: # 创建 - branding_id = str(uuid.uuid4()) + branding_id = str(uuid.uuid4()) cursor.execute( """ INSERT INTO tenant_branding @@ -963,11 +963,11 @@ class TenantManager: def get_branding_css(self, tenant_id: str) -> str: """生成品牌 CSS""" - branding = self.get_branding(tenant_id) + branding = self.get_branding(tenant_id) if not branding: return "" - css = [] + css = [] if branding.primary_color: css.append(f""" @@ -1007,35 +1007,35 @@ class TenantManager: email: str, role: str, invited_by: str, - permissions: list[str] | None = None, + permissions: list[str] | None = None, ) -> TenantMember: """邀请成员加入租户""" - conn = self._get_connection() + conn = self._get_connection() try: - member_id = str(uuid.uuid4()) + member_id = str(uuid.uuid4()) # 使用角色默认权限 - role_enum = ( + role_enum = ( TenantRole(role) if role in [r.value for r in TenantRole] else TenantRole.MEMBER ) - default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) - final_permissions = permissions or default_permissions + default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) + final_permissions = permissions or default_permissions - member = TenantMember( - id=member_id, - tenant_id=tenant_id, - user_id="pending", # 临时值,待用户接受邀请后更新 - email=email, - role=role, - permissions=final_permissions, - invited_by=invited_by, - invited_at=datetime.now(), - joined_at=None, - last_active_at=None, - status="pending", + member = TenantMember( + id = member_id, + tenant_id = tenant_id, + user_id = "pending", # 临时值,待用户接受邀请后更新 + email = email, + role = role, + permissions = final_permissions, + invited_by = invited_by, + invited_at = datetime.now(), + joined_at = None, + last_active_at = None, + status = "pending", ) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO tenant_members @@ -1067,14 +1067,14 @@ class TenantManager: def accept_invitation(self, invitation_id: str, user_id: str) -> bool: """接受邀请""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE tenant_members - SET user_id = ?, status = 'active', joined_at = ? - WHERE id = ? AND status = 'pending' + SET user_id = ?, status = 'active', joined_at = ? + WHERE id = ? AND status = 'pending' """, (user_id, datetime.now(), invitation_id), ) @@ -1087,13 +1087,13 @@ class TenantManager: def remove_member(self, tenant_id: str, member_id: str) -> bool: """移除成员""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ DELETE FROM tenant_members - WHERE id = ? AND tenant_id = ? + WHERE id = ? AND tenant_id = ? """, (member_id, tenant_id), ) @@ -1103,21 +1103,21 @@ class TenantManager: conn.close() def update_member_role( - self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None + self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None ) -> bool: """更新成员角色""" - conn = self._get_connection() + conn = self._get_connection() try: - role_enum = TenantRole(role) - default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) - final_permissions = permissions or default_permissions + role_enum = TenantRole(role) + default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) + final_permissions = permissions or default_permissions - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ UPDATE tenant_members - SET role = ?, permissions = ?, updated_at = ? - WHERE id = ? AND tenant_id = ? + SET role = ?, permissions = ?, updated_at = ? + WHERE id = ? AND tenant_id = ? """, (role, json.dumps(final_permissions), datetime.now(), member_id, tenant_id), ) @@ -1128,23 +1128,23 @@ class TenantManager: finally: conn.close() - def list_members(self, tenant_id: str, status: str | None = None) -> list[TenantMember]: + def list_members(self, tenant_id: str, status: str | None = None) -> list[TenantMember]: """列出租户成员""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = "SELECT * FROM tenant_members WHERE tenant_id = ?" - params = [tenant_id] + query = "SELECT * FROM tenant_members WHERE tenant_id = ?" + params = [tenant_id] if status: - query += " AND status = ?" + query += " AND status = ?" params.append(status) query += " ORDER BY invited_at DESC" cursor.execute(query, params) - rows = cursor.fetchall() + rows = cursor.fetchall() return [self._row_to_member(row) for row in rows] @@ -1153,31 +1153,31 @@ class TenantManager: def check_permission(self, tenant_id: str, user_id: str, resource: str, action: str) -> bool: """检查用户是否有特定权限""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT role, permissions FROM tenant_members - WHERE tenant_id = ? AND user_id = ? AND status = 'active' + WHERE tenant_id = ? AND user_id = ? AND status = 'active' """, (tenant_id, user_id), ) - row = cursor.fetchone() + row = cursor.fetchone() if not row: return False - role = row["role"] - permissions = json.loads(row["permissions"] or "[]") + role = row["role"] + permissions = json.loads(row["permissions"] or "[]") # 所有者拥有所有权限 if role == TenantRole.OWNER.value: return True # 检查具体权限 - required = f"{resource}:{action}" - wildcard = f"{resource}:*" + required = f"{resource}:{action}" + wildcard = f"{resource}:*" return required in permissions or wildcard in permissions or "*" in permissions @@ -1186,24 +1186,24 @@ class TenantManager: def get_user_tenants(self, user_id: str) -> list[dict[str, Any]]: """获取用户所属的所有租户""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ SELECT t.*, m.role, m.status as member_status FROM tenants t - JOIN tenant_members m ON t.id = m.tenant_id - WHERE m.user_id = ? AND m.status = 'active' + JOIN tenant_members m ON t.id = m.tenant_id + WHERE m.user_id = ? AND m.status = 'active' ORDER BY t.created_at DESC """, - (user_id,), + (user_id, ), ) - rows = cursor.fetchall() + rows = cursor.fetchall() - result = [] + result = [] for row in rows: - tenant = self._row_to_tenant(row) + tenant = self._row_to_tenant(row) result.append( { **asdict(tenant), @@ -1221,20 +1221,20 @@ class TenantManager: def record_usage( self, tenant_id: str, - storage_bytes: int = 0, - transcription_seconds: int = 0, - api_calls: int = 0, - projects_count: int = 0, - entities_count: int = 0, - members_count: int = 0, - ): + storage_bytes: int = 0, + transcription_seconds: int = 0, + api_calls: int = 0, + projects_count: int = 0, + entities_count: int = 0, + members_count: int = 0, + ) -> None: """记录资源使用""" - conn = self._get_connection() + conn = self._get_connection() try: - today = datetime.now().date() - usage_id = str(uuid.uuid4()) + today = datetime.now().date() + usage_id = str(uuid.uuid4()) - cursor = conn.cursor() + cursor = conn.cursor() cursor.execute( """ INSERT INTO tenant_usage @@ -1242,12 +1242,12 @@ class TenantManager: projects_count, entities_count, members_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(tenant_id, date) DO UPDATE SET - storage_bytes = storage_bytes + excluded.storage_bytes, - transcription_seconds = transcription_seconds + excluded.transcription_seconds, - api_calls = api_calls + excluded.api_calls, - projects_count = MAX(projects_count, excluded.projects_count), - entities_count = MAX(entities_count, excluded.entities_count), - members_count = MAX(members_count, excluded.members_count) + storage_bytes = storage_bytes + excluded.storage_bytes, + transcription_seconds = transcription_seconds + excluded.transcription_seconds, + api_calls = api_calls + excluded.api_calls, + projects_count = MAX(projects_count, excluded.projects_count), + entities_count = MAX(entities_count, excluded.entities_count), + members_count = MAX(members_count, excluded.members_count) """, ( usage_id, @@ -1268,14 +1268,14 @@ class TenantManager: conn.close() def get_usage_stats( - self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None + self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None ) -> dict[str, Any]: """获取使用统计""" - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() + cursor = conn.cursor() - query = """ + query = """ SELECT SUM(storage_bytes) as total_storage, SUM(transcription_seconds) as total_transcription, @@ -1284,9 +1284,9 @@ class TenantManager: MAX(entities_count) as max_entities, MAX(members_count) as max_members FROM tenant_usage - WHERE tenant_id = ? + WHERE tenant_id = ? """ - params = [tenant_id] + params = [tenant_id] if start_date: query += " AND date >= ?" @@ -1296,11 +1296,11 @@ class TenantManager: params.append(end_date.date()) cursor.execute(query, params) - row = cursor.fetchone() + row = cursor.fetchone() # 获取租户限制 - tenant = self.get_tenant(tenant_id) - limits = tenant.resource_limits if tenant else {} + tenant = self.get_tenant(tenant_id) + limits = tenant.resource_limits if tenant else {} return { "storage_bytes": row["total_storage"] or 0, @@ -1344,14 +1344,14 @@ class TenantManager: Returns: (是否允许, 当前使用量, 限制值) """ - tenant = self.get_tenant(tenant_id) + tenant = self.get_tenant(tenant_id) if not tenant: return False, 0, 0 - limits = tenant.resource_limits - stats = self.get_usage_stats(tenant_id) + limits = tenant.resource_limits + stats = self.get_usage_stats(tenant_id) - resource_map = { + resource_map = { "storage": ("storage_mb", stats["storage_mb"]), "transcription": ("max_transcription_minutes", stats["transcription_minutes"]), "api_calls": ("max_api_calls_per_day", stats["api_calls"]), @@ -1363,8 +1363,8 @@ class TenantManager: if resource_type not in resource_map: return True, 0, -1 - limit_key, current = resource_map[resource_type] - limit = limits.get(limit_key, 0) + limit_key, current = resource_map[resource_type] + limit = limits.get(limit_key, 0) # -1 表示无限制 if limit == -1: @@ -1377,21 +1377,21 @@ class TenantManager: def _generate_slug(self, name: str) -> str: """生成 URL 友好的 slug""" # 转换为小写,替换空格为连字符 - slug = re.sub(r"[^\w\s-]", "", name.lower()) - slug = re.sub(r"[-\s]+", "-", slug) + slug = re.sub(r"[^\w\s-]", "", name.lower()) + slug = re.sub(r"[-\s]+", "-", slug) # 检查是否已存在 - conn = self._get_connection() + conn = self._get_connection() try: - cursor = conn.cursor() - base_slug = slug - counter = 1 + cursor = conn.cursor() + base_slug = slug + counter = 1 while True: - cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug,)) + cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug, )) if not cursor.fetchone(): break - slug = f"{base_slug}-{counter}" + slug = f"{base_slug}-{counter}" counter += 1 return slug @@ -1401,12 +1401,12 @@ class TenantManager: def _generate_verification_token(self, tenant_id: str, domain: str) -> str: """生成域名验证令牌""" - data = f"{tenant_id}:{domain}:{datetime.now().isoformat()}" + data = f"{tenant_id}:{domain}:{datetime.now().isoformat()}" return hashlib.sha256(data.encode()).hexdigest()[:32] def _validate_domain(self, domain: str) -> bool: """验证域名格式""" - pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$" + pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$" return bool(re.match(pattern, domain)) def _check_domain_verification(self, domain: str, token: str, method: str) -> bool: @@ -1419,7 +1419,7 @@ class TenantManager: # TODO: 实现 DNS TXT 记录查询 # import dns.resolver # try: - # answers = dns.resolver.resolve(f"_insightflow.{domain}", 'TXT') + # answers = dns.resolver.resolve(f"_insightflow.{domain}", 'TXT') # for rdata in answers: # if token in str(rdata): # return True @@ -1431,7 +1431,7 @@ class TenantManager: # TODO: 实现 HTTP 文件验证 # import requests # try: - # response = requests.get(f"http://{domain}/.well-known/insightflow-verify.txt", timeout=10) + # response = requests.get(f"http://{domain}/.well-known/insightflow-verify.txt", timeout = 10) # if response.status_code == 200 and token in response.text: # return True # except (ImportError, Exception): @@ -1442,14 +1442,14 @@ class TenantManager: def _darken_color(self, hex_color: str, percent: int) -> str: """加深颜色""" - hex_color = hex_color.lstrip("#") - r = int(hex_color[0:2], 16) - g = int(hex_color[2:4], 16) - b = int(hex_color[4:6], 16) + hex_color = hex_color.lstrip("#") + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) - r = int(r * (100 - percent) / 100) - g = int(g * (100 - percent) / 100) - b = int(b * (100 - percent) / 100) + r = int(r * (100 - percent) / 100) + g = int(g * (100 - percent) / 100) + b = int(b * (100 - percent) / 100) return f"#{r:02x}{g:02x}{b:02x}" @@ -1467,10 +1467,10 @@ class TenantManager: email: str, role: TenantRole, invited_by: str | None, - ): + ) -> None: """内部方法:添加成员""" - cursor = conn.cursor() - member_id = str(uuid.uuid4()) + cursor = conn.cursor() + member_id = str(uuid.uuid4()) cursor.execute( """ @@ -1497,60 +1497,60 @@ class TenantManager: def _row_to_tenant(self, row: sqlite3.Row) -> Tenant: """数据库行转换为 Tenant 对象""" return Tenant( - id=row["id"], - name=row["name"], - slug=row["slug"], - description=row["description"], - tier=row["tier"], - status=row["status"], - owner_id=row["owner_id"], - created_at=( + id = row["id"], + name = row["name"], + slug = row["slug"], + description = row["description"], + tier = row["tier"], + status = row["status"], + owner_id = row["owner_id"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - expires_at=( + expires_at = ( datetime.fromisoformat(row["expires_at"]) if row["expires_at"] and isinstance(row["expires_at"], str) else row["expires_at"] ), - settings=json.loads(row["settings"] or "{}"), - resource_limits=json.loads(row["resource_limits"] or "{}"), - metadata=json.loads(row["metadata"] or "{}"), + settings = json.loads(row["settings"] or "{}"), + resource_limits = json.loads(row["resource_limits"] or "{}"), + metadata = json.loads(row["metadata"] or "{}"), ) def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain: """数据库行转换为 TenantDomain 对象""" return TenantDomain( - id=row["id"], - tenant_id=row["tenant_id"], - domain=row["domain"], - status=row["status"], - verification_token=row["verification_token"], - verification_method=row["verification_method"], - verified_at=( + id = row["id"], + tenant_id = row["tenant_id"], + domain = row["domain"], + status = row["status"], + verification_token = row["verification_token"], + verification_method = row["verification_method"], + verified_at = ( datetime.fromisoformat(row["verified_at"]) if row["verified_at"] and isinstance(row["verified_at"], str) else row["verified_at"] ), - created_at=( + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] ), - is_primary=bool(row["is_primary"]), - ssl_enabled=bool(row["ssl_enabled"]), - ssl_expires_at=( + is_primary = bool(row["is_primary"]), + ssl_enabled = bool(row["ssl_enabled"]), + ssl_expires_at = ( datetime.fromisoformat(row["ssl_expires_at"]) if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str) else row["ssl_expires_at"] @@ -1560,22 +1560,22 @@ class TenantManager: def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding: """数据库行转换为 TenantBranding 对象""" return TenantBranding( - id=row["id"], - tenant_id=row["tenant_id"], - logo_url=row["logo_url"], - favicon_url=row["favicon_url"], - primary_color=row["primary_color"], - secondary_color=row["secondary_color"], - custom_css=row["custom_css"], - custom_js=row["custom_js"], - login_page_bg=row["login_page_bg"], - email_template=row["email_template"], - created_at=( + id = row["id"], + tenant_id = row["tenant_id"], + logo_url = row["logo_url"], + favicon_url = row["favicon_url"], + primary_color = row["primary_color"], + secondary_color = row["secondary_color"], + custom_css = row["custom_css"], + custom_js = row["custom_js"], + login_page_bg = row["login_page_bg"], + email_template = row["email_template"], + created_at = ( datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] ), - updated_at=( + updated_at = ( datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] @@ -1585,29 +1585,29 @@ class TenantManager: def _row_to_member(self, row: sqlite3.Row) -> TenantMember: """数据库行转换为 TenantMember 对象""" return TenantMember( - id=row["id"], - tenant_id=row["tenant_id"], - user_id=row["user_id"], - email=row["email"], - role=row["role"], - permissions=json.loads(row["permissions"] or "[]"), - invited_by=row["invited_by"], - invited_at=( + id = row["id"], + tenant_id = row["tenant_id"], + user_id = row["user_id"], + email = row["email"], + role = row["role"], + permissions = json.loads(row["permissions"] or "[]"), + invited_by = row["invited_by"], + invited_at = ( datetime.fromisoformat(row["invited_at"]) if isinstance(row["invited_at"], str) else row["invited_at"] ), - joined_at=( + joined_at = ( datetime.fromisoformat(row["joined_at"]) if row["joined_at"] and isinstance(row["joined_at"], str) else row["joined_at"] ), - last_active_at=( + last_active_at = ( datetime.fromisoformat(row["last_active_at"]) if row["last_active_at"] and isinstance(row["last_active_at"], str) else row["last_active_at"] ), - status=row["status"], + status = row["status"], ) @@ -1617,13 +1617,13 @@ class TenantManager: class TenantContext: """租户上下文管理器 - 用于请求级别的租户隔离""" - _current_tenant_id: str | None = None - _current_user_id: str | None = None + _current_tenant_id: str | None = None + _current_user_id: str | None = None @classmethod def set_current_tenant(cls, tenant_id: str) -> None: """设置当前租户上下文""" - cls._current_tenant_id = tenant_id + cls._current_tenant_id = tenant_id @classmethod def get_current_tenant(cls) -> str | None: @@ -1633,7 +1633,7 @@ class TenantContext: @classmethod def set_current_user(cls, user_id: str) -> None: """设置当前用户""" - cls._current_user_id = user_id + cls._current_user_id = user_id @classmethod def get_current_user(cls) -> str | None: @@ -1643,17 +1643,17 @@ class TenantContext: @classmethod def clear(cls) -> None: """清除上下文""" - cls._current_tenant_id = None - cls._current_user_id = None + cls._current_tenant_id = None + cls._current_user_id = None # 全局租户管理器实例 -tenant_manager = None +tenant_manager = None -def get_tenant_manager(db_path: str = "insightflow.db") -> TenantManager: +def get_tenant_manager(db_path: str = "insightflow.db") -> TenantManager: """获取租户管理器实例(单例模式)""" global tenant_manager if tenant_manager is None: - tenant_manager = TenantManager(db_path) + tenant_manager = TenantManager(db_path) return tenant_manager diff --git a/backend/test_multimodal.py b/backend/test_multimodal.py index eeb7e8f..b7c26da 100644 --- a/backend/test_multimodal.py +++ b/backend/test_multimodal.py @@ -10,9 +10,9 @@ import sys # 添加 backend 目录到路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -print("=" * 60) +print(" = " * 60) print("InsightFlow 多模态模块测试") -print("=" * 60) +print(" = " * 60) # 测试导入 print("\n1. 测试模块导入...") @@ -42,7 +42,7 @@ except ImportError as e: print("\n2. 测试模块初始化...") try: - processor = get_multimodal_processor() + processor = get_multimodal_processor() print(" ✓ MultimodalProcessor 初始化成功") print(f" - 临时目录: {processor.temp_dir}") print(f" - 帧提取间隔: {processor.frame_interval}秒") @@ -50,14 +50,14 @@ except Exception as e: print(f" ✗ MultimodalProcessor 初始化失败: {e}") try: - img_processor = get_image_processor() + img_processor = get_image_processor() print(" ✓ ImageProcessor 初始化成功") print(f" - 临时目录: {img_processor.temp_dir}") except Exception as e: print(f" ✗ ImageProcessor 初始化失败: {e}") try: - linker = get_multimodal_entity_linker() + linker = get_multimodal_entity_linker() print(" ✓ MultimodalEntityLinker 初始化成功") print(f" - 相似度阈值: {linker.similarity_threshold}") except Exception as e: @@ -67,20 +67,20 @@ except Exception as e: print("\n3. 测试实体关联功能...") try: - linker = get_multimodal_entity_linker() + linker = get_multimodal_entity_linker() # 测试字符串相似度 - sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha") + sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha") assert sim == 1.0, "完全匹配应该返回1.0" print(f" ✓ 字符串相似度计算正常 (完全匹配: {sim})") - sim = linker.calculate_string_similarity("K8s", "Kubernetes") + sim = linker.calculate_string_similarity("K8s", "Kubernetes") print(f" ✓ 字符串相似度计算正常 (不同字符串: {sim:.2f})") # 测试实体相似度 - entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"} - entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"} - sim, match_type = linker.calculate_entity_similarity(entity1, entity2) + entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"} + entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"} + sim, match_type = linker.calculate_entity_similarity(entity1, entity2) print(f" ✓ 实体相似度计算正常 (相似度: {sim:.2f}, 类型: {match_type})") except Exception as e: @@ -90,7 +90,7 @@ except Exception as e: print("\n4. 测试图片处理器功能...") try: - processor = get_image_processor() + processor = get_image_processor() # 测试图片类型检测(使用模拟数据) print(f" ✓ 支持的图片类型: {list(processor.IMAGE_TYPES.keys())}") @@ -103,7 +103,7 @@ except Exception as e: print("\n5. 测试视频处理器配置...") try: - processor = get_multimodal_processor() + processor = get_multimodal_processor() print(f" ✓ 视频目录: {processor.video_dir}") print(f" ✓ 帧目录: {processor.frames_dir}") @@ -129,11 +129,11 @@ print("\n6. 测试数据库多模态方法...") try: from db_manager import get_db_manager - db = get_db_manager() + db = get_db_manager() # 检查多模态表是否存在 - conn = db.get_conn() - tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"] + conn = db.get_conn() + tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"] for table in tables: try: @@ -147,6 +147,6 @@ try: except Exception as e: print(f" ✗ 数据库多模态方法测试失败: {e}") -print("\n" + "=" * 60) +print("\n" + " = " * 60) print("测试完成") -print("=" * 60) +print(" = " * 60) diff --git a/backend/test_phase7_task6_8.py b/backend/test_phase7_task6_8.py index 6cd872f..40beaae 100644 --- a/backend/test_phase7_task6_8.py +++ b/backend/test_phase7_task6_8.py @@ -21,27 +21,27 @@ from search_manager import ( sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -def test_fulltext_search(): +def test_fulltext_search() -> None: """测试全文搜索""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试全文搜索 (FullTextSearch)") - print("=" * 60) + print(" = " * 60) - search = FullTextSearch() + search = FullTextSearch() # 测试索引创建 print("\n1. 测试索引创建...") - success = search.index_content( - content_id="test_entity_1", - content_type="entity", - project_id="test_project", - text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。", + success = search.index_content( + content_id = "test_entity_1", + content_type = "entity", + project_id = "test_project", + text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。", ) print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") # 测试搜索 print("\n2. 测试关键词搜索...") - results = search.search("测试", project_id="test_project") + results = search.search("测试", project_id = "test_project") print(f" 搜索结果数量: {len(results)}") if results: print(f" 第一个结果: {results[0].content[:50]}...") @@ -49,28 +49,28 @@ def test_fulltext_search(): # 测试布尔搜索 print("\n3. 测试布尔搜索...") - results = search.search("测试 AND 全文", project_id="test_project") + results = search.search("测试 AND 全文", project_id = "test_project") print(f" AND 搜索结果: {len(results)}") - results = search.search("测试 OR 关键词", project_id="test_project") + results = search.search("测试 OR 关键词", project_id = "test_project") print(f" OR 搜索结果: {len(results)}") # 测试高亮 print("\n4. 测试文本高亮...") - highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文") + highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文") print(f" 高亮结果: {highlighted}") print("\n✓ 全文搜索测试完成") return True -def test_semantic_search(): +def test_semantic_search() -> None: """测试语义搜索""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试语义搜索 (SemanticSearch)") - print("=" * 60) + print(" = " * 60) - semantic = SemanticSearch() + semantic = SemanticSearch() # 检查可用性 print(f"\n1. 语义搜索可用性: {'✓ 可用' if semantic.is_available() else '✗ 不可用'}") @@ -81,18 +81,18 @@ def test_semantic_search(): # 测试 embedding 生成 print("\n2. 测试 embedding 生成...") - embedding = semantic.generate_embedding("这是一个测试句子") + embedding = semantic.generate_embedding("这是一个测试句子") if embedding: print(f" Embedding 维度: {len(embedding)}") print(f" 前5个值: {embedding[:5]}") # 测试索引 print("\n3. 测试语义索引...") - success = semantic.index_embedding( - content_id="test_content_1", - content_type="transcript", - project_id="test_project", - text="这是用于语义搜索测试的文本内容。", + success = semantic.index_embedding( + content_id = "test_content_1", + content_type = "transcript", + project_id = "test_project", + text = "这是用于语义搜索测试的文本内容。", ) print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") @@ -100,13 +100,13 @@ def test_semantic_search(): return True -def test_entity_path_discovery(): +def test_entity_path_discovery() -> None: """测试实体路径发现""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试实体路径发现 (EntityPathDiscovery)") - print("=" * 60) + print(" = " * 60) - discovery = EntityPathDiscovery() + discovery = EntityPathDiscovery() print("\n1. 测试路径发现初始化...") print(f" 数据库路径: {discovery.db_path}") @@ -119,13 +119,13 @@ def test_entity_path_discovery(): return True -def test_knowledge_gap_detection(): +def test_knowledge_gap_detection() -> None: """测试知识缺口识别""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试知识缺口识别 (KnowledgeGapDetection)") - print("=" * 60) + print(" = " * 60) - detection = KnowledgeGapDetection() + detection = KnowledgeGapDetection() print("\n1. 测试缺口检测初始化...") print(f" 数据库路径: {detection.db_path}") @@ -138,32 +138,32 @@ def test_knowledge_gap_detection(): return True -def test_cache_manager(): +def test_cache_manager() -> None: """测试缓存管理器""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试缓存管理器 (CacheManager)") - print("=" * 60) + print(" = " * 60) - cache = CacheManager() + cache = CacheManager() print(f"\n1. 缓存后端: {'Redis' if cache.use_redis else '内存 LRU'}") print("\n2. 测试缓存操作...") # 设置缓存 - cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60) + cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60) print(" ✓ 设置缓存 test_key_1") # 获取缓存 - _ = cache.get("test_key_1") + _ = cache.get("test_key_1") print(" ✓ 获取缓存: {value}") # 批量操作 cache.set_many( - {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60 + {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60 ) print(" ✓ 批量设置缓存") - _ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"]) + _ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"]) print(" ✓ 批量获取缓存: {len(values)} 个") # 删除缓存 @@ -171,7 +171,7 @@ def test_cache_manager(): print(" ✓ 删除缓存 test_key_1") # 获取统计 - stats = cache.get_stats() + stats = cache.get_stats() print("\n3. 缓存统计:") print(f" 总请求数: {stats['total_requests']}") print(f" 命中数: {stats['hits']}") @@ -186,13 +186,13 @@ def test_cache_manager(): return True -def test_task_queue(): +def test_task_queue() -> None: """测试任务队列""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试任务队列 (TaskQueue)") - print("=" * 60) + print(" = " * 60) - queue = TaskQueue() + queue = TaskQueue() print(f"\n1. 任务队列可用性: {'✓ 可用' if queue.is_available() else '✗ 不可用'}") print(f" 后端: {'Celery' if queue.use_celery else '内存'}") @@ -200,25 +200,25 @@ def test_task_queue(): print("\n2. 测试任务提交...") # 定义测试任务处理器 - def test_task_handler(payload): + def test_task_handler(payload) -> None: print(f" 执行任务: {payload}") return {"status": "success", "processed": True} queue.register_handler("test_task", test_task_handler) # 提交任务 - task_id = queue.submit( - task_type="test_task", payload={"test": "data", "timestamp": time.time()} + task_id = queue.submit( + task_type = "test_task", payload = {"test": "data", "timestamp": time.time()} ) print(" ✓ 提交任务: {task_id}") # 获取任务状态 - task_info = queue.get_status(task_id) + task_info = queue.get_status(task_id) if task_info: print(" ✓ 任务状态: {task_info.status}") # 获取统计 - stats = queue.get_stats() + stats = queue.get_stats() print("\n3. 任务队列统计:") print(f" 后端: {stats['backend']}") print(f" 按状态统计: {stats.get('by_status', {})}") @@ -227,38 +227,38 @@ def test_task_queue(): return True -def test_performance_monitor(): +def test_performance_monitor() -> None: """测试性能监控""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试性能监控 (PerformanceMonitor)") - print("=" * 60) + print(" = " * 60) - monitor = PerformanceMonitor() + monitor = PerformanceMonitor() print("\n1. 测试指标记录...") # 记录一些测试指标 for i in range(5): monitor.record_metric( - metric_type="api_response", - duration_ms=50 + i * 10, - endpoint="/api/v1/test", - metadata={"test": True}, + metric_type = "api_response", + duration_ms = 50 + i * 10, + endpoint = "/api/v1/test", + metadata = {"test": True}, ) for i in range(3): monitor.record_metric( - metric_type="db_query", - duration_ms=20 + i * 5, - endpoint="SELECT test", - metadata={"test": True}, + metric_type = "db_query", + duration_ms = 20 + i * 5, + endpoint = "SELECT test", + metadata = {"test": True}, ) print(" ✓ 记录了 8 个测试指标") # 获取统计 print("\n2. 获取性能统计...") - stats = monitor.get_stats(hours=1) + stats = monitor.get_stats(hours = 1) print(f" 总请求数: {stats['overall']['total_requests']}") print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms") print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms") @@ -274,19 +274,19 @@ def test_performance_monitor(): return True -def test_search_manager(): +def test_search_manager() -> None: """测试搜索管理器""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试搜索管理器 (SearchManager)") - print("=" * 60) + print(" = " * 60) - manager = get_search_manager() + manager = get_search_manager() print("\n1. 搜索管理器初始化...") print(" ✓ 搜索管理器已初始化") print("\n2. 获取搜索统计...") - stats = manager.get_search_stats() + stats = manager.get_search_stats() print(f" 全文索引数: {stats['fulltext_indexed']}") print(f" 语义索引数: {stats['semantic_indexed']}") print(f" 语义搜索可用: {stats['semantic_search_available']}") @@ -295,24 +295,24 @@ def test_search_manager(): return True -def test_performance_manager(): +def test_performance_manager() -> None: """测试性能管理器""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试性能管理器 (PerformanceManager)") - print("=" * 60) + print(" = " * 60) - manager = get_performance_manager() + manager = get_performance_manager() print("\n1. 性能管理器初始化...") print(" ✓ 性能管理器已初始化") print("\n2. 获取系统健康状态...") - health = manager.get_health_status() + health = manager.get_health_status() print(f" 缓存后端: {health['cache']['backend']}") print(f" 任务队列后端: {health['task_queue']['backend']}") print("\n3. 获取完整统计...") - stats = manager.get_full_stats() + stats = manager.get_full_stats() print(f" 缓存统计: {stats['cache']['total_requests']} 请求") print(f" 任务队列统计: {stats['task_queue']}") @@ -320,14 +320,14 @@ def test_performance_manager(): return True -def run_all_tests(): +def run_all_tests() -> None: """运行所有测试""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("InsightFlow Phase 7 Task 6 & 8 测试") print("高级搜索与发现 + 性能优化与扩展") - print("=" * 60) + print(" = " * 60) - results = [] + results = [] # 搜索模块测试 try: @@ -386,15 +386,15 @@ def run_all_tests(): results.append(("性能管理器", False)) # 打印测试汇总 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试汇总") - print("=" * 60) + print(" = " * 60) - passed = sum(1 for _, result in results if result) - total = len(results) + passed = sum(1 for _, result in results if result) + total = len(results) for name, result in results: - status = "✓ 通过" if result else "✗ 失败" + status = "✓ 通过" if result else "✗ 失败" print(f" {status} - {name}") print(f"\n总计: {passed}/{total} 测试通过") @@ -408,5 +408,5 @@ def run_all_tests(): if __name__ == "__main__": - success = run_all_tests() + success = run_all_tests() sys.exit(0 if success else 1) diff --git a/backend/test_phase8_task1.py b/backend/test_phase8_task1.py index b014b62..c7ddcd5 100644 --- a/backend/test_phase8_task1.py +++ b/backend/test_phase8_task1.py @@ -18,18 +18,18 @@ from tenant_manager import get_tenant_manager sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -def test_tenant_management(): +def test_tenant_management() -> None: """测试租户管理功能""" - print("=" * 60) + print(" = " * 60) print("测试 1: 租户管理") - print("=" * 60) + print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 创建租户 print("\n1.1 创建租户...") - tenant = manager.create_tenant( - name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant" + tenant = manager.create_tenant( + name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant" ) print(f"✅ 租户创建成功: {tenant.id}") print(f" - 名称: {tenant.name}") @@ -40,43 +40,43 @@ def test_tenant_management(): # 2. 获取租户 print("\n1.2 获取租户信息...") - fetched = manager.get_tenant(tenant.id) + fetched = manager.get_tenant(tenant.id) assert fetched is not None, "获取租户失败" print(f"✅ 获取租户成功: {fetched.name}") # 3. 通过 slug 获取 print("\n1.3 通过 slug 获取租户...") - by_slug = manager.get_tenant_by_slug(tenant.slug) + by_slug = manager.get_tenant_by_slug(tenant.slug) assert by_slug is not None, "通过 slug 获取失败" print(f"✅ 通过 slug 获取成功: {by_slug.name}") # 4. 更新租户 print("\n1.4 更新租户信息...") - updated = manager.update_tenant( - tenant_id=tenant.id, name="Test Company Updated", tier="enterprise" + updated = manager.update_tenant( + tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise" ) assert updated is not None, "更新租户失败" print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}") # 5. 列出租户 print("\n1.5 列出租户...") - tenants = manager.list_tenants(limit=10) + tenants = manager.list_tenants(limit = 10) print(f"✅ 找到 {len(tenants)} 个租户") return tenant.id -def test_domain_management(tenant_id: str): +def test_domain_management(tenant_id: str) -> None: """测试域名管理功能""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试 2: 域名管理") - print("=" * 60) + print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 添加域名 print("\n2.1 添加自定义域名...") - domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True) + domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True) print(f"✅ 域名添加成功: {domain.domain}") print(f" - ID: {domain.id}") print(f" - 状态: {domain.status}") @@ -84,19 +84,19 @@ def test_domain_management(tenant_id: str): # 2. 获取验证指导 print("\n2.2 获取域名验证指导...") - instructions = manager.get_domain_verification_instructions(domain.id) + instructions = manager.get_domain_verification_instructions(domain.id) print("✅ 验证指导:") print(f" - DNS 记录: {instructions['dns_record']}") print(f" - 文件验证: {instructions['file_verification']}") # 3. 验证域名 print("\n2.3 验证域名...") - verified = manager.verify_domain(tenant_id, domain.id) + verified = manager.verify_domain(tenant_id, domain.id) print(f"✅ 域名验证结果: {verified}") # 4. 通过域名获取租户 print("\n2.4 通过域名获取租户...") - by_domain = manager.get_tenant_by_domain("test.example.com") + by_domain = manager.get_tenant_by_domain("test.example.com") if by_domain: print(f"✅ 通过域名获取租户成功: {by_domain.name}") else: @@ -104,7 +104,7 @@ def test_domain_management(tenant_id: str): # 5. 列出域名 print("\n2.5 列出所有域名...") - domains = manager.list_domains(tenant_id) + domains = manager.list_domains(tenant_id) print(f"✅ 找到 {len(domains)} 个域名") for d in domains: print(f" - {d.domain} ({d.status})") @@ -112,25 +112,25 @@ def test_domain_management(tenant_id: str): return domain.id -def test_branding_management(tenant_id: str): +def test_branding_management(tenant_id: str) -> None: """测试品牌白标功能""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试 3: 品牌白标") - print("=" * 60) + print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 更新品牌配置 print("\n3.1 更新品牌配置...") - branding = manager.update_branding( - tenant_id=tenant_id, - logo_url="https://example.com/logo.png", - favicon_url="https://example.com/favicon.ico", - primary_color="#1890ff", - secondary_color="#52c41a", - custom_css=".header { background: #1890ff; }", - custom_js="console.log('Custom JS loaded');", - login_page_bg="https://example.com/bg.jpg", + branding = manager.update_branding( + tenant_id = tenant_id, + logo_url = "https://example.com/logo.png", + favicon_url = "https://example.com/favicon.ico", + primary_color = "#1890ff", + secondary_color = "#52c41a", + custom_css = ".header { background: #1890ff; }", + custom_js = "console.log('Custom JS loaded');", + login_page_bg = "https://example.com/bg.jpg", ) print("✅ 品牌配置更新成功") print(f" - Logo: {branding.logo_url}") @@ -139,67 +139,67 @@ def test_branding_management(tenant_id: str): # 2. 获取品牌配置 print("\n3.2 获取品牌配置...") - fetched = manager.get_branding(tenant_id) + fetched = manager.get_branding(tenant_id) assert fetched is not None, "获取品牌配置失败" print("✅ 获取品牌配置成功") # 3. 生成品牌 CSS print("\n3.3 生成品牌 CSS...") - css = manager.get_branding_css(tenant_id) + css = manager.get_branding_css(tenant_id) print(f"✅ 生成 CSS 成功 ({len(css)} 字符)") print(f" CSS 预览:\n{css[:200]}...") return branding.id -def test_member_management(tenant_id: str): +def test_member_management(tenant_id: str) -> None: """测试成员管理功能""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试 4: 成员管理") - print("=" * 60) + print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 邀请成员 print("\n4.1 邀请成员...") - member1 = manager.invite_member( - tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001" + member1 = manager.invite_member( + tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001" ) print(f"✅ 成员邀请成功: {member1.email}") print(f" - ID: {member1.id}") print(f" - 角色: {member1.role}") print(f" - 权限: {member1.permissions}") - member2 = manager.invite_member( - tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001" + member2 = manager.invite_member( + tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001" ) print(f"✅ 成员邀请成功: {member2.email}") # 2. 接受邀请 print("\n4.2 接受邀请...") - accepted = manager.accept_invitation(member1.id, "user_002") + accepted = manager.accept_invitation(member1.id, "user_002") print(f"✅ 邀请接受结果: {accepted}") # 3. 列出成员 print("\n4.3 列出所有成员...") - members = manager.list_members(tenant_id) + members = manager.list_members(tenant_id) print(f"✅ 找到 {len(members)} 个成员") for m in members: print(f" - {m.email} ({m.role}) - {m.status}") # 4. 检查权限 print("\n4.4 检查权限...") - can_manage = manager.check_permission(tenant_id, "user_002", "project", "create") + can_manage = manager.check_permission(tenant_id, "user_002", "project", "create") print(f"✅ user_002 可以创建项目: {can_manage}") # 5. 更新成员角色 print("\n4.5 更新成员角色...") - updated = manager.update_member_role(tenant_id, member2.id, "viewer") + updated = manager.update_member_role(tenant_id, member2.id, "viewer") print(f"✅ 角色更新结果: {updated}") # 6. 获取用户所属租户 print("\n4.6 获取用户所属租户...") - user_tenants = manager.get_user_tenants("user_002") + user_tenants = manager.get_user_tenants("user_002") print(f"✅ user_002 属于 {len(user_tenants)} 个租户") for t in user_tenants: print(f" - {t['name']} ({t['member_role']})") @@ -207,30 +207,30 @@ def test_member_management(tenant_id: str): return member1.id, member2.id -def test_usage_tracking(tenant_id: str): +def test_usage_tracking(tenant_id: str) -> None: """测试资源使用统计功能""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("测试 5: 资源使用统计") - print("=" * 60) + print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 1. 记录使用 print("\n5.1 记录资源使用...") manager.record_usage( - tenant_id=tenant_id, - storage_bytes=1024 * 1024 * 50, # 50MB - transcription_seconds=600, # 10分钟 - api_calls=100, - projects_count=5, - entities_count=50, - members_count=3, + tenant_id = tenant_id, + storage_bytes = 1024 * 1024 * 50, # 50MB + transcription_seconds = 600, # 10分钟 + api_calls = 100, + projects_count = 5, + entities_count = 50, + members_count = 3, ) print("✅ 资源使用记录成功") # 2. 获取使用统计 print("\n5.2 获取使用统计...") - stats = manager.get_usage_stats(tenant_id) + stats = manager.get_usage_stats(tenant_id) print("✅ 使用统计:") print(f" - 存储: {stats['storage_mb']:.2f} MB") print(f" - 转录: {stats['transcription_minutes']:.2f} 分钟") @@ -243,19 +243,19 @@ def test_usage_tracking(tenant_id: str): # 3. 检查资源限制 print("\n5.3 检查资源限制...") for resource in ["storage", "transcription", "api_calls", "projects", "entities", "members"]: - allowed, current, limit = manager.check_resource_limit(tenant_id, resource) + allowed, current, limit = manager.check_resource_limit(tenant_id, resource) print(f" - {resource}: {current}/{limit} ({'✅' if allowed else '❌'})") return stats -def cleanup(tenant_id: str, domain_id: str, member_ids: list): +def cleanup(tenant_id: str, domain_id: str, member_ids: list) -> None: """清理测试数据""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("清理测试数据") - print("=" * 60) + print(" = " * 60) - manager = get_tenant_manager() + manager = get_tenant_manager() # 移除成员 for member_id in member_ids: @@ -273,28 +273,28 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list): print(f"✅ 租户已删除: {tenant_id}") -def main(): +def main() -> None: """主测试函数""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试") - print("=" * 60) + print(" = " * 60) - tenant_id = None - domain_id = None - member_ids = [] + tenant_id = None + domain_id = None + member_ids = [] try: # 运行所有测试 - tenant_id = test_tenant_management() - domain_id = test_domain_management(tenant_id) + tenant_id = test_tenant_management() + domain_id = test_domain_management(tenant_id) test_branding_management(tenant_id) - m1, m2 = test_member_management(tenant_id) - member_ids = [m1, m2] + m1, m2 = test_member_management(tenant_id) + member_ids = [m1, m2] test_usage_tracking(tenant_id) - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("✅ 所有测试通过!") - print("=" * 60) + print(" = " * 60) except Exception as e: print(f"\n❌ 测试失败: {e}") diff --git a/backend/test_phase8_task2.py b/backend/test_phase8_task2.py index f6f749e..bc2589c 100644 --- a/backend/test_phase8_task2.py +++ b/backend/test_phase8_task2.py @@ -12,31 +12,31 @@ from subscription_manager import PaymentProvider, SubscriptionManager sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -def test_subscription_manager(): +def test_subscription_manager() -> None: """测试订阅管理器""" - print("=" * 60) + print(" = " * 60) print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试") - print("=" * 60) + print(" = " * 60) # 使用临时文件数据库进行测试 - db_path = tempfile.mktemp(suffix=".db") + db_path = tempfile.mktemp(suffix = ".db") try: - manager = SubscriptionManager(db_path=db_path) + manager = SubscriptionManager(db_path = db_path) print("\n1. 测试订阅计划管理") print("-" * 40) # 获取默认计划 - plans = manager.list_plans() + plans = manager.list_plans() print(f"✓ 默认计划数量: {len(plans)}") for plan in plans: print(f" - {plan.name} ({plan.tier}): ¥{plan.price_monthly}/月") # 通过 tier 获取计划 - free_plan = manager.get_plan_by_tier("free") - pro_plan = manager.get_plan_by_tier("pro") - enterprise_plan = manager.get_plan_by_tier("enterprise") + free_plan = manager.get_plan_by_tier("free") + pro_plan = manager.get_plan_by_tier("pro") + enterprise_plan = manager.get_plan_by_tier("enterprise") assert free_plan is not None, "Free 计划应该存在" assert pro_plan is not None, "Pro 计划应该存在" @@ -49,14 +49,14 @@ def test_subscription_manager(): print("\n2. 测试订阅管理") print("-" * 40) - tenant_id = "test-tenant-001" + tenant_id = "test-tenant-001" # 创建订阅 - subscription = manager.create_subscription( - tenant_id=tenant_id, - plan_id=pro_plan.id, - payment_provider=PaymentProvider.STRIPE.value, - trial_days=14, + subscription = manager.create_subscription( + tenant_id = tenant_id, + plan_id = pro_plan.id, + payment_provider = PaymentProvider.STRIPE.value, + trial_days = 14, ) print(f"✓ 创建订阅: {subscription.id}") @@ -66,7 +66,7 @@ def test_subscription_manager(): print(f" - 试用结束: {subscription.trial_end}") # 获取租户订阅 - tenant_sub = manager.get_tenant_subscription(tenant_id) + tenant_sub = manager.get_tenant_subscription(tenant_id) assert tenant_sub is not None, "应该能获取到租户订阅" print(f"✓ 获取租户订阅: {tenant_sub.id}") @@ -74,27 +74,27 @@ def test_subscription_manager(): print("-" * 40) # 记录转录用量 - usage1 = manager.record_usage( - tenant_id=tenant_id, - resource_type="transcription", - quantity=120, - unit="minute", - description="会议转录", + usage1 = manager.record_usage( + tenant_id = tenant_id, + resource_type = "transcription", + quantity = 120, + unit = "minute", + description = "会议转录", ) print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}") # 记录存储用量 - usage2 = manager.record_usage( - tenant_id=tenant_id, - resource_type="storage", - quantity=2.5, - unit="gb", - description="文件存储", + usage2 = manager.record_usage( + tenant_id = tenant_id, + resource_type = "storage", + quantity = 2.5, + unit = "gb", + description = "文件存储", ) print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}") # 获取用量汇总 - summary = manager.get_usage_summary(tenant_id) + summary = manager.get_usage_summary(tenant_id) print("✓ 用量汇总:") print(f" - 总费用: ¥{summary['total_cost']:.2f}") for resource, data in summary["breakdown"].items(): @@ -104,12 +104,12 @@ def test_subscription_manager(): print("-" * 40) # 创建支付 - payment = manager.create_payment( - tenant_id=tenant_id, - amount=99.0, - currency="CNY", - provider=PaymentProvider.ALIPAY.value, - payment_method="qrcode", + payment = manager.create_payment( + tenant_id = tenant_id, + amount = 99.0, + currency = "CNY", + provider = PaymentProvider.ALIPAY.value, + payment_method = "qrcode", ) print(f"✓ 创建支付: {payment.id}") print(f" - 金额: ¥{payment.amount}") @@ -117,22 +117,22 @@ def test_subscription_manager(): print(f" - 状态: {payment.status}") # 确认支付 - confirmed = manager.confirm_payment(payment.id, "alipay_123456") + confirmed = manager.confirm_payment(payment.id, "alipay_123456") print(f"✓ 确认支付完成: {confirmed.status}") # 列出支付记录 - payments = manager.list_payments(tenant_id) + payments = manager.list_payments(tenant_id) print(f"✓ 支付记录数量: {len(payments)}") print("\n5. 测试发票管理") print("-" * 40) # 列出发票 - invoices = manager.list_invoices(tenant_id) + invoices = manager.list_invoices(tenant_id) print(f"✓ 发票数量: {len(invoices)}") if invoices: - invoice = invoices[0] + invoice = invoices[0] print(f" - 发票号: {invoice.invoice_number}") print(f" - 金额: ¥{invoice.amount_due}") print(f" - 状态: {invoice.status}") @@ -141,12 +141,12 @@ def test_subscription_manager(): print("-" * 40) # 申请退款 - refund = manager.request_refund( - tenant_id=tenant_id, - payment_id=payment.id, - amount=50.0, - reason="服务不满意", - requested_by="user_001", + refund = manager.request_refund( + tenant_id = tenant_id, + payment_id = payment.id, + amount = 50.0, + reason = "服务不满意", + requested_by = "user_001", ) print(f"✓ 申请退款: {refund.id}") print(f" - 金额: ¥{refund.amount}") @@ -154,21 +154,21 @@ def test_subscription_manager(): print(f" - 状态: {refund.status}") # 批准退款 - approved = manager.approve_refund(refund.id, "admin_001") + approved = manager.approve_refund(refund.id, "admin_001") print(f"✓ 批准退款: {approved.status}") # 完成退款 - completed = manager.complete_refund(refund.id, "refund_123456") + completed = manager.complete_refund(refund.id, "refund_123456") print(f"✓ 完成退款: {completed.status}") # 列出退款记录 - refunds = manager.list_refunds(tenant_id) + refunds = manager.list_refunds(tenant_id) print(f"✓ 退款记录数量: {len(refunds)}") print("\n7. 测试账单历史") print("-" * 40) - history = manager.get_billing_history(tenant_id) + history = manager.get_billing_history(tenant_id) print(f"✓ 账单历史记录数量: {len(history)}") for h in history: print(f" - [{h.type}] {h.description}: ¥{h.amount}") @@ -177,24 +177,24 @@ def test_subscription_manager(): print("-" * 40) # Stripe Checkout - stripe_session = manager.create_stripe_checkout_session( - tenant_id=tenant_id, - plan_id=enterprise_plan.id, - success_url="https://example.com/success", - cancel_url="https://example.com/cancel", + stripe_session = manager.create_stripe_checkout_session( + tenant_id = tenant_id, + plan_id = enterprise_plan.id, + success_url = "https://example.com/success", + cancel_url = "https://example.com/cancel", ) print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}") # 支付宝订单 - alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id) + alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id) print(f"✓ 支付宝订单: {alipay_order['order_id']}") # 微信支付订单 - wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id) + wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id) print(f"✓ 微信支付订单: {wechat_order['order_id']}") # Webhook 处理 - webhook_result = manager.handle_webhook( + webhook_result = manager.handle_webhook( "stripe", {"event_type": "checkout.session.completed", "data": {"object": {"id": "cs_test"}}}, ) @@ -204,19 +204,19 @@ def test_subscription_manager(): print("-" * 40) # 更改计划 - changed = manager.change_plan( - subscription_id=subscription.id, new_plan_id=enterprise_plan.id + changed = manager.change_plan( + subscription_id = subscription.id, new_plan_id = enterprise_plan.id ) print(f"✓ 更改计划: {changed.plan_id} (Enterprise)") # 取消订阅 - cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True) + cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True) print(f"✓ 取消订阅: {cancelled.status}") print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}") - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("所有测试通过! ✓") - print("=" * 60) + print(" = " * 60) finally: # 清理临时数据库 diff --git a/backend/test_phase8_task4.py b/backend/test_phase8_task4.py index 6305dfc..702363d 100644 --- a/backend/test_phase8_task4.py +++ b/backend/test_phase8_task4.py @@ -14,31 +14,31 @@ from ai_manager import ModelType, PredictionType, get_ai_manager sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -def test_custom_model(): +def test_custom_model() -> None: """测试自定义模型功能""" print("\n=== 测试自定义模型 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 创建自定义模型 print("1. 创建自定义模型...") - model = manager.create_custom_model( - tenant_id="tenant_001", - name="领域实体识别模型", - description="用于识别医疗领域实体的自定义模型", - model_type=ModelType.CUSTOM_NER, - training_data={ + model = manager.create_custom_model( + tenant_id = "tenant_001", + name = "领域实体识别模型", + description = "用于识别医疗领域实体的自定义模型", + model_type = ModelType.CUSTOM_NER, + training_data = { "entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"], "domain": "medical", }, - hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32}, - created_by="user_001", + hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32}, + created_by = "user_001", ) print(f" 创建成功: {model.id}, 状态: {model.status.value}") # 2. 添加训练样本 print("2. 添加训练样本...") - samples = [ + samples = [ { "text": "患者张三患有高血压,正在服用降压药治疗。", "entities": [ @@ -66,22 +66,22 @@ def test_custom_model(): ] for sample_data in samples: - sample = manager.add_training_sample( - model_id=model.id, - text=sample_data["text"], - entities=sample_data["entities"], - metadata={"source": "manual"}, + sample = manager.add_training_sample( + model_id = model.id, + text = sample_data["text"], + entities = sample_data["entities"], + metadata = {"source": "manual"}, ) print(f" 添加样本: {sample.id}") # 3. 获取训练样本 print("3. 获取训练样本...") - all_samples = manager.get_training_samples(model.id) + all_samples = manager.get_training_samples(model.id) print(f" 共有 {len(all_samples)} 个训练样本") # 4. 列出自定义模型 print("4. 列出自定义模型...") - models = manager.list_custom_models(tenant_id="tenant_001") + models = manager.list_custom_models(tenant_id = "tenant_001") print(f" 找到 {len(models)} 个模型") for m in models: print(f" - {m.name} ({m.model_type.value}): {m.status.value}") @@ -89,16 +89,16 @@ def test_custom_model(): return model.id -async def test_train_and_predict(model_id: str): +async def test_train_and_predict(model_id: str) -> None: """测试训练和预测""" print("\n=== 测试模型训练和预测 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 训练模型 print("1. 训练模型...") try: - trained_model = await manager.train_custom_model(model_id) + trained_model = await manager.train_custom_model(model_id) print(f" 训练完成: {trained_model.status.value}") print(f" 指标: {trained_model.metrics}") except Exception as e: @@ -107,50 +107,50 @@ async def test_train_and_predict(model_id: str): # 2. 使用模型预测 print("2. 使用模型预测...") - test_text = "赵六患有糖尿病,正在使用胰岛素治疗。" + test_text = "赵六患有糖尿病,正在使用胰岛素治疗。" try: - entities = await manager.predict_with_custom_model(model_id, test_text) + entities = await manager.predict_with_custom_model(model_id, test_text) print(f" 输入: {test_text}") print(f" 预测实体: {entities}") except Exception as e: print(f" 预测失败: {e}") -def test_prediction_models(): +def test_prediction_models() -> None: """测试预测模型""" print("\n=== 测试预测模型 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 创建趋势预测模型 print("1. 创建趋势预测模型...") - trend_model = manager.create_prediction_model( - tenant_id="tenant_001", - project_id="project_001", - name="实体数量趋势预测", - prediction_type=PredictionType.TREND, - target_entity_type="PERSON", - features=["entity_count", "time_period", "document_count"], - model_config={"algorithm": "linear_regression", "window_size": 7}, + trend_model = manager.create_prediction_model( + tenant_id = "tenant_001", + project_id = "project_001", + name = "实体数量趋势预测", + prediction_type = PredictionType.TREND, + target_entity_type = "PERSON", + features = ["entity_count", "time_period", "document_count"], + model_config = {"algorithm": "linear_regression", "window_size": 7}, ) print(f" 创建成功: {trend_model.id}") # 2. 创建异常检测模型 print("2. 创建异常检测模型...") - anomaly_model = manager.create_prediction_model( - tenant_id="tenant_001", - project_id="project_001", - name="实体增长异常检测", - prediction_type=PredictionType.ANOMALY, - target_entity_type=None, - features=["daily_growth", "weekly_growth"], - model_config={"threshold": 2.5, "sensitivity": "medium"}, + anomaly_model = manager.create_prediction_model( + tenant_id = "tenant_001", + project_id = "project_001", + name = "实体增长异常检测", + prediction_type = PredictionType.ANOMALY, + target_entity_type = None, + features = ["daily_growth", "weekly_growth"], + model_config = {"threshold": 2.5, "sensitivity": "medium"}, ) print(f" 创建成功: {anomaly_model.id}") # 3. 列出预测模型 print("3. 列出预测模型...") - models = manager.list_prediction_models(tenant_id="tenant_001") + models = manager.list_prediction_models(tenant_id = "tenant_001") print(f" 找到 {len(models)} 个预测模型") for m in models: print(f" - {m.name} ({m.prediction_type.value})") @@ -158,15 +158,15 @@ def test_prediction_models(): return trend_model.id, anomaly_model.id -async def test_predictions(trend_model_id: str, anomaly_model_id: str): +async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None: """测试预测功能""" print("\n=== 测试预测功能 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 1. 训练趋势预测模型 print("1. 训练趋势预测模型...") - historical_data = [ + historical_data = [ {"date": "2024-01-01", "value": 10}, {"date": "2024-01-02", "value": 12}, {"date": "2024-01-03", "value": 15}, @@ -175,62 +175,62 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str): {"date": "2024-01-06", "value": 20}, {"date": "2024-01-07", "value": 22}, ] - trained = await manager.train_prediction_model(trend_model_id, historical_data) + trained = await manager.train_prediction_model(trend_model_id, historical_data) print(f" 训练完成,准确率: {trained.accuracy}") # 2. 趋势预测 print("2. 趋势预测...") - trend_result = await manager.predict( + trend_result = await manager.predict( trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]} ) print(f" 预测结果: {trend_result.prediction_data}") # 3. 异常检测 print("3. 异常检测...") - anomaly_result = await manager.predict( + anomaly_result = await manager.predict( anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]} ) print(f" 检测结果: {anomaly_result.prediction_data}") -def test_kg_rag(): +def test_kg_rag() -> None: """测试知识图谱 RAG""" print("\n=== 测试知识图谱 RAG ===") - manager = get_ai_manager() + manager = get_ai_manager() # 创建 RAG 配置 print("1. 创建知识图谱 RAG 配置...") - rag = manager.create_kg_rag( - tenant_id="tenant_001", - project_id="project_001", - name="项目知识问答", - description="基于项目知识图谱的智能问答", - kg_config={ + rag = manager.create_kg_rag( + tenant_id = "tenant_001", + project_id = "project_001", + name = "项目知识问答", + description = "基于项目知识图谱的智能问答", + kg_config = { "entity_types": ["PERSON", "ORG", "PROJECT", "TECH"], "relation_types": ["works_with", "belongs_to", "depends_on"], }, - retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True}, - generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True}, + retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True}, + generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True}, ) print(f" 创建成功: {rag.id}") # 列出 RAG 配置 print("2. 列出 RAG 配置...") - rags = manager.list_kg_rags(tenant_id="tenant_001") + rags = manager.list_kg_rags(tenant_id = "tenant_001") print(f" 找到 {len(rags)} 个配置") return rag.id -async def test_kg_rag_query(rag_id: str): +async def test_kg_rag_query(rag_id: str) -> None: """测试 RAG 查询""" print("\n=== 测试知识图谱 RAG 查询 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 模拟项目实体和关系 - project_entities = [ + project_entities = [ {"id": "e1", "name": "张三", "type": "PERSON", "definition": "项目经理"}, {"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"}, {"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"}, @@ -238,7 +238,7 @@ async def test_kg_rag_query(rag_id: str): {"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"}, ] - project_relations = [ + project_relations = [ { "source_entity_id": "e1", "target_entity_id": "e3", @@ -275,14 +275,14 @@ async def test_kg_rag_query(rag_id: str): # 执行查询 print("1. 执行 RAG 查询...") - query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?" + query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?" try: - result = await manager.query_kg_rag( - rag_id=rag_id, - query=query_text, - project_entities=project_entities, - project_relations=project_relations, + result = await manager.query_kg_rag( + rag_id = rag_id, + query = query_text, + project_entities = project_entities, + project_relations = project_relations, ) print(f" 查询: {result.query}") @@ -294,14 +294,14 @@ async def test_kg_rag_query(rag_id: str): print(f" 查询失败: {e}") -async def test_smart_summary(): +async def test_smart_summary() -> None: """测试智能摘要""" print("\n=== 测试智能摘要 ===") - manager = get_ai_manager() + manager = get_ai_manager() # 模拟转录文本 - transcript_text = """ + transcript_text = """ 今天的会议主要讨论了 Project Alpha 的进展情况。张三作为项目经理, 汇报了当前的项目进度,表示已经完成了 80% 的开发工作。李四提出了 一些关于 Kubernetes 部署的问题,建议我们采用新的部署策略。 @@ -309,7 +309,7 @@ async def test_smart_summary(): 大家一致认为项目进展顺利,预计可以按时交付。 """ - content_data = { + content_data = { "text": transcript_text, "entities": [ {"name": "张三", "type": "PERSON"}, @@ -320,18 +320,18 @@ async def test_smart_summary(): } # 生成不同类型的摘要 - summary_types = ["extractive", "abstractive", "key_points"] + summary_types = ["extractive", "abstractive", "key_points"] for summary_type in summary_types: print(f"1. 生成 {summary_type} 类型摘要...") try: - summary = await manager.generate_smart_summary( - tenant_id="tenant_001", - project_id="project_001", - source_type="transcript", - source_id="transcript_001", - summary_type=summary_type, - content_data=content_data, + summary = await manager.generate_smart_summary( + tenant_id = "tenant_001", + project_id = "project_001", + source_type = "transcript", + source_id = "transcript_001", + summary_type = summary_type, + content_data = content_data, ) print(f" 摘要类型: {summary.summary_type}") @@ -342,27 +342,27 @@ async def test_smart_summary(): print(f" 生成失败: {e}") -async def main(): +async def main() -> None: """主测试函数""" - print("=" * 60) + print(" = " * 60) print("InsightFlow Phase 8 Task 4 - AI 能力增强测试") - print("=" * 60) + print(" = " * 60) try: # 测试自定义模型 - model_id = test_custom_model() + model_id = test_custom_model() # 测试训练和预测 await test_train_and_predict(model_id) # 测试预测模型 - trend_model_id, anomaly_model_id = test_prediction_models() + trend_model_id, anomaly_model_id = test_prediction_models() # 测试预测功能 await test_predictions(trend_model_id, anomaly_model_id) # 测试知识图谱 RAG - rag_id = test_kg_rag() + rag_id = test_kg_rag() # 测试 RAG 查询 await test_kg_rag_query(rag_id) @@ -370,9 +370,9 @@ async def main(): # 测试智能摘要 await test_smart_summary() - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("所有测试完成!") - print("=" * 60) + print(" = " * 60) except Exception as e: print(f"\n测试失败: {e}") diff --git a/backend/test_phase8_task5.py b/backend/test_phase8_task5.py index 793f0a6..ee10a8f 100644 --- a/backend/test_phase8_task5.py +++ b/backend/test_phase8_task5.py @@ -28,7 +28,7 @@ from growth_manager import ( ) # 添加 backend 目录到路径 -backend_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.dirname(os.path.abspath(__file__)) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) @@ -36,35 +36,35 @@ if backend_dir not in sys.path: class TestGrowthManager: """测试 Growth Manager 功能""" - def __init__(self): - self.manager = GrowthManager() - self.test_tenant_id = "test_tenant_001" - self.test_user_id = "test_user_001" - self.test_results = [] + def __init__(self) -> None: + self.manager = GrowthManager() + self.test_tenant_id = "test_tenant_001" + self.test_user_id = "test_user_001" + self.test_results = [] - def log(self, message: str, success: bool = True): + def log(self, message: str, success: bool = True) -> None: """记录测试结果""" - status = "✅" if success else "❌" + status = "✅" if success else "❌" print(f"{status} {message}") self.test_results.append((message, success)) # ==================== 测试用户行为分析 ==================== - async def test_track_event(self): + async def test_track_event(self) -> None: """测试事件追踪""" print("\n📊 测试事件追踪...") try: - event = await self.manager.track_event( - tenant_id=self.test_tenant_id, - user_id=self.test_user_id, - event_type=EventType.PAGE_VIEW, - event_name="dashboard_view", - properties={"page": "/dashboard", "duration": 120}, - session_id="session_001", - device_info={"browser": "Chrome", "os": "MacOS"}, - referrer="https://google.com", - utm_params={"source": "google", "medium": "organic", "campaign": "summer"}, + event = await self.manager.track_event( + tenant_id = self.test_tenant_id, + user_id = self.test_user_id, + event_type = EventType.PAGE_VIEW, + event_name = "dashboard_view", + properties = {"page": "/dashboard", "duration": 120}, + session_id = "session_001", + device_info = {"browser": "Chrome", "os": "MacOS"}, + referrer = "https://google.com", + utm_params = {"source": "google", "medium": "organic", "campaign": "summer"}, ) assert event.id is not None @@ -74,15 +74,15 @@ class TestGrowthManager: self.log(f"事件追踪成功: {event.id}") return True except Exception as e: - self.log(f"事件追踪失败: {e}", success=False) + self.log(f"事件追踪失败: {e}", success = False) return False - async def test_track_multiple_events(self): + async def test_track_multiple_events(self) -> None: """测试追踪多个事件""" print("\n📊 测试追踪多个事件...") try: - events = [ + events = [ (EventType.FEATURE_USE, "entity_extraction", {"entity_count": 5}), (EventType.FEATURE_USE, "relation_discovery", {"relation_count": 3}), (EventType.CONVERSION, "upgrade_click", {"plan": "pro"}), @@ -91,25 +91,25 @@ class TestGrowthManager: for event_type, event_name, props in events: await self.manager.track_event( - tenant_id=self.test_tenant_id, - user_id=self.test_user_id, - event_type=event_type, - event_name=event_name, - properties=props, + tenant_id = self.test_tenant_id, + user_id = self.test_user_id, + event_type = event_type, + event_name = event_name, + properties = props, ) self.log(f"成功追踪 {len(events)} 个事件") return True except Exception as e: - self.log(f"批量事件追踪失败: {e}", success=False) + self.log(f"批量事件追踪失败: {e}", success = False) return False - def test_get_user_profile(self): + def test_get_user_profile(self) -> None: """测试获取用户画像""" print("\n👤 测试用户画像...") try: - profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id) + profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id) if profile: assert profile.user_id == self.test_user_id @@ -120,18 +120,18 @@ class TestGrowthManager: return True except Exception as e: - self.log(f"获取用户画像失败: {e}", success=False) + self.log(f"获取用户画像失败: {e}", success = False) return False - def test_get_analytics_summary(self): + def test_get_analytics_summary(self) -> None: """测试获取分析汇总""" print("\n📈 测试分析汇总...") try: - summary = self.manager.get_user_analytics_summary( - tenant_id=self.test_tenant_id, - start_date=datetime.now() - timedelta(days=7), - end_date=datetime.now(), + summary = self.manager.get_user_analytics_summary( + tenant_id = self.test_tenant_id, + start_date = datetime.now() - timedelta(days = 7), + end_date = datetime.now(), ) assert "unique_users" in summary @@ -141,25 +141,25 @@ class TestGrowthManager: self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件") return True except Exception as e: - self.log(f"获取分析汇总失败: {e}", success=False) + self.log(f"获取分析汇总失败: {e}", success = False) return False - def test_create_funnel(self): + def test_create_funnel(self) -> None: """测试创建转化漏斗""" print("\n🎯 测试创建转化漏斗...") try: - funnel = self.manager.create_funnel( - tenant_id=self.test_tenant_id, - name="用户注册转化漏斗", - description="从访问到完成注册的转化流程", - steps=[ + funnel = self.manager.create_funnel( + tenant_id = self.test_tenant_id, + name = "用户注册转化漏斗", + description = "从访问到完成注册的转化流程", + steps = [ {"name": "访问首页", "event_name": "page_view_home"}, {"name": "点击注册", "event_name": "signup_click"}, {"name": "填写信息", "event_name": "signup_form_fill"}, {"name": "完成注册", "event_name": "signup_complete"}, ], - created_by="test", + created_by = "test", ) assert funnel.id is not None @@ -168,10 +168,10 @@ class TestGrowthManager: self.log(f"漏斗创建成功: {funnel.id}") return funnel.id except Exception as e: - self.log(f"创建漏斗失败: {e}", success=False) + self.log(f"创建漏斗失败: {e}", success = False) return None - def test_analyze_funnel(self, funnel_id: str): + def test_analyze_funnel(self, funnel_id: str) -> None: """测试分析漏斗""" print("\n📉 测试漏斗分析...") @@ -180,10 +180,10 @@ class TestGrowthManager: return False try: - analysis = self.manager.analyze_funnel( - funnel_id=funnel_id, - period_start=datetime.now() - timedelta(days=30), - period_end=datetime.now(), + analysis = self.manager.analyze_funnel( + funnel_id = funnel_id, + period_start = datetime.now() - timedelta(days = 30), + period_end = datetime.now(), ) if analysis: @@ -194,18 +194,18 @@ class TestGrowthManager: self.log("漏斗分析返回空结果") return False except Exception as e: - self.log(f"漏斗分析失败: {e}", success=False) + self.log(f"漏斗分析失败: {e}", success = False) return False - def test_calculate_retention(self): + def test_calculate_retention(self) -> None: """测试留存率计算""" print("\n🔄 测试留存率计算...") try: - retention = self.manager.calculate_retention( - tenant_id=self.test_tenant_id, - cohort_date=datetime.now() - timedelta(days=7), - periods=[1, 3, 7], + retention = self.manager.calculate_retention( + tenant_id = self.test_tenant_id, + cohort_date = datetime.now() - timedelta(days = 7), + periods = [1, 3, 7], ) assert "cohort_date" in retention @@ -214,34 +214,34 @@ class TestGrowthManager: self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户") return True except Exception as e: - self.log(f"留存率计算失败: {e}", success=False) + self.log(f"留存率计算失败: {e}", success = False) return False # ==================== 测试 A/B 测试框架 ==================== - def test_create_experiment(self): + def test_create_experiment(self) -> None: """测试创建实验""" print("\n🧪 测试创建 A/B 测试实验...") try: - experiment = self.manager.create_experiment( - tenant_id=self.test_tenant_id, - name="首页按钮颜色测试", - description="测试不同按钮颜色对转化率的影响", - hypothesis="蓝色按钮比红色按钮有更高的点击率", - variants=[ + experiment = self.manager.create_experiment( + tenant_id = self.test_tenant_id, + name = "首页按钮颜色测试", + description = "测试不同按钮颜色对转化率的影响", + hypothesis = "蓝色按钮比红色按钮有更高的点击率", + variants = [ {"id": "control", "name": "红色按钮", "is_control": True}, {"id": "variant_a", "name": "蓝色按钮", "is_control": False}, {"id": "variant_b", "name": "绿色按钮", "is_control": False}, ], - traffic_allocation=TrafficAllocationType.RANDOM, - traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, - target_audience={"conditions": []}, - primary_metric="button_click_rate", - secondary_metrics=["conversion_rate", "bounce_rate"], - min_sample_size=100, - confidence_level=0.95, - created_by="test", + traffic_allocation = TrafficAllocationType.RANDOM, + traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, + target_audience = {"conditions": []}, + primary_metric = "button_click_rate", + secondary_metrics = ["conversion_rate", "bounce_rate"], + min_sample_size = 100, + confidence_level = 0.95, + created_by = "test", ) assert experiment.id is not None @@ -250,23 +250,23 @@ class TestGrowthManager: self.log(f"实验创建成功: {experiment.id}") return experiment.id except Exception as e: - self.log(f"创建实验失败: {e}", success=False) + self.log(f"创建实验失败: {e}", success = False) return None - def test_list_experiments(self): + def test_list_experiments(self) -> None: """测试列出实验""" print("\n📋 测试列出实验...") try: - experiments = self.manager.list_experiments(self.test_tenant_id) + experiments = self.manager.list_experiments(self.test_tenant_id) self.log(f"列出 {len(experiments)} 个实验") return True except Exception as e: - self.log(f"列出实验失败: {e}", success=False) + self.log(f"列出实验失败: {e}", success = False) return False - def test_assign_variant(self, experiment_id: str): + def test_assign_variant(self, experiment_id: str) -> None: """测试分配变体""" print("\n🎲 测试分配实验变体...") @@ -279,26 +279,26 @@ class TestGrowthManager: self.manager.start_experiment(experiment_id) # 测试多个用户的变体分配 - test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"] - assignments = {} + test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"] + assignments = {} for user_id in test_users: - variant_id = self.manager.assign_variant( - experiment_id=experiment_id, - user_id=user_id, - user_attributes={"user_id": user_id, "segment": "new"}, + variant_id = self.manager.assign_variant( + experiment_id = experiment_id, + user_id = user_id, + user_attributes = {"user_id": user_id, "segment": "new"}, ) if variant_id: - assignments[user_id] = variant_id + assignments[user_id] = variant_id self.log(f"变体分配完成: {len(assignments)} 个用户") return True except Exception as e: - self.log(f"变体分配失败: {e}", success=False) + self.log(f"变体分配失败: {e}", success = False) return False - def test_record_experiment_metric(self, experiment_id: str): + def test_record_experiment_metric(self, experiment_id: str) -> None: """测试记录实验指标""" print("\n📊 测试记录实验指标...") @@ -308,7 +308,7 @@ class TestGrowthManager: try: # 模拟记录一些指标 - test_data = [ + test_data = [ ("user_001", "control", 1), ("user_002", "variant_a", 1), ("user_003", "variant_b", 0), @@ -318,20 +318,20 @@ class TestGrowthManager: for user_id, variant_id, value in test_data: self.manager.record_experiment_metric( - experiment_id=experiment_id, - variant_id=variant_id, - user_id=user_id, - metric_name="button_click_rate", - metric_value=value, + experiment_id = experiment_id, + variant_id = variant_id, + user_id = user_id, + metric_name = "button_click_rate", + metric_value = value, ) self.log(f"成功记录 {len(test_data)} 条指标") return True except Exception as e: - self.log(f"记录指标失败: {e}", success=False) + self.log(f"记录指标失败: {e}", success = False) return False - def test_analyze_experiment(self, experiment_id: str): + def test_analyze_experiment(self, experiment_id: str) -> None: """测试分析实验结果""" print("\n📈 测试分析实验结果...") @@ -340,31 +340,31 @@ class TestGrowthManager: return False try: - result = self.manager.analyze_experiment(experiment_id) + result = self.manager.analyze_experiment(experiment_id) if "error" not in result: self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体") return True else: - self.log(f"实验分析返回错误: {result['error']}", success=False) + self.log(f"实验分析返回错误: {result['error']}", success = False) return False except Exception as e: - self.log(f"实验分析失败: {e}", success=False) + self.log(f"实验分析失败: {e}", success = False) return False # ==================== 测试邮件营销 ==================== - def test_create_email_template(self): + def test_create_email_template(self) -> None: """测试创建邮件模板""" print("\n📧 测试创建邮件模板...") try: - template = self.manager.create_email_template( - tenant_id=self.test_tenant_id, - name="欢迎邮件", - template_type=EmailTemplateType.WELCOME, - subject="欢迎加入 InsightFlow!", - html_content=""" + template = self.manager.create_email_template( + tenant_id = self.test_tenant_id, + name = "欢迎邮件", + template_type = EmailTemplateType.WELCOME, + subject = "欢迎加入 InsightFlow!", + html_content = """

欢迎,{{user_name}}!

感谢您注册 InsightFlow。我们很高兴您能加入我们!

您的账户已创建,可以开始使用以下功能:

@@ -373,10 +373,10 @@ class TestGrowthManager:
  • 智能实体提取
  • 团队协作
  • -

    立即开始使用

    +

    立即开始使用

    """, - from_name="InsightFlow 团队", - from_email="welcome@insightflow.io", + from_name = "InsightFlow 团队", + from_email = "welcome@insightflow.io", ) assert template.id is not None @@ -385,23 +385,23 @@ class TestGrowthManager: self.log(f"邮件模板创建成功: {template.id}") return template.id except Exception as e: - self.log(f"创建邮件模板失败: {e}", success=False) + self.log(f"创建邮件模板失败: {e}", success = False) return None - def test_list_email_templates(self): + def test_list_email_templates(self) -> None: """测试列出邮件模板""" print("\n📧 测试列出邮件模板...") try: - templates = self.manager.list_email_templates(self.test_tenant_id) + templates = self.manager.list_email_templates(self.test_tenant_id) self.log(f"列出 {len(templates)} 个邮件模板") return True except Exception as e: - self.log(f"列出邮件模板失败: {e}", success=False) + self.log(f"列出邮件模板失败: {e}", success = False) return False - def test_render_template(self, template_id: str): + def test_render_template(self, template_id: str) -> None: """测试渲染邮件模板""" print("\n🎨 测试渲染邮件模板...") @@ -410,9 +410,9 @@ class TestGrowthManager: return False try: - rendered = self.manager.render_template( - template_id=template_id, - variables={ + rendered = self.manager.render_template( + template_id = template_id, + variables = { "user_name": "张三", "dashboard_url": "https://app.insightflow.io/dashboard", }, @@ -424,13 +424,13 @@ class TestGrowthManager: self.log(f"模板渲染成功: {rendered['subject']}") return True else: - self.log("模板渲染返回空结果", success=False) + self.log("模板渲染返回空结果", success = False) return False except Exception as e: - self.log(f"模板渲染失败: {e}", success=False) + self.log(f"模板渲染失败: {e}", success = False) return False - def test_create_email_campaign(self, template_id: str): + def test_create_email_campaign(self, template_id: str) -> None: """测试创建邮件营销活动""" print("\n📮 测试创建邮件营销活动...") @@ -439,11 +439,11 @@ class TestGrowthManager: return None try: - campaign = self.manager.create_email_campaign( - tenant_id=self.test_tenant_id, - name="新用户欢迎活动", - template_id=template_id, - recipient_list=[ + campaign = self.manager.create_email_campaign( + tenant_id = self.test_tenant_id, + name = "新用户欢迎活动", + template_id = template_id, + recipient_list = [ {"user_id": "user_001", "email": "user1@example.com"}, {"user_id": "user_002", "email": "user2@example.com"}, {"user_id": "user_003", "email": "user3@example.com"}, @@ -456,21 +456,21 @@ class TestGrowthManager: self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人") return campaign.id except Exception as e: - self.log(f"创建营销活动失败: {e}", success=False) + self.log(f"创建营销活动失败: {e}", success = False) return None - def test_create_automation_workflow(self): + def test_create_automation_workflow(self) -> None: """测试创建自动化工作流""" print("\n🤖 测试创建自动化工作流...") try: - workflow = self.manager.create_automation_workflow( - tenant_id=self.test_tenant_id, - name="新用户欢迎序列", - description="用户注册后自动发送欢迎邮件序列", - trigger_type=WorkflowTriggerType.USER_SIGNUP, - trigger_conditions={"event": "user_signup"}, - actions=[ + workflow = self.manager.create_automation_workflow( + tenant_id = self.test_tenant_id, + name = "新用户欢迎序列", + description = "用户注册后自动发送欢迎邮件序列", + trigger_type = WorkflowTriggerType.USER_SIGNUP, + trigger_conditions = {"event": "user_signup"}, + actions = [ {"type": "send_email", "template_type": "welcome", "delay_hours": 0}, {"type": "send_email", "template_type": "onboarding", "delay_hours": 24}, {"type": "send_email", "template_type": "feature_tips", "delay_hours": 72}, @@ -483,27 +483,27 @@ class TestGrowthManager: self.log(f"自动化工作流创建成功: {workflow.id}") return True except Exception as e: - self.log(f"创建工作流失败: {e}", success=False) + self.log(f"创建工作流失败: {e}", success = False) return False # ==================== 测试推荐系统 ==================== - def test_create_referral_program(self): + def test_create_referral_program(self) -> None: """测试创建推荐计划""" print("\n🎁 测试创建推荐计划...") try: - program = self.manager.create_referral_program( - tenant_id=self.test_tenant_id, - name="邀请好友奖励计划", - description="邀请好友注册,双方获得积分奖励", - referrer_reward_type="credit", - referrer_reward_value=100.0, - referee_reward_type="credit", - referee_reward_value=50.0, - max_referrals_per_user=10, - referral_code_length=8, - expiry_days=30, + program = self.manager.create_referral_program( + tenant_id = self.test_tenant_id, + name = "邀请好友奖励计划", + description = "邀请好友注册,双方获得积分奖励", + referrer_reward_type = "credit", + referrer_reward_value = 100.0, + referee_reward_type = "credit", + referee_reward_value = 50.0, + max_referrals_per_user = 10, + referral_code_length = 8, + expiry_days = 30, ) assert program.id is not None @@ -512,10 +512,10 @@ class TestGrowthManager: self.log(f"推荐计划创建成功: {program.id}") return program.id except Exception as e: - self.log(f"创建推荐计划失败: {e}", success=False) + self.log(f"创建推荐计划失败: {e}", success = False) return None - def test_generate_referral_code(self, program_id: str): + def test_generate_referral_code(self, program_id: str) -> None: """测试生成推荐码""" print("\n🔑 测试生成推荐码...") @@ -524,8 +524,8 @@ class TestGrowthManager: return None try: - referral = self.manager.generate_referral_code( - program_id=program_id, referrer_id="referrer_user_001" + referral = self.manager.generate_referral_code( + program_id = program_id, referrer_id = "referrer_user_001" ) if referral: @@ -535,13 +535,13 @@ class TestGrowthManager: self.log(f"推荐码生成成功: {referral.referral_code}") return referral.referral_code else: - self.log("生成推荐码返回空结果", success=False) + self.log("生成推荐码返回空结果", success = False) return None except Exception as e: - self.log(f"生成推荐码失败: {e}", success=False) + self.log(f"生成推荐码失败: {e}", success = False) return None - def test_apply_referral_code(self, referral_code: str): + def test_apply_referral_code(self, referral_code: str) -> None: """测试应用推荐码""" print("\n✅ 测试应用推荐码...") @@ -550,21 +550,21 @@ class TestGrowthManager: return False try: - success = self.manager.apply_referral_code( - referral_code=referral_code, referee_id="new_user_001" + success = self.manager.apply_referral_code( + referral_code = referral_code, referee_id = "new_user_001" ) if success: self.log(f"推荐码应用成功: {referral_code}") return True else: - self.log("推荐码应用失败", success=False) + self.log("推荐码应用失败", success = False) return False except Exception as e: - self.log(f"应用推荐码失败: {e}", success=False) + self.log(f"应用推荐码失败: {e}", success = False) return False - def test_get_referral_stats(self, program_id: str): + def test_get_referral_stats(self, program_id: str) -> None: """测试获取推荐统计""" print("\n📊 测试获取推荐统计...") @@ -573,7 +573,7 @@ class TestGrowthManager: return False try: - stats = self.manager.get_referral_stats(program_id) + stats = self.manager.get_referral_stats(program_id) assert "total_referrals" in stats assert "conversion_rate" in stats @@ -583,24 +583,24 @@ class TestGrowthManager: ) return True except Exception as e: - self.log(f"获取推荐统计失败: {e}", success=False) + self.log(f"获取推荐统计失败: {e}", success = False) return False - def test_create_team_incentive(self): + def test_create_team_incentive(self) -> None: """测试创建团队激励""" print("\n🏆 测试创建团队升级激励...") try: - incentive = self.manager.create_team_incentive( - tenant_id=self.test_tenant_id, - name="团队升级奖励", - description="团队规模达到5人升级到 Pro 计划可获得折扣", - target_tier="pro", - min_team_size=5, - incentive_type="discount", - incentive_value=20.0, # 20% 折扣 - valid_from=datetime.now(), - valid_until=datetime.now() + timedelta(days=90), + incentive = self.manager.create_team_incentive( + tenant_id = self.test_tenant_id, + name = "团队升级奖励", + description = "团队规模达到5人升级到 Pro 计划可获得折扣", + target_tier = "pro", + min_team_size = 5, + incentive_type = "discount", + incentive_value = 20.0, # 20% 折扣 + valid_from = datetime.now(), + valid_until = datetime.now() + timedelta(days = 90), ) assert incentive.id is not None @@ -609,116 +609,116 @@ class TestGrowthManager: self.log(f"团队激励创建成功: {incentive.id}") return True except Exception as e: - self.log(f"创建团队激励失败: {e}", success=False) + self.log(f"创建团队激励失败: {e}", success = False) return False - def test_check_team_incentive_eligibility(self): + def test_check_team_incentive_eligibility(self) -> None: """测试检查团队激励资格""" print("\n🔍 测试检查团队激励资格...") try: - incentives = self.manager.check_team_incentive_eligibility( - tenant_id=self.test_tenant_id, current_tier="free", team_size=5 + incentives = self.manager.check_team_incentive_eligibility( + tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5 ) self.log(f"找到 {len(incentives)} 个符合条件的激励") return True except Exception as e: - self.log(f"检查激励资格失败: {e}", success=False) + self.log(f"检查激励资格失败: {e}", success = False) return False # ==================== 测试实时仪表板 ==================== - def test_get_realtime_dashboard(self): + def test_get_realtime_dashboard(self) -> None: """测试获取实时仪表板""" print("\n📺 测试实时分析仪表板...") try: - dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id) + dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id) assert "today" in dashboard assert "recent_events" in dashboard assert "top_features" in dashboard - today = dashboard["today"] + today = dashboard["today"] self.log( f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件" ) return True except Exception as e: - self.log(f"获取实时仪表板失败: {e}", success=False) + self.log(f"获取实时仪表板失败: {e}", success = False) return False # ==================== 运行所有测试 ==================== - async def run_all_tests(self): + async def run_all_tests(self) -> None: """运行所有测试""" - print("=" * 60) + print(" = " * 60) print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试") - print("=" * 60) + print(" = " * 60) # 用户行为分析测试 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("📊 模块 1: 用户行为分析") - print("=" * 60) + print(" = " * 60) await self.test_track_event() await self.test_track_multiple_events() self.test_get_user_profile() self.test_get_analytics_summary() - funnel_id = self.test_create_funnel() + funnel_id = self.test_create_funnel() self.test_analyze_funnel(funnel_id) self.test_calculate_retention() # A/B 测试框架测试 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("🧪 模块 2: A/B 测试框架") - print("=" * 60) + print(" = " * 60) - experiment_id = self.test_create_experiment() + experiment_id = self.test_create_experiment() self.test_list_experiments() self.test_assign_variant(experiment_id) self.test_record_experiment_metric(experiment_id) self.test_analyze_experiment(experiment_id) # 邮件营销测试 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("📧 模块 3: 邮件营销自动化") - print("=" * 60) + print(" = " * 60) - template_id = self.test_create_email_template() + template_id = self.test_create_email_template() self.test_list_email_templates() self.test_render_template(template_id) self.test_create_email_campaign(template_id) self.test_create_automation_workflow() # 推荐系统测试 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("🎁 模块 4: 推荐系统") - print("=" * 60) + print(" = " * 60) - program_id = self.test_create_referral_program() - referral_code = self.test_generate_referral_code(program_id) + program_id = self.test_create_referral_program() + referral_code = self.test_generate_referral_code(program_id) self.test_apply_referral_code(referral_code) self.test_get_referral_stats(program_id) self.test_create_team_incentive() self.test_check_team_incentive_eligibility() # 实时仪表板测试 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("📺 模块 5: 实时分析仪表板") - print("=" * 60) + print(" = " * 60) self.test_get_realtime_dashboard() # 测试总结 - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("📋 测试总结") - print("=" * 60) + print(" = " * 60) - total_tests = len(self.test_results) - passed_tests = sum(1 for _, success in self.test_results if success) - failed_tests = total_tests - passed_tests + total_tests = len(self.test_results) + passed_tests = sum(1 for _, success in self.test_results if success) + failed_tests = total_tests - passed_tests print(f"总测试数: {total_tests}") print(f"通过: {passed_tests} ✅") @@ -731,14 +731,14 @@ class TestGrowthManager: if not success: print(f" - {message}") - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("✨ 测试完成!") - print("=" * 60) + print(" = " * 60) -async def main(): +async def main() -> None: """主函数""" - tester = TestGrowthManager() + tester = TestGrowthManager() await tester.run_all_tests() diff --git a/backend/test_phase8_task6.py b/backend/test_phase8_task6.py index c1816cb..6163551 100644 --- a/backend/test_phase8_task6.py +++ b/backend/test_phase8_task6.py @@ -25,7 +25,7 @@ from developer_ecosystem_manager import ( ) # Add backend directory to path -backend_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.dirname(os.path.abspath(__file__)) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) @@ -33,10 +33,10 @@ if backend_dir not in sys.path: class TestDeveloperEcosystem: """开发者生态系统测试类""" - def __init__(self): - self.manager = DeveloperEcosystemManager() - self.test_results = [] - self.created_ids = { + def __init__(self) -> None: + self.manager = DeveloperEcosystemManager() + self.test_results = [] + self.created_ids = { "sdk": [], "template": [], "plugin": [], @@ -45,19 +45,19 @@ class TestDeveloperEcosystem: "portal_config": [], } - def log(self, message: str, success: bool = True): + def log(self, message: str, success: bool = True) -> None: """记录测试结果""" - status = "✅" if success else "❌" + status = "✅" if success else "❌" print(f"{status} {message}") self.test_results.append( {"message": message, "success": success, "timestamp": datetime.now().isoformat()} ) - def run_all_tests(self): + def run_all_tests(self) -> None: """运行所有测试""" - print("=" * 60) + print(" = " * 60) print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests") - print("=" * 60) + print(" = " * 60) # SDK Tests print("\n📦 SDK Release & Management Tests") @@ -119,538 +119,538 @@ class TestDeveloperEcosystem: # Print Summary self.print_summary() - def test_sdk_create(self): + def test_sdk_create(self) -> None: """测试创建 SDK""" try: - sdk = self.manager.create_sdk_release( - name="InsightFlow Python SDK", - language=SDKLanguage.PYTHON, - version="1.0.0", - description="Python SDK for InsightFlow API", - changelog="Initial release", - download_url="https://pypi.org/insightflow/1.0.0", - documentation_url="https://docs.insightflow.io/python", - repository_url="https://github.com/insightflow/python-sdk", - package_name="insightflow", - min_platform_version="1.0.0", - dependencies=[{"name": "requests", "version": ">=2.0"}], - file_size=1024000, - checksum="abc123", - created_by="test_user", + sdk = self.manager.create_sdk_release( + name = "InsightFlow Python SDK", + language = SDKLanguage.PYTHON, + version = "1.0.0", + description = "Python SDK for InsightFlow API", + changelog = "Initial release", + download_url = "https://pypi.org/insightflow/1.0.0", + documentation_url = "https://docs.insightflow.io/python", + repository_url = "https://github.com/insightflow/python-sdk", + package_name = "insightflow", + min_platform_version = "1.0.0", + dependencies = [{"name": "requests", "version": ">= 2.0"}], + file_size = 1024000, + checksum = "abc123", + created_by = "test_user", ) self.created_ids["sdk"].append(sdk.id) self.log(f"Created SDK: {sdk.name} ({sdk.id})") # Create JavaScript SDK - sdk_js = self.manager.create_sdk_release( - name="InsightFlow JavaScript SDK", - language=SDKLanguage.JAVASCRIPT, - version="1.0.0", - description="JavaScript SDK for InsightFlow API", - changelog="Initial release", - download_url="https://npmjs.com/insightflow/1.0.0", - documentation_url="https://docs.insightflow.io/js", - repository_url="https://github.com/insightflow/js-sdk", - package_name="@insightflow/sdk", - min_platform_version="1.0.0", - dependencies=[{"name": "axios", "version": ">=0.21"}], - file_size=512000, - checksum="def456", - created_by="test_user", + sdk_js = self.manager.create_sdk_release( + name = "InsightFlow JavaScript SDK", + language = SDKLanguage.JAVASCRIPT, + version = "1.0.0", + description = "JavaScript SDK for InsightFlow API", + changelog = "Initial release", + download_url = "https://npmjs.com/insightflow/1.0.0", + documentation_url = "https://docs.insightflow.io/js", + repository_url = "https://github.com/insightflow/js-sdk", + package_name = "@insightflow/sdk", + min_platform_version = "1.0.0", + dependencies = [{"name": "axios", "version": ">= 0.21"}], + file_size = 512000, + checksum = "def456", + created_by = "test_user", ) self.created_ids["sdk"].append(sdk_js.id) self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})") except Exception as e: - self.log(f"Failed to create SDK: {str(e)}", success=False) + self.log(f"Failed to create SDK: {str(e)}", success = False) - def test_sdk_list(self): + def test_sdk_list(self) -> None: """测试列出 SDK""" try: - sdks = self.manager.list_sdk_releases() + sdks = self.manager.list_sdk_releases() self.log(f"Listed {len(sdks)} SDKs") # Test filter by language - python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON) + python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON) self.log(f"Found {len(python_sdks)} Python SDKs") # Test search - search_results = self.manager.list_sdk_releases(search="Python") + search_results = self.manager.list_sdk_releases(search = "Python") self.log(f"Search found {len(search_results)} SDKs") except Exception as e: - self.log(f"Failed to list SDKs: {str(e)}", success=False) + self.log(f"Failed to list SDKs: {str(e)}", success = False) - def test_sdk_get(self): + def test_sdk_get(self) -> None: """测试获取 SDK 详情""" try: if self.created_ids["sdk"]: - sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0]) + sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0]) if sdk: self.log(f"Retrieved SDK: {sdk.name}") else: - self.log("SDK not found", success=False) + self.log("SDK not found", success = False) except Exception as e: - self.log(f"Failed to get SDK: {str(e)}", success=False) + self.log(f"Failed to get SDK: {str(e)}", success = False) - def test_sdk_update(self): + def test_sdk_update(self) -> None: """测试更新 SDK""" try: if self.created_ids["sdk"]: - sdk = self.manager.update_sdk_release( - self.created_ids["sdk"][0], description="Updated description" + sdk = self.manager.update_sdk_release( + self.created_ids["sdk"][0], description = "Updated description" ) if sdk: self.log(f"Updated SDK: {sdk.name}") except Exception as e: - self.log(f"Failed to update SDK: {str(e)}", success=False) + self.log(f"Failed to update SDK: {str(e)}", success = False) - def test_sdk_publish(self): + def test_sdk_publish(self) -> None: """测试发布 SDK""" try: if self.created_ids["sdk"]: - sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0]) + sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0]) if sdk: self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})") except Exception as e: - self.log(f"Failed to publish SDK: {str(e)}", success=False) + self.log(f"Failed to publish SDK: {str(e)}", success = False) - def test_sdk_version_add(self): + def test_sdk_version_add(self) -> None: """测试添加 SDK 版本""" try: if self.created_ids["sdk"]: - version = self.manager.add_sdk_version( - sdk_id=self.created_ids["sdk"][0], - version="1.1.0", - is_lts=True, - release_notes="Bug fixes and improvements", - download_url="https://pypi.org/insightflow/1.1.0", - checksum="xyz789", - file_size=1100000, + version = self.manager.add_sdk_version( + sdk_id = self.created_ids["sdk"][0], + version = "1.1.0", + is_lts = True, + release_notes = "Bug fixes and improvements", + download_url = "https://pypi.org/insightflow/1.1.0", + checksum = "xyz789", + file_size = 1100000, ) self.log(f"Added SDK version: {version.version}") except Exception as e: - self.log(f"Failed to add SDK version: {str(e)}", success=False) + self.log(f"Failed to add SDK version: {str(e)}", success = False) - def test_template_create(self): + def test_template_create(self) -> None: """测试创建模板""" try: - template = self.manager.create_template( - name="医疗行业实体识别模板", - description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体", - category=TemplateCategory.MEDICAL, - subcategory="entity_recognition", - tags=["medical", "healthcare", "ner"], - author_id="dev_001", - author_name="Medical AI Lab", - price=99.0, - currency="CNY", - preview_image_url="https://cdn.insightflow.io/templates/medical.png", - demo_url="https://demo.insightflow.io/medical", - documentation_url="https://docs.insightflow.io/templates/medical", - download_url="https://cdn.insightflow.io/templates/medical.zip", - version="1.0.0", - min_platform_version="2.0.0", - file_size=5242880, - checksum="tpl123", + template = self.manager.create_template( + name = "医疗行业实体识别模板", + description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体", + category = TemplateCategory.MEDICAL, + subcategory = "entity_recognition", + tags = ["medical", "healthcare", "ner"], + author_id = "dev_001", + author_name = "Medical AI Lab", + price = 99.0, + currency = "CNY", + preview_image_url = "https://cdn.insightflow.io/templates/medical.png", + demo_url = "https://demo.insightflow.io/medical", + documentation_url = "https://docs.insightflow.io/templates/medical", + download_url = "https://cdn.insightflow.io/templates/medical.zip", + version = "1.0.0", + min_platform_version = "2.0.0", + file_size = 5242880, + checksum = "tpl123", ) self.created_ids["template"].append(template.id) self.log(f"Created template: {template.name} ({template.id})") # Create free template - template_free = self.manager.create_template( - name="通用实体识别模板", - description="适用于一般场景的实体识别模板", - category=TemplateCategory.GENERAL, - subcategory=None, - tags=["general", "ner", "basic"], - author_id="dev_002", - author_name="InsightFlow Team", - price=0.0, - currency="CNY", + template_free = self.manager.create_template( + name = "通用实体识别模板", + description = "适用于一般场景的实体识别模板", + category = TemplateCategory.GENERAL, + subcategory = None, + tags = ["general", "ner", "basic"], + author_id = "dev_002", + author_name = "InsightFlow Team", + price = 0.0, + currency = "CNY", ) self.created_ids["template"].append(template_free.id) self.log(f"Created free template: {template_free.name}") except Exception as e: - self.log(f"Failed to create template: {str(e)}", success=False) + self.log(f"Failed to create template: {str(e)}", success = False) - def test_template_list(self): + def test_template_list(self) -> None: """测试列出模板""" try: - templates = self.manager.list_templates() + templates = self.manager.list_templates() self.log(f"Listed {len(templates)} templates") # Filter by category - medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL) + medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL) self.log(f"Found {len(medical_templates)} medical templates") # Filter by price - free_templates = self.manager.list_templates(max_price=0) + free_templates = self.manager.list_templates(max_price = 0) self.log(f"Found {len(free_templates)} free templates") except Exception as e: - self.log(f"Failed to list templates: {str(e)}", success=False) + self.log(f"Failed to list templates: {str(e)}", success = False) - def test_template_get(self): + def test_template_get(self) -> None: """测试获取模板详情""" try: if self.created_ids["template"]: - template = self.manager.get_template(self.created_ids["template"][0]) + template = self.manager.get_template(self.created_ids["template"][0]) if template: self.log(f"Retrieved template: {template.name}") except Exception as e: - self.log(f"Failed to get template: {str(e)}", success=False) + self.log(f"Failed to get template: {str(e)}", success = False) - def test_template_approve(self): + def test_template_approve(self) -> None: """测试审核通过模板""" try: if self.created_ids["template"]: - template = self.manager.approve_template( - self.created_ids["template"][0], reviewed_by="admin_001" + template = self.manager.approve_template( + self.created_ids["template"][0], reviewed_by = "admin_001" ) if template: self.log(f"Approved template: {template.name}") except Exception as e: - self.log(f"Failed to approve template: {str(e)}", success=False) + self.log(f"Failed to approve template: {str(e)}", success = False) - def test_template_publish(self): + def test_template_publish(self) -> None: """测试发布模板""" try: if self.created_ids["template"]: - template = self.manager.publish_template(self.created_ids["template"][0]) + template = self.manager.publish_template(self.created_ids["template"][0]) if template: self.log(f"Published template: {template.name}") except Exception as e: - self.log(f"Failed to publish template: {str(e)}", success=False) + self.log(f"Failed to publish template: {str(e)}", success = False) - def test_template_review(self): + def test_template_review(self) -> None: """测试添加模板评价""" try: if self.created_ids["template"]: - review = self.manager.add_template_review( - template_id=self.created_ids["template"][0], - user_id="user_001", - user_name="Test User", - rating=5, - comment="Great template! Very accurate for medical entities.", - is_verified_purchase=True, + review = self.manager.add_template_review( + template_id = self.created_ids["template"][0], + user_id = "user_001", + user_name = "Test User", + rating = 5, + comment = "Great template! Very accurate for medical entities.", + is_verified_purchase = True, ) self.log(f"Added template review: {review.rating} stars") except Exception as e: - self.log(f"Failed to add template review: {str(e)}", success=False) + self.log(f"Failed to add template review: {str(e)}", success = False) - def test_plugin_create(self): + def test_plugin_create(self) -> None: """测试创建插件""" try: - plugin = self.manager.create_plugin( - name="飞书机器人集成插件", - description="将 InsightFlow 与飞书机器人集成,实现自动通知", - category=PluginCategory.INTEGRATION, - tags=["feishu", "bot", "integration", "notification"], - author_id="dev_003", - author_name="Integration Team", - price=49.0, - currency="CNY", - pricing_model="paid", - preview_image_url="https://cdn.insightflow.io/plugins/feishu.png", - demo_url="https://demo.insightflow.io/feishu", - documentation_url="https://docs.insightflow.io/plugins/feishu", - repository_url="https://github.com/insightflow/feishu-plugin", - download_url="https://cdn.insightflow.io/plugins/feishu.zip", - webhook_url="https://api.insightflow.io/webhooks/feishu", - permissions=["read:projects", "write:notifications"], - version="1.0.0", - min_platform_version="2.0.0", - file_size=1048576, - checksum="plg123", + plugin = self.manager.create_plugin( + name = "飞书机器人集成插件", + description = "将 InsightFlow 与飞书机器人集成,实现自动通知", + category = PluginCategory.INTEGRATION, + tags = ["feishu", "bot", "integration", "notification"], + author_id = "dev_003", + author_name = "Integration Team", + price = 49.0, + currency = "CNY", + pricing_model = "paid", + preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png", + demo_url = "https://demo.insightflow.io/feishu", + documentation_url = "https://docs.insightflow.io/plugins/feishu", + repository_url = "https://github.com/insightflow/feishu-plugin", + download_url = "https://cdn.insightflow.io/plugins/feishu.zip", + webhook_url = "https://api.insightflow.io/webhooks/feishu", + permissions = ["read:projects", "write:notifications"], + version = "1.0.0", + min_platform_version = "2.0.0", + file_size = 1048576, + checksum = "plg123", ) self.created_ids["plugin"].append(plugin.id) self.log(f"Created plugin: {plugin.name} ({plugin.id})") # Create free plugin - plugin_free = self.manager.create_plugin( - name="数据导出插件", - description="支持多种格式的数据导出", - category=PluginCategory.ANALYSIS, - tags=["export", "data", "csv", "json"], - author_id="dev_004", - author_name="Data Team", - price=0.0, - currency="CNY", - pricing_model="free", + plugin_free = self.manager.create_plugin( + name = "数据导出插件", + description = "支持多种格式的数据导出", + category = PluginCategory.ANALYSIS, + tags = ["export", "data", "csv", "json"], + author_id = "dev_004", + author_name = "Data Team", + price = 0.0, + currency = "CNY", + pricing_model = "free", ) self.created_ids["plugin"].append(plugin_free.id) self.log(f"Created free plugin: {plugin_free.name}") except Exception as e: - self.log(f"Failed to create plugin: {str(e)}", success=False) + self.log(f"Failed to create plugin: {str(e)}", success = False) - def test_plugin_list(self): + def test_plugin_list(self) -> None: """测试列出插件""" try: - plugins = self.manager.list_plugins() + plugins = self.manager.list_plugins() self.log(f"Listed {len(plugins)} plugins") # Filter by category - integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION) + integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION) self.log(f"Found {len(integration_plugins)} integration plugins") except Exception as e: - self.log(f"Failed to list plugins: {str(e)}", success=False) + self.log(f"Failed to list plugins: {str(e)}", success = False) - def test_plugin_get(self): + def test_plugin_get(self) -> None: """测试获取插件详情""" try: if self.created_ids["plugin"]: - plugin = self.manager.get_plugin(self.created_ids["plugin"][0]) + plugin = self.manager.get_plugin(self.created_ids["plugin"][0]) if plugin: self.log(f"Retrieved plugin: {plugin.name}") except Exception as e: - self.log(f"Failed to get plugin: {str(e)}", success=False) + self.log(f"Failed to get plugin: {str(e)}", success = False) - def test_plugin_review(self): + def test_plugin_review(self) -> None: """测试审核插件""" try: if self.created_ids["plugin"]: - plugin = self.manager.review_plugin( + plugin = self.manager.review_plugin( self.created_ids["plugin"][0], - reviewed_by="admin_001", - status=PluginStatus.APPROVED, - notes="Code review passed", + reviewed_by = "admin_001", + status = PluginStatus.APPROVED, + notes = "Code review passed", ) if plugin: self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})") except Exception as e: - self.log(f"Failed to review plugin: {str(e)}", success=False) + self.log(f"Failed to review plugin: {str(e)}", success = False) - def test_plugin_publish(self): + def test_plugin_publish(self) -> None: """测试发布插件""" try: if self.created_ids["plugin"]: - plugin = self.manager.publish_plugin(self.created_ids["plugin"][0]) + plugin = self.manager.publish_plugin(self.created_ids["plugin"][0]) if plugin: self.log(f"Published plugin: {plugin.name}") except Exception as e: - self.log(f"Failed to publish plugin: {str(e)}", success=False) + self.log(f"Failed to publish plugin: {str(e)}", success = False) - def test_plugin_review_add(self): + def test_plugin_review_add(self) -> None: """测试添加插件评价""" try: if self.created_ids["plugin"]: - review = self.manager.add_plugin_review( - plugin_id=self.created_ids["plugin"][0], - user_id="user_002", - user_name="Plugin User", - rating=4, - comment="Works great with Feishu!", - is_verified_purchase=True, + review = self.manager.add_plugin_review( + plugin_id = self.created_ids["plugin"][0], + user_id = "user_002", + user_name = "Plugin User", + rating = 4, + comment = "Works great with Feishu!", + is_verified_purchase = True, ) self.log(f"Added plugin review: {review.rating} stars") except Exception as e: - self.log(f"Failed to add plugin review: {str(e)}", success=False) + self.log(f"Failed to add plugin review: {str(e)}", success = False) - def test_developer_profile_create(self): + def test_developer_profile_create(self) -> None: """测试创建开发者档案""" try: # Generate unique user IDs - unique_id = uuid.uuid4().hex[:8] + unique_id = uuid.uuid4().hex[:8] - profile = self.manager.create_developer_profile( - user_id=f"user_dev_{unique_id}_001", - display_name="张三", - email=f"zhangsan_{unique_id}@example.com", - bio="专注于医疗AI和自然语言处理", - website="https://zhangsan.dev", - github_url="https://github.com/zhangsan", - avatar_url="https://cdn.example.com/avatars/zhangsan.png", + profile = self.manager.create_developer_profile( + user_id = f"user_dev_{unique_id}_001", + display_name = "张三", + email = f"zhangsan_{unique_id}@example.com", + bio = "专注于医疗AI和自然语言处理", + website = "https://zhangsan.dev", + github_url = "https://github.com/zhangsan", + avatar_url = "https://cdn.example.com/avatars/zhangsan.png", ) self.created_ids["developer"].append(profile.id) self.log(f"Created developer profile: {profile.display_name} ({profile.id})") # Create another developer - profile2 = self.manager.create_developer_profile( - user_id=f"user_dev_{unique_id}_002", - display_name="李四", - email=f"lisi_{unique_id}@example.com", - bio="全栈开发者,热爱开源", + profile2 = self.manager.create_developer_profile( + user_id = f"user_dev_{unique_id}_002", + display_name = "李四", + email = f"lisi_{unique_id}@example.com", + bio = "全栈开发者,热爱开源", ) self.created_ids["developer"].append(profile2.id) self.log(f"Created developer profile: {profile2.display_name}") except Exception as e: - self.log(f"Failed to create developer profile: {str(e)}", success=False) + self.log(f"Failed to create developer profile: {str(e)}", success = False) - def test_developer_profile_get(self): + def test_developer_profile_get(self) -> None: """测试获取开发者档案""" try: if self.created_ids["developer"]: - profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) + profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) if profile: self.log(f"Retrieved developer profile: {profile.display_name}") except Exception as e: - self.log(f"Failed to get developer profile: {str(e)}", success=False) + self.log(f"Failed to get developer profile: {str(e)}", success = False) - def test_developer_verify(self): + def test_developer_verify(self) -> None: """测试验证开发者""" try: if self.created_ids["developer"]: - profile = self.manager.verify_developer( + profile = self.manager.verify_developer( self.created_ids["developer"][0], DeveloperStatus.VERIFIED ) if profile: self.log(f"Verified developer: {profile.display_name} ({profile.status.value})") except Exception as e: - self.log(f"Failed to verify developer: {str(e)}", success=False) + self.log(f"Failed to verify developer: {str(e)}", success = False) - def test_developer_stats_update(self): + def test_developer_stats_update(self) -> None: """测试更新开发者统计""" try: if self.created_ids["developer"]: self.manager.update_developer_stats(self.created_ids["developer"][0]) - profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) + profile = self.manager.get_developer_profile(self.created_ids["developer"][0]) self.log( f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates" ) except Exception as e: - self.log(f"Failed to update developer stats: {str(e)}", success=False) + self.log(f"Failed to update developer stats: {str(e)}", success = False) - def test_code_example_create(self): + def test_code_example_create(self) -> None: """测试创建代码示例""" try: - example = self.manager.create_code_example( - title="使用 Python SDK 创建项目", - description="演示如何使用 Python SDK 创建新项目", - language="python", - category="quickstart", - code="""from insightflow import Client + example = self.manager.create_code_example( + title = "使用 Python SDK 创建项目", + description = "演示如何使用 Python SDK 创建新项目", + language = "python", + category = "quickstart", + code = """from insightflow import Client -client = Client(api_key="your_api_key") -project = client.projects.create(name="My Project") +client = Client(api_key = "your_api_key") +project = client.projects.create(name = "My Project") print(f"Created project: {project.id}") """, - explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。", - tags=["python", "quickstart", "projects"], - author_id="dev_001", - author_name="InsightFlow Team", - api_endpoints=["/api/v1/projects"], + explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。", + tags = ["python", "quickstart", "projects"], + author_id = "dev_001", + author_name = "InsightFlow Team", + api_endpoints = ["/api/v1/projects"], ) self.created_ids["code_example"].append(example.id) self.log(f"Created code example: {example.title}") # Create JavaScript example - example_js = self.manager.create_code_example( - title="使用 JavaScript SDK 上传文件", - description="演示如何使用 JavaScript SDK 上传音频文件", - language="javascript", - category="upload", - code="""const { Client } = require('insightflow'); + example_js = self.manager.create_code_example( + title = "使用 JavaScript SDK 上传文件", + description = "演示如何使用 JavaScript SDK 上传音频文件", + language = "javascript", + category = "upload", + code = """const { Client } = require('insightflow'); -const client = new Client({ apiKey: 'your_api_key' }); -const result = await client.uploads.create({ +const client = new Client({ apiKey: 'your_api_key' }); +const result = await client.uploads.create({ projectId: 'proj_123', file: './meeting.mp3' }); console.log('Upload complete:', result.id); """, - explanation="使用 JavaScript SDK 上传文件到 InsightFlow", - tags=["javascript", "upload", "audio"], - author_id="dev_002", - author_name="JS Team", + explanation = "使用 JavaScript SDK 上传文件到 InsightFlow", + tags = ["javascript", "upload", "audio"], + author_id = "dev_002", + author_name = "JS Team", ) self.created_ids["code_example"].append(example_js.id) self.log(f"Created code example: {example_js.title}") except Exception as e: - self.log(f"Failed to create code example: {str(e)}", success=False) + self.log(f"Failed to create code example: {str(e)}", success = False) - def test_code_example_list(self): + def test_code_example_list(self) -> None: """测试列出代码示例""" try: - examples = self.manager.list_code_examples() + examples = self.manager.list_code_examples() self.log(f"Listed {len(examples)} code examples") # Filter by language - python_examples = self.manager.list_code_examples(language="python") + python_examples = self.manager.list_code_examples(language = "python") self.log(f"Found {len(python_examples)} Python examples") except Exception as e: - self.log(f"Failed to list code examples: {str(e)}", success=False) + self.log(f"Failed to list code examples: {str(e)}", success = False) - def test_code_example_get(self): + def test_code_example_get(self) -> None: """测试获取代码示例详情""" try: if self.created_ids["code_example"]: - example = self.manager.get_code_example(self.created_ids["code_example"][0]) + example = self.manager.get_code_example(self.created_ids["code_example"][0]) if example: self.log( f"Retrieved code example: {example.title} (views: {example.view_count})" ) except Exception as e: - self.log(f"Failed to get code example: {str(e)}", success=False) + self.log(f"Failed to get code example: {str(e)}", success = False) - def test_portal_config_create(self): + def test_portal_config_create(self) -> None: """测试创建开发者门户配置""" try: - config = self.manager.create_portal_config( - name="InsightFlow Developer Portal", - description="开发者门户 - SDK、API 文档和示例代码", - theme="default", - primary_color="#1890ff", - secondary_color="#52c41a", - support_email="developers@insightflow.io", - support_url="https://support.insightflow.io", - github_url="https://github.com/insightflow", - discord_url="https://discord.gg/insightflow", - api_base_url="https://api.insightflow.io/v1", + config = self.manager.create_portal_config( + name = "InsightFlow Developer Portal", + description = "开发者门户 - SDK、API 文档和示例代码", + theme = "default", + primary_color = "#1890ff", + secondary_color = "#52c41a", + support_email = "developers@insightflow.io", + support_url = "https://support.insightflow.io", + github_url = "https://github.com/insightflow", + discord_url = "https://discord.gg/insightflow", + api_base_url = "https://api.insightflow.io/v1", ) self.created_ids["portal_config"].append(config.id) self.log(f"Created portal config: {config.name}") except Exception as e: - self.log(f"Failed to create portal config: {str(e)}", success=False) + self.log(f"Failed to create portal config: {str(e)}", success = False) - def test_portal_config_get(self): + def test_portal_config_get(self) -> None: """测试获取开发者门户配置""" try: if self.created_ids["portal_config"]: - config = self.manager.get_portal_config(self.created_ids["portal_config"][0]) + config = self.manager.get_portal_config(self.created_ids["portal_config"][0]) if config: self.log(f"Retrieved portal config: {config.name}") # Test active config - active_config = self.manager.get_active_portal_config() + active_config = self.manager.get_active_portal_config() if active_config: self.log(f"Active portal config: {active_config.name}") except Exception as e: - self.log(f"Failed to get portal config: {str(e)}", success=False) + self.log(f"Failed to get portal config: {str(e)}", success = False) - def test_revenue_record(self): + def test_revenue_record(self) -> None: """测试记录开发者收益""" try: if self.created_ids["developer"] and self.created_ids["plugin"]: - revenue = self.manager.record_revenue( - developer_id=self.created_ids["developer"][0], - item_type="plugin", - item_id=self.created_ids["plugin"][0], - item_name="飞书机器人集成插件", - sale_amount=49.0, - currency="CNY", - buyer_id="user_buyer_001", - transaction_id="txn_123456", + revenue = self.manager.record_revenue( + developer_id = self.created_ids["developer"][0], + item_type = "plugin", + item_id = self.created_ids["plugin"][0], + item_name = "飞书机器人集成插件", + sale_amount = 49.0, + currency = "CNY", + buyer_id = "user_buyer_001", + transaction_id = "txn_123456", ) self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}") self.log(f" - Platform fee: {revenue.platform_fee}") self.log(f" - Developer earnings: {revenue.developer_earnings}") except Exception as e: - self.log(f"Failed to record revenue: {str(e)}", success=False) + self.log(f"Failed to record revenue: {str(e)}", success = False) - def test_revenue_summary(self): + def test_revenue_summary(self) -> None: """测试获取开发者收益汇总""" try: if self.created_ids["developer"]: - summary = self.manager.get_developer_revenue_summary( + summary = self.manager.get_developer_revenue_summary( self.created_ids["developer"][0] ) self.log("Revenue summary for developer:") @@ -659,17 +659,17 @@ console.log('Upload complete:', result.id); self.log(f" - Total earnings: {summary['total_earnings']}") self.log(f" - Transaction count: {summary['transaction_count']}") except Exception as e: - self.log(f"Failed to get revenue summary: {str(e)}", success=False) + self.log(f"Failed to get revenue summary: {str(e)}", success = False) - def print_summary(self): + def print_summary(self) -> None: """打印测试摘要""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("Test Summary") - print("=" * 60) + print(" = " * 60) - total = len(self.test_results) - passed = sum(1 for r in self.test_results if r["success"]) - failed = total - passed + total = len(self.test_results) + passed = sum(1 for r in self.test_results if r["success"]) + failed = total - passed print(f"Total tests: {total}") print(f"Passed: {passed} ✅") @@ -686,12 +686,12 @@ console.log('Upload complete:', result.id); if ids: print(f" {resource_type}: {len(ids)}") - print("=" * 60) + print(" = " * 60) -def main(): +def main() -> None: """主函数""" - test = TestDeveloperEcosystem() + test = TestDeveloperEcosystem() test.run_all_tests() diff --git a/backend/test_phase8_task8.py b/backend/test_phase8_task8.py index 03f5edb..eef988a 100644 --- a/backend/test_phase8_task8.py +++ b/backend/test_phase8_task8.py @@ -26,7 +26,7 @@ from ops_manager import ( ) # Add backend directory to path -backend_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.dirname(os.path.abspath(__file__)) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) @@ -34,22 +34,22 @@ if backend_dir not in sys.path: class TestOpsManager: """测试运维与监控管理器""" - def __init__(self): - self.manager = get_ops_manager() - self.tenant_id = "test_tenant_001" - self.test_results = [] + def __init__(self) -> None: + self.manager = get_ops_manager() + self.tenant_id = "test_tenant_001" + self.test_results = [] - def log(self, message: str, success: bool = True): + def log(self, message: str, success: bool = True) -> None: """记录测试结果""" - status = "✅" if success else "❌" + status = "✅" if success else "❌" print(f"{status} {message}") self.test_results.append((message, success)) - def run_all_tests(self): + def run_all_tests(self) -> None: """运行所有测试""" - print("=" * 60) + print(" = " * 60) print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests") - print("=" * 60) + print(" = " * 60) # 1. 告警系统测试 self.test_alert_rules() @@ -73,63 +73,63 @@ class TestOpsManager: # 打印测试总结 self.print_summary() - def test_alert_rules(self): + def test_alert_rules(self) -> None: """测试告警规则管理""" print("\n📋 Testing Alert Rules...") try: # 创建阈值告警规则 - rule1 = self.manager.create_alert_rule( - tenant_id=self.tenant_id, - name="CPU 使用率告警", - description="当 CPU 使用率超过 80% 时触发告警", - rule_type=AlertRuleType.THRESHOLD, - severity=AlertSeverity.P1, - metric="cpu_usage_percent", - condition=">", - threshold=80.0, - duration=300, - evaluation_interval=60, - channels=[], - labels={"service": "api", "team": "platform"}, - annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, - created_by="test_user", + rule1 = self.manager.create_alert_rule( + tenant_id = self.tenant_id, + name = "CPU 使用率告警", + description = "当 CPU 使用率超过 80% 时触发告警", + rule_type = AlertRuleType.THRESHOLD, + severity = AlertSeverity.P1, + metric = "cpu_usage_percent", + condition = ">", + threshold = 80.0, + duration = 300, + evaluation_interval = 60, + channels = [], + labels = {"service": "api", "team": "platform"}, + annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, + created_by = "test_user", ) self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})") # 创建异常检测告警规则 - rule2 = self.manager.create_alert_rule( - tenant_id=self.tenant_id, - name="内存异常检测", - description="检测内存使用异常", - rule_type=AlertRuleType.ANOMALY, - severity=AlertSeverity.P2, - metric="memory_usage_percent", - condition=">", - threshold=0.0, - duration=600, - evaluation_interval=300, - channels=[], - labels={"service": "database"}, - annotations={}, - created_by="test_user", + rule2 = self.manager.create_alert_rule( + tenant_id = self.tenant_id, + name = "内存异常检测", + description = "检测内存使用异常", + rule_type = AlertRuleType.ANOMALY, + severity = AlertSeverity.P2, + metric = "memory_usage_percent", + condition = ">", + threshold = 0.0, + duration = 600, + evaluation_interval = 300, + channels = [], + labels = {"service": "database"}, + annotations = {}, + created_by = "test_user", ) self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})") # 获取告警规则 - fetched_rule = self.manager.get_alert_rule(rule1.id) + fetched_rule = self.manager.get_alert_rule(rule1.id) assert fetched_rule is not None assert fetched_rule.name == rule1.name self.log(f"Fetched alert rule: {fetched_rule.name}") # 列出租户的所有告警规则 - rules = self.manager.list_alert_rules(self.tenant_id) + rules = self.manager.list_alert_rules(self.tenant_id) assert len(rules) >= 2 self.log(f"Listed {len(rules)} alert rules for tenant") # 更新告警规则 - updated_rule = self.manager.update_alert_rule( - rule1.id, threshold=85.0, description="更新后的描述" + updated_rule = self.manager.update_alert_rule( + rule1.id, threshold = 85.0, description = "更新后的描述" ) assert updated_rule.threshold == 85.0 self.log(f"Updated alert rule threshold to {updated_rule.threshold}") @@ -140,57 +140,57 @@ class TestOpsManager: self.log("Deleted test alert rules") except Exception as e: - self.log(f"Alert rules test failed: {e}", success=False) + self.log(f"Alert rules test failed: {e}", success = False) - def test_alert_channels(self): + def test_alert_channels(self) -> None: """测试告警渠道管理""" print("\n📢 Testing Alert Channels...") try: # 创建飞书告警渠道 - channel1 = self.manager.create_alert_channel( - tenant_id=self.tenant_id, - name="飞书告警", - channel_type=AlertChannelType.FEISHU, - config={ + channel1 = self.manager.create_alert_channel( + tenant_id = self.tenant_id, + name = "飞书告警", + channel_type = AlertChannelType.FEISHU, + config = { "webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test", "secret": "test_secret", }, - severity_filter=["p0", "p1"], + severity_filter = ["p0", "p1"], ) self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})") # 创建钉钉告警渠道 - channel2 = self.manager.create_alert_channel( - tenant_id=self.tenant_id, - name="钉钉告警", - channel_type=AlertChannelType.DINGTALK, - config={ - "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test", + channel2 = self.manager.create_alert_channel( + tenant_id = self.tenant_id, + name = "钉钉告警", + channel_type = AlertChannelType.DINGTALK, + config = { + "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test", "secret": "test_secret", }, - severity_filter=["p0", "p1", "p2"], + severity_filter = ["p0", "p1", "p2"], ) self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})") # 创建 Slack 告警渠道 - channel3 = self.manager.create_alert_channel( - tenant_id=self.tenant_id, - name="Slack 告警", - channel_type=AlertChannelType.SLACK, - config={"webhook_url": "https://hooks.slack.com/services/test"}, - severity_filter=["p0", "p1", "p2", "p3"], + channel3 = self.manager.create_alert_channel( + tenant_id = self.tenant_id, + name = "Slack 告警", + channel_type = AlertChannelType.SLACK, + config = {"webhook_url": "https://hooks.slack.com/services/test"}, + severity_filter = ["p0", "p1", "p2", "p3"], ) self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})") # 获取告警渠道 - fetched_channel = self.manager.get_alert_channel(channel1.id) + fetched_channel = self.manager.get_alert_channel(channel1.id) assert fetched_channel is not None assert fetched_channel.name == channel1.name self.log(f"Fetched alert channel: {fetched_channel.name}") # 列出租户的所有告警渠道 - channels = self.manager.list_alert_channels(self.tenant_id) + channels = self.manager.list_alert_channels(self.tenant_id) assert len(channels) >= 3 self.log(f"Listed {len(channels)} alert channels for tenant") @@ -198,74 +198,74 @@ class TestOpsManager: for channel in channels: if channel.tenant_id == self.tenant_id: with self.manager._get_db() as conn: - conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,)) + conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id, )) conn.commit() self.log("Deleted test alert channels") except Exception as e: - self.log(f"Alert channels test failed: {e}", success=False) + self.log(f"Alert channels test failed: {e}", success = False) - def test_alerts(self): + def test_alerts(self) -> None: """测试告警管理""" print("\n🚨 Testing Alerts...") try: # 创建告警规则 - rule = self.manager.create_alert_rule( - tenant_id=self.tenant_id, - name="测试告警规则", - description="用于测试的告警规则", - rule_type=AlertRuleType.THRESHOLD, - severity=AlertSeverity.P1, - metric="test_metric", - condition=">", - threshold=100.0, - duration=60, - evaluation_interval=60, - channels=[], - labels={}, - annotations={}, - created_by="test_user", + rule = self.manager.create_alert_rule( + tenant_id = self.tenant_id, + name = "测试告警规则", + description = "用于测试的告警规则", + rule_type = AlertRuleType.THRESHOLD, + severity = AlertSeverity.P1, + metric = "test_metric", + condition = ">", + threshold = 100.0, + duration = 60, + evaluation_interval = 60, + channels = [], + labels = {}, + annotations = {}, + created_by = "test_user", ) # 记录资源指标 for i in range(10): self.manager.record_resource_metric( - tenant_id=self.tenant_id, - resource_type=ResourceType.CPU, - resource_id="server-001", - metric_name="test_metric", - metric_value=110.0 + i, - unit="percent", - metadata={"region": "cn-north-1"}, + tenant_id = self.tenant_id, + resource_type = ResourceType.CPU, + resource_id = "server-001", + metric_name = "test_metric", + metric_value = 110.0 + i, + unit = "percent", + metadata = {"region": "cn-north-1"}, ) self.log("Recorded 10 resource metrics") # 手动创建告警 from ops_manager import Alert - alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}" - now = datetime.now().isoformat() + alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}" + now = datetime.now().isoformat() - alert = Alert( - id=alert_id, - rule_id=rule.id, - tenant_id=self.tenant_id, - severity=AlertSeverity.P1, - status=AlertStatus.FIRING, - title="测试告警", - description="这是一条测试告警", - metric="test_metric", - value=120.0, - threshold=100.0, - labels={"test": "true"}, - annotations={}, - started_at=now, - resolved_at=None, - acknowledged_by=None, - acknowledged_at=None, - notification_sent={}, - suppression_count=0, + alert = Alert( + id = alert_id, + rule_id = rule.id, + tenant_id = self.tenant_id, + severity = AlertSeverity.P1, + status = AlertStatus.FIRING, + title = "测试告警", + description = "这是一条测试告警", + metric = "test_metric", + value = 120.0, + threshold = 100.0, + labels = {"test": "true"}, + annotations = {}, + started_at = now, + resolved_at = None, + acknowledged_by = None, + acknowledged_at = None, + notification_sent = {}, + suppression_count = 0, ) with self.manager._get_db() as conn: @@ -299,20 +299,20 @@ class TestOpsManager: self.log(f"Created test alert: {alert.id}") # 列出租户的告警 - alerts = self.manager.list_alerts(self.tenant_id) + alerts = self.manager.list_alerts(self.tenant_id) assert len(alerts) >= 1 self.log(f"Listed {len(alerts)} alerts for tenant") # 确认告警 self.manager.acknowledge_alert(alert_id, "test_user") - fetched_alert = self.manager.get_alert(alert_id) + fetched_alert = self.manager.get_alert(alert_id) assert fetched_alert.status == AlertStatus.ACKNOWLEDGED assert fetched_alert.acknowledged_by == "test_user" self.log(f"Acknowledged alert: {alert_id}") # 解决告警 self.manager.resolve_alert(alert_id) - fetched_alert = self.manager.get_alert(alert_id) + fetched_alert = self.manager.get_alert(alert_id) assert fetched_alert.status == AlertStatus.RESOLVED assert fetched_alert.resolved_at is not None self.log(f"Resolved alert: {alert_id}") @@ -320,23 +320,23 @@ class TestOpsManager: # 清理 self.manager.delete_alert_rule(rule.id) with self.manager._get_db() as conn: - conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id,)) - conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id, )) + conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, )) conn.commit() self.log("Cleaned up test data") except Exception as e: - self.log(f"Alerts test failed: {e}", success=False) + self.log(f"Alerts test failed: {e}", success = False) - def test_capacity_planning(self): + def test_capacity_planning(self) -> None: """测试容量规划""" print("\n📊 Testing Capacity Planning...") try: # 记录历史指标数据 - base_time = datetime.now() - timedelta(days=30) + base_time = datetime.now() - timedelta(days = 30) for i in range(30): - timestamp = (base_time + timedelta(days=i)).isoformat() + timestamp = (base_time + timedelta(days = i)).isoformat() with self.manager._get_db() as conn: conn.execute( """ @@ -360,13 +360,13 @@ class TestOpsManager: self.log("Recorded 30 days of historical metrics") # 创建容量规划 - prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") - plan = self.manager.create_capacity_plan( - tenant_id=self.tenant_id, - resource_type=ResourceType.CPU, - current_capacity=100.0, - prediction_date=prediction_date, - confidence=0.85, + prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d") + plan = self.manager.create_capacity_plan( + tenant_id = self.tenant_id, + resource_type = ResourceType.CPU, + current_capacity = 100.0, + prediction_date = prediction_date, + confidence = 0.85, ) self.log(f"Created capacity plan: {plan.id}") @@ -375,38 +375,38 @@ class TestOpsManager: self.log(f" Recommended action: {plan.recommended_action}") # 获取容量规划列表 - plans = self.manager.get_capacity_plans(self.tenant_id) + plans = self.manager.get_capacity_plans(self.tenant_id) assert len(plans) >= 1 self.log(f"Listed {len(plans)} capacity plans") # 清理 with self.manager._get_db() as conn: - conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,)) - conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id, )) + conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, )) conn.commit() self.log("Cleaned up capacity planning test data") except Exception as e: - self.log(f"Capacity planning test failed: {e}", success=False) + self.log(f"Capacity planning test failed: {e}", success = False) - def test_auto_scaling(self): + def test_auto_scaling(self) -> None: """测试自动扩缩容""" print("\n⚖️ Testing Auto Scaling...") try: # 创建自动扩缩容策略 - policy = self.manager.create_auto_scaling_policy( - tenant_id=self.tenant_id, - name="API 服务自动扩缩容", - resource_type=ResourceType.CPU, - min_instances=2, - max_instances=10, - target_utilization=0.7, - scale_up_threshold=0.8, - scale_down_threshold=0.3, - scale_up_step=2, - scale_down_step=1, - cooldown_period=300, + policy = self.manager.create_auto_scaling_policy( + tenant_id = self.tenant_id, + name = "API 服务自动扩缩容", + resource_type = ResourceType.CPU, + min_instances = 2, + max_instances = 10, + target_utilization = 0.7, + scale_up_threshold = 0.8, + scale_down_threshold = 0.3, + scale_up_step = 2, + scale_down_step = 1, + cooldown_period = 300, ) self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})") @@ -415,13 +415,13 @@ class TestOpsManager: self.log(f" Target utilization: {policy.target_utilization}") # 获取策略列表 - policies = self.manager.list_auto_scaling_policies(self.tenant_id) + policies = self.manager.list_auto_scaling_policies(self.tenant_id) assert len(policies) >= 1 self.log(f"Listed {len(policies)} auto scaling policies") # 模拟扩缩容评估 - event = self.manager.evaluate_scaling_policy( - policy_id=policy.id, current_instances=3, current_utilization=0.85 + event = self.manager.evaluate_scaling_policy( + policy_id = policy.id, current_instances = 3, current_utilization = 0.85 ) if event: @@ -432,62 +432,62 @@ class TestOpsManager: self.log("No scaling action needed") # 获取扩缩容事件列表 - events = self.manager.list_scaling_events(self.tenant_id) + events = self.manager.list_scaling_events(self.tenant_id) self.log(f"Listed {len(events)} scaling events") # 清理 with self.manager._get_db() as conn: - conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id, )) conn.execute( - "DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,) + "DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id, ) ) conn.commit() self.log("Cleaned up auto scaling test data") except Exception as e: - self.log(f"Auto scaling test failed: {e}", success=False) + self.log(f"Auto scaling test failed: {e}", success = False) - def test_health_checks(self): + def test_health_checks(self) -> None: """测试健康检查""" print("\n💓 Testing Health Checks...") try: # 创建 HTTP 健康检查 - check1 = self.manager.create_health_check( - tenant_id=self.tenant_id, - name="API 服务健康检查", - target_type="service", - target_id="api-service", - check_type="http", - check_config={"url": "https://api.insightflow.io/health", "expected_status": 200}, - interval=60, - timeout=10, - retry_count=3, + check1 = self.manager.create_health_check( + tenant_id = self.tenant_id, + name = "API 服务健康检查", + target_type = "service", + target_id = "api-service", + check_type = "http", + check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200}, + interval = 60, + timeout = 10, + retry_count = 3, ) self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})") # 创建 TCP 健康检查 - check2 = self.manager.create_health_check( - tenant_id=self.tenant_id, - name="数据库健康检查", - target_type="database", - target_id="postgres-001", - check_type="tcp", - check_config={"host": "db.insightflow.io", "port": 5432}, - interval=30, - timeout=5, - retry_count=2, + check2 = self.manager.create_health_check( + tenant_id = self.tenant_id, + name = "数据库健康检查", + target_type = "database", + target_id = "postgres-001", + check_type = "tcp", + check_config = {"host": "db.insightflow.io", "port": 5432}, + interval = 30, + timeout = 5, + retry_count = 2, ) self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})") # 获取健康检查列表 - checks = self.manager.list_health_checks(self.tenant_id) + checks = self.manager.list_health_checks(self.tenant_id) assert len(checks) >= 2 self.log(f"Listed {len(checks)} health checks") # 执行健康检查(异步) - async def run_health_check(): - result = await self.manager.execute_health_check(check1.id) + async def run_health_check() -> None: + result = await self.manager.execute_health_check(check1.id) return result # 由于健康检查需要网络,这里只验证方法存在 @@ -495,28 +495,28 @@ class TestOpsManager: # 清理 with self.manager._get_db() as conn: - conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id, )) conn.commit() self.log("Cleaned up health check test data") except Exception as e: - self.log(f"Health checks test failed: {e}", success=False) + self.log(f"Health checks test failed: {e}", success = False) - def test_failover(self): + def test_failover(self) -> None: """测试故障转移""" print("\n🔄 Testing Failover...") try: # 创建故障转移配置 - config = self.manager.create_failover_config( - tenant_id=self.tenant_id, - name="主备数据中心故障转移", - primary_region="cn-north-1", - secondary_regions=["cn-south-1", "cn-east-1"], - failover_trigger="health_check_failed", - auto_failover=False, - failover_timeout=300, - health_check_id=None, + config = self.manager.create_failover_config( + tenant_id = self.tenant_id, + name = "主备数据中心故障转移", + primary_region = "cn-north-1", + secondary_regions = ["cn-south-1", "cn-east-1"], + failover_trigger = "health_check_failed", + auto_failover = False, + failover_timeout = 300, + health_check_id = None, ) self.log(f"Created failover config: {config.name} (ID: {config.id})") @@ -524,13 +524,13 @@ class TestOpsManager: self.log(f" Secondary regions: {config.secondary_regions}") # 获取故障转移配置列表 - configs = self.manager.list_failover_configs(self.tenant_id) + configs = self.manager.list_failover_configs(self.tenant_id) assert len(configs) >= 1 self.log(f"Listed {len(configs)} failover configs") # 发起故障转移 - event = self.manager.initiate_failover( - config_id=config.id, reason="Primary region health check failed" + event = self.manager.initiate_failover( + config_id = config.id, reason = "Primary region health check failed" ) if event: @@ -540,41 +540,41 @@ class TestOpsManager: # 更新故障转移状态 self.manager.update_failover_status(event.id, "completed") - updated_event = self.manager.get_failover_event(event.id) + updated_event = self.manager.get_failover_event(event.id) assert updated_event.status == "completed" self.log("Failover completed") # 获取故障转移事件列表 - events = self.manager.list_failover_events(self.tenant_id) + events = self.manager.list_failover_events(self.tenant_id) self.log(f"Listed {len(events)} failover events") # 清理 with self.manager._get_db() as conn: - conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,)) - conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id, )) + conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id, )) conn.commit() self.log("Cleaned up failover test data") except Exception as e: - self.log(f"Failover test failed: {e}", success=False) + self.log(f"Failover test failed: {e}", success = False) - def test_backup(self): + def test_backup(self) -> None: """测试备份与恢复""" print("\n💾 Testing Backup & Recovery...") try: # 创建备份任务 - job = self.manager.create_backup_job( - tenant_id=self.tenant_id, - name="每日数据库备份", - backup_type="full", - target_type="database", - target_id="postgres-main", - schedule="0 2 * * *", # 每天凌晨2点 - retention_days=30, - encryption_enabled=True, - compression_enabled=True, - storage_location="s3://insightflow-backups/", + job = self.manager.create_backup_job( + tenant_id = self.tenant_id, + name = "每日数据库备份", + backup_type = "full", + target_type = "database", + target_id = "postgres-main", + schedule = "0 2 * * *", # 每天凌晨2点 + retention_days = 30, + encryption_enabled = True, + compression_enabled = True, + storage_location = "s3://insightflow-backups/", ) self.log(f"Created backup job: {job.name} (ID: {job.id})") @@ -582,12 +582,12 @@ class TestOpsManager: self.log(f" Retention: {job.retention_days} days") # 获取备份任务列表 - jobs = self.manager.list_backup_jobs(self.tenant_id) + jobs = self.manager.list_backup_jobs(self.tenant_id) assert len(jobs) >= 1 self.log(f"Listed {len(jobs)} backup jobs") # 执行备份 - record = self.manager.execute_backup(job.id) + record = self.manager.execute_backup(job.id) if record: self.log(f"Executed backup: {record.id}") @@ -595,50 +595,50 @@ class TestOpsManager: self.log(f" Storage: {record.storage_path}") # 获取备份记录列表 - records = self.manager.list_backup_records(self.tenant_id) + records = self.manager.list_backup_records(self.tenant_id) self.log(f"Listed {len(records)} backup records") # 测试恢复(模拟) - restore_result = self.manager.restore_from_backup(record.id) + restore_result = self.manager.restore_from_backup(record.id) self.log(f"Restore test result: {restore_result}") # 清理 with self.manager._get_db() as conn: - conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,)) - conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id, )) + conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id, )) conn.commit() self.log("Cleaned up backup test data") except Exception as e: - self.log(f"Backup test failed: {e}", success=False) + self.log(f"Backup test failed: {e}", success = False) - def test_cost_optimization(self): + def test_cost_optimization(self) -> None: """测试成本优化""" print("\n💰 Testing Cost Optimization...") try: # 记录资源利用率数据 - report_date = datetime.now().strftime("%Y-%m-%d") + report_date = datetime.now().strftime("%Y-%m-%d") for i in range(5): self.manager.record_resource_utilization( - tenant_id=self.tenant_id, - resource_type=ResourceType.CPU, - resource_id=f"server-{i:03d}", - utilization_rate=0.05 + random.random() * 0.1, # 低利用率 - peak_utilization=0.15, - avg_utilization=0.08, - idle_time_percent=0.85, - report_date=report_date, - recommendations=["Consider downsizing this resource"], + tenant_id = self.tenant_id, + resource_type = ResourceType.CPU, + resource_id = f"server-{i:03d}", + utilization_rate = 0.05 + random.random() * 0.1, # 低利用率 + peak_utilization = 0.15, + avg_utilization = 0.08, + idle_time_percent = 0.85, + report_date = report_date, + recommendations = ["Consider downsizing this resource"], ) self.log("Recorded 5 resource utilization records") # 生成成本报告 - now = datetime.now() - report = self.manager.generate_cost_report( - tenant_id=self.tenant_id, year=now.year, month=now.month + now = datetime.now() + report = self.manager.generate_cost_report( + tenant_id = self.tenant_id, year = now.year, month = now.month ) self.log(f"Generated cost report: {report.id}") @@ -647,11 +647,11 @@ class TestOpsManager: self.log(f" Anomalies detected: {len(report.anomalies)}") # 检测闲置资源 - idle_resources = self.manager.detect_idle_resources(self.tenant_id) + idle_resources = self.manager.detect_idle_resources(self.tenant_id) self.log(f"Detected {len(idle_resources)} idle resources") # 获取闲置资源列表 - idle_list = self.manager.get_idle_resources(self.tenant_id) + idle_list = self.manager.get_idle_resources(self.tenant_id) for resource in idle_list: self.log( f" Idle resource: {resource.resource_name} (est. cost: { @@ -660,7 +660,7 @@ class TestOpsManager: ) # 生成成本优化建议 - suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id) + suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id) self.log(f"Generated {len(suggestions)} cost optimization suggestions") for suggestion in suggestions: @@ -672,12 +672,12 @@ class TestOpsManager: self.log(f" Difficulty: {suggestion.difficulty}") # 获取优化建议列表 - all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id) + all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id) self.log(f"Listed {len(all_suggestions)} optimization suggestions") # 应用优化建议 if all_suggestions: - applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id) + applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id) if applied: self.log(f"Applied optimization suggestion: {applied.title}") assert applied.is_applied @@ -686,29 +686,29 @@ class TestOpsManager: # 清理 with self.manager._get_db() as conn: conn.execute( - "DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?", - (self.tenant_id,), + "DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?", + (self.tenant_id, ), ) - conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id, )) conn.execute( - "DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,) + "DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id, ) ) - conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,)) + conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id, )) conn.commit() self.log("Cleaned up cost optimization test data") except Exception as e: - self.log(f"Cost optimization test failed: {e}", success=False) + self.log(f"Cost optimization test failed: {e}", success = False) - def print_summary(self): + def print_summary(self) -> None: """打印测试总结""" - print("\n" + "=" * 60) + print("\n" + " = " * 60) print("Test Summary") - print("=" * 60) + print(" = " * 60) - total = len(self.test_results) - passed = sum(1 for _, success in self.test_results if success) - failed = total - passed + total = len(self.test_results) + passed = sum(1 for _, success in self.test_results if success) + failed = total - passed print(f"Total tests: {total}") print(f"Passed: {passed} ✅") @@ -720,12 +720,12 @@ class TestOpsManager: if not success: print(f" ❌ {message}") - print("=" * 60) + print(" = " * 60) -def main(): +def main() -> None: """主函数""" - test = TestOpsManager() + test = TestOpsManager() test.run_all_tests() diff --git a/backend/tingwu_client.py b/backend/tingwu_client.py index c70e9e6..de330f5 100644 --- a/backend/tingwu_client.py +++ b/backend/tingwu_client.py @@ -10,19 +10,19 @@ from typing import Any class TingwuClient: - def __init__(self): - self.access_key = os.getenv("ALI_ACCESS_KEY", "") - self.secret_key = os.getenv("ALI_SECRET_KEY", "") - self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com" + def __init__(self) -> None: + self.access_key = os.getenv("ALI_ACCESS_KEY", "") + self.secret_key = os.getenv("ALI_SECRET_KEY", "") + self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com" if not self.access_key or not self.secret_key: raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required") def _sign_request( - self, method: str, uri: str, query: str = "", body: str = "" + self, method: str, uri: str, query: str = "", body: str = "" ) -> dict[str, str]: """阿里云签名 V3""" - timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") # 简化签名,实际生产需要完整实现 # 这里使用基础认证头 @@ -31,10 +31,10 @@ class TingwuClient: "x-acs-action": "CreateTask", "x-acs-version": "2023-09-30", "x-acs-date": timestamp, - "Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing", + "Authorization": f"ACS3-HMAC-SHA256 Credential = {self.access_key}/acs/tingwu/cn-beijing", } - def create_task(self, audio_url: str, language: str = "zh") -> str: + def create_task(self, audio_url: str, language: str = "zh") -> str: """创建听悟任务""" try: # 导入移到文件顶部会导致循环导入,保持在这里 @@ -42,23 +42,23 @@ class TingwuClient: from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient - config = open_api_models.Config( - access_key_id=self.access_key, access_key_secret=self.secret_key + config = open_api_models.Config( + access_key_id = self.access_key, access_key_secret = self.secret_key ) - config.endpoint = "tingwu.cn-beijing.aliyuncs.com" - client = TingwuSDKClient(config) + config.endpoint = "tingwu.cn-beijing.aliyuncs.com" + client = TingwuSDKClient(config) - request = tingwu_models.CreateTaskRequest( - type="offline", - input=tingwu_models.Input(source="OSS", file_url=audio_url), - parameters=tingwu_models.Parameters( - transcription=tingwu_models.Transcription( - diarization_enabled=True, sentence_max_length=20 + request = tingwu_models.CreateTaskRequest( + type = "offline", + input = tingwu_models.Input(source = "OSS", file_url = audio_url), + parameters = tingwu_models.Parameters( + transcription = tingwu_models.Transcription( + diarization_enabled = True, sentence_max_length = 20 ) ), ) - response = client.create_task(request) + response = client.create_task(request) if response.body.code == "0": return response.body.data.task_id else: @@ -73,29 +73,26 @@ class TingwuClient: return f"mock_task_{int(time.time())}" def get_task_result( - self, task_id: str, max_retries: int = 60, interval: int = 5 + self, task_id: str, max_retries: int = 60, interval: int = 5 ) -> dict[str, Any]: """获取任务结果""" try: # 导入移到文件顶部会导致循环导入,保持在这里 - from alibabacloud_tea_openapi import models as open_api_models - from alibabacloud_tingwu20230930 import models as tingwu_models - from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient - config = open_api_models.Config( - access_key_id=self.access_key, access_key_secret=self.secret_key + config = open_api_models.Config( + access_key_id = self.access_key, access_key_secret = self.secret_key ) - config.endpoint = "tingwu.cn-beijing.aliyuncs.com" - client = TingwuSDKClient(config) + config.endpoint = "tingwu.cn-beijing.aliyuncs.com" + client = TingwuSDKClient(config) for i in range(max_retries): - request = tingwu_models.GetTaskInfoRequest() - response = client.get_task_info(task_id, request) + request = tingwu_models.GetTaskInfoRequest() + response = client.get_task_info(task_id, request) if response.body.code != "0": raise RuntimeError(f"Query failed: {response.body.message}") - status = response.body.data.task_status + status = response.body.data.task_status if status == "SUCCESS": return self._parse_result(response.body.data) @@ -116,11 +113,11 @@ class TingwuClient: def _parse_result(self, data) -> dict[str, Any]: """解析结果""" - result = data.result - transcription = result.transcription + result = data.result + transcription = result.transcription - full_text = "" - segments = [] + full_text = "" + segments = [] if transcription.paragraphs: for para in transcription.paragraphs: @@ -153,8 +150,8 @@ class TingwuClient: ], } - def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]: + def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]: """一键转录""" - task_id = self.create_task(audio_url, language) + task_id = self.create_task(audio_url, language) print(f"Tingwu task: {task_id}") return self.get_task_result(task_id) diff --git a/backend/workflow_manager.py b/backend/workflow_manager.py index 5ff70e9..e6f29ce 100644 --- a/backend/workflow_manager.py +++ b/backend/workflow_manager.py @@ -31,52 +31,52 @@ from workflow_manager import WorkflowManager import urllib.parse # Constants -UUID_LENGTH = 8 # UUID 截断长度 -DEFAULT_TIMEOUT = 300 # 默认超时时间(秒) -DEFAULT_RETRY_COUNT = 3 # 默认重试次数 -DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒) +UUID_LENGTH = 8 # UUID 截断长度 +DEFAULT_TIMEOUT = 300 # 默认超时时间(秒) +DEFAULT_RETRY_COUNT = 3 # 默认重试次数 +DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒) # Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logging.basicConfig(level = logging.INFO) +logger = logging.getLogger(__name__) class WorkflowStatus(Enum): """工作流状态""" - ACTIVE = "active" - PAUSED = "paused" - ERROR = "error" - COMPLETED = "completed" + ACTIVE = "active" + PAUSED = "paused" + ERROR = "error" + COMPLETED = "completed" class WorkflowType(Enum): """工作流类型""" - AUTO_ANALYZE = "auto_analyze" # 自动分析新文件 - AUTO_ALIGN = "auto_align" # 自动实体对齐 - AUTO_RELATION = "auto_relation" # 自动关系发现 - SCHEDULED_REPORT = "scheduled_report" # 定时报告 - CUSTOM = "custom" # 自定义工作流 + AUTO_ANALYZE = "auto_analyze" # 自动分析新文件 + AUTO_ALIGN = "auto_align" # 自动实体对齐 + AUTO_RELATION = "auto_relation" # 自动关系发现 + SCHEDULED_REPORT = "scheduled_report" # 定时报告 + CUSTOM = "custom" # 自定义工作流 class WebhookType(Enum): """Webhook 类型""" - FEISHU = "feishu" - DINGTALK = "dingtalk" - SLACK = "slack" - CUSTOM = "custom" + FEISHU = "feishu" + DINGTALK = "dingtalk" + SLACK = "slack" + CUSTOM = "custom" class TaskStatus(Enum): """任务执行状态""" - PENDING = "pending" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - CANCELLED = "cancelled" + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + CANCELLED = "cancelled" @dataclass @@ -87,20 +87,20 @@ class WorkflowTask: workflow_id: str name: str task_type: str # analyze, align, discover_relations, notify, custom - config: dict = field(default_factory=dict) - order: int = 0 - depends_on: list[str] = field(default_factory=list) - timeout_seconds: int = 300 - retry_count: int = 3 - retry_delay: int = 5 - created_at: str = "" - updated_at: str = "" + config: dict = field(default_factory = dict) + order: int = 0 + depends_on: list[str] = field(default_factory = list) + timeout_seconds: int = 300 + retry_count: int = 3 + retry_delay: int = 5 + created_at: str = "" + updated_at: str = "" - def __post_init__(self): + def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() if not self.updated_at: - self.updated_at = self.created_at + self.updated_at = self.created_at @dataclass @@ -111,21 +111,21 @@ class WebhookConfig: name: str webhook_type: str # feishu, dingtalk, slack, custom url: str - secret: str = "" # 用于签名验证 - headers: dict = field(default_factory=dict) - template: str = "" # 消息模板 - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_used_at: str | None = None - success_count: int = 0 - fail_count: int = 0 + secret: str = "" # 用于签名验证 + headers: dict = field(default_factory = dict) + template: str = "" # 消息模板 + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_used_at: str | None = None + success_count: int = 0 + fail_count: int = 0 - def __post_init__(self): + def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() if not self.updated_at: - self.updated_at = self.created_at + self.updated_at = self.created_at @dataclass @@ -137,25 +137,25 @@ class Workflow: description: str workflow_type: str project_id: str - status: str = "active" - schedule: str | None = None # cron expression or interval - schedule_type: str = "manual" # manual, cron, interval - config: dict = field(default_factory=dict) - webhook_ids: list[str] = field(default_factory=list) - is_active: bool = True - created_at: str = "" - updated_at: str = "" - last_run_at: str | None = None - next_run_at: str | None = None - run_count: int = 0 - success_count: int = 0 - fail_count: int = 0 + status: str = "active" + schedule: str | None = None # cron expression or interval + schedule_type: str = "manual" # manual, cron, interval + config: dict = field(default_factory = dict) + webhook_ids: list[str] = field(default_factory = list) + is_active: bool = True + created_at: str = "" + updated_at: str = "" + last_run_at: str | None = None + next_run_at: str | None = None + run_count: int = 0 + success_count: int = 0 + fail_count: int = 0 - def __post_init__(self): + def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() if not self.updated_at: - self.updated_at = self.created_at + self.updated_at = self.created_at @dataclass @@ -164,31 +164,31 @@ class WorkflowLog: id: str workflow_id: str - task_id: str | None = None - status: str = "pending" # pending, running, success, failed, cancelled - start_time: str | None = None - end_time: str | None = None - duration_ms: int = 0 - input_data: dict = field(default_factory=dict) - output_data: dict = field(default_factory=dict) - error_message: str = "" - created_at: str = "" + task_id: str | None = None + status: str = "pending" # pending, running, success, failed, cancelled + start_time: str | None = None + end_time: str | None = None + duration_ms: int = 0 + input_data: dict = field(default_factory = dict) + output_data: dict = field(default_factory = dict) + error_message: str = "" + created_at: str = "" - def __post_init__(self): + def __post_init__(self) -> None: if not self.created_at: - self.created_at = datetime.now().isoformat() + self.created_at = datetime.now().isoformat() class WebhookNotifier: """Webhook 通知器 - 支持飞书、钉钉、Slack""" - def __init__(self): - self.http_client = httpx.AsyncClient(timeout=30.0) + def __init__(self) -> None: + self.http_client = httpx.AsyncClient(timeout = 30.0) async def send(self, config: WebhookConfig, message: dict) -> bool: """发送 Webhook 通知""" try: - webhook_type = WebhookType(config.webhook_type) + webhook_type = WebhookType(config.webhook_type) if webhook_type == WebhookType.FEISHU: return await self._send_feishu(config, message) @@ -205,20 +205,20 @@ class WebhookNotifier: async def _send_feishu(self, config: WebhookConfig, message: dict) -> bool: """发送飞书通知""" - timestamp = str(int(datetime.now().timestamp())) + timestamp = str(int(datetime.now().timestamp())) # 签名计算 if config.secret: - string_to_sign = f"{timestamp}\n{config.secret}" - hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() - sign = base64.b64encode(hmac_code).decode("utf-8") + string_to_sign = f"{timestamp}\n{config.secret}" + hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod = hashlib.sha256).digest() + sign = base64.b64encode(hmac_code).decode("utf-8") else: - sign = "" + sign = "" # 构建消息体 if "content" in message: # 文本消息 - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "text", @@ -226,7 +226,7 @@ class WebhookNotifier: } elif "title" in message: # 富文本消息 - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "post", @@ -241,47 +241,47 @@ class WebhookNotifier: } else: # 卡片消息 - payload = { + payload = { "timestamp": timestamp, "sign": sign, "msg_type": "interactive", "card": message.get("card", {}), } - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(config.url, json=payload, headers=headers) + response = await self.http_client.post(config.url, json = payload, headers = headers) response.raise_for_status() - result = response.json() + result = response.json() return result.get("code") == 0 async def _send_dingtalk(self, config: WebhookConfig, message: dict) -> bool: """发送钉钉通知""" - timestamp = str(round(datetime.now().timestamp() * 1000)) + timestamp = str(round(datetime.now().timestamp() * 1000)) # 签名计算 if config.secret: - secret_enc = config.secret.encode("utf-8") - string_to_sign = f"{timestamp}\n{config.secret}" - hmac_code = hmac.new( - secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256 + secret_enc = config.secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{config.secret}" + hmac_code = hmac.new( + secret_enc, string_to_sign.encode("utf-8"), digestmod = hashlib.sha256 ).digest() - sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - url = f"{config.url}×tamp={timestamp}&sign={sign}" + sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) + url = f"{config.url}×tamp = {timestamp}&sign = {sign}" else: - url = config.url + url = config.url # 构建消息体 if "content" in message: - payload = {"msgtype": "text", "text": {"content": message["content"]}} + payload = {"msgtype": "text", "text": {"content": message["content"]}} elif "title" in message: - payload = { + payload = { "msgtype": "markdown", "markdown": {"title": message["title"], "text": message.get("markdown", "")}, } elif "link" in message: - payload = { + payload = { "msgtype": "link", "link": { "text": message.get("text", ""), @@ -291,46 +291,46 @@ class WebhookNotifier: }, } else: - payload = {"msgtype": "action_card", "action_card": message.get("action_card", {})} + payload = {"msgtype": "action_card", "action_card": message.get("action_card", {})} - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(url, json=payload, headers=headers) + response = await self.http_client.post(url, json = payload, headers = headers) response.raise_for_status() - result = response.json() + result = response.json() return result.get("errcode") == 0 async def _send_slack(self, config: WebhookConfig, message: dict) -> bool: """发送 Slack 通知""" # Slack 直接支持标准 webhook 格式 - payload = { + payload = { "text": message.get("content", message.get("text", "")), } if "blocks" in message: - payload["blocks"] = message["blocks"] + payload["blocks"] = message["blocks"] if "attachments" in message: - payload["attachments"] = message["attachments"] + payload["attachments"] = message["attachments"] - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(config.url, json=payload, headers=headers) + response = await self.http_client.post(config.url, json = payload, headers = headers) response.raise_for_status() return response.text == "ok" async def _send_custom(self, config: WebhookConfig, message: dict) -> bool: """发送自定义 Webhook 通知""" - headers = {"Content-Type": "application/json", **config.headers} + headers = {"Content-Type": "application/json", **config.headers} - response = await self.http_client.post(config.url, json=message, headers=headers) + response = await self.http_client.post(config.url, json = message, headers = headers) response.raise_for_status() return True - async def close(self): + async def close(self) -> None: """关闭 HTTP 客户端""" await self.http_client.aclose() @@ -339,16 +339,16 @@ class WorkflowManager: """工作流管理器 - 核心管理类""" # 默认配置常量 - DEFAULT_TIMEOUT: int = 300 - DEFAULT_RETRY_COUNT: int = 3 - DEFAULT_RETRY_DELAY: int = 5 + DEFAULT_TIMEOUT: int = 300 + DEFAULT_RETRY_COUNT: int = 3 + DEFAULT_RETRY_DELAY: int = 5 - def __init__(self, db_manager=None): - self.db = db_manager - self.scheduler = AsyncIOScheduler() - self.notifier = WebhookNotifier() - self._task_handlers: dict[str, Callable] = {} - self._running_tasks: dict[str, asyncio.Task] = {} + def __init__(self, db_manager = None) -> None: + self.db = db_manager + self.scheduler = AsyncIOScheduler() + self.notifier = WebhookNotifier() + self._task_handlers: dict[str, Callable] = {} + self._running_tasks: dict[str, asyncio.Task] = {} self._setup_default_handlers() # 添加调度器事件监听 @@ -356,7 +356,7 @@ class WorkflowManager: def _setup_default_handlers(self) -> None: """设置默认的任务处理器""" - self._task_handlers = { + self._task_handlers = { "analyze": self._handle_analyze_task, "align": self._handle_align_task, "discover_relations": self._handle_discover_relations_task, @@ -366,7 +366,7 @@ class WorkflowManager: def register_task_handler(self, task_type: str, handler: Callable) -> None: """注册自定义任务处理器""" - self._task_handlers[task_type] = handler + self._task_handlers[task_type] = handler def start(self) -> None: """启动工作流管理器""" @@ -381,13 +381,13 @@ class WorkflowManager: def stop(self) -> None: """停止工作流管理器""" if self.scheduler.running: - self.scheduler.shutdown(wait=True) + self.scheduler.shutdown(wait = True) logger.info("Workflow scheduler stopped") - async def _load_and_schedule_workflows(self): + async def _load_and_schedule_workflows(self) -> None: """从数据库加载并调度所有活跃工作流""" try: - workflows = self.list_workflows(status="active") + workflows = self.list_workflows(status = "active") for workflow in workflows: if workflow.schedule and workflow.is_active: self._schedule_workflow(workflow) @@ -396,7 +396,7 @@ class WorkflowManager: def _schedule_workflow(self, workflow: Workflow) -> None: """调度工作流""" - job_id = f"workflow_{workflow.id}" + job_id = f"workflow_{workflow.id}" # 移除已存在的任务 if self.scheduler.get_job(job_id): @@ -404,29 +404,29 @@ class WorkflowManager: if workflow.schedule_type == "cron": # Cron 表达式调度 - trigger = CronTrigger.from_crontab(workflow.schedule) + trigger = CronTrigger.from_crontab(workflow.schedule) elif workflow.schedule_type == "interval": # 间隔调度 - interval_minutes = int(workflow.schedule) - trigger = IntervalTrigger(minutes=interval_minutes) + interval_minutes = int(workflow.schedule) + trigger = IntervalTrigger(minutes = interval_minutes) else: return self.scheduler.add_job( - func=self._execute_workflow_job, - trigger=trigger, - id=job_id, - args=[workflow.id], - replace_existing=True, - max_instances=1, - coalesce=True, + func = self._execute_workflow_job, + trigger = trigger, + id = job_id, + args = [workflow.id], + replace_existing = True, + max_instances = 1, + coalesce = True, ) logger.info( f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}" ) - async def _execute_workflow_job(self, workflow_id: str): + async def _execute_workflow_job(self, workflow_id: str) -> None: """调度器调用的工作流执行函数""" try: await self.execute_workflow(workflow_id) @@ -444,7 +444,7 @@ class WorkflowManager: def create_workflow(self, workflow: Workflow) -> Workflow: """创建工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO workflows @@ -486,9 +486,9 @@ class WorkflowManager: def get_workflow(self, workflow_id: str) -> Workflow | None: """获取工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id,)).fetchone() + row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id, )).fetchone() if not row: return None @@ -498,27 +498,27 @@ class WorkflowManager: conn.close() def list_workflows( - self, project_id: str = None, status: str = None, workflow_type: str = None + self, project_id: str = None, status: str = None, workflow_type: str = None ) -> list[Workflow]: """列出工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - conditions = [] - params = [] + conditions = [] + params = [] if project_id: - conditions.append("project_id = ?") + conditions.append("project_id = ?") params.append(project_id) if status: - conditions.append("status = ?") + conditions.append("status = ?") params.append(status) if workflow_type: - conditions.append("workflow_type = ?") + conditions.append("workflow_type = ?") params.append(workflow_type) - where_clause = " AND ".join(conditions) if conditions else "1=1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params ).fetchall() @@ -528,9 +528,9 @@ class WorkflowManager: def update_workflow(self, workflow_id: str, **kwargs) -> Workflow | None: """更新工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = [ + allowed_fields = [ "name", "description", "status", @@ -540,12 +540,12 @@ class WorkflowManager: "config", "webhook_ids", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") if f in ["config", "webhook_ids"]: values.append(json.dumps(kwargs[f])) else: @@ -554,20 +554,20 @@ class WorkflowManager: if not updates: return self.get_workflow(workflow_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(workflow_id) - query = f"UPDATE workflows SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE workflows SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() # 重新调度 - workflow = self.get_workflow(workflow_id) + workflow = self.get_workflow(workflow_id) if workflow and workflow.schedule and workflow.is_active: self._schedule_workflow(workflow) elif workflow and not workflow.is_active: - job_id = f"workflow_{workflow_id}" + job_id = f"workflow_{workflow_id}" if self.scheduler.get_job(job_id): self.scheduler.remove_job(job_id) @@ -577,18 +577,18 @@ class WorkflowManager: def delete_workflow(self, workflow_id: str) -> bool: """删除工作流""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: # 移除调度 - job_id = f"workflow_{workflow_id}" + job_id = f"workflow_{workflow_id}" if self.scheduler.get_job(job_id): self.scheduler.remove_job(job_id) # 删除相关任务 - conn.execute("DELETE FROM workflow_tasks WHERE workflow_id = ?", (workflow_id,)) + conn.execute("DELETE FROM workflow_tasks WHERE workflow_id = ?", (workflow_id, )) # 删除工作流 - conn.execute("DELETE FROM workflows WHERE id = ?", (workflow_id,)) + conn.execute("DELETE FROM workflows WHERE id = ?", (workflow_id, )) conn.commit() return True @@ -598,31 +598,31 @@ class WorkflowManager: def _row_to_workflow(self, row) -> Workflow: """将数据库行转换为 Workflow 对象""" return Workflow( - id=row["id"], - name=row["name"], - description=row["description"] or "", - workflow_type=row["workflow_type"], - project_id=row["project_id"], - status=row["status"], - schedule=row["schedule"], - schedule_type=row["schedule_type"], - config=json.loads(row["config"]) if row["config"] else {}, - webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], - last_run_at=row["last_run_at"], - next_run_at=row["next_run_at"], - run_count=row["run_count"] or 0, - success_count=row["success_count"] or 0, - fail_count=row["fail_count"] or 0, + id = row["id"], + name = row["name"], + description = row["description"] or "", + workflow_type = row["workflow_type"], + project_id = row["project_id"], + status = row["status"], + schedule = row["schedule"], + schedule_type = row["schedule_type"], + config = json.loads(row["config"]) if row["config"] else {}, + webhook_ids = json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], + last_run_at = row["last_run_at"], + next_run_at = row["next_run_at"], + run_count = row["run_count"] or 0, + success_count = row["success_count"] or 0, + fail_count = row["fail_count"] or 0, ) # ==================== Workflow Task CRUD ==================== def create_task(self, task: WorkflowTask) -> WorkflowTask: """创建工作流任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO workflow_tasks @@ -652,9 +652,9 @@ class WorkflowManager: def get_task(self, task_id: str) -> WorkflowTask | None: """获取任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id,)).fetchone() + row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id, )).fetchone() if not row: return None @@ -665,11 +665,11 @@ class WorkflowManager: def list_tasks(self, workflow_id: str) -> list[WorkflowTask]: """列出工作流的所有任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - rows = conn.execute( - "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", - (workflow_id,), + rows = conn.execute( + "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", + (workflow_id, ), ).fetchall() return [self._row_to_task(row) for row in rows] @@ -678,9 +678,9 @@ class WorkflowManager: def update_task(self, task_id: str, **kwargs) -> WorkflowTask | None: """更新任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = [ + allowed_fields = [ "name", "task_type", "config", @@ -690,12 +690,12 @@ class WorkflowManager: "retry_count", "retry_delay", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") if f in ["config", "depends_on"]: values.append(json.dumps(kwargs[f])) else: @@ -704,11 +704,11 @@ class WorkflowManager: if not updates: return self.get_task(task_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(task_id) - query = f"UPDATE workflow_tasks SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE workflow_tasks SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -718,9 +718,9 @@ class WorkflowManager: def delete_task(self, task_id: str) -> bool: """删除任务""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id,)) + conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id, )) conn.commit() return True finally: @@ -729,25 +729,25 @@ class WorkflowManager: def _row_to_task(self, row) -> WorkflowTask: """将数据库行转换为 WorkflowTask 对象""" return WorkflowTask( - id=row["id"], - workflow_id=row["workflow_id"], - name=row["name"], - task_type=row["task_type"], - config=json.loads(row["config"]) if row["config"] else {}, - order=row["task_order"] or 0, - depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [], - timeout_seconds=row["timeout_seconds"] or 300, - retry_count=row["retry_count"] or 3, - retry_delay=row["retry_delay"] or 5, - created_at=row["created_at"], - updated_at=row["updated_at"], + id = row["id"], + workflow_id = row["workflow_id"], + name = row["name"], + task_type = row["task_type"], + config = json.loads(row["config"]) if row["config"] else {}, + order = row["task_order"] or 0, + depends_on = json.loads(row["depends_on"]) if row["depends_on"] else [], + timeout_seconds = row["timeout_seconds"] or 300, + retry_count = row["retry_count"] or 3, + retry_delay = row["retry_delay"] or 5, + created_at = row["created_at"], + updated_at = row["updated_at"], ) # ==================== Webhook Config CRUD ==================== def create_webhook(self, webhook: WebhookConfig) -> WebhookConfig: """创建 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO webhook_configs @@ -778,10 +778,10 @@ class WorkflowManager: def get_webhook(self, webhook_id: str) -> WebhookConfig | None: """获取 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute( - "SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,) + row = conn.execute( + "SELECT * FROM webhook_configs WHERE id = ?", (webhook_id, ) ).fetchone() if not row: @@ -793,9 +793,9 @@ class WorkflowManager: def list_webhooks(self) -> list[WebhookConfig]: """列出所有 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - rows = conn.execute("SELECT * FROM webhook_configs ORDER BY created_at DESC").fetchall() + rows = conn.execute("SELECT * FROM webhook_configs ORDER BY created_at DESC").fetchall() return [self._row_to_webhook(row) for row in rows] finally: @@ -803,9 +803,9 @@ class WorkflowManager: def update_webhook(self, webhook_id: str, **kwargs) -> WebhookConfig | None: """更新 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = [ + allowed_fields = [ "name", "webhook_type", "url", @@ -814,12 +814,12 @@ class WorkflowManager: "template", "is_active", ] - updates = [] - values = [] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") if f == "headers": values.append(json.dumps(kwargs[f])) else: @@ -828,11 +828,11 @@ class WorkflowManager: if not updates: return self.get_webhook(webhook_id) - updates.append("updated_at = ?") + updates.append("updated_at = ?") values.append(datetime.now().isoformat()) values.append(webhook_id) - query = f"UPDATE webhook_configs SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE webhook_configs SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -842,9 +842,9 @@ class WorkflowManager: def delete_webhook(self, webhook_id: str) -> bool: """删除 Webhook 配置""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id,)) + conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id, )) conn.commit() return True finally: @@ -852,20 +852,20 @@ class WorkflowManager: def update_webhook_stats(self, webhook_id: str, success: bool) -> None: """更新 Webhook 统计""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: if success: conn.execute( """UPDATE webhook_configs - SET success_count = success_count + 1, last_used_at = ? - WHERE id = ?""", + SET success_count = success_count + 1, last_used_at = ? + WHERE id = ?""", (datetime.now().isoformat(), webhook_id), ) else: conn.execute( """UPDATE webhook_configs - SET fail_count = fail_count + 1, last_used_at = ? - WHERE id = ?""", + SET fail_count = fail_count + 1, last_used_at = ? + WHERE id = ?""", (datetime.now().isoformat(), webhook_id), ) conn.commit() @@ -875,26 +875,26 @@ class WorkflowManager: def _row_to_webhook(self, row) -> WebhookConfig: """将数据库行转换为 WebhookConfig 对象""" return WebhookConfig( - id=row["id"], - name=row["name"], - webhook_type=row["webhook_type"], - url=row["url"], - secret=row["secret"] or "", - headers=json.loads(row["headers"]) if row["headers"] else {}, - template=row["template"] or "", - is_active=bool(row["is_active"]), - created_at=row["created_at"], - updated_at=row["updated_at"], - last_used_at=row["last_used_at"], - success_count=row["success_count"] or 0, - fail_count=row["fail_count"] or 0, + id = row["id"], + name = row["name"], + webhook_type = row["webhook_type"], + url = row["url"], + secret = row["secret"] or "", + headers = json.loads(row["headers"]) if row["headers"] else {}, + template = row["template"] or "", + is_active = bool(row["is_active"]), + created_at = row["created_at"], + updated_at = row["updated_at"], + last_used_at = row["last_used_at"], + success_count = row["success_count"] or 0, + fail_count = row["fail_count"] or 0, ) # ==================== Workflow Log ==================== def create_log(self, log: WorkflowLog) -> WorkflowLog: """创建工作流日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: conn.execute( """INSERT INTO workflow_logs @@ -922,15 +922,15 @@ class WorkflowManager: def update_log(self, log_id: str, **kwargs) -> WorkflowLog | None: """更新工作流日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - allowed_fields = ["status", "end_time", "duration_ms", "output_data", "error_message"] - updates = [] - values = [] + allowed_fields = ["status", "end_time", "duration_ms", "output_data", "error_message"] + updates = [] + values = [] for f in allowed_fields: if f in kwargs: - updates.append(f"{f} = ?") + updates.append(f"{f} = ?") if f == "output_data": values.append(json.dumps(kwargs[f])) else: @@ -940,7 +940,7 @@ class WorkflowManager: return None values.append(log_id) - query = f"UPDATE workflow_logs SET {', '.join(updates)} WHERE id = ?" + query = f"UPDATE workflow_logs SET {', '.join(updates)} WHERE id = ?" conn.execute(query, values) conn.commit() @@ -950,9 +950,9 @@ class WorkflowManager: def get_log(self, log_id: str) -> WorkflowLog | None: """获取日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id,)).fetchone() + row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id, )).fetchone() if not row: return None @@ -963,31 +963,31 @@ class WorkflowManager: def list_logs( self, - workflow_id: str = None, - task_id: str = None, - status: str = None, - limit: int = 100, - offset: int = 0, + workflow_id: str = None, + task_id: str = None, + status: str = None, + limit: int = 100, + offset: int = 0, ) -> list[WorkflowLog]: """列出工作流日志""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - conditions = [] - params = [] + conditions = [] + params = [] if workflow_id: - conditions.append("workflow_id = ?") + conditions.append("workflow_id = ?") params.append(workflow_id) if task_id: - conditions.append("task_id = ?") + conditions.append("task_id = ?") params.append(task_id) if status: - conditions.append("status = ?") + conditions.append("status = ?") params.append(status) - where_clause = " AND ".join(conditions) if conditions else "1=1" + where_clause = " AND ".join(conditions) if conditions else "1 = 1" - rows = conn.execute( + rows = conn.execute( f"""SELECT * FROM workflow_logs WHERE {where_clause} ORDER BY created_at DESC @@ -999,46 +999,46 @@ class WorkflowManager: finally: conn.close() - def get_workflow_stats(self, workflow_id: str, days: int = 30) -> dict: + def get_workflow_stats(self, workflow_id: str, days: int = 30) -> dict: """获取工作流统计""" - conn = self.db.get_conn() + conn = self.db.get_conn() try: - since = (datetime.now() - timedelta(days=days)).isoformat() + since = (datetime.now() - timedelta(days = days)).isoformat() # 总执行次数 - total = conn.execute( - "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", + total = conn.execute( + "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since), ).fetchone()[0] # 成功次数 - success = conn.execute( - "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'success' AND created_at > ?", + success = conn.execute( + "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'success' AND created_at > ?", (workflow_id, since), ).fetchone()[0] # 失败次数 - failed = conn.execute( - "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'failed' AND created_at > ?", + failed = conn.execute( + "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'failed' AND created_at > ?", (workflow_id, since), ).fetchone()[0] # 平均执行时间 - avg_duration = ( + avg_duration = ( conn.execute( - "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", + "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since), ).fetchone()[0] or 0 ) # 每日统计 - daily = conn.execute( + daily = conn.execute( """SELECT DATE(created_at) as date, COUNT(*) as count, - SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success + SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success FROM workflow_logs - WHERE workflow_id = ? AND created_at > ? + WHERE workflow_id = ? AND created_at > ? GROUP BY DATE(created_at) ORDER BY date""", (workflow_id, since), @@ -1060,24 +1060,24 @@ class WorkflowManager: def _row_to_log(self, row) -> WorkflowLog: """将数据库行转换为 WorkflowLog 对象""" return WorkflowLog( - id=row["id"], - workflow_id=row["workflow_id"], - task_id=row["task_id"], - status=row["status"], - start_time=row["start_time"], - end_time=row["end_time"], - duration_ms=row["duration_ms"] or 0, - input_data=json.loads(row["input_data"]) if row["input_data"] else {}, - output_data=json.loads(row["output_data"]) if row["output_data"] else {}, - error_message=row["error_message"] or "", - created_at=row["created_at"], + id = row["id"], + workflow_id = row["workflow_id"], + task_id = row["task_id"], + status = row["status"], + start_time = row["start_time"], + end_time = row["end_time"], + duration_ms = row["duration_ms"] or 0, + input_data = json.loads(row["input_data"]) if row["input_data"] else {}, + output_data = json.loads(row["output_data"]) if row["output_data"] else {}, + error_message = row["error_message"] or "", + created_at = row["created_at"], ) # ==================== Workflow Execution ==================== - async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict: + async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict: """执行工作流""" - workflow = self.get_workflow(workflow_id) + workflow = self.get_workflow(workflow_id) if not workflow: raise ValueError(f"Workflow {workflow_id} not found") @@ -1085,49 +1085,49 @@ class WorkflowManager: raise ValueError(f"Workflow {workflow_id} is not active") # 更新最后运行时间 - now = datetime.now().isoformat() - self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1) + now = datetime.now().isoformat() + self.update_workflow(workflow_id, last_run_at = now, run_count = workflow.run_count + 1) # 创建工作流执行日志 - log = WorkflowLog( - id=str(uuid.uuid4())[:UUID_LENGTH], - workflow_id=workflow_id, - status=TaskStatus.RUNNING.value, - start_time=now, - input_data=input_data or {}, + log = WorkflowLog( + id = str(uuid.uuid4())[:UUID_LENGTH], + workflow_id = workflow_id, + status = TaskStatus.RUNNING.value, + start_time = now, + input_data = input_data or {}, ) self.create_log(log) - start_time = datetime.now() - results = {} + start_time = datetime.now() + results = {} try: # 获取所有任务 - tasks = self.list_tasks(workflow_id) + tasks = self.list_tasks(workflow_id) if not tasks: # 没有任务时执行默认行为 - results = await self._execute_default_workflow(workflow, input_data) + results = await self._execute_default_workflow(workflow, input_data) else: # 按依赖顺序执行任务 - results = await self._execute_tasks_with_deps(tasks, input_data, log.id) + results = await self._execute_tasks_with_deps(tasks, input_data, log.id) # 发送通知 - await self._send_workflow_notification(workflow, results, success=True) + await self._send_workflow_notification(workflow, results, success = True) # 更新日志为成功 - end_time = datetime.now() - duration = int((end_time - start_time).total_seconds() * 1000) + end_time = datetime.now() + duration = int((end_time - start_time).total_seconds() * 1000) self.update_log( log.id, - status=TaskStatus.SUCCESS.value, - end_time=end_time.isoformat(), - duration_ms=duration, - output_data=results, + status = TaskStatus.SUCCESS.value, + end_time = end_time.isoformat(), + duration_ms = duration, + output_data = results, ) # 更新成功计数 - self.update_workflow(workflow_id, success_count=workflow.success_count + 1) + self.update_workflow(workflow_id, success_count = workflow.success_count + 1) return { "success": True, @@ -1141,21 +1141,21 @@ class WorkflowManager: logger.error(f"Workflow {workflow_id} execution failed: {e}") # 更新日志为失败 - end_time = datetime.now() - duration = int((end_time - start_time).total_seconds() * 1000) + end_time = datetime.now() + duration = int((end_time - start_time).total_seconds() * 1000) self.update_log( log.id, - status=TaskStatus.FAILED.value, - end_time=end_time.isoformat(), - duration_ms=duration, - error_message=str(e), + status = TaskStatus.FAILED.value, + end_time = end_time.isoformat(), + duration_ms = duration, + error_message = str(e), ) # 更新失败计数 - self.update_workflow(workflow_id, fail_count=workflow.fail_count + 1) + self.update_workflow(workflow_id, fail_count = workflow.fail_count + 1) # 发送失败通知 - await self._send_workflow_notification(workflow, {"error": str(e)}, success=False) + await self._send_workflow_notification(workflow, {"error": str(e)}, success = False) raise @@ -1163,12 +1163,12 @@ class WorkflowManager: self, tasks: list[WorkflowTask], input_data: dict, log_id: str ) -> dict: """按依赖顺序执行任务""" - results = {} - completed_tasks = set() + results = {} + completed_tasks = set() while len(completed_tasks) < len(tasks): # 找到可以执行的任务(依赖已完成) - ready_tasks = [ + ready_tasks = [ t for t in tasks if t.id not in completed_tasks @@ -1180,12 +1180,12 @@ class WorkflowManager: raise ValueError("Circular dependency detected or tasks cannot be resolved") # 并行执行就绪的任务 - task_coros = [] + task_coros = [] for task in ready_tasks: - task_input = {**input_data, **results} + task_input = {**input_data, **results} task_coros.append(self._execute_single_task(task, task_input, log_id)) - task_results = await asyncio.gather(*task_coros, return_exceptions=True) + task_results = await asyncio.gather(*task_coros, return_exceptions = True) for task, result in zip(ready_tasks, task_results): if isinstance(result, Exception): @@ -1195,7 +1195,7 @@ class WorkflowManager: for attempt in range(task.retry_count): await asyncio.sleep(task.retry_delay) try: - result = await self._execute_single_task(task, task_input, log_id) + result = await self._execute_single_task(task, task_input, log_id) break except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Task {task.id} retry {attempt + 1} failed: {e}") @@ -1204,38 +1204,38 @@ class WorkflowManager: else: raise result - results[task.name] = result + results[task.name] = result completed_tasks.add(task.id) return results async def _execute_single_task(self, task: WorkflowTask, input_data: dict, log_id: str) -> Any: """执行单个任务""" - handler = self._task_handlers.get(task.task_type) + handler = self._task_handlers.get(task.task_type) if not handler: raise ValueError(f"No handler for task type: {task.task_type}") # 创建任务日志 - task_log = WorkflowLog( - id=str(uuid.uuid4())[:UUID_LENGTH], - workflow_id=task.workflow_id, - task_id=task.id, - status=TaskStatus.RUNNING.value, - start_time=datetime.now().isoformat(), - input_data=input_data, + task_log = WorkflowLog( + id = str(uuid.uuid4())[:UUID_LENGTH], + workflow_id = task.workflow_id, + task_id = task.id, + status = TaskStatus.RUNNING.value, + start_time = datetime.now().isoformat(), + input_data = input_data, ) self.create_log(task_log) try: # 设置超时 - result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds) + result = await asyncio.wait_for(handler(task, input_data), timeout = task.timeout_seconds) # 更新任务日志为成功 self.update_log( task_log.id, - status=TaskStatus.SUCCESS.value, - end_time=datetime.now().isoformat(), - output_data={"result": result} if not isinstance(result, dict) else result, + status = TaskStatus.SUCCESS.value, + end_time = datetime.now().isoformat(), + output_data = {"result": result} if not isinstance(result, dict) else result, ) return result @@ -1243,24 +1243,24 @@ class WorkflowManager: except TimeoutError: self.update_log( task_log.id, - status=TaskStatus.FAILED.value, - end_time=datetime.now().isoformat(), - error_message="Task timeout", + status = TaskStatus.FAILED.value, + end_time = datetime.now().isoformat(), + error_message = "Task timeout", ) raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s") except Exception as e: self.update_log( task_log.id, - status=TaskStatus.FAILED.value, - end_time=datetime.now().isoformat(), - error_message=str(e), + status = TaskStatus.FAILED.value, + end_time = datetime.now().isoformat(), + error_message = str(e), ) raise async def _execute_default_workflow(self, workflow: Workflow, input_data: dict) -> dict: """执行默认工作流(根据类型)""" - workflow_type = WorkflowType(workflow.workflow_type) + workflow_type = WorkflowType(workflow.workflow_type) if workflow_type == WorkflowType.AUTO_ANALYZE: return await self._auto_analyze_files(workflow, input_data) @@ -1277,8 +1277,8 @@ class WorkflowManager: async def _handle_analyze_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理分析任务""" - project_id = input_data.get("project_id") - file_ids = input_data.get("file_ids", []) + project_id = input_data.get("project_id") + file_ids = input_data.get("file_ids", []) if not project_id: raise ValueError("project_id required for analyze task") @@ -1294,8 +1294,8 @@ class WorkflowManager: async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理实体对齐任务""" - project_id = input_data.get("project_id") - threshold = task.config.get("threshold", 0.85) + project_id = input_data.get("project_id") + threshold = task.config.get("threshold", 0.85) if not project_id: raise ValueError("project_id required for align task") @@ -1311,7 +1311,7 @@ class WorkflowManager: async def _handle_discover_relations_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理关系发现任务""" - project_id = input_data.get("project_id") + project_id = input_data.get("project_id") if not project_id: raise ValueError("project_id required for discover_relations task") @@ -1326,24 +1326,24 @@ class WorkflowManager: async def _handle_notify_task(self, task: WorkflowTask, input_data: dict) -> dict: """处理通知任务""" - webhook_id = task.config.get("webhook_id") - message = task.config.get("message", {}) + webhook_id = task.config.get("webhook_id") + message = task.config.get("message", {}) if not webhook_id: raise ValueError("webhook_id required for notify task") - webhook = self.get_webhook(webhook_id) + webhook = self.get_webhook(webhook_id) if not webhook: raise ValueError(f"Webhook {webhook_id} not found") # 替换模板变量 if webhook.template: try: - message = json.loads(webhook.template.format(**input_data)) + message = json.loads(webhook.template.format(**input_data)) except (json.JSONDecodeError, KeyError, ValueError): pass - success = await self.notifier.send(webhook, message) + success = await self.notifier.send(webhook, message) self.update_webhook_stats(webhook_id, success) return {"task": "notify", "webhook_id": webhook_id, "success": success} @@ -1362,7 +1362,7 @@ class WorkflowManager: async def _auto_analyze_files(self, workflow: Workflow, input_data: dict) -> dict: """自动分析新上传的文件""" - project_id = workflow.project_id + project_id = workflow.project_id # 获取未分析的文件(实际实现需要查询数据库) # 这里是一个示例实现 @@ -1377,8 +1377,8 @@ class WorkflowManager: async def _auto_align_entities(self, workflow: Workflow, input_data: dict) -> dict: """自动实体对齐""" - project_id = workflow.project_id - threshold = workflow.config.get("threshold", 0.85) + project_id = workflow.project_id + threshold = workflow.config.get("threshold", 0.85) return { "workflow_type": "auto_align", @@ -1390,7 +1390,7 @@ class WorkflowManager: async def _auto_discover_relations(self, workflow: Workflow, input_data: dict) -> dict: """自动关系发现""" - project_id = workflow.project_id + project_id = workflow.project_id return { "workflow_type": "auto_relation", @@ -1401,8 +1401,8 @@ class WorkflowManager: async def _generate_scheduled_report(self, workflow: Workflow, input_data: dict) -> dict: """生成定时报告""" - project_id = workflow.project_id - report_type = workflow.config.get("report_type", "summary") + project_id = workflow.project_id + report_type = workflow.config.get("report_type", "summary") return { "workflow_type": "scheduled_report", @@ -1414,26 +1414,26 @@ class WorkflowManager: # ==================== Notification ==================== async def _send_workflow_notification( - self, workflow: Workflow, results: dict, success: bool = True - ): + self, workflow: Workflow, results: dict, success: bool = True + ) -> None: """发送工作流执行通知""" if not workflow.webhook_ids: return for webhook_id in workflow.webhook_ids: - webhook = self.get_webhook(webhook_id) + webhook = self.get_webhook(webhook_id) if not webhook or not webhook.is_active: continue # 构建通知消息 if webhook.webhook_type == WebhookType.FEISHU.value: - message = self._build_feishu_message(workflow, results, success) + message = self._build_feishu_message(workflow, results, success) elif webhook.webhook_type == WebhookType.DINGTALK.value: - message = self._build_dingtalk_message(workflow, results, success) + message = self._build_dingtalk_message(workflow, results, success) elif webhook.webhook_type == WebhookType.SLACK.value: - message = self._build_slack_message(workflow, results, success) + message = self._build_slack_message(workflow, results, success) else: - message = { + message = { "workflow_id": workflow.id, "workflow_name": workflow.name, "status": "success" if success else "failed", @@ -1442,14 +1442,14 @@ class WorkflowManager: } try: - result = await self.notifier.send(webhook, message) + result = await self.notifier.send(webhook, message) self.update_webhook_stats(webhook_id, result) except (TimeoutError, httpx.HTTPError) as e: logger.error(f"Failed to send notification to {webhook_id}: {e}") def _build_feishu_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建飞书消息""" - status_text = "✅ 成功" if success else "❌ 失败" + status_text = "✅ 成功" if success else "❌ 失败" return { "title": f"工作流执行通知: {workflow.name}", @@ -1462,7 +1462,7 @@ class WorkflowManager: def _build_dingtalk_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建钉钉消息""" - status_text = "✅ 成功" if success else "❌ 失败" + status_text = "✅ 成功" if success else "❌ 失败" return { "title": f"工作流执行通知: {workflow.name}", @@ -1476,15 +1476,15 @@ class WorkflowManager: **结果:** ```json -{json.dumps(results, ensure_ascii=False, indent=2)} +{json.dumps(results, ensure_ascii = False, indent = 2)} ``` """, } def _build_slack_message(self, workflow: Workflow, results: dict, success: bool) -> dict: """构建 Slack 消息""" - color = "#36a64f" if success else "#ff0000" - status_text = "Success" if success else "Failed" + color = "#36a64f" if success else "#ff0000" + status_text = "Success" if success else "Failed" return { "attachments": [ @@ -1507,12 +1507,12 @@ class WorkflowManager: # Singleton instance -_workflow_manager = None +_workflow_manager = None -def get_workflow_manager(db_manager=None) -> WorkflowManager: +def get_workflow_manager(db_manager = None) -> WorkflowManager: """获取 WorkflowManager 单例""" global _workflow_manager if _workflow_manager is None: - _workflow_manager = WorkflowManager(db_manager) + _workflow_manager = WorkflowManager(db_manager) return _workflow_manager diff --git a/code_review_fixer.py b/code_review_fixer.py index 4d2b639..e84141e 100644 --- a/code_review_fixer.py +++ b/code_review_fixer.py @@ -11,10 +11,10 @@ from pathlib import Path from typing import Any # 项目路径 -PROJECT_PATH = Path("/root/.openclaw/workspace/projects/insightflow") +PROJECT_PATH = Path("/root/.openclaw/workspace/projects/insightflow") # 修复报告 -report = { +report = { "fixed": [], "manual_review": [], "errors": [] @@ -22,7 +22,7 @@ report = { def find_python_files() -> list[Path]: """查找所有 Python 文件""" - py_files = [] + py_files = [] for py_file in PROJECT_PATH.rglob("*.py"): if "__pycache__" not in str(py_file): py_files.append(py_file) @@ -30,12 +30,12 @@ def find_python_files() -> list[Path]: def check_duplicate_imports(content: str, file_path: Path) -> list[dict]: """检查重复导入""" - issues = [] - lines = content.split('\n') - imports = {} - + issues = [] + lines = content.split('\n') + imports = {} + for i, line in enumerate(lines, 1): - line_stripped = line.strip() + line_stripped = line.strip() if line_stripped.startswith('import ') or line_stripped.startswith('from '): if line_stripped in imports: issues.append({ @@ -45,17 +45,17 @@ def check_duplicate_imports(content: str, file_path: Path) -> list[dict]: "original_line": imports[line_stripped] }) else: - imports[line_stripped] = i + imports[line_stripped] = i return issues def check_bare_excepts(content: str, file_path: Path) -> list[dict]: """检查裸异常捕获""" - issues = [] - lines = content.split('\n') - + issues = [] + lines = content.split('\n') + for i, line in enumerate(lines, 1): - stripped = line.strip() - # 检查 except Exception: 或 except : + stripped = line.strip() + # 检查 except Exception: 或 except Exception: if re.match(r'^except\s*:', stripped): issues.append({ "line": i, @@ -66,9 +66,9 @@ def check_bare_excepts(content: str, file_path: Path) -> list[dict]: def check_line_length(content: str, file_path: Path) -> list[dict]: """检查行长度(PEP8: 79字符,这里放宽到 100)""" - issues = [] - lines = content.split('\n') - + issues = [] + lines = content.split('\n') + for i, line in enumerate(lines, 1): if len(line) > 100: issues.append({ @@ -81,24 +81,24 @@ def check_line_length(content: str, file_path: Path) -> list[dict]: def check_unused_imports(content: str, file_path: Path) -> list[dict]: """检查未使用的导入""" - issues = [] + issues = [] try: - tree = ast.parse(content) - imports = {} - used_names = set() - + tree = ast.parse(content) + imports = {} + used_names = set() + for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: - imports[alias.asname or alias.name] = node + imports[alias.asname or alias.name] = node elif isinstance(node, ast.ImportFrom): for alias in node.names: - name = alias.asname or alias.name + name = alias.asname or alias.name if name != '*': - imports[name] = node + imports[name] = node elif isinstance(node, ast.Name): used_names.add(node.id) - + for name, node in imports.items(): if name not in used_names and not name.startswith('_'): issues.append({ @@ -112,9 +112,9 @@ def check_unused_imports(content: str, file_path: Path) -> list[dict]: def check_string_formatting(content: str, file_path: Path) -> list[dict]: """检查混合字符串格式化(建议使用 f-string)""" - issues = [] - lines = content.split('\n') - + issues = [] + lines = content.split('\n') + for i, line in enumerate(lines, 1): # 检查 % 格式化 if re.search(r'["\'].*%\s*\w+', line) and '%' in line: @@ -136,18 +136,18 @@ def check_string_formatting(content: str, file_path: Path) -> list[dict]: def check_magic_numbers(content: str, file_path: Path) -> list[dict]: """检查魔法数字""" - issues = [] - lines = content.split('\n') - + issues = [] + lines = content.split('\n') + # 常见魔法数字模式(排除常见索引和简单值) - magic_pattern = re.compile(r'(? list[dict]: def check_sql_injection(content: str, file_path: Path) -> list[dict]: """检查 SQL 注入风险""" - issues = [] - lines = content.split('\n') - + issues = [] + lines = content.split('\n') + for i, line in enumerate(lines, 1): # 检查字符串拼接的 SQL if 'execute(' in line or 'executescript(' in line or 'executemany(' in line: @@ -179,9 +179,9 @@ def check_sql_injection(content: str, file_path: Path) -> list[dict]: def check_cors_config(content: str, file_path: Path) -> list[dict]: """检查 CORS 配置""" - issues = [] - lines = content.split('\n') - + issues = [] + lines = content.split('\n') + for i, line in enumerate(lines, 1): if 'allow_origins' in line and '["*"]' in line: issues.append({ @@ -194,48 +194,48 @@ def check_cors_config(content: str, file_path: Path) -> list[dict]: def fix_bare_excepts(content: str) -> str: """修复裸异常捕获""" - lines = content.split('\n') - new_lines = [] - + lines = content.split('\n') + new_lines = [] + for line in lines: - stripped = line.strip() + stripped = line.strip() if re.match(r'^except\s*:', stripped): # 替换为具体异常 - indent = len(line) - len(line.lstrip()) - new_line = ' ' * indent + 'except (RuntimeError, ValueError, TypeError):' + indent = len(line) - len(line.lstrip()) + new_line = ' ' * indent + 'except (RuntimeError, ValueError, TypeError):' new_lines.append(new_line) else: new_lines.append(line) - + return '\n'.join(new_lines) def fix_line_length(content: str) -> str: """修复行长度问题(简单折行)""" - lines = content.split('\n') - new_lines = [] - + lines = content.split('\n') + new_lines = [] + for line in lines: if len(line) > 100: # 尝试在逗号或运算符处折行 - if ',' in line[80:]: + if ', ' in line[80:]: # 简单处理:截断并添加续行 - indent = len(line) - len(line.lstrip()) + indent = len(line) - len(line.lstrip()) new_lines.append(line) else: new_lines.append(line) else: new_lines.append(line) - + return '\n'.join(new_lines) def analyze_file(file_path: Path) -> dict: """分析单个文件""" try: - content = file_path.read_text(encoding='utf-8') + content = file_path.read_text(encoding = 'utf-8') except Exception as e: return {"error": str(e)} - - issues = { + + issues = { "duplicate_imports": check_duplicate_imports(content, file_path), "bare_excepts": check_bare_excepts(content, file_path), "line_length": check_line_length(content, file_path), @@ -245,22 +245,22 @@ def analyze_file(file_path: Path) -> dict: "sql_injection": check_sql_injection(content, file_path), "cors_config": check_cors_config(content, file_path), } - + return issues def fix_file(file_path: Path, issues: dict) -> bool: """自动修复文件问题""" try: - content = file_path.read_text(encoding='utf-8') - original_content = content - + content = file_path.read_text(encoding = 'utf-8') + original_content = content + # 修复裸异常 if issues.get("bare_excepts"): - content = fix_bare_excepts(content) - + content = fix_bare_excepts(content) + # 如果有修改,写回文件 if content != original_content: - file_path.write_text(content, encoding='utf-8') + file_path.write_text(content, encoding = 'utf-8') return True return False except Exception as e: @@ -269,88 +269,88 @@ def fix_file(file_path: Path, issues: dict) -> bool: def generate_report(all_issues: dict) -> str: """生成修复报告""" - lines = [] + lines = [] lines.append("# InsightFlow 代码审查报告") lines.append(f"\n生成时间: {__import__('datetime').datetime.now().isoformat()}") lines.append("\n## 自动修复的问题\n") - - total_fixed = 0 + + total_fixed = 0 for file_path, issues in all_issues.items(): - fixed_count = 0 + fixed_count = 0 for issue_type, issue_list in issues.items(): if issue_type in ["bare_excepts"] and issue_list: fixed_count += len(issue_list) - + if fixed_count > 0: lines.append(f"### {file_path}") lines.append(f"- 修复裸异常捕获: {fixed_count} 处") total_fixed += fixed_count - + if total_fixed == 0: lines.append("未发现需要自动修复的问题。") - + lines.append(f"\n**总计自动修复: {total_fixed} 处**") - + lines.append("\n## 需要人工确认的问题\n") - - total_manual = 0 + + total_manual = 0 for file_path, issues in all_issues.items(): - manual_issues = [] - + manual_issues = [] + if issues.get("sql_injection"): manual_issues.extend(issues["sql_injection"]) if issues.get("cors_config"): manual_issues.extend(issues["cors_config"]) - + if manual_issues: lines.append(f"### {file_path}") for issue in manual_issues: lines.append(f"- **{issue['type']}** (第 {issue['line']} 行): {issue.get('content', '')}") total_manual += len(manual_issues) - + if total_manual == 0: lines.append("未发现需要人工确认的问题。") - + lines.append(f"\n**总计待确认: {total_manual} 处**") - + lines.append("\n## 代码风格建议\n") - + for file_path, issues in all_issues.items(): - style_issues = [] + style_issues = [] if issues.get("line_length"): style_issues.extend(issues["line_length"]) if issues.get("string_formatting"): style_issues.extend(issues["string_formatting"]) if issues.get("magic_numbers"): style_issues.extend(issues["magic_numbers"]) - + if style_issues: lines.append(f"### {file_path}") for issue in style_issues[:5]: # 只显示前5个 lines.append(f"- 第 {issue['line']} 行: {issue['type']}") if len(style_issues) > 5: lines.append(f"- ... 还有 {len(style_issues) - 5} 个类似问题") - + return '\n'.join(lines) -def git_commit_and_push(): +def git_commit_and_push() -> None: """提交并推送代码""" try: os.chdir(PROJECT_PATH) - + # 检查是否有修改 - result = subprocess.run( + result = subprocess.run( ["git", "status", "--porcelain"], - capture_output=True, - text=True + capture_output = True, + text = True ) - + if not result.stdout.strip(): return "没有需要提交的更改" - + # 添加所有修改 - subprocess.run(["git", "add", "-A"], check=True) - + subprocess.run(["git", "add", "-A"], check = True) + # 提交 subprocess.run( ["git", "commit", "-m", """fix: auto-fix code issues (cron) @@ -359,52 +359,52 @@ def git_commit_and_push(): - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解"""], - check=True + check = True ) - + # 推送 - subprocess.run(["git", "push"], check=True) - + subprocess.run(["git", "push"], check = True) + return "✅ 提交并推送成功" except subprocess.CalledProcessError as e: return f"❌ Git 操作失败: {e}" except Exception as e: return f"❌ 错误: {e}" -def main(): +def main() -> None: """主函数""" print("🔍 开始代码审查...") - - py_files = find_python_files() + + py_files = find_python_files() print(f"📁 找到 {len(py_files)} 个 Python 文件") - - all_issues = {} - + + all_issues = {} + for py_file in py_files: print(f" 分析: {py_file.name}") - issues = analyze_file(py_file) - all_issues[py_file] = issues - + issues = analyze_file(py_file) + all_issues[py_file] = issues + # 自动修复 if fix_file(py_file, issues): report["fixed"].append(str(py_file)) - + # 生成报告 - report_content = generate_report(all_issues) - report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md" - report_path.write_text(report_content, encoding='utf-8') - + report_content = generate_report(all_issues) + report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md" + report_path.write_text(report_content, encoding = 'utf-8') + print("\n📄 报告已生成:", report_path) - + # Git 提交 print("\n🚀 提交代码...") - git_result = git_commit_and_push() + git_result = git_commit_and_push() print(git_result) - + # 追加提交结果到报告 - with open(report_path, 'a', encoding='utf-8') as f: + with open(report_path, 'a', encoding = 'utf-8') as f: f.write(f"\n\n## Git 提交结果\n\n{git_result}\n") - + print("\n✅ 代码审查完成!") return report_content diff --git a/code_reviewer.py b/code_reviewer.py index 251d0d4..90f84bd 100644 --- a/code_reviewer.py +++ b/code_reviewer.py @@ -15,25 +15,25 @@ class CodeIssue: line_no: int, issue_type: str, message: str, - severity: str = "info", - ): - self.file_path = file_path - self.line_no = line_no - self.issue_type = issue_type - self.message = message - self.severity = severity # info, warning, error - self.fixed = False + severity: str = "info", + ) -> None: + self.file_path = file_path + self.line_no = line_no + self.issue_type = issue_type + self.message = message + self.severity = severity # info, warning, error + self.fixed = False - def __repr__(self): + def __repr__(self) -> None: return f"{self.severity.upper()}: {self.file_path}:{self.line_no} - {self.issue_type}: {self.message}" class CodeReviewer: - def __init__(self, base_path: str): - self.base_path = Path(base_path) - self.issues: list[CodeIssue] = [] - self.fixed_issues: list[CodeIssue] = [] - self.manual_review_issues: list[CodeIssue] = [] + def __init__(self, base_path: str) -> None: + self.base_path = Path(base_path) + self.issues: list[CodeIssue] = [] + self.fixed_issues: list[CodeIssue] = [] + self.manual_review_issues: list[CodeIssue] = [] def scan_all(self) -> None: """扫描所有 Python 文件""" @@ -45,14 +45,14 @@ class CodeReviewer: def scan_file(self, file_path: Path) -> None: """扫描单个文件""" try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - lines = content.split("\n") + with open(file_path, "r", encoding = "utf-8") as f: + content = f.read() + lines = content.split("\n") except Exception as e: print(f"Error reading {file_path}: {e}") return - rel_path = str(file_path.relative_to(self.base_path)) + rel_path = str(file_path.relative_to(self.base_path)) # 1. 检查裸异常捕获 self._check_bare_exceptions(content, lines, rel_path) @@ -92,7 +92,7 @@ class CodeReviewer: # 跳过有注释说明的情况 if "# noqa" in line or "# intentional" in line.lower(): continue - issue = CodeIssue( + issue = CodeIssue( file_path, i, "bare_exception", @@ -105,17 +105,17 @@ class CodeReviewer: self, content: str, lines: list[str], file_path: str ) -> None: """检查重复导入""" - imports = {} + imports = {} for i, line in enumerate(lines, 1): - match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip()) + match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip()) if match: - module = match.group(1) or "" - names = match.group(2).split(",") + module = match.group(1) or "" + names = match.group(2).split(", ") for name in names: - name = name.strip().split()[0] # 处理 'as' 别名 - key = f"{module}.{name}" if module else name + name = name.strip().split()[0] # 处理 'as' 别名 + key = f"{module}.{name}" if module else name if key in imports: - issue = CodeIssue( + issue = CodeIssue( file_path, i, "duplicate_import", @@ -123,7 +123,7 @@ class CodeReviewer: "warning", ) self.issues.append(issue) - imports[key] = i + imports[key] = i def _check_pep8_issues( self, content: str, lines: list[str], file_path: str @@ -132,7 +132,7 @@ class CodeReviewer: for i, line in enumerate(lines, 1): # 行长度超过 120 if len(line) > 120: - issue = CodeIssue( + issue = CodeIssue( file_path, i, "line_too_long", @@ -143,7 +143,7 @@ class CodeReviewer: # 行尾空格 if line.rstrip() != line: - issue = CodeIssue( + issue = CodeIssue( file_path, i, "trailing_whitespace", "行尾有空格", "info" ) self.issues.append(issue) @@ -151,7 +151,7 @@ class CodeReviewer: # 多余的空行 if i > 1 and line.strip() == "" and lines[i - 2].strip() == "": if i < len(lines) and lines[i].strip() == "": - issue = CodeIssue( + issue = CodeIssue( file_path, i, "extra_blank_line", "多余的空行", "info" ) self.issues.append(issue) @@ -161,23 +161,23 @@ class CodeReviewer: ) -> None: """检查未使用的导入""" try: - tree = ast.parse(content) + tree = ast.parse(content) except SyntaxError: return - imported_names = {} - used_names = set() + imported_names = {} + used_names = set() for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: - name = alias.asname if alias.asname else alias.name - imported_names[name] = node.lineno + name = alias.asname if alias.asname else alias.name + imported_names[name] = node.lineno elif isinstance(node, ast.ImportFrom): for alias in node.names: - name = alias.asname if alias.asname else alias.name + name = alias.asname if alias.asname else alias.name if name != "*": - imported_names[name] = node.lineno + imported_names[name] = node.lineno elif isinstance(node, ast.Name): used_names.add(node.id) @@ -186,7 +186,7 @@ class CodeReviewer: # 排除一些常见例外 if name in ["annotations", "TYPE_CHECKING"]: continue - issue = CodeIssue( + issue = CodeIssue( file_path, lineno, "unused_import", f"未使用的导入: {name}", "info" ) self.issues.append(issue) @@ -195,20 +195,20 @@ class CodeReviewer: self, content: str, lines: list[str], file_path: str ) -> None: """检查混合字符串格式化""" - has_fstring = False - has_percent = False - has_format = False + has_fstring = False + has_percent = False + has_format = False for i, line in enumerate(lines, 1): if re.search(r'f["\']', line): - has_fstring = True + has_fstring = True if re.search(r"%[sdfr]", line) and not re.search(r"\d+%", line): - has_percent = True + has_percent = True if ".format(" in line: - has_format = True + has_format = True if has_fstring and (has_percent or has_format): - issue = CodeIssue( + issue = CodeIssue( file_path, 0, "mixed_formatting", @@ -222,25 +222,25 @@ class CodeReviewer: ) -> None: """检查魔法数字""" # 常见的魔法数字模式 - magic_patterns = [ - (r"=\s*(\d{3,})\s*[^:]", "可能的魔法数字"), - (r"timeout\s*=\s*(\d+)", "timeout 魔法数字"), - (r"limit\s*=\s*(\d+)", "limit 魔法数字"), - (r"port\s*=\s*(\d+)", "port 魔法数字"), + magic_patterns = [ + (r" = \s*(\d{3, })\s*[^:]", "可能的魔法数字"), + (r"timeout\s* = \s*(\d+)", "timeout 魔法数字"), + (r"limit\s* = \s*(\d+)", "limit 魔法数字"), + (r"port\s* = \s*(\d+)", "port 魔法数字"), ] for i, line in enumerate(lines, 1): # 跳过注释和字符串 - code_part = line.split("#")[0] + code_part = line.split("#")[0] if not code_part.strip(): continue for pattern, msg in magic_patterns: if re.search(pattern, code_part, re.IGNORECASE): # 排除常见的合理数字 - match = re.search(r"(\d{3,})", code_part) + match = re.search(r"(\d{3, })", code_part) if match: - num = int(match.group(1)) + num = int(match.group(1)) if num in [ 200, 404, @@ -257,7 +257,7 @@ class CodeReviewer: 8000, ]: continue - issue = CodeIssue( + issue = CodeIssue( file_path, i, "magic_number", f"{msg}: {num}", "info" ) self.issues.append(issue) @@ -272,7 +272,7 @@ class CodeReviewer: r'execute\s*\(\s*f["\']', line ): if "?" not in line and "%s" in line: - issue = CodeIssue( + issue = CodeIssue( file_path, i, "sql_injection_risk", @@ -287,7 +287,7 @@ class CodeReviewer: """检查 CORS 配置""" for i, line in enumerate(lines, 1): if "allow_origins" in line and '["*"]' in line: - issue = CodeIssue( + issue = CodeIssue( file_path, i, "cors_wildcard", @@ -303,7 +303,7 @@ class CodeReviewer: for i, line in enumerate(lines, 1): # 检查硬编码密钥 if re.search( - r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']', + r'(password|secret|key|token)\s* = \s*["\'][^"\']+["\']', line, re.IGNORECASE, ): @@ -316,7 +316,7 @@ class CodeReviewer: if not re.search(r'["\']\*+["\']', line) and not re.search( r'["\']<[^"\']*>["\']', line ): - issue = CodeIssue( + issue = CodeIssue( file_path, i, "hardcoded_secret", @@ -328,62 +328,62 @@ class CodeReviewer: def auto_fix(self) -> None: """自动修复问题""" # 按文件分组问题 - issues_by_file: dict[str, list[CodeIssue]] = {} + issues_by_file: dict[str, list[CodeIssue]] = {} for issue in self.issues: if issue.file_path not in issues_by_file: - issues_by_file[issue.file_path] = [] + issues_by_file[issue.file_path] = [] issues_by_file[issue.file_path].append(issue) for file_path, issues in issues_by_file.items(): - full_path = self.base_path / file_path + full_path = self.base_path / file_path if not full_path.exists(): continue try: - with open(full_path, "r", encoding="utf-8") as f: - content = f.read() - lines = content.split("\n") + with open(full_path, "r", encoding = "utf-8") as f: + content = f.read() + lines = content.split("\n") except Exception as e: print(f"Error reading {full_path}: {e}") continue - original_lines = lines.copy() + original_lines = lines.copy() # 修复行尾空格 for issue in issues: if issue.issue_type == "trailing_whitespace": - idx = issue.line_no - 1 + idx = issue.line_no - 1 if 0 <= idx < len(lines): - lines[idx] = lines[idx].rstrip() - issue.fixed = True + lines[idx] = lines[idx].rstrip() + issue.fixed = True # 修复裸异常 for issue in issues: if issue.issue_type == "bare_exception": - idx = issue.line_no - 1 + idx = issue.line_no - 1 if 0 <= idx < len(lines): - line = lines[idx] - # 将 except: 改为 except Exception: + line = lines[idx] + # 将 except Exception: 改为 except Exception: if re.search(r"except\s*:\s*$", line.strip()): - lines[idx] = line.replace("except:", "except Exception:") - issue.fixed = True + lines[idx] = line.replace("except Exception:", "except Exception:") + issue.fixed = True elif re.search(r"except\s+Exception\s*:\s*$", line.strip()): # 已经是 Exception,但可能需要更具体 pass # 如果文件有修改,写回 if lines != original_lines: - with open(full_path, "w", encoding="utf-8") as f: + with open(full_path, "w", encoding = "utf-8") as f: f.write("\n".join(lines)) print(f"Fixed issues in {file_path}") # 移动到已修复列表 - self.fixed_issues = [i for i in self.issues if i.fixed] - self.issues = [i for i in self.issues if not i.fixed] + self.fixed_issues = [i for i in self.issues if i.fixed] + self.issues = [i for i in self.issues if not i.fixed] def generate_report(self) -> str: """生成审查报告""" - report = [] + report = [] report.append("# InsightFlow 代码审查报告") report.append(f"\n扫描路径: {self.base_path}") report.append(f"扫描时间: {__import__('datetime').datetime.now().isoformat()}") @@ -421,9 +421,9 @@ class CodeReviewer: return "\n".join(report) -def main(): - base_path = "/root/.openclaw/workspace/projects/insightflow/backend" - reviewer = CodeReviewer(base_path) +def main() -> None: + base_path = "/root/.openclaw/workspace/projects/insightflow/backend" + reviewer = CodeReviewer(base_path) print("开始扫描代码...") reviewer.scan_all() @@ -437,9 +437,9 @@ def main(): print(f"\n已修复 {len(reviewer.fixed_issues)} 个问题") # 生成报告 - report = reviewer.generate_report() - report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md" - with open(report_path, "w", encoding="utf-8") as f: + report = reviewer.generate_report() + report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md" + with open(report_path, "w", encoding = "utf-8") as f: f.write(report) print(f"\n报告已保存到: {report_path}")