fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 修复语法错误(运算符空格问题) - 修复类型注解格式
This commit is contained in:
@@ -224,4 +224,8 @@
|
||||
- 第 490 行: line_too_long
|
||||
- 第 541 行: line_too_long
|
||||
- 第 579 行: line_too_long
|
||||
- ... 还有 2 个类似问题
|
||||
- ... 还有 2 个类似问题
|
||||
|
||||
## Git 提交结果
|
||||
|
||||
✅ 提交并推送成功
|
||||
|
||||
BIN
__pycache__/auto_code_fixer.cpython-312.pyc
Normal file
BIN
__pycache__/auto_code_fixer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/auto_fix_code.cpython-312.pyc
Normal file
BIN
__pycache__/auto_fix_code.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/code_review_fixer.cpython-312.pyc
Normal file
BIN
__pycache__/code_review_fixer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/code_reviewer.cpython-312.pyc
Normal file
BIN
__pycache__/code_reviewer.cpython-312.pyc
Normal file
Binary file not shown.
@@ -19,30 +19,30 @@ class CodeIssue:
|
||||
line_no: int,
|
||||
issue_type: str,
|
||||
message: str,
|
||||
severity: str = "warning",
|
||||
original_line: str = "",
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.line_no = line_no
|
||||
self.issue_type = issue_type
|
||||
self.message = message
|
||||
self.severity = severity
|
||||
self.original_line = original_line
|
||||
self.fixed = False
|
||||
severity: str = "warning",
|
||||
original_line: str = "",
|
||||
) -> None:
|
||||
self.file_path = file_path
|
||||
self.line_no = line_no
|
||||
self.issue_type = issue_type
|
||||
self.message = message
|
||||
self.severity = severity
|
||||
self.original_line = original_line
|
||||
self.fixed = False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> None:
|
||||
return f"{self.file_path}:{self.line_no} [{self.severity}] {self.issue_type}: {self.message}"
|
||||
|
||||
|
||||
class CodeFixer:
|
||||
"""代码自动修复器"""
|
||||
|
||||
def __init__(self, project_path: str):
|
||||
self.project_path = Path(project_path)
|
||||
self.issues: list[CodeIssue] = []
|
||||
self.fixed_issues: list[CodeIssue] = []
|
||||
self.manual_issues: list[CodeIssue] = []
|
||||
self.scanned_files: list[str] = []
|
||||
def __init__(self, project_path: str) -> None:
|
||||
self.project_path = Path(project_path)
|
||||
self.issues: list[CodeIssue] = []
|
||||
self.fixed_issues: list[CodeIssue] = []
|
||||
self.manual_issues: list[CodeIssue] = []
|
||||
self.scanned_files: list[str] = []
|
||||
|
||||
def scan_all_files(self) -> None:
|
||||
"""扫描所有 Python 文件"""
|
||||
@@ -55,9 +55,9 @@ class CodeFixer:
|
||||
def _scan_file(self, file_path: Path) -> None:
|
||||
"""扫描单个文件"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
with open(file_path, "r", encoding = "utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
except Exception as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
return
|
||||
@@ -85,7 +85,7 @@ class CodeFixer:
|
||||
) -> None:
|
||||
"""检查裸异常捕获"""
|
||||
for i, line in enumerate(lines, 1):
|
||||
# 匹配 except: 但不匹配 except Exception: 或 except SpecificError:
|
||||
# 匹配 except Exception: 但不匹配 except Exception: 或 except SpecificError:
|
||||
if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line):
|
||||
# 跳过注释说明的情况
|
||||
if "# noqa" in line or "# intentional" in line.lower():
|
||||
@@ -130,25 +130,25 @@ class CodeFixer:
|
||||
def _check_unused_imports(self, file_path: Path, content: str) -> None:
|
||||
"""检查未使用的导入"""
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError:
|
||||
return
|
||||
|
||||
imports = {}
|
||||
imports = {}
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
imports[name] = node.lineno
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
imports[name] = node.lineno
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if alias.name == "*":
|
||||
continue
|
||||
imports[name] = node.lineno
|
||||
imports[name] = node.lineno
|
||||
|
||||
# 检查使用
|
||||
used_names = set()
|
||||
used_names = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Name):
|
||||
used_names.add(node.id)
|
||||
@@ -216,15 +216,15 @@ class CodeFixer:
|
||||
) -> None:
|
||||
"""检查敏感信息泄露"""
|
||||
# 排除的文件
|
||||
excluded_files = ["auto_code_fixer.py", "code_reviewer.py"]
|
||||
excluded_files = ["auto_code_fixer.py", "code_reviewer.py"]
|
||||
if any(excluded in str(file_path) for excluded in excluded_files):
|
||||
return
|
||||
|
||||
patterns = [
|
||||
(r'password\s*=\s*["\'][^"\']{8,}["\']', "硬编码密码"),
|
||||
(r'secret_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"),
|
||||
(r'api_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码 API Key"),
|
||||
(r'token\s*=\s*["\'][^"\']{8,}["\']', "硬编码 Token"),
|
||||
patterns = [
|
||||
(r'password\s* = \s*["\'][^"\']{8, }["\']', "硬编码密码"),
|
||||
(r'secret_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码密钥"),
|
||||
(r'api_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码 API Key"),
|
||||
(r'token\s* = \s*["\'][^"\']{8, }["\']', "硬编码 Token"),
|
||||
]
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
@@ -241,7 +241,7 @@ class CodeFixer:
|
||||
if any(x in line.lower() for x in ["your_", "example", "placeholder", "test", "demo"]):
|
||||
continue
|
||||
# 排除 Enum 定义
|
||||
if re.search(r'^\s*[A-Z_]+\s*=', line.strip()):
|
||||
if re.search(r'^\s*[A-Z_]+\s* = ', line.strip()):
|
||||
continue
|
||||
self.manual_issues.append(
|
||||
CodeIssue(
|
||||
@@ -256,17 +256,17 @@ class CodeFixer:
|
||||
|
||||
def fix_auto_fixable(self) -> None:
|
||||
"""自动修复可修复的问题"""
|
||||
auto_fix_types = {
|
||||
auto_fix_types = {
|
||||
"trailing_whitespace",
|
||||
"bare_exception",
|
||||
}
|
||||
|
||||
# 按文件分组
|
||||
files_to_fix = {}
|
||||
files_to_fix = {}
|
||||
for issue in self.issues:
|
||||
if issue.issue_type in auto_fix_types:
|
||||
if issue.file_path not in files_to_fix:
|
||||
files_to_fix[issue.file_path] = []
|
||||
files_to_fix[issue.file_path] = []
|
||||
files_to_fix[issue.file_path].append(issue)
|
||||
|
||||
for file_path, file_issues in files_to_fix.items():
|
||||
@@ -275,43 +275,43 @@ class CodeFixer:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
with open(file_path, "r", encoding = "utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
original_lines = lines.copy()
|
||||
fixed_lines = set()
|
||||
original_lines = lines.copy()
|
||||
fixed_lines = set()
|
||||
|
||||
# 修复行尾空格
|
||||
for issue in file_issues:
|
||||
if issue.issue_type == "trailing_whitespace":
|
||||
line_idx = issue.line_no - 1
|
||||
line_idx = issue.line_no - 1
|
||||
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
|
||||
if lines[line_idx].rstrip() != lines[line_idx]:
|
||||
lines[line_idx] = lines[line_idx].rstrip()
|
||||
lines[line_idx] = lines[line_idx].rstrip()
|
||||
fixed_lines.add(line_idx)
|
||||
issue.fixed = True
|
||||
issue.fixed = True
|
||||
self.fixed_issues.append(issue)
|
||||
|
||||
# 修复裸异常
|
||||
for issue in file_issues:
|
||||
if issue.issue_type == "bare_exception":
|
||||
line_idx = issue.line_no - 1
|
||||
line_idx = issue.line_no - 1
|
||||
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
|
||||
line = lines[line_idx]
|
||||
# 将 except: 改为 except Exception:
|
||||
line = lines[line_idx]
|
||||
# 将 except Exception: 改为 except Exception:
|
||||
if re.search(r"except\s*:\s*$", line.strip()):
|
||||
lines[line_idx] = line.replace("except:", "except Exception:")
|
||||
lines[line_idx] = line.replace("except Exception:", "except Exception:")
|
||||
fixed_lines.add(line_idx)
|
||||
issue.fixed = True
|
||||
issue.fixed = True
|
||||
self.fixed_issues.append(issue)
|
||||
|
||||
# 如果文件有修改,写回
|
||||
if lines != original_lines:
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
with open(file_path, "w", encoding = "utf-8") as f:
|
||||
f.write("\n".join(lines))
|
||||
print(f"Fixed issues in {file_path}")
|
||||
except Exception as e:
|
||||
@@ -319,7 +319,7 @@ class CodeFixer:
|
||||
|
||||
def categorize_issues(self) -> dict[str, list[CodeIssue]]:
|
||||
"""分类问题"""
|
||||
categories = {
|
||||
categories = {
|
||||
"critical": [],
|
||||
"error": [],
|
||||
"warning": [],
|
||||
@@ -334,7 +334,7 @@ class CodeFixer:
|
||||
|
||||
def generate_report(self) -> str:
|
||||
"""生成修复报告"""
|
||||
report = []
|
||||
report = []
|
||||
report.append("# InsightFlow 代码审查报告")
|
||||
report.append("")
|
||||
report.append(f"扫描时间: {os.popen('date').read().strip()}")
|
||||
@@ -349,9 +349,9 @@ class CodeFixer:
|
||||
report.append("")
|
||||
|
||||
# 问题统计
|
||||
categories = self.categorize_issues()
|
||||
manual_critical = [i for i in self.manual_issues if i.severity == "critical"]
|
||||
manual_warning = [i for i in self.manual_issues if i.severity == "warning"]
|
||||
categories = self.categorize_issues()
|
||||
manual_critical = [i for i in self.manual_issues if i.severity == "critical"]
|
||||
manual_warning = [i for i in self.manual_issues if i.severity == "warning"]
|
||||
|
||||
report.append("## 问题分类统计")
|
||||
report.append("")
|
||||
@@ -393,17 +393,17 @@ class CodeFixer:
|
||||
# 其他问题
|
||||
report.append("## 📋 其他发现的问题")
|
||||
report.append("")
|
||||
other_issues = [
|
||||
other_issues = [
|
||||
i
|
||||
for i in self.issues
|
||||
if i not in self.fixed_issues
|
||||
]
|
||||
|
||||
# 按类型分组
|
||||
by_type = {}
|
||||
by_type = {}
|
||||
for issue in other_issues:
|
||||
if issue.issue_type not in by_type:
|
||||
by_type[issue.issue_type] = []
|
||||
by_type[issue.issue_type] = []
|
||||
by_type[issue.issue_type].append(issue)
|
||||
|
||||
for issue_type, issues in sorted(by_type.items()):
|
||||
@@ -424,21 +424,21 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
|
||||
"""Git 提交和推送"""
|
||||
try:
|
||||
# 检查是否有变更
|
||||
result = subprocess.run(
|
||||
result = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
cwd=project_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd = project_path,
|
||||
capture_output = True,
|
||||
text = True,
|
||||
)
|
||||
|
||||
if not result.stdout.strip():
|
||||
return True, "没有需要提交的变更"
|
||||
|
||||
# 添加所有变更
|
||||
subprocess.run(["git", "add", "-A"], cwd=project_path, check=True)
|
||||
subprocess.run(["git", "add", "-A"], cwd = project_path, check = True)
|
||||
|
||||
# 提交
|
||||
commit_msg = """fix: auto-fix code issues (cron)
|
||||
commit_msg = """fix: auto-fix code issues (cron)
|
||||
|
||||
- 修复重复导入/字段
|
||||
- 修复异常处理
|
||||
@@ -446,11 +446,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
|
||||
- 添加类型注解"""
|
||||
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", commit_msg], cwd=project_path, check=True
|
||||
["git", "commit", "-m", commit_msg], cwd = project_path, check = True
|
||||
)
|
||||
|
||||
# 推送
|
||||
subprocess.run(["git", "push"], cwd=project_path, check=True)
|
||||
subprocess.run(["git", "push"], cwd = project_path, check = True)
|
||||
|
||||
return True, "提交并推送成功"
|
||||
except subprocess.CalledProcessError as e:
|
||||
@@ -459,11 +459,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
|
||||
return False, f"Git 操作异常: {e}"
|
||||
|
||||
|
||||
def main():
|
||||
project_path = "/root/.openclaw/workspace/projects/insightflow"
|
||||
def main() -> None:
|
||||
project_path = "/root/.openclaw/workspace/projects/insightflow"
|
||||
|
||||
print("🔍 开始扫描代码...")
|
||||
fixer = CodeFixer(project_path)
|
||||
fixer = CodeFixer(project_path)
|
||||
fixer.scan_all_files()
|
||||
|
||||
print(f"📊 发现 {len(fixer.issues)} 个可自动修复问题")
|
||||
@@ -475,30 +475,30 @@ def main():
|
||||
print(f"✅ 已修复 {len(fixer.fixed_issues)} 个问题")
|
||||
|
||||
# 生成报告
|
||||
report = fixer.generate_report()
|
||||
report = fixer.generate_report()
|
||||
|
||||
# 保存报告
|
||||
report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md"
|
||||
with open(report_path, "w", encoding="utf-8") as f:
|
||||
report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md"
|
||||
with open(report_path, "w", encoding = "utf-8") as f:
|
||||
f.write(report)
|
||||
|
||||
print(f"📝 报告已保存到: {report_path}")
|
||||
|
||||
# Git 提交
|
||||
print("📤 提交变更到 Git...")
|
||||
success, msg = git_commit_and_push(project_path)
|
||||
success, msg = git_commit_and_push(project_path)
|
||||
print(f"{'✅' if success else '❌'} {msg}")
|
||||
|
||||
# 添加 Git 结果到报告
|
||||
report += f"\n\n## Git 提交结果\n\n{'✅' if success else '❌'} {msg}\n"
|
||||
|
||||
# 重新保存完整报告
|
||||
with open(report_path, "w", encoding="utf-8") as f:
|
||||
with open(report_path, "w", encoding = "utf-8") as f:
|
||||
f.write(report)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print(report)
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
150
auto_fix_code.py
150
auto_fix_code.py
@@ -14,11 +14,11 @@ from pathlib import Path
|
||||
def run_ruff_check(directory: str) -> list[dict]:
|
||||
"""运行 ruff 检查并返回问题列表"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ruff", "check", "--select=E,W,F,I", "--output-format=json", directory],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
result = subprocess.run(
|
||||
["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory],
|
||||
capture_output = True,
|
||||
text = True,
|
||||
check = False,
|
||||
)
|
||||
if result.stdout:
|
||||
return json.loads(result.stdout)
|
||||
@@ -29,18 +29,18 @@ def run_ruff_check(directory: str) -> list[dict]:
|
||||
|
||||
|
||||
def fix_bare_except(content: str) -> str:
|
||||
"""修复裸异常捕获 - 将 bare except: 改为 except Exception:"""
|
||||
pattern = r'except\s*:\s*\n'
|
||||
replacement = 'except Exception:\n'
|
||||
"""修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:"""
|
||||
pattern = r'except\s*:\s*\n'
|
||||
replacement = 'except Exception:\n'
|
||||
return re.sub(pattern, replacement, content)
|
||||
|
||||
|
||||
def fix_undefined_names(content: str, filepath: str) -> str:
|
||||
"""修复未定义的名称"""
|
||||
lines = content.split('\n')
|
||||
modified = False
|
||||
|
||||
import_map = {
|
||||
lines = content.split('\n')
|
||||
modified = False
|
||||
|
||||
import_map = {
|
||||
'ExportEntity': 'from export_manager import ExportEntity',
|
||||
'ExportRelation': 'from export_manager import ExportRelation',
|
||||
'ExportTranscript': 'from export_manager import ExportTranscript',
|
||||
@@ -49,23 +49,23 @@ def fix_undefined_names(content: str, filepath: str) -> str:
|
||||
'OpsManager': 'from ops_manager import OpsManager',
|
||||
'urllib': 'import urllib.parse',
|
||||
}
|
||||
|
||||
undefined_names = set()
|
||||
|
||||
undefined_names = set()
|
||||
for name, import_stmt in import_map.items():
|
||||
if name in content and import_stmt not in content:
|
||||
undefined_names.add((name, import_stmt))
|
||||
|
||||
|
||||
if undefined_names:
|
||||
import_idx = 0
|
||||
import_idx = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith('import ') or line.startswith('from '):
|
||||
import_idx = i + 1
|
||||
|
||||
import_idx = i + 1
|
||||
|
||||
for name, import_stmt in sorted(undefined_names):
|
||||
lines.insert(import_idx, import_stmt)
|
||||
import_idx += 1
|
||||
modified = True
|
||||
|
||||
modified = True
|
||||
|
||||
if modified:
|
||||
return '\n'.join(lines)
|
||||
return content
|
||||
@@ -73,100 +73,100 @@ def fix_undefined_names(content: str, filepath: str) -> str:
|
||||
|
||||
def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]:
|
||||
"""修复单个文件的问题"""
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
original_content = f.read()
|
||||
|
||||
content = original_content
|
||||
fixed_issues = []
|
||||
manual_fix_needed = []
|
||||
|
||||
with open(filepath, 'r', encoding = 'utf-8') as f:
|
||||
original_content = f.read()
|
||||
|
||||
content = original_content
|
||||
fixed_issues = []
|
||||
manual_fix_needed = []
|
||||
|
||||
for issue in issues:
|
||||
code = issue.get('code', '')
|
||||
message = issue.get('message', '')
|
||||
line_num = issue['location']['row']
|
||||
|
||||
code = issue.get('code', '')
|
||||
message = issue.get('message', '')
|
||||
line_num = issue['location']['row']
|
||||
|
||||
if code == 'F821':
|
||||
content = fix_undefined_names(content, filepath)
|
||||
content = fix_undefined_names(content, filepath)
|
||||
if content != original_content:
|
||||
fixed_issues.append(f"F821 - {message} (line {line_num})")
|
||||
else:
|
||||
manual_fix_needed.append(f"F821 - {message} (line {line_num})")
|
||||
elif code == 'E501':
|
||||
manual_fix_needed.append(f"E501 (line {line_num})")
|
||||
|
||||
content = fix_bare_except(content)
|
||||
|
||||
|
||||
content = fix_bare_except(content)
|
||||
|
||||
if content != original_content:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
with open(filepath, 'w', encoding = 'utf-8') as f:
|
||||
f.write(content)
|
||||
return True, fixed_issues, manual_fix_needed
|
||||
|
||||
|
||||
return False, fixed_issues, manual_fix_needed
|
||||
|
||||
|
||||
def main():
|
||||
base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
|
||||
backend_dir = base_dir / "backend"
|
||||
|
||||
print("=" * 60)
|
||||
def main() -> None:
|
||||
base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
|
||||
backend_dir = base_dir / "backend"
|
||||
|
||||
print(" = " * 60)
|
||||
print("InsightFlow 代码自动修复")
|
||||
print("=" * 60)
|
||||
|
||||
print(" = " * 60)
|
||||
|
||||
print("\n1. 扫描代码问题...")
|
||||
issues = run_ruff_check(str(backend_dir))
|
||||
|
||||
issues_by_file = {}
|
||||
issues = run_ruff_check(str(backend_dir))
|
||||
|
||||
issues_by_file = {}
|
||||
for issue in issues:
|
||||
filepath = issue.get('filename', '')
|
||||
filepath = issue.get('filename', '')
|
||||
if filepath not in issues_by_file:
|
||||
issues_by_file[filepath] = []
|
||||
issues_by_file[filepath] = []
|
||||
issues_by_file[filepath].append(issue)
|
||||
|
||||
|
||||
print(f" 发现 {len(issues)} 个问题,分布在 {len(issues_by_file)} 个文件中")
|
||||
|
||||
issue_types = {}
|
||||
|
||||
issue_types = {}
|
||||
for issue in issues:
|
||||
code = issue.get('code', 'UNKNOWN')
|
||||
issue_types[code] = issue_types.get(code, 0) + 1
|
||||
|
||||
code = issue.get('code', 'UNKNOWN')
|
||||
issue_types[code] = issue_types.get(code, 0) + 1
|
||||
|
||||
print("\n2. 问题类型统计:")
|
||||
for code, count in sorted(issue_types.items(), key=lambda x: -x[1]):
|
||||
for code, count in sorted(issue_types.items(), key = lambda x: -x[1]):
|
||||
print(f" - {code}: {count} 个")
|
||||
|
||||
|
||||
print("\n3. 尝试自动修复...")
|
||||
fixed_files = []
|
||||
all_fixed_issues = []
|
||||
all_manual_fixes = []
|
||||
|
||||
fixed_files = []
|
||||
all_fixed_issues = []
|
||||
all_manual_fixes = []
|
||||
|
||||
for filepath, file_issues in issues_by_file.items():
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
modified, fixed, manual = fix_file(filepath, file_issues)
|
||||
|
||||
modified, fixed, manual = fix_file(filepath, file_issues)
|
||||
if modified:
|
||||
fixed_files.append(filepath)
|
||||
all_fixed_issues.extend(fixed)
|
||||
all_manual_fixes.extend([(filepath, m) for m in manual])
|
||||
|
||||
|
||||
print(f" 直接修改了 {len(fixed_files)} 个文件")
|
||||
print(f" 自动修复了 {len(all_fixed_issues)} 个问题")
|
||||
|
||||
|
||||
print("\n4. 运行 ruff 自动格式化...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["ruff", "format", str(backend_dir)],
|
||||
capture_output=True,
|
||||
check=False,
|
||||
capture_output = True,
|
||||
check = False,
|
||||
)
|
||||
print(" 格式化完成")
|
||||
except Exception as e:
|
||||
print(f" 格式化失败: {e}")
|
||||
|
||||
|
||||
print("\n5. 再次检查...")
|
||||
remaining_issues = run_ruff_check(str(backend_dir))
|
||||
remaining_issues = run_ruff_check(str(backend_dir))
|
||||
print(f" 剩余 {len(remaining_issues)} 个问题需要手动处理")
|
||||
|
||||
report = {
|
||||
|
||||
report = {
|
||||
'total_issues': len(issues),
|
||||
'fixed_files': len(fixed_files),
|
||||
'fixed_issues': len(all_fixed_issues),
|
||||
@@ -174,15 +174,15 @@ def main():
|
||||
'issue_types': issue_types,
|
||||
'manual_fix_needed': all_manual_fixes[:30],
|
||||
}
|
||||
|
||||
|
||||
return report
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
report = main()
|
||||
print("\n" + "=" * 60)
|
||||
report = main()
|
||||
print("\n" + " = " * 60)
|
||||
print("修复报告")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print(f"总问题数: {report['total_issues']}")
|
||||
print(f"修复文件数: {report['fixed_files']}")
|
||||
print(f"自动修复问题数: {report['fixed_issues']}")
|
||||
|
||||
BIN
backend/__pycache__/ai_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/ai_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/api_key_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/api_key_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/collaboration_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/collaboration_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/db_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/db_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/developer_ecosystem_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/developer_ecosystem_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/document_processor.cpython-312.pyc
Normal file
BIN
backend/__pycache__/document_processor.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/enterprise_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/enterprise_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/entity_aligner.cpython-312.pyc
Normal file
BIN
backend/__pycache__/entity_aligner.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/export_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/export_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/growth_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/growth_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/image_processor.cpython-312.pyc
Normal file
BIN
backend/__pycache__/image_processor.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/init_db.cpython-312.pyc
Normal file
BIN
backend/__pycache__/init_db.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/knowledge_reasoner.cpython-312.pyc
Normal file
BIN
backend/__pycache__/knowledge_reasoner.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/llm_client.cpython-312.pyc
Normal file
BIN
backend/__pycache__/llm_client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/localization_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/localization_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/main.cpython-312.pyc
Normal file
BIN
backend/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/multimodal_entity_linker.cpython-312.pyc
Normal file
BIN
backend/__pycache__/multimodal_entity_linker.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/multimodal_processor.cpython-312.pyc
Normal file
BIN
backend/__pycache__/multimodal_processor.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/neo4j_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/neo4j_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/ops_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/ops_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/oss_uploader.cpython-312.pyc
Normal file
BIN
backend/__pycache__/oss_uploader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/performance_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/performance_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/plugin_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/plugin_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/rate_limiter.cpython-312.pyc
Normal file
BIN
backend/__pycache__/rate_limiter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/search_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/search_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/security_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/security_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/subscription_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/subscription_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/tenant_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/tenant_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_multimodal.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_multimodal.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase7_task6_8.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase7_task6_8.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase8_task1.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase8_task1.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase8_task2.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase8_task2.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase8_task4.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase8_task4.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase8_task5.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase8_task5.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase8_task6.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase8_task6.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/test_phase8_task8.cpython-312.pyc
Normal file
BIN
backend/__pycache__/test_phase8_task8.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/tingwu_client.cpython-312.pyc
Normal file
BIN
backend/__pycache__/tingwu_client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/workflow_manager.cpython-312.pyc
Normal file
BIN
backend/__pycache__/workflow_manager.cpython-312.pyc
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -13,13 +13,13 @@ from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
|
||||
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
|
||||
|
||||
|
||||
class ApiKeyStatus(Enum):
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
EXPIRED = "expired"
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -37,18 +37,18 @@ class ApiKey:
|
||||
last_used_at: str | None
|
||||
revoked_at: str | None
|
||||
revoked_reason: str | None
|
||||
total_calls: int = 0
|
||||
total_calls: int = 0
|
||||
|
||||
|
||||
class ApiKeyManager:
|
||||
"""API Key 管理器"""
|
||||
|
||||
# Key 前缀
|
||||
KEY_PREFIX = "ak_live_"
|
||||
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
|
||||
KEY_PREFIX = "ak_live_"
|
||||
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||
self.db_path = db_path
|
||||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self) -> None:
|
||||
@@ -117,7 +117,7 @@ class ApiKeyManager:
|
||||
def _generate_key(self) -> str:
|
||||
"""生成新的 API Key"""
|
||||
# 生成 40 字符的随机字符串
|
||||
random_part = secrets.token_urlsafe(30)[:40]
|
||||
random_part = secrets.token_urlsafe(30)[:40]
|
||||
return f"{self.KEY_PREFIX}{random_part}"
|
||||
|
||||
def _hash_key(self, key: str) -> str:
|
||||
@@ -131,10 +131,10 @@ class ApiKeyManager:
|
||||
def create_key(
|
||||
self,
|
||||
name: str,
|
||||
owner_id: str | None = None,
|
||||
permissions: list[str] = None,
|
||||
rate_limit: int = 60,
|
||||
expires_days: int | None = None,
|
||||
owner_id: str | None = None,
|
||||
permissions: list[str] = None,
|
||||
rate_limit: int = 60,
|
||||
expires_days: int | None = None,
|
||||
) -> tuple[str, ApiKey]:
|
||||
"""
|
||||
创建新的 API Key
|
||||
@@ -143,32 +143,32 @@ class ApiKeyManager:
|
||||
tuple: (原始key(仅返回一次), ApiKey对象)
|
||||
"""
|
||||
if permissions is None:
|
||||
permissions = ["read"]
|
||||
permissions = ["read"]
|
||||
|
||||
key_id = secrets.token_hex(16)
|
||||
raw_key = self._generate_key()
|
||||
key_hash = self._hash_key(raw_key)
|
||||
key_preview = self._get_preview(raw_key)
|
||||
key_id = secrets.token_hex(16)
|
||||
raw_key = self._generate_key()
|
||||
key_hash = self._hash_key(raw_key)
|
||||
key_preview = self._get_preview(raw_key)
|
||||
|
||||
expires_at = None
|
||||
expires_at = None
|
||||
if expires_days:
|
||||
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
|
||||
expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat()
|
||||
|
||||
api_key = ApiKey(
|
||||
id=key_id,
|
||||
key_hash=key_hash,
|
||||
key_preview=key_preview,
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
permissions=permissions,
|
||||
rate_limit=rate_limit,
|
||||
status=ApiKeyStatus.ACTIVE.value,
|
||||
created_at=datetime.now().isoformat(),
|
||||
expires_at=expires_at,
|
||||
last_used_at=None,
|
||||
revoked_at=None,
|
||||
revoked_reason=None,
|
||||
total_calls=0,
|
||||
api_key = ApiKey(
|
||||
id = key_id,
|
||||
key_hash = key_hash,
|
||||
key_preview = key_preview,
|
||||
name = name,
|
||||
owner_id = owner_id,
|
||||
permissions = permissions,
|
||||
rate_limit = rate_limit,
|
||||
status = ApiKeyStatus.ACTIVE.value,
|
||||
created_at = datetime.now().isoformat(),
|
||||
expires_at = expires_at,
|
||||
last_used_at = None,
|
||||
revoked_at = None,
|
||||
revoked_reason = None,
|
||||
total_calls = 0,
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -203,16 +203,16 @@ class ApiKeyManager:
|
||||
Returns:
|
||||
ApiKey if valid, None otherwise
|
||||
"""
|
||||
key_hash = self._hash_key(key)
|
||||
key_hash = self._hash_key(key)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone()
|
||||
conn.row_factory = sqlite3.Row
|
||||
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
api_key = self._row_to_api_key(row)
|
||||
api_key = self._row_to_api_key(row)
|
||||
|
||||
# 检查状态
|
||||
if api_key.status != ApiKeyStatus.ACTIVE.value:
|
||||
@@ -220,11 +220,11 @@ class ApiKeyManager:
|
||||
|
||||
# 检查是否过期
|
||||
if api_key.expires_at:
|
||||
expires = datetime.fromisoformat(api_key.expires_at)
|
||||
expires = datetime.fromisoformat(api_key.expires_at)
|
||||
if datetime.now() > expires:
|
||||
# 更新状态为过期
|
||||
conn.execute(
|
||||
"UPDATE api_keys SET status = ? WHERE id = ?",
|
||||
"UPDATE api_keys SET status = ? WHERE id = ?",
|
||||
(ApiKeyStatus.EXPIRED.value, api_key.id),
|
||||
)
|
||||
conn.commit()
|
||||
@@ -232,22 +232,22 @@ class ApiKeyManager:
|
||||
|
||||
return api_key
|
||||
|
||||
def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
|
||||
def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
|
||||
"""撤销 API Key"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# 验证所有权(如果提供了 owner_id)
|
||||
if owner_id:
|
||||
row = conn.execute(
|
||||
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
|
||||
row = conn.execute(
|
||||
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
|
||||
).fetchone()
|
||||
if not row or row[0] != owner_id:
|
||||
return False
|
||||
|
||||
cursor = conn.execute(
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
UPDATE api_keys
|
||||
SET status = ?, revoked_at = ?, revoked_reason = ?
|
||||
WHERE id = ? AND status = ?
|
||||
SET status = ?, revoked_at = ?, revoked_reason = ?
|
||||
WHERE id = ? AND status = ?
|
||||
""",
|
||||
(
|
||||
ApiKeyStatus.REVOKED.value,
|
||||
@@ -260,17 +260,17 @@ class ApiKeyManager:
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
|
||||
def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
|
||||
"""通过 ID 获取 API Key(不包含敏感信息)"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
if owner_id:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
|
||||
row = conn.execute(
|
||||
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
|
||||
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone()
|
||||
|
||||
if row:
|
||||
return self._row_to_api_key(row)
|
||||
@@ -278,54 +278,54 @@ class ApiKeyManager:
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
owner_id: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
owner_id: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ApiKey]:
|
||||
"""列出 API Keys"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
query = "SELECT * FROM api_keys WHERE 1=1"
|
||||
params = []
|
||||
query = "SELECT * FROM api_keys WHERE 1 = 1"
|
||||
params = []
|
||||
|
||||
if owner_id:
|
||||
query += " AND owner_id = ?"
|
||||
query += " AND owner_id = ?"
|
||||
params.append(owner_id)
|
||||
|
||||
if status:
|
||||
query += " AND status = ?"
|
||||
query += " AND status = ?"
|
||||
params.append(status)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [self._row_to_api_key(row) for row in rows]
|
||||
|
||||
def update_key(
|
||||
self,
|
||||
key_id: str,
|
||||
name: str | None = None,
|
||||
permissions: list[str] | None = None,
|
||||
rate_limit: int | None = None,
|
||||
owner_id: str | None = None,
|
||||
name: str | None = None,
|
||||
permissions: list[str] | None = None,
|
||||
rate_limit: int | None = None,
|
||||
owner_id: str | None = None,
|
||||
) -> bool:
|
||||
"""更新 API Key 信息"""
|
||||
updates = []
|
||||
params = []
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if name is not None:
|
||||
updates.append("name = ?")
|
||||
updates.append("name = ?")
|
||||
params.append(name)
|
||||
|
||||
if permissions is not None:
|
||||
updates.append("permissions = ?")
|
||||
updates.append("permissions = ?")
|
||||
params.append(json.dumps(permissions))
|
||||
|
||||
if rate_limit is not None:
|
||||
updates.append("rate_limit = ?")
|
||||
updates.append("rate_limit = ?")
|
||||
params.append(rate_limit)
|
||||
|
||||
if not updates:
|
||||
@@ -336,14 +336,14 @@ class ApiKeyManager:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# 验证所有权
|
||||
if owner_id:
|
||||
row = conn.execute(
|
||||
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
|
||||
row = conn.execute(
|
||||
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
|
||||
).fetchone()
|
||||
if not row or row[0] != owner_id:
|
||||
return False
|
||||
|
||||
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor = conn.execute(query, params)
|
||||
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor = conn.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
@@ -353,8 +353,8 @@ class ApiKeyManager:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE api_keys
|
||||
SET last_used_at = ?, total_calls = total_calls + 1
|
||||
WHERE id = ?
|
||||
SET last_used_at = ?, total_calls = total_calls + 1
|
||||
WHERE id = ?
|
||||
""",
|
||||
(datetime.now().isoformat(), key_id),
|
||||
)
|
||||
@@ -365,12 +365,12 @@ class ApiKeyManager:
|
||||
api_key_id: str,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
status_code: int = 200,
|
||||
response_time_ms: int = 0,
|
||||
ip_address: str = "",
|
||||
user_agent: str = "",
|
||||
error_message: str = "",
|
||||
):
|
||||
status_code: int = 200,
|
||||
response_time_ms: int = 0,
|
||||
ip_address: str = "",
|
||||
user_agent: str = "",
|
||||
error_message: str = "",
|
||||
) -> None:
|
||||
"""记录 API 调用日志"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
@@ -395,21 +395,21 @@ class ApiKeyManager:
|
||||
|
||||
def get_call_logs(
|
||||
self,
|
||||
api_key_id: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
api_key_id: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict]:
|
||||
"""获取 API 调用日志"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
query = "SELECT * FROM api_call_logs WHERE 1=1"
|
||||
params = []
|
||||
query = "SELECT * FROM api_call_logs WHERE 1 = 1"
|
||||
params = []
|
||||
|
||||
if api_key_id:
|
||||
query += " AND api_key_id = ?"
|
||||
query += " AND api_key_id = ?"
|
||||
params.append(api_key_id)
|
||||
|
||||
if start_date:
|
||||
@@ -423,16 +423,16 @@ class ApiKeyManager:
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
|
||||
def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
|
||||
"""获取 API 调用统计"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# 总体统计
|
||||
query = f"""
|
||||
query = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_calls,
|
||||
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
|
||||
@@ -444,15 +444,15 @@ class ApiKeyManager:
|
||||
WHERE created_at >= date('now', '-{days} days')
|
||||
"""
|
||||
|
||||
params = []
|
||||
params = []
|
||||
if api_key_id:
|
||||
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||
params.insert(0, api_key_id)
|
||||
|
||||
row = conn.execute(query, params).fetchone()
|
||||
row = conn.execute(query, params).fetchone()
|
||||
|
||||
# 按端点统计
|
||||
endpoint_query = f"""
|
||||
endpoint_query = f"""
|
||||
SELECT
|
||||
endpoint,
|
||||
method,
|
||||
@@ -462,19 +462,19 @@ class ApiKeyManager:
|
||||
WHERE created_at >= date('now', '-{days} days')
|
||||
"""
|
||||
|
||||
endpoint_params = []
|
||||
endpoint_params = []
|
||||
if api_key_id:
|
||||
endpoint_query = endpoint_query.replace(
|
||||
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
|
||||
endpoint_query = endpoint_query.replace(
|
||||
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
|
||||
)
|
||||
endpoint_params.insert(0, api_key_id)
|
||||
|
||||
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
|
||||
|
||||
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
|
||||
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
|
||||
|
||||
# 按天统计
|
||||
daily_query = f"""
|
||||
daily_query = f"""
|
||||
SELECT
|
||||
date(created_at) as date,
|
||||
COUNT(*) as calls,
|
||||
@@ -483,16 +483,16 @@ class ApiKeyManager:
|
||||
WHERE created_at >= date('now', '-{days} days')
|
||||
"""
|
||||
|
||||
daily_params = []
|
||||
daily_params = []
|
||||
if api_key_id:
|
||||
daily_query = daily_query.replace(
|
||||
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
|
||||
daily_query = daily_query.replace(
|
||||
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
|
||||
)
|
||||
daily_params.insert(0, api_key_id)
|
||||
|
||||
daily_query += " GROUP BY date(created_at) ORDER BY date"
|
||||
|
||||
daily_rows = conn.execute(daily_query, daily_params).fetchall()
|
||||
daily_rows = conn.execute(daily_query, daily_params).fetchall()
|
||||
|
||||
return {
|
||||
"summary": {
|
||||
@@ -510,30 +510,30 @@ class ApiKeyManager:
|
||||
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
|
||||
"""将数据库行转换为 ApiKey 对象"""
|
||||
return ApiKey(
|
||||
id=row["id"],
|
||||
key_hash=row["key_hash"],
|
||||
key_preview=row["key_preview"],
|
||||
name=row["name"],
|
||||
owner_id=row["owner_id"],
|
||||
permissions=json.loads(row["permissions"]),
|
||||
rate_limit=row["rate_limit"],
|
||||
status=row["status"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
last_used_at=row["last_used_at"],
|
||||
revoked_at=row["revoked_at"],
|
||||
revoked_reason=row["revoked_reason"],
|
||||
total_calls=row["total_calls"],
|
||||
id = row["id"],
|
||||
key_hash = row["key_hash"],
|
||||
key_preview = row["key_preview"],
|
||||
name = row["name"],
|
||||
owner_id = row["owner_id"],
|
||||
permissions = json.loads(row["permissions"]),
|
||||
rate_limit = row["rate_limit"],
|
||||
status = row["status"],
|
||||
created_at = row["created_at"],
|
||||
expires_at = row["expires_at"],
|
||||
last_used_at = row["last_used_at"],
|
||||
revoked_at = row["revoked_at"],
|
||||
revoked_reason = row["revoked_reason"],
|
||||
total_calls = row["total_calls"],
|
||||
)
|
||||
|
||||
|
||||
# 全局实例
|
||||
_api_key_manager: ApiKeyManager | None = None
|
||||
_api_key_manager: ApiKeyManager | None = None
|
||||
|
||||
|
||||
def get_api_key_manager() -> ApiKeyManager:
|
||||
"""获取 API Key 管理器实例"""
|
||||
global _api_key_manager
|
||||
if _api_key_manager is None:
|
||||
_api_key_manager = ApiKeyManager()
|
||||
_api_key_manager = ApiKeyManager()
|
||||
return _api_key_manager
|
||||
|
||||
@@ -15,29 +15,29 @@ from typing import Any
|
||||
class SharePermission(Enum):
|
||||
"""分享权限级别"""
|
||||
|
||||
READ_ONLY = "read_only" # 只读
|
||||
COMMENT = "comment" # 可评论
|
||||
EDIT = "edit" # 可编辑
|
||||
ADMIN = "admin" # 管理员
|
||||
READ_ONLY = "read_only" # 只读
|
||||
COMMENT = "comment" # 可评论
|
||||
EDIT = "edit" # 可编辑
|
||||
ADMIN = "admin" # 管理员
|
||||
|
||||
|
||||
class CommentTargetType(Enum):
|
||||
"""评论目标类型"""
|
||||
|
||||
ENTITY = "entity" # 实体评论
|
||||
RELATION = "relation" # 关系评论
|
||||
TRANSCRIPT = "transcript" # 转录文本评论
|
||||
PROJECT = "project" # 项目级评论
|
||||
ENTITY = "entity" # 实体评论
|
||||
RELATION = "relation" # 关系评论
|
||||
TRANSCRIPT = "transcript" # 转录文本评论
|
||||
PROJECT = "project" # 项目级评论
|
||||
|
||||
|
||||
class ChangeType(Enum):
|
||||
"""变更类型"""
|
||||
|
||||
CREATE = "create" # 创建
|
||||
UPDATE = "update" # 更新
|
||||
DELETE = "delete" # 删除
|
||||
MERGE = "merge" # 合并
|
||||
SPLIT = "split" # 拆分
|
||||
CREATE = "create" # 创建
|
||||
UPDATE = "update" # 更新
|
||||
DELETE = "delete" # 删除
|
||||
MERGE = "merge" # 合并
|
||||
SPLIT = "split" # 拆分
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -136,10 +136,10 @@ class TeamSpace:
|
||||
class CollaborationManager:
|
||||
"""协作管理主类"""
|
||||
|
||||
def __init__(self, db_manager=None):
|
||||
self.db = db_manager
|
||||
self._shares_cache: dict[str, ProjectShare] = {}
|
||||
self._comments_cache: dict[str, list[Comment]] = {}
|
||||
def __init__(self, db_manager = None) -> None:
|
||||
self.db = db_manager
|
||||
self._shares_cache: dict[str, ProjectShare] = {}
|
||||
self._comments_cache: dict[str, list[Comment]] = {}
|
||||
|
||||
# ============ 项目分享 ============
|
||||
|
||||
@@ -147,57 +147,57 @@ class CollaborationManager:
|
||||
self,
|
||||
project_id: str,
|
||||
created_by: str,
|
||||
permission: str = "read_only",
|
||||
expires_in_days: int | None = None,
|
||||
max_uses: int | None = None,
|
||||
password: str | None = None,
|
||||
allow_download: bool = False,
|
||||
allow_export: bool = False,
|
||||
permission: str = "read_only",
|
||||
expires_in_days: int | None = None,
|
||||
max_uses: int | None = None,
|
||||
password: str | None = None,
|
||||
allow_download: bool = False,
|
||||
allow_export: bool = False,
|
||||
) -> ProjectShare:
|
||||
"""创建项目分享链接"""
|
||||
share_id = str(uuid.uuid4())
|
||||
token = self._generate_share_token(project_id)
|
||||
share_id = str(uuid.uuid4())
|
||||
token = self._generate_share_token(project_id)
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
expires_at = None
|
||||
now = datetime.now().isoformat()
|
||||
expires_at = None
|
||||
if expires_in_days:
|
||||
expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat()
|
||||
expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat()
|
||||
|
||||
password_hash = None
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = hashlib.sha256(password.encode()).hexdigest()
|
||||
password_hash = hashlib.sha256(password.encode()).hexdigest()
|
||||
|
||||
share = ProjectShare(
|
||||
id=share_id,
|
||||
project_id=project_id,
|
||||
token=token,
|
||||
permission=permission,
|
||||
created_by=created_by,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
max_uses=max_uses,
|
||||
use_count=0,
|
||||
password_hash=password_hash,
|
||||
is_active=True,
|
||||
allow_download=allow_download,
|
||||
allow_export=allow_export,
|
||||
share = ProjectShare(
|
||||
id = share_id,
|
||||
project_id = project_id,
|
||||
token = token,
|
||||
permission = permission,
|
||||
created_by = created_by,
|
||||
created_at = now,
|
||||
expires_at = expires_at,
|
||||
max_uses = max_uses,
|
||||
use_count = 0,
|
||||
password_hash = password_hash,
|
||||
is_active = True,
|
||||
allow_download = allow_download,
|
||||
allow_export = allow_export,
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
if self.db:
|
||||
self._save_share_to_db(share)
|
||||
|
||||
self._shares_cache[token] = share
|
||||
self._shares_cache[token] = share
|
||||
return share
|
||||
|
||||
def _generate_share_token(self, project_id: str) -> str:
|
||||
"""生成分享令牌"""
|
||||
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
|
||||
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
|
||||
return hashlib.sha256(data.encode()).hexdigest()[:32]
|
||||
|
||||
def _save_share_to_db(self, share: ProjectShare) -> None:
|
||||
"""保存分享记录到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO project_shares
|
||||
@@ -224,12 +224,12 @@ class CollaborationManager:
|
||||
)
|
||||
self.db.conn.commit()
|
||||
|
||||
def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None:
|
||||
def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None:
|
||||
"""验证分享令牌"""
|
||||
# 从缓存或数据库获取
|
||||
share = self._shares_cache.get(token)
|
||||
share = self._shares_cache.get(token)
|
||||
if not share and self.db:
|
||||
share = self._get_share_from_db(token)
|
||||
share = self._get_share_from_db(token)
|
||||
|
||||
if not share:
|
||||
return None
|
||||
@@ -250,7 +250,7 @@ class CollaborationManager:
|
||||
if share.password_hash:
|
||||
if not password:
|
||||
return None
|
||||
password_hash = hashlib.sha256(password.encode()).hexdigest()
|
||||
password_hash = hashlib.sha256(password.encode()).hexdigest()
|
||||
if password_hash != share.password_hash:
|
||||
return None
|
||||
|
||||
@@ -258,63 +258,63 @@ class CollaborationManager:
|
||||
|
||||
def _get_share_from_db(self, token: str) -> ProjectShare | None:
|
||||
"""从数据库获取分享记录"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM project_shares WHERE token = ?
|
||||
SELECT * FROM project_shares WHERE token = ?
|
||||
""",
|
||||
(token,),
|
||||
(token, ),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return ProjectShare(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
token=row[2],
|
||||
permission=row[3],
|
||||
created_by=row[4],
|
||||
created_at=row[5],
|
||||
expires_at=row[6],
|
||||
max_uses=row[7],
|
||||
use_count=row[8],
|
||||
password_hash=row[9],
|
||||
is_active=bool(row[10]),
|
||||
allow_download=bool(row[11]),
|
||||
allow_export=bool(row[12]),
|
||||
id = row[0],
|
||||
project_id = row[1],
|
||||
token = row[2],
|
||||
permission = row[3],
|
||||
created_by = row[4],
|
||||
created_at = row[5],
|
||||
expires_at = row[6],
|
||||
max_uses = row[7],
|
||||
use_count = row[8],
|
||||
password_hash = row[9],
|
||||
is_active = bool(row[10]),
|
||||
allow_download = bool(row[11]),
|
||||
allow_export = bool(row[12]),
|
||||
)
|
||||
|
||||
def increment_share_usage(self, token: str) -> None:
|
||||
"""增加分享链接使用次数"""
|
||||
share = self._shares_cache.get(token)
|
||||
share = self._shares_cache.get(token)
|
||||
if share:
|
||||
share.use_count += 1
|
||||
|
||||
if self.db:
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE project_shares
|
||||
SET use_count = use_count + 1
|
||||
WHERE token = ?
|
||||
SET use_count = use_count + 1
|
||||
WHERE token = ?
|
||||
""",
|
||||
(token,),
|
||||
(token, ),
|
||||
)
|
||||
self.db.conn.commit()
|
||||
|
||||
def revoke_share_link(self, share_id: str, revoked_by: str) -> bool:
|
||||
"""撤销分享链接"""
|
||||
if self.db:
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE project_shares
|
||||
SET is_active = 0
|
||||
WHERE id = ?
|
||||
SET is_active = 0
|
||||
WHERE id = ?
|
||||
""",
|
||||
(share_id,),
|
||||
(share_id, ),
|
||||
)
|
||||
self.db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -325,33 +325,33 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM project_shares
|
||||
WHERE project_id = ?
|
||||
WHERE project_id = ?
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
(project_id,),
|
||||
(project_id, ),
|
||||
)
|
||||
|
||||
shares = []
|
||||
shares = []
|
||||
for row in cursor.fetchall():
|
||||
shares.append(
|
||||
ProjectShare(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
token=row[2],
|
||||
permission=row[3],
|
||||
created_by=row[4],
|
||||
created_at=row[5],
|
||||
expires_at=row[6],
|
||||
max_uses=row[7],
|
||||
use_count=row[8],
|
||||
password_hash=row[9],
|
||||
is_active=bool(row[10]),
|
||||
allow_download=bool(row[11]),
|
||||
allow_export=bool(row[12]),
|
||||
id = row[0],
|
||||
project_id = row[1],
|
||||
token = row[2],
|
||||
permission = row[3],
|
||||
created_by = row[4],
|
||||
created_at = row[5],
|
||||
expires_at = row[6],
|
||||
max_uses = row[7],
|
||||
use_count = row[8],
|
||||
password_hash = row[9],
|
||||
is_active = bool(row[10]),
|
||||
allow_download = bool(row[11]),
|
||||
allow_export = bool(row[12]),
|
||||
)
|
||||
)
|
||||
return shares
|
||||
@@ -366,46 +366,46 @@ class CollaborationManager:
|
||||
author: str,
|
||||
author_name: str,
|
||||
content: str,
|
||||
parent_id: str | None = None,
|
||||
mentions: list[str] | None = None,
|
||||
attachments: list[dict] | None = None,
|
||||
parent_id: str | None = None,
|
||||
mentions: list[str] | None = None,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> Comment:
|
||||
"""添加评论"""
|
||||
comment_id = str(uuid.uuid4())
|
||||
now = datetime.now().isoformat()
|
||||
comment_id = str(uuid.uuid4())
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
comment = Comment(
|
||||
id=comment_id,
|
||||
project_id=project_id,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
parent_id=parent_id,
|
||||
author=author,
|
||||
author_name=author_name,
|
||||
content=content,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
resolved=False,
|
||||
resolved_by=None,
|
||||
resolved_at=None,
|
||||
mentions=mentions or [],
|
||||
attachments=attachments or [],
|
||||
comment = Comment(
|
||||
id = comment_id,
|
||||
project_id = project_id,
|
||||
target_type = target_type,
|
||||
target_id = target_id,
|
||||
parent_id = parent_id,
|
||||
author = author,
|
||||
author_name = author_name,
|
||||
content = content,
|
||||
created_at = now,
|
||||
updated_at = now,
|
||||
resolved = False,
|
||||
resolved_by = None,
|
||||
resolved_at = None,
|
||||
mentions = mentions or [],
|
||||
attachments = attachments or [],
|
||||
)
|
||||
|
||||
if self.db:
|
||||
self._save_comment_to_db(comment)
|
||||
|
||||
# 更新缓存
|
||||
key = f"{target_type}:{target_id}"
|
||||
key = f"{target_type}:{target_id}"
|
||||
if key not in self._comments_cache:
|
||||
self._comments_cache[key] = []
|
||||
self._comments_cache[key] = []
|
||||
self._comments_cache[key].append(comment)
|
||||
|
||||
return comment
|
||||
|
||||
def _save_comment_to_db(self, comment: Comment) -> None:
|
||||
"""保存评论到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO comments
|
||||
@@ -435,18 +435,18 @@ class CollaborationManager:
|
||||
self.db.conn.commit()
|
||||
|
||||
def get_comments(
|
||||
self, target_type: str, target_id: str, include_resolved: bool = True
|
||||
self, target_type: str, target_id: str, include_resolved: bool = True
|
||||
) -> list[Comment]:
|
||||
"""获取评论列表"""
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
if include_resolved:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM comments
|
||||
WHERE target_type = ? AND target_id = ?
|
||||
WHERE target_type = ? AND target_id = ?
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(target_type, target_id),
|
||||
@@ -455,13 +455,13 @@ class CollaborationManager:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM comments
|
||||
WHERE target_type = ? AND target_id = ? AND resolved = 0
|
||||
WHERE target_type = ? AND target_id = ? AND resolved = 0
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(target_type, target_id),
|
||||
)
|
||||
|
||||
comments = []
|
||||
comments = []
|
||||
for row in cursor.fetchall():
|
||||
comments.append(self._row_to_comment(row))
|
||||
return comments
|
||||
@@ -469,21 +469,21 @@ class CollaborationManager:
|
||||
def _row_to_comment(self, row) -> Comment:
|
||||
"""将数据库行转换为Comment对象"""
|
||||
return Comment(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
target_type=row[2],
|
||||
target_id=row[3],
|
||||
parent_id=row[4],
|
||||
author=row[5],
|
||||
author_name=row[6],
|
||||
content=row[7],
|
||||
created_at=row[8],
|
||||
updated_at=row[9],
|
||||
resolved=bool(row[10]),
|
||||
resolved_by=row[11],
|
||||
resolved_at=row[12],
|
||||
mentions=json.loads(row[13]) if row[13] else [],
|
||||
attachments=json.loads(row[14]) if row[14] else [],
|
||||
id = row[0],
|
||||
project_id = row[1],
|
||||
target_type = row[2],
|
||||
target_id = row[3],
|
||||
parent_id = row[4],
|
||||
author = row[5],
|
||||
author_name = row[6],
|
||||
content = row[7],
|
||||
created_at = row[8],
|
||||
updated_at = row[9],
|
||||
resolved = bool(row[10]),
|
||||
resolved_by = row[11],
|
||||
resolved_at = row[12],
|
||||
mentions = json.loads(row[13]) if row[13] else [],
|
||||
attachments = json.loads(row[14]) if row[14] else [],
|
||||
)
|
||||
|
||||
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
|
||||
@@ -491,13 +491,13 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return None
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE comments
|
||||
SET content = ?, updated_at = ?
|
||||
WHERE id = ? AND author = ?
|
||||
SET content = ?, updated_at = ?
|
||||
WHERE id = ? AND author = ?
|
||||
""",
|
||||
(content, now, comment_id, updated_by),
|
||||
)
|
||||
@@ -509,9 +509,9 @@ class CollaborationManager:
|
||||
|
||||
def _get_comment_by_id(self, comment_id: str) -> Comment | None:
|
||||
"""根据ID获取评论"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,))
|
||||
row = cursor.fetchone()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id, ))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_comment(row)
|
||||
return None
|
||||
@@ -521,13 +521,13 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return False
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE comments
|
||||
SET resolved = 1, resolved_by = ?, resolved_at = ?
|
||||
WHERE id = ?
|
||||
SET resolved = 1, resolved_by = ?, resolved_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(resolved_by, now, comment_id),
|
||||
)
|
||||
@@ -539,13 +539,13 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return False
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
# 只允许作者或管理员删除
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM comments
|
||||
WHERE id = ? AND (author = ? OR ? IN (
|
||||
SELECT created_by FROM projects WHERE id = comments.project_id
|
||||
WHERE id = ? AND (author = ? OR ? IN (
|
||||
SELECT created_by FROM projects WHERE id = comments.project_id
|
||||
))
|
||||
""",
|
||||
(comment_id, deleted_by, deleted_by),
|
||||
@@ -554,24 +554,24 @@ class CollaborationManager:
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_project_comments(
|
||||
self, project_id: str, limit: int = 50, offset: int = 0
|
||||
self, project_id: str, limit: int = 50, offset: int = 0
|
||||
) -> list[Comment]:
|
||||
"""获取项目下的所有评论"""
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM comments
|
||||
WHERE project_id = ?
|
||||
WHERE project_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(project_id, limit, offset),
|
||||
)
|
||||
|
||||
comments = []
|
||||
comments = []
|
||||
for row in cursor.fetchall():
|
||||
comments.append(self._row_to_comment(row))
|
||||
return comments
|
||||
@@ -587,32 +587,32 @@ class CollaborationManager:
|
||||
entity_name: str,
|
||||
changed_by: str,
|
||||
changed_by_name: str,
|
||||
old_value: dict | None = None,
|
||||
new_value: dict | None = None,
|
||||
description: str = "",
|
||||
session_id: str | None = None,
|
||||
old_value: dict | None = None,
|
||||
new_value: dict | None = None,
|
||||
description: str = "",
|
||||
session_id: str | None = None,
|
||||
) -> ChangeRecord:
|
||||
"""记录变更"""
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.now().isoformat()
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
record = ChangeRecord(
|
||||
id=record_id,
|
||||
project_id=project_id,
|
||||
change_type=change_type,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
entity_name=entity_name,
|
||||
changed_by=changed_by,
|
||||
changed_by_name=changed_by_name,
|
||||
changed_at=now,
|
||||
old_value=old_value,
|
||||
new_value=new_value,
|
||||
description=description,
|
||||
session_id=session_id,
|
||||
reverted=False,
|
||||
reverted_at=None,
|
||||
reverted_by=None,
|
||||
record = ChangeRecord(
|
||||
id = record_id,
|
||||
project_id = project_id,
|
||||
change_type = change_type,
|
||||
entity_type = entity_type,
|
||||
entity_id = entity_id,
|
||||
entity_name = entity_name,
|
||||
changed_by = changed_by,
|
||||
changed_by_name = changed_by_name,
|
||||
changed_at = now,
|
||||
old_value = old_value,
|
||||
new_value = new_value,
|
||||
description = description,
|
||||
session_id = session_id,
|
||||
reverted = False,
|
||||
reverted_at = None,
|
||||
reverted_by = None,
|
||||
)
|
||||
|
||||
if self.db:
|
||||
@@ -622,7 +622,7 @@ class CollaborationManager:
|
||||
|
||||
def _save_change_to_db(self, record: ChangeRecord) -> None:
|
||||
"""保存变更记录到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO change_history
|
||||
@@ -655,22 +655,22 @@ class CollaborationManager:
|
||||
def get_change_history(
|
||||
self,
|
||||
project_id: str,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[ChangeRecord]:
|
||||
"""获取变更历史"""
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
|
||||
if entity_type and entity_id:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM change_history
|
||||
WHERE project_id = ? AND entity_type = ? AND entity_id = ?
|
||||
WHERE project_id = ? AND entity_type = ? AND entity_id = ?
|
||||
ORDER BY changed_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
@@ -680,7 +680,7 @@ class CollaborationManager:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM change_history
|
||||
WHERE project_id = ? AND entity_type = ?
|
||||
WHERE project_id = ? AND entity_type = ?
|
||||
ORDER BY changed_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
@@ -690,14 +690,14 @@ class CollaborationManager:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM change_history
|
||||
WHERE project_id = ?
|
||||
WHERE project_id = ?
|
||||
ORDER BY changed_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(project_id, limit, offset),
|
||||
)
|
||||
|
||||
records = []
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
records.append(self._row_to_change_record(row))
|
||||
return records
|
||||
@@ -705,22 +705,22 @@ class CollaborationManager:
|
||||
def _row_to_change_record(self, row) -> ChangeRecord:
|
||||
"""将数据库行转换为ChangeRecord对象"""
|
||||
return ChangeRecord(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
change_type=row[2],
|
||||
entity_type=row[3],
|
||||
entity_id=row[4],
|
||||
entity_name=row[5],
|
||||
changed_by=row[6],
|
||||
changed_by_name=row[7],
|
||||
changed_at=row[8],
|
||||
old_value=json.loads(row[9]) if row[9] else None,
|
||||
new_value=json.loads(row[10]) if row[10] else None,
|
||||
description=row[11],
|
||||
session_id=row[12],
|
||||
reverted=bool(row[13]),
|
||||
reverted_at=row[14],
|
||||
reverted_by=row[15],
|
||||
id = row[0],
|
||||
project_id = row[1],
|
||||
change_type = row[2],
|
||||
entity_type = row[3],
|
||||
entity_id = row[4],
|
||||
entity_name = row[5],
|
||||
changed_by = row[6],
|
||||
changed_by_name = row[7],
|
||||
changed_at = row[8],
|
||||
old_value = json.loads(row[9]) if row[9] else None,
|
||||
new_value = json.loads(row[10]) if row[10] else None,
|
||||
description = row[11],
|
||||
session_id = row[12],
|
||||
reverted = bool(row[13]),
|
||||
reverted_at = row[14],
|
||||
reverted_by = row[15],
|
||||
)
|
||||
|
||||
def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
|
||||
@@ -728,17 +728,17 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM change_history
|
||||
WHERE entity_type = ? AND entity_id = ?
|
||||
WHERE entity_type = ? AND entity_id = ?
|
||||
ORDER BY changed_at ASC
|
||||
""",
|
||||
(entity_type, entity_id),
|
||||
)
|
||||
|
||||
records = []
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
records.append(self._row_to_change_record(row))
|
||||
return records
|
||||
@@ -748,13 +748,13 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return False
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE change_history
|
||||
SET reverted = 1, reverted_at = ?, reverted_by = ?
|
||||
WHERE id = ? AND reverted = 0
|
||||
SET reverted = 1, reverted_at = ?, reverted_by = ?
|
||||
WHERE id = ? AND reverted = 0
|
||||
""",
|
||||
(now, reverted_by, record_id),
|
||||
)
|
||||
@@ -766,49 +766,49 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return {}
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
|
||||
# 总变更数
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM change_history WHERE project_id = ?
|
||||
SELECT COUNT(*) FROM change_history WHERE project_id = ?
|
||||
""",
|
||||
(project_id,),
|
||||
(project_id, ),
|
||||
)
|
||||
total_changes = cursor.fetchone()[0]
|
||||
total_changes = cursor.fetchone()[0]
|
||||
|
||||
# 按类型统计
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT change_type, COUNT(*) FROM change_history
|
||||
WHERE project_id = ? GROUP BY change_type
|
||||
WHERE project_id = ? GROUP BY change_type
|
||||
""",
|
||||
(project_id,),
|
||||
(project_id, ),
|
||||
)
|
||||
type_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
type_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# 按实体类型统计
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT entity_type, COUNT(*) FROM change_history
|
||||
WHERE project_id = ? GROUP BY entity_type
|
||||
WHERE project_id = ? GROUP BY entity_type
|
||||
""",
|
||||
(project_id,),
|
||||
(project_id, ),
|
||||
)
|
||||
entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# 最近活跃的用户
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT changed_by_name, COUNT(*) as count FROM change_history
|
||||
WHERE project_id = ?
|
||||
WHERE project_id = ?
|
||||
GROUP BY changed_by_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 5
|
||||
""",
|
||||
(project_id,),
|
||||
(project_id, ),
|
||||
)
|
||||
top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()]
|
||||
top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()]
|
||||
|
||||
return {
|
||||
"total_changes": total_changes,
|
||||
@@ -827,27 +827,27 @@ class CollaborationManager:
|
||||
user_email: str,
|
||||
role: str,
|
||||
invited_by: str,
|
||||
permissions: list[str] | None = None,
|
||||
permissions: list[str] | None = None,
|
||||
) -> TeamMember:
|
||||
"""添加团队成员"""
|
||||
member_id = str(uuid.uuid4())
|
||||
now = datetime.now().isoformat()
|
||||
member_id = str(uuid.uuid4())
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
# 根据角色设置默认权限
|
||||
if permissions is None:
|
||||
permissions = self._get_default_permissions(role)
|
||||
permissions = self._get_default_permissions(role)
|
||||
|
||||
member = TeamMember(
|
||||
id=member_id,
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
user_email=user_email,
|
||||
role=role,
|
||||
joined_at=now,
|
||||
invited_by=invited_by,
|
||||
last_active_at=None,
|
||||
permissions=permissions,
|
||||
member = TeamMember(
|
||||
id = member_id,
|
||||
project_id = project_id,
|
||||
user_id = user_id,
|
||||
user_name = user_name,
|
||||
user_email = user_email,
|
||||
role = role,
|
||||
joined_at = now,
|
||||
invited_by = invited_by,
|
||||
last_active_at = None,
|
||||
permissions = permissions,
|
||||
)
|
||||
|
||||
if self.db:
|
||||
@@ -857,7 +857,7 @@ class CollaborationManager:
|
||||
|
||||
def _get_default_permissions(self, role: str) -> list[str]:
|
||||
"""获取角色的默认权限"""
|
||||
permissions_map = {
|
||||
permissions_map = {
|
||||
"owner": ["read", "write", "delete", "share", "admin", "export"],
|
||||
"admin": ["read", "write", "delete", "share", "export"],
|
||||
"editor": ["read", "write", "export"],
|
||||
@@ -868,7 +868,7 @@ class CollaborationManager:
|
||||
|
||||
def _save_member_to_db(self, member: TeamMember) -> None:
|
||||
"""保存成员到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO team_members
|
||||
@@ -896,16 +896,16 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM team_members WHERE project_id = ?
|
||||
SELECT * FROM team_members WHERE project_id = ?
|
||||
ORDER BY joined_at ASC
|
||||
""",
|
||||
(project_id,),
|
||||
(project_id, ),
|
||||
)
|
||||
|
||||
members = []
|
||||
members = []
|
||||
for row in cursor.fetchall():
|
||||
members.append(self._row_to_team_member(row))
|
||||
return members
|
||||
@@ -913,16 +913,16 @@ class CollaborationManager:
|
||||
def _row_to_team_member(self, row) -> TeamMember:
|
||||
"""将数据库行转换为TeamMember对象"""
|
||||
return TeamMember(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
user_id=row[2],
|
||||
user_name=row[3],
|
||||
user_email=row[4],
|
||||
role=row[5],
|
||||
joined_at=row[6],
|
||||
invited_by=row[7],
|
||||
last_active_at=row[8],
|
||||
permissions=json.loads(row[9]) if row[9] else [],
|
||||
id = row[0],
|
||||
project_id = row[1],
|
||||
user_id = row[2],
|
||||
user_name = row[3],
|
||||
user_email = row[4],
|
||||
role = row[5],
|
||||
joined_at = row[6],
|
||||
invited_by = row[7],
|
||||
last_active_at = row[8],
|
||||
permissions = json.loads(row[9]) if row[9] else [],
|
||||
)
|
||||
|
||||
def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool:
|
||||
@@ -930,13 +930,13 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return False
|
||||
|
||||
permissions = self._get_default_permissions(new_role)
|
||||
cursor = self.db.conn.cursor()
|
||||
permissions = self._get_default_permissions(new_role)
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE team_members
|
||||
SET role = ?, permissions = ?
|
||||
WHERE id = ?
|
||||
SET role = ?, permissions = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(new_role, json.dumps(permissions), member_id),
|
||||
)
|
||||
@@ -948,8 +948,8 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return False
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id,))
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id, ))
|
||||
self.db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
@@ -958,20 +958,20 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return False
|
||||
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT permissions FROM team_members
|
||||
WHERE project_id = ? AND user_id = ?
|
||||
WHERE project_id = ? AND user_id = ?
|
||||
""",
|
||||
(project_id, user_id),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return False
|
||||
|
||||
permissions = json.loads(row[0]) if row[0] else []
|
||||
permissions = json.loads(row[0]) if row[0] else []
|
||||
return permission in permissions or "admin" in permissions
|
||||
|
||||
def update_last_active(self, project_id: str, user_id: str) -> None:
|
||||
@@ -979,13 +979,13 @@ class CollaborationManager:
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
now = datetime.now().isoformat()
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE team_members
|
||||
SET last_active_at = ?
|
||||
WHERE project_id = ? AND user_id = ?
|
||||
SET last_active_at = ?
|
||||
WHERE project_id = ? AND user_id = ?
|
||||
""",
|
||||
(now, project_id, user_id),
|
||||
)
|
||||
@@ -993,12 +993,12 @@ class CollaborationManager:
|
||||
|
||||
|
||||
# 全局协作管理器实例
|
||||
_collaboration_manager = None
|
||||
_collaboration_manager = None
|
||||
|
||||
|
||||
def get_collaboration_manager(db_manager=None) -> None:
|
||||
def get_collaboration_manager(db_manager = None) -> None:
|
||||
"""获取协作管理器单例"""
|
||||
global _collaboration_manager
|
||||
if _collaboration_manager is None:
|
||||
_collaboration_manager = CollaborationManager(db_manager)
|
||||
_collaboration_manager = CollaborationManager(db_manager)
|
||||
return _collaboration_manager
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -11,8 +11,8 @@ import os
|
||||
class DocumentProcessor:
|
||||
"""文档处理器 - 提取 PDF/DOCX 文本"""
|
||||
|
||||
def __init__(self):
|
||||
self.supported_formats = {
|
||||
def __init__(self) -> None:
|
||||
self.supported_formats = {
|
||||
".pdf": self._extract_pdf,
|
||||
".docx": self._extract_docx,
|
||||
".doc": self._extract_docx,
|
||||
@@ -31,18 +31,18 @@ class DocumentProcessor:
|
||||
Returns:
|
||||
{"text": "提取的文本内容", "format": "文件格式"}
|
||||
"""
|
||||
ext = os.path.splitext(filename.lower())[1]
|
||||
ext = os.path.splitext(filename.lower())[1]
|
||||
|
||||
if ext not in self.supported_formats:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}"
|
||||
)
|
||||
|
||||
extractor = self.supported_formats[ext]
|
||||
text = extractor(content)
|
||||
extractor = self.supported_formats[ext]
|
||||
text = extractor(content)
|
||||
|
||||
# 清理文本
|
||||
text = self._clean_text(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
return {"text": text, "format": ext, "filename": filename}
|
||||
|
||||
@@ -51,12 +51,12 @@ class DocumentProcessor:
|
||||
try:
|
||||
import PyPDF2
|
||||
|
||||
pdf_file = io.BytesIO(content)
|
||||
reader = PyPDF2.PdfReader(pdf_file)
|
||||
pdf_file = io.BytesIO(content)
|
||||
reader = PyPDF2.PdfReader(pdf_file)
|
||||
|
||||
text_parts = []
|
||||
text_parts = []
|
||||
for page in reader.pages:
|
||||
page_text = page.extract_text()
|
||||
page_text = page.extract_text()
|
||||
if page_text:
|
||||
text_parts.append(page_text)
|
||||
|
||||
@@ -66,10 +66,10 @@ class DocumentProcessor:
|
||||
try:
|
||||
import pdfplumber
|
||||
|
||||
text_parts = []
|
||||
text_parts = []
|
||||
with pdfplumber.open(io.BytesIO(content)) as pdf:
|
||||
for page in pdf.pages:
|
||||
page_text = page.extract_text()
|
||||
page_text = page.extract_text()
|
||||
if page_text:
|
||||
text_parts.append(page_text)
|
||||
return "\n\n".join(text_parts)
|
||||
@@ -85,10 +85,10 @@ class DocumentProcessor:
|
||||
try:
|
||||
import docx
|
||||
|
||||
doc_file = io.BytesIO(content)
|
||||
doc = docx.Document(doc_file)
|
||||
doc_file = io.BytesIO(content)
|
||||
doc = docx.Document(doc_file)
|
||||
|
||||
text_parts = []
|
||||
text_parts = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
text_parts.append(para.text)
|
||||
@@ -96,7 +96,7 @@ class DocumentProcessor:
|
||||
# 提取表格中的文本
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
row_text = []
|
||||
row_text = []
|
||||
for cell in row.cells:
|
||||
if cell.text.strip():
|
||||
row_text.append(cell.text.strip())
|
||||
@@ -114,7 +114,7 @@ class DocumentProcessor:
|
||||
def _extract_txt(self, content: bytes) -> str:
|
||||
"""提取纯文本"""
|
||||
# 尝试多种编码
|
||||
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
|
||||
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -123,7 +123,7 @@ class DocumentProcessor:
|
||||
continue
|
||||
|
||||
# 如果都失败了,使用 latin-1 并忽略错误
|
||||
return content.decode("latin-1", errors="ignore")
|
||||
return content.decode("latin-1", errors = "ignore")
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""清理提取的文本"""
|
||||
@@ -131,29 +131,29 @@ class DocumentProcessor:
|
||||
return ""
|
||||
|
||||
# 移除多余的空白字符
|
||||
lines = text.split("\n")
|
||||
cleaned_lines = []
|
||||
lines = text.split("\n")
|
||||
cleaned_lines = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
line = line.strip()
|
||||
# 移除空行,但保留段落分隔
|
||||
if line:
|
||||
cleaned_lines.append(line)
|
||||
|
||||
# 合并行,保留段落结构
|
||||
text = "\n\n".join(cleaned_lines)
|
||||
text = "\n\n".join(cleaned_lines)
|
||||
|
||||
# 移除多余的空格
|
||||
text = " ".join(text.split())
|
||||
text = " ".join(text.split())
|
||||
|
||||
# 移除控制字符
|
||||
text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t")
|
||||
text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t")
|
||||
|
||||
return text.strip()
|
||||
|
||||
def is_supported(self, filename: str) -> bool:
|
||||
"""检查文件格式是否支持"""
|
||||
ext = os.path.splitext(filename.lower())[1]
|
||||
ext = os.path.splitext(filename.lower())[1]
|
||||
return ext in self.supported_formats
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ class SimpleTextExtractor:
|
||||
|
||||
def extract(self, content: bytes, filename: str) -> str:
|
||||
"""尝试提取文本"""
|
||||
encodings = ["utf-8", "gbk", "latin-1"]
|
||||
encodings = ["utf-8", "gbk", "latin-1"]
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -173,15 +173,15 @@ class SimpleTextExtractor:
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
return content.decode("latin-1", errors="ignore")
|
||||
return content.decode("latin-1", errors = "ignore")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试
|
||||
processor = DocumentProcessor()
|
||||
processor = DocumentProcessor()
|
||||
|
||||
# 测试文本提取
|
||||
test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs."
|
||||
result = processor.process(test_text.encode("utf-8"), "test.txt")
|
||||
test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs."
|
||||
result = processor.process(test_text.encode("utf-8"), "test.txt")
|
||||
print(f"Text extraction test: {len(result['text'])} chars")
|
||||
print(result["text"][:100])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,8 +12,8 @@ import httpx
|
||||
import numpy as np
|
||||
|
||||
# API Keys
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -27,9 +27,9 @@ class EntityEmbedding:
|
||||
class EntityAligner:
|
||||
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
|
||||
|
||||
def __init__(self, similarity_threshold: float = 0.85):
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_cache: dict[str, list[float]] = {}
|
||||
def __init__(self, similarity_threshold: float = 0.85) -> None:
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_cache: dict[str, list[float]] = {}
|
||||
|
||||
def get_embedding(self, text: str) -> list[float] | None:
|
||||
"""
|
||||
@@ -45,25 +45,25 @@ class EntityAligner:
|
||||
return None
|
||||
|
||||
# 检查缓存
|
||||
cache_key = hash(text)
|
||||
cache_key = hash(text)
|
||||
if cache_key in self.embedding_cache:
|
||||
return self.embedding_cache[cache_key]
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
response = httpx.post(
|
||||
f"{KIMI_BASE_URL}/v1/embeddings",
|
||||
headers={
|
||||
headers = {
|
||||
"Authorization": f"Bearer {KIMI_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": "k2p5", "input": text[:500]}, # 限制长度
|
||||
timeout=30.0,
|
||||
json = {"model": "k2p5", "input": text[:500]}, # 限制长度
|
||||
timeout = 30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = response.json()
|
||||
|
||||
embedding = result["data"][0]["embedding"]
|
||||
self.embedding_cache[cache_key] = embedding
|
||||
embedding = result["data"][0]["embedding"]
|
||||
self.embedding_cache[cache_key] = embedding
|
||||
return embedding
|
||||
|
||||
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
|
||||
@@ -81,20 +81,20 @@ class EntityAligner:
|
||||
Returns:
|
||||
相似度分数 (0-1)
|
||||
"""
|
||||
vec1 = np.array(embedding1)
|
||||
vec2 = np.array(embedding2)
|
||||
vec1 = np.array(embedding1)
|
||||
vec2 = np.array(embedding2)
|
||||
|
||||
# 余弦相似度
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm1 * norm2))
|
||||
|
||||
def get_entity_text(self, name: str, definition: str = "") -> str:
|
||||
def get_entity_text(self, name: str, definition: str = "") -> str:
|
||||
"""
|
||||
构建用于 embedding 的实体文本
|
||||
|
||||
@@ -113,9 +113,9 @@ class EntityAligner:
|
||||
self,
|
||||
project_id: str,
|
||||
name: str,
|
||||
definition: str = "",
|
||||
exclude_id: str | None = None,
|
||||
threshold: float | None = None,
|
||||
definition: str = "",
|
||||
exclude_id: str | None = None,
|
||||
threshold: float | None = None,
|
||||
) -> object | None:
|
||||
"""
|
||||
查找相似的实体
|
||||
@@ -131,54 +131,54 @@ class EntityAligner:
|
||||
相似的实体或 None
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.similarity_threshold
|
||||
threshold = self.similarity_threshold
|
||||
|
||||
try:
|
||||
from db_manager import get_db_manager
|
||||
|
||||
db = get_db_manager()
|
||||
db = get_db_manager()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
# 获取项目的所有实体
|
||||
entities = db.get_all_entities_for_embedding(project_id)
|
||||
entities = db.get_all_entities_for_embedding(project_id)
|
||||
|
||||
if not entities:
|
||||
return None
|
||||
|
||||
# 获取查询实体的 embedding
|
||||
query_text = self.get_entity_text(name, definition)
|
||||
query_embedding = self.get_embedding(query_text)
|
||||
query_text = self.get_entity_text(name, definition)
|
||||
query_embedding = self.get_embedding(query_text)
|
||||
|
||||
if query_embedding is None:
|
||||
# 如果 embedding API 失败,回退到简单匹配
|
||||
return self._fallback_similarity_match(entities, name, exclude_id)
|
||||
|
||||
best_match = None
|
||||
best_score = threshold
|
||||
best_match = None
|
||||
best_score = threshold
|
||||
|
||||
for entity in entities:
|
||||
if exclude_id and entity.id == exclude_id:
|
||||
continue
|
||||
|
||||
# 获取实体的 embedding
|
||||
entity_text = self.get_entity_text(entity.name, entity.definition)
|
||||
entity_embedding = self.get_embedding(entity_text)
|
||||
entity_text = self.get_entity_text(entity.name, entity.definition)
|
||||
entity_embedding = self.get_embedding(entity_text)
|
||||
|
||||
if entity_embedding is None:
|
||||
continue
|
||||
|
||||
# 计算相似度
|
||||
similarity = self.compute_similarity(query_embedding, entity_embedding)
|
||||
similarity = self.compute_similarity(query_embedding, entity_embedding)
|
||||
|
||||
if similarity > best_score:
|
||||
best_score = similarity
|
||||
best_match = entity
|
||||
best_score = similarity
|
||||
best_match = entity
|
||||
|
||||
return best_match
|
||||
|
||||
def _fallback_similarity_match(
|
||||
self, entities: list[object], name: str, exclude_id: str | None = None
|
||||
self, entities: list[object], name: str, exclude_id: str | None = None
|
||||
) -> object | None:
|
||||
"""
|
||||
回退到简单的相似度匹配(不使用 embedding)
|
||||
@@ -191,7 +191,7 @@ class EntityAligner:
|
||||
Returns:
|
||||
最相似的实体或 None
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
name_lower = name.lower()
|
||||
|
||||
# 1. 精确匹配
|
||||
for entity in entities:
|
||||
@@ -212,7 +212,7 @@ class EntityAligner:
|
||||
return None
|
||||
|
||||
def batch_align_entities(
|
||||
self, project_id: str, new_entities: list[dict], threshold: float | None = None
|
||||
self, project_id: str, new_entities: list[dict], threshold: float | None = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
批量对齐实体
|
||||
@@ -226,16 +226,16 @@ class EntityAligner:
|
||||
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.similarity_threshold
|
||||
threshold = self.similarity_threshold
|
||||
|
||||
results = []
|
||||
results = []
|
||||
|
||||
for new_ent in new_entities:
|
||||
matched = self.find_similar_entity(
|
||||
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
|
||||
matched = self.find_similar_entity(
|
||||
project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold
|
||||
)
|
||||
|
||||
result = {
|
||||
result = {
|
||||
"new_entity": new_ent,
|
||||
"matched_entity": None,
|
||||
"similarity": 0.0,
|
||||
@@ -244,28 +244,28 @@ class EntityAligner:
|
||||
|
||||
if matched:
|
||||
# 计算相似度
|
||||
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
|
||||
matched_text = self.get_entity_text(matched.name, matched.definition)
|
||||
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
|
||||
matched_text = self.get_entity_text(matched.name, matched.definition)
|
||||
|
||||
query_emb = self.get_embedding(query_text)
|
||||
matched_emb = self.get_embedding(matched_text)
|
||||
query_emb = self.get_embedding(query_text)
|
||||
matched_emb = self.get_embedding(matched_text)
|
||||
|
||||
if query_emb and matched_emb:
|
||||
similarity = self.compute_similarity(query_emb, matched_emb)
|
||||
result["matched_entity"] = {
|
||||
similarity = self.compute_similarity(query_emb, matched_emb)
|
||||
result["matched_entity"] = {
|
||||
"id": matched.id,
|
||||
"name": matched.name,
|
||||
"type": matched.type,
|
||||
"definition": matched.definition,
|
||||
}
|
||||
result["similarity"] = similarity
|
||||
result["should_merge"] = similarity >= threshold
|
||||
result["similarity"] = similarity
|
||||
result["should_merge"] = similarity >= threshold
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
|
||||
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
|
||||
"""
|
||||
使用 LLM 建议实体的别名
|
||||
|
||||
@@ -279,7 +279,7 @@ class EntityAligner:
|
||||
if not KIMI_API_KEY:
|
||||
return []
|
||||
|
||||
prompt = f"""为以下实体生成可能的别名或简称:
|
||||
prompt = f"""为以下实体生成可能的别名或简称:
|
||||
|
||||
实体名称:{entity_name}
|
||||
定义:{entity_definition}
|
||||
@@ -290,28 +290,28 @@ class EntityAligner:
|
||||
只返回 JSON,不要其他内容。"""
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
response = httpx.post(
|
||||
f"{KIMI_BASE_URL}/v1/chat/completions",
|
||||
headers={
|
||||
headers = {
|
||||
"Authorization": f"Bearer {KIMI_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
json = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
},
|
||||
timeout=30.0,
|
||||
timeout = 30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
data = json.loads(json_match.group())
|
||||
return data.get("aliases", [])
|
||||
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
|
||||
print(f"Alias suggestion failed: {e}")
|
||||
@@ -340,8 +340,8 @@ def simple_similarity(str1: str, str2: str) -> float:
|
||||
return 0.0
|
||||
|
||||
# 转换为小写
|
||||
s1 = str1.lower()
|
||||
s2 = str2.lower()
|
||||
s1 = str1.lower()
|
||||
s2 = str2.lower()
|
||||
|
||||
# 包含关系
|
||||
if s1 in s2 or s2 in s1:
|
||||
@@ -355,11 +355,11 @@ def simple_similarity(str1: str, str2: str) -> float:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试
|
||||
aligner = EntityAligner()
|
||||
aligner = EntityAligner()
|
||||
|
||||
# 测试 embedding
|
||||
test_text = "Kubernetes 容器编排平台"
|
||||
embedding = aligner.get_embedding(test_text)
|
||||
test_text = "Kubernetes 容器编排平台"
|
||||
embedding = aligner.get_embedding(test_text)
|
||||
if embedding:
|
||||
print(f"Embedding dimension: {len(embedding)}")
|
||||
print(f"First 5 values: {embedding[:5]}")
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
print("Embedding API not available")
|
||||
|
||||
# 测试相似度计算
|
||||
emb1 = [1.0, 0.0, 0.0]
|
||||
emb2 = [0.9, 0.1, 0.0]
|
||||
sim = aligner.compute_similarity(emb1, emb2)
|
||||
emb1 = [1.0, 0.0, 0.0]
|
||||
emb2 = [0.9, 0.1, 0.0]
|
||||
sim = aligner.compute_similarity(emb1, emb2)
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
|
||||
@@ -14,9 +14,9 @@ from typing import Any
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
PANDAS_AVAILABLE = True
|
||||
PANDAS_AVAILABLE = True
|
||||
except ImportError:
|
||||
PANDAS_AVAILABLE = False
|
||||
PANDAS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from reportlab.lib import colors
|
||||
@@ -32,9 +32,9 @@ try:
|
||||
TableStyle,
|
||||
)
|
||||
|
||||
REPORTLAB_AVAILABLE = True
|
||||
REPORTLAB_AVAILABLE = True
|
||||
except ImportError:
|
||||
REPORTLAB_AVAILABLE = False
|
||||
REPORTLAB_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -71,8 +71,8 @@ class ExportTranscript:
|
||||
class ExportManager:
|
||||
"""导出管理器 - 处理各种导出需求"""
|
||||
|
||||
def __init__(self, db_manager=None):
|
||||
self.db = db_manager
|
||||
def __init__(self, db_manager = None) -> None:
|
||||
self.db = db_manager
|
||||
|
||||
def export_knowledge_graph_svg(
|
||||
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
|
||||
@@ -84,21 +84,21 @@ class ExportManager:
|
||||
SVG 字符串
|
||||
"""
|
||||
# 计算布局参数
|
||||
width = 1200
|
||||
height = 800
|
||||
center_x = width / 2
|
||||
center_y = height / 2
|
||||
radius = 300
|
||||
width = 1200
|
||||
height = 800
|
||||
center_x = width / 2
|
||||
center_y = height / 2
|
||||
radius = 300
|
||||
|
||||
# 按类型分组实体
|
||||
entities_by_type = {}
|
||||
entities_by_type = {}
|
||||
for e in entities:
|
||||
if e.type not in entities_by_type:
|
||||
entities_by_type[e.type] = []
|
||||
entities_by_type[e.type] = []
|
||||
entities_by_type[e.type].append(e)
|
||||
|
||||
# 颜色映射
|
||||
type_colors = {
|
||||
type_colors = {
|
||||
"PERSON": "#FF6B6B",
|
||||
"ORGANIZATION": "#4ECDC4",
|
||||
"LOCATION": "#45B7D1",
|
||||
@@ -110,110 +110,110 @@ class ExportManager:
|
||||
}
|
||||
|
||||
# 计算实体位置
|
||||
entity_positions = {}
|
||||
angle_step = 2 * 3.14159 / max(len(entities), 1)
|
||||
entity_positions = {}
|
||||
angle_step = 2 * 3.14159 / max(len(entities), 1)
|
||||
|
||||
for i, entity in enumerate(entities):
|
||||
i * angle_step
|
||||
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
|
||||
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
|
||||
entity_positions[entity.id] = (x, y)
|
||||
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
|
||||
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
|
||||
entity_positions[entity.id] = (x, y)
|
||||
|
||||
# 生成 SVG
|
||||
svg_parts = [
|
||||
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" '
|
||||
f'viewBox="0 0 {width} {height}">',
|
||||
svg_parts = [
|
||||
f'<svg xmlns = "http://www.w3.org/2000/svg" width = "{width}" height = "{height}" '
|
||||
f'viewBox = "0 0 {width} {height}">',
|
||||
"<defs>",
|
||||
' <marker id="arrowhead" markerWidth="10" markerHeight="7" '
|
||||
'refX="9" refY="3.5" orient="auto">',
|
||||
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>',
|
||||
' <marker id = "arrowhead" markerWidth = "10" markerHeight = "7" '
|
||||
'refX = "9" refY = "3.5" orient = "auto">',
|
||||
' <polygon points = "0 0, 10 3.5, 0 7" fill = "#7f8c8d"/>',
|
||||
" </marker>",
|
||||
"</defs>",
|
||||
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>',
|
||||
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" '
|
||||
f'font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>',
|
||||
f'<rect width = "{width}" height = "{height}" fill = "#f8f9fa"/>',
|
||||
f'<text x = "{center_x}" y = "30" text-anchor = "middle" font-size = "20" '
|
||||
f'font-weight = "bold" fill = "#2c3e50">知识图谱 - {project_id}</text>',
|
||||
]
|
||||
|
||||
# 绘制关系连线
|
||||
for rel in relations:
|
||||
if rel.source in entity_positions and rel.target in entity_positions:
|
||||
x1, y1 = entity_positions[rel.source]
|
||||
x2, y2 = entity_positions[rel.target]
|
||||
x1, y1 = entity_positions[rel.source]
|
||||
x2, y2 = entity_positions[rel.target]
|
||||
|
||||
# 计算箭头终点(避免覆盖节点)
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
dist = (dx**2 + dy**2) ** 0.5
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
dist = (dx**2 + dy**2) ** 0.5
|
||||
if dist > 0:
|
||||
offset = 40
|
||||
x2 = x2 - dx * offset / dist
|
||||
y2 = y2 - dy * offset / dist
|
||||
offset = 40
|
||||
x2 = x2 - dx * offset / dist
|
||||
y2 = y2 - dy * offset / dist
|
||||
|
||||
svg_parts.append(
|
||||
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" '
|
||||
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>'
|
||||
f'<line x1 = "{x1}" y1 = "{y1}" x2 = "{x2}" y2 = "{y2}" '
|
||||
f'stroke = "#7f8c8d" stroke-width = "2" marker-end = "url(#arrowhead)" opacity = "0.6"/>'
|
||||
)
|
||||
|
||||
# 关系标签
|
||||
mid_x = (x1 + x2) / 2
|
||||
mid_y = (y1 + y2) / 2
|
||||
mid_x = (x1 + x2) / 2
|
||||
mid_y = (y1 + y2) / 2
|
||||
svg_parts.append(
|
||||
f'<rect x="{mid_x - 30}" y="{mid_y - 10}" width="60" height="20" '
|
||||
f'fill="white" stroke="#bdc3c7" rx="3"/>'
|
||||
f'<rect x = "{mid_x - 30}" y = "{mid_y - 10}" width = "60" height = "20" '
|
||||
f'fill = "white" stroke = "#bdc3c7" rx = "3"/>'
|
||||
)
|
||||
svg_parts.append(
|
||||
f'<text x="{mid_x}" y="{mid_y + 5}" text-anchor="middle" '
|
||||
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>'
|
||||
f'<text x = "{mid_x}" y = "{mid_y + 5}" text-anchor = "middle" '
|
||||
f'font-size = "10" fill = "#2c3e50">{rel.relation_type}</text>'
|
||||
)
|
||||
|
||||
# 绘制实体节点
|
||||
for entity in entities:
|
||||
if entity.id in entity_positions:
|
||||
x, y = entity_positions[entity.id]
|
||||
color = type_colors.get(entity.type, type_colors["default"])
|
||||
x, y = entity_positions[entity.id]
|
||||
color = type_colors.get(entity.type, type_colors["default"])
|
||||
|
||||
# 节点圆圈
|
||||
svg_parts.append(
|
||||
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>'
|
||||
f'<circle cx = "{x}" cy = "{y}" r = "35" fill = "{color}" stroke = "white" stroke-width = "3"/>'
|
||||
)
|
||||
|
||||
# 实体名称
|
||||
svg_parts.append(
|
||||
f'<text x="{x}" y="{y + 5}" text-anchor="middle" font-size="12" '
|
||||
f'font-weight="bold" fill="white">{entity.name[:8]}</text>'
|
||||
f'<text x = "{x}" y = "{y + 5}" text-anchor = "middle" font-size = "12" '
|
||||
f'font-weight = "bold" fill = "white">{entity.name[:8]}</text>'
|
||||
)
|
||||
|
||||
# 实体类型
|
||||
svg_parts.append(
|
||||
f'<text x="{x}" y="{y + 55}" text-anchor="middle" font-size="10" '
|
||||
f'fill="#7f8c8d">{entity.type}</text>'
|
||||
f'<text x = "{x}" y = "{y + 55}" text-anchor = "middle" font-size = "10" '
|
||||
f'fill = "#7f8c8d">{entity.type}</text>'
|
||||
)
|
||||
|
||||
# 图例
|
||||
legend_x = width - 150
|
||||
legend_y = 80
|
||||
rect_x = legend_x - 10
|
||||
rect_y = legend_y - 20
|
||||
rect_height = len(type_colors) * 25 + 10
|
||||
legend_x = width - 150
|
||||
legend_y = 80
|
||||
rect_x = legend_x - 10
|
||||
rect_y = legend_y - 20
|
||||
rect_height = len(type_colors) * 25 + 10
|
||||
svg_parts.append(
|
||||
f'<rect x="{rect_x}" y="{rect_y}" width="140" height="{rect_height}" '
|
||||
f'fill="white" stroke="#bdc3c7" rx="5"/>'
|
||||
f'<rect x = "{rect_x}" y = "{rect_y}" width = "140" height = "{rect_height}" '
|
||||
f'fill = "white" stroke = "#bdc3c7" rx = "5"/>'
|
||||
)
|
||||
svg_parts.append(
|
||||
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
|
||||
f'fill="#2c3e50">实体类型</text>'
|
||||
f'<text x = "{legend_x}" y = "{legend_y}" font-size = "12" font-weight = "bold" '
|
||||
f'fill = "#2c3e50">实体类型</text>'
|
||||
)
|
||||
|
||||
for i, (etype, color) in enumerate(type_colors.items()):
|
||||
if etype != "default":
|
||||
y_pos = legend_y + 25 + i * 20
|
||||
y_pos = legend_y + 25 + i * 20
|
||||
svg_parts.append(
|
||||
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
|
||||
f'<circle cx = "{legend_x + 10}" cy = "{y_pos}" r = "8" fill = "{color}"/>'
|
||||
)
|
||||
text_y = y_pos + 4
|
||||
text_y = y_pos + 4
|
||||
svg_parts.append(
|
||||
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
|
||||
f'fill="#2c3e50">{etype}</text>'
|
||||
f'<text x = "{legend_x + 25}" y = "{text_y}" font-size = "10" '
|
||||
f'fill = "#2c3e50">{etype}</text>'
|
||||
)
|
||||
|
||||
svg_parts.append("</svg>")
|
||||
@@ -231,12 +231,12 @@ class ExportManager:
|
||||
try:
|
||||
import cairosvg
|
||||
|
||||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||||
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
|
||||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||||
png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8"))
|
||||
return png_bytes
|
||||
except ImportError:
|
||||
# 如果没有 cairosvg,返回 SVG 的 base64
|
||||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||||
return base64.b64encode(svg_content.encode("utf-8"))
|
||||
|
||||
def export_entities_excel(self, entities: list[ExportEntity]) -> bytes:
|
||||
@@ -250,9 +250,9 @@ class ExportManager:
|
||||
raise ImportError("pandas is required for Excel export")
|
||||
|
||||
# 准备数据
|
||||
data = []
|
||||
data = []
|
||||
for e in entities:
|
||||
row = {
|
||||
row = {
|
||||
"ID": e.id,
|
||||
"名称": e.name,
|
||||
"类型": e.type,
|
||||
@@ -262,29 +262,29 @@ class ExportManager:
|
||||
}
|
||||
# 添加属性
|
||||
for attr_name, attr_value in e.attributes.items():
|
||||
row[f"属性:{attr_name}"] = attr_value
|
||||
row[f"属性:{attr_name}"] = attr_value
|
||||
data.append(row)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# 写入 Excel
|
||||
output = io.BytesIO()
|
||||
with pd.ExcelWriter(output, engine="openpyxl") as writer:
|
||||
df.to_excel(writer, sheet_name="实体列表", index=False)
|
||||
output = io.BytesIO()
|
||||
with pd.ExcelWriter(output, engine = "openpyxl") as writer:
|
||||
df.to_excel(writer, sheet_name = "实体列表", index = False)
|
||||
|
||||
# 调整列宽
|
||||
worksheet = writer.sheets["实体列表"]
|
||||
worksheet = writer.sheets["实体列表"]
|
||||
for column in worksheet.columns:
|
||||
max_length = 0
|
||||
column_letter = column[0].column_letter
|
||||
max_length = 0
|
||||
column_letter = column[0].column_letter
|
||||
for cell in column:
|
||||
try:
|
||||
if len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
max_length = len(str(cell.value))
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
pass
|
||||
adjusted_width = min(max_length + 2, 50)
|
||||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||||
adjusted_width = min(max_length + 2, 50)
|
||||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
@@ -295,24 +295,24 @@ class ExportManager:
|
||||
Returns:
|
||||
CSV 字符串
|
||||
"""
|
||||
output = io.StringIO()
|
||||
output = io.StringIO()
|
||||
|
||||
# 收集所有可能的属性列
|
||||
all_attrs = set()
|
||||
all_attrs = set()
|
||||
for e in entities:
|
||||
all_attrs.update(e.attributes.keys())
|
||||
|
||||
# 表头
|
||||
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [
|
||||
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [
|
||||
f"属性:{a}" for a in sorted(all_attrs)
|
||||
]
|
||||
|
||||
writer = csv.writer(output)
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(headers)
|
||||
|
||||
# 数据行
|
||||
for e in entities:
|
||||
row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count]
|
||||
row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count]
|
||||
for attr in sorted(all_attrs):
|
||||
row.append(e.attributes.get(attr, ""))
|
||||
writer.writerow(row)
|
||||
@@ -327,8 +327,8 @@ class ExportManager:
|
||||
CSV 字符串
|
||||
"""
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["ID", "源实体", "目标实体", "关系类型", "置信度", "证据"])
|
||||
|
||||
for r in relations:
|
||||
@@ -345,7 +345,7 @@ class ExportManager:
|
||||
Returns:
|
||||
Markdown 字符串
|
||||
"""
|
||||
lines = [
|
||||
lines = [
|
||||
f"# {transcript.name}",
|
||||
"",
|
||||
f"**类型**: {transcript.type}",
|
||||
@@ -369,10 +369,10 @@ class ExportManager:
|
||||
]
|
||||
)
|
||||
for seg in transcript.segments:
|
||||
speaker = seg.get("speaker", "Unknown")
|
||||
start = seg.get("start", 0)
|
||||
end = seg.get("end", 0)
|
||||
text = seg.get("text", "")
|
||||
speaker = seg.get("speaker", "Unknown")
|
||||
start = seg.get("start", 0)
|
||||
end = seg.get("end", 0)
|
||||
text = seg.get("text", "")
|
||||
lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}")
|
||||
lines.append("")
|
||||
|
||||
@@ -387,12 +387,12 @@ class ExportManager:
|
||||
]
|
||||
)
|
||||
for mention in transcript.entity_mentions:
|
||||
entity_id = mention.get("entity_id", "")
|
||||
entity = entities_map.get(entity_id)
|
||||
entity_name = entity.name if entity else mention.get("entity_name", "Unknown")
|
||||
entity_type = entity.type if entity else "Unknown"
|
||||
position = mention.get("position", "")
|
||||
context = mention.get("context", "")[:50] + "..." if mention.get("context") else ""
|
||||
entity_id = mention.get("entity_id", "")
|
||||
entity = entities_map.get(entity_id)
|
||||
entity_name = entity.name if entity else mention.get("entity_name", "Unknown")
|
||||
entity_type = entity.type if entity else "Unknown"
|
||||
position = mention.get("position", "")
|
||||
context = mention.get("context", "")[:50] + "..." if mention.get("context") else ""
|
||||
lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -404,7 +404,7 @@ class ExportManager:
|
||||
entities: list[ExportEntity],
|
||||
relations: list[ExportRelation],
|
||||
transcripts: list[ExportTranscript],
|
||||
summary: str = "",
|
||||
summary: str = "",
|
||||
) -> bytes:
|
||||
"""
|
||||
导出项目报告为 PDF 格式
|
||||
@@ -415,29 +415,29 @@ class ExportManager:
|
||||
if not REPORTLAB_AVAILABLE:
|
||||
raise ImportError("reportlab is required for PDF export")
|
||||
|
||||
output = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18
|
||||
output = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18
|
||||
)
|
||||
|
||||
# 样式
|
||||
styles = getSampleStyleSheet()
|
||||
title_style = ParagraphStyle(
|
||||
styles = getSampleStyleSheet()
|
||||
title_style = ParagraphStyle(
|
||||
"CustomTitle",
|
||||
parent=styles["Heading1"],
|
||||
fontSize=24,
|
||||
spaceAfter=30,
|
||||
textColor=colors.HexColor("#2c3e50"),
|
||||
parent = styles["Heading1"],
|
||||
fontSize = 24,
|
||||
spaceAfter = 30,
|
||||
textColor = colors.HexColor("#2c3e50"),
|
||||
)
|
||||
heading_style = ParagraphStyle(
|
||||
heading_style = ParagraphStyle(
|
||||
"CustomHeading",
|
||||
parent=styles["Heading2"],
|
||||
fontSize=16,
|
||||
spaceAfter=12,
|
||||
textColor=colors.HexColor("#34495e"),
|
||||
parent = styles["Heading2"],
|
||||
fontSize = 16,
|
||||
spaceAfter = 12,
|
||||
textColor = colors.HexColor("#34495e"),
|
||||
)
|
||||
|
||||
story = []
|
||||
story = []
|
||||
|
||||
# 标题页
|
||||
story.append(Paragraph("InsightFlow 项目报告", title_style))
|
||||
@@ -452,7 +452,7 @@ class ExportManager:
|
||||
|
||||
# 统计概览
|
||||
story.append(Paragraph("项目概览", heading_style))
|
||||
stats_data = [
|
||||
stats_data = [
|
||||
["指标", "数值"],
|
||||
["实体数量", str(len(entities))],
|
||||
["关系数量", str(len(relations))],
|
||||
@@ -460,14 +460,14 @@ class ExportManager:
|
||||
]
|
||||
|
||||
# 按类型统计实体
|
||||
type_counts = {}
|
||||
type_counts = {}
|
||||
for e in entities:
|
||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||
|
||||
for etype, count in sorted(type_counts.items()):
|
||||
stats_data.append([f"{etype} 实体", str(count)])
|
||||
|
||||
stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch])
|
||||
stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch])
|
||||
stats_table.setStyle(
|
||||
TableStyle(
|
||||
[
|
||||
@@ -496,8 +496,8 @@ class ExportManager:
|
||||
story.append(PageBreak())
|
||||
story.append(Paragraph("实体列表", heading_style))
|
||||
|
||||
entity_data = [["名称", "类型", "提及次数", "定义"]]
|
||||
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[
|
||||
entity_data = [["名称", "类型", "提及次数", "定义"]]
|
||||
for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[
|
||||
:50
|
||||
]: # 限制前50个
|
||||
entity_data.append(
|
||||
@@ -509,8 +509,8 @@ class ExportManager:
|
||||
]
|
||||
)
|
||||
|
||||
entity_table = Table(
|
||||
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
|
||||
entity_table = Table(
|
||||
entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
|
||||
)
|
||||
entity_table.setStyle(
|
||||
TableStyle(
|
||||
@@ -534,12 +534,12 @@ class ExportManager:
|
||||
story.append(PageBreak())
|
||||
story.append(Paragraph("关系列表", heading_style))
|
||||
|
||||
relation_data = [["源实体", "关系", "目标实体", "置信度"]]
|
||||
relation_data = [["源实体", "关系", "目标实体", "置信度"]]
|
||||
for r in relations[:100]: # 限制前100个
|
||||
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
|
||||
|
||||
relation_table = Table(
|
||||
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
|
||||
relation_table = Table(
|
||||
relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
|
||||
)
|
||||
relation_table.setStyle(
|
||||
TableStyle(
|
||||
@@ -574,7 +574,7 @@ class ExportManager:
|
||||
Returns:
|
||||
JSON 字符串
|
||||
"""
|
||||
data = {
|
||||
data = {
|
||||
"project_id": project_id,
|
||||
"project_name": project_name,
|
||||
"export_time": datetime.now().isoformat(),
|
||||
@@ -613,16 +613,16 @@ class ExportManager:
|
||||
],
|
||||
}
|
||||
|
||||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||
return json.dumps(data, ensure_ascii = False, indent = 2)
|
||||
|
||||
|
||||
# 全局导出管理器实例
|
||||
_export_manager = None
|
||||
_export_manager = None
|
||||
|
||||
|
||||
def get_export_manager(db_manager=None) -> None:
|
||||
def get_export_manager(db_manager = None) -> None:
|
||||
"""获取导出管理器实例"""
|
||||
global _export_manager
|
||||
if _export_manager is None:
|
||||
_export_manager = ExportManager(db_manager)
|
||||
_export_manager = ExportManager(db_manager)
|
||||
return _export_manager
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,30 +11,30 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Constants
|
||||
UUID_LENGTH = 8 # UUID 截断长度
|
||||
UUID_LENGTH = 8 # UUID 截断长度
|
||||
|
||||
# 尝试导入图像处理库
|
||||
try:
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
|
||||
PIL_AVAILABLE = True
|
||||
PIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIL_AVAILABLE = False
|
||||
PIL_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
CV2_AVAILABLE = True
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
|
||||
PYTESSERACT_AVAILABLE = True
|
||||
PYTESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTESSERACT_AVAILABLE = False
|
||||
PYTESSERACT_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,7 +44,7 @@ class ImageEntity:
|
||||
name: str
|
||||
type: str
|
||||
confidence: float
|
||||
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
|
||||
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,7 +70,7 @@ class ImageProcessingResult:
|
||||
width: int
|
||||
height: int
|
||||
success: bool
|
||||
error_message: str = ""
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,7 +87,7 @@ class ImageProcessor:
|
||||
"""图片处理器 - 处理各种类型图片"""
|
||||
|
||||
# 图片类型定义
|
||||
IMAGE_TYPES = {
|
||||
IMAGE_TYPES = {
|
||||
"whiteboard": "白板",
|
||||
"ppt": "PPT/演示文稿",
|
||||
"handwritten": "手写笔记",
|
||||
@@ -96,17 +96,17 @@ class ImageProcessor:
|
||||
"other": "其他",
|
||||
}
|
||||
|
||||
def __init__(self, temp_dir: str = None) -> None:
|
||||
def __init__(self, temp_dir: str = None) -> None:
|
||||
"""
|
||||
初始化图片处理器
|
||||
|
||||
Args:
|
||||
temp_dir: 临时文件目录
|
||||
"""
|
||||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||||
os.makedirs(self.temp_dir, exist_ok = True)
|
||||
|
||||
def preprocess_image(self, image, image_type: str = None) -> None:
|
||||
def preprocess_image(self, image, image_type: str = None) -> None:
|
||||
"""
|
||||
预处理图片以提高OCR质量
|
||||
|
||||
@@ -123,25 +123,25 @@ class ImageProcessor:
|
||||
try:
|
||||
# 转换为RGB(如果是RGBA)
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image = image.convert("RGB")
|
||||
|
||||
# 根据图片类型进行针对性处理
|
||||
if image_type == "whiteboard":
|
||||
# 白板:增强对比度,去除背景
|
||||
image = self._enhance_whiteboard(image)
|
||||
image = self._enhance_whiteboard(image)
|
||||
elif image_type == "handwritten":
|
||||
# 手写笔记:降噪,增强对比度
|
||||
image = self._enhance_handwritten(image)
|
||||
image = self._enhance_handwritten(image)
|
||||
elif image_type == "screenshot":
|
||||
# 截图:轻微锐化
|
||||
image = image.filter(ImageFilter.SHARPEN)
|
||||
image = image.filter(ImageFilter.SHARPEN)
|
||||
|
||||
# 通用处理:调整大小(如果太大)
|
||||
max_size = 4096
|
||||
max_size = 4096
|
||||
if max(image.size) > max_size:
|
||||
ratio = max_size / max(image.size)
|
||||
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
ratio = max_size / max(image.size)
|
||||
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
return image
|
||||
except Exception as e:
|
||||
@@ -151,33 +151,33 @@ class ImageProcessor:
|
||||
def _enhance_whiteboard(self, image) -> None:
|
||||
"""增强白板图片"""
|
||||
# 转换为灰度
|
||||
gray = image.convert("L")
|
||||
gray = image.convert("L")
|
||||
|
||||
# 增强对比度
|
||||
enhancer = ImageEnhance.Contrast(gray)
|
||||
enhanced = enhancer.enhance(2.0)
|
||||
enhancer = ImageEnhance.Contrast(gray)
|
||||
enhanced = enhancer.enhance(2.0)
|
||||
|
||||
# 二值化
|
||||
threshold = 128
|
||||
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
|
||||
threshold = 128
|
||||
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
|
||||
|
||||
return binary.convert("L")
|
||||
|
||||
def _enhance_handwritten(self, image) -> None:
|
||||
"""增强手写笔记图片"""
|
||||
# 转换为灰度
|
||||
gray = image.convert("L")
|
||||
gray = image.convert("L")
|
||||
|
||||
# 轻微降噪
|
||||
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
|
||||
blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1))
|
||||
|
||||
# 增强对比度
|
||||
enhancer = ImageEnhance.Contrast(blurred)
|
||||
enhanced = enhancer.enhance(1.5)
|
||||
enhancer = ImageEnhance.Contrast(blurred)
|
||||
enhanced = enhancer.enhance(1.5)
|
||||
|
||||
return enhanced
|
||||
|
||||
def detect_image_type(self, image, ocr_text: str = "") -> str:
|
||||
def detect_image_type(self, image, ocr_text: str = "") -> str:
|
||||
"""
|
||||
自动检测图片类型
|
||||
|
||||
@@ -193,8 +193,8 @@ class ImageProcessor:
|
||||
|
||||
try:
|
||||
# 基于图片特征和OCR内容判断类型
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
|
||||
# 检测是否为PPT(通常是16:9或4:3)
|
||||
if 1.3 <= aspect_ratio <= 1.8:
|
||||
@@ -204,12 +204,12 @@ class ImageProcessor:
|
||||
|
||||
# 检测是否为白板(大量手写文字,可能有箭头、框等)
|
||||
if CV2_AVAILABLE:
|
||||
img_array = np.array(image.convert("RGB"))
|
||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||
img_array = np.array(image.convert("RGB"))
|
||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||
|
||||
# 检测边缘(白板通常有很多线条)
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
edge_ratio = np.sum(edges > 0) / edges.size
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
edge_ratio = np.sum(edges > 0) / edges.size
|
||||
|
||||
# 如果边缘比例高,可能是白板
|
||||
if edge_ratio > 0.05 and len(ocr_text) > 50:
|
||||
@@ -236,7 +236,7 @@ class ImageProcessor:
|
||||
print(f"Image type detection error: {e}")
|
||||
return "other"
|
||||
|
||||
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
|
||||
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
|
||||
"""
|
||||
对图片进行OCR识别
|
||||
|
||||
@@ -252,15 +252,15 @@ class ImageProcessor:
|
||||
|
||||
try:
|
||||
# 预处理图片
|
||||
processed_image = self.preprocess_image(image)
|
||||
processed_image = self.preprocess_image(image)
|
||||
|
||||
# 执行OCR
|
||||
text = pytesseract.image_to_string(processed_image, lang=lang)
|
||||
text = pytesseract.image_to_string(processed_image, lang = lang)
|
||||
|
||||
# 获取置信度
|
||||
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
|
||||
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||||
data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT)
|
||||
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||||
|
||||
return text.strip(), avg_confidence / 100.0
|
||||
except Exception as e:
|
||||
@@ -277,26 +277,26 @@ class ImageProcessor:
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
entities = []
|
||||
entities = []
|
||||
|
||||
# 简单的实体提取规则(可以替换为LLM调用)
|
||||
# 提取大写字母开头的词组(可能是专有名词)
|
||||
import re
|
||||
|
||||
# 项目名称(通常是大写或带引号)
|
||||
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
|
||||
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
|
||||
for match in re.finditer(project_pattern, text):
|
||||
name = match.group(1) or match.group(2)
|
||||
name = match.group(1) or match.group(2)
|
||||
if name and len(name) > 2:
|
||||
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
|
||||
entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7))
|
||||
|
||||
# 人名(中文)
|
||||
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)"
|
||||
name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)"
|
||||
for match in re.finditer(name_pattern, text):
|
||||
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
|
||||
entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8))
|
||||
|
||||
# 技术术语
|
||||
tech_keywords = [
|
||||
tech_keywords = [
|
||||
"K8s",
|
||||
"Kubernetes",
|
||||
"Docker",
|
||||
@@ -314,13 +314,13 @@ class ImageProcessor:
|
||||
]
|
||||
for keyword in tech_keywords:
|
||||
if keyword in text:
|
||||
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
|
||||
entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9))
|
||||
|
||||
# 去重
|
||||
seen = set()
|
||||
unique_entities = []
|
||||
seen = set()
|
||||
unique_entities = []
|
||||
for e in entities:
|
||||
key = (e.name.lower(), e.type)
|
||||
key = (e.name.lower(), e.type)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_entities.append(e)
|
||||
@@ -341,19 +341,19 @@ class ImageProcessor:
|
||||
Returns:
|
||||
图片描述
|
||||
"""
|
||||
type_name = self.IMAGE_TYPES.get(image_type, "图片")
|
||||
type_name = self.IMAGE_TYPES.get(image_type, "图片")
|
||||
|
||||
description_parts = [f"这是一张{type_name}图片。"]
|
||||
description_parts = [f"这是一张{type_name}图片。"]
|
||||
|
||||
if ocr_text:
|
||||
# 提取前200字符作为摘要
|
||||
text_preview = ocr_text[:200].replace("\n", " ")
|
||||
text_preview = ocr_text[:200].replace("\n", " ")
|
||||
if len(ocr_text) > 200:
|
||||
text_preview += "..."
|
||||
description_parts.append(f"内容摘要:{text_preview}")
|
||||
|
||||
if entities:
|
||||
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
|
||||
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
|
||||
description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}")
|
||||
|
||||
return " ".join(description_parts)
|
||||
@@ -361,9 +361,9 @@ class ImageProcessor:
|
||||
def process_image(
|
||||
self,
|
||||
image_data: bytes,
|
||||
filename: str = None,
|
||||
image_id: str = None,
|
||||
detect_type: bool = True,
|
||||
filename: str = None,
|
||||
image_id: str = None,
|
||||
detect_type: bool = True,
|
||||
) -> ImageProcessingResult:
|
||||
"""
|
||||
处理单张图片
|
||||
@@ -377,73 +377,73 @@ class ImageProcessor:
|
||||
Returns:
|
||||
图片处理结果
|
||||
"""
|
||||
image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH]
|
||||
image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH]
|
||||
|
||||
if not PIL_AVAILABLE:
|
||||
return ImageProcessingResult(
|
||||
image_id=image_id,
|
||||
image_type="other",
|
||||
ocr_text="",
|
||||
description="PIL not available",
|
||||
entities=[],
|
||||
relations=[],
|
||||
width=0,
|
||||
height=0,
|
||||
success=False,
|
||||
error_message="PIL library not available",
|
||||
image_id = image_id,
|
||||
image_type = "other",
|
||||
ocr_text = "",
|
||||
description = "PIL not available",
|
||||
entities = [],
|
||||
relations = [],
|
||||
width = 0,
|
||||
height = 0,
|
||||
success = False,
|
||||
error_message = "PIL library not available",
|
||||
)
|
||||
|
||||
try:
|
||||
# 加载图片
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
width, height = image.size
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
width, height = image.size
|
||||
|
||||
# 执行OCR
|
||||
ocr_text, ocr_confidence = self.perform_ocr(image)
|
||||
ocr_text, ocr_confidence = self.perform_ocr(image)
|
||||
|
||||
# 检测图片类型
|
||||
image_type = "other"
|
||||
image_type = "other"
|
||||
if detect_type:
|
||||
image_type = self.detect_image_type(image, ocr_text)
|
||||
image_type = self.detect_image_type(image, ocr_text)
|
||||
|
||||
# 提取实体
|
||||
entities = self.extract_entities_from_text(ocr_text)
|
||||
entities = self.extract_entities_from_text(ocr_text)
|
||||
|
||||
# 生成描述
|
||||
description = self.generate_description(image_type, ocr_text, entities)
|
||||
description = self.generate_description(image_type, ocr_text, entities)
|
||||
|
||||
# 提取关系(基于实体共现)
|
||||
relations = self._extract_relations(entities, ocr_text)
|
||||
relations = self._extract_relations(entities, ocr_text)
|
||||
|
||||
# 保存图片文件(可选)
|
||||
if filename:
|
||||
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
|
||||
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
|
||||
image.save(save_path)
|
||||
|
||||
return ImageProcessingResult(
|
||||
image_id=image_id,
|
||||
image_type=image_type,
|
||||
ocr_text=ocr_text,
|
||||
description=description,
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
width=width,
|
||||
height=height,
|
||||
success=True,
|
||||
image_id = image_id,
|
||||
image_type = image_type,
|
||||
ocr_text = ocr_text,
|
||||
description = description,
|
||||
entities = entities,
|
||||
relations = relations,
|
||||
width = width,
|
||||
height = height,
|
||||
success = True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ImageProcessingResult(
|
||||
image_id=image_id,
|
||||
image_type="other",
|
||||
ocr_text="",
|
||||
description="",
|
||||
entities=[],
|
||||
relations=[],
|
||||
width=0,
|
||||
height=0,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
image_id = image_id,
|
||||
image_type = "other",
|
||||
ocr_text = "",
|
||||
description = "",
|
||||
entities = [],
|
||||
relations = [],
|
||||
width = 0,
|
||||
height = 0,
|
||||
success = False,
|
||||
error_message = str(e),
|
||||
)
|
||||
|
||||
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
|
||||
@@ -457,16 +457,16 @@ class ImageProcessor:
|
||||
Returns:
|
||||
关系列表
|
||||
"""
|
||||
relations = []
|
||||
relations = []
|
||||
|
||||
if len(entities) < 2:
|
||||
return relations
|
||||
|
||||
# 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关
|
||||
sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".")
|
||||
sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".")
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_entities = []
|
||||
sentence_entities = []
|
||||
for entity in entities:
|
||||
if entity.name in sentence:
|
||||
sentence_entities.append(entity)
|
||||
@@ -477,17 +477,17 @@ class ImageProcessor:
|
||||
for j in range(i + 1, len(sentence_entities)):
|
||||
relations.append(
|
||||
ImageRelation(
|
||||
source=sentence_entities[i].name,
|
||||
target=sentence_entities[j].name,
|
||||
relation_type="related",
|
||||
confidence=0.5,
|
||||
source = sentence_entities[i].name,
|
||||
target = sentence_entities[j].name,
|
||||
relation_type = "related",
|
||||
confidence = 0.5,
|
||||
)
|
||||
)
|
||||
|
||||
return relations
|
||||
|
||||
def process_batch(
|
||||
self, images_data: list[tuple[bytes, str]], project_id: str = None
|
||||
self, images_data: list[tuple[bytes, str]], project_id: str = None
|
||||
) -> BatchProcessingResult:
|
||||
"""
|
||||
批量处理图片
|
||||
@@ -499,12 +499,12 @@ class ImageProcessor:
|
||||
Returns:
|
||||
批量处理结果
|
||||
"""
|
||||
results = []
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
results = []
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for image_data, filename in images_data:
|
||||
result = self.process_image(image_data, filename)
|
||||
result = self.process_image(image_data, filename)
|
||||
results.append(result)
|
||||
|
||||
if result.success:
|
||||
@@ -513,10 +513,10 @@ class ImageProcessor:
|
||||
failed_count += 1
|
||||
|
||||
return BatchProcessingResult(
|
||||
results=results,
|
||||
total_count=len(results),
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
results = results,
|
||||
total_count = len(results),
|
||||
success_count = success_count,
|
||||
failed_count = failed_count,
|
||||
)
|
||||
|
||||
def image_to_base64(self, image_data: bytes) -> str:
|
||||
@@ -531,7 +531,7 @@ class ImageProcessor:
|
||||
"""
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
|
||||
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
|
||||
"""
|
||||
生成图片缩略图
|
||||
|
||||
@@ -546,11 +546,11 @@ class ImageProcessor:
|
||||
return image_data
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
image.thumbnail(size, Image.Resampling.LANCZOS)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format = "JPEG")
|
||||
return buffer.getvalue()
|
||||
except Exception as e:
|
||||
print(f"Thumbnail generation error: {e}")
|
||||
@@ -558,12 +558,12 @@ class ImageProcessor:
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_image_processor = None
|
||||
_image_processor = None
|
||||
|
||||
|
||||
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
|
||||
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
|
||||
"""获取图片处理器单例"""
|
||||
global _image_processor
|
||||
if _image_processor is None:
|
||||
_image_processor = ImageProcessor(temp_dir)
|
||||
_image_processor = ImageProcessor(temp_dir)
|
||||
return _image_processor
|
||||
|
||||
@@ -4,27 +4,27 @@
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
db_path = os.path.join(os.path.dirname(__file__), "insightflow.db")
|
||||
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
|
||||
db_path = os.path.join(os.path.dirname(__file__), "insightflow.db")
|
||||
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
|
||||
|
||||
print(f"Database path: {db_path}")
|
||||
print(f"Schema path: {schema_path}")
|
||||
|
||||
# Read schema
|
||||
with open(schema_path) as f:
|
||||
schema = f.read()
|
||||
schema = f.read()
|
||||
|
||||
# Execute schema
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Split schema by semicolons and execute each statement
|
||||
statements = schema.split(";")
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
statements = schema.split(";")
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for stmt in statements:
|
||||
stmt = stmt.strip()
|
||||
stmt = stmt.strip()
|
||||
if stmt:
|
||||
try:
|
||||
cursor.execute(stmt)
|
||||
|
||||
@@ -12,18 +12,18 @@ from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
|
||||
|
||||
class ReasoningType(Enum):
|
||||
"""推理类型"""
|
||||
|
||||
CAUSAL = "causal" # 因果推理
|
||||
ASSOCIATIVE = "associative" # 关联推理
|
||||
TEMPORAL = "temporal" # 时序推理
|
||||
COMPARATIVE = "comparative" # 对比推理
|
||||
SUMMARY = "summary" # 总结推理
|
||||
CAUSAL = "causal" # 因果推理
|
||||
ASSOCIATIVE = "associative" # 关联推理
|
||||
TEMPORAL = "temporal" # 时序推理
|
||||
COMPARATIVE = "comparative" # 对比推理
|
||||
SUMMARY = "summary" # 总结推理
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -51,38 +51,38 @@ class InferencePath:
|
||||
class KnowledgeReasoner:
|
||||
"""知识推理引擎"""
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = None):
|
||||
self.api_key = api_key or KIMI_API_KEY
|
||||
self.base_url = base_url or KIMI_BASE_URL
|
||||
self.headers = {
|
||||
def __init__(self, api_key: str = None, base_url: str = None) -> None:
|
||||
self.api_key = api_key or KIMI_API_KEY
|
||||
self.base_url = base_url or KIMI_BASE_URL
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
|
||||
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
|
||||
"""调用 LLM"""
|
||||
if not self.api_key:
|
||||
raise ValueError("KIMI_API_KEY not set")
|
||||
|
||||
payload = {
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
response = await client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
headers = self.headers,
|
||||
json = payload,
|
||||
timeout = 120.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def enhanced_qa(
|
||||
self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium"
|
||||
self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium"
|
||||
) -> ReasoningResult:
|
||||
"""
|
||||
增强问答 - 结合图谱推理的问答
|
||||
@@ -94,7 +94,7 @@ class KnowledgeReasoner:
|
||||
reasoning_depth: 推理深度 (shallow/medium/deep)
|
||||
"""
|
||||
# 1. 分析问题类型
|
||||
analysis = await self._analyze_question(query)
|
||||
analysis = await self._analyze_question(query)
|
||||
|
||||
# 2. 根据问题类型选择推理策略
|
||||
if analysis["type"] == "causal":
|
||||
@@ -108,7 +108,7 @@ class KnowledgeReasoner:
|
||||
|
||||
async def _analyze_question(self, query: str) -> dict:
|
||||
"""分析问题类型和意图"""
|
||||
prompt = f"""分析以下问题的类型和意图:
|
||||
prompt = f"""分析以下问题的类型和意图:
|
||||
|
||||
问题:{query}
|
||||
|
||||
@@ -127,9 +127,9 @@ class KnowledgeReasoner:
|
||||
- factual: 事实类问题(是什么、有哪些)
|
||||
- opinion: 观点类问题(怎么看、态度、评价)"""
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.1)
|
||||
content = await self._call_llm(prompt, temperature = 0.1)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
@@ -144,10 +144,10 @@ class KnowledgeReasoner:
|
||||
"""因果推理 - 分析原因和影响"""
|
||||
|
||||
# 构建因果分析提示
|
||||
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)
|
||||
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)
|
||||
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)
|
||||
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)
|
||||
|
||||
prompt = f"""基于以下知识图谱进行因果推理分析:
|
||||
prompt = f"""基于以下知识图谱进行因果推理分析:
|
||||
|
||||
## 问题
|
||||
{query}
|
||||
@@ -159,7 +159,7 @@ class KnowledgeReasoner:
|
||||
{relations_str[:2000]}
|
||||
|
||||
## 项目上下文
|
||||
{json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]}
|
||||
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]}
|
||||
|
||||
请进行因果分析,返回 JSON 格式:
|
||||
{{
|
||||
@@ -172,31 +172,31 @@ class KnowledgeReasoner:
|
||||
"knowledge_gaps": ["缺失信息1"]
|
||||
}}"""
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
content = await self._call_llm(prompt, temperature = 0.3)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
data = json.loads(json_match.group())
|
||||
return ReasoningResult(
|
||||
answer=data.get("answer", ""),
|
||||
reasoning_type=ReasoningType.CAUSAL,
|
||||
confidence=data.get("confidence", 0.7),
|
||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities=[],
|
||||
gaps=data.get("knowledge_gaps", []),
|
||||
answer = data.get("answer", ""),
|
||||
reasoning_type = ReasoningType.CAUSAL,
|
||||
confidence = data.get("confidence", 0.7),
|
||||
evidence = [{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities = [],
|
||||
gaps = data.get("knowledge_gaps", []),
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
return ReasoningResult(
|
||||
answer=content,
|
||||
reasoning_type=ReasoningType.CAUSAL,
|
||||
confidence=0.5,
|
||||
evidence=[],
|
||||
related_entities=[],
|
||||
gaps=["无法完成因果推理"],
|
||||
answer = content,
|
||||
reasoning_type = ReasoningType.CAUSAL,
|
||||
confidence = 0.5,
|
||||
evidence = [],
|
||||
related_entities = [],
|
||||
gaps = ["无法完成因果推理"],
|
||||
)
|
||||
|
||||
async def _comparative_reasoning(
|
||||
@@ -204,16 +204,16 @@ class KnowledgeReasoner:
|
||||
) -> ReasoningResult:
|
||||
"""对比推理 - 比较实体间的异同"""
|
||||
|
||||
prompt = f"""基于以下知识图谱进行对比分析:
|
||||
prompt = f"""基于以下知识图谱进行对比分析:
|
||||
|
||||
## 问题
|
||||
{query}
|
||||
|
||||
## 实体
|
||||
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]}
|
||||
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]}
|
||||
|
||||
## 关系
|
||||
{json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]}
|
||||
{json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]}
|
||||
|
||||
请进行对比分析,返回 JSON 格式:
|
||||
{{
|
||||
@@ -226,31 +226,31 @@ class KnowledgeReasoner:
|
||||
"knowledge_gaps": []
|
||||
}}"""
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
content = await self._call_llm(prompt, temperature = 0.3)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
data = json.loads(json_match.group())
|
||||
return ReasoningResult(
|
||||
answer=data.get("answer", ""),
|
||||
reasoning_type=ReasoningType.COMPARATIVE,
|
||||
confidence=data.get("confidence", 0.7),
|
||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities=[],
|
||||
gaps=data.get("knowledge_gaps", []),
|
||||
answer = data.get("answer", ""),
|
||||
reasoning_type = ReasoningType.COMPARATIVE,
|
||||
confidence = data.get("confidence", 0.7),
|
||||
evidence = [{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities = [],
|
||||
gaps = data.get("knowledge_gaps", []),
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
return ReasoningResult(
|
||||
answer=content,
|
||||
reasoning_type=ReasoningType.COMPARATIVE,
|
||||
confidence=0.5,
|
||||
evidence=[],
|
||||
related_entities=[],
|
||||
gaps=[],
|
||||
answer = content,
|
||||
reasoning_type = ReasoningType.COMPARATIVE,
|
||||
confidence = 0.5,
|
||||
evidence = [],
|
||||
related_entities = [],
|
||||
gaps = [],
|
||||
)
|
||||
|
||||
async def _temporal_reasoning(
|
||||
@@ -258,16 +258,16 @@ class KnowledgeReasoner:
|
||||
) -> ReasoningResult:
|
||||
"""时序推理 - 分析时间线和演变"""
|
||||
|
||||
prompt = f"""基于以下知识图谱进行时序分析:
|
||||
prompt = f"""基于以下知识图谱进行时序分析:
|
||||
|
||||
## 问题
|
||||
{query}
|
||||
|
||||
## 项目时间线
|
||||
{json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]}
|
||||
{json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]}
|
||||
|
||||
## 实体提及历史
|
||||
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]}
|
||||
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]}
|
||||
|
||||
请进行时序分析,返回 JSON 格式:
|
||||
{{
|
||||
@@ -280,31 +280,31 @@ class KnowledgeReasoner:
|
||||
"knowledge_gaps": []
|
||||
}}"""
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
content = await self._call_llm(prompt, temperature = 0.3)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
data = json.loads(json_match.group())
|
||||
return ReasoningResult(
|
||||
answer=data.get("answer", ""),
|
||||
reasoning_type=ReasoningType.TEMPORAL,
|
||||
confidence=data.get("confidence", 0.7),
|
||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities=[],
|
||||
gaps=data.get("knowledge_gaps", []),
|
||||
answer = data.get("answer", ""),
|
||||
reasoning_type = ReasoningType.TEMPORAL,
|
||||
confidence = data.get("confidence", 0.7),
|
||||
evidence = [{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities = [],
|
||||
gaps = data.get("knowledge_gaps", []),
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
return ReasoningResult(
|
||||
answer=content,
|
||||
reasoning_type=ReasoningType.TEMPORAL,
|
||||
confidence=0.5,
|
||||
evidence=[],
|
||||
related_entities=[],
|
||||
gaps=[],
|
||||
answer = content,
|
||||
reasoning_type = ReasoningType.TEMPORAL,
|
||||
confidence = 0.5,
|
||||
evidence = [],
|
||||
related_entities = [],
|
||||
gaps = [],
|
||||
)
|
||||
|
||||
async def _associative_reasoning(
|
||||
@@ -312,16 +312,16 @@ class KnowledgeReasoner:
|
||||
) -> ReasoningResult:
|
||||
"""关联推理 - 发现实体间的隐含关联"""
|
||||
|
||||
prompt = f"""基于以下知识图谱进行关联分析:
|
||||
prompt = f"""基于以下知识图谱进行关联分析:
|
||||
|
||||
## 问题
|
||||
{query}
|
||||
|
||||
## 实体
|
||||
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)}
|
||||
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)}
|
||||
|
||||
## 关系
|
||||
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)}
|
||||
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)}
|
||||
|
||||
请进行关联推理,发现隐含联系,返回 JSON 格式:
|
||||
{{
|
||||
@@ -334,52 +334,52 @@ class KnowledgeReasoner:
|
||||
"knowledge_gaps": []
|
||||
}}"""
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.4)
|
||||
content = await self._call_llm(prompt, temperature = 0.4)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
data = json.loads(json_match.group())
|
||||
return ReasoningResult(
|
||||
answer=data.get("answer", ""),
|
||||
reasoning_type=ReasoningType.ASSOCIATIVE,
|
||||
confidence=data.get("confidence", 0.7),
|
||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities=[],
|
||||
gaps=data.get("knowledge_gaps", []),
|
||||
answer = data.get("answer", ""),
|
||||
reasoning_type = ReasoningType.ASSOCIATIVE,
|
||||
confidence = data.get("confidence", 0.7),
|
||||
evidence = [{"text": e} for e in data.get("evidence", [])],
|
||||
related_entities = [],
|
||||
gaps = data.get("knowledge_gaps", []),
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
return ReasoningResult(
|
||||
answer=content,
|
||||
reasoning_type=ReasoningType.ASSOCIATIVE,
|
||||
confidence=0.5,
|
||||
evidence=[],
|
||||
related_entities=[],
|
||||
gaps=[],
|
||||
answer = content,
|
||||
reasoning_type = ReasoningType.ASSOCIATIVE,
|
||||
confidence = 0.5,
|
||||
evidence = [],
|
||||
related_entities = [],
|
||||
gaps = [],
|
||||
)
|
||||
|
||||
def find_inference_paths(
|
||||
self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3
|
||||
self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3
|
||||
) -> list[InferencePath]:
|
||||
"""
|
||||
发现两个实体之间的推理路径
|
||||
|
||||
使用 BFS 在关系图中搜索路径
|
||||
"""
|
||||
relations = graph_data.get("relations", [])
|
||||
relations = graph_data.get("relations", [])
|
||||
|
||||
# 构建邻接表
|
||||
adj = {}
|
||||
adj = {}
|
||||
for r in relations:
|
||||
src = r.get("source_id") or r.get("source")
|
||||
tgt = r.get("target_id") or r.get("target")
|
||||
src = r.get("source_id") or r.get("source")
|
||||
tgt = r.get("target_id") or r.get("target")
|
||||
if src not in adj:
|
||||
adj[src] = []
|
||||
adj[src] = []
|
||||
if tgt not in adj:
|
||||
adj[tgt] = []
|
||||
adj[tgt] = []
|
||||
adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r})
|
||||
# 无向图也添加反向
|
||||
adj[tgt].append(
|
||||
@@ -389,21 +389,21 @@ class KnowledgeReasoner:
|
||||
# BFS 搜索路径
|
||||
from collections import deque
|
||||
|
||||
paths = []
|
||||
queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])])
|
||||
paths = []
|
||||
queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])])
|
||||
{start_entity}
|
||||
|
||||
while queue and len(paths) < 5:
|
||||
current, path = queue.popleft()
|
||||
current, path = queue.popleft()
|
||||
|
||||
if current == end_entity and len(path) > 1:
|
||||
# 找到一条路径
|
||||
paths.append(
|
||||
InferencePath(
|
||||
start_entity=start_entity,
|
||||
end_entity=end_entity,
|
||||
path=path,
|
||||
strength=self._calculate_path_strength(path),
|
||||
start_entity = start_entity,
|
||||
end_entity = end_entity,
|
||||
path = path,
|
||||
strength = self._calculate_path_strength(path),
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -412,9 +412,9 @@ class KnowledgeReasoner:
|
||||
continue
|
||||
|
||||
for neighbor in adj.get(current, []):
|
||||
next_entity = neighbor["target"]
|
||||
next_entity = neighbor["target"]
|
||||
if next_entity not in [p["entity"] for p in path]: # 避免循环
|
||||
new_path = path + [
|
||||
new_path = path + [
|
||||
{
|
||||
"entity": next_entity,
|
||||
"relation": neighbor["relation"],
|
||||
@@ -424,7 +424,7 @@ class KnowledgeReasoner:
|
||||
queue.append((next_entity, new_path))
|
||||
|
||||
# 按强度排序
|
||||
paths.sort(key=lambda p: p.strength, reverse=True)
|
||||
paths.sort(key = lambda p: p.strength, reverse = True)
|
||||
return paths
|
||||
|
||||
def _calculate_path_strength(self, path: list[dict]) -> float:
|
||||
@@ -433,23 +433,23 @@ class KnowledgeReasoner:
|
||||
return 0.0
|
||||
|
||||
# 路径越短越强
|
||||
length_factor = 1.0 / len(path)
|
||||
length_factor = 1.0 / len(path)
|
||||
|
||||
# 关系置信度
|
||||
confidence_sum = 0
|
||||
confidence_count = 0
|
||||
confidence_sum = 0
|
||||
confidence_count = 0
|
||||
for node in path[1:]: # 跳过第一个节点
|
||||
rel_data = node.get("relation_data", {})
|
||||
rel_data = node.get("relation_data", {})
|
||||
if "confidence" in rel_data:
|
||||
confidence_sum += rel_data["confidence"]
|
||||
confidence_count += 1
|
||||
|
||||
confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5
|
||||
confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5
|
||||
|
||||
return length_factor * confidence_factor
|
||||
|
||||
async def summarize_project(
|
||||
self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive"
|
||||
self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive"
|
||||
) -> dict:
|
||||
"""
|
||||
项目智能总结
|
||||
@@ -457,17 +457,17 @@ class KnowledgeReasoner:
|
||||
Args:
|
||||
summary_type: comprehensive/executive/technical/risk
|
||||
"""
|
||||
type_prompts = {
|
||||
type_prompts = {
|
||||
"comprehensive": "全面总结项目的所有方面",
|
||||
"executive": "高管摘要,关注关键决策和风险",
|
||||
"technical": "技术总结,关注架构和技术栈",
|
||||
"risk": "风险分析,关注潜在问题和依赖",
|
||||
}
|
||||
|
||||
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}:
|
||||
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}:
|
||||
|
||||
## 项目信息
|
||||
{json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]}
|
||||
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]}
|
||||
|
||||
## 知识图谱
|
||||
实体数: {len(graph_data.get("entities", []))}
|
||||
@@ -483,9 +483,9 @@ class KnowledgeReasoner:
|
||||
"confidence": 0.85
|
||||
}}"""
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
content = await self._call_llm(prompt, temperature = 0.3)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
try:
|
||||
@@ -504,11 +504,11 @@ class KnowledgeReasoner:
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_reasoner = None
|
||||
_reasoner = None
|
||||
|
||||
|
||||
def get_knowledge_reasoner() -> KnowledgeReasoner:
|
||||
global _reasoner
|
||||
if _reasoner is None:
|
||||
_reasoner = KnowledgeReasoner()
|
||||
_reasoner = KnowledgeReasoner()
|
||||
return _reasoner
|
||||
|
||||
@@ -12,8 +12,8 @@ from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,22 +41,22 @@ class RelationExtractionResult:
|
||||
class LLMClient:
|
||||
"""Kimi API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = None):
|
||||
self.api_key = api_key or KIMI_API_KEY
|
||||
self.base_url = base_url or KIMI_BASE_URL
|
||||
self.headers = {
|
||||
def __init__(self, api_key: str = None, base_url: str = None) -> None:
|
||||
self.api_key = api_key or KIMI_API_KEY
|
||||
self.base_url = base_url or KIMI_BASE_URL
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def chat(
|
||||
self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False
|
||||
self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False
|
||||
) -> str:
|
||||
"""发送聊天请求"""
|
||||
if not self.api_key:
|
||||
raise ValueError("KIMI_API_KEY not set")
|
||||
|
||||
payload = {
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
||||
"temperature": temperature,
|
||||
@@ -64,24 +64,24 @@ class LLMClient:
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
response = await client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
headers = self.headers,
|
||||
json = payload,
|
||||
timeout = 120.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def chat_stream(
|
||||
self, messages: list[ChatMessage], temperature: float = 0.3
|
||||
self, messages: list[ChatMessage], temperature: float = 0.3
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天请求"""
|
||||
if not self.api_key:
|
||||
raise ValueError("KIMI_API_KEY not set")
|
||||
|
||||
payload = {
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
||||
"temperature": temperature,
|
||||
@@ -92,19 +92,19 @@ class LLMClient:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
headers = self.headers,
|
||||
json = payload,
|
||||
timeout = 120.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = json.loads(data)
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "content" in delta:
|
||||
yield delta["content"]
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
@@ -114,7 +114,7 @@ class LLMClient:
|
||||
self, text: str
|
||||
) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]:
|
||||
"""提取实体和关系,带置信度分数"""
|
||||
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
|
||||
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
|
||||
|
||||
文本:{text[:3000]}
|
||||
|
||||
@@ -139,30 +139,30 @@ class LLMClient:
|
||||
]
|
||||
}}"""
|
||||
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
content = await self.chat(messages, temperature=0.1)
|
||||
messages = [ChatMessage(role = "user", content = prompt)]
|
||||
content = await self.chat(messages, temperature = 0.1)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
if not json_match:
|
||||
return [], []
|
||||
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
entities = [
|
||||
data = json.loads(json_match.group())
|
||||
entities = [
|
||||
EntityExtractionResult(
|
||||
name=e["name"],
|
||||
type=e.get("type", "OTHER"),
|
||||
definition=e.get("definition", ""),
|
||||
confidence=e.get("confidence", 0.8),
|
||||
name = e["name"],
|
||||
type = e.get("type", "OTHER"),
|
||||
definition = e.get("definition", ""),
|
||||
confidence = e.get("confidence", 0.8),
|
||||
)
|
||||
for e in data.get("entities", [])
|
||||
]
|
||||
relations = [
|
||||
relations = [
|
||||
RelationExtractionResult(
|
||||
source=r["source"],
|
||||
target=r["target"],
|
||||
type=r.get("type", "related"),
|
||||
confidence=r.get("confidence", 0.8),
|
||||
source = r["source"],
|
||||
target = r["target"],
|
||||
type = r.get("type", "related"),
|
||||
confidence = r.get("confidence", 0.8),
|
||||
)
|
||||
for r in data.get("relations", [])
|
||||
]
|
||||
@@ -173,10 +173,10 @@ class LLMClient:
|
||||
|
||||
async def rag_query(self, query: str, context: str, project_context: dict) -> str:
|
||||
"""RAG 问答 - 基于项目上下文回答问题"""
|
||||
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
|
||||
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
|
||||
|
||||
## 项目信息
|
||||
{json.dumps(project_context, ensure_ascii=False, indent=2)}
|
||||
{json.dumps(project_context, ensure_ascii = False, indent = 2)}
|
||||
|
||||
## 相关上下文
|
||||
{context[:4000]}
|
||||
@@ -186,21 +186,21 @@ class LLMClient:
|
||||
|
||||
请用中文回答,保持简洁专业。如果信息不足,请明确说明。"""
|
||||
|
||||
messages = [
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
|
||||
role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
|
||||
),
|
||||
ChatMessage(role="user", content=prompt),
|
||||
ChatMessage(role = "user", content = prompt),
|
||||
]
|
||||
|
||||
return await self.chat(messages, temperature=0.3)
|
||||
return await self.chat(messages, temperature = 0.3)
|
||||
|
||||
async def agent_command(self, command: str, project_context: dict) -> dict:
|
||||
"""Agent 指令解析 - 将自然语言指令转换为结构化操作"""
|
||||
prompt = f"""解析以下用户指令,转换为结构化操作:
|
||||
prompt = f"""解析以下用户指令,转换为结构化操作:
|
||||
|
||||
## 项目信息
|
||||
{json.dumps(project_context, ensure_ascii=False, indent=2)}
|
||||
{json.dumps(project_context, ensure_ascii = False, indent = 2)}
|
||||
|
||||
## 用户指令
|
||||
{command}
|
||||
@@ -221,10 +221,10 @@ class LLMClient:
|
||||
- create_relation: 创建关系,params 包含 source(源实体), target(目标实体), relation_type(关系类型)
|
||||
"""
|
||||
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
content = await self.chat(messages, temperature=0.1)
|
||||
messages = [ChatMessage(role = "user", content = prompt)]
|
||||
content = await self.chat(messages, temperature = 0.1)
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
if not json_match:
|
||||
return {"intent": "unknown", "explanation": "无法解析指令"}
|
||||
|
||||
@@ -235,14 +235,14 @@ class LLMClient:
|
||||
|
||||
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
|
||||
"""分析实体在项目中的演变/态度变化"""
|
||||
mentions_text = "\n".join(
|
||||
mentions_text = "\n".join(
|
||||
[
|
||||
f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}"
|
||||
for m in mentions[:20]
|
||||
] # 限制数量
|
||||
)
|
||||
|
||||
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
|
||||
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
|
||||
|
||||
## 提及记录
|
||||
{mentions_text}
|
||||
@@ -255,16 +255,16 @@ class LLMClient:
|
||||
|
||||
用中文回答,结构清晰。"""
|
||||
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
return await self.chat(messages, temperature=0.3)
|
||||
messages = [ChatMessage(role = "user", content = prompt)]
|
||||
return await self.chat(messages, temperature = 0.3)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_llm_client = None
|
||||
_llm_client = None
|
||||
|
||||
|
||||
def get_llm_client() -> LLMClient:
|
||||
global _llm_client
|
||||
if _llm_client is None:
|
||||
_llm_client = LLMClient()
|
||||
_llm_client = LLMClient()
|
||||
return _llm_client
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,13 +9,13 @@ from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
# Constants
|
||||
UUID_LENGTH = 8 # UUID 截断长度
|
||||
UUID_LENGTH = 8 # UUID 截断长度
|
||||
|
||||
# 尝试导入embedding库
|
||||
try:
|
||||
NUMPY_AVAILABLE = True
|
||||
NUMPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
NUMPY_AVAILABLE = False
|
||||
NUMPY_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,11 +30,11 @@ class MultimodalEntity:
|
||||
source_id: str
|
||||
mention_context: str
|
||||
confidence: float
|
||||
modality_features: dict = None # 模态特定特征
|
||||
modality_features: dict = None # 模态特定特征
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.modality_features is None:
|
||||
self.modality_features = {}
|
||||
self.modality_features = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -78,7 +78,7 @@ class MultimodalEntityLinker:
|
||||
"""多模态实体关联器 - 跨模态实体对齐和知识融合"""
|
||||
|
||||
# 关联类型
|
||||
LINK_TYPES = {
|
||||
LINK_TYPES = {
|
||||
"same_as": "同一实体",
|
||||
"related_to": "相关实体",
|
||||
"part_of": "组成部分",
|
||||
@@ -86,16 +86,16 @@ class MultimodalEntityLinker:
|
||||
}
|
||||
|
||||
# 模态类型
|
||||
MODALITIES = ["audio", "video", "image", "document"]
|
||||
MODALITIES = ["audio", "video", "image", "document"]
|
||||
|
||||
def __init__(self, similarity_threshold: float = 0.85) -> None:
|
||||
def __init__(self, similarity_threshold: float = 0.85) -> None:
|
||||
"""
|
||||
初始化多模态实体关联器
|
||||
|
||||
Args:
|
||||
similarity_threshold: 相似度阈值
|
||||
"""
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
def calculate_string_similarity(self, s1: str, s2: str) -> float:
|
||||
"""
|
||||
@@ -111,7 +111,7 @@ class MultimodalEntityLinker:
|
||||
if not s1 or not s2:
|
||||
return 0.0
|
||||
|
||||
s1, s2 = s1.lower().strip(), s2.lower().strip()
|
||||
s1, s2 = s1.lower().strip(), s2.lower().strip()
|
||||
|
||||
# 完全匹配
|
||||
if s1 == s2:
|
||||
@@ -136,7 +136,7 @@ class MultimodalEntityLinker:
|
||||
(相似度, 匹配类型)
|
||||
"""
|
||||
# 名称相似度
|
||||
name_sim = self.calculate_string_similarity(
|
||||
name_sim = self.calculate_string_similarity(
|
||||
entity1.get("name", ""), entity2.get("name", "")
|
||||
)
|
||||
|
||||
@@ -145,8 +145,8 @@ class MultimodalEntityLinker:
|
||||
return 1.0, "exact"
|
||||
|
||||
# 检查别名
|
||||
aliases1 = set(a.lower() for a in entity1.get("aliases", []))
|
||||
aliases2 = set(a.lower() for a in entity2.get("aliases", []))
|
||||
aliases1 = set(a.lower() for a in entity1.get("aliases", []))
|
||||
aliases2 = set(a.lower() for a in entity2.get("aliases", []))
|
||||
|
||||
if aliases1 & aliases2: # 有共同别名
|
||||
return 0.95, "alias_match"
|
||||
@@ -157,12 +157,12 @@ class MultimodalEntityLinker:
|
||||
return 0.95, "alias_match"
|
||||
|
||||
# 定义相似度
|
||||
def_sim = self.calculate_string_similarity(
|
||||
def_sim = self.calculate_string_similarity(
|
||||
entity1.get("definition", ""), entity2.get("definition", "")
|
||||
)
|
||||
|
||||
# 综合相似度
|
||||
combined_sim = name_sim * 0.7 + def_sim * 0.3
|
||||
combined_sim = name_sim * 0.7 + def_sim * 0.3
|
||||
|
||||
if combined_sim >= self.similarity_threshold:
|
||||
return combined_sim, "fuzzy"
|
||||
@@ -170,7 +170,7 @@ class MultimodalEntityLinker:
|
||||
return combined_sim, "none"
|
||||
|
||||
def find_matching_entity(
|
||||
self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None
|
||||
self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None
|
||||
) -> AlignmentResult | None:
|
||||
"""
|
||||
在候选实体中查找匹配的实体
|
||||
@@ -183,28 +183,28 @@ class MultimodalEntityLinker:
|
||||
Returns:
|
||||
对齐结果
|
||||
"""
|
||||
exclude_ids = exclude_ids or set()
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
exclude_ids = exclude_ids or set()
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for candidate in candidate_entities:
|
||||
if candidate.get("id") in exclude_ids:
|
||||
continue
|
||||
|
||||
similarity, match_type = self.calculate_entity_similarity(query_entity, candidate)
|
||||
similarity, match_type = self.calculate_entity_similarity(query_entity, candidate)
|
||||
|
||||
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
||||
best_similarity = similarity
|
||||
best_match = candidate
|
||||
best_match_type = match_type
|
||||
best_similarity = similarity
|
||||
best_match = candidate
|
||||
best_match_type = match_type
|
||||
|
||||
if best_match:
|
||||
return AlignmentResult(
|
||||
entity_id=query_entity.get("id"),
|
||||
matched_entity_id=best_match.get("id"),
|
||||
similarity=best_similarity,
|
||||
match_type=best_match_type,
|
||||
confidence=best_similarity,
|
||||
entity_id = query_entity.get("id"),
|
||||
matched_entity_id = best_match.get("id"),
|
||||
similarity = best_similarity,
|
||||
match_type = best_match_type,
|
||||
confidence = best_similarity,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -230,10 +230,10 @@ class MultimodalEntityLinker:
|
||||
Returns:
|
||||
实体关联列表
|
||||
"""
|
||||
links = []
|
||||
links = []
|
||||
|
||||
# 合并所有实体
|
||||
all_entities = {
|
||||
all_entities = {
|
||||
"audio": audio_entities,
|
||||
"video": video_entities,
|
||||
"image": image_entities,
|
||||
@@ -246,24 +246,24 @@ class MultimodalEntityLinker:
|
||||
if mod1 >= mod2: # 避免重复比较
|
||||
continue
|
||||
|
||||
entities1 = all_entities.get(mod1, [])
|
||||
entities2 = all_entities.get(mod2, [])
|
||||
entities1 = all_entities.get(mod1, [])
|
||||
entities2 = all_entities.get(mod2, [])
|
||||
|
||||
for ent1 in entities1:
|
||||
# 在另一个模态中查找匹配
|
||||
result = self.find_matching_entity(ent1, entities2)
|
||||
result = self.find_matching_entity(ent1, entities2)
|
||||
|
||||
if result and result.matched_entity_id:
|
||||
link = EntityLink(
|
||||
id=str(uuid.uuid4())[:UUID_LENGTH],
|
||||
project_id=project_id,
|
||||
source_entity_id=ent1.get("id"),
|
||||
target_entity_id=result.matched_entity_id,
|
||||
link_type="same_as" if result.similarity > 0.95 else "related_to",
|
||||
source_modality=mod1,
|
||||
target_modality=mod2,
|
||||
confidence=result.confidence,
|
||||
evidence=f"Cross-modal alignment: {result.match_type}",
|
||||
link = EntityLink(
|
||||
id = str(uuid.uuid4())[:UUID_LENGTH],
|
||||
project_id = project_id,
|
||||
source_entity_id = ent1.get("id"),
|
||||
target_entity_id = result.matched_entity_id,
|
||||
link_type = "same_as" if result.similarity > 0.95 else "related_to",
|
||||
source_modality = mod1,
|
||||
target_modality = mod2,
|
||||
confidence = result.confidence,
|
||||
evidence = f"Cross-modal alignment: {result.match_type}",
|
||||
)
|
||||
links.append(link)
|
||||
|
||||
@@ -284,7 +284,7 @@ class MultimodalEntityLinker:
|
||||
融合结果
|
||||
"""
|
||||
# 收集所有属性
|
||||
fused_properties = {
|
||||
fused_properties = {
|
||||
"names": set(),
|
||||
"definitions": [],
|
||||
"aliases": set(),
|
||||
@@ -293,7 +293,7 @@ class MultimodalEntityLinker:
|
||||
"contexts": [],
|
||||
}
|
||||
|
||||
merged_ids = []
|
||||
merged_ids = []
|
||||
|
||||
for entity in linked_entities:
|
||||
merged_ids.append(entity.get("id"))
|
||||
@@ -318,21 +318,21 @@ class MultimodalEntityLinker:
|
||||
fused_properties["contexts"].append(mention.get("mention_context"))
|
||||
|
||||
# 选择最佳定义(最长的那个)
|
||||
best_definition = (
|
||||
max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else ""
|
||||
best_definition = (
|
||||
max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else ""
|
||||
)
|
||||
|
||||
# 选择最佳名称(最常见的那个)
|
||||
from collections import Counter
|
||||
|
||||
name_counts = Counter(fused_properties["names"])
|
||||
best_name = name_counts.most_common(1)[0][0] if name_counts else ""
|
||||
name_counts = Counter(fused_properties["names"])
|
||||
best_name = name_counts.most_common(1)[0][0] if name_counts else ""
|
||||
|
||||
# 构建融合结果
|
||||
return FusionResult(
|
||||
canonical_entity_id=entity_id,
|
||||
merged_entity_ids=merged_ids,
|
||||
fused_properties={
|
||||
canonical_entity_id = entity_id,
|
||||
merged_entity_ids = merged_ids,
|
||||
fused_properties = {
|
||||
"name": best_name,
|
||||
"definition": best_definition,
|
||||
"aliases": list(fused_properties["aliases"]),
|
||||
@@ -340,8 +340,8 @@ class MultimodalEntityLinker:
|
||||
"modalities": list(fused_properties["modalities"]),
|
||||
"contexts": fused_properties["contexts"][:10], # 最多10个上下文
|
||||
},
|
||||
source_modalities=list(fused_properties["modalities"]),
|
||||
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
|
||||
source_modalities = list(fused_properties["modalities"]),
|
||||
confidence = min(1.0, len(linked_entities) * 0.2 + 0.5),
|
||||
)
|
||||
|
||||
def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
|
||||
@@ -354,30 +354,30 @@ class MultimodalEntityLinker:
|
||||
Returns:
|
||||
冲突列表
|
||||
"""
|
||||
conflicts = []
|
||||
conflicts = []
|
||||
|
||||
# 按名称分组
|
||||
name_groups = {}
|
||||
name_groups = {}
|
||||
for entity in entities:
|
||||
name = entity.get("name", "").lower()
|
||||
name = entity.get("name", "").lower()
|
||||
if name:
|
||||
if name not in name_groups:
|
||||
name_groups[name] = []
|
||||
name_groups[name] = []
|
||||
name_groups[name].append(entity)
|
||||
|
||||
# 检测同名但定义不同的实体
|
||||
for name, group in name_groups.items():
|
||||
if len(group) > 1:
|
||||
# 检查定义是否相似
|
||||
definitions = [e.get("definition", "") for e in group if e.get("definition")]
|
||||
definitions = [e.get("definition", "") for e in group if e.get("definition")]
|
||||
|
||||
if len(definitions) > 1:
|
||||
# 计算定义之间的相似度
|
||||
sim_matrix = []
|
||||
sim_matrix = []
|
||||
for i, d1 in enumerate(definitions):
|
||||
for j, d2 in enumerate(definitions):
|
||||
if i < j:
|
||||
sim = self.calculate_string_similarity(d1, d2)
|
||||
sim = self.calculate_string_similarity(d1, d2)
|
||||
sim_matrix.append(sim)
|
||||
|
||||
# 如果定义相似度都很低,可能是冲突
|
||||
@@ -394,7 +394,7 @@ class MultimodalEntityLinker:
|
||||
return conflicts
|
||||
|
||||
def suggest_entity_merges(
|
||||
self, entities: list[dict], existing_links: list[EntityLink] = None
|
||||
self, entities: list[dict], existing_links: list[EntityLink] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
建议实体合并
|
||||
@@ -406,13 +406,13 @@ class MultimodalEntityLinker:
|
||||
Returns:
|
||||
合并建议列表
|
||||
"""
|
||||
suggestions = []
|
||||
existing_pairs = set()
|
||||
suggestions = []
|
||||
existing_pairs = set()
|
||||
|
||||
# 记录已有的关联
|
||||
if existing_links:
|
||||
for link in existing_links:
|
||||
pair = tuple(sorted([link.source_entity_id, link.target_entity_id]))
|
||||
pair = tuple(sorted([link.source_entity_id, link.target_entity_id]))
|
||||
existing_pairs.add(pair)
|
||||
|
||||
# 检查所有实体对
|
||||
@@ -422,12 +422,12 @@ class MultimodalEntityLinker:
|
||||
continue
|
||||
|
||||
# 检查是否已有关联
|
||||
pair = tuple(sorted([ent1.get("id"), ent2.get("id")]))
|
||||
pair = tuple(sorted([ent1.get("id"), ent2.get("id")]))
|
||||
if pair in existing_pairs:
|
||||
continue
|
||||
|
||||
# 计算相似度
|
||||
similarity, match_type = self.calculate_entity_similarity(ent1, ent2)
|
||||
similarity, match_type = self.calculate_entity_similarity(ent1, ent2)
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
suggestions.append(
|
||||
@@ -441,7 +441,7 @@ class MultimodalEntityLinker:
|
||||
)
|
||||
|
||||
# 按相似度排序
|
||||
suggestions.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
suggestions.sort(key = lambda x: x["similarity"], reverse = True)
|
||||
|
||||
return suggestions
|
||||
|
||||
@@ -451,8 +451,8 @@ class MultimodalEntityLinker:
|
||||
entity_id: str,
|
||||
source_type: str,
|
||||
source_id: str,
|
||||
mention_context: str = "",
|
||||
confidence: float = 1.0,
|
||||
mention_context: str = "",
|
||||
confidence: float = 1.0,
|
||||
) -> MultimodalEntity:
|
||||
"""
|
||||
创建多模态实体记录
|
||||
@@ -469,14 +469,14 @@ class MultimodalEntityLinker:
|
||||
多模态实体记录
|
||||
"""
|
||||
return MultimodalEntity(
|
||||
id=str(uuid.uuid4())[:UUID_LENGTH],
|
||||
entity_id=entity_id,
|
||||
project_id=project_id,
|
||||
name="", # 将在后续填充
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
mention_context=mention_context,
|
||||
confidence=confidence,
|
||||
id = str(uuid.uuid4())[:UUID_LENGTH],
|
||||
entity_id = entity_id,
|
||||
project_id = project_id,
|
||||
name = "", # 将在后续填充
|
||||
source_type = source_type,
|
||||
source_id = source_id,
|
||||
mention_context = mention_context,
|
||||
confidence = confidence,
|
||||
)
|
||||
|
||||
def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:
|
||||
@@ -489,7 +489,7 @@ class MultimodalEntityLinker:
|
||||
Returns:
|
||||
模态分布统计
|
||||
"""
|
||||
distribution = {mod: 0 for mod in self.MODALITIES}
|
||||
distribution = {mod: 0 for mod in self.MODALITIES}
|
||||
|
||||
# 统计每个模态的实体数
|
||||
for me in multimodal_entities:
|
||||
@@ -497,13 +497,13 @@ class MultimodalEntityLinker:
|
||||
distribution[me.source_type] += 1
|
||||
|
||||
# 统计跨模态实体
|
||||
entity_modalities = {}
|
||||
entity_modalities = {}
|
||||
for me in multimodal_entities:
|
||||
if me.entity_id not in entity_modalities:
|
||||
entity_modalities[me.entity_id] = set()
|
||||
entity_modalities[me.entity_id] = set()
|
||||
entity_modalities[me.entity_id].add(me.source_type)
|
||||
|
||||
cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1)
|
||||
cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1)
|
||||
|
||||
return {
|
||||
"modality_distribution": distribution,
|
||||
@@ -517,12 +517,12 @@ class MultimodalEntityLinker:
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_multimodal_entity_linker = None
|
||||
_multimodal_entity_linker = None
|
||||
|
||||
|
||||
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
|
||||
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
|
||||
"""获取多模态实体关联器单例"""
|
||||
global _multimodal_entity_linker
|
||||
if _multimodal_entity_linker is None:
|
||||
_multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold)
|
||||
_multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold)
|
||||
return _multimodal_entity_linker
|
||||
|
||||
@@ -13,30 +13,30 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
# Constants
|
||||
UUID_LENGTH = 8 # UUID 截断长度
|
||||
UUID_LENGTH = 8 # UUID 截断长度
|
||||
|
||||
# 尝试导入OCR库
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
|
||||
PYTESSERACT_AVAILABLE = True
|
||||
PYTESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTESSERACT_AVAILABLE = False
|
||||
PYTESSERACT_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import cv2
|
||||
|
||||
CV2_AVAILABLE = True
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import ffmpeg
|
||||
|
||||
FFMPEG_AVAILABLE = True
|
||||
FFMPEG_AVAILABLE = True
|
||||
except ImportError:
|
||||
FFMPEG_AVAILABLE = False
|
||||
FFMPEG_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,13 +48,13 @@ class VideoFrame:
|
||||
frame_number: int
|
||||
timestamp: float
|
||||
frame_path: str
|
||||
ocr_text: str = ""
|
||||
ocr_confidence: float = 0.0
|
||||
entities_detected: list[dict] = None
|
||||
ocr_text: str = ""
|
||||
ocr_confidence: float = 0.0
|
||||
entities_detected: list[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.entities_detected is None:
|
||||
self.entities_detected = []
|
||||
self.entities_detected = []
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -65,20 +65,20 @@ class VideoInfo:
|
||||
project_id: str
|
||||
filename: str
|
||||
file_path: str
|
||||
duration: float = 0.0
|
||||
width: int = 0
|
||||
height: int = 0
|
||||
fps: float = 0.0
|
||||
audio_extracted: bool = False
|
||||
audio_path: str = ""
|
||||
transcript_id: str = ""
|
||||
status: str = "pending"
|
||||
error_message: str = ""
|
||||
metadata: dict = None
|
||||
duration: float = 0.0
|
||||
width: int = 0
|
||||
height: int = 0
|
||||
fps: float = 0.0
|
||||
audio_extracted: bool = False
|
||||
audio_path: str = ""
|
||||
transcript_id: str = ""
|
||||
status: str = "pending"
|
||||
error_message: str = ""
|
||||
metadata: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -91,13 +91,13 @@ class VideoProcessingResult:
|
||||
ocr_results: list[dict]
|
||||
full_text: str # 整合的文本(音频转录 + OCR文本)
|
||||
success: bool
|
||||
error_message: str = ""
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
"""多模态处理器 - 处理视频文件"""
|
||||
|
||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
|
||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
|
||||
"""
|
||||
初始化多模态处理器
|
||||
|
||||
@@ -105,16 +105,16 @@ class MultimodalProcessor:
|
||||
temp_dir: 临时文件目录
|
||||
frame_interval: 关键帧提取间隔(秒)
|
||||
"""
|
||||
self.temp_dir = temp_dir or tempfile.gettempdir()
|
||||
self.frame_interval = frame_interval
|
||||
self.video_dir = os.path.join(self.temp_dir, "videos")
|
||||
self.frames_dir = os.path.join(self.temp_dir, "frames")
|
||||
self.audio_dir = os.path.join(self.temp_dir, "audio")
|
||||
self.temp_dir = temp_dir or tempfile.gettempdir()
|
||||
self.frame_interval = frame_interval
|
||||
self.video_dir = os.path.join(self.temp_dir, "videos")
|
||||
self.frames_dir = os.path.join(self.temp_dir, "frames")
|
||||
self.audio_dir = os.path.join(self.temp_dir, "audio")
|
||||
|
||||
# 创建目录
|
||||
os.makedirs(self.video_dir, exist_ok=True)
|
||||
os.makedirs(self.frames_dir, exist_ok=True)
|
||||
os.makedirs(self.audio_dir, exist_ok=True)
|
||||
os.makedirs(self.video_dir, exist_ok = True)
|
||||
os.makedirs(self.frames_dir, exist_ok = True)
|
||||
os.makedirs(self.audio_dir, exist_ok = True)
|
||||
|
||||
def extract_video_info(self, video_path: str) -> dict:
|
||||
"""
|
||||
@@ -128,11 +128,11 @@ class MultimodalProcessor:
|
||||
"""
|
||||
try:
|
||||
if FFMPEG_AVAILABLE:
|
||||
probe = ffmpeg.probe(video_path)
|
||||
video_stream = next(
|
||||
probe = ffmpeg.probe(video_path)
|
||||
video_stream = next(
|
||||
(s for s in probe["streams"] if s["codec_type"] == "video"), None
|
||||
)
|
||||
audio_stream = next(
|
||||
audio_stream = next(
|
||||
(s for s in probe["streams"] if s["codec_type"] == "audio"), None
|
||||
)
|
||||
|
||||
@@ -147,21 +147,21 @@ class MultimodalProcessor:
|
||||
}
|
||||
else:
|
||||
# 使用 ffprobe 命令行
|
||||
cmd = [
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration,bit_rate",
|
||||
"format = duration, bit_rate",
|
||||
"-show_entries",
|
||||
"stream=width,height,r_frame_rate",
|
||||
"stream = width, height, r_frame_rate",
|
||||
"-of",
|
||||
"json",
|
||||
video_path,
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
result = subprocess.run(cmd, capture_output = True, text = True)
|
||||
if result.returncode == 0:
|
||||
data = json.loads(result.stdout)
|
||||
data = json.loads(result.stdout)
|
||||
return {
|
||||
"duration": float(data["format"].get("duration", 0)),
|
||||
"width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0,
|
||||
@@ -177,7 +177,7 @@ class MultimodalProcessor:
|
||||
|
||||
return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0}
|
||||
|
||||
def extract_audio(self, video_path: str, output_path: str = None) -> str:
|
||||
def extract_audio(self, video_path: str, output_path: str = None) -> str:
|
||||
"""
|
||||
从视频中提取音频
|
||||
|
||||
@@ -189,20 +189,20 @@ class MultimodalProcessor:
|
||||
提取的音频文件路径
|
||||
"""
|
||||
if output_path is None:
|
||||
video_name = Path(video_path).stem
|
||||
output_path = os.path.join(self.audio_dir, f"{video_name}.wav")
|
||||
video_name = Path(video_path).stem
|
||||
output_path = os.path.join(self.audio_dir, f"{video_name}.wav")
|
||||
|
||||
try:
|
||||
if FFMPEG_AVAILABLE:
|
||||
(
|
||||
ffmpeg.input(video_path)
|
||||
.output(output_path, ac=1, ar=16000, vn=None)
|
||||
.output(output_path, ac = 1, ar = 16000, vn = None)
|
||||
.overwrite_output()
|
||||
.run(quiet=True)
|
||||
.run(quiet = True)
|
||||
)
|
||||
else:
|
||||
# 使用命令行 ffmpeg
|
||||
cmd = [
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
video_path,
|
||||
@@ -216,14 +216,14 @@ class MultimodalProcessor:
|
||||
"-y",
|
||||
output_path,
|
||||
]
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
subprocess.run(cmd, check = True, capture_output = True)
|
||||
|
||||
return output_path
|
||||
except Exception as e:
|
||||
print(f"Error extracting audio: {e}")
|
||||
raise
|
||||
|
||||
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
|
||||
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
|
||||
"""
|
||||
从视频中提取关键帧
|
||||
|
||||
@@ -235,31 +235,31 @@ class MultimodalProcessor:
|
||||
Returns:
|
||||
提取的帧文件路径列表
|
||||
"""
|
||||
interval = interval or self.frame_interval
|
||||
frame_paths = []
|
||||
interval = interval or self.frame_interval
|
||||
frame_paths = []
|
||||
|
||||
# 创建帧存储目录
|
||||
video_frames_dir = os.path.join(self.frames_dir, video_id)
|
||||
os.makedirs(video_frames_dir, exist_ok=True)
|
||||
video_frames_dir = os.path.join(self.frames_dir, video_id)
|
||||
os.makedirs(video_frames_dir, exist_ok = True)
|
||||
|
||||
try:
|
||||
if CV2_AVAILABLE:
|
||||
# 使用 OpenCV 提取帧
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
frame_interval_frames = int(fps * interval)
|
||||
frame_number = 0
|
||||
frame_interval_frames = int(fps * interval)
|
||||
frame_number = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if frame_number % frame_interval_frames == 0:
|
||||
timestamp = frame_number / fps
|
||||
frame_path = os.path.join(
|
||||
timestamp = frame_number / fps
|
||||
frame_path = os.path.join(
|
||||
video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg"
|
||||
)
|
||||
cv2.imwrite(frame_path, frame)
|
||||
@@ -271,23 +271,23 @@ class MultimodalProcessor:
|
||||
else:
|
||||
# 使用 ffmpeg 命令行提取帧
|
||||
Path(video_path).stem
|
||||
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
|
||||
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
|
||||
|
||||
cmd = [
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
video_path,
|
||||
"-vf",
|
||||
f"fps=1/{interval}",
|
||||
f"fps = 1/{interval}",
|
||||
"-frame_pts",
|
||||
"1",
|
||||
"-y",
|
||||
output_pattern,
|
||||
]
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
subprocess.run(cmd, check = True, capture_output = True)
|
||||
|
||||
# 获取生成的帧文件列表
|
||||
frame_paths = sorted(
|
||||
frame_paths = sorted(
|
||||
[
|
||||
os.path.join(video_frames_dir, f)
|
||||
for f in os.listdir(video_frames_dir)
|
||||
@@ -313,19 +313,19 @@ class MultimodalProcessor:
|
||||
return "", 0.0
|
||||
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
image = Image.open(image_path)
|
||||
|
||||
# 预处理:转换为灰度图
|
||||
if image.mode != "L":
|
||||
image = image.convert("L")
|
||||
image = image.convert("L")
|
||||
|
||||
# 使用 pytesseract 进行 OCR
|
||||
text = pytesseract.image_to_string(image, lang="chi_sim+eng")
|
||||
text = pytesseract.image_to_string(image, lang = "chi_sim+eng")
|
||||
|
||||
# 获取置信度数据
|
||||
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
|
||||
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||||
data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT)
|
||||
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||||
|
||||
return text.strip(), avg_confidence / 100.0
|
||||
except Exception as e:
|
||||
@@ -333,7 +333,7 @@ class MultimodalProcessor:
|
||||
return "", 0.0
|
||||
|
||||
def process_video(
|
||||
self, video_data: bytes, filename: str, project_id: str, video_id: str = None
|
||||
self, video_data: bytes, filename: str, project_id: str, video_id: str = None
|
||||
) -> VideoProcessingResult:
|
||||
"""
|
||||
处理视频文件:提取音频、关键帧、OCR
|
||||
@@ -347,48 +347,48 @@ class MultimodalProcessor:
|
||||
Returns:
|
||||
视频处理结果
|
||||
"""
|
||||
video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH]
|
||||
video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH]
|
||||
|
||||
try:
|
||||
# 保存视频文件
|
||||
video_path = os.path.join(self.video_dir, f"{video_id}_{filename}")
|
||||
video_path = os.path.join(self.video_dir, f"{video_id}_{filename}")
|
||||
with open(video_path, "wb") as f:
|
||||
f.write(video_data)
|
||||
|
||||
# 提取视频信息
|
||||
video_info = self.extract_video_info(video_path)
|
||||
video_info = self.extract_video_info(video_path)
|
||||
|
||||
# 提取音频
|
||||
audio_path = ""
|
||||
audio_path = ""
|
||||
if video_info["has_audio"]:
|
||||
audio_path = self.extract_audio(video_path)
|
||||
audio_path = self.extract_audio(video_path)
|
||||
|
||||
# 提取关键帧
|
||||
frame_paths = self.extract_keyframes(video_path, video_id)
|
||||
frame_paths = self.extract_keyframes(video_path, video_id)
|
||||
|
||||
# 对关键帧进行 OCR
|
||||
frames = []
|
||||
ocr_results = []
|
||||
all_ocr_text = []
|
||||
frames = []
|
||||
ocr_results = []
|
||||
all_ocr_text = []
|
||||
|
||||
for i, frame_path in enumerate(frame_paths):
|
||||
# 解析帧信息
|
||||
frame_name = os.path.basename(frame_path)
|
||||
parts = frame_name.replace(".jpg", "").split("_")
|
||||
frame_number = int(parts[1]) if len(parts) > 1 else i
|
||||
timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval
|
||||
frame_name = os.path.basename(frame_path)
|
||||
parts = frame_name.replace(".jpg", "").split("_")
|
||||
frame_number = int(parts[1]) if len(parts) > 1 else i
|
||||
timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval
|
||||
|
||||
# OCR 识别
|
||||
ocr_text, confidence = self.perform_ocr(frame_path)
|
||||
ocr_text, confidence = self.perform_ocr(frame_path)
|
||||
|
||||
frame = VideoFrame(
|
||||
id=str(uuid.uuid4())[:UUID_LENGTH],
|
||||
video_id=video_id,
|
||||
frame_number=frame_number,
|
||||
timestamp=timestamp,
|
||||
frame_path=frame_path,
|
||||
ocr_text=ocr_text,
|
||||
ocr_confidence=confidence,
|
||||
frame = VideoFrame(
|
||||
id = str(uuid.uuid4())[:UUID_LENGTH],
|
||||
video_id = video_id,
|
||||
frame_number = frame_number,
|
||||
timestamp = timestamp,
|
||||
frame_path = frame_path,
|
||||
ocr_text = ocr_text,
|
||||
ocr_confidence = confidence,
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
@@ -404,29 +404,29 @@ class MultimodalProcessor:
|
||||
all_ocr_text.append(ocr_text)
|
||||
|
||||
# 整合所有 OCR 文本
|
||||
full_ocr_text = "\n\n".join(all_ocr_text)
|
||||
full_ocr_text = "\n\n".join(all_ocr_text)
|
||||
|
||||
return VideoProcessingResult(
|
||||
video_id=video_id,
|
||||
audio_path=audio_path,
|
||||
frames=frames,
|
||||
ocr_results=ocr_results,
|
||||
full_text=full_ocr_text,
|
||||
success=True,
|
||||
video_id = video_id,
|
||||
audio_path = audio_path,
|
||||
frames = frames,
|
||||
ocr_results = ocr_results,
|
||||
full_text = full_ocr_text,
|
||||
success = True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return VideoProcessingResult(
|
||||
video_id=video_id,
|
||||
audio_path="",
|
||||
frames=[],
|
||||
ocr_results=[],
|
||||
full_text="",
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
video_id = video_id,
|
||||
audio_path = "",
|
||||
frames = [],
|
||||
ocr_results = [],
|
||||
full_text = "",
|
||||
success = False,
|
||||
error_message = str(e),
|
||||
)
|
||||
|
||||
def cleanup(self, video_id: str = None) -> None:
|
||||
def cleanup(self, video_id: str = None) -> None:
|
||||
"""
|
||||
清理临时文件
|
||||
|
||||
@@ -438,7 +438,7 @@ class MultimodalProcessor:
|
||||
if video_id:
|
||||
# 清理特定视频的文件
|
||||
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
|
||||
target_dir = (
|
||||
target_dir = (
|
||||
os.path.join(dir_path, video_id) if dir_path == self.frames_dir else dir_path
|
||||
)
|
||||
if os.path.exists(target_dir):
|
||||
@@ -450,16 +450,16 @@ class MultimodalProcessor:
|
||||
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
|
||||
if os.path.exists(dir_path):
|
||||
shutil.rmtree(dir_path)
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
os.makedirs(dir_path, exist_ok = True)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_multimodal_processor = None
|
||||
_multimodal_processor = None
|
||||
|
||||
|
||||
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
|
||||
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
|
||||
"""获取多模态处理器单例"""
|
||||
global _multimodal_processor
|
||||
if _multimodal_processor is None:
|
||||
_multimodal_processor = MultimodalProcessor(temp_dir, frame_interval)
|
||||
_multimodal_processor = MultimodalProcessor(temp_dir, frame_interval)
|
||||
return _multimodal_processor
|
||||
|
||||
@@ -10,20 +10,20 @@ import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Neo4j 连接配置
|
||||
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
||||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
|
||||
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
||||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
|
||||
|
||||
# 延迟导入,避免未安装时出错
|
||||
try:
|
||||
from neo4j import Driver, GraphDatabase
|
||||
|
||||
NEO4J_AVAILABLE = True
|
||||
NEO4J_AVAILABLE = True
|
||||
except ImportError:
|
||||
NEO4J_AVAILABLE = False
|
||||
NEO4J_AVAILABLE = False
|
||||
logger.warning("Neo4j driver not installed. Neo4j features will be disabled.")
|
||||
|
||||
|
||||
@@ -35,15 +35,15 @@ class GraphEntity:
|
||||
project_id: str
|
||||
name: str
|
||||
type: str
|
||||
definition: str = ""
|
||||
aliases: list[str] = None
|
||||
properties: dict = None
|
||||
definition: str = ""
|
||||
aliases: list[str] = None
|
||||
properties: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.aliases is None:
|
||||
self.aliases = []
|
||||
self.aliases = []
|
||||
if self.properties is None:
|
||||
self.properties = {}
|
||||
self.properties = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -54,12 +54,12 @@ class GraphRelation:
|
||||
source_id: str
|
||||
target_id: str
|
||||
relation_type: str
|
||||
evidence: str = ""
|
||||
properties: dict = None
|
||||
evidence: str = ""
|
||||
properties: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.properties is None:
|
||||
self.properties = {}
|
||||
self.properties = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -69,7 +69,7 @@ class PathResult:
|
||||
nodes: list[dict]
|
||||
relationships: list[dict]
|
||||
length: int
|
||||
total_weight: float = 0.0
|
||||
total_weight: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -79,7 +79,7 @@ class CommunityResult:
|
||||
community_id: int
|
||||
nodes: list[dict]
|
||||
size: int
|
||||
density: float = 0.0
|
||||
density: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,17 +89,17 @@ class CentralityResult:
|
||||
entity_id: str
|
||||
entity_name: str
|
||||
score: float
|
||||
rank: int = 0
|
||||
rank: int = 0
|
||||
|
||||
|
||||
class Neo4jManager:
|
||||
"""Neo4j 图数据库管理器"""
|
||||
|
||||
def __init__(self, uri: str = None, user: str = None, password: str = None):
|
||||
self.uri = uri or NEO4J_URI
|
||||
self.user = user or NEO4J_USER
|
||||
self.password = password or NEO4J_PASSWORD
|
||||
self._driver: Driver | None = None
|
||||
def __init__(self, uri: str = None, user: str = None, password: str = None) -> None:
|
||||
self.uri = uri or NEO4J_URI
|
||||
self.user = user or NEO4J_USER
|
||||
self.password = password or NEO4J_PASSWORD
|
||||
self._driver: Driver | None = None
|
||||
|
||||
if not NEO4J_AVAILABLE:
|
||||
logger.error("Neo4j driver not available. Please install: pip install neo4j")
|
||||
@@ -113,13 +113,13 @@ class Neo4jManager:
|
||||
return
|
||||
|
||||
try:
|
||||
self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
|
||||
self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password))
|
||||
# 验证连接
|
||||
self._driver.verify_connectivity()
|
||||
logger.info(f"Connected to Neo4j at {self.uri}")
|
||||
except (RuntimeError, ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to connect to Neo4j: {e}")
|
||||
self._driver = None
|
||||
self._driver = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
@@ -179,7 +179,7 @@ class Neo4jManager:
|
||||
# ==================== 数据同步 ====================
|
||||
|
||||
def sync_project(
|
||||
self, project_id: str, project_name: str, project_description: str = ""
|
||||
self, project_id: str, project_name: str, project_description: str = ""
|
||||
) -> None:
|
||||
"""同步项目节点到 Neo4j"""
|
||||
if not self._driver:
|
||||
@@ -189,13 +189,13 @@ class Neo4jManager:
|
||||
session.run(
|
||||
"""
|
||||
MERGE (p:Project {id: $project_id})
|
||||
SET p.name = $name,
|
||||
p.description = $description,
|
||||
p.updated_at = datetime()
|
||||
SET p.name = $name,
|
||||
p.description = $description,
|
||||
p.updated_at = datetime()
|
||||
""",
|
||||
project_id=project_id,
|
||||
name=project_name,
|
||||
description=project_description,
|
||||
project_id = project_id,
|
||||
name = project_name,
|
||||
description = project_description,
|
||||
)
|
||||
|
||||
def sync_entity(self, entity: GraphEntity) -> None:
|
||||
@@ -208,23 +208,23 @@ class Neo4jManager:
|
||||
session.run(
|
||||
"""
|
||||
MERGE (e:Entity {id: $id})
|
||||
SET e.name = $name,
|
||||
e.type = $type,
|
||||
e.definition = $definition,
|
||||
e.aliases = $aliases,
|
||||
e.properties = $properties,
|
||||
e.updated_at = datetime()
|
||||
SET e.name = $name,
|
||||
e.type = $type,
|
||||
e.definition = $definition,
|
||||
e.aliases = $aliases,
|
||||
e.properties = $properties,
|
||||
e.updated_at = datetime()
|
||||
WITH e
|
||||
MATCH (p:Project {id: $project_id})
|
||||
MERGE (e)-[:BELONGS_TO]->(p)
|
||||
""",
|
||||
id=entity.id,
|
||||
project_id=entity.project_id,
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
definition=entity.definition,
|
||||
aliases=json.dumps(entity.aliases),
|
||||
properties=json.dumps(entity.properties),
|
||||
id = entity.id,
|
||||
project_id = entity.project_id,
|
||||
name = entity.name,
|
||||
type = entity.type,
|
||||
definition = entity.definition,
|
||||
aliases = json.dumps(entity.aliases),
|
||||
properties = json.dumps(entity.properties),
|
||||
)
|
||||
|
||||
def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
|
||||
@@ -234,7 +234,7 @@ class Neo4jManager:
|
||||
|
||||
with self._driver.session() as session:
|
||||
# 使用 UNWIND 批量处理
|
||||
entities_data = [
|
||||
entities_data = [
|
||||
{
|
||||
"id": e.id,
|
||||
"project_id": e.project_id,
|
||||
@@ -251,17 +251,17 @@ class Neo4jManager:
|
||||
"""
|
||||
UNWIND $entities AS entity
|
||||
MERGE (e:Entity {id: entity.id})
|
||||
SET e.name = entity.name,
|
||||
e.type = entity.type,
|
||||
e.definition = entity.definition,
|
||||
e.aliases = entity.aliases,
|
||||
e.properties = entity.properties,
|
||||
e.updated_at = datetime()
|
||||
SET e.name = entity.name,
|
||||
e.type = entity.type,
|
||||
e.definition = entity.definition,
|
||||
e.aliases = entity.aliases,
|
||||
e.properties = entity.properties,
|
||||
e.updated_at = datetime()
|
||||
WITH e, entity
|
||||
MATCH (p:Project {id: entity.project_id})
|
||||
MERGE (e)-[:BELONGS_TO]->(p)
|
||||
""",
|
||||
entities=entities_data,
|
||||
entities = entities_data,
|
||||
)
|
||||
|
||||
def sync_relation(self, relation: GraphRelation) -> None:
|
||||
@@ -275,17 +275,17 @@ class Neo4jManager:
|
||||
MATCH (source:Entity {id: $source_id})
|
||||
MATCH (target:Entity {id: $target_id})
|
||||
MERGE (source)-[r:RELATES_TO {id: $id}]->(target)
|
||||
SET r.relation_type = $relation_type,
|
||||
r.evidence = $evidence,
|
||||
r.properties = $properties,
|
||||
r.updated_at = datetime()
|
||||
SET r.relation_type = $relation_type,
|
||||
r.evidence = $evidence,
|
||||
r.properties = $properties,
|
||||
r.updated_at = datetime()
|
||||
""",
|
||||
id=relation.id,
|
||||
source_id=relation.source_id,
|
||||
target_id=relation.target_id,
|
||||
relation_type=relation.relation_type,
|
||||
evidence=relation.evidence,
|
||||
properties=json.dumps(relation.properties),
|
||||
id = relation.id,
|
||||
source_id = relation.source_id,
|
||||
target_id = relation.target_id,
|
||||
relation_type = relation.relation_type,
|
||||
evidence = relation.evidence,
|
||||
properties = json.dumps(relation.properties),
|
||||
)
|
||||
|
||||
def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
|
||||
@@ -294,7 +294,7 @@ class Neo4jManager:
|
||||
return
|
||||
|
||||
with self._driver.session() as session:
|
||||
relations_data = [
|
||||
relations_data = [
|
||||
{
|
||||
"id": r.id,
|
||||
"source_id": r.source_id,
|
||||
@@ -312,12 +312,12 @@ class Neo4jManager:
|
||||
MATCH (source:Entity {id: rel.source_id})
|
||||
MATCH (target:Entity {id: rel.target_id})
|
||||
MERGE (source)-[r:RELATES_TO {id: rel.id}]->(target)
|
||||
SET r.relation_type = rel.relation_type,
|
||||
r.evidence = rel.evidence,
|
||||
r.properties = rel.properties,
|
||||
r.updated_at = datetime()
|
||||
SET r.relation_type = rel.relation_type,
|
||||
r.evidence = rel.evidence,
|
||||
r.properties = rel.properties,
|
||||
r.updated_at = datetime()
|
||||
""",
|
||||
relations=relations_data,
|
||||
relations = relations_data,
|
||||
)
|
||||
|
||||
def delete_entity(self, entity_id: str) -> None:
|
||||
@@ -331,7 +331,7 @@ class Neo4jManager:
|
||||
MATCH (e:Entity {id: $id})
|
||||
DETACH DELETE e
|
||||
""",
|
||||
id=entity_id,
|
||||
id = entity_id,
|
||||
)
|
||||
|
||||
def delete_project(self, project_id: str) -> None:
|
||||
@@ -346,13 +346,13 @@ class Neo4jManager:
|
||||
OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p)
|
||||
DETACH DELETE e, p
|
||||
""",
|
||||
id=project_id,
|
||||
id = project_id,
|
||||
)
|
||||
|
||||
# ==================== 复杂图查询 ====================
|
||||
|
||||
def find_shortest_path(
|
||||
self, source_id: str, target_id: str, max_depth: int = 10
|
||||
self, source_id: str, target_id: str, max_depth: int = 10
|
||||
) -> PathResult | None:
|
||||
"""
|
||||
查找两个实体之间的最短路径
|
||||
@@ -369,31 +369,31 @@ class Neo4jManager:
|
||||
return None
|
||||
|
||||
with self._driver.session() as session:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH path = shortestPath(
|
||||
MATCH path = shortestPath(
|
||||
(source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
|
||||
)
|
||||
RETURN path
|
||||
""",
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
max_depth=max_depth,
|
||||
source_id = source_id,
|
||||
target_id = target_id,
|
||||
max_depth = max_depth,
|
||||
)
|
||||
|
||||
record = result.single()
|
||||
record = result.single()
|
||||
if not record:
|
||||
return None
|
||||
|
||||
path = record["path"]
|
||||
path = record["path"]
|
||||
|
||||
# 提取节点和关系
|
||||
nodes = [
|
||||
nodes = [
|
||||
{"id": node["id"], "name": node["name"], "type": node["type"]}
|
||||
for node in path.nodes
|
||||
]
|
||||
|
||||
relationships = [
|
||||
relationships = [
|
||||
{
|
||||
"source": rel.start_node["id"],
|
||||
"target": rel.end_node["id"],
|
||||
@@ -404,11 +404,11 @@ class Neo4jManager:
|
||||
]
|
||||
|
||||
return PathResult(
|
||||
nodes=nodes, relationships=relationships, length=len(path.relationships)
|
||||
nodes = nodes, relationships = relationships, length = len(path.relationships)
|
||||
)
|
||||
|
||||
def find_all_paths(
|
||||
self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10
|
||||
self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10
|
||||
) -> list[PathResult]:
|
||||
"""
|
||||
查找两个实体之间的所有路径
|
||||
@@ -426,29 +426,29 @@ class Neo4jManager:
|
||||
return []
|
||||
|
||||
with self._driver.session() as session:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
|
||||
MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
|
||||
WHERE source <> target
|
||||
RETURN path
|
||||
LIMIT $limit
|
||||
""",
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
max_depth=max_depth,
|
||||
limit=limit,
|
||||
source_id = source_id,
|
||||
target_id = target_id,
|
||||
max_depth = max_depth,
|
||||
limit = limit,
|
||||
)
|
||||
|
||||
paths = []
|
||||
paths = []
|
||||
for record in result:
|
||||
path = record["path"]
|
||||
path = record["path"]
|
||||
|
||||
nodes = [
|
||||
nodes = [
|
||||
{"id": node["id"], "name": node["name"], "type": node["type"]}
|
||||
for node in path.nodes
|
||||
]
|
||||
|
||||
relationships = [
|
||||
relationships = [
|
||||
{
|
||||
"source": rel.start_node["id"],
|
||||
"target": rel.end_node["id"],
|
||||
@@ -460,14 +460,14 @@ class Neo4jManager:
|
||||
|
||||
paths.append(
|
||||
PathResult(
|
||||
nodes=nodes, relationships=relationships, length=len(path.relationships)
|
||||
nodes = nodes, relationships = relationships, length = len(path.relationships)
|
||||
)
|
||||
)
|
||||
|
||||
return paths
|
||||
|
||||
def find_neighbors(
|
||||
self, entity_id: str, relation_type: str = None, limit: int = 50
|
||||
self, entity_id: str, relation_type: str = None, limit: int = 50
|
||||
) -> list[dict]:
|
||||
"""
|
||||
查找实体的邻居节点
|
||||
@@ -485,30 +485,30 @@ class Neo4jManager:
|
||||
|
||||
with self._driver.session() as session:
|
||||
if relation_type:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]-(neighbor:Entity)
|
||||
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
|
||||
LIMIT $limit
|
||||
""",
|
||||
entity_id=entity_id,
|
||||
relation_type=relation_type,
|
||||
limit=limit,
|
||||
entity_id = entity_id,
|
||||
relation_type = relation_type,
|
||||
limit = limit,
|
||||
)
|
||||
else:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO]-(neighbor:Entity)
|
||||
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
|
||||
LIMIT $limit
|
||||
""",
|
||||
entity_id=entity_id,
|
||||
limit=limit,
|
||||
entity_id = entity_id,
|
||||
limit = limit,
|
||||
)
|
||||
|
||||
neighbors = []
|
||||
neighbors = []
|
||||
for record in result:
|
||||
node = record["neighbor"]
|
||||
node = record["neighbor"]
|
||||
neighbors.append(
|
||||
{
|
||||
"id": node["id"],
|
||||
@@ -536,13 +536,13 @@ class Neo4jManager:
|
||||
return []
|
||||
|
||||
with self._driver.session() as session:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2})
|
||||
RETURN DISTINCT common
|
||||
""",
|
||||
id1=entity_id1,
|
||||
id2=entity_id2,
|
||||
id1 = entity_id1,
|
||||
id2 = entity_id2,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -556,7 +556,7 @@ class Neo4jManager:
|
||||
|
||||
# ==================== 图算法分析 ====================
|
||||
|
||||
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
|
||||
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
|
||||
"""
|
||||
计算 PageRank 中心性
|
||||
|
||||
@@ -571,7 +571,7 @@ class Neo4jManager:
|
||||
return []
|
||||
|
||||
with self._driver.session() as session:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
CALL gds.graph.exists('project-graph-$project_id') YIELD exists
|
||||
WITH exists
|
||||
@@ -581,7 +581,7 @@ class Neo4jManager:
|
||||
{}
|
||||
) YIELD value RETURN value
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
# 创建临时图
|
||||
@@ -601,11 +601,11 @@ class Neo4jManager:
|
||||
}
|
||||
)
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
# 运行 PageRank
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
CALL gds.pageRank.stream('project-graph-$project_id')
|
||||
YIELD nodeId, score
|
||||
@@ -615,19 +615,19 @@ class Neo4jManager:
|
||||
ORDER BY score DESC
|
||||
LIMIT $top_n
|
||||
""",
|
||||
project_id=project_id,
|
||||
top_n=top_n,
|
||||
project_id = project_id,
|
||||
top_n = top_n,
|
||||
)
|
||||
|
||||
rankings = []
|
||||
rank = 1
|
||||
rankings = []
|
||||
rank = 1
|
||||
for record in result:
|
||||
rankings.append(
|
||||
CentralityResult(
|
||||
entity_id=record["entity_id"],
|
||||
entity_name=record["entity_name"],
|
||||
score=record["score"],
|
||||
rank=rank,
|
||||
entity_id = record["entity_id"],
|
||||
entity_name = record["entity_name"],
|
||||
score = record["score"],
|
||||
rank = rank,
|
||||
)
|
||||
)
|
||||
rank += 1
|
||||
@@ -637,12 +637,12 @@ class Neo4jManager:
|
||||
"""
|
||||
CALL gds.graph.drop('project-graph-$project_id')
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
return rankings
|
||||
|
||||
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
|
||||
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
|
||||
"""
|
||||
计算 Betweenness 中心性(桥梁作用)
|
||||
|
||||
@@ -658,7 +658,7 @@ class Neo4jManager:
|
||||
|
||||
with self._driver.session() as session:
|
||||
# 使用 APOC 的 betweenness 计算(如果没有 GDS)
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
|
||||
@@ -667,19 +667,19 @@ class Neo4jManager:
|
||||
LIMIT $top_n
|
||||
RETURN e.id as entity_id, e.name as entity_name, degree as score
|
||||
""",
|
||||
project_id=project_id,
|
||||
top_n=top_n,
|
||||
project_id = project_id,
|
||||
top_n = top_n,
|
||||
)
|
||||
|
||||
rankings = []
|
||||
rank = 1
|
||||
rankings = []
|
||||
rank = 1
|
||||
for record in result:
|
||||
rankings.append(
|
||||
CentralityResult(
|
||||
entity_id=record["entity_id"],
|
||||
entity_name=record["entity_name"],
|
||||
score=float(record["score"]),
|
||||
rank=rank,
|
||||
entity_id = record["entity_id"],
|
||||
entity_name = record["entity_name"],
|
||||
score = float(record["score"]),
|
||||
rank = rank,
|
||||
)
|
||||
)
|
||||
rank += 1
|
||||
@@ -701,7 +701,7 @@ class Neo4jManager:
|
||||
|
||||
with self._driver.session() as session:
|
||||
# 简单的社区检测:基于连通分量
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)-[:BELONGS_TO]->(p)
|
||||
@@ -710,25 +710,25 @@ class Neo4jManager:
|
||||
connections, size(connections) as connection_count
|
||||
ORDER BY connection_count DESC
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
# 手动分组(基于连通性)
|
||||
communities = {}
|
||||
communities = {}
|
||||
for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
connections = record["connections"]
|
||||
entity_id = record["entity_id"]
|
||||
connections = record["connections"]
|
||||
|
||||
# 找到所属的社区
|
||||
found_community = None
|
||||
found_community = None
|
||||
for comm_id, comm_data in communities.items():
|
||||
if any(conn in comm_data["member_ids"] for conn in connections):
|
||||
found_community = comm_id
|
||||
found_community = comm_id
|
||||
break
|
||||
|
||||
if found_community is None:
|
||||
found_community = len(communities)
|
||||
communities[found_community] = {"member_ids": set(), "nodes": []}
|
||||
found_community = len(communities)
|
||||
communities[found_community] = {"member_ids": set(), "nodes": []}
|
||||
|
||||
communities[found_community]["member_ids"].add(entity_id)
|
||||
communities[found_community]["nodes"].append(
|
||||
@@ -741,27 +741,27 @@ class Neo4jManager:
|
||||
)
|
||||
|
||||
# 构建结果
|
||||
results = []
|
||||
results = []
|
||||
for comm_id, comm_data in communities.items():
|
||||
nodes = comm_data["nodes"]
|
||||
size = len(nodes)
|
||||
nodes = comm_data["nodes"]
|
||||
size = len(nodes)
|
||||
# 计算密度(简化版)
|
||||
max_edges = size * (size - 1) / 2 if size > 1 else 1
|
||||
actual_edges = sum(n["connections"] for n in nodes) / 2
|
||||
density = actual_edges / max_edges if max_edges > 0 else 0
|
||||
max_edges = size * (size - 1) / 2 if size > 1 else 1
|
||||
actual_edges = sum(n["connections"] for n in nodes) / 2
|
||||
density = actual_edges / max_edges if max_edges > 0 else 0
|
||||
|
||||
results.append(
|
||||
CommunityResult(
|
||||
community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0)
|
||||
community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0)
|
||||
)
|
||||
)
|
||||
|
||||
# 按大小排序
|
||||
results.sort(key=lambda x: x.size, reverse=True)
|
||||
results.sort(key = lambda x: x.size, reverse = True)
|
||||
return results
|
||||
|
||||
def find_central_entities(
|
||||
self, project_id: str, metric: str = "degree"
|
||||
self, project_id: str, metric: str = "degree"
|
||||
) -> list[CentralityResult]:
|
||||
"""
|
||||
查找中心实体
|
||||
@@ -778,7 +778,7 @@ class Neo4jManager:
|
||||
|
||||
with self._driver.session() as session:
|
||||
if metric == "degree":
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
|
||||
@@ -787,11 +787,11 @@ class Neo4jManager:
|
||||
ORDER BY degree DESC
|
||||
LIMIT 20
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
else:
|
||||
# 默认使用度中心性
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
|
||||
@@ -800,18 +800,18 @@ class Neo4jManager:
|
||||
ORDER BY degree DESC
|
||||
LIMIT 20
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
rankings = []
|
||||
rank = 1
|
||||
rankings = []
|
||||
rank = 1
|
||||
for record in result:
|
||||
rankings.append(
|
||||
CentralityResult(
|
||||
entity_id=record["entity_id"],
|
||||
entity_name=record["entity_name"],
|
||||
score=float(record["score"]),
|
||||
rank=rank,
|
||||
entity_id = record["entity_id"],
|
||||
entity_name = record["entity_name"],
|
||||
score = float(record["score"]),
|
||||
rank = rank,
|
||||
)
|
||||
)
|
||||
rank += 1
|
||||
@@ -835,49 +835,49 @@ class Neo4jManager:
|
||||
|
||||
with self._driver.session() as session:
|
||||
# 实体数量
|
||||
entity_count = session.run(
|
||||
entity_count = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
RETURN count(e) as count
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
).single()["count"]
|
||||
|
||||
# 关系数量
|
||||
relation_count = session.run(
|
||||
relation_count = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
MATCH (e)-[r:RELATES_TO]-()
|
||||
RETURN count(r) as count
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
).single()["count"]
|
||||
|
||||
# 实体类型分布
|
||||
type_distribution = session.run(
|
||||
type_distribution = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
RETURN e.type as type, count(e) as count
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
types = {record["type"]: record["count"] for record in type_distribution}
|
||||
types = {record["type"]: record["count"] for record in type_distribution}
|
||||
|
||||
# 平均度
|
||||
avg_degree = session.run(
|
||||
avg_degree = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
OPTIONAL MATCH (e)-[:RELATES_TO]-(other)
|
||||
WITH e, count(other) as degree
|
||||
RETURN avg(degree) as avg_degree
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
).single()["avg_degree"]
|
||||
|
||||
# 关系类型分布
|
||||
rel_types = session.run(
|
||||
rel_types = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
|
||||
MATCH (e)-[r:RELATES_TO]-()
|
||||
@@ -885,10 +885,10 @@ class Neo4jManager:
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
project_id=project_id,
|
||||
project_id = project_id,
|
||||
)
|
||||
|
||||
relation_types = {record["type"]: record["count"] for record in rel_types}
|
||||
relation_types = {record["type"]: record["count"] for record in rel_types}
|
||||
|
||||
return {
|
||||
"entity_count": entity_count,
|
||||
@@ -901,7 +901,7 @@ class Neo4jManager:
|
||||
else 0,
|
||||
}
|
||||
|
||||
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
|
||||
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
|
||||
"""
|
||||
获取指定实体的子图
|
||||
|
||||
@@ -916,7 +916,7 @@ class Neo4jManager:
|
||||
return {"nodes": [], "relationships": []}
|
||||
|
||||
with self._driver.session() as session:
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (e:Entity)
|
||||
WHERE e.id IN $entity_ids
|
||||
@@ -927,14 +927,14 @@ class Neo4jManager:
|
||||
}) YIELD node
|
||||
RETURN DISTINCT node
|
||||
""",
|
||||
entity_ids=entity_ids,
|
||||
depth=depth,
|
||||
entity_ids = entity_ids,
|
||||
depth = depth,
|
||||
)
|
||||
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
for record in result:
|
||||
node = record["node"]
|
||||
node = record["node"]
|
||||
node_ids.add(node["id"])
|
||||
nodes.append(
|
||||
{
|
||||
@@ -946,17 +946,17 @@ class Neo4jManager:
|
||||
)
|
||||
|
||||
# 获取这些节点之间的关系
|
||||
result = session.run(
|
||||
result = session.run(
|
||||
"""
|
||||
MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity)
|
||||
WHERE source.id IN $node_ids AND target.id IN $node_ids
|
||||
RETURN source.id as source_id, target.id as target_id,
|
||||
r.relation_type as type, r.evidence as evidence
|
||||
""",
|
||||
node_ids=list(node_ids),
|
||||
node_ids = list(node_ids),
|
||||
)
|
||||
|
||||
relationships = [
|
||||
relationships = [
|
||||
{
|
||||
"source": record["source_id"],
|
||||
"target": record["target_id"],
|
||||
@@ -970,14 +970,14 @@ class Neo4jManager:
|
||||
|
||||
|
||||
# 全局单例
|
||||
_neo4j_manager = None
|
||||
_neo4j_manager = None
|
||||
|
||||
|
||||
def get_neo4j_manager() -> Neo4jManager:
|
||||
"""获取 Neo4j 管理器单例"""
|
||||
global _neo4j_manager
|
||||
if _neo4j_manager is None:
|
||||
_neo4j_manager = Neo4jManager()
|
||||
_neo4j_manager = Neo4jManager()
|
||||
return _neo4j_manager
|
||||
|
||||
|
||||
@@ -986,7 +986,7 @@ def close_neo4j_manager() -> None:
|
||||
global _neo4j_manager
|
||||
if _neo4j_manager:
|
||||
_neo4j_manager.close()
|
||||
_neo4j_manager = None
|
||||
_neo4j_manager = None
|
||||
|
||||
|
||||
# 便捷函数
|
||||
@@ -1004,7 +1004,7 @@ def sync_project_to_neo4j(
|
||||
entities: 实体列表(字典格式)
|
||||
relations: 关系列表(字典格式)
|
||||
"""
|
||||
manager = get_neo4j_manager()
|
||||
manager = get_neo4j_manager()
|
||||
if not manager.is_connected():
|
||||
logger.warning("Neo4j not connected, skipping sync")
|
||||
return
|
||||
@@ -1013,29 +1013,29 @@ def sync_project_to_neo4j(
|
||||
manager.sync_project(project_id, project_name)
|
||||
|
||||
# 同步实体
|
||||
graph_entities = [
|
||||
graph_entities = [
|
||||
GraphEntity(
|
||||
id=e["id"],
|
||||
project_id=project_id,
|
||||
name=e["name"],
|
||||
type=e.get("type", "unknown"),
|
||||
definition=e.get("definition", ""),
|
||||
aliases=e.get("aliases", []),
|
||||
properties=e.get("properties", {}),
|
||||
id = e["id"],
|
||||
project_id = project_id,
|
||||
name = e["name"],
|
||||
type = e.get("type", "unknown"),
|
||||
definition = e.get("definition", ""),
|
||||
aliases = e.get("aliases", []),
|
||||
properties = e.get("properties", {}),
|
||||
)
|
||||
for e in entities
|
||||
]
|
||||
manager.sync_entities_batch(graph_entities)
|
||||
|
||||
# 同步关系
|
||||
graph_relations = [
|
||||
graph_relations = [
|
||||
GraphRelation(
|
||||
id=r["id"],
|
||||
source_id=r["source_entity_id"],
|
||||
target_id=r["target_entity_id"],
|
||||
relation_type=r["relation_type"],
|
||||
evidence=r.get("evidence", ""),
|
||||
properties=r.get("properties", {}),
|
||||
id = r["id"],
|
||||
source_id = r["source_entity_id"],
|
||||
target_id = r["target_entity_id"],
|
||||
relation_type = r["relation_type"],
|
||||
evidence = r.get("evidence", ""),
|
||||
properties = r.get("properties", {}),
|
||||
)
|
||||
for r in relations
|
||||
]
|
||||
@@ -1048,9 +1048,9 @@ def sync_project_to_neo4j(
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level = logging.INFO)
|
||||
|
||||
manager = Neo4jManager()
|
||||
manager = Neo4jManager()
|
||||
|
||||
if manager.is_connected():
|
||||
print("✅ Connected to Neo4j")
|
||||
@@ -1064,18 +1064,18 @@ if __name__ == "__main__":
|
||||
print("✅ Project synced")
|
||||
|
||||
# 测试实体
|
||||
test_entity = GraphEntity(
|
||||
id="test-entity-1",
|
||||
project_id="test-project",
|
||||
name="Test Entity",
|
||||
type="Person",
|
||||
definition="A test entity",
|
||||
test_entity = GraphEntity(
|
||||
id = "test-entity-1",
|
||||
project_id = "test-project",
|
||||
name = "Test Entity",
|
||||
type = "Person",
|
||||
definition = "A test entity",
|
||||
)
|
||||
manager.sync_entity(test_entity)
|
||||
print("✅ Entity synced")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_graph_stats("test-project")
|
||||
stats = manager.get_graph_stats("test-project")
|
||||
print(f"📊 Graph stats: {stats}")
|
||||
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,30 +11,30 @@ import oss2
|
||||
|
||||
|
||||
class OSSUploader:
|
||||
def __init__(self):
|
||||
self.access_key = os.getenv("ALI_ACCESS_KEY")
|
||||
self.secret_key = os.getenv("ALI_SECRET_KEY")
|
||||
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")
|
||||
self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com")
|
||||
self.endpoint = f"https://{self.region}"
|
||||
def __init__(self) -> None:
|
||||
self.access_key = os.getenv("ALI_ACCESS_KEY")
|
||||
self.secret_key = os.getenv("ALI_SECRET_KEY")
|
||||
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")
|
||||
self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com")
|
||||
self.endpoint = f"https://{self.region}"
|
||||
|
||||
if not self.access_key or not self.secret_key:
|
||||
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY must be set")
|
||||
|
||||
self.auth = oss2.Auth(self.access_key, self.secret_key)
|
||||
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
|
||||
self.auth = oss2.Auth(self.access_key, self.secret_key)
|
||||
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
|
||||
|
||||
def upload_audio(self, audio_data: bytes, filename: str) -> tuple:
|
||||
"""上传音频到 OSS,返回 (URL, object_name)"""
|
||||
# 生成唯一文件名
|
||||
ext = os.path.splitext(filename)[1] or ".wav"
|
||||
object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}"
|
||||
ext = os.path.splitext(filename)[1] or ".wav"
|
||||
object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}"
|
||||
|
||||
# 上传文件
|
||||
self.bucket.put_object(object_name, audio_data)
|
||||
|
||||
# 生成临时访问 URL (1小时有效)
|
||||
url = self.bucket.sign_url("GET", object_name, 3600)
|
||||
url = self.bucket.sign_url("GET", object_name, 3600)
|
||||
return url, object_name
|
||||
|
||||
def delete_object(self, object_name: str) -> None:
|
||||
@@ -43,11 +43,11 @@ class OSSUploader:
|
||||
|
||||
|
||||
# 单例
|
||||
_oss_uploader = None
|
||||
_oss_uploader = None
|
||||
|
||||
|
||||
def get_oss_uploader() -> OSSUploader:
|
||||
global _oss_uploader
|
||||
if _oss_uploader is None:
|
||||
_oss_uploader = OSSUploader()
|
||||
_oss_uploader = OSSUploader()
|
||||
return _oss_uploader
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -17,9 +17,9 @@ from functools import wraps
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
|
||||
requests_per_minute: int = 60
|
||||
burst_size: int = 10 # 突发请求数
|
||||
window_size: int = 60 # 窗口大小(秒)
|
||||
requests_per_minute: int = 60
|
||||
burst_size: int = 10 # 突发请求数
|
||||
window_size: int = 60 # 窗口大小(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,16 +35,16 @@ class RateLimitInfo:
|
||||
class SlidingWindowCounter:
|
||||
"""滑动窗口计数器"""
|
||||
|
||||
def __init__(self, window_size: int = 60):
|
||||
self.window_size = window_size
|
||||
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
def __init__(self, window_size: int = 60) -> None:
|
||||
self.window_size = window_size
|
||||
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
|
||||
async def add_request(self) -> int:
|
||||
"""添加请求,返回当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
now = int(time.time())
|
||||
self.requests[now] += 1
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
@@ -52,14 +52,14 @@ class SlidingWindowCounter:
|
||||
async def get_count(self) -> int:
|
||||
"""获取当前窗口内的请求数"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
now = int(time.time())
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
def _cleanup_old(self, now: int) -> None:
|
||||
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||
for k in old_keys:
|
||||
self.requests.pop(k, None)
|
||||
|
||||
@@ -69,13 +69,13 @@ class RateLimiter:
|
||||
|
||||
def __init__(self) -> None:
|
||||
# key -> SlidingWindowCounter
|
||||
self.counters: dict[str, SlidingWindowCounter] = {}
|
||||
self.counters: dict[str, SlidingWindowCounter] = {}
|
||||
# key -> RateLimitConfig
|
||||
self.configs: dict[str, RateLimitConfig] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self.configs: dict[str, RateLimitConfig] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
|
||||
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
|
||||
"""
|
||||
检查是否允许请求
|
||||
|
||||
@@ -87,70 +87,70 @@ class RateLimiter:
|
||||
RateLimitInfo
|
||||
"""
|
||||
if config is None:
|
||||
config = RateLimitConfig()
|
||||
config = RateLimitConfig()
|
||||
|
||||
async with self._lock:
|
||||
if key not in self.counters:
|
||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||
self.configs[key] = config
|
||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||
self.configs[key] = config
|
||||
|
||||
counter = self.counters[key]
|
||||
stored_config = self.configs.get(key, config)
|
||||
counter = self.counters[key]
|
||||
stored_config = self.configs.get(key, config)
|
||||
|
||||
# 获取当前计数
|
||||
current_count = await counter.get_count()
|
||||
current_count = await counter.get_count()
|
||||
|
||||
# 计算剩余配额
|
||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||
|
||||
# 计算重置时间
|
||||
now = int(time.time())
|
||||
reset_time = now + stored_config.window_size
|
||||
now = int(time.time())
|
||||
reset_time = now + stored_config.window_size
|
||||
|
||||
# 检查是否超过限制
|
||||
if current_count >= stored_config.requests_per_minute:
|
||||
return RateLimitInfo(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
reset_time=reset_time,
|
||||
retry_after=stored_config.window_size,
|
||||
allowed = False,
|
||||
remaining = 0,
|
||||
reset_time = reset_time,
|
||||
retry_after = stored_config.window_size,
|
||||
)
|
||||
|
||||
# 允许请求,增加计数
|
||||
await counter.add_request()
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
|
||||
allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0
|
||||
)
|
||||
|
||||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||||
"""获取限流信息(不增加计数)"""
|
||||
if key not in self.counters:
|
||||
config = RateLimitConfig()
|
||||
config = RateLimitConfig()
|
||||
return RateLimitInfo(
|
||||
allowed=True,
|
||||
remaining=config.requests_per_minute,
|
||||
reset_time=int(time.time()) + config.window_size,
|
||||
retry_after=0,
|
||||
allowed = True,
|
||||
remaining = config.requests_per_minute,
|
||||
reset_time = int(time.time()) + config.window_size,
|
||||
retry_after = 0,
|
||||
)
|
||||
|
||||
counter = self.counters[key]
|
||||
config = self.configs.get(key, RateLimitConfig())
|
||||
counter = self.counters[key]
|
||||
config = self.configs.get(key, RateLimitConfig())
|
||||
|
||||
current_count = await counter.get_count()
|
||||
remaining = max(0, config.requests_per_minute - current_count)
|
||||
reset_time = int(time.time()) + config.window_size
|
||||
current_count = await counter.get_count()
|
||||
remaining = max(0, config.requests_per_minute - current_count)
|
||||
reset_time = int(time.time()) + config.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
allowed=current_count < config.requests_per_minute,
|
||||
remaining=remaining,
|
||||
reset_time=reset_time,
|
||||
retry_after=max(0, config.window_size)
|
||||
allowed = current_count < config.requests_per_minute,
|
||||
remaining = remaining,
|
||||
reset_time = reset_time,
|
||||
retry_after = max(0, config.window_size)
|
||||
if current_count >= config.requests_per_minute
|
||||
else 0,
|
||||
)
|
||||
|
||||
def reset(self, key: str | None = None) -> None:
|
||||
def reset(self, key: str | None = None) -> None:
|
||||
"""重置限流计数器"""
|
||||
if key:
|
||||
self.counters.pop(key, None)
|
||||
@@ -161,21 +161,21 @@ class RateLimiter:
|
||||
|
||||
|
||||
# 全局限流器实例
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取限流器实例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
# 限流装饰器(用于函数级别限流)
|
||||
|
||||
|
||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||||
"""
|
||||
限流装饰器
|
||||
|
||||
@@ -184,14 +184,14 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
|
||||
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
limiter = get_rate_limiter()
|
||||
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
||||
def decorator(func) -> None:
|
||||
limiter = get_rate_limiter()
|
||||
config = RateLimitConfig(requests_per_minute = requests_per_minute)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
info = await limiter.is_allowed(key, config)
|
||||
async def async_wrapper(*args, **kwargs) -> None:
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
info = await limiter.is_allowed(key, config)
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
@@ -201,10 +201,10 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
def sync_wrapper(*args, **kwargs) -> None:
|
||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||
# 同步版本使用 asyncio.run
|
||||
info = asyncio.run(limiter.is_allowed(key, config))
|
||||
info = asyncio.run(limiter.is_allowed(key, config))
|
||||
|
||||
if not info.allowed:
|
||||
raise RateLimitExceeded(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -10,9 +10,9 @@ import sys
|
||||
# 添加 backend 目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("InsightFlow 多模态模块测试")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
# 测试导入
|
||||
print("\n1. 测试模块导入...")
|
||||
@@ -42,7 +42,7 @@ except ImportError as e:
|
||||
print("\n2. 测试模块初始化...")
|
||||
|
||||
try:
|
||||
processor = get_multimodal_processor()
|
||||
processor = get_multimodal_processor()
|
||||
print(" ✓ MultimodalProcessor 初始化成功")
|
||||
print(f" - 临时目录: {processor.temp_dir}")
|
||||
print(f" - 帧提取间隔: {processor.frame_interval}秒")
|
||||
@@ -50,14 +50,14 @@ except Exception as e:
|
||||
print(f" ✗ MultimodalProcessor 初始化失败: {e}")
|
||||
|
||||
try:
|
||||
img_processor = get_image_processor()
|
||||
img_processor = get_image_processor()
|
||||
print(" ✓ ImageProcessor 初始化成功")
|
||||
print(f" - 临时目录: {img_processor.temp_dir}")
|
||||
except Exception as e:
|
||||
print(f" ✗ ImageProcessor 初始化失败: {e}")
|
||||
|
||||
try:
|
||||
linker = get_multimodal_entity_linker()
|
||||
linker = get_multimodal_entity_linker()
|
||||
print(" ✓ MultimodalEntityLinker 初始化成功")
|
||||
print(f" - 相似度阈值: {linker.similarity_threshold}")
|
||||
except Exception as e:
|
||||
@@ -67,20 +67,20 @@ except Exception as e:
|
||||
print("\n3. 测试实体关联功能...")
|
||||
|
||||
try:
|
||||
linker = get_multimodal_entity_linker()
|
||||
linker = get_multimodal_entity_linker()
|
||||
|
||||
# 测试字符串相似度
|
||||
sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha")
|
||||
sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha")
|
||||
assert sim == 1.0, "完全匹配应该返回1.0"
|
||||
print(f" ✓ 字符串相似度计算正常 (完全匹配: {sim})")
|
||||
|
||||
sim = linker.calculate_string_similarity("K8s", "Kubernetes")
|
||||
sim = linker.calculate_string_similarity("K8s", "Kubernetes")
|
||||
print(f" ✓ 字符串相似度计算正常 (不同字符串: {sim:.2f})")
|
||||
|
||||
# 测试实体相似度
|
||||
entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"}
|
||||
entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"}
|
||||
sim, match_type = linker.calculate_entity_similarity(entity1, entity2)
|
||||
entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"}
|
||||
entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"}
|
||||
sim, match_type = linker.calculate_entity_similarity(entity1, entity2)
|
||||
print(f" ✓ 实体相似度计算正常 (相似度: {sim:.2f}, 类型: {match_type})")
|
||||
|
||||
except Exception as e:
|
||||
@@ -90,7 +90,7 @@ except Exception as e:
|
||||
print("\n4. 测试图片处理器功能...")
|
||||
|
||||
try:
|
||||
processor = get_image_processor()
|
||||
processor = get_image_processor()
|
||||
|
||||
# 测试图片类型检测(使用模拟数据)
|
||||
print(f" ✓ 支持的图片类型: {list(processor.IMAGE_TYPES.keys())}")
|
||||
@@ -103,7 +103,7 @@ except Exception as e:
|
||||
print("\n5. 测试视频处理器配置...")
|
||||
|
||||
try:
|
||||
processor = get_multimodal_processor()
|
||||
processor = get_multimodal_processor()
|
||||
|
||||
print(f" ✓ 视频目录: {processor.video_dir}")
|
||||
print(f" ✓ 帧目录: {processor.frames_dir}")
|
||||
@@ -129,11 +129,11 @@ print("\n6. 测试数据库多模态方法...")
|
||||
try:
|
||||
from db_manager import get_db_manager
|
||||
|
||||
db = get_db_manager()
|
||||
db = get_db_manager()
|
||||
|
||||
# 检查多模态表是否存在
|
||||
conn = db.get_conn()
|
||||
tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"]
|
||||
conn = db.get_conn()
|
||||
tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"]
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
@@ -147,6 +147,6 @@ try:
|
||||
except Exception as e:
|
||||
print(f" ✗ 数据库多模态方法测试失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
@@ -21,27 +21,27 @@ from search_manager import (
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def test_fulltext_search():
|
||||
def test_fulltext_search() -> None:
|
||||
"""测试全文搜索"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试全文搜索 (FullTextSearch)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
search = FullTextSearch()
|
||||
search = FullTextSearch()
|
||||
|
||||
# 测试索引创建
|
||||
print("\n1. 测试索引创建...")
|
||||
success = search.index_content(
|
||||
content_id="test_entity_1",
|
||||
content_type="entity",
|
||||
project_id="test_project",
|
||||
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
|
||||
success = search.index_content(
|
||||
content_id = "test_entity_1",
|
||||
content_type = "entity",
|
||||
project_id = "test_project",
|
||||
text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
|
||||
)
|
||||
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
|
||||
|
||||
# 测试搜索
|
||||
print("\n2. 测试关键词搜索...")
|
||||
results = search.search("测试", project_id="test_project")
|
||||
results = search.search("测试", project_id = "test_project")
|
||||
print(f" 搜索结果数量: {len(results)}")
|
||||
if results:
|
||||
print(f" 第一个结果: {results[0].content[:50]}...")
|
||||
@@ -49,28 +49,28 @@ def test_fulltext_search():
|
||||
|
||||
# 测试布尔搜索
|
||||
print("\n3. 测试布尔搜索...")
|
||||
results = search.search("测试 AND 全文", project_id="test_project")
|
||||
results = search.search("测试 AND 全文", project_id = "test_project")
|
||||
print(f" AND 搜索结果: {len(results)}")
|
||||
|
||||
results = search.search("测试 OR 关键词", project_id="test_project")
|
||||
results = search.search("测试 OR 关键词", project_id = "test_project")
|
||||
print(f" OR 搜索结果: {len(results)}")
|
||||
|
||||
# 测试高亮
|
||||
print("\n4. 测试文本高亮...")
|
||||
highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文")
|
||||
highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文")
|
||||
print(f" 高亮结果: {highlighted}")
|
||||
|
||||
print("\n✓ 全文搜索测试完成")
|
||||
return True
|
||||
|
||||
|
||||
def test_semantic_search():
|
||||
def test_semantic_search() -> None:
|
||||
"""测试语义搜索"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试语义搜索 (SemanticSearch)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
semantic = SemanticSearch()
|
||||
semantic = SemanticSearch()
|
||||
|
||||
# 检查可用性
|
||||
print(f"\n1. 语义搜索可用性: {'✓ 可用' if semantic.is_available() else '✗ 不可用'}")
|
||||
@@ -81,18 +81,18 @@ def test_semantic_search():
|
||||
|
||||
# 测试 embedding 生成
|
||||
print("\n2. 测试 embedding 生成...")
|
||||
embedding = semantic.generate_embedding("这是一个测试句子")
|
||||
embedding = semantic.generate_embedding("这是一个测试句子")
|
||||
if embedding:
|
||||
print(f" Embedding 维度: {len(embedding)}")
|
||||
print(f" 前5个值: {embedding[:5]}")
|
||||
|
||||
# 测试索引
|
||||
print("\n3. 测试语义索引...")
|
||||
success = semantic.index_embedding(
|
||||
content_id="test_content_1",
|
||||
content_type="transcript",
|
||||
project_id="test_project",
|
||||
text="这是用于语义搜索测试的文本内容。",
|
||||
success = semantic.index_embedding(
|
||||
content_id = "test_content_1",
|
||||
content_type = "transcript",
|
||||
project_id = "test_project",
|
||||
text = "这是用于语义搜索测试的文本内容。",
|
||||
)
|
||||
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
|
||||
|
||||
@@ -100,13 +100,13 @@ def test_semantic_search():
|
||||
return True
|
||||
|
||||
|
||||
def test_entity_path_discovery():
|
||||
def test_entity_path_discovery() -> None:
|
||||
"""测试实体路径发现"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试实体路径发现 (EntityPathDiscovery)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
discovery = EntityPathDiscovery()
|
||||
discovery = EntityPathDiscovery()
|
||||
|
||||
print("\n1. 测试路径发现初始化...")
|
||||
print(f" 数据库路径: {discovery.db_path}")
|
||||
@@ -119,13 +119,13 @@ def test_entity_path_discovery():
|
||||
return True
|
||||
|
||||
|
||||
def test_knowledge_gap_detection():
|
||||
def test_knowledge_gap_detection() -> None:
|
||||
"""测试知识缺口识别"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试知识缺口识别 (KnowledgeGapDetection)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
detection = KnowledgeGapDetection()
|
||||
detection = KnowledgeGapDetection()
|
||||
|
||||
print("\n1. 测试缺口检测初始化...")
|
||||
print(f" 数据库路径: {detection.db_path}")
|
||||
@@ -138,32 +138,32 @@ def test_knowledge_gap_detection():
|
||||
return True
|
||||
|
||||
|
||||
def test_cache_manager():
|
||||
def test_cache_manager() -> None:
|
||||
"""测试缓存管理器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试缓存管理器 (CacheManager)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
cache = CacheManager()
|
||||
cache = CacheManager()
|
||||
|
||||
print(f"\n1. 缓存后端: {'Redis' if cache.use_redis else '内存 LRU'}")
|
||||
|
||||
print("\n2. 测试缓存操作...")
|
||||
# 设置缓存
|
||||
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60)
|
||||
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60)
|
||||
print(" ✓ 设置缓存 test_key_1")
|
||||
|
||||
# 获取缓存
|
||||
_ = cache.get("test_key_1")
|
||||
_ = cache.get("test_key_1")
|
||||
print(" ✓ 获取缓存: {value}")
|
||||
|
||||
# 批量操作
|
||||
cache.set_many(
|
||||
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60
|
||||
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60
|
||||
)
|
||||
print(" ✓ 批量设置缓存")
|
||||
|
||||
_ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
|
||||
_ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
|
||||
print(" ✓ 批量获取缓存: {len(values)} 个")
|
||||
|
||||
# 删除缓存
|
||||
@@ -171,7 +171,7 @@ def test_cache_manager():
|
||||
print(" ✓ 删除缓存 test_key_1")
|
||||
|
||||
# 获取统计
|
||||
stats = cache.get_stats()
|
||||
stats = cache.get_stats()
|
||||
print("\n3. 缓存统计:")
|
||||
print(f" 总请求数: {stats['total_requests']}")
|
||||
print(f" 命中数: {stats['hits']}")
|
||||
@@ -186,13 +186,13 @@ def test_cache_manager():
|
||||
return True
|
||||
|
||||
|
||||
def test_task_queue():
|
||||
def test_task_queue() -> None:
|
||||
"""测试任务队列"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试任务队列 (TaskQueue)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
queue = TaskQueue()
|
||||
queue = TaskQueue()
|
||||
|
||||
print(f"\n1. 任务队列可用性: {'✓ 可用' if queue.is_available() else '✗ 不可用'}")
|
||||
print(f" 后端: {'Celery' if queue.use_celery else '内存'}")
|
||||
@@ -200,25 +200,25 @@ def test_task_queue():
|
||||
print("\n2. 测试任务提交...")
|
||||
|
||||
# 定义测试任务处理器
|
||||
def test_task_handler(payload):
|
||||
def test_task_handler(payload) -> None:
|
||||
print(f" 执行任务: {payload}")
|
||||
return {"status": "success", "processed": True}
|
||||
|
||||
queue.register_handler("test_task", test_task_handler)
|
||||
|
||||
# 提交任务
|
||||
task_id = queue.submit(
|
||||
task_type="test_task", payload={"test": "data", "timestamp": time.time()}
|
||||
task_id = queue.submit(
|
||||
task_type = "test_task", payload = {"test": "data", "timestamp": time.time()}
|
||||
)
|
||||
print(" ✓ 提交任务: {task_id}")
|
||||
|
||||
# 获取任务状态
|
||||
task_info = queue.get_status(task_id)
|
||||
task_info = queue.get_status(task_id)
|
||||
if task_info:
|
||||
print(" ✓ 任务状态: {task_info.status}")
|
||||
|
||||
# 获取统计
|
||||
stats = queue.get_stats()
|
||||
stats = queue.get_stats()
|
||||
print("\n3. 任务队列统计:")
|
||||
print(f" 后端: {stats['backend']}")
|
||||
print(f" 按状态统计: {stats.get('by_status', {})}")
|
||||
@@ -227,38 +227,38 @@ def test_task_queue():
|
||||
return True
|
||||
|
||||
|
||||
def test_performance_monitor():
|
||||
def test_performance_monitor() -> None:
|
||||
"""测试性能监控"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试性能监控 (PerformanceMonitor)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
monitor = PerformanceMonitor()
|
||||
monitor = PerformanceMonitor()
|
||||
|
||||
print("\n1. 测试指标记录...")
|
||||
|
||||
# 记录一些测试指标
|
||||
for i in range(5):
|
||||
monitor.record_metric(
|
||||
metric_type="api_response",
|
||||
duration_ms=50 + i * 10,
|
||||
endpoint="/api/v1/test",
|
||||
metadata={"test": True},
|
||||
metric_type = "api_response",
|
||||
duration_ms = 50 + i * 10,
|
||||
endpoint = "/api/v1/test",
|
||||
metadata = {"test": True},
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
monitor.record_metric(
|
||||
metric_type="db_query",
|
||||
duration_ms=20 + i * 5,
|
||||
endpoint="SELECT test",
|
||||
metadata={"test": True},
|
||||
metric_type = "db_query",
|
||||
duration_ms = 20 + i * 5,
|
||||
endpoint = "SELECT test",
|
||||
metadata = {"test": True},
|
||||
)
|
||||
|
||||
print(" ✓ 记录了 8 个测试指标")
|
||||
|
||||
# 获取统计
|
||||
print("\n2. 获取性能统计...")
|
||||
stats = monitor.get_stats(hours=1)
|
||||
stats = monitor.get_stats(hours = 1)
|
||||
print(f" 总请求数: {stats['overall']['total_requests']}")
|
||||
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
|
||||
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
|
||||
@@ -274,19 +274,19 @@ def test_performance_monitor():
|
||||
return True
|
||||
|
||||
|
||||
def test_search_manager():
|
||||
def test_search_manager() -> None:
|
||||
"""测试搜索管理器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试搜索管理器 (SearchManager)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_search_manager()
|
||||
manager = get_search_manager()
|
||||
|
||||
print("\n1. 搜索管理器初始化...")
|
||||
print(" ✓ 搜索管理器已初始化")
|
||||
|
||||
print("\n2. 获取搜索统计...")
|
||||
stats = manager.get_search_stats()
|
||||
stats = manager.get_search_stats()
|
||||
print(f" 全文索引数: {stats['fulltext_indexed']}")
|
||||
print(f" 语义索引数: {stats['semantic_indexed']}")
|
||||
print(f" 语义搜索可用: {stats['semantic_search_available']}")
|
||||
@@ -295,24 +295,24 @@ def test_search_manager():
|
||||
return True
|
||||
|
||||
|
||||
def test_performance_manager():
|
||||
def test_performance_manager() -> None:
|
||||
"""测试性能管理器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试性能管理器 (PerformanceManager)")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_performance_manager()
|
||||
manager = get_performance_manager()
|
||||
|
||||
print("\n1. 性能管理器初始化...")
|
||||
print(" ✓ 性能管理器已初始化")
|
||||
|
||||
print("\n2. 获取系统健康状态...")
|
||||
health = manager.get_health_status()
|
||||
health = manager.get_health_status()
|
||||
print(f" 缓存后端: {health['cache']['backend']}")
|
||||
print(f" 任务队列后端: {health['task_queue']['backend']}")
|
||||
|
||||
print("\n3. 获取完整统计...")
|
||||
stats = manager.get_full_stats()
|
||||
stats = manager.get_full_stats()
|
||||
print(f" 缓存统计: {stats['cache']['total_requests']} 请求")
|
||||
print(f" 任务队列统计: {stats['task_queue']}")
|
||||
|
||||
@@ -320,14 +320,14 @@ def test_performance_manager():
|
||||
return True
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
def run_all_tests() -> None:
|
||||
"""运行所有测试"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("InsightFlow Phase 7 Task 6 & 8 测试")
|
||||
print("高级搜索与发现 + 性能优化与扩展")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
results = []
|
||||
results = []
|
||||
|
||||
# 搜索模块测试
|
||||
try:
|
||||
@@ -386,15 +386,15 @@ def run_all_tests():
|
||||
results.append(("性能管理器", False))
|
||||
|
||||
# 打印测试汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试汇总")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for name, result in results:
|
||||
status = "✓ 通过" if result else "✗ 失败"
|
||||
status = "✓ 通过" if result else "✗ 失败"
|
||||
print(f" {status} - {name}")
|
||||
|
||||
print(f"\n总计: {passed}/{total} 测试通过")
|
||||
@@ -408,5 +408,5 @@ def run_all_tests():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
@@ -18,18 +18,18 @@ from tenant_manager import get_tenant_manager
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def test_tenant_management():
|
||||
def test_tenant_management() -> None:
|
||||
"""测试租户管理功能"""
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("测试 1: 租户管理")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_tenant_manager()
|
||||
manager = get_tenant_manager()
|
||||
|
||||
# 1. 创建租户
|
||||
print("\n1.1 创建租户...")
|
||||
tenant = manager.create_tenant(
|
||||
name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant"
|
||||
tenant = manager.create_tenant(
|
||||
name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant"
|
||||
)
|
||||
print(f"✅ 租户创建成功: {tenant.id}")
|
||||
print(f" - 名称: {tenant.name}")
|
||||
@@ -40,43 +40,43 @@ def test_tenant_management():
|
||||
|
||||
# 2. 获取租户
|
||||
print("\n1.2 获取租户信息...")
|
||||
fetched = manager.get_tenant(tenant.id)
|
||||
fetched = manager.get_tenant(tenant.id)
|
||||
assert fetched is not None, "获取租户失败"
|
||||
print(f"✅ 获取租户成功: {fetched.name}")
|
||||
|
||||
# 3. 通过 slug 获取
|
||||
print("\n1.3 通过 slug 获取租户...")
|
||||
by_slug = manager.get_tenant_by_slug(tenant.slug)
|
||||
by_slug = manager.get_tenant_by_slug(tenant.slug)
|
||||
assert by_slug is not None, "通过 slug 获取失败"
|
||||
print(f"✅ 通过 slug 获取成功: {by_slug.name}")
|
||||
|
||||
# 4. 更新租户
|
||||
print("\n1.4 更新租户信息...")
|
||||
updated = manager.update_tenant(
|
||||
tenant_id=tenant.id, name="Test Company Updated", tier="enterprise"
|
||||
updated = manager.update_tenant(
|
||||
tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise"
|
||||
)
|
||||
assert updated is not None, "更新租户失败"
|
||||
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
|
||||
|
||||
# 5. 列出租户
|
||||
print("\n1.5 列出租户...")
|
||||
tenants = manager.list_tenants(limit=10)
|
||||
tenants = manager.list_tenants(limit = 10)
|
||||
print(f"✅ 找到 {len(tenants)} 个租户")
|
||||
|
||||
return tenant.id
|
||||
|
||||
|
||||
def test_domain_management(tenant_id: str):
|
||||
def test_domain_management(tenant_id: str) -> None:
|
||||
"""测试域名管理功能"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试 2: 域名管理")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_tenant_manager()
|
||||
manager = get_tenant_manager()
|
||||
|
||||
# 1. 添加域名
|
||||
print("\n2.1 添加自定义域名...")
|
||||
domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True)
|
||||
domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True)
|
||||
print(f"✅ 域名添加成功: {domain.domain}")
|
||||
print(f" - ID: {domain.id}")
|
||||
print(f" - 状态: {domain.status}")
|
||||
@@ -84,19 +84,19 @@ def test_domain_management(tenant_id: str):
|
||||
|
||||
# 2. 获取验证指导
|
||||
print("\n2.2 获取域名验证指导...")
|
||||
instructions = manager.get_domain_verification_instructions(domain.id)
|
||||
instructions = manager.get_domain_verification_instructions(domain.id)
|
||||
print("✅ 验证指导:")
|
||||
print(f" - DNS 记录: {instructions['dns_record']}")
|
||||
print(f" - 文件验证: {instructions['file_verification']}")
|
||||
|
||||
# 3. 验证域名
|
||||
print("\n2.3 验证域名...")
|
||||
verified = manager.verify_domain(tenant_id, domain.id)
|
||||
verified = manager.verify_domain(tenant_id, domain.id)
|
||||
print(f"✅ 域名验证结果: {verified}")
|
||||
|
||||
# 4. 通过域名获取租户
|
||||
print("\n2.4 通过域名获取租户...")
|
||||
by_domain = manager.get_tenant_by_domain("test.example.com")
|
||||
by_domain = manager.get_tenant_by_domain("test.example.com")
|
||||
if by_domain:
|
||||
print(f"✅ 通过域名获取租户成功: {by_domain.name}")
|
||||
else:
|
||||
@@ -104,7 +104,7 @@ def test_domain_management(tenant_id: str):
|
||||
|
||||
# 5. 列出域名
|
||||
print("\n2.5 列出所有域名...")
|
||||
domains = manager.list_domains(tenant_id)
|
||||
domains = manager.list_domains(tenant_id)
|
||||
print(f"✅ 找到 {len(domains)} 个域名")
|
||||
for d in domains:
|
||||
print(f" - {d.domain} ({d.status})")
|
||||
@@ -112,25 +112,25 @@ def test_domain_management(tenant_id: str):
|
||||
return domain.id
|
||||
|
||||
|
||||
def test_branding_management(tenant_id: str):
|
||||
def test_branding_management(tenant_id: str) -> None:
|
||||
"""测试品牌白标功能"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试 3: 品牌白标")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_tenant_manager()
|
||||
manager = get_tenant_manager()
|
||||
|
||||
# 1. 更新品牌配置
|
||||
print("\n3.1 更新品牌配置...")
|
||||
branding = manager.update_branding(
|
||||
tenant_id=tenant_id,
|
||||
logo_url="https://example.com/logo.png",
|
||||
favicon_url="https://example.com/favicon.ico",
|
||||
primary_color="#1890ff",
|
||||
secondary_color="#52c41a",
|
||||
custom_css=".header { background: #1890ff; }",
|
||||
custom_js="console.log('Custom JS loaded');",
|
||||
login_page_bg="https://example.com/bg.jpg",
|
||||
branding = manager.update_branding(
|
||||
tenant_id = tenant_id,
|
||||
logo_url = "https://example.com/logo.png",
|
||||
favicon_url = "https://example.com/favicon.ico",
|
||||
primary_color = "#1890ff",
|
||||
secondary_color = "#52c41a",
|
||||
custom_css = ".header { background: #1890ff; }",
|
||||
custom_js = "console.log('Custom JS loaded');",
|
||||
login_page_bg = "https://example.com/bg.jpg",
|
||||
)
|
||||
print("✅ 品牌配置更新成功")
|
||||
print(f" - Logo: {branding.logo_url}")
|
||||
@@ -139,67 +139,67 @@ def test_branding_management(tenant_id: str):
|
||||
|
||||
# 2. 获取品牌配置
|
||||
print("\n3.2 获取品牌配置...")
|
||||
fetched = manager.get_branding(tenant_id)
|
||||
fetched = manager.get_branding(tenant_id)
|
||||
assert fetched is not None, "获取品牌配置失败"
|
||||
print("✅ 获取品牌配置成功")
|
||||
|
||||
# 3. 生成品牌 CSS
|
||||
print("\n3.3 生成品牌 CSS...")
|
||||
css = manager.get_branding_css(tenant_id)
|
||||
css = manager.get_branding_css(tenant_id)
|
||||
print(f"✅ 生成 CSS 成功 ({len(css)} 字符)")
|
||||
print(f" CSS 预览:\n{css[:200]}...")
|
||||
|
||||
return branding.id
|
||||
|
||||
|
||||
def test_member_management(tenant_id: str):
|
||||
def test_member_management(tenant_id: str) -> None:
|
||||
"""测试成员管理功能"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试 4: 成员管理")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_tenant_manager()
|
||||
manager = get_tenant_manager()
|
||||
|
||||
# 1. 邀请成员
|
||||
print("\n4.1 邀请成员...")
|
||||
member1 = manager.invite_member(
|
||||
tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001"
|
||||
member1 = manager.invite_member(
|
||||
tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001"
|
||||
)
|
||||
print(f"✅ 成员邀请成功: {member1.email}")
|
||||
print(f" - ID: {member1.id}")
|
||||
print(f" - 角色: {member1.role}")
|
||||
print(f" - 权限: {member1.permissions}")
|
||||
|
||||
member2 = manager.invite_member(
|
||||
tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001"
|
||||
member2 = manager.invite_member(
|
||||
tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001"
|
||||
)
|
||||
print(f"✅ 成员邀请成功: {member2.email}")
|
||||
|
||||
# 2. 接受邀请
|
||||
print("\n4.2 接受邀请...")
|
||||
accepted = manager.accept_invitation(member1.id, "user_002")
|
||||
accepted = manager.accept_invitation(member1.id, "user_002")
|
||||
print(f"✅ 邀请接受结果: {accepted}")
|
||||
|
||||
# 3. 列出成员
|
||||
print("\n4.3 列出所有成员...")
|
||||
members = manager.list_members(tenant_id)
|
||||
members = manager.list_members(tenant_id)
|
||||
print(f"✅ 找到 {len(members)} 个成员")
|
||||
for m in members:
|
||||
print(f" - {m.email} ({m.role}) - {m.status}")
|
||||
|
||||
# 4. 检查权限
|
||||
print("\n4.4 检查权限...")
|
||||
can_manage = manager.check_permission(tenant_id, "user_002", "project", "create")
|
||||
can_manage = manager.check_permission(tenant_id, "user_002", "project", "create")
|
||||
print(f"✅ user_002 可以创建项目: {can_manage}")
|
||||
|
||||
# 5. 更新成员角色
|
||||
print("\n4.5 更新成员角色...")
|
||||
updated = manager.update_member_role(tenant_id, member2.id, "viewer")
|
||||
updated = manager.update_member_role(tenant_id, member2.id, "viewer")
|
||||
print(f"✅ 角色更新结果: {updated}")
|
||||
|
||||
# 6. 获取用户所属租户
|
||||
print("\n4.6 获取用户所属租户...")
|
||||
user_tenants = manager.get_user_tenants("user_002")
|
||||
user_tenants = manager.get_user_tenants("user_002")
|
||||
print(f"✅ user_002 属于 {len(user_tenants)} 个租户")
|
||||
for t in user_tenants:
|
||||
print(f" - {t['name']} ({t['member_role']})")
|
||||
@@ -207,30 +207,30 @@ def test_member_management(tenant_id: str):
|
||||
return member1.id, member2.id
|
||||
|
||||
|
||||
def test_usage_tracking(tenant_id: str):
|
||||
def test_usage_tracking(tenant_id: str) -> None:
|
||||
"""测试资源使用统计功能"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("测试 5: 资源使用统计")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_tenant_manager()
|
||||
manager = get_tenant_manager()
|
||||
|
||||
# 1. 记录使用
|
||||
print("\n5.1 记录资源使用...")
|
||||
manager.record_usage(
|
||||
tenant_id=tenant_id,
|
||||
storage_bytes=1024 * 1024 * 50, # 50MB
|
||||
transcription_seconds=600, # 10分钟
|
||||
api_calls=100,
|
||||
projects_count=5,
|
||||
entities_count=50,
|
||||
members_count=3,
|
||||
tenant_id = tenant_id,
|
||||
storage_bytes = 1024 * 1024 * 50, # 50MB
|
||||
transcription_seconds = 600, # 10分钟
|
||||
api_calls = 100,
|
||||
projects_count = 5,
|
||||
entities_count = 50,
|
||||
members_count = 3,
|
||||
)
|
||||
print("✅ 资源使用记录成功")
|
||||
|
||||
# 2. 获取使用统计
|
||||
print("\n5.2 获取使用统计...")
|
||||
stats = manager.get_usage_stats(tenant_id)
|
||||
stats = manager.get_usage_stats(tenant_id)
|
||||
print("✅ 使用统计:")
|
||||
print(f" - 存储: {stats['storage_mb']:.2f} MB")
|
||||
print(f" - 转录: {stats['transcription_minutes']:.2f} 分钟")
|
||||
@@ -243,19 +243,19 @@ def test_usage_tracking(tenant_id: str):
|
||||
# 3. 检查资源限制
|
||||
print("\n5.3 检查资源限制...")
|
||||
for resource in ["storage", "transcription", "api_calls", "projects", "entities", "members"]:
|
||||
allowed, current, limit = manager.check_resource_limit(tenant_id, resource)
|
||||
allowed, current, limit = manager.check_resource_limit(tenant_id, resource)
|
||||
print(f" - {resource}: {current}/{limit} ({'✅' if allowed else '❌'})")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def cleanup(tenant_id: str, domain_id: str, member_ids: list):
|
||||
def cleanup(tenant_id: str, domain_id: str, member_ids: list) -> None:
|
||||
"""清理测试数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("清理测试数据")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
manager = get_tenant_manager()
|
||||
manager = get_tenant_manager()
|
||||
|
||||
# 移除成员
|
||||
for member_id in member_ids:
|
||||
@@ -273,28 +273,28 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list):
|
||||
print(f"✅ 租户已删除: {tenant_id}")
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""主测试函数"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
tenant_id = None
|
||||
domain_id = None
|
||||
member_ids = []
|
||||
tenant_id = None
|
||||
domain_id = None
|
||||
member_ids = []
|
||||
|
||||
try:
|
||||
# 运行所有测试
|
||||
tenant_id = test_tenant_management()
|
||||
domain_id = test_domain_management(tenant_id)
|
||||
tenant_id = test_tenant_management()
|
||||
domain_id = test_domain_management(tenant_id)
|
||||
test_branding_management(tenant_id)
|
||||
m1, m2 = test_member_management(tenant_id)
|
||||
member_ids = [m1, m2]
|
||||
m1, m2 = test_member_management(tenant_id)
|
||||
member_ids = [m1, m2]
|
||||
test_usage_tracking(tenant_id)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("✅ 所有测试通过!")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
|
||||
@@ -12,31 +12,31 @@ from subscription_manager import PaymentProvider, SubscriptionManager
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def test_subscription_manager():
|
||||
def test_subscription_manager() -> None:
|
||||
"""测试订阅管理器"""
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
# 使用临时文件数据库进行测试
|
||||
db_path = tempfile.mktemp(suffix=".db")
|
||||
db_path = tempfile.mktemp(suffix = ".db")
|
||||
|
||||
try:
|
||||
manager = SubscriptionManager(db_path=db_path)
|
||||
manager = SubscriptionManager(db_path = db_path)
|
||||
|
||||
print("\n1. 测试订阅计划管理")
|
||||
print("-" * 40)
|
||||
|
||||
# 获取默认计划
|
||||
plans = manager.list_plans()
|
||||
plans = manager.list_plans()
|
||||
print(f"✓ 默认计划数量: {len(plans)}")
|
||||
for plan in plans:
|
||||
print(f" - {plan.name} ({plan.tier}): ¥{plan.price_monthly}/月")
|
||||
|
||||
# 通过 tier 获取计划
|
||||
free_plan = manager.get_plan_by_tier("free")
|
||||
pro_plan = manager.get_plan_by_tier("pro")
|
||||
enterprise_plan = manager.get_plan_by_tier("enterprise")
|
||||
free_plan = manager.get_plan_by_tier("free")
|
||||
pro_plan = manager.get_plan_by_tier("pro")
|
||||
enterprise_plan = manager.get_plan_by_tier("enterprise")
|
||||
|
||||
assert free_plan is not None, "Free 计划应该存在"
|
||||
assert pro_plan is not None, "Pro 计划应该存在"
|
||||
@@ -49,14 +49,14 @@ def test_subscription_manager():
|
||||
print("\n2. 测试订阅管理")
|
||||
print("-" * 40)
|
||||
|
||||
tenant_id = "test-tenant-001"
|
||||
tenant_id = "test-tenant-001"
|
||||
|
||||
# 创建订阅
|
||||
subscription = manager.create_subscription(
|
||||
tenant_id=tenant_id,
|
||||
plan_id=pro_plan.id,
|
||||
payment_provider=PaymentProvider.STRIPE.value,
|
||||
trial_days=14,
|
||||
subscription = manager.create_subscription(
|
||||
tenant_id = tenant_id,
|
||||
plan_id = pro_plan.id,
|
||||
payment_provider = PaymentProvider.STRIPE.value,
|
||||
trial_days = 14,
|
||||
)
|
||||
|
||||
print(f"✓ 创建订阅: {subscription.id}")
|
||||
@@ -66,7 +66,7 @@ def test_subscription_manager():
|
||||
print(f" - 试用结束: {subscription.trial_end}")
|
||||
|
||||
# 获取租户订阅
|
||||
tenant_sub = manager.get_tenant_subscription(tenant_id)
|
||||
tenant_sub = manager.get_tenant_subscription(tenant_id)
|
||||
assert tenant_sub is not None, "应该能获取到租户订阅"
|
||||
print(f"✓ 获取租户订阅: {tenant_sub.id}")
|
||||
|
||||
@@ -74,27 +74,27 @@ def test_subscription_manager():
|
||||
print("-" * 40)
|
||||
|
||||
# 记录转录用量
|
||||
usage1 = manager.record_usage(
|
||||
tenant_id=tenant_id,
|
||||
resource_type="transcription",
|
||||
quantity=120,
|
||||
unit="minute",
|
||||
description="会议转录",
|
||||
usage1 = manager.record_usage(
|
||||
tenant_id = tenant_id,
|
||||
resource_type = "transcription",
|
||||
quantity = 120,
|
||||
unit = "minute",
|
||||
description = "会议转录",
|
||||
)
|
||||
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
|
||||
|
||||
# 记录存储用量
|
||||
usage2 = manager.record_usage(
|
||||
tenant_id=tenant_id,
|
||||
resource_type="storage",
|
||||
quantity=2.5,
|
||||
unit="gb",
|
||||
description="文件存储",
|
||||
usage2 = manager.record_usage(
|
||||
tenant_id = tenant_id,
|
||||
resource_type = "storage",
|
||||
quantity = 2.5,
|
||||
unit = "gb",
|
||||
description = "文件存储",
|
||||
)
|
||||
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
|
||||
|
||||
# 获取用量汇总
|
||||
summary = manager.get_usage_summary(tenant_id)
|
||||
summary = manager.get_usage_summary(tenant_id)
|
||||
print("✓ 用量汇总:")
|
||||
print(f" - 总费用: ¥{summary['total_cost']:.2f}")
|
||||
for resource, data in summary["breakdown"].items():
|
||||
@@ -104,12 +104,12 @@ def test_subscription_manager():
|
||||
print("-" * 40)
|
||||
|
||||
# 创建支付
|
||||
payment = manager.create_payment(
|
||||
tenant_id=tenant_id,
|
||||
amount=99.0,
|
||||
currency="CNY",
|
||||
provider=PaymentProvider.ALIPAY.value,
|
||||
payment_method="qrcode",
|
||||
payment = manager.create_payment(
|
||||
tenant_id = tenant_id,
|
||||
amount = 99.0,
|
||||
currency = "CNY",
|
||||
provider = PaymentProvider.ALIPAY.value,
|
||||
payment_method = "qrcode",
|
||||
)
|
||||
print(f"✓ 创建支付: {payment.id}")
|
||||
print(f" - 金额: ¥{payment.amount}")
|
||||
@@ -117,22 +117,22 @@ def test_subscription_manager():
|
||||
print(f" - 状态: {payment.status}")
|
||||
|
||||
# 确认支付
|
||||
confirmed = manager.confirm_payment(payment.id, "alipay_123456")
|
||||
confirmed = manager.confirm_payment(payment.id, "alipay_123456")
|
||||
print(f"✓ 确认支付完成: {confirmed.status}")
|
||||
|
||||
# 列出支付记录
|
||||
payments = manager.list_payments(tenant_id)
|
||||
payments = manager.list_payments(tenant_id)
|
||||
print(f"✓ 支付记录数量: {len(payments)}")
|
||||
|
||||
print("\n5. 测试发票管理")
|
||||
print("-" * 40)
|
||||
|
||||
# 列出发票
|
||||
invoices = manager.list_invoices(tenant_id)
|
||||
invoices = manager.list_invoices(tenant_id)
|
||||
print(f"✓ 发票数量: {len(invoices)}")
|
||||
|
||||
if invoices:
|
||||
invoice = invoices[0]
|
||||
invoice = invoices[0]
|
||||
print(f" - 发票号: {invoice.invoice_number}")
|
||||
print(f" - 金额: ¥{invoice.amount_due}")
|
||||
print(f" - 状态: {invoice.status}")
|
||||
@@ -141,12 +141,12 @@ def test_subscription_manager():
|
||||
print("-" * 40)
|
||||
|
||||
# 申请退款
|
||||
refund = manager.request_refund(
|
||||
tenant_id=tenant_id,
|
||||
payment_id=payment.id,
|
||||
amount=50.0,
|
||||
reason="服务不满意",
|
||||
requested_by="user_001",
|
||||
refund = manager.request_refund(
|
||||
tenant_id = tenant_id,
|
||||
payment_id = payment.id,
|
||||
amount = 50.0,
|
||||
reason = "服务不满意",
|
||||
requested_by = "user_001",
|
||||
)
|
||||
print(f"✓ 申请退款: {refund.id}")
|
||||
print(f" - 金额: ¥{refund.amount}")
|
||||
@@ -154,21 +154,21 @@ def test_subscription_manager():
|
||||
print(f" - 状态: {refund.status}")
|
||||
|
||||
# 批准退款
|
||||
approved = manager.approve_refund(refund.id, "admin_001")
|
||||
approved = manager.approve_refund(refund.id, "admin_001")
|
||||
print(f"✓ 批准退款: {approved.status}")
|
||||
|
||||
# 完成退款
|
||||
completed = manager.complete_refund(refund.id, "refund_123456")
|
||||
completed = manager.complete_refund(refund.id, "refund_123456")
|
||||
print(f"✓ 完成退款: {completed.status}")
|
||||
|
||||
# 列出退款记录
|
||||
refunds = manager.list_refunds(tenant_id)
|
||||
refunds = manager.list_refunds(tenant_id)
|
||||
print(f"✓ 退款记录数量: {len(refunds)}")
|
||||
|
||||
print("\n7. 测试账单历史")
|
||||
print("-" * 40)
|
||||
|
||||
history = manager.get_billing_history(tenant_id)
|
||||
history = manager.get_billing_history(tenant_id)
|
||||
print(f"✓ 账单历史记录数量: {len(history)}")
|
||||
for h in history:
|
||||
print(f" - [{h.type}] {h.description}: ¥{h.amount}")
|
||||
@@ -177,24 +177,24 @@ def test_subscription_manager():
|
||||
print("-" * 40)
|
||||
|
||||
# Stripe Checkout
|
||||
stripe_session = manager.create_stripe_checkout_session(
|
||||
tenant_id=tenant_id,
|
||||
plan_id=enterprise_plan.id,
|
||||
success_url="https://example.com/success",
|
||||
cancel_url="https://example.com/cancel",
|
||||
stripe_session = manager.create_stripe_checkout_session(
|
||||
tenant_id = tenant_id,
|
||||
plan_id = enterprise_plan.id,
|
||||
success_url = "https://example.com/success",
|
||||
cancel_url = "https://example.com/cancel",
|
||||
)
|
||||
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
|
||||
|
||||
# 支付宝订单
|
||||
alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id)
|
||||
alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id)
|
||||
print(f"✓ 支付宝订单: {alipay_order['order_id']}")
|
||||
|
||||
# 微信支付订单
|
||||
wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id)
|
||||
wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id)
|
||||
print(f"✓ 微信支付订单: {wechat_order['order_id']}")
|
||||
|
||||
# Webhook 处理
|
||||
webhook_result = manager.handle_webhook(
|
||||
webhook_result = manager.handle_webhook(
|
||||
"stripe",
|
||||
{"event_type": "checkout.session.completed", "data": {"object": {"id": "cs_test"}}},
|
||||
)
|
||||
@@ -204,19 +204,19 @@ def test_subscription_manager():
|
||||
print("-" * 40)
|
||||
|
||||
# 更改计划
|
||||
changed = manager.change_plan(
|
||||
subscription_id=subscription.id, new_plan_id=enterprise_plan.id
|
||||
changed = manager.change_plan(
|
||||
subscription_id = subscription.id, new_plan_id = enterprise_plan.id
|
||||
)
|
||||
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
|
||||
|
||||
# 取消订阅
|
||||
cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True)
|
||||
cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True)
|
||||
print(f"✓ 取消订阅: {cancelled.status}")
|
||||
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("所有测试通过! ✓")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
finally:
|
||||
# 清理临时数据库
|
||||
|
||||
@@ -14,31 +14,31 @@ from ai_manager import ModelType, PredictionType, get_ai_manager
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def test_custom_model():
|
||||
def test_custom_model() -> None:
|
||||
"""测试自定义模型功能"""
|
||||
print("\n=== 测试自定义模型 ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 1. 创建自定义模型
|
||||
print("1. 创建自定义模型...")
|
||||
model = manager.create_custom_model(
|
||||
tenant_id="tenant_001",
|
||||
name="领域实体识别模型",
|
||||
description="用于识别医疗领域实体的自定义模型",
|
||||
model_type=ModelType.CUSTOM_NER,
|
||||
training_data={
|
||||
model = manager.create_custom_model(
|
||||
tenant_id = "tenant_001",
|
||||
name = "领域实体识别模型",
|
||||
description = "用于识别医疗领域实体的自定义模型",
|
||||
model_type = ModelType.CUSTOM_NER,
|
||||
training_data = {
|
||||
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
|
||||
"domain": "medical",
|
||||
},
|
||||
hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
|
||||
created_by="user_001",
|
||||
hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
|
||||
created_by = "user_001",
|
||||
)
|
||||
print(f" 创建成功: {model.id}, 状态: {model.status.value}")
|
||||
|
||||
# 2. 添加训练样本
|
||||
print("2. 添加训练样本...")
|
||||
samples = [
|
||||
samples = [
|
||||
{
|
||||
"text": "患者张三患有高血压,正在服用降压药治疗。",
|
||||
"entities": [
|
||||
@@ -66,22 +66,22 @@ def test_custom_model():
|
||||
]
|
||||
|
||||
for sample_data in samples:
|
||||
sample = manager.add_training_sample(
|
||||
model_id=model.id,
|
||||
text=sample_data["text"],
|
||||
entities=sample_data["entities"],
|
||||
metadata={"source": "manual"},
|
||||
sample = manager.add_training_sample(
|
||||
model_id = model.id,
|
||||
text = sample_data["text"],
|
||||
entities = sample_data["entities"],
|
||||
metadata = {"source": "manual"},
|
||||
)
|
||||
print(f" 添加样本: {sample.id}")
|
||||
|
||||
# 3. 获取训练样本
|
||||
print("3. 获取训练样本...")
|
||||
all_samples = manager.get_training_samples(model.id)
|
||||
all_samples = manager.get_training_samples(model.id)
|
||||
print(f" 共有 {len(all_samples)} 个训练样本")
|
||||
|
||||
# 4. 列出自定义模型
|
||||
print("4. 列出自定义模型...")
|
||||
models = manager.list_custom_models(tenant_id="tenant_001")
|
||||
models = manager.list_custom_models(tenant_id = "tenant_001")
|
||||
print(f" 找到 {len(models)} 个模型")
|
||||
for m in models:
|
||||
print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
|
||||
@@ -89,16 +89,16 @@ def test_custom_model():
|
||||
return model.id
|
||||
|
||||
|
||||
async def test_train_and_predict(model_id: str):
|
||||
async def test_train_and_predict(model_id: str) -> None:
|
||||
"""测试训练和预测"""
|
||||
print("\n=== 测试模型训练和预测 ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 1. 训练模型
|
||||
print("1. 训练模型...")
|
||||
try:
|
||||
trained_model = await manager.train_custom_model(model_id)
|
||||
trained_model = await manager.train_custom_model(model_id)
|
||||
print(f" 训练完成: {trained_model.status.value}")
|
||||
print(f" 指标: {trained_model.metrics}")
|
||||
except Exception as e:
|
||||
@@ -107,50 +107,50 @@ async def test_train_and_predict(model_id: str):
|
||||
|
||||
# 2. 使用模型预测
|
||||
print("2. 使用模型预测...")
|
||||
test_text = "赵六患有糖尿病,正在使用胰岛素治疗。"
|
||||
test_text = "赵六患有糖尿病,正在使用胰岛素治疗。"
|
||||
try:
|
||||
entities = await manager.predict_with_custom_model(model_id, test_text)
|
||||
entities = await manager.predict_with_custom_model(model_id, test_text)
|
||||
print(f" 输入: {test_text}")
|
||||
print(f" 预测实体: {entities}")
|
||||
except Exception as e:
|
||||
print(f" 预测失败: {e}")
|
||||
|
||||
|
||||
def test_prediction_models():
|
||||
def test_prediction_models() -> None:
|
||||
"""测试预测模型"""
|
||||
print("\n=== 测试预测模型 ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 1. 创建趋势预测模型
|
||||
print("1. 创建趋势预测模型...")
|
||||
trend_model = manager.create_prediction_model(
|
||||
tenant_id="tenant_001",
|
||||
project_id="project_001",
|
||||
name="实体数量趋势预测",
|
||||
prediction_type=PredictionType.TREND,
|
||||
target_entity_type="PERSON",
|
||||
features=["entity_count", "time_period", "document_count"],
|
||||
model_config={"algorithm": "linear_regression", "window_size": 7},
|
||||
trend_model = manager.create_prediction_model(
|
||||
tenant_id = "tenant_001",
|
||||
project_id = "project_001",
|
||||
name = "实体数量趋势预测",
|
||||
prediction_type = PredictionType.TREND,
|
||||
target_entity_type = "PERSON",
|
||||
features = ["entity_count", "time_period", "document_count"],
|
||||
model_config = {"algorithm": "linear_regression", "window_size": 7},
|
||||
)
|
||||
print(f" 创建成功: {trend_model.id}")
|
||||
|
||||
# 2. 创建异常检测模型
|
||||
print("2. 创建异常检测模型...")
|
||||
anomaly_model = manager.create_prediction_model(
|
||||
tenant_id="tenant_001",
|
||||
project_id="project_001",
|
||||
name="实体增长异常检测",
|
||||
prediction_type=PredictionType.ANOMALY,
|
||||
target_entity_type=None,
|
||||
features=["daily_growth", "weekly_growth"],
|
||||
model_config={"threshold": 2.5, "sensitivity": "medium"},
|
||||
anomaly_model = manager.create_prediction_model(
|
||||
tenant_id = "tenant_001",
|
||||
project_id = "project_001",
|
||||
name = "实体增长异常检测",
|
||||
prediction_type = PredictionType.ANOMALY,
|
||||
target_entity_type = None,
|
||||
features = ["daily_growth", "weekly_growth"],
|
||||
model_config = {"threshold": 2.5, "sensitivity": "medium"},
|
||||
)
|
||||
print(f" 创建成功: {anomaly_model.id}")
|
||||
|
||||
# 3. 列出预测模型
|
||||
print("3. 列出预测模型...")
|
||||
models = manager.list_prediction_models(tenant_id="tenant_001")
|
||||
models = manager.list_prediction_models(tenant_id = "tenant_001")
|
||||
print(f" 找到 {len(models)} 个预测模型")
|
||||
for m in models:
|
||||
print(f" - {m.name} ({m.prediction_type.value})")
|
||||
@@ -158,15 +158,15 @@ def test_prediction_models():
|
||||
return trend_model.id, anomaly_model.id
|
||||
|
||||
|
||||
async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
||||
async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None:
|
||||
"""测试预测功能"""
|
||||
print("\n=== 测试预测功能 ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 1. 训练趋势预测模型
|
||||
print("1. 训练趋势预测模型...")
|
||||
historical_data = [
|
||||
historical_data = [
|
||||
{"date": "2024-01-01", "value": 10},
|
||||
{"date": "2024-01-02", "value": 12},
|
||||
{"date": "2024-01-03", "value": 15},
|
||||
@@ -175,62 +175,62 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
||||
{"date": "2024-01-06", "value": 20},
|
||||
{"date": "2024-01-07", "value": 22},
|
||||
]
|
||||
trained = await manager.train_prediction_model(trend_model_id, historical_data)
|
||||
trained = await manager.train_prediction_model(trend_model_id, historical_data)
|
||||
print(f" 训练完成,准确率: {trained.accuracy}")
|
||||
|
||||
# 2. 趋势预测
|
||||
print("2. 趋势预测...")
|
||||
trend_result = await manager.predict(
|
||||
trend_result = await manager.predict(
|
||||
trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]}
|
||||
)
|
||||
print(f" 预测结果: {trend_result.prediction_data}")
|
||||
|
||||
# 3. 异常检测
|
||||
print("3. 异常检测...")
|
||||
anomaly_result = await manager.predict(
|
||||
anomaly_result = await manager.predict(
|
||||
anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]}
|
||||
)
|
||||
print(f" 检测结果: {anomaly_result.prediction_data}")
|
||||
|
||||
|
||||
def test_kg_rag():
|
||||
def test_kg_rag() -> None:
|
||||
"""测试知识图谱 RAG"""
|
||||
print("\n=== 测试知识图谱 RAG ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 创建 RAG 配置
|
||||
print("1. 创建知识图谱 RAG 配置...")
|
||||
rag = manager.create_kg_rag(
|
||||
tenant_id="tenant_001",
|
||||
project_id="project_001",
|
||||
name="项目知识问答",
|
||||
description="基于项目知识图谱的智能问答",
|
||||
kg_config={
|
||||
rag = manager.create_kg_rag(
|
||||
tenant_id = "tenant_001",
|
||||
project_id = "project_001",
|
||||
name = "项目知识问答",
|
||||
description = "基于项目知识图谱的智能问答",
|
||||
kg_config = {
|
||||
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
|
||||
"relation_types": ["works_with", "belongs_to", "depends_on"],
|
||||
},
|
||||
retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
|
||||
generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
|
||||
retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
|
||||
generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
|
||||
)
|
||||
print(f" 创建成功: {rag.id}")
|
||||
|
||||
# 列出 RAG 配置
|
||||
print("2. 列出 RAG 配置...")
|
||||
rags = manager.list_kg_rags(tenant_id="tenant_001")
|
||||
rags = manager.list_kg_rags(tenant_id = "tenant_001")
|
||||
print(f" 找到 {len(rags)} 个配置")
|
||||
|
||||
return rag.id
|
||||
|
||||
|
||||
async def test_kg_rag_query(rag_id: str):
|
||||
async def test_kg_rag_query(rag_id: str) -> None:
|
||||
"""测试 RAG 查询"""
|
||||
print("\n=== 测试知识图谱 RAG 查询 ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 模拟项目实体和关系
|
||||
project_entities = [
|
||||
project_entities = [
|
||||
{"id": "e1", "name": "张三", "type": "PERSON", "definition": "项目经理"},
|
||||
{"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"},
|
||||
{"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"},
|
||||
@@ -238,7 +238,7 @@ async def test_kg_rag_query(rag_id: str):
|
||||
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"},
|
||||
]
|
||||
|
||||
project_relations = [
|
||||
project_relations = [
|
||||
{
|
||||
"source_entity_id": "e1",
|
||||
"target_entity_id": "e3",
|
||||
@@ -275,14 +275,14 @@ async def test_kg_rag_query(rag_id: str):
|
||||
|
||||
# 执行查询
|
||||
print("1. 执行 RAG 查询...")
|
||||
query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?"
|
||||
query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?"
|
||||
|
||||
try:
|
||||
result = await manager.query_kg_rag(
|
||||
rag_id=rag_id,
|
||||
query=query_text,
|
||||
project_entities=project_entities,
|
||||
project_relations=project_relations,
|
||||
result = await manager.query_kg_rag(
|
||||
rag_id = rag_id,
|
||||
query = query_text,
|
||||
project_entities = project_entities,
|
||||
project_relations = project_relations,
|
||||
)
|
||||
|
||||
print(f" 查询: {result.query}")
|
||||
@@ -294,14 +294,14 @@ async def test_kg_rag_query(rag_id: str):
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
|
||||
async def test_smart_summary():
|
||||
async def test_smart_summary() -> None:
|
||||
"""测试智能摘要"""
|
||||
print("\n=== 测试智能摘要 ===")
|
||||
|
||||
manager = get_ai_manager()
|
||||
manager = get_ai_manager()
|
||||
|
||||
# 模拟转录文本
|
||||
transcript_text = """
|
||||
transcript_text = """
|
||||
今天的会议主要讨论了 Project Alpha 的进展情况。张三作为项目经理,
|
||||
汇报了当前的项目进度,表示已经完成了 80% 的开发工作。李四提出了
|
||||
一些关于 Kubernetes 部署的问题,建议我们采用新的部署策略。
|
||||
@@ -309,7 +309,7 @@ async def test_smart_summary():
|
||||
大家一致认为项目进展顺利,预计可以按时交付。
|
||||
"""
|
||||
|
||||
content_data = {
|
||||
content_data = {
|
||||
"text": transcript_text,
|
||||
"entities": [
|
||||
{"name": "张三", "type": "PERSON"},
|
||||
@@ -320,18 +320,18 @@ async def test_smart_summary():
|
||||
}
|
||||
|
||||
# 生成不同类型的摘要
|
||||
summary_types = ["extractive", "abstractive", "key_points"]
|
||||
summary_types = ["extractive", "abstractive", "key_points"]
|
||||
|
||||
for summary_type in summary_types:
|
||||
print(f"1. 生成 {summary_type} 类型摘要...")
|
||||
try:
|
||||
summary = await manager.generate_smart_summary(
|
||||
tenant_id="tenant_001",
|
||||
project_id="project_001",
|
||||
source_type="transcript",
|
||||
source_id="transcript_001",
|
||||
summary_type=summary_type,
|
||||
content_data=content_data,
|
||||
summary = await manager.generate_smart_summary(
|
||||
tenant_id = "tenant_001",
|
||||
project_id = "project_001",
|
||||
source_type = "transcript",
|
||||
source_id = "transcript_001",
|
||||
summary_type = summary_type,
|
||||
content_data = content_data,
|
||||
)
|
||||
|
||||
print(f" 摘要类型: {summary.summary_type}")
|
||||
@@ -342,27 +342,27 @@ async def test_smart_summary():
|
||||
print(f" 生成失败: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
async def main() -> None:
|
||||
"""主测试函数"""
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("InsightFlow Phase 8 Task 4 - AI 能力增强测试")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
try:
|
||||
# 测试自定义模型
|
||||
model_id = test_custom_model()
|
||||
model_id = test_custom_model()
|
||||
|
||||
# 测试训练和预测
|
||||
await test_train_and_predict(model_id)
|
||||
|
||||
# 测试预测模型
|
||||
trend_model_id, anomaly_model_id = test_prediction_models()
|
||||
trend_model_id, anomaly_model_id = test_prediction_models()
|
||||
|
||||
# 测试预测功能
|
||||
await test_predictions(trend_model_id, anomaly_model_id)
|
||||
|
||||
# 测试知识图谱 RAG
|
||||
rag_id = test_kg_rag()
|
||||
rag_id = test_kg_rag()
|
||||
|
||||
# 测试 RAG 查询
|
||||
await test_kg_rag_query(rag_id)
|
||||
@@ -370,9 +370,9 @@ async def main():
|
||||
# 测试智能摘要
|
||||
await test_smart_summary()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("所有测试完成!")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n测试失败: {e}")
|
||||
|
||||
@@ -28,7 +28,7 @@ from growth_manager import (
|
||||
)
|
||||
|
||||
# 添加 backend 目录到路径
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if backend_dir not in sys.path:
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
@@ -36,35 +36,35 @@ if backend_dir not in sys.path:
|
||||
class TestGrowthManager:
|
||||
"""测试 Growth Manager 功能"""
|
||||
|
||||
def __init__(self):
|
||||
self.manager = GrowthManager()
|
||||
self.test_tenant_id = "test_tenant_001"
|
||||
self.test_user_id = "test_user_001"
|
||||
self.test_results = []
|
||||
def __init__(self) -> None:
|
||||
self.manager = GrowthManager()
|
||||
self.test_tenant_id = "test_tenant_001"
|
||||
self.test_user_id = "test_user_001"
|
||||
self.test_results = []
|
||||
|
||||
def log(self, message: str, success: bool = True):
|
||||
def log(self, message: str, success: bool = True) -> None:
|
||||
"""记录测试结果"""
|
||||
status = "✅" if success else "❌"
|
||||
status = "✅" if success else "❌"
|
||||
print(f"{status} {message}")
|
||||
self.test_results.append((message, success))
|
||||
|
||||
# ==================== 测试用户行为分析 ====================
|
||||
|
||||
async def test_track_event(self):
|
||||
async def test_track_event(self) -> None:
|
||||
"""测试事件追踪"""
|
||||
print("\n📊 测试事件追踪...")
|
||||
|
||||
try:
|
||||
event = await self.manager.track_event(
|
||||
tenant_id=self.test_tenant_id,
|
||||
user_id=self.test_user_id,
|
||||
event_type=EventType.PAGE_VIEW,
|
||||
event_name="dashboard_view",
|
||||
properties={"page": "/dashboard", "duration": 120},
|
||||
session_id="session_001",
|
||||
device_info={"browser": "Chrome", "os": "MacOS"},
|
||||
referrer="https://google.com",
|
||||
utm_params={"source": "google", "medium": "organic", "campaign": "summer"},
|
||||
event = await self.manager.track_event(
|
||||
tenant_id = self.test_tenant_id,
|
||||
user_id = self.test_user_id,
|
||||
event_type = EventType.PAGE_VIEW,
|
||||
event_name = "dashboard_view",
|
||||
properties = {"page": "/dashboard", "duration": 120},
|
||||
session_id = "session_001",
|
||||
device_info = {"browser": "Chrome", "os": "MacOS"},
|
||||
referrer = "https://google.com",
|
||||
utm_params = {"source": "google", "medium": "organic", "campaign": "summer"},
|
||||
)
|
||||
|
||||
assert event.id is not None
|
||||
@@ -74,15 +74,15 @@ class TestGrowthManager:
|
||||
self.log(f"事件追踪成功: {event.id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"事件追踪失败: {e}", success=False)
|
||||
self.log(f"事件追踪失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
async def test_track_multiple_events(self):
|
||||
async def test_track_multiple_events(self) -> None:
|
||||
"""测试追踪多个事件"""
|
||||
print("\n📊 测试追踪多个事件...")
|
||||
|
||||
try:
|
||||
events = [
|
||||
events = [
|
||||
(EventType.FEATURE_USE, "entity_extraction", {"entity_count": 5}),
|
||||
(EventType.FEATURE_USE, "relation_discovery", {"relation_count": 3}),
|
||||
(EventType.CONVERSION, "upgrade_click", {"plan": "pro"}),
|
||||
@@ -91,25 +91,25 @@ class TestGrowthManager:
|
||||
|
||||
for event_type, event_name, props in events:
|
||||
await self.manager.track_event(
|
||||
tenant_id=self.test_tenant_id,
|
||||
user_id=self.test_user_id,
|
||||
event_type=event_type,
|
||||
event_name=event_name,
|
||||
properties=props,
|
||||
tenant_id = self.test_tenant_id,
|
||||
user_id = self.test_user_id,
|
||||
event_type = event_type,
|
||||
event_name = event_name,
|
||||
properties = props,
|
||||
)
|
||||
|
||||
self.log(f"成功追踪 {len(events)} 个事件")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"批量事件追踪失败: {e}", success=False)
|
||||
self.log(f"批量事件追踪失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_get_user_profile(self):
|
||||
def test_get_user_profile(self) -> None:
|
||||
"""测试获取用户画像"""
|
||||
print("\n👤 测试用户画像...")
|
||||
|
||||
try:
|
||||
profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id)
|
||||
profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id)
|
||||
|
||||
if profile:
|
||||
assert profile.user_id == self.test_user_id
|
||||
@@ -120,18 +120,18 @@ class TestGrowthManager:
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"获取用户画像失败: {e}", success=False)
|
||||
self.log(f"获取用户画像失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_get_analytics_summary(self):
|
||||
def test_get_analytics_summary(self) -> None:
|
||||
"""测试获取分析汇总"""
|
||||
print("\n📈 测试分析汇总...")
|
||||
|
||||
try:
|
||||
summary = self.manager.get_user_analytics_summary(
|
||||
tenant_id=self.test_tenant_id,
|
||||
start_date=datetime.now() - timedelta(days=7),
|
||||
end_date=datetime.now(),
|
||||
summary = self.manager.get_user_analytics_summary(
|
||||
tenant_id = self.test_tenant_id,
|
||||
start_date = datetime.now() - timedelta(days = 7),
|
||||
end_date = datetime.now(),
|
||||
)
|
||||
|
||||
assert "unique_users" in summary
|
||||
@@ -141,25 +141,25 @@ class TestGrowthManager:
|
||||
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"获取分析汇总失败: {e}", success=False)
|
||||
self.log(f"获取分析汇总失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_create_funnel(self):
|
||||
def test_create_funnel(self) -> None:
|
||||
"""测试创建转化漏斗"""
|
||||
print("\n🎯 测试创建转化漏斗...")
|
||||
|
||||
try:
|
||||
funnel = self.manager.create_funnel(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="用户注册转化漏斗",
|
||||
description="从访问到完成注册的转化流程",
|
||||
steps=[
|
||||
funnel = self.manager.create_funnel(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "用户注册转化漏斗",
|
||||
description = "从访问到完成注册的转化流程",
|
||||
steps = [
|
||||
{"name": "访问首页", "event_name": "page_view_home"},
|
||||
{"name": "点击注册", "event_name": "signup_click"},
|
||||
{"name": "填写信息", "event_name": "signup_form_fill"},
|
||||
{"name": "完成注册", "event_name": "signup_complete"},
|
||||
],
|
||||
created_by="test",
|
||||
created_by = "test",
|
||||
)
|
||||
|
||||
assert funnel.id is not None
|
||||
@@ -168,10 +168,10 @@ class TestGrowthManager:
|
||||
self.log(f"漏斗创建成功: {funnel.id}")
|
||||
return funnel.id
|
||||
except Exception as e:
|
||||
self.log(f"创建漏斗失败: {e}", success=False)
|
||||
self.log(f"创建漏斗失败: {e}", success = False)
|
||||
return None
|
||||
|
||||
def test_analyze_funnel(self, funnel_id: str):
|
||||
def test_analyze_funnel(self, funnel_id: str) -> None:
|
||||
"""测试分析漏斗"""
|
||||
print("\n📉 测试漏斗分析...")
|
||||
|
||||
@@ -180,10 +180,10 @@ class TestGrowthManager:
|
||||
return False
|
||||
|
||||
try:
|
||||
analysis = self.manager.analyze_funnel(
|
||||
funnel_id=funnel_id,
|
||||
period_start=datetime.now() - timedelta(days=30),
|
||||
period_end=datetime.now(),
|
||||
analysis = self.manager.analyze_funnel(
|
||||
funnel_id = funnel_id,
|
||||
period_start = datetime.now() - timedelta(days = 30),
|
||||
period_end = datetime.now(),
|
||||
)
|
||||
|
||||
if analysis:
|
||||
@@ -194,18 +194,18 @@ class TestGrowthManager:
|
||||
self.log("漏斗分析返回空结果")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log(f"漏斗分析失败: {e}", success=False)
|
||||
self.log(f"漏斗分析失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_calculate_retention(self):
|
||||
def test_calculate_retention(self) -> None:
|
||||
"""测试留存率计算"""
|
||||
print("\n🔄 测试留存率计算...")
|
||||
|
||||
try:
|
||||
retention = self.manager.calculate_retention(
|
||||
tenant_id=self.test_tenant_id,
|
||||
cohort_date=datetime.now() - timedelta(days=7),
|
||||
periods=[1, 3, 7],
|
||||
retention = self.manager.calculate_retention(
|
||||
tenant_id = self.test_tenant_id,
|
||||
cohort_date = datetime.now() - timedelta(days = 7),
|
||||
periods = [1, 3, 7],
|
||||
)
|
||||
|
||||
assert "cohort_date" in retention
|
||||
@@ -214,34 +214,34 @@ class TestGrowthManager:
|
||||
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"留存率计算失败: {e}", success=False)
|
||||
self.log(f"留存率计算失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
# ==================== 测试 A/B 测试框架 ====================
|
||||
|
||||
def test_create_experiment(self):
|
||||
def test_create_experiment(self) -> None:
|
||||
"""测试创建实验"""
|
||||
print("\n🧪 测试创建 A/B 测试实验...")
|
||||
|
||||
try:
|
||||
experiment = self.manager.create_experiment(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="首页按钮颜色测试",
|
||||
description="测试不同按钮颜色对转化率的影响",
|
||||
hypothesis="蓝色按钮比红色按钮有更高的点击率",
|
||||
variants=[
|
||||
experiment = self.manager.create_experiment(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "首页按钮颜色测试",
|
||||
description = "测试不同按钮颜色对转化率的影响",
|
||||
hypothesis = "蓝色按钮比红色按钮有更高的点击率",
|
||||
variants = [
|
||||
{"id": "control", "name": "红色按钮", "is_control": True},
|
||||
{"id": "variant_a", "name": "蓝色按钮", "is_control": False},
|
||||
{"id": "variant_b", "name": "绿色按钮", "is_control": False},
|
||||
],
|
||||
traffic_allocation=TrafficAllocationType.RANDOM,
|
||||
traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
|
||||
target_audience={"conditions": []},
|
||||
primary_metric="button_click_rate",
|
||||
secondary_metrics=["conversion_rate", "bounce_rate"],
|
||||
min_sample_size=100,
|
||||
confidence_level=0.95,
|
||||
created_by="test",
|
||||
traffic_allocation = TrafficAllocationType.RANDOM,
|
||||
traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
|
||||
target_audience = {"conditions": []},
|
||||
primary_metric = "button_click_rate",
|
||||
secondary_metrics = ["conversion_rate", "bounce_rate"],
|
||||
min_sample_size = 100,
|
||||
confidence_level = 0.95,
|
||||
created_by = "test",
|
||||
)
|
||||
|
||||
assert experiment.id is not None
|
||||
@@ -250,23 +250,23 @@ class TestGrowthManager:
|
||||
self.log(f"实验创建成功: {experiment.id}")
|
||||
return experiment.id
|
||||
except Exception as e:
|
||||
self.log(f"创建实验失败: {e}", success=False)
|
||||
self.log(f"创建实验失败: {e}", success = False)
|
||||
return None
|
||||
|
||||
def test_list_experiments(self):
|
||||
def test_list_experiments(self) -> None:
|
||||
"""测试列出实验"""
|
||||
print("\n📋 测试列出实验...")
|
||||
|
||||
try:
|
||||
experiments = self.manager.list_experiments(self.test_tenant_id)
|
||||
experiments = self.manager.list_experiments(self.test_tenant_id)
|
||||
|
||||
self.log(f"列出 {len(experiments)} 个实验")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"列出实验失败: {e}", success=False)
|
||||
self.log(f"列出实验失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_assign_variant(self, experiment_id: str):
|
||||
def test_assign_variant(self, experiment_id: str) -> None:
|
||||
"""测试分配变体"""
|
||||
print("\n🎲 测试分配实验变体...")
|
||||
|
||||
@@ -279,26 +279,26 @@ class TestGrowthManager:
|
||||
self.manager.start_experiment(experiment_id)
|
||||
|
||||
# 测试多个用户的变体分配
|
||||
test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"]
|
||||
assignments = {}
|
||||
test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"]
|
||||
assignments = {}
|
||||
|
||||
for user_id in test_users:
|
||||
variant_id = self.manager.assign_variant(
|
||||
experiment_id=experiment_id,
|
||||
user_id=user_id,
|
||||
user_attributes={"user_id": user_id, "segment": "new"},
|
||||
variant_id = self.manager.assign_variant(
|
||||
experiment_id = experiment_id,
|
||||
user_id = user_id,
|
||||
user_attributes = {"user_id": user_id, "segment": "new"},
|
||||
)
|
||||
|
||||
if variant_id:
|
||||
assignments[user_id] = variant_id
|
||||
assignments[user_id] = variant_id
|
||||
|
||||
self.log(f"变体分配完成: {len(assignments)} 个用户")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"变体分配失败: {e}", success=False)
|
||||
self.log(f"变体分配失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_record_experiment_metric(self, experiment_id: str):
|
||||
def test_record_experiment_metric(self, experiment_id: str) -> None:
|
||||
"""测试记录实验指标"""
|
||||
print("\n📊 测试记录实验指标...")
|
||||
|
||||
@@ -308,7 +308,7 @@ class TestGrowthManager:
|
||||
|
||||
try:
|
||||
# 模拟记录一些指标
|
||||
test_data = [
|
||||
test_data = [
|
||||
("user_001", "control", 1),
|
||||
("user_002", "variant_a", 1),
|
||||
("user_003", "variant_b", 0),
|
||||
@@ -318,20 +318,20 @@ class TestGrowthManager:
|
||||
|
||||
for user_id, variant_id, value in test_data:
|
||||
self.manager.record_experiment_metric(
|
||||
experiment_id=experiment_id,
|
||||
variant_id=variant_id,
|
||||
user_id=user_id,
|
||||
metric_name="button_click_rate",
|
||||
metric_value=value,
|
||||
experiment_id = experiment_id,
|
||||
variant_id = variant_id,
|
||||
user_id = user_id,
|
||||
metric_name = "button_click_rate",
|
||||
metric_value = value,
|
||||
)
|
||||
|
||||
self.log(f"成功记录 {len(test_data)} 条指标")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"记录指标失败: {e}", success=False)
|
||||
self.log(f"记录指标失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_analyze_experiment(self, experiment_id: str):
|
||||
def test_analyze_experiment(self, experiment_id: str) -> None:
|
||||
"""测试分析实验结果"""
|
||||
print("\n📈 测试分析实验结果...")
|
||||
|
||||
@@ -340,31 +340,31 @@ class TestGrowthManager:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = self.manager.analyze_experiment(experiment_id)
|
||||
result = self.manager.analyze_experiment(experiment_id)
|
||||
|
||||
if "error" not in result:
|
||||
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
|
||||
return True
|
||||
else:
|
||||
self.log(f"实验分析返回错误: {result['error']}", success=False)
|
||||
self.log(f"实验分析返回错误: {result['error']}", success = False)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log(f"实验分析失败: {e}", success=False)
|
||||
self.log(f"实验分析失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
# ==================== 测试邮件营销 ====================
|
||||
|
||||
def test_create_email_template(self):
|
||||
def test_create_email_template(self) -> None:
|
||||
"""测试创建邮件模板"""
|
||||
print("\n📧 测试创建邮件模板...")
|
||||
|
||||
try:
|
||||
template = self.manager.create_email_template(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="欢迎邮件",
|
||||
template_type=EmailTemplateType.WELCOME,
|
||||
subject="欢迎加入 InsightFlow!",
|
||||
html_content="""
|
||||
template = self.manager.create_email_template(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "欢迎邮件",
|
||||
template_type = EmailTemplateType.WELCOME,
|
||||
subject = "欢迎加入 InsightFlow!",
|
||||
html_content = """
|
||||
<h1>欢迎,{{user_name}}!</h1>
|
||||
<p>感谢您注册 InsightFlow。我们很高兴您能加入我们!</p>
|
||||
<p>您的账户已创建,可以开始使用以下功能:</p>
|
||||
@@ -373,10 +373,10 @@ class TestGrowthManager:
|
||||
<li>智能实体提取</li>
|
||||
<li>团队协作</li>
|
||||
</ul>
|
||||
<p><a href="{{dashboard_url}}">立即开始使用</a></p>
|
||||
<p><a href = "{{dashboard_url}}">立即开始使用</a></p>
|
||||
""",
|
||||
from_name="InsightFlow 团队",
|
||||
from_email="welcome@insightflow.io",
|
||||
from_name = "InsightFlow 团队",
|
||||
from_email = "welcome@insightflow.io",
|
||||
)
|
||||
|
||||
assert template.id is not None
|
||||
@@ -385,23 +385,23 @@ class TestGrowthManager:
|
||||
self.log(f"邮件模板创建成功: {template.id}")
|
||||
return template.id
|
||||
except Exception as e:
|
||||
self.log(f"创建邮件模板失败: {e}", success=False)
|
||||
self.log(f"创建邮件模板失败: {e}", success = False)
|
||||
return None
|
||||
|
||||
def test_list_email_templates(self):
|
||||
def test_list_email_templates(self) -> None:
|
||||
"""测试列出邮件模板"""
|
||||
print("\n📧 测试列出邮件模板...")
|
||||
|
||||
try:
|
||||
templates = self.manager.list_email_templates(self.test_tenant_id)
|
||||
templates = self.manager.list_email_templates(self.test_tenant_id)
|
||||
|
||||
self.log(f"列出 {len(templates)} 个邮件模板")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"列出邮件模板失败: {e}", success=False)
|
||||
self.log(f"列出邮件模板失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_render_template(self, template_id: str):
|
||||
def test_render_template(self, template_id: str) -> None:
|
||||
"""测试渲染邮件模板"""
|
||||
print("\n🎨 测试渲染邮件模板...")
|
||||
|
||||
@@ -410,9 +410,9 @@ class TestGrowthManager:
|
||||
return False
|
||||
|
||||
try:
|
||||
rendered = self.manager.render_template(
|
||||
template_id=template_id,
|
||||
variables={
|
||||
rendered = self.manager.render_template(
|
||||
template_id = template_id,
|
||||
variables = {
|
||||
"user_name": "张三",
|
||||
"dashboard_url": "https://app.insightflow.io/dashboard",
|
||||
},
|
||||
@@ -424,13 +424,13 @@ class TestGrowthManager:
|
||||
self.log(f"模板渲染成功: {rendered['subject']}")
|
||||
return True
|
||||
else:
|
||||
self.log("模板渲染返回空结果", success=False)
|
||||
self.log("模板渲染返回空结果", success = False)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log(f"模板渲染失败: {e}", success=False)
|
||||
self.log(f"模板渲染失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_create_email_campaign(self, template_id: str):
|
||||
def test_create_email_campaign(self, template_id: str) -> None:
|
||||
"""测试创建邮件营销活动"""
|
||||
print("\n📮 测试创建邮件营销活动...")
|
||||
|
||||
@@ -439,11 +439,11 @@ class TestGrowthManager:
|
||||
return None
|
||||
|
||||
try:
|
||||
campaign = self.manager.create_email_campaign(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="新用户欢迎活动",
|
||||
template_id=template_id,
|
||||
recipient_list=[
|
||||
campaign = self.manager.create_email_campaign(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "新用户欢迎活动",
|
||||
template_id = template_id,
|
||||
recipient_list = [
|
||||
{"user_id": "user_001", "email": "user1@example.com"},
|
||||
{"user_id": "user_002", "email": "user2@example.com"},
|
||||
{"user_id": "user_003", "email": "user3@example.com"},
|
||||
@@ -456,21 +456,21 @@ class TestGrowthManager:
|
||||
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
|
||||
return campaign.id
|
||||
except Exception as e:
|
||||
self.log(f"创建营销活动失败: {e}", success=False)
|
||||
self.log(f"创建营销活动失败: {e}", success = False)
|
||||
return None
|
||||
|
||||
def test_create_automation_workflow(self):
|
||||
def test_create_automation_workflow(self) -> None:
|
||||
"""测试创建自动化工作流"""
|
||||
print("\n🤖 测试创建自动化工作流...")
|
||||
|
||||
try:
|
||||
workflow = self.manager.create_automation_workflow(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="新用户欢迎序列",
|
||||
description="用户注册后自动发送欢迎邮件序列",
|
||||
trigger_type=WorkflowTriggerType.USER_SIGNUP,
|
||||
trigger_conditions={"event": "user_signup"},
|
||||
actions=[
|
||||
workflow = self.manager.create_automation_workflow(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "新用户欢迎序列",
|
||||
description = "用户注册后自动发送欢迎邮件序列",
|
||||
trigger_type = WorkflowTriggerType.USER_SIGNUP,
|
||||
trigger_conditions = {"event": "user_signup"},
|
||||
actions = [
|
||||
{"type": "send_email", "template_type": "welcome", "delay_hours": 0},
|
||||
{"type": "send_email", "template_type": "onboarding", "delay_hours": 24},
|
||||
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72},
|
||||
@@ -483,27 +483,27 @@ class TestGrowthManager:
|
||||
self.log(f"自动化工作流创建成功: {workflow.id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"创建工作流失败: {e}", success=False)
|
||||
self.log(f"创建工作流失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
# ==================== 测试推荐系统 ====================
|
||||
|
||||
def test_create_referral_program(self):
|
||||
def test_create_referral_program(self) -> None:
|
||||
"""测试创建推荐计划"""
|
||||
print("\n🎁 测试创建推荐计划...")
|
||||
|
||||
try:
|
||||
program = self.manager.create_referral_program(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="邀请好友奖励计划",
|
||||
description="邀请好友注册,双方获得积分奖励",
|
||||
referrer_reward_type="credit",
|
||||
referrer_reward_value=100.0,
|
||||
referee_reward_type="credit",
|
||||
referee_reward_value=50.0,
|
||||
max_referrals_per_user=10,
|
||||
referral_code_length=8,
|
||||
expiry_days=30,
|
||||
program = self.manager.create_referral_program(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "邀请好友奖励计划",
|
||||
description = "邀请好友注册,双方获得积分奖励",
|
||||
referrer_reward_type = "credit",
|
||||
referrer_reward_value = 100.0,
|
||||
referee_reward_type = "credit",
|
||||
referee_reward_value = 50.0,
|
||||
max_referrals_per_user = 10,
|
||||
referral_code_length = 8,
|
||||
expiry_days = 30,
|
||||
)
|
||||
|
||||
assert program.id is not None
|
||||
@@ -512,10 +512,10 @@ class TestGrowthManager:
|
||||
self.log(f"推荐计划创建成功: {program.id}")
|
||||
return program.id
|
||||
except Exception as e:
|
||||
self.log(f"创建推荐计划失败: {e}", success=False)
|
||||
self.log(f"创建推荐计划失败: {e}", success = False)
|
||||
return None
|
||||
|
||||
def test_generate_referral_code(self, program_id: str):
|
||||
def test_generate_referral_code(self, program_id: str) -> None:
|
||||
"""测试生成推荐码"""
|
||||
print("\n🔑 测试生成推荐码...")
|
||||
|
||||
@@ -524,8 +524,8 @@ class TestGrowthManager:
|
||||
return None
|
||||
|
||||
try:
|
||||
referral = self.manager.generate_referral_code(
|
||||
program_id=program_id, referrer_id="referrer_user_001"
|
||||
referral = self.manager.generate_referral_code(
|
||||
program_id = program_id, referrer_id = "referrer_user_001"
|
||||
)
|
||||
|
||||
if referral:
|
||||
@@ -535,13 +535,13 @@ class TestGrowthManager:
|
||||
self.log(f"推荐码生成成功: {referral.referral_code}")
|
||||
return referral.referral_code
|
||||
else:
|
||||
self.log("生成推荐码返回空结果", success=False)
|
||||
self.log("生成推荐码返回空结果", success = False)
|
||||
return None
|
||||
except Exception as e:
|
||||
self.log(f"生成推荐码失败: {e}", success=False)
|
||||
self.log(f"生成推荐码失败: {e}", success = False)
|
||||
return None
|
||||
|
||||
def test_apply_referral_code(self, referral_code: str):
|
||||
def test_apply_referral_code(self, referral_code: str) -> None:
|
||||
"""测试应用推荐码"""
|
||||
print("\n✅ 测试应用推荐码...")
|
||||
|
||||
@@ -550,21 +550,21 @@ class TestGrowthManager:
|
||||
return False
|
||||
|
||||
try:
|
||||
success = self.manager.apply_referral_code(
|
||||
referral_code=referral_code, referee_id="new_user_001"
|
||||
success = self.manager.apply_referral_code(
|
||||
referral_code = referral_code, referee_id = "new_user_001"
|
||||
)
|
||||
|
||||
if success:
|
||||
self.log(f"推荐码应用成功: {referral_code}")
|
||||
return True
|
||||
else:
|
||||
self.log("推荐码应用失败", success=False)
|
||||
self.log("推荐码应用失败", success = False)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log(f"应用推荐码失败: {e}", success=False)
|
||||
self.log(f"应用推荐码失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_get_referral_stats(self, program_id: str):
|
||||
def test_get_referral_stats(self, program_id: str) -> None:
|
||||
"""测试获取推荐统计"""
|
||||
print("\n📊 测试获取推荐统计...")
|
||||
|
||||
@@ -573,7 +573,7 @@ class TestGrowthManager:
|
||||
return False
|
||||
|
||||
try:
|
||||
stats = self.manager.get_referral_stats(program_id)
|
||||
stats = self.manager.get_referral_stats(program_id)
|
||||
|
||||
assert "total_referrals" in stats
|
||||
assert "conversion_rate" in stats
|
||||
@@ -583,24 +583,24 @@ class TestGrowthManager:
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"获取推荐统计失败: {e}", success=False)
|
||||
self.log(f"获取推荐统计失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_create_team_incentive(self):
|
||||
def test_create_team_incentive(self) -> None:
|
||||
"""测试创建团队激励"""
|
||||
print("\n🏆 测试创建团队升级激励...")
|
||||
|
||||
try:
|
||||
incentive = self.manager.create_team_incentive(
|
||||
tenant_id=self.test_tenant_id,
|
||||
name="团队升级奖励",
|
||||
description="团队规模达到5人升级到 Pro 计划可获得折扣",
|
||||
target_tier="pro",
|
||||
min_team_size=5,
|
||||
incentive_type="discount",
|
||||
incentive_value=20.0, # 20% 折扣
|
||||
valid_from=datetime.now(),
|
||||
valid_until=datetime.now() + timedelta(days=90),
|
||||
incentive = self.manager.create_team_incentive(
|
||||
tenant_id = self.test_tenant_id,
|
||||
name = "团队升级奖励",
|
||||
description = "团队规模达到5人升级到 Pro 计划可获得折扣",
|
||||
target_tier = "pro",
|
||||
min_team_size = 5,
|
||||
incentive_type = "discount",
|
||||
incentive_value = 20.0, # 20% 折扣
|
||||
valid_from = datetime.now(),
|
||||
valid_until = datetime.now() + timedelta(days = 90),
|
||||
)
|
||||
|
||||
assert incentive.id is not None
|
||||
@@ -609,116 +609,116 @@ class TestGrowthManager:
|
||||
self.log(f"团队激励创建成功: {incentive.id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"创建团队激励失败: {e}", success=False)
|
||||
self.log(f"创建团队激励失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
def test_check_team_incentive_eligibility(self):
|
||||
def test_check_team_incentive_eligibility(self) -> None:
|
||||
"""测试检查团队激励资格"""
|
||||
print("\n🔍 测试检查团队激励资格...")
|
||||
|
||||
try:
|
||||
incentives = self.manager.check_team_incentive_eligibility(
|
||||
tenant_id=self.test_tenant_id, current_tier="free", team_size=5
|
||||
incentives = self.manager.check_team_incentive_eligibility(
|
||||
tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5
|
||||
)
|
||||
|
||||
self.log(f"找到 {len(incentives)} 个符合条件的激励")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"检查激励资格失败: {e}", success=False)
|
||||
self.log(f"检查激励资格失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
# ==================== 测试实时仪表板 ====================
|
||||
|
||||
def test_get_realtime_dashboard(self):
|
||||
def test_get_realtime_dashboard(self) -> None:
|
||||
"""测试获取实时仪表板"""
|
||||
print("\n📺 测试实时分析仪表板...")
|
||||
|
||||
try:
|
||||
dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id)
|
||||
dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id)
|
||||
|
||||
assert "today" in dashboard
|
||||
assert "recent_events" in dashboard
|
||||
assert "top_features" in dashboard
|
||||
|
||||
today = dashboard["today"]
|
||||
today = dashboard["today"]
|
||||
self.log(
|
||||
f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"获取实时仪表板失败: {e}", success=False)
|
||||
self.log(f"获取实时仪表板失败: {e}", success = False)
|
||||
return False
|
||||
|
||||
# ==================== 运行所有测试 ====================
|
||||
|
||||
async def run_all_tests(self):
|
||||
async def run_all_tests(self) -> None:
|
||||
"""运行所有测试"""
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
# 用户行为分析测试
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("📊 模块 1: 用户行为分析")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
await self.test_track_event()
|
||||
await self.test_track_multiple_events()
|
||||
self.test_get_user_profile()
|
||||
self.test_get_analytics_summary()
|
||||
funnel_id = self.test_create_funnel()
|
||||
funnel_id = self.test_create_funnel()
|
||||
self.test_analyze_funnel(funnel_id)
|
||||
self.test_calculate_retention()
|
||||
|
||||
# A/B 测试框架测试
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("🧪 模块 2: A/B 测试框架")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
experiment_id = self.test_create_experiment()
|
||||
experiment_id = self.test_create_experiment()
|
||||
self.test_list_experiments()
|
||||
self.test_assign_variant(experiment_id)
|
||||
self.test_record_experiment_metric(experiment_id)
|
||||
self.test_analyze_experiment(experiment_id)
|
||||
|
||||
# 邮件营销测试
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("📧 模块 3: 邮件营销自动化")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
template_id = self.test_create_email_template()
|
||||
template_id = self.test_create_email_template()
|
||||
self.test_list_email_templates()
|
||||
self.test_render_template(template_id)
|
||||
self.test_create_email_campaign(template_id)
|
||||
self.test_create_automation_workflow()
|
||||
|
||||
# 推荐系统测试
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("🎁 模块 4: 推荐系统")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
program_id = self.test_create_referral_program()
|
||||
referral_code = self.test_generate_referral_code(program_id)
|
||||
program_id = self.test_create_referral_program()
|
||||
referral_code = self.test_generate_referral_code(program_id)
|
||||
self.test_apply_referral_code(referral_code)
|
||||
self.test_get_referral_stats(program_id)
|
||||
self.test_create_team_incentive()
|
||||
self.test_check_team_incentive_eligibility()
|
||||
|
||||
# 实时仪表板测试
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("📺 模块 5: 实时分析仪表板")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
self.test_get_realtime_dashboard()
|
||||
|
||||
# 测试总结
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("📋 测试总结")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
total_tests = len(self.test_results)
|
||||
passed_tests = sum(1 for _, success in self.test_results if success)
|
||||
failed_tests = total_tests - passed_tests
|
||||
total_tests = len(self.test_results)
|
||||
passed_tests = sum(1 for _, success in self.test_results if success)
|
||||
failed_tests = total_tests - passed_tests
|
||||
|
||||
print(f"总测试数: {total_tests}")
|
||||
print(f"通过: {passed_tests} ✅")
|
||||
@@ -731,14 +731,14 @@ class TestGrowthManager:
|
||||
if not success:
|
||||
print(f" - {message}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("✨ 测试完成!")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
|
||||
async def main():
|
||||
async def main() -> None:
|
||||
"""主函数"""
|
||||
tester = TestGrowthManager()
|
||||
tester = TestGrowthManager()
|
||||
await tester.run_all_tests()
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from developer_ecosystem_manager import (
|
||||
)
|
||||
|
||||
# Add backend directory to path
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if backend_dir not in sys.path:
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
@@ -33,10 +33,10 @@ if backend_dir not in sys.path:
|
||||
class TestDeveloperEcosystem:
|
||||
"""开发者生态系统测试类"""
|
||||
|
||||
def __init__(self):
|
||||
self.manager = DeveloperEcosystemManager()
|
||||
self.test_results = []
|
||||
self.created_ids = {
|
||||
def __init__(self) -> None:
|
||||
self.manager = DeveloperEcosystemManager()
|
||||
self.test_results = []
|
||||
self.created_ids = {
|
||||
"sdk": [],
|
||||
"template": [],
|
||||
"plugin": [],
|
||||
@@ -45,19 +45,19 @@ class TestDeveloperEcosystem:
|
||||
"portal_config": [],
|
||||
}
|
||||
|
||||
def log(self, message: str, success: bool = True):
|
||||
def log(self, message: str, success: bool = True) -> None:
|
||||
"""记录测试结果"""
|
||||
status = "✅" if success else "❌"
|
||||
status = "✅" if success else "❌"
|
||||
print(f"{status} {message}")
|
||||
self.test_results.append(
|
||||
{"message": message, "success": success, "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
def run_all_tests(self) -> None:
|
||||
"""运行所有测试"""
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
# SDK Tests
|
||||
print("\n📦 SDK Release & Management Tests")
|
||||
@@ -119,538 +119,538 @@ class TestDeveloperEcosystem:
|
||||
# Print Summary
|
||||
self.print_summary()
|
||||
|
||||
def test_sdk_create(self):
|
||||
def test_sdk_create(self) -> None:
|
||||
"""测试创建 SDK"""
|
||||
try:
|
||||
sdk = self.manager.create_sdk_release(
|
||||
name="InsightFlow Python SDK",
|
||||
language=SDKLanguage.PYTHON,
|
||||
version="1.0.0",
|
||||
description="Python SDK for InsightFlow API",
|
||||
changelog="Initial release",
|
||||
download_url="https://pypi.org/insightflow/1.0.0",
|
||||
documentation_url="https://docs.insightflow.io/python",
|
||||
repository_url="https://github.com/insightflow/python-sdk",
|
||||
package_name="insightflow",
|
||||
min_platform_version="1.0.0",
|
||||
dependencies=[{"name": "requests", "version": ">=2.0"}],
|
||||
file_size=1024000,
|
||||
checksum="abc123",
|
||||
created_by="test_user",
|
||||
sdk = self.manager.create_sdk_release(
|
||||
name = "InsightFlow Python SDK",
|
||||
language = SDKLanguage.PYTHON,
|
||||
version = "1.0.0",
|
||||
description = "Python SDK for InsightFlow API",
|
||||
changelog = "Initial release",
|
||||
download_url = "https://pypi.org/insightflow/1.0.0",
|
||||
documentation_url = "https://docs.insightflow.io/python",
|
||||
repository_url = "https://github.com/insightflow/python-sdk",
|
||||
package_name = "insightflow",
|
||||
min_platform_version = "1.0.0",
|
||||
dependencies = [{"name": "requests", "version": ">= 2.0"}],
|
||||
file_size = 1024000,
|
||||
checksum = "abc123",
|
||||
created_by = "test_user",
|
||||
)
|
||||
self.created_ids["sdk"].append(sdk.id)
|
||||
self.log(f"Created SDK: {sdk.name} ({sdk.id})")
|
||||
|
||||
# Create JavaScript SDK
|
||||
sdk_js = self.manager.create_sdk_release(
|
||||
name="InsightFlow JavaScript SDK",
|
||||
language=SDKLanguage.JAVASCRIPT,
|
||||
version="1.0.0",
|
||||
description="JavaScript SDK for InsightFlow API",
|
||||
changelog="Initial release",
|
||||
download_url="https://npmjs.com/insightflow/1.0.0",
|
||||
documentation_url="https://docs.insightflow.io/js",
|
||||
repository_url="https://github.com/insightflow/js-sdk",
|
||||
package_name="@insightflow/sdk",
|
||||
min_platform_version="1.0.0",
|
||||
dependencies=[{"name": "axios", "version": ">=0.21"}],
|
||||
file_size=512000,
|
||||
checksum="def456",
|
||||
created_by="test_user",
|
||||
sdk_js = self.manager.create_sdk_release(
|
||||
name = "InsightFlow JavaScript SDK",
|
||||
language = SDKLanguage.JAVASCRIPT,
|
||||
version = "1.0.0",
|
||||
description = "JavaScript SDK for InsightFlow API",
|
||||
changelog = "Initial release",
|
||||
download_url = "https://npmjs.com/insightflow/1.0.0",
|
||||
documentation_url = "https://docs.insightflow.io/js",
|
||||
repository_url = "https://github.com/insightflow/js-sdk",
|
||||
package_name = "@insightflow/sdk",
|
||||
min_platform_version = "1.0.0",
|
||||
dependencies = [{"name": "axios", "version": ">= 0.21"}],
|
||||
file_size = 512000,
|
||||
checksum = "def456",
|
||||
created_by = "test_user",
|
||||
)
|
||||
self.created_ids["sdk"].append(sdk_js.id)
|
||||
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to create SDK: {str(e)}", success=False)
|
||||
self.log(f"Failed to create SDK: {str(e)}", success = False)
|
||||
|
||||
def test_sdk_list(self):
|
||||
def test_sdk_list(self) -> None:
|
||||
"""测试列出 SDK"""
|
||||
try:
|
||||
sdks = self.manager.list_sdk_releases()
|
||||
sdks = self.manager.list_sdk_releases()
|
||||
self.log(f"Listed {len(sdks)} SDKs")
|
||||
|
||||
# Test filter by language
|
||||
python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON)
|
||||
python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON)
|
||||
self.log(f"Found {len(python_sdks)} Python SDKs")
|
||||
|
||||
# Test search
|
||||
search_results = self.manager.list_sdk_releases(search="Python")
|
||||
search_results = self.manager.list_sdk_releases(search = "Python")
|
||||
self.log(f"Search found {len(search_results)} SDKs")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to list SDKs: {str(e)}", success=False)
|
||||
self.log(f"Failed to list SDKs: {str(e)}", success = False)
|
||||
|
||||
def test_sdk_get(self):
|
||||
def test_sdk_get(self) -> None:
|
||||
"""测试获取 SDK 详情"""
|
||||
try:
|
||||
if self.created_ids["sdk"]:
|
||||
sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0])
|
||||
sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0])
|
||||
if sdk:
|
||||
self.log(f"Retrieved SDK: {sdk.name}")
|
||||
else:
|
||||
self.log("SDK not found", success=False)
|
||||
self.log("SDK not found", success = False)
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get SDK: {str(e)}", success=False)
|
||||
self.log(f"Failed to get SDK: {str(e)}", success = False)
|
||||
|
||||
def test_sdk_update(self):
|
||||
def test_sdk_update(self) -> None:
|
||||
"""测试更新 SDK"""
|
||||
try:
|
||||
if self.created_ids["sdk"]:
|
||||
sdk = self.manager.update_sdk_release(
|
||||
self.created_ids["sdk"][0], description="Updated description"
|
||||
sdk = self.manager.update_sdk_release(
|
||||
self.created_ids["sdk"][0], description = "Updated description"
|
||||
)
|
||||
if sdk:
|
||||
self.log(f"Updated SDK: {sdk.name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to update SDK: {str(e)}", success=False)
|
||||
self.log(f"Failed to update SDK: {str(e)}", success = False)
|
||||
|
||||
def test_sdk_publish(self):
|
||||
def test_sdk_publish(self) -> None:
|
||||
"""测试发布 SDK"""
|
||||
try:
|
||||
if self.created_ids["sdk"]:
|
||||
sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0])
|
||||
sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0])
|
||||
if sdk:
|
||||
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to publish SDK: {str(e)}", success=False)
|
||||
self.log(f"Failed to publish SDK: {str(e)}", success = False)
|
||||
|
||||
def test_sdk_version_add(self):
|
||||
def test_sdk_version_add(self) -> None:
|
||||
"""测试添加 SDK 版本"""
|
||||
try:
|
||||
if self.created_ids["sdk"]:
|
||||
version = self.manager.add_sdk_version(
|
||||
sdk_id=self.created_ids["sdk"][0],
|
||||
version="1.1.0",
|
||||
is_lts=True,
|
||||
release_notes="Bug fixes and improvements",
|
||||
download_url="https://pypi.org/insightflow/1.1.0",
|
||||
checksum="xyz789",
|
||||
file_size=1100000,
|
||||
version = self.manager.add_sdk_version(
|
||||
sdk_id = self.created_ids["sdk"][0],
|
||||
version = "1.1.0",
|
||||
is_lts = True,
|
||||
release_notes = "Bug fixes and improvements",
|
||||
download_url = "https://pypi.org/insightflow/1.1.0",
|
||||
checksum = "xyz789",
|
||||
file_size = 1100000,
|
||||
)
|
||||
self.log(f"Added SDK version: {version.version}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to add SDK version: {str(e)}", success=False)
|
||||
self.log(f"Failed to add SDK version: {str(e)}", success = False)
|
||||
|
||||
def test_template_create(self):
|
||||
def test_template_create(self) -> None:
|
||||
"""测试创建模板"""
|
||||
try:
|
||||
template = self.manager.create_template(
|
||||
name="医疗行业实体识别模板",
|
||||
description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
|
||||
category=TemplateCategory.MEDICAL,
|
||||
subcategory="entity_recognition",
|
||||
tags=["medical", "healthcare", "ner"],
|
||||
author_id="dev_001",
|
||||
author_name="Medical AI Lab",
|
||||
price=99.0,
|
||||
currency="CNY",
|
||||
preview_image_url="https://cdn.insightflow.io/templates/medical.png",
|
||||
demo_url="https://demo.insightflow.io/medical",
|
||||
documentation_url="https://docs.insightflow.io/templates/medical",
|
||||
download_url="https://cdn.insightflow.io/templates/medical.zip",
|
||||
version="1.0.0",
|
||||
min_platform_version="2.0.0",
|
||||
file_size=5242880,
|
||||
checksum="tpl123",
|
||||
template = self.manager.create_template(
|
||||
name = "医疗行业实体识别模板",
|
||||
description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
|
||||
category = TemplateCategory.MEDICAL,
|
||||
subcategory = "entity_recognition",
|
||||
tags = ["medical", "healthcare", "ner"],
|
||||
author_id = "dev_001",
|
||||
author_name = "Medical AI Lab",
|
||||
price = 99.0,
|
||||
currency = "CNY",
|
||||
preview_image_url = "https://cdn.insightflow.io/templates/medical.png",
|
||||
demo_url = "https://demo.insightflow.io/medical",
|
||||
documentation_url = "https://docs.insightflow.io/templates/medical",
|
||||
download_url = "https://cdn.insightflow.io/templates/medical.zip",
|
||||
version = "1.0.0",
|
||||
min_platform_version = "2.0.0",
|
||||
file_size = 5242880,
|
||||
checksum = "tpl123",
|
||||
)
|
||||
self.created_ids["template"].append(template.id)
|
||||
self.log(f"Created template: {template.name} ({template.id})")
|
||||
|
||||
# Create free template
|
||||
template_free = self.manager.create_template(
|
||||
name="通用实体识别模板",
|
||||
description="适用于一般场景的实体识别模板",
|
||||
category=TemplateCategory.GENERAL,
|
||||
subcategory=None,
|
||||
tags=["general", "ner", "basic"],
|
||||
author_id="dev_002",
|
||||
author_name="InsightFlow Team",
|
||||
price=0.0,
|
||||
currency="CNY",
|
||||
template_free = self.manager.create_template(
|
||||
name = "通用实体识别模板",
|
||||
description = "适用于一般场景的实体识别模板",
|
||||
category = TemplateCategory.GENERAL,
|
||||
subcategory = None,
|
||||
tags = ["general", "ner", "basic"],
|
||||
author_id = "dev_002",
|
||||
author_name = "InsightFlow Team",
|
||||
price = 0.0,
|
||||
currency = "CNY",
|
||||
)
|
||||
self.created_ids["template"].append(template_free.id)
|
||||
self.log(f"Created free template: {template_free.name}")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to create template: {str(e)}", success=False)
|
||||
self.log(f"Failed to create template: {str(e)}", success = False)
|
||||
|
||||
def test_template_list(self):
|
||||
def test_template_list(self) -> None:
|
||||
"""测试列出模板"""
|
||||
try:
|
||||
templates = self.manager.list_templates()
|
||||
templates = self.manager.list_templates()
|
||||
self.log(f"Listed {len(templates)} templates")
|
||||
|
||||
# Filter by category
|
||||
medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL)
|
||||
medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL)
|
||||
self.log(f"Found {len(medical_templates)} medical templates")
|
||||
|
||||
# Filter by price
|
||||
free_templates = self.manager.list_templates(max_price=0)
|
||||
free_templates = self.manager.list_templates(max_price = 0)
|
||||
self.log(f"Found {len(free_templates)} free templates")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to list templates: {str(e)}", success=False)
|
||||
self.log(f"Failed to list templates: {str(e)}", success = False)
|
||||
|
||||
def test_template_get(self):
|
||||
def test_template_get(self) -> None:
|
||||
"""测试获取模板详情"""
|
||||
try:
|
||||
if self.created_ids["template"]:
|
||||
template = self.manager.get_template(self.created_ids["template"][0])
|
||||
template = self.manager.get_template(self.created_ids["template"][0])
|
||||
if template:
|
||||
self.log(f"Retrieved template: {template.name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get template: {str(e)}", success=False)
|
||||
self.log(f"Failed to get template: {str(e)}", success = False)
|
||||
|
||||
def test_template_approve(self):
|
||||
def test_template_approve(self) -> None:
|
||||
"""测试审核通过模板"""
|
||||
try:
|
||||
if self.created_ids["template"]:
|
||||
template = self.manager.approve_template(
|
||||
self.created_ids["template"][0], reviewed_by="admin_001"
|
||||
template = self.manager.approve_template(
|
||||
self.created_ids["template"][0], reviewed_by = "admin_001"
|
||||
)
|
||||
if template:
|
||||
self.log(f"Approved template: {template.name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to approve template: {str(e)}", success=False)
|
||||
self.log(f"Failed to approve template: {str(e)}", success = False)
|
||||
|
||||
def test_template_publish(self):
|
||||
def test_template_publish(self) -> None:
|
||||
"""测试发布模板"""
|
||||
try:
|
||||
if self.created_ids["template"]:
|
||||
template = self.manager.publish_template(self.created_ids["template"][0])
|
||||
template = self.manager.publish_template(self.created_ids["template"][0])
|
||||
if template:
|
||||
self.log(f"Published template: {template.name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to publish template: {str(e)}", success=False)
|
||||
self.log(f"Failed to publish template: {str(e)}", success = False)
|
||||
|
||||
def test_template_review(self):
|
||||
def test_template_review(self) -> None:
|
||||
"""测试添加模板评价"""
|
||||
try:
|
||||
if self.created_ids["template"]:
|
||||
review = self.manager.add_template_review(
|
||||
template_id=self.created_ids["template"][0],
|
||||
user_id="user_001",
|
||||
user_name="Test User",
|
||||
rating=5,
|
||||
comment="Great template! Very accurate for medical entities.",
|
||||
is_verified_purchase=True,
|
||||
review = self.manager.add_template_review(
|
||||
template_id = self.created_ids["template"][0],
|
||||
user_id = "user_001",
|
||||
user_name = "Test User",
|
||||
rating = 5,
|
||||
comment = "Great template! Very accurate for medical entities.",
|
||||
is_verified_purchase = True,
|
||||
)
|
||||
self.log(f"Added template review: {review.rating} stars")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to add template review: {str(e)}", success=False)
|
||||
self.log(f"Failed to add template review: {str(e)}", success = False)
|
||||
|
||||
def test_plugin_create(self):
|
||||
def test_plugin_create(self) -> None:
|
||||
"""测试创建插件"""
|
||||
try:
|
||||
plugin = self.manager.create_plugin(
|
||||
name="飞书机器人集成插件",
|
||||
description="将 InsightFlow 与飞书机器人集成,实现自动通知",
|
||||
category=PluginCategory.INTEGRATION,
|
||||
tags=["feishu", "bot", "integration", "notification"],
|
||||
author_id="dev_003",
|
||||
author_name="Integration Team",
|
||||
price=49.0,
|
||||
currency="CNY",
|
||||
pricing_model="paid",
|
||||
preview_image_url="https://cdn.insightflow.io/plugins/feishu.png",
|
||||
demo_url="https://demo.insightflow.io/feishu",
|
||||
documentation_url="https://docs.insightflow.io/plugins/feishu",
|
||||
repository_url="https://github.com/insightflow/feishu-plugin",
|
||||
download_url="https://cdn.insightflow.io/plugins/feishu.zip",
|
||||
webhook_url="https://api.insightflow.io/webhooks/feishu",
|
||||
permissions=["read:projects", "write:notifications"],
|
||||
version="1.0.0",
|
||||
min_platform_version="2.0.0",
|
||||
file_size=1048576,
|
||||
checksum="plg123",
|
||||
plugin = self.manager.create_plugin(
|
||||
name = "飞书机器人集成插件",
|
||||
description = "将 InsightFlow 与飞书机器人集成,实现自动通知",
|
||||
category = PluginCategory.INTEGRATION,
|
||||
tags = ["feishu", "bot", "integration", "notification"],
|
||||
author_id = "dev_003",
|
||||
author_name = "Integration Team",
|
||||
price = 49.0,
|
||||
currency = "CNY",
|
||||
pricing_model = "paid",
|
||||
preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png",
|
||||
demo_url = "https://demo.insightflow.io/feishu",
|
||||
documentation_url = "https://docs.insightflow.io/plugins/feishu",
|
||||
repository_url = "https://github.com/insightflow/feishu-plugin",
|
||||
download_url = "https://cdn.insightflow.io/plugins/feishu.zip",
|
||||
webhook_url = "https://api.insightflow.io/webhooks/feishu",
|
||||
permissions = ["read:projects", "write:notifications"],
|
||||
version = "1.0.0",
|
||||
min_platform_version = "2.0.0",
|
||||
file_size = 1048576,
|
||||
checksum = "plg123",
|
||||
)
|
||||
self.created_ids["plugin"].append(plugin.id)
|
||||
self.log(f"Created plugin: {plugin.name} ({plugin.id})")
|
||||
|
||||
# Create free plugin
|
||||
plugin_free = self.manager.create_plugin(
|
||||
name="数据导出插件",
|
||||
description="支持多种格式的数据导出",
|
||||
category=PluginCategory.ANALYSIS,
|
||||
tags=["export", "data", "csv", "json"],
|
||||
author_id="dev_004",
|
||||
author_name="Data Team",
|
||||
price=0.0,
|
||||
currency="CNY",
|
||||
pricing_model="free",
|
||||
plugin_free = self.manager.create_plugin(
|
||||
name = "数据导出插件",
|
||||
description = "支持多种格式的数据导出",
|
||||
category = PluginCategory.ANALYSIS,
|
||||
tags = ["export", "data", "csv", "json"],
|
||||
author_id = "dev_004",
|
||||
author_name = "Data Team",
|
||||
price = 0.0,
|
||||
currency = "CNY",
|
||||
pricing_model = "free",
|
||||
)
|
||||
self.created_ids["plugin"].append(plugin_free.id)
|
||||
self.log(f"Created free plugin: {plugin_free.name}")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to create plugin: {str(e)}", success=False)
|
||||
self.log(f"Failed to create plugin: {str(e)}", success = False)
|
||||
|
||||
def test_plugin_list(self):
|
||||
def test_plugin_list(self) -> None:
|
||||
"""测试列出插件"""
|
||||
try:
|
||||
plugins = self.manager.list_plugins()
|
||||
plugins = self.manager.list_plugins()
|
||||
self.log(f"Listed {len(plugins)} plugins")
|
||||
|
||||
# Filter by category
|
||||
integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION)
|
||||
integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION)
|
||||
self.log(f"Found {len(integration_plugins)} integration plugins")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to list plugins: {str(e)}", success=False)
|
||||
self.log(f"Failed to list plugins: {str(e)}", success = False)
|
||||
|
||||
def test_plugin_get(self):
|
||||
def test_plugin_get(self) -> None:
|
||||
"""测试获取插件详情"""
|
||||
try:
|
||||
if self.created_ids["plugin"]:
|
||||
plugin = self.manager.get_plugin(self.created_ids["plugin"][0])
|
||||
plugin = self.manager.get_plugin(self.created_ids["plugin"][0])
|
||||
if plugin:
|
||||
self.log(f"Retrieved plugin: {plugin.name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get plugin: {str(e)}", success=False)
|
||||
self.log(f"Failed to get plugin: {str(e)}", success = False)
|
||||
|
||||
def test_plugin_review(self):
|
||||
def test_plugin_review(self) -> None:
|
||||
"""测试审核插件"""
|
||||
try:
|
||||
if self.created_ids["plugin"]:
|
||||
plugin = self.manager.review_plugin(
|
||||
plugin = self.manager.review_plugin(
|
||||
self.created_ids["plugin"][0],
|
||||
reviewed_by="admin_001",
|
||||
status=PluginStatus.APPROVED,
|
||||
notes="Code review passed",
|
||||
reviewed_by = "admin_001",
|
||||
status = PluginStatus.APPROVED,
|
||||
notes = "Code review passed",
|
||||
)
|
||||
if plugin:
|
||||
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to review plugin: {str(e)}", success=False)
|
||||
self.log(f"Failed to review plugin: {str(e)}", success = False)
|
||||
|
||||
def test_plugin_publish(self):
|
||||
def test_plugin_publish(self) -> None:
|
||||
"""测试发布插件"""
|
||||
try:
|
||||
if self.created_ids["plugin"]:
|
||||
plugin = self.manager.publish_plugin(self.created_ids["plugin"][0])
|
||||
plugin = self.manager.publish_plugin(self.created_ids["plugin"][0])
|
||||
if plugin:
|
||||
self.log(f"Published plugin: {plugin.name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to publish plugin: {str(e)}", success=False)
|
||||
self.log(f"Failed to publish plugin: {str(e)}", success = False)
|
||||
|
||||
def test_plugin_review_add(self):
|
||||
def test_plugin_review_add(self) -> None:
|
||||
"""测试添加插件评价"""
|
||||
try:
|
||||
if self.created_ids["plugin"]:
|
||||
review = self.manager.add_plugin_review(
|
||||
plugin_id=self.created_ids["plugin"][0],
|
||||
user_id="user_002",
|
||||
user_name="Plugin User",
|
||||
rating=4,
|
||||
comment="Works great with Feishu!",
|
||||
is_verified_purchase=True,
|
||||
review = self.manager.add_plugin_review(
|
||||
plugin_id = self.created_ids["plugin"][0],
|
||||
user_id = "user_002",
|
||||
user_name = "Plugin User",
|
||||
rating = 4,
|
||||
comment = "Works great with Feishu!",
|
||||
is_verified_purchase = True,
|
||||
)
|
||||
self.log(f"Added plugin review: {review.rating} stars")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to add plugin review: {str(e)}", success=False)
|
||||
self.log(f"Failed to add plugin review: {str(e)}", success = False)
|
||||
|
||||
def test_developer_profile_create(self):
|
||||
def test_developer_profile_create(self) -> None:
|
||||
"""测试创建开发者档案"""
|
||||
try:
|
||||
# Generate unique user IDs
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
|
||||
profile = self.manager.create_developer_profile(
|
||||
user_id=f"user_dev_{unique_id}_001",
|
||||
display_name="张三",
|
||||
email=f"zhangsan_{unique_id}@example.com",
|
||||
bio="专注于医疗AI和自然语言处理",
|
||||
website="https://zhangsan.dev",
|
||||
github_url="https://github.com/zhangsan",
|
||||
avatar_url="https://cdn.example.com/avatars/zhangsan.png",
|
||||
profile = self.manager.create_developer_profile(
|
||||
user_id = f"user_dev_{unique_id}_001",
|
||||
display_name = "张三",
|
||||
email = f"zhangsan_{unique_id}@example.com",
|
||||
bio = "专注于医疗AI和自然语言处理",
|
||||
website = "https://zhangsan.dev",
|
||||
github_url = "https://github.com/zhangsan",
|
||||
avatar_url = "https://cdn.example.com/avatars/zhangsan.png",
|
||||
)
|
||||
self.created_ids["developer"].append(profile.id)
|
||||
self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
|
||||
|
||||
# Create another developer
|
||||
profile2 = self.manager.create_developer_profile(
|
||||
user_id=f"user_dev_{unique_id}_002",
|
||||
display_name="李四",
|
||||
email=f"lisi_{unique_id}@example.com",
|
||||
bio="全栈开发者,热爱开源",
|
||||
profile2 = self.manager.create_developer_profile(
|
||||
user_id = f"user_dev_{unique_id}_002",
|
||||
display_name = "李四",
|
||||
email = f"lisi_{unique_id}@example.com",
|
||||
bio = "全栈开发者,热爱开源",
|
||||
)
|
||||
self.created_ids["developer"].append(profile2.id)
|
||||
self.log(f"Created developer profile: {profile2.display_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to create developer profile: {str(e)}", success=False)
|
||||
self.log(f"Failed to create developer profile: {str(e)}", success = False)
|
||||
|
||||
def test_developer_profile_get(self):
|
||||
def test_developer_profile_get(self) -> None:
|
||||
"""测试获取开发者档案"""
|
||||
try:
|
||||
if self.created_ids["developer"]:
|
||||
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
|
||||
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
|
||||
if profile:
|
||||
self.log(f"Retrieved developer profile: {profile.display_name}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get developer profile: {str(e)}", success=False)
|
||||
self.log(f"Failed to get developer profile: {str(e)}", success = False)
|
||||
|
||||
def test_developer_verify(self):
|
||||
def test_developer_verify(self) -> None:
|
||||
"""测试验证开发者"""
|
||||
try:
|
||||
if self.created_ids["developer"]:
|
||||
profile = self.manager.verify_developer(
|
||||
profile = self.manager.verify_developer(
|
||||
self.created_ids["developer"][0], DeveloperStatus.VERIFIED
|
||||
)
|
||||
if profile:
|
||||
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to verify developer: {str(e)}", success=False)
|
||||
self.log(f"Failed to verify developer: {str(e)}", success = False)
|
||||
|
||||
def test_developer_stats_update(self):
|
||||
def test_developer_stats_update(self) -> None:
|
||||
"""测试更新开发者统计"""
|
||||
try:
|
||||
if self.created_ids["developer"]:
|
||||
self.manager.update_developer_stats(self.created_ids["developer"][0])
|
||||
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
|
||||
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
|
||||
self.log(
|
||||
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates"
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(f"Failed to update developer stats: {str(e)}", success=False)
|
||||
self.log(f"Failed to update developer stats: {str(e)}", success = False)
|
||||
|
||||
def test_code_example_create(self):
|
||||
def test_code_example_create(self) -> None:
|
||||
"""测试创建代码示例"""
|
||||
try:
|
||||
example = self.manager.create_code_example(
|
||||
title="使用 Python SDK 创建项目",
|
||||
description="演示如何使用 Python SDK 创建新项目",
|
||||
language="python",
|
||||
category="quickstart",
|
||||
code="""from insightflow import Client
|
||||
example = self.manager.create_code_example(
|
||||
title = "使用 Python SDK 创建项目",
|
||||
description = "演示如何使用 Python SDK 创建新项目",
|
||||
language = "python",
|
||||
category = "quickstart",
|
||||
code = """from insightflow import Client
|
||||
|
||||
client = Client(api_key="your_api_key")
|
||||
project = client.projects.create(name="My Project")
|
||||
client = Client(api_key = "your_api_key")
|
||||
project = client.projects.create(name = "My Project")
|
||||
print(f"Created project: {project.id}")
|
||||
""",
|
||||
explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
|
||||
tags=["python", "quickstart", "projects"],
|
||||
author_id="dev_001",
|
||||
author_name="InsightFlow Team",
|
||||
api_endpoints=["/api/v1/projects"],
|
||||
explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
|
||||
tags = ["python", "quickstart", "projects"],
|
||||
author_id = "dev_001",
|
||||
author_name = "InsightFlow Team",
|
||||
api_endpoints = ["/api/v1/projects"],
|
||||
)
|
||||
self.created_ids["code_example"].append(example.id)
|
||||
self.log(f"Created code example: {example.title}")
|
||||
|
||||
# Create JavaScript example
|
||||
example_js = self.manager.create_code_example(
|
||||
title="使用 JavaScript SDK 上传文件",
|
||||
description="演示如何使用 JavaScript SDK 上传音频文件",
|
||||
language="javascript",
|
||||
category="upload",
|
||||
code="""const { Client } = require('insightflow');
|
||||
example_js = self.manager.create_code_example(
|
||||
title = "使用 JavaScript SDK 上传文件",
|
||||
description = "演示如何使用 JavaScript SDK 上传音频文件",
|
||||
language = "javascript",
|
||||
category = "upload",
|
||||
code = """const { Client } = require('insightflow');
|
||||
|
||||
const client = new Client({ apiKey: 'your_api_key' });
|
||||
const result = await client.uploads.create({
|
||||
const client = new Client({ apiKey: 'your_api_key' });
|
||||
const result = await client.uploads.create({
|
||||
projectId: 'proj_123',
|
||||
file: './meeting.mp3'
|
||||
});
|
||||
console.log('Upload complete:', result.id);
|
||||
""",
|
||||
explanation="使用 JavaScript SDK 上传文件到 InsightFlow",
|
||||
tags=["javascript", "upload", "audio"],
|
||||
author_id="dev_002",
|
||||
author_name="JS Team",
|
||||
explanation = "使用 JavaScript SDK 上传文件到 InsightFlow",
|
||||
tags = ["javascript", "upload", "audio"],
|
||||
author_id = "dev_002",
|
||||
author_name = "JS Team",
|
||||
)
|
||||
self.created_ids["code_example"].append(example_js.id)
|
||||
self.log(f"Created code example: {example_js.title}")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to create code example: {str(e)}", success=False)
|
||||
self.log(f"Failed to create code example: {str(e)}", success = False)
|
||||
|
||||
def test_code_example_list(self):
|
||||
def test_code_example_list(self) -> None:
|
||||
"""测试列出代码示例"""
|
||||
try:
|
||||
examples = self.manager.list_code_examples()
|
||||
examples = self.manager.list_code_examples()
|
||||
self.log(f"Listed {len(examples)} code examples")
|
||||
|
||||
# Filter by language
|
||||
python_examples = self.manager.list_code_examples(language="python")
|
||||
python_examples = self.manager.list_code_examples(language = "python")
|
||||
self.log(f"Found {len(python_examples)} Python examples")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to list code examples: {str(e)}", success=False)
|
||||
self.log(f"Failed to list code examples: {str(e)}", success = False)
|
||||
|
||||
def test_code_example_get(self):
|
||||
def test_code_example_get(self) -> None:
|
||||
"""测试获取代码示例详情"""
|
||||
try:
|
||||
if self.created_ids["code_example"]:
|
||||
example = self.manager.get_code_example(self.created_ids["code_example"][0])
|
||||
example = self.manager.get_code_example(self.created_ids["code_example"][0])
|
||||
if example:
|
||||
self.log(
|
||||
f"Retrieved code example: {example.title} (views: {example.view_count})"
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get code example: {str(e)}", success=False)
|
||||
self.log(f"Failed to get code example: {str(e)}", success = False)
|
||||
|
||||
def test_portal_config_create(self):
|
||||
def test_portal_config_create(self) -> None:
|
||||
"""测试创建开发者门户配置"""
|
||||
try:
|
||||
config = self.manager.create_portal_config(
|
||||
name="InsightFlow Developer Portal",
|
||||
description="开发者门户 - SDK、API 文档和示例代码",
|
||||
theme="default",
|
||||
primary_color="#1890ff",
|
||||
secondary_color="#52c41a",
|
||||
support_email="developers@insightflow.io",
|
||||
support_url="https://support.insightflow.io",
|
||||
github_url="https://github.com/insightflow",
|
||||
discord_url="https://discord.gg/insightflow",
|
||||
api_base_url="https://api.insightflow.io/v1",
|
||||
config = self.manager.create_portal_config(
|
||||
name = "InsightFlow Developer Portal",
|
||||
description = "开发者门户 - SDK、API 文档和示例代码",
|
||||
theme = "default",
|
||||
primary_color = "#1890ff",
|
||||
secondary_color = "#52c41a",
|
||||
support_email = "developers@insightflow.io",
|
||||
support_url = "https://support.insightflow.io",
|
||||
github_url = "https://github.com/insightflow",
|
||||
discord_url = "https://discord.gg/insightflow",
|
||||
api_base_url = "https://api.insightflow.io/v1",
|
||||
)
|
||||
self.created_ids["portal_config"].append(config.id)
|
||||
self.log(f"Created portal config: {config.name}")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to create portal config: {str(e)}", success=False)
|
||||
self.log(f"Failed to create portal config: {str(e)}", success = False)
|
||||
|
||||
def test_portal_config_get(self):
|
||||
def test_portal_config_get(self) -> None:
|
||||
"""测试获取开发者门户配置"""
|
||||
try:
|
||||
if self.created_ids["portal_config"]:
|
||||
config = self.manager.get_portal_config(self.created_ids["portal_config"][0])
|
||||
config = self.manager.get_portal_config(self.created_ids["portal_config"][0])
|
||||
if config:
|
||||
self.log(f"Retrieved portal config: {config.name}")
|
||||
|
||||
# Test active config
|
||||
active_config = self.manager.get_active_portal_config()
|
||||
active_config = self.manager.get_active_portal_config()
|
||||
if active_config:
|
||||
self.log(f"Active portal config: {active_config.name}")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get portal config: {str(e)}", success=False)
|
||||
self.log(f"Failed to get portal config: {str(e)}", success = False)
|
||||
|
||||
def test_revenue_record(self):
|
||||
def test_revenue_record(self) -> None:
|
||||
"""测试记录开发者收益"""
|
||||
try:
|
||||
if self.created_ids["developer"] and self.created_ids["plugin"]:
|
||||
revenue = self.manager.record_revenue(
|
||||
developer_id=self.created_ids["developer"][0],
|
||||
item_type="plugin",
|
||||
item_id=self.created_ids["plugin"][0],
|
||||
item_name="飞书机器人集成插件",
|
||||
sale_amount=49.0,
|
||||
currency="CNY",
|
||||
buyer_id="user_buyer_001",
|
||||
transaction_id="txn_123456",
|
||||
revenue = self.manager.record_revenue(
|
||||
developer_id = self.created_ids["developer"][0],
|
||||
item_type = "plugin",
|
||||
item_id = self.created_ids["plugin"][0],
|
||||
item_name = "飞书机器人集成插件",
|
||||
sale_amount = 49.0,
|
||||
currency = "CNY",
|
||||
buyer_id = "user_buyer_001",
|
||||
transaction_id = "txn_123456",
|
||||
)
|
||||
self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}")
|
||||
self.log(f" - Platform fee: {revenue.platform_fee}")
|
||||
self.log(f" - Developer earnings: {revenue.developer_earnings}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to record revenue: {str(e)}", success=False)
|
||||
self.log(f"Failed to record revenue: {str(e)}", success = False)
|
||||
|
||||
def test_revenue_summary(self):
|
||||
def test_revenue_summary(self) -> None:
|
||||
"""测试获取开发者收益汇总"""
|
||||
try:
|
||||
if self.created_ids["developer"]:
|
||||
summary = self.manager.get_developer_revenue_summary(
|
||||
summary = self.manager.get_developer_revenue_summary(
|
||||
self.created_ids["developer"][0]
|
||||
)
|
||||
self.log("Revenue summary for developer:")
|
||||
@@ -659,17 +659,17 @@ console.log('Upload complete:', result.id);
|
||||
self.log(f" - Total earnings: {summary['total_earnings']}")
|
||||
self.log(f" - Transaction count: {summary['transaction_count']}")
|
||||
except Exception as e:
|
||||
self.log(f"Failed to get revenue summary: {str(e)}", success=False)
|
||||
self.log(f"Failed to get revenue summary: {str(e)}", success = False)
|
||||
|
||||
def print_summary(self):
|
||||
def print_summary(self) -> None:
|
||||
"""打印测试摘要"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
total = len(self.test_results)
|
||||
passed = sum(1 for r in self.test_results if r["success"])
|
||||
failed = total - passed
|
||||
total = len(self.test_results)
|
||||
passed = sum(1 for r in self.test_results if r["success"])
|
||||
failed = total - passed
|
||||
|
||||
print(f"Total tests: {total}")
|
||||
print(f"Passed: {passed} ✅")
|
||||
@@ -686,12 +686,12 @@ console.log('Upload complete:', result.id);
|
||||
if ids:
|
||||
print(f" {resource_type}: {len(ids)}")
|
||||
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""主函数"""
|
||||
test = TestDeveloperEcosystem()
|
||||
test = TestDeveloperEcosystem()
|
||||
test.run_all_tests()
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from ops_manager import (
|
||||
)
|
||||
|
||||
# Add backend directory to path
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if backend_dir not in sys.path:
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
@@ -34,22 +34,22 @@ if backend_dir not in sys.path:
|
||||
class TestOpsManager:
|
||||
"""测试运维与监控管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.manager = get_ops_manager()
|
||||
self.tenant_id = "test_tenant_001"
|
||||
self.test_results = []
|
||||
def __init__(self) -> None:
|
||||
self.manager = get_ops_manager()
|
||||
self.tenant_id = "test_tenant_001"
|
||||
self.test_results = []
|
||||
|
||||
def log(self, message: str, success: bool = True):
|
||||
def log(self, message: str, success: bool = True) -> None:
|
||||
"""记录测试结果"""
|
||||
status = "✅" if success else "❌"
|
||||
status = "✅" if success else "❌"
|
||||
print(f"{status} {message}")
|
||||
self.test_results.append((message, success))
|
||||
|
||||
def run_all_tests(self):
|
||||
def run_all_tests(self) -> None:
|
||||
"""运行所有测试"""
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
# 1. 告警系统测试
|
||||
self.test_alert_rules()
|
||||
@@ -73,63 +73,63 @@ class TestOpsManager:
|
||||
# 打印测试总结
|
||||
self.print_summary()
|
||||
|
||||
def test_alert_rules(self):
|
||||
def test_alert_rules(self) -> None:
|
||||
"""测试告警规则管理"""
|
||||
print("\n📋 Testing Alert Rules...")
|
||||
|
||||
try:
|
||||
# 创建阈值告警规则
|
||||
rule1 = self.manager.create_alert_rule(
|
||||
tenant_id=self.tenant_id,
|
||||
name="CPU 使用率告警",
|
||||
description="当 CPU 使用率超过 80% 时触发告警",
|
||||
rule_type=AlertRuleType.THRESHOLD,
|
||||
severity=AlertSeverity.P1,
|
||||
metric="cpu_usage_percent",
|
||||
condition=">",
|
||||
threshold=80.0,
|
||||
duration=300,
|
||||
evaluation_interval=60,
|
||||
channels=[],
|
||||
labels={"service": "api", "team": "platform"},
|
||||
annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
|
||||
created_by="test_user",
|
||||
rule1 = self.manager.create_alert_rule(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "CPU 使用率告警",
|
||||
description = "当 CPU 使用率超过 80% 时触发告警",
|
||||
rule_type = AlertRuleType.THRESHOLD,
|
||||
severity = AlertSeverity.P1,
|
||||
metric = "cpu_usage_percent",
|
||||
condition = ">",
|
||||
threshold = 80.0,
|
||||
duration = 300,
|
||||
evaluation_interval = 60,
|
||||
channels = [],
|
||||
labels = {"service": "api", "team": "platform"},
|
||||
annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
|
||||
created_by = "test_user",
|
||||
)
|
||||
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
|
||||
|
||||
# 创建异常检测告警规则
|
||||
rule2 = self.manager.create_alert_rule(
|
||||
tenant_id=self.tenant_id,
|
||||
name="内存异常检测",
|
||||
description="检测内存使用异常",
|
||||
rule_type=AlertRuleType.ANOMALY,
|
||||
severity=AlertSeverity.P2,
|
||||
metric="memory_usage_percent",
|
||||
condition=">",
|
||||
threshold=0.0,
|
||||
duration=600,
|
||||
evaluation_interval=300,
|
||||
channels=[],
|
||||
labels={"service": "database"},
|
||||
annotations={},
|
||||
created_by="test_user",
|
||||
rule2 = self.manager.create_alert_rule(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "内存异常检测",
|
||||
description = "检测内存使用异常",
|
||||
rule_type = AlertRuleType.ANOMALY,
|
||||
severity = AlertSeverity.P2,
|
||||
metric = "memory_usage_percent",
|
||||
condition = ">",
|
||||
threshold = 0.0,
|
||||
duration = 600,
|
||||
evaluation_interval = 300,
|
||||
channels = [],
|
||||
labels = {"service": "database"},
|
||||
annotations = {},
|
||||
created_by = "test_user",
|
||||
)
|
||||
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
|
||||
|
||||
# 获取告警规则
|
||||
fetched_rule = self.manager.get_alert_rule(rule1.id)
|
||||
fetched_rule = self.manager.get_alert_rule(rule1.id)
|
||||
assert fetched_rule is not None
|
||||
assert fetched_rule.name == rule1.name
|
||||
self.log(f"Fetched alert rule: {fetched_rule.name}")
|
||||
|
||||
# 列出租户的所有告警规则
|
||||
rules = self.manager.list_alert_rules(self.tenant_id)
|
||||
rules = self.manager.list_alert_rules(self.tenant_id)
|
||||
assert len(rules) >= 2
|
||||
self.log(f"Listed {len(rules)} alert rules for tenant")
|
||||
|
||||
# 更新告警规则
|
||||
updated_rule = self.manager.update_alert_rule(
|
||||
rule1.id, threshold=85.0, description="更新后的描述"
|
||||
updated_rule = self.manager.update_alert_rule(
|
||||
rule1.id, threshold = 85.0, description = "更新后的描述"
|
||||
)
|
||||
assert updated_rule.threshold == 85.0
|
||||
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
|
||||
@@ -140,57 +140,57 @@ class TestOpsManager:
|
||||
self.log("Deleted test alert rules")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Alert rules test failed: {e}", success=False)
|
||||
self.log(f"Alert rules test failed: {e}", success = False)
|
||||
|
||||
def test_alert_channels(self):
|
||||
def test_alert_channels(self) -> None:
|
||||
"""测试告警渠道管理"""
|
||||
print("\n📢 Testing Alert Channels...")
|
||||
|
||||
try:
|
||||
# 创建飞书告警渠道
|
||||
channel1 = self.manager.create_alert_channel(
|
||||
tenant_id=self.tenant_id,
|
||||
name="飞书告警",
|
||||
channel_type=AlertChannelType.FEISHU,
|
||||
config={
|
||||
channel1 = self.manager.create_alert_channel(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "飞书告警",
|
||||
channel_type = AlertChannelType.FEISHU,
|
||||
config = {
|
||||
"webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test",
|
||||
"secret": "test_secret",
|
||||
},
|
||||
severity_filter=["p0", "p1"],
|
||||
severity_filter = ["p0", "p1"],
|
||||
)
|
||||
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
|
||||
|
||||
# 创建钉钉告警渠道
|
||||
channel2 = self.manager.create_alert_channel(
|
||||
tenant_id=self.tenant_id,
|
||||
name="钉钉告警",
|
||||
channel_type=AlertChannelType.DINGTALK,
|
||||
config={
|
||||
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test",
|
||||
channel2 = self.manager.create_alert_channel(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "钉钉告警",
|
||||
channel_type = AlertChannelType.DINGTALK,
|
||||
config = {
|
||||
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test",
|
||||
"secret": "test_secret",
|
||||
},
|
||||
severity_filter=["p0", "p1", "p2"],
|
||||
severity_filter = ["p0", "p1", "p2"],
|
||||
)
|
||||
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
|
||||
|
||||
# 创建 Slack 告警渠道
|
||||
channel3 = self.manager.create_alert_channel(
|
||||
tenant_id=self.tenant_id,
|
||||
name="Slack 告警",
|
||||
channel_type=AlertChannelType.SLACK,
|
||||
config={"webhook_url": "https://hooks.slack.com/services/test"},
|
||||
severity_filter=["p0", "p1", "p2", "p3"],
|
||||
channel3 = self.manager.create_alert_channel(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "Slack 告警",
|
||||
channel_type = AlertChannelType.SLACK,
|
||||
config = {"webhook_url": "https://hooks.slack.com/services/test"},
|
||||
severity_filter = ["p0", "p1", "p2", "p3"],
|
||||
)
|
||||
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
|
||||
|
||||
# 获取告警渠道
|
||||
fetched_channel = self.manager.get_alert_channel(channel1.id)
|
||||
fetched_channel = self.manager.get_alert_channel(channel1.id)
|
||||
assert fetched_channel is not None
|
||||
assert fetched_channel.name == channel1.name
|
||||
self.log(f"Fetched alert channel: {fetched_channel.name}")
|
||||
|
||||
# 列出租户的所有告警渠道
|
||||
channels = self.manager.list_alert_channels(self.tenant_id)
|
||||
channels = self.manager.list_alert_channels(self.tenant_id)
|
||||
assert len(channels) >= 3
|
||||
self.log(f"Listed {len(channels)} alert channels for tenant")
|
||||
|
||||
@@ -198,74 +198,74 @@ class TestOpsManager:
|
||||
for channel in channels:
|
||||
if channel.tenant_id == self.tenant_id:
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,))
|
||||
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id, ))
|
||||
conn.commit()
|
||||
self.log("Deleted test alert channels")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Alert channels test failed: {e}", success=False)
|
||||
self.log(f"Alert channels test failed: {e}", success = False)
|
||||
|
||||
def test_alerts(self):
|
||||
def test_alerts(self) -> None:
|
||||
"""测试告警管理"""
|
||||
print("\n🚨 Testing Alerts...")
|
||||
|
||||
try:
|
||||
# 创建告警规则
|
||||
rule = self.manager.create_alert_rule(
|
||||
tenant_id=self.tenant_id,
|
||||
name="测试告警规则",
|
||||
description="用于测试的告警规则",
|
||||
rule_type=AlertRuleType.THRESHOLD,
|
||||
severity=AlertSeverity.P1,
|
||||
metric="test_metric",
|
||||
condition=">",
|
||||
threshold=100.0,
|
||||
duration=60,
|
||||
evaluation_interval=60,
|
||||
channels=[],
|
||||
labels={},
|
||||
annotations={},
|
||||
created_by="test_user",
|
||||
rule = self.manager.create_alert_rule(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "测试告警规则",
|
||||
description = "用于测试的告警规则",
|
||||
rule_type = AlertRuleType.THRESHOLD,
|
||||
severity = AlertSeverity.P1,
|
||||
metric = "test_metric",
|
||||
condition = ">",
|
||||
threshold = 100.0,
|
||||
duration = 60,
|
||||
evaluation_interval = 60,
|
||||
channels = [],
|
||||
labels = {},
|
||||
annotations = {},
|
||||
created_by = "test_user",
|
||||
)
|
||||
|
||||
# 记录资源指标
|
||||
for i in range(10):
|
||||
self.manager.record_resource_metric(
|
||||
tenant_id=self.tenant_id,
|
||||
resource_type=ResourceType.CPU,
|
||||
resource_id="server-001",
|
||||
metric_name="test_metric",
|
||||
metric_value=110.0 + i,
|
||||
unit="percent",
|
||||
metadata={"region": "cn-north-1"},
|
||||
tenant_id = self.tenant_id,
|
||||
resource_type = ResourceType.CPU,
|
||||
resource_id = "server-001",
|
||||
metric_name = "test_metric",
|
||||
metric_value = 110.0 + i,
|
||||
unit = "percent",
|
||||
metadata = {"region": "cn-north-1"},
|
||||
)
|
||||
self.log("Recorded 10 resource metrics")
|
||||
|
||||
# 手动创建告警
|
||||
from ops_manager import Alert
|
||||
|
||||
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
now = datetime.now().isoformat()
|
||||
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
alert = Alert(
|
||||
id=alert_id,
|
||||
rule_id=rule.id,
|
||||
tenant_id=self.tenant_id,
|
||||
severity=AlertSeverity.P1,
|
||||
status=AlertStatus.FIRING,
|
||||
title="测试告警",
|
||||
description="这是一条测试告警",
|
||||
metric="test_metric",
|
||||
value=120.0,
|
||||
threshold=100.0,
|
||||
labels={"test": "true"},
|
||||
annotations={},
|
||||
started_at=now,
|
||||
resolved_at=None,
|
||||
acknowledged_by=None,
|
||||
acknowledged_at=None,
|
||||
notification_sent={},
|
||||
suppression_count=0,
|
||||
alert = Alert(
|
||||
id = alert_id,
|
||||
rule_id = rule.id,
|
||||
tenant_id = self.tenant_id,
|
||||
severity = AlertSeverity.P1,
|
||||
status = AlertStatus.FIRING,
|
||||
title = "测试告警",
|
||||
description = "这是一条测试告警",
|
||||
metric = "test_metric",
|
||||
value = 120.0,
|
||||
threshold = 100.0,
|
||||
labels = {"test": "true"},
|
||||
annotations = {},
|
||||
started_at = now,
|
||||
resolved_at = None,
|
||||
acknowledged_by = None,
|
||||
acknowledged_at = None,
|
||||
notification_sent = {},
|
||||
suppression_count = 0,
|
||||
)
|
||||
|
||||
with self.manager._get_db() as conn:
|
||||
@@ -299,20 +299,20 @@ class TestOpsManager:
|
||||
self.log(f"Created test alert: {alert.id}")
|
||||
|
||||
# 列出租户的告警
|
||||
alerts = self.manager.list_alerts(self.tenant_id)
|
||||
alerts = self.manager.list_alerts(self.tenant_id)
|
||||
assert len(alerts) >= 1
|
||||
self.log(f"Listed {len(alerts)} alerts for tenant")
|
||||
|
||||
# 确认告警
|
||||
self.manager.acknowledge_alert(alert_id, "test_user")
|
||||
fetched_alert = self.manager.get_alert(alert_id)
|
||||
fetched_alert = self.manager.get_alert(alert_id)
|
||||
assert fetched_alert.status == AlertStatus.ACKNOWLEDGED
|
||||
assert fetched_alert.acknowledged_by == "test_user"
|
||||
self.log(f"Acknowledged alert: {alert_id}")
|
||||
|
||||
# 解决告警
|
||||
self.manager.resolve_alert(alert_id)
|
||||
fetched_alert = self.manager.get_alert(alert_id)
|
||||
fetched_alert = self.manager.get_alert(alert_id)
|
||||
assert fetched_alert.status == AlertStatus.RESOLVED
|
||||
assert fetched_alert.resolved_at is not None
|
||||
self.log(f"Resolved alert: {alert_id}")
|
||||
@@ -320,23 +320,23 @@ class TestOpsManager:
|
||||
# 清理
|
||||
self.manager.delete_alert_rule(rule.id)
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id,))
|
||||
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id, ))
|
||||
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.commit()
|
||||
self.log("Cleaned up test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Alerts test failed: {e}", success=False)
|
||||
self.log(f"Alerts test failed: {e}", success = False)
|
||||
|
||||
def test_capacity_planning(self):
|
||||
def test_capacity_planning(self) -> None:
|
||||
"""测试容量规划"""
|
||||
print("\n📊 Testing Capacity Planning...")
|
||||
|
||||
try:
|
||||
# 记录历史指标数据
|
||||
base_time = datetime.now() - timedelta(days=30)
|
||||
base_time = datetime.now() - timedelta(days = 30)
|
||||
for i in range(30):
|
||||
timestamp = (base_time + timedelta(days=i)).isoformat()
|
||||
timestamp = (base_time + timedelta(days = i)).isoformat()
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -360,13 +360,13 @@ class TestOpsManager:
|
||||
self.log("Recorded 30 days of historical metrics")
|
||||
|
||||
# 创建容量规划
|
||||
prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
|
||||
plan = self.manager.create_capacity_plan(
|
||||
tenant_id=self.tenant_id,
|
||||
resource_type=ResourceType.CPU,
|
||||
current_capacity=100.0,
|
||||
prediction_date=prediction_date,
|
||||
confidence=0.85,
|
||||
prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d")
|
||||
plan = self.manager.create_capacity_plan(
|
||||
tenant_id = self.tenant_id,
|
||||
resource_type = ResourceType.CPU,
|
||||
current_capacity = 100.0,
|
||||
prediction_date = prediction_date,
|
||||
confidence = 0.85,
|
||||
)
|
||||
|
||||
self.log(f"Created capacity plan: {plan.id}")
|
||||
@@ -375,38 +375,38 @@ class TestOpsManager:
|
||||
self.log(f" Recommended action: {plan.recommended_action}")
|
||||
|
||||
# 获取容量规划列表
|
||||
plans = self.manager.get_capacity_plans(self.tenant_id)
|
||||
plans = self.manager.get_capacity_plans(self.tenant_id)
|
||||
assert len(plans) >= 1
|
||||
self.log(f"Listed {len(plans)} capacity plans")
|
||||
|
||||
# 清理
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.commit()
|
||||
self.log("Cleaned up capacity planning test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Capacity planning test failed: {e}", success=False)
|
||||
self.log(f"Capacity planning test failed: {e}", success = False)
|
||||
|
||||
def test_auto_scaling(self):
|
||||
def test_auto_scaling(self) -> None:
|
||||
"""测试自动扩缩容"""
|
||||
print("\n⚖️ Testing Auto Scaling...")
|
||||
|
||||
try:
|
||||
# 创建自动扩缩容策略
|
||||
policy = self.manager.create_auto_scaling_policy(
|
||||
tenant_id=self.tenant_id,
|
||||
name="API 服务自动扩缩容",
|
||||
resource_type=ResourceType.CPU,
|
||||
min_instances=2,
|
||||
max_instances=10,
|
||||
target_utilization=0.7,
|
||||
scale_up_threshold=0.8,
|
||||
scale_down_threshold=0.3,
|
||||
scale_up_step=2,
|
||||
scale_down_step=1,
|
||||
cooldown_period=300,
|
||||
policy = self.manager.create_auto_scaling_policy(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "API 服务自动扩缩容",
|
||||
resource_type = ResourceType.CPU,
|
||||
min_instances = 2,
|
||||
max_instances = 10,
|
||||
target_utilization = 0.7,
|
||||
scale_up_threshold = 0.8,
|
||||
scale_down_threshold = 0.3,
|
||||
scale_up_step = 2,
|
||||
scale_down_step = 1,
|
||||
cooldown_period = 300,
|
||||
)
|
||||
|
||||
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
|
||||
@@ -415,13 +415,13 @@ class TestOpsManager:
|
||||
self.log(f" Target utilization: {policy.target_utilization}")
|
||||
|
||||
# 获取策略列表
|
||||
policies = self.manager.list_auto_scaling_policies(self.tenant_id)
|
||||
policies = self.manager.list_auto_scaling_policies(self.tenant_id)
|
||||
assert len(policies) >= 1
|
||||
self.log(f"Listed {len(policies)} auto scaling policies")
|
||||
|
||||
# 模拟扩缩容评估
|
||||
event = self.manager.evaluate_scaling_policy(
|
||||
policy_id=policy.id, current_instances=3, current_utilization=0.85
|
||||
event = self.manager.evaluate_scaling_policy(
|
||||
policy_id = policy.id, current_instances = 3, current_utilization = 0.85
|
||||
)
|
||||
|
||||
if event:
|
||||
@@ -432,62 +432,62 @@ class TestOpsManager:
|
||||
self.log("No scaling action needed")
|
||||
|
||||
# 获取扩缩容事件列表
|
||||
events = self.manager.list_scaling_events(self.tenant_id)
|
||||
events = self.manager.list_scaling_events(self.tenant_id)
|
||||
self.log(f"Listed {len(events)} scaling events")
|
||||
|
||||
# 清理
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.execute(
|
||||
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,)
|
||||
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id, )
|
||||
)
|
||||
conn.commit()
|
||||
self.log("Cleaned up auto scaling test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Auto scaling test failed: {e}", success=False)
|
||||
self.log(f"Auto scaling test failed: {e}", success = False)
|
||||
|
||||
def test_health_checks(self):
|
||||
def test_health_checks(self) -> None:
|
||||
"""测试健康检查"""
|
||||
print("\n💓 Testing Health Checks...")
|
||||
|
||||
try:
|
||||
# 创建 HTTP 健康检查
|
||||
check1 = self.manager.create_health_check(
|
||||
tenant_id=self.tenant_id,
|
||||
name="API 服务健康检查",
|
||||
target_type="service",
|
||||
target_id="api-service",
|
||||
check_type="http",
|
||||
check_config={"url": "https://api.insightflow.io/health", "expected_status": 200},
|
||||
interval=60,
|
||||
timeout=10,
|
||||
retry_count=3,
|
||||
check1 = self.manager.create_health_check(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "API 服务健康检查",
|
||||
target_type = "service",
|
||||
target_id = "api-service",
|
||||
check_type = "http",
|
||||
check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200},
|
||||
interval = 60,
|
||||
timeout = 10,
|
||||
retry_count = 3,
|
||||
)
|
||||
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
|
||||
|
||||
# 创建 TCP 健康检查
|
||||
check2 = self.manager.create_health_check(
|
||||
tenant_id=self.tenant_id,
|
||||
name="数据库健康检查",
|
||||
target_type="database",
|
||||
target_id="postgres-001",
|
||||
check_type="tcp",
|
||||
check_config={"host": "db.insightflow.io", "port": 5432},
|
||||
interval=30,
|
||||
timeout=5,
|
||||
retry_count=2,
|
||||
check2 = self.manager.create_health_check(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "数据库健康检查",
|
||||
target_type = "database",
|
||||
target_id = "postgres-001",
|
||||
check_type = "tcp",
|
||||
check_config = {"host": "db.insightflow.io", "port": 5432},
|
||||
interval = 30,
|
||||
timeout = 5,
|
||||
retry_count = 2,
|
||||
)
|
||||
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
|
||||
|
||||
# 获取健康检查列表
|
||||
checks = self.manager.list_health_checks(self.tenant_id)
|
||||
checks = self.manager.list_health_checks(self.tenant_id)
|
||||
assert len(checks) >= 2
|
||||
self.log(f"Listed {len(checks)} health checks")
|
||||
|
||||
# 执行健康检查(异步)
|
||||
async def run_health_check():
|
||||
result = await self.manager.execute_health_check(check1.id)
|
||||
async def run_health_check() -> None:
|
||||
result = await self.manager.execute_health_check(check1.id)
|
||||
return result
|
||||
|
||||
# 由于健康检查需要网络,这里只验证方法存在
|
||||
@@ -495,28 +495,28 @@ class TestOpsManager:
|
||||
|
||||
# 清理
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.commit()
|
||||
self.log("Cleaned up health check test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Health checks test failed: {e}", success=False)
|
||||
self.log(f"Health checks test failed: {e}", success = False)
|
||||
|
||||
def test_failover(self):
|
||||
def test_failover(self) -> None:
|
||||
"""测试故障转移"""
|
||||
print("\n🔄 Testing Failover...")
|
||||
|
||||
try:
|
||||
# 创建故障转移配置
|
||||
config = self.manager.create_failover_config(
|
||||
tenant_id=self.tenant_id,
|
||||
name="主备数据中心故障转移",
|
||||
primary_region="cn-north-1",
|
||||
secondary_regions=["cn-south-1", "cn-east-1"],
|
||||
failover_trigger="health_check_failed",
|
||||
auto_failover=False,
|
||||
failover_timeout=300,
|
||||
health_check_id=None,
|
||||
config = self.manager.create_failover_config(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "主备数据中心故障转移",
|
||||
primary_region = "cn-north-1",
|
||||
secondary_regions = ["cn-south-1", "cn-east-1"],
|
||||
failover_trigger = "health_check_failed",
|
||||
auto_failover = False,
|
||||
failover_timeout = 300,
|
||||
health_check_id = None,
|
||||
)
|
||||
|
||||
self.log(f"Created failover config: {config.name} (ID: {config.id})")
|
||||
@@ -524,13 +524,13 @@ class TestOpsManager:
|
||||
self.log(f" Secondary regions: {config.secondary_regions}")
|
||||
|
||||
# 获取故障转移配置列表
|
||||
configs = self.manager.list_failover_configs(self.tenant_id)
|
||||
configs = self.manager.list_failover_configs(self.tenant_id)
|
||||
assert len(configs) >= 1
|
||||
self.log(f"Listed {len(configs)} failover configs")
|
||||
|
||||
# 发起故障转移
|
||||
event = self.manager.initiate_failover(
|
||||
config_id=config.id, reason="Primary region health check failed"
|
||||
event = self.manager.initiate_failover(
|
||||
config_id = config.id, reason = "Primary region health check failed"
|
||||
)
|
||||
|
||||
if event:
|
||||
@@ -540,41 +540,41 @@ class TestOpsManager:
|
||||
|
||||
# 更新故障转移状态
|
||||
self.manager.update_failover_status(event.id, "completed")
|
||||
updated_event = self.manager.get_failover_event(event.id)
|
||||
updated_event = self.manager.get_failover_event(event.id)
|
||||
assert updated_event.status == "completed"
|
||||
self.log("Failover completed")
|
||||
|
||||
# 获取故障转移事件列表
|
||||
events = self.manager.list_failover_events(self.tenant_id)
|
||||
events = self.manager.list_failover_events(self.tenant_id)
|
||||
self.log(f"Listed {len(events)} failover events")
|
||||
|
||||
# 清理
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.commit()
|
||||
self.log("Cleaned up failover test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Failover test failed: {e}", success=False)
|
||||
self.log(f"Failover test failed: {e}", success = False)
|
||||
|
||||
def test_backup(self):
|
||||
def test_backup(self) -> None:
|
||||
"""测试备份与恢复"""
|
||||
print("\n💾 Testing Backup & Recovery...")
|
||||
|
||||
try:
|
||||
# 创建备份任务
|
||||
job = self.manager.create_backup_job(
|
||||
tenant_id=self.tenant_id,
|
||||
name="每日数据库备份",
|
||||
backup_type="full",
|
||||
target_type="database",
|
||||
target_id="postgres-main",
|
||||
schedule="0 2 * * *", # 每天凌晨2点
|
||||
retention_days=30,
|
||||
encryption_enabled=True,
|
||||
compression_enabled=True,
|
||||
storage_location="s3://insightflow-backups/",
|
||||
job = self.manager.create_backup_job(
|
||||
tenant_id = self.tenant_id,
|
||||
name = "每日数据库备份",
|
||||
backup_type = "full",
|
||||
target_type = "database",
|
||||
target_id = "postgres-main",
|
||||
schedule = "0 2 * * *", # 每天凌晨2点
|
||||
retention_days = 30,
|
||||
encryption_enabled = True,
|
||||
compression_enabled = True,
|
||||
storage_location = "s3://insightflow-backups/",
|
||||
)
|
||||
|
||||
self.log(f"Created backup job: {job.name} (ID: {job.id})")
|
||||
@@ -582,12 +582,12 @@ class TestOpsManager:
|
||||
self.log(f" Retention: {job.retention_days} days")
|
||||
|
||||
# 获取备份任务列表
|
||||
jobs = self.manager.list_backup_jobs(self.tenant_id)
|
||||
jobs = self.manager.list_backup_jobs(self.tenant_id)
|
||||
assert len(jobs) >= 1
|
||||
self.log(f"Listed {len(jobs)} backup jobs")
|
||||
|
||||
# 执行备份
|
||||
record = self.manager.execute_backup(job.id)
|
||||
record = self.manager.execute_backup(job.id)
|
||||
|
||||
if record:
|
||||
self.log(f"Executed backup: {record.id}")
|
||||
@@ -595,50 +595,50 @@ class TestOpsManager:
|
||||
self.log(f" Storage: {record.storage_path}")
|
||||
|
||||
# 获取备份记录列表
|
||||
records = self.manager.list_backup_records(self.tenant_id)
|
||||
records = self.manager.list_backup_records(self.tenant_id)
|
||||
self.log(f"Listed {len(records)} backup records")
|
||||
|
||||
# 测试恢复(模拟)
|
||||
restore_result = self.manager.restore_from_backup(record.id)
|
||||
restore_result = self.manager.restore_from_backup(record.id)
|
||||
self.log(f"Restore test result: {restore_result}")
|
||||
|
||||
# 清理
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.commit()
|
||||
self.log("Cleaned up backup test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Backup test failed: {e}", success=False)
|
||||
self.log(f"Backup test failed: {e}", success = False)
|
||||
|
||||
def test_cost_optimization(self):
|
||||
def test_cost_optimization(self) -> None:
|
||||
"""测试成本优化"""
|
||||
print("\n💰 Testing Cost Optimization...")
|
||||
|
||||
try:
|
||||
# 记录资源利用率数据
|
||||
report_date = datetime.now().strftime("%Y-%m-%d")
|
||||
report_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
for i in range(5):
|
||||
self.manager.record_resource_utilization(
|
||||
tenant_id=self.tenant_id,
|
||||
resource_type=ResourceType.CPU,
|
||||
resource_id=f"server-{i:03d}",
|
||||
utilization_rate=0.05 + random.random() * 0.1, # 低利用率
|
||||
peak_utilization=0.15,
|
||||
avg_utilization=0.08,
|
||||
idle_time_percent=0.85,
|
||||
report_date=report_date,
|
||||
recommendations=["Consider downsizing this resource"],
|
||||
tenant_id = self.tenant_id,
|
||||
resource_type = ResourceType.CPU,
|
||||
resource_id = f"server-{i:03d}",
|
||||
utilization_rate = 0.05 + random.random() * 0.1, # 低利用率
|
||||
peak_utilization = 0.15,
|
||||
avg_utilization = 0.08,
|
||||
idle_time_percent = 0.85,
|
||||
report_date = report_date,
|
||||
recommendations = ["Consider downsizing this resource"],
|
||||
)
|
||||
|
||||
self.log("Recorded 5 resource utilization records")
|
||||
|
||||
# 生成成本报告
|
||||
now = datetime.now()
|
||||
report = self.manager.generate_cost_report(
|
||||
tenant_id=self.tenant_id, year=now.year, month=now.month
|
||||
now = datetime.now()
|
||||
report = self.manager.generate_cost_report(
|
||||
tenant_id = self.tenant_id, year = now.year, month = now.month
|
||||
)
|
||||
|
||||
self.log(f"Generated cost report: {report.id}")
|
||||
@@ -647,11 +647,11 @@ class TestOpsManager:
|
||||
self.log(f" Anomalies detected: {len(report.anomalies)}")
|
||||
|
||||
# 检测闲置资源
|
||||
idle_resources = self.manager.detect_idle_resources(self.tenant_id)
|
||||
idle_resources = self.manager.detect_idle_resources(self.tenant_id)
|
||||
self.log(f"Detected {len(idle_resources)} idle resources")
|
||||
|
||||
# 获取闲置资源列表
|
||||
idle_list = self.manager.get_idle_resources(self.tenant_id)
|
||||
idle_list = self.manager.get_idle_resources(self.tenant_id)
|
||||
for resource in idle_list:
|
||||
self.log(
|
||||
f" Idle resource: {resource.resource_name} (est. cost: {
|
||||
@@ -660,7 +660,7 @@ class TestOpsManager:
|
||||
)
|
||||
|
||||
# 生成成本优化建议
|
||||
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
|
||||
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
|
||||
self.log(f"Generated {len(suggestions)} cost optimization suggestions")
|
||||
|
||||
for suggestion in suggestions:
|
||||
@@ -672,12 +672,12 @@ class TestOpsManager:
|
||||
self.log(f" Difficulty: {suggestion.difficulty}")
|
||||
|
||||
# 获取优化建议列表
|
||||
all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id)
|
||||
all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id)
|
||||
self.log(f"Listed {len(all_suggestions)} optimization suggestions")
|
||||
|
||||
# 应用优化建议
|
||||
if all_suggestions:
|
||||
applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id)
|
||||
applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id)
|
||||
if applied:
|
||||
self.log(f"Applied optimization suggestion: {applied.title}")
|
||||
assert applied.is_applied
|
||||
@@ -686,29 +686,29 @@ class TestOpsManager:
|
||||
# 清理
|
||||
with self.manager._get_db() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
|
||||
(self.tenant_id,),
|
||||
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
|
||||
(self.tenant_id, ),
|
||||
)
|
||||
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.execute(
|
||||
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,)
|
||||
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id, )
|
||||
)
|
||||
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
|
||||
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id, ))
|
||||
conn.commit()
|
||||
self.log("Cleaned up cost optimization test data")
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"Cost optimization test failed: {e}", success=False)
|
||||
self.log(f"Cost optimization test failed: {e}", success = False)
|
||||
|
||||
def print_summary(self):
|
||||
def print_summary(self) -> None:
|
||||
"""打印测试总结"""
|
||||
print("\n" + "=" * 60)
|
||||
print("\n" + " = " * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
total = len(self.test_results)
|
||||
passed = sum(1 for _, success in self.test_results if success)
|
||||
failed = total - passed
|
||||
total = len(self.test_results)
|
||||
passed = sum(1 for _, success in self.test_results if success)
|
||||
failed = total - passed
|
||||
|
||||
print(f"Total tests: {total}")
|
||||
print(f"Passed: {passed} ✅")
|
||||
@@ -720,12 +720,12 @@ class TestOpsManager:
|
||||
if not success:
|
||||
print(f" ❌ {message}")
|
||||
|
||||
print("=" * 60)
|
||||
print(" = " * 60)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""主函数"""
|
||||
test = TestOpsManager()
|
||||
test = TestOpsManager()
|
||||
test.run_all_tests()
|
||||
|
||||
|
||||
|
||||
@@ -10,19 +10,19 @@ from typing import Any
|
||||
|
||||
|
||||
class TingwuClient:
|
||||
def __init__(self):
|
||||
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
|
||||
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
|
||||
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
|
||||
def __init__(self) -> None:
|
||||
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
|
||||
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
|
||||
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
|
||||
|
||||
if not self.access_key or not self.secret_key:
|
||||
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
|
||||
|
||||
def _sign_request(
|
||||
self, method: str, uri: str, query: str = "", body: str = ""
|
||||
self, method: str, uri: str, query: str = "", body: str = ""
|
||||
) -> dict[str, str]:
|
||||
"""阿里云签名 V3"""
|
||||
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# 简化签名,实际生产需要完整实现
|
||||
# 这里使用基础认证头
|
||||
@@ -31,10 +31,10 @@ class TingwuClient:
|
||||
"x-acs-action": "CreateTask",
|
||||
"x-acs-version": "2023-09-30",
|
||||
"x-acs-date": timestamp,
|
||||
"Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing",
|
||||
"Authorization": f"ACS3-HMAC-SHA256 Credential = {self.access_key}/acs/tingwu/cn-beijing",
|
||||
}
|
||||
|
||||
def create_task(self, audio_url: str, language: str = "zh") -> str:
|
||||
def create_task(self, audio_url: str, language: str = "zh") -> str:
|
||||
"""创建听悟任务"""
|
||||
try:
|
||||
# 导入移到文件顶部会导致循环导入,保持在这里
|
||||
@@ -42,23 +42,23 @@ class TingwuClient:
|
||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||
|
||||
config = open_api_models.Config(
|
||||
access_key_id=self.access_key, access_key_secret=self.secret_key
|
||||
config = open_api_models.Config(
|
||||
access_key_id = self.access_key, access_key_secret = self.secret_key
|
||||
)
|
||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||
client = TingwuSDKClient(config)
|
||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||
client = TingwuSDKClient(config)
|
||||
|
||||
request = tingwu_models.CreateTaskRequest(
|
||||
type="offline",
|
||||
input=tingwu_models.Input(source="OSS", file_url=audio_url),
|
||||
parameters=tingwu_models.Parameters(
|
||||
transcription=tingwu_models.Transcription(
|
||||
diarization_enabled=True, sentence_max_length=20
|
||||
request = tingwu_models.CreateTaskRequest(
|
||||
type = "offline",
|
||||
input = tingwu_models.Input(source = "OSS", file_url = audio_url),
|
||||
parameters = tingwu_models.Parameters(
|
||||
transcription = tingwu_models.Transcription(
|
||||
diarization_enabled = True, sentence_max_length = 20
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.create_task(request)
|
||||
response = client.create_task(request)
|
||||
if response.body.code == "0":
|
||||
return response.body.data.task_id
|
||||
else:
|
||||
@@ -73,29 +73,26 @@ class TingwuClient:
|
||||
return f"mock_task_{int(time.time())}"
|
||||
|
||||
def get_task_result(
|
||||
self, task_id: str, max_retries: int = 60, interval: int = 5
|
||||
self, task_id: str, max_retries: int = 60, interval: int = 5
|
||||
) -> dict[str, Any]:
|
||||
"""获取任务结果"""
|
||||
try:
|
||||
# 导入移到文件顶部会导致循环导入,保持在这里
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||
|
||||
config = open_api_models.Config(
|
||||
access_key_id=self.access_key, access_key_secret=self.secret_key
|
||||
config = open_api_models.Config(
|
||||
access_key_id = self.access_key, access_key_secret = self.secret_key
|
||||
)
|
||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||
client = TingwuSDKClient(config)
|
||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||
client = TingwuSDKClient(config)
|
||||
|
||||
for i in range(max_retries):
|
||||
request = tingwu_models.GetTaskInfoRequest()
|
||||
response = client.get_task_info(task_id, request)
|
||||
request = tingwu_models.GetTaskInfoRequest()
|
||||
response = client.get_task_info(task_id, request)
|
||||
|
||||
if response.body.code != "0":
|
||||
raise RuntimeError(f"Query failed: {response.body.message}")
|
||||
|
||||
status = response.body.data.task_status
|
||||
status = response.body.data.task_status
|
||||
|
||||
if status == "SUCCESS":
|
||||
return self._parse_result(response.body.data)
|
||||
@@ -116,11 +113,11 @@ class TingwuClient:
|
||||
|
||||
def _parse_result(self, data) -> dict[str, Any]:
|
||||
"""解析结果"""
|
||||
result = data.result
|
||||
transcription = result.transcription
|
||||
result = data.result
|
||||
transcription = result.transcription
|
||||
|
||||
full_text = ""
|
||||
segments = []
|
||||
full_text = ""
|
||||
segments = []
|
||||
|
||||
if transcription.paragraphs:
|
||||
for para in transcription.paragraphs:
|
||||
@@ -153,8 +150,8 @@ class TingwuClient:
|
||||
],
|
||||
}
|
||||
|
||||
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
|
||||
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
|
||||
"""一键转录"""
|
||||
task_id = self.create_task(audio_url, language)
|
||||
task_id = self.create_task(audio_url, language)
|
||||
print(f"Tingwu task: {task_id}")
|
||||
return self.get_task_result(task_id)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,10 +11,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# 项目路径
|
||||
PROJECT_PATH = Path("/root/.openclaw/workspace/projects/insightflow")
|
||||
PROJECT_PATH = Path("/root/.openclaw/workspace/projects/insightflow")
|
||||
|
||||
# 修复报告
|
||||
report = {
|
||||
report = {
|
||||
"fixed": [],
|
||||
"manual_review": [],
|
||||
"errors": []
|
||||
@@ -22,7 +22,7 @@ report = {
|
||||
|
||||
def find_python_files() -> list[Path]:
|
||||
"""查找所有 Python 文件"""
|
||||
py_files = []
|
||||
py_files = []
|
||||
for py_file in PROJECT_PATH.rglob("*.py"):
|
||||
if "__pycache__" not in str(py_file):
|
||||
py_files.append(py_file)
|
||||
@@ -30,12 +30,12 @@ def find_python_files() -> list[Path]:
|
||||
|
||||
def check_duplicate_imports(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查重复导入"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
imports = {}
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
imports = {}
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
line_stripped = line.strip()
|
||||
line_stripped = line.strip()
|
||||
if line_stripped.startswith('import ') or line_stripped.startswith('from '):
|
||||
if line_stripped in imports:
|
||||
issues.append({
|
||||
@@ -45,17 +45,17 @@ def check_duplicate_imports(content: str, file_path: Path) -> list[dict]:
|
||||
"original_line": imports[line_stripped]
|
||||
})
|
||||
else:
|
||||
imports[line_stripped] = i
|
||||
imports[line_stripped] = i
|
||||
return issues
|
||||
|
||||
def check_bare_excepts(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查裸异常捕获"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
# 检查 except Exception: 或 except :
|
||||
stripped = line.strip()
|
||||
# 检查 except Exception: 或 except Exception:
|
||||
if re.match(r'^except\s*:', stripped):
|
||||
issues.append({
|
||||
"line": i,
|
||||
@@ -66,9 +66,9 @@ def check_bare_excepts(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def check_line_length(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查行长度(PEP8: 79字符,这里放宽到 100)"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if len(line) > 100:
|
||||
issues.append({
|
||||
@@ -81,24 +81,24 @@ def check_line_length(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def check_unused_imports(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查未使用的导入"""
|
||||
issues = []
|
||||
issues = []
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
imports = {}
|
||||
used_names = set()
|
||||
|
||||
tree = ast.parse(content)
|
||||
imports = {}
|
||||
used_names = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
imports[alias.asname or alias.name] = node
|
||||
imports[alias.asname or alias.name] = node
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname or alias.name
|
||||
name = alias.asname or alias.name
|
||||
if name != '*':
|
||||
imports[name] = node
|
||||
imports[name] = node
|
||||
elif isinstance(node, ast.Name):
|
||||
used_names.add(node.id)
|
||||
|
||||
|
||||
for name, node in imports.items():
|
||||
if name not in used_names and not name.startswith('_'):
|
||||
issues.append({
|
||||
@@ -112,9 +112,9 @@ def check_unused_imports(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def check_string_formatting(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查混合字符串格式化(建议使用 f-string)"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# 检查 % 格式化
|
||||
if re.search(r'["\'].*%\s*\w+', line) and '%' in line:
|
||||
@@ -136,18 +136,18 @@ def check_string_formatting(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def check_magic_numbers(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查魔法数字"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
# 常见魔法数字模式(排除常见索引和简单值)
|
||||
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3,})(?![\w\d_])')
|
||||
|
||||
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3, })(?![\w\d_])')
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if line.strip().startswith('#'):
|
||||
continue
|
||||
matches = magic_pattern.findall(line)
|
||||
matches = magic_pattern.findall(line)
|
||||
for match in matches:
|
||||
num = int(match)
|
||||
num = int(match)
|
||||
# 排除常见值
|
||||
if num not in [200, 201, 204, 301, 302, 400, 401, 403, 404, 429, 500, 502, 503, 3600, 86400]:
|
||||
issues.append({
|
||||
@@ -160,9 +160,9 @@ def check_magic_numbers(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def check_sql_injection(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查 SQL 注入风险"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# 检查字符串拼接的 SQL
|
||||
if 'execute(' in line or 'executescript(' in line or 'executemany(' in line:
|
||||
@@ -179,9 +179,9 @@ def check_sql_injection(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def check_cors_config(content: str, file_path: Path) -> list[dict]:
|
||||
"""检查 CORS 配置"""
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
issues = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if 'allow_origins' in line and '["*"]' in line:
|
||||
issues.append({
|
||||
@@ -194,48 +194,48 @@ def check_cors_config(content: str, file_path: Path) -> list[dict]:
|
||||
|
||||
def fix_bare_excepts(content: str) -> str:
|
||||
"""修复裸异常捕获"""
|
||||
lines = content.split('\n')
|
||||
new_lines = []
|
||||
|
||||
lines = content.split('\n')
|
||||
new_lines = []
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
stripped = line.strip()
|
||||
if re.match(r'^except\s*:', stripped):
|
||||
# 替换为具体异常
|
||||
indent = len(line) - len(line.lstrip())
|
||||
new_line = ' ' * indent + 'except (RuntimeError, ValueError, TypeError):'
|
||||
indent = len(line) - len(line.lstrip())
|
||||
new_line = ' ' * indent + 'except (RuntimeError, ValueError, TypeError):'
|
||||
new_lines.append(new_line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
|
||||
return '\n'.join(new_lines)
|
||||
|
||||
def fix_line_length(content: str) -> str:
|
||||
"""修复行长度问题(简单折行)"""
|
||||
lines = content.split('\n')
|
||||
new_lines = []
|
||||
|
||||
lines = content.split('\n')
|
||||
new_lines = []
|
||||
|
||||
for line in lines:
|
||||
if len(line) > 100:
|
||||
# 尝试在逗号或运算符处折行
|
||||
if ',' in line[80:]:
|
||||
if ', ' in line[80:]:
|
||||
# 简单处理:截断并添加续行
|
||||
indent = len(line) - len(line.lstrip())
|
||||
indent = len(line) - len(line.lstrip())
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
|
||||
return '\n'.join(new_lines)
|
||||
|
||||
def analyze_file(file_path: Path) -> dict:
|
||||
"""分析单个文件"""
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
content = file_path.read_text(encoding = 'utf-8')
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
issues = {
|
||||
|
||||
issues = {
|
||||
"duplicate_imports": check_duplicate_imports(content, file_path),
|
||||
"bare_excepts": check_bare_excepts(content, file_path),
|
||||
"line_length": check_line_length(content, file_path),
|
||||
@@ -245,22 +245,22 @@ def analyze_file(file_path: Path) -> dict:
|
||||
"sql_injection": check_sql_injection(content, file_path),
|
||||
"cors_config": check_cors_config(content, file_path),
|
||||
}
|
||||
|
||||
|
||||
return issues
|
||||
|
||||
def fix_file(file_path: Path, issues: dict) -> bool:
|
||||
"""自动修复文件问题"""
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
original_content = content
|
||||
|
||||
content = file_path.read_text(encoding = 'utf-8')
|
||||
original_content = content
|
||||
|
||||
# 修复裸异常
|
||||
if issues.get("bare_excepts"):
|
||||
content = fix_bare_excepts(content)
|
||||
|
||||
content = fix_bare_excepts(content)
|
||||
|
||||
# 如果有修改,写回文件
|
||||
if content != original_content:
|
||||
file_path.write_text(content, encoding='utf-8')
|
||||
file_path.write_text(content, encoding = 'utf-8')
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
@@ -269,88 +269,88 @@ def fix_file(file_path: Path, issues: dict) -> bool:
|
||||
|
||||
def generate_report(all_issues: dict) -> str:
|
||||
"""生成修复报告"""
|
||||
lines = []
|
||||
lines = []
|
||||
lines.append("# InsightFlow 代码审查报告")
|
||||
lines.append(f"\n生成时间: {__import__('datetime').datetime.now().isoformat()}")
|
||||
lines.append("\n## 自动修复的问题\n")
|
||||
|
||||
total_fixed = 0
|
||||
|
||||
total_fixed = 0
|
||||
for file_path, issues in all_issues.items():
|
||||
fixed_count = 0
|
||||
fixed_count = 0
|
||||
for issue_type, issue_list in issues.items():
|
||||
if issue_type in ["bare_excepts"] and issue_list:
|
||||
fixed_count += len(issue_list)
|
||||
|
||||
|
||||
if fixed_count > 0:
|
||||
lines.append(f"### {file_path}")
|
||||
lines.append(f"- 修复裸异常捕获: {fixed_count} 处")
|
||||
total_fixed += fixed_count
|
||||
|
||||
|
||||
if total_fixed == 0:
|
||||
lines.append("未发现需要自动修复的问题。")
|
||||
|
||||
|
||||
lines.append(f"\n**总计自动修复: {total_fixed} 处**")
|
||||
|
||||
|
||||
lines.append("\n## 需要人工确认的问题\n")
|
||||
|
||||
total_manual = 0
|
||||
|
||||
total_manual = 0
|
||||
for file_path, issues in all_issues.items():
|
||||
manual_issues = []
|
||||
|
||||
manual_issues = []
|
||||
|
||||
if issues.get("sql_injection"):
|
||||
manual_issues.extend(issues["sql_injection"])
|
||||
if issues.get("cors_config"):
|
||||
manual_issues.extend(issues["cors_config"])
|
||||
|
||||
|
||||
if manual_issues:
|
||||
lines.append(f"### {file_path}")
|
||||
for issue in manual_issues:
|
||||
lines.append(f"- **{issue['type']}** (第 {issue['line']} 行): {issue.get('content', '')}")
|
||||
total_manual += len(manual_issues)
|
||||
|
||||
|
||||
if total_manual == 0:
|
||||
lines.append("未发现需要人工确认的问题。")
|
||||
|
||||
|
||||
lines.append(f"\n**总计待确认: {total_manual} 处**")
|
||||
|
||||
|
||||
lines.append("\n## 代码风格建议\n")
|
||||
|
||||
|
||||
for file_path, issues in all_issues.items():
|
||||
style_issues = []
|
||||
style_issues = []
|
||||
if issues.get("line_length"):
|
||||
style_issues.extend(issues["line_length"])
|
||||
if issues.get("string_formatting"):
|
||||
style_issues.extend(issues["string_formatting"])
|
||||
if issues.get("magic_numbers"):
|
||||
style_issues.extend(issues["magic_numbers"])
|
||||
|
||||
|
||||
if style_issues:
|
||||
lines.append(f"### {file_path}")
|
||||
for issue in style_issues[:5]: # 只显示前5个
|
||||
lines.append(f"- 第 {issue['line']} 行: {issue['type']}")
|
||||
if len(style_issues) > 5:
|
||||
lines.append(f"- ... 还有 {len(style_issues) - 5} 个类似问题")
|
||||
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def git_commit_and_push():
|
||||
def git_commit_and_push() -> None:
|
||||
"""提交并推送代码"""
|
||||
try:
|
||||
os.chdir(PROJECT_PATH)
|
||||
|
||||
|
||||
# 检查是否有修改
|
||||
result = subprocess.run(
|
||||
result = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
capture_output = True,
|
||||
text = True
|
||||
)
|
||||
|
||||
|
||||
if not result.stdout.strip():
|
||||
return "没有需要提交的更改"
|
||||
|
||||
|
||||
# 添加所有修改
|
||||
subprocess.run(["git", "add", "-A"], check=True)
|
||||
|
||||
subprocess.run(["git", "add", "-A"], check = True)
|
||||
|
||||
# 提交
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", """fix: auto-fix code issues (cron)
|
||||
@@ -359,52 +359,52 @@ def git_commit_and_push():
|
||||
- 修复异常处理
|
||||
- 修复PEP8格式问题
|
||||
- 添加类型注解"""],
|
||||
check=True
|
||||
check = True
|
||||
)
|
||||
|
||||
|
||||
# 推送
|
||||
subprocess.run(["git", "push"], check=True)
|
||||
|
||||
subprocess.run(["git", "push"], check = True)
|
||||
|
||||
return "✅ 提交并推送成功"
|
||||
except subprocess.CalledProcessError as e:
|
||||
return f"❌ Git 操作失败: {e}"
|
||||
except Exception as e:
|
||||
return f"❌ 错误: {e}"
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""主函数"""
|
||||
print("🔍 开始代码审查...")
|
||||
|
||||
py_files = find_python_files()
|
||||
|
||||
py_files = find_python_files()
|
||||
print(f"📁 找到 {len(py_files)} 个 Python 文件")
|
||||
|
||||
all_issues = {}
|
||||
|
||||
|
||||
all_issues = {}
|
||||
|
||||
for py_file in py_files:
|
||||
print(f" 分析: {py_file.name}")
|
||||
issues = analyze_file(py_file)
|
||||
all_issues[py_file] = issues
|
||||
|
||||
issues = analyze_file(py_file)
|
||||
all_issues[py_file] = issues
|
||||
|
||||
# 自动修复
|
||||
if fix_file(py_file, issues):
|
||||
report["fixed"].append(str(py_file))
|
||||
|
||||
|
||||
# 生成报告
|
||||
report_content = generate_report(all_issues)
|
||||
report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md"
|
||||
report_path.write_text(report_content, encoding='utf-8')
|
||||
|
||||
report_content = generate_report(all_issues)
|
||||
report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md"
|
||||
report_path.write_text(report_content, encoding = 'utf-8')
|
||||
|
||||
print("\n📄 报告已生成:", report_path)
|
||||
|
||||
|
||||
# Git 提交
|
||||
print("\n🚀 提交代码...")
|
||||
git_result = git_commit_and_push()
|
||||
git_result = git_commit_and_push()
|
||||
print(git_result)
|
||||
|
||||
|
||||
# 追加提交结果到报告
|
||||
with open(report_path, 'a', encoding='utf-8') as f:
|
||||
with open(report_path, 'a', encoding = 'utf-8') as f:
|
||||
f.write(f"\n\n## Git 提交结果\n\n{git_result}\n")
|
||||
|
||||
|
||||
print("\n✅ 代码审查完成!")
|
||||
return report_content
|
||||
|
||||
|
||||
166
code_reviewer.py
166
code_reviewer.py
@@ -15,25 +15,25 @@ class CodeIssue:
|
||||
line_no: int,
|
||||
issue_type: str,
|
||||
message: str,
|
||||
severity: str = "info",
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.line_no = line_no
|
||||
self.issue_type = issue_type
|
||||
self.message = message
|
||||
self.severity = severity # info, warning, error
|
||||
self.fixed = False
|
||||
severity: str = "info",
|
||||
) -> None:
|
||||
self.file_path = file_path
|
||||
self.line_no = line_no
|
||||
self.issue_type = issue_type
|
||||
self.message = message
|
||||
self.severity = severity # info, warning, error
|
||||
self.fixed = False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> None:
|
||||
return f"{self.severity.upper()}: {self.file_path}:{self.line_no} - {self.issue_type}: {self.message}"
|
||||
|
||||
|
||||
class CodeReviewer:
|
||||
def __init__(self, base_path: str):
|
||||
self.base_path = Path(base_path)
|
||||
self.issues: list[CodeIssue] = []
|
||||
self.fixed_issues: list[CodeIssue] = []
|
||||
self.manual_review_issues: list[CodeIssue] = []
|
||||
def __init__(self, base_path: str) -> None:
|
||||
self.base_path = Path(base_path)
|
||||
self.issues: list[CodeIssue] = []
|
||||
self.fixed_issues: list[CodeIssue] = []
|
||||
self.manual_review_issues: list[CodeIssue] = []
|
||||
|
||||
def scan_all(self) -> None:
|
||||
"""扫描所有 Python 文件"""
|
||||
@@ -45,14 +45,14 @@ class CodeReviewer:
|
||||
def scan_file(self, file_path: Path) -> None:
|
||||
"""扫描单个文件"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
with open(file_path, "r", encoding = "utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
except Exception as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
return
|
||||
|
||||
rel_path = str(file_path.relative_to(self.base_path))
|
||||
rel_path = str(file_path.relative_to(self.base_path))
|
||||
|
||||
# 1. 检查裸异常捕获
|
||||
self._check_bare_exceptions(content, lines, rel_path)
|
||||
@@ -92,7 +92,7 @@ class CodeReviewer:
|
||||
# 跳过有注释说明的情况
|
||||
if "# noqa" in line or "# intentional" in line.lower():
|
||||
continue
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
i,
|
||||
"bare_exception",
|
||||
@@ -105,17 +105,17 @@ class CodeReviewer:
|
||||
self, content: str, lines: list[str], file_path: str
|
||||
) -> None:
|
||||
"""检查重复导入"""
|
||||
imports = {}
|
||||
imports = {}
|
||||
for i, line in enumerate(lines, 1):
|
||||
match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip())
|
||||
match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip())
|
||||
if match:
|
||||
module = match.group(1) or ""
|
||||
names = match.group(2).split(",")
|
||||
module = match.group(1) or ""
|
||||
names = match.group(2).split(", ")
|
||||
for name in names:
|
||||
name = name.strip().split()[0] # 处理 'as' 别名
|
||||
key = f"{module}.{name}" if module else name
|
||||
name = name.strip().split()[0] # 处理 'as' 别名
|
||||
key = f"{module}.{name}" if module else name
|
||||
if key in imports:
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
i,
|
||||
"duplicate_import",
|
||||
@@ -123,7 +123,7 @@ class CodeReviewer:
|
||||
"warning",
|
||||
)
|
||||
self.issues.append(issue)
|
||||
imports[key] = i
|
||||
imports[key] = i
|
||||
|
||||
def _check_pep8_issues(
|
||||
self, content: str, lines: list[str], file_path: str
|
||||
@@ -132,7 +132,7 @@ class CodeReviewer:
|
||||
for i, line in enumerate(lines, 1):
|
||||
# 行长度超过 120
|
||||
if len(line) > 120:
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
i,
|
||||
"line_too_long",
|
||||
@@ -143,7 +143,7 @@ class CodeReviewer:
|
||||
|
||||
# 行尾空格
|
||||
if line.rstrip() != line:
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path, i, "trailing_whitespace", "行尾有空格", "info"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
@@ -151,7 +151,7 @@ class CodeReviewer:
|
||||
# 多余的空行
|
||||
if i > 1 and line.strip() == "" and lines[i - 2].strip() == "":
|
||||
if i < len(lines) and lines[i].strip() == "":
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path, i, "extra_blank_line", "多余的空行", "info"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
@@ -161,23 +161,23 @@ class CodeReviewer:
|
||||
) -> None:
|
||||
"""检查未使用的导入"""
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError:
|
||||
return
|
||||
|
||||
imported_names = {}
|
||||
used_names = set()
|
||||
imported_names = {}
|
||||
used_names = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
imported_names[name] = node.lineno
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
imported_names[name] = node.lineno
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name != "*":
|
||||
imported_names[name] = node.lineno
|
||||
imported_names[name] = node.lineno
|
||||
elif isinstance(node, ast.Name):
|
||||
used_names.add(node.id)
|
||||
|
||||
@@ -186,7 +186,7 @@ class CodeReviewer:
|
||||
# 排除一些常见例外
|
||||
if name in ["annotations", "TYPE_CHECKING"]:
|
||||
continue
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path, lineno, "unused_import", f"未使用的导入: {name}", "info"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
@@ -195,20 +195,20 @@ class CodeReviewer:
|
||||
self, content: str, lines: list[str], file_path: str
|
||||
) -> None:
|
||||
"""检查混合字符串格式化"""
|
||||
has_fstring = False
|
||||
has_percent = False
|
||||
has_format = False
|
||||
has_fstring = False
|
||||
has_percent = False
|
||||
has_format = False
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if re.search(r'f["\']', line):
|
||||
has_fstring = True
|
||||
has_fstring = True
|
||||
if re.search(r"%[sdfr]", line) and not re.search(r"\d+%", line):
|
||||
has_percent = True
|
||||
has_percent = True
|
||||
if ".format(" in line:
|
||||
has_format = True
|
||||
has_format = True
|
||||
|
||||
if has_fstring and (has_percent or has_format):
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
0,
|
||||
"mixed_formatting",
|
||||
@@ -222,25 +222,25 @@ class CodeReviewer:
|
||||
) -> None:
|
||||
"""检查魔法数字"""
|
||||
# 常见的魔法数字模式
|
||||
magic_patterns = [
|
||||
(r"=\s*(\d{3,})\s*[^:]", "可能的魔法数字"),
|
||||
(r"timeout\s*=\s*(\d+)", "timeout 魔法数字"),
|
||||
(r"limit\s*=\s*(\d+)", "limit 魔法数字"),
|
||||
(r"port\s*=\s*(\d+)", "port 魔法数字"),
|
||||
magic_patterns = [
|
||||
(r" = \s*(\d{3, })\s*[^:]", "可能的魔法数字"),
|
||||
(r"timeout\s* = \s*(\d+)", "timeout 魔法数字"),
|
||||
(r"limit\s* = \s*(\d+)", "limit 魔法数字"),
|
||||
(r"port\s* = \s*(\d+)", "port 魔法数字"),
|
||||
]
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# 跳过注释和字符串
|
||||
code_part = line.split("#")[0]
|
||||
code_part = line.split("#")[0]
|
||||
if not code_part.strip():
|
||||
continue
|
||||
|
||||
for pattern, msg in magic_patterns:
|
||||
if re.search(pattern, code_part, re.IGNORECASE):
|
||||
# 排除常见的合理数字
|
||||
match = re.search(r"(\d{3,})", code_part)
|
||||
match = re.search(r"(\d{3, })", code_part)
|
||||
if match:
|
||||
num = int(match.group(1))
|
||||
num = int(match.group(1))
|
||||
if num in [
|
||||
200,
|
||||
404,
|
||||
@@ -257,7 +257,7 @@ class CodeReviewer:
|
||||
8000,
|
||||
]:
|
||||
continue
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path, i, "magic_number", f"{msg}: {num}", "info"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
@@ -272,7 +272,7 @@ class CodeReviewer:
|
||||
r'execute\s*\(\s*f["\']', line
|
||||
):
|
||||
if "?" not in line and "%s" in line:
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
i,
|
||||
"sql_injection_risk",
|
||||
@@ -287,7 +287,7 @@ class CodeReviewer:
|
||||
"""检查 CORS 配置"""
|
||||
for i, line in enumerate(lines, 1):
|
||||
if "allow_origins" in line and '["*"]' in line:
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
i,
|
||||
"cors_wildcard",
|
||||
@@ -303,7 +303,7 @@ class CodeReviewer:
|
||||
for i, line in enumerate(lines, 1):
|
||||
# 检查硬编码密钥
|
||||
if re.search(
|
||||
r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']',
|
||||
r'(password|secret|key|token)\s* = \s*["\'][^"\']+["\']',
|
||||
line,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
@@ -316,7 +316,7 @@ class CodeReviewer:
|
||||
if not re.search(r'["\']\*+["\']', line) and not re.search(
|
||||
r'["\']<[^"\']*>["\']', line
|
||||
):
|
||||
issue = CodeIssue(
|
||||
issue = CodeIssue(
|
||||
file_path,
|
||||
i,
|
||||
"hardcoded_secret",
|
||||
@@ -328,62 +328,62 @@ class CodeReviewer:
|
||||
def auto_fix(self) -> None:
|
||||
"""自动修复问题"""
|
||||
# 按文件分组问题
|
||||
issues_by_file: dict[str, list[CodeIssue]] = {}
|
||||
issues_by_file: dict[str, list[CodeIssue]] = {}
|
||||
for issue in self.issues:
|
||||
if issue.file_path not in issues_by_file:
|
||||
issues_by_file[issue.file_path] = []
|
||||
issues_by_file[issue.file_path] = []
|
||||
issues_by_file[issue.file_path].append(issue)
|
||||
|
||||
for file_path, issues in issues_by_file.items():
|
||||
full_path = self.base_path / file_path
|
||||
full_path = self.base_path / file_path
|
||||
if not full_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
with open(full_path, "r", encoding = "utf-8") as f:
|
||||
content = f.read()
|
||||
lines = content.split("\n")
|
||||
except Exception as e:
|
||||
print(f"Error reading {full_path}: {e}")
|
||||
continue
|
||||
|
||||
original_lines = lines.copy()
|
||||
original_lines = lines.copy()
|
||||
|
||||
# 修复行尾空格
|
||||
for issue in issues:
|
||||
if issue.issue_type == "trailing_whitespace":
|
||||
idx = issue.line_no - 1
|
||||
idx = issue.line_no - 1
|
||||
if 0 <= idx < len(lines):
|
||||
lines[idx] = lines[idx].rstrip()
|
||||
issue.fixed = True
|
||||
lines[idx] = lines[idx].rstrip()
|
||||
issue.fixed = True
|
||||
|
||||
# 修复裸异常
|
||||
for issue in issues:
|
||||
if issue.issue_type == "bare_exception":
|
||||
idx = issue.line_no - 1
|
||||
idx = issue.line_no - 1
|
||||
if 0 <= idx < len(lines):
|
||||
line = lines[idx]
|
||||
# 将 except: 改为 except Exception:
|
||||
line = lines[idx]
|
||||
# 将 except Exception: 改为 except Exception:
|
||||
if re.search(r"except\s*:\s*$", line.strip()):
|
||||
lines[idx] = line.replace("except:", "except Exception:")
|
||||
issue.fixed = True
|
||||
lines[idx] = line.replace("except Exception:", "except Exception:")
|
||||
issue.fixed = True
|
||||
elif re.search(r"except\s+Exception\s*:\s*$", line.strip()):
|
||||
# 已经是 Exception,但可能需要更具体
|
||||
pass
|
||||
|
||||
# 如果文件有修改,写回
|
||||
if lines != original_lines:
|
||||
with open(full_path, "w", encoding="utf-8") as f:
|
||||
with open(full_path, "w", encoding = "utf-8") as f:
|
||||
f.write("\n".join(lines))
|
||||
print(f"Fixed issues in {file_path}")
|
||||
|
||||
# 移动到已修复列表
|
||||
self.fixed_issues = [i for i in self.issues if i.fixed]
|
||||
self.issues = [i for i in self.issues if not i.fixed]
|
||||
self.fixed_issues = [i for i in self.issues if i.fixed]
|
||||
self.issues = [i for i in self.issues if not i.fixed]
|
||||
|
||||
def generate_report(self) -> str:
|
||||
"""生成审查报告"""
|
||||
report = []
|
||||
report = []
|
||||
report.append("# InsightFlow 代码审查报告")
|
||||
report.append(f"\n扫描路径: {self.base_path}")
|
||||
report.append(f"扫描时间: {__import__('datetime').datetime.now().isoformat()}")
|
||||
@@ -421,9 +421,9 @@ class CodeReviewer:
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
def main():
|
||||
base_path = "/root/.openclaw/workspace/projects/insightflow/backend"
|
||||
reviewer = CodeReviewer(base_path)
|
||||
def main() -> None:
|
||||
base_path = "/root/.openclaw/workspace/projects/insightflow/backend"
|
||||
reviewer = CodeReviewer(base_path)
|
||||
|
||||
print("开始扫描代码...")
|
||||
reviewer.scan_all()
|
||||
@@ -437,9 +437,9 @@ def main():
|
||||
print(f"\n已修复 {len(reviewer.fixed_issues)} 个问题")
|
||||
|
||||
# 生成报告
|
||||
report = reviewer.generate_report()
|
||||
report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md"
|
||||
with open(report_path, "w", encoding="utf-8") as f:
|
||||
report = reviewer.generate_report()
|
||||
report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md"
|
||||
with open(report_path, "w", encoding = "utf-8") as f:
|
||||
f.write(report)
|
||||
print(f"\n报告已保存到: {report_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user