fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 修复语法错误(运算符空格问题)
- 修复类型注解格式
This commit is contained in:
AutoFix Bot
2026-03-02 06:09:49 +08:00
parent b83265e5fd
commit e23f1fec08
84 changed files with 9492 additions and 9491 deletions

View File

@@ -224,4 +224,8 @@
- 第 490 行: line_too_long
- 第 541 行: line_too_long
- 第 579 行: line_too_long
- ... 还有 2 个类似问题
- ... 还有 2 个类似问题
## Git 提交结果
✅ 提交并推送成功

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -19,30 +19,30 @@ class CodeIssue:
line_no: int,
issue_type: str,
message: str,
severity: str = "warning",
original_line: str = "",
):
self.file_path = file_path
self.line_no = line_no
self.issue_type = issue_type
self.message = message
self.severity = severity
self.original_line = original_line
self.fixed = False
severity: str = "warning",
original_line: str = "",
) -> None:
self.file_path = file_path
self.line_no = line_no
self.issue_type = issue_type
self.message = message
self.severity = severity
self.original_line = original_line
self.fixed = False
def __repr__(self):
def __repr__(self) -> None:
return f"{self.file_path}:{self.line_no} [{self.severity}] {self.issue_type}: {self.message}"
class CodeFixer:
"""代码自动修复器"""
def __init__(self, project_path: str):
self.project_path = Path(project_path)
self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = []
self.manual_issues: list[CodeIssue] = []
self.scanned_files: list[str] = []
def __init__(self, project_path: str) -> None:
self.project_path = Path(project_path)
self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = []
self.manual_issues: list[CodeIssue] = []
self.scanned_files: list[str] = []
def scan_all_files(self) -> None:
"""扫描所有 Python 文件"""
@@ -55,9 +55,9 @@ class CodeFixer:
def _scan_file(self, file_path: Path) -> None:
"""扫描单个文件"""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
with open(file_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception as e:
print(f"Error reading {file_path}: {e}")
return
@@ -85,7 +85,7 @@ class CodeFixer:
) -> None:
"""检查裸异常捕获"""
for i, line in enumerate(lines, 1):
# 匹配 except: 但不匹配 except Exception: 或 except SpecificError:
# 匹配 except Exception: 但不匹配 except Exception: 或 except SpecificError:
if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line):
# 跳过注释说明的情况
if "# noqa" in line or "# intentional" in line.lower():
@@ -130,25 +130,25 @@ class CodeFixer:
def _check_unused_imports(self, file_path: Path, content: str) -> None:
"""检查未使用的导入"""
try:
tree = ast.parse(content)
tree = ast.parse(content)
except SyntaxError:
return
imports = {}
imports = {}
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
imports[name] = node.lineno
name = alias.asname if alias.asname else alias.name
imports[name] = node.lineno
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
name = alias.asname if alias.asname else alias.name
if alias.name == "*":
continue
imports[name] = node.lineno
imports[name] = node.lineno
# 检查使用
used_names = set()
used_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name):
used_names.add(node.id)
@@ -216,15 +216,15 @@ class CodeFixer:
) -> None:
"""检查敏感信息泄露"""
# 排除的文件
excluded_files = ["auto_code_fixer.py", "code_reviewer.py"]
excluded_files = ["auto_code_fixer.py", "code_reviewer.py"]
if any(excluded in str(file_path) for excluded in excluded_files):
return
patterns = [
(r'password\s*=\s*["\'][^"\']{8,}["\']', "硬编码密码"),
(r'secret_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"),
(r'api_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码 API Key"),
(r'token\s*=\s*["\'][^"\']{8,}["\']', "硬编码 Token"),
patterns = [
(r'password\s* = \s*["\'][^"\']{8, }["\']', "硬编码密码"),
(r'secret_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码密钥"),
(r'api_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码 API Key"),
(r'token\s* = \s*["\'][^"\']{8, }["\']', "硬编码 Token"),
]
for i, line in enumerate(lines, 1):
@@ -241,7 +241,7 @@ class CodeFixer:
if any(x in line.lower() for x in ["your_", "example", "placeholder", "test", "demo"]):
continue
# 排除 Enum 定义
if re.search(r'^\s*[A-Z_]+\s*=', line.strip()):
if re.search(r'^\s*[A-Z_]+\s* = ', line.strip()):
continue
self.manual_issues.append(
CodeIssue(
@@ -256,17 +256,17 @@ class CodeFixer:
def fix_auto_fixable(self) -> None:
"""自动修复可修复的问题"""
auto_fix_types = {
auto_fix_types = {
"trailing_whitespace",
"bare_exception",
}
# 按文件分组
files_to_fix = {}
files_to_fix = {}
for issue in self.issues:
if issue.issue_type in auto_fix_types:
if issue.file_path not in files_to_fix:
files_to_fix[issue.file_path] = []
files_to_fix[issue.file_path] = []
files_to_fix[issue.file_path].append(issue)
for file_path, file_issues in files_to_fix.items():
@@ -275,43 +275,43 @@ class CodeFixer:
continue
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
with open(file_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception:
continue
original_lines = lines.copy()
fixed_lines = set()
original_lines = lines.copy()
fixed_lines = set()
# 修复行尾空格
for issue in file_issues:
if issue.issue_type == "trailing_whitespace":
line_idx = issue.line_no - 1
line_idx = issue.line_no - 1
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
if lines[line_idx].rstrip() != lines[line_idx]:
lines[line_idx] = lines[line_idx].rstrip()
lines[line_idx] = lines[line_idx].rstrip()
fixed_lines.add(line_idx)
issue.fixed = True
issue.fixed = True
self.fixed_issues.append(issue)
# 修复裸异常
for issue in file_issues:
if issue.issue_type == "bare_exception":
line_idx = issue.line_no - 1
line_idx = issue.line_no - 1
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
line = lines[line_idx]
# 将 except: 改为 except Exception:
line = lines[line_idx]
# 将 except Exception: 改为 except Exception:
if re.search(r"except\s*:\s*$", line.strip()):
lines[line_idx] = line.replace("except:", "except Exception:")
lines[line_idx] = line.replace("except Exception:", "except Exception:")
fixed_lines.add(line_idx)
issue.fixed = True
issue.fixed = True
self.fixed_issues.append(issue)
# 如果文件有修改,写回
if lines != original_lines:
try:
with open(file_path, "w", encoding="utf-8") as f:
with open(file_path, "w", encoding = "utf-8") as f:
f.write("\n".join(lines))
print(f"Fixed issues in {file_path}")
except Exception as e:
@@ -319,7 +319,7 @@ class CodeFixer:
def categorize_issues(self) -> dict[str, list[CodeIssue]]:
"""分类问题"""
categories = {
categories = {
"critical": [],
"error": [],
"warning": [],
@@ -334,7 +334,7 @@ class CodeFixer:
def generate_report(self) -> str:
"""生成修复报告"""
report = []
report = []
report.append("# InsightFlow 代码审查报告")
report.append("")
report.append(f"扫描时间: {os.popen('date').read().strip()}")
@@ -349,9 +349,9 @@ class CodeFixer:
report.append("")
# 问题统计
categories = self.categorize_issues()
manual_critical = [i for i in self.manual_issues if i.severity == "critical"]
manual_warning = [i for i in self.manual_issues if i.severity == "warning"]
categories = self.categorize_issues()
manual_critical = [i for i in self.manual_issues if i.severity == "critical"]
manual_warning = [i for i in self.manual_issues if i.severity == "warning"]
report.append("## 问题分类统计")
report.append("")
@@ -393,17 +393,17 @@ class CodeFixer:
# 其他问题
report.append("## 📋 其他发现的问题")
report.append("")
other_issues = [
other_issues = [
i
for i in self.issues
if i not in self.fixed_issues
]
# 按类型分组
by_type = {}
by_type = {}
for issue in other_issues:
if issue.issue_type not in by_type:
by_type[issue.issue_type] = []
by_type[issue.issue_type] = []
by_type[issue.issue_type].append(issue)
for issue_type, issues in sorted(by_type.items()):
@@ -424,21 +424,21 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
"""Git 提交和推送"""
try:
# 检查是否有变更
result = subprocess.run(
result = subprocess.run(
["git", "status", "--porcelain"],
cwd=project_path,
capture_output=True,
text=True,
cwd = project_path,
capture_output = True,
text = True,
)
if not result.stdout.strip():
return True, "没有需要提交的变更"
# 添加所有变更
subprocess.run(["git", "add", "-A"], cwd=project_path, check=True)
subprocess.run(["git", "add", "-A"], cwd = project_path, check = True)
# 提交
commit_msg = """fix: auto-fix code issues (cron)
commit_msg = """fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
@@ -446,11 +446,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
- 添加类型注解"""
subprocess.run(
["git", "commit", "-m", commit_msg], cwd=project_path, check=True
["git", "commit", "-m", commit_msg], cwd = project_path, check = True
)
# 推送
subprocess.run(["git", "push"], cwd=project_path, check=True)
subprocess.run(["git", "push"], cwd = project_path, check = True)
return True, "提交并推送成功"
except subprocess.CalledProcessError as e:
@@ -459,11 +459,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
return False, f"Git 操作异常: {e}"
def main():
project_path = "/root/.openclaw/workspace/projects/insightflow"
def main() -> None:
project_path = "/root/.openclaw/workspace/projects/insightflow"
print("🔍 开始扫描代码...")
fixer = CodeFixer(project_path)
fixer = CodeFixer(project_path)
fixer.scan_all_files()
print(f"📊 发现 {len(fixer.issues)} 个可自动修复问题")
@@ -475,30 +475,30 @@ def main():
print(f"✅ 已修复 {len(fixer.fixed_issues)} 个问题")
# 生成报告
report = fixer.generate_report()
report = fixer.generate_report()
# 保存报告
report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding="utf-8") as f:
report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding = "utf-8") as f:
f.write(report)
print(f"📝 报告已保存到: {report_path}")
# Git 提交
print("📤 提交变更到 Git...")
success, msg = git_commit_and_push(project_path)
success, msg = git_commit_and_push(project_path)
print(f"{'' if success else ''} {msg}")
# 添加 Git 结果到报告
report += f"\n\n## Git 提交结果\n\n{'' if success else ''} {msg}\n"
# 重新保存完整报告
with open(report_path, "w", encoding="utf-8") as f:
with open(report_path, "w", encoding = "utf-8") as f:
f.write(report)
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print(report)
print("=" * 60)
print(" = " * 60)
return report

View File

@@ -14,11 +14,11 @@ from pathlib import Path
def run_ruff_check(directory: str) -> list[dict]:
"""运行 ruff 检查并返回问题列表"""
try:
result = subprocess.run(
["ruff", "check", "--select=E,W,F,I", "--output-format=json", directory],
capture_output=True,
text=True,
check=False,
result = subprocess.run(
["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory],
capture_output = True,
text = True,
check = False,
)
if result.stdout:
return json.loads(result.stdout)
@@ -29,18 +29,18 @@ def run_ruff_check(directory: str) -> list[dict]:
def fix_bare_except(content: str) -> str:
"""修复裸异常捕获 - 将 bare except: 改为 except Exception:"""
pattern = r'except\s*:\s*\n'
replacement = 'except Exception:\n'
"""修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:"""
pattern = r'except\s*:\s*\n'
replacement = 'except Exception:\n'
return re.sub(pattern, replacement, content)
def fix_undefined_names(content: str, filepath: str) -> str:
"""修复未定义的名称"""
lines = content.split('\n')
modified = False
import_map = {
lines = content.split('\n')
modified = False
import_map = {
'ExportEntity': 'from export_manager import ExportEntity',
'ExportRelation': 'from export_manager import ExportRelation',
'ExportTranscript': 'from export_manager import ExportTranscript',
@@ -49,23 +49,23 @@ def fix_undefined_names(content: str, filepath: str) -> str:
'OpsManager': 'from ops_manager import OpsManager',
'urllib': 'import urllib.parse',
}
undefined_names = set()
undefined_names = set()
for name, import_stmt in import_map.items():
if name in content and import_stmt not in content:
undefined_names.add((name, import_stmt))
if undefined_names:
import_idx = 0
import_idx = 0
for i, line in enumerate(lines):
if line.startswith('import ') or line.startswith('from '):
import_idx = i + 1
import_idx = i + 1
for name, import_stmt in sorted(undefined_names):
lines.insert(import_idx, import_stmt)
import_idx += 1
modified = True
modified = True
if modified:
return '\n'.join(lines)
return content
@@ -73,100 +73,100 @@ def fix_undefined_names(content: str, filepath: str) -> str:
def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]:
"""修复单个文件的问题"""
with open(filepath, 'r', encoding='utf-8') as f:
original_content = f.read()
content = original_content
fixed_issues = []
manual_fix_needed = []
with open(filepath, 'r', encoding = 'utf-8') as f:
original_content = f.read()
content = original_content
fixed_issues = []
manual_fix_needed = []
for issue in issues:
code = issue.get('code', '')
message = issue.get('message', '')
line_num = issue['location']['row']
code = issue.get('code', '')
message = issue.get('message', '')
line_num = issue['location']['row']
if code == 'F821':
content = fix_undefined_names(content, filepath)
content = fix_undefined_names(content, filepath)
if content != original_content:
fixed_issues.append(f"F821 - {message} (line {line_num})")
else:
manual_fix_needed.append(f"F821 - {message} (line {line_num})")
elif code == 'E501':
manual_fix_needed.append(f"E501 (line {line_num})")
content = fix_bare_except(content)
content = fix_bare_except(content)
if content != original_content:
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, 'w', encoding = 'utf-8') as f:
f.write(content)
return True, fixed_issues, manual_fix_needed
return False, fixed_issues, manual_fix_needed
def main():
base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
backend_dir = base_dir / "backend"
print("=" * 60)
def main() -> None:
base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
backend_dir = base_dir / "backend"
print(" = " * 60)
print("InsightFlow 代码自动修复")
print("=" * 60)
print(" = " * 60)
print("\n1. 扫描代码问题...")
issues = run_ruff_check(str(backend_dir))
issues_by_file = {}
issues = run_ruff_check(str(backend_dir))
issues_by_file = {}
for issue in issues:
filepath = issue.get('filename', '')
filepath = issue.get('filename', '')
if filepath not in issues_by_file:
issues_by_file[filepath] = []
issues_by_file[filepath] = []
issues_by_file[filepath].append(issue)
print(f" 发现 {len(issues)} 个问题,分布在 {len(issues_by_file)} 个文件中")
issue_types = {}
issue_types = {}
for issue in issues:
code = issue.get('code', 'UNKNOWN')
issue_types[code] = issue_types.get(code, 0) + 1
code = issue.get('code', 'UNKNOWN')
issue_types[code] = issue_types.get(code, 0) + 1
print("\n2. 问题类型统计:")
for code, count in sorted(issue_types.items(), key=lambda x: -x[1]):
for code, count in sorted(issue_types.items(), key = lambda x: -x[1]):
print(f" - {code}: {count}")
print("\n3. 尝试自动修复...")
fixed_files = []
all_fixed_issues = []
all_manual_fixes = []
fixed_files = []
all_fixed_issues = []
all_manual_fixes = []
for filepath, file_issues in issues_by_file.items():
if not os.path.exists(filepath):
continue
modified, fixed, manual = fix_file(filepath, file_issues)
modified, fixed, manual = fix_file(filepath, file_issues)
if modified:
fixed_files.append(filepath)
all_fixed_issues.extend(fixed)
all_manual_fixes.extend([(filepath, m) for m in manual])
print(f" 直接修改了 {len(fixed_files)} 个文件")
print(f" 自动修复了 {len(all_fixed_issues)} 个问题")
print("\n4. 运行 ruff 自动格式化...")
try:
subprocess.run(
["ruff", "format", str(backend_dir)],
capture_output=True,
check=False,
capture_output = True,
check = False,
)
print(" 格式化完成")
except Exception as e:
print(f" 格式化失败: {e}")
print("\n5. 再次检查...")
remaining_issues = run_ruff_check(str(backend_dir))
remaining_issues = run_ruff_check(str(backend_dir))
print(f" 剩余 {len(remaining_issues)} 个问题需要手动处理")
report = {
report = {
'total_issues': len(issues),
'fixed_files': len(fixed_files),
'fixed_issues': len(all_fixed_issues),
@@ -174,15 +174,15 @@ def main():
'issue_types': issue_types,
'manual_fix_needed': all_manual_fixes[:30],
}
return report
if __name__ == "__main__":
report = main()
print("\n" + "=" * 60)
report = main()
print("\n" + " = " * 60)
print("修复报告")
print("=" * 60)
print(" = " * 60)
print(f"总问题数: {report['total_issues']}")
print(f"修复文件数: {report['fixed_files']}")
print(f"自动修复问题数: {report['fixed_issues']}")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -13,13 +13,13 @@ from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
class ApiKeyStatus(Enum):
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
@dataclass
@@ -37,18 +37,18 @@ class ApiKey:
last_used_at: str | None
revoked_at: str | None
revoked_reason: str | None
total_calls: int = 0
total_calls: int = 0
class ApiKeyManager:
"""API Key 管理器"""
# Key 前缀
KEY_PREFIX = "ak_live_"
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
KEY_PREFIX = "ak_live_"
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path
def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path
self._init_db()
def _init_db(self) -> None:
@@ -117,7 +117,7 @@ class ApiKeyManager:
def _generate_key(self) -> str:
"""生成新的 API Key"""
# 生成 40 字符的随机字符串
random_part = secrets.token_urlsafe(30)[:40]
random_part = secrets.token_urlsafe(30)[:40]
return f"{self.KEY_PREFIX}{random_part}"
def _hash_key(self, key: str) -> str:
@@ -131,10 +131,10 @@ class ApiKeyManager:
def create_key(
self,
name: str,
owner_id: str | None = None,
permissions: list[str] = None,
rate_limit: int = 60,
expires_days: int | None = None,
owner_id: str | None = None,
permissions: list[str] = None,
rate_limit: int = 60,
expires_days: int | None = None,
) -> tuple[str, ApiKey]:
"""
创建新的 API Key
@@ -143,32 +143,32 @@ class ApiKeyManager:
tuple: (原始key仅返回一次, ApiKey对象)
"""
if permissions is None:
permissions = ["read"]
permissions = ["read"]
key_id = secrets.token_hex(16)
raw_key = self._generate_key()
key_hash = self._hash_key(raw_key)
key_preview = self._get_preview(raw_key)
key_id = secrets.token_hex(16)
raw_key = self._generate_key()
key_hash = self._hash_key(raw_key)
key_preview = self._get_preview(raw_key)
expires_at = None
expires_at = None
if expires_days:
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat()
api_key = ApiKey(
id=key_id,
key_hash=key_hash,
key_preview=key_preview,
name=name,
owner_id=owner_id,
permissions=permissions,
rate_limit=rate_limit,
status=ApiKeyStatus.ACTIVE.value,
created_at=datetime.now().isoformat(),
expires_at=expires_at,
last_used_at=None,
revoked_at=None,
revoked_reason=None,
total_calls=0,
api_key = ApiKey(
id = key_id,
key_hash = key_hash,
key_preview = key_preview,
name = name,
owner_id = owner_id,
permissions = permissions,
rate_limit = rate_limit,
status = ApiKeyStatus.ACTIVE.value,
created_at = datetime.now().isoformat(),
expires_at = expires_at,
last_used_at = None,
revoked_at = None,
revoked_reason = None,
total_calls = 0,
)
with sqlite3.connect(self.db_path) as conn:
@@ -203,16 +203,16 @@ class ApiKeyManager:
Returns:
ApiKey if valid, None otherwise
"""
key_hash = self._hash_key(key)
key_hash = self._hash_key(key)
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone()
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone()
if not row:
return None
api_key = self._row_to_api_key(row)
api_key = self._row_to_api_key(row)
# 检查状态
if api_key.status != ApiKeyStatus.ACTIVE.value:
@@ -220,11 +220,11 @@ class ApiKeyManager:
# 检查是否过期
if api_key.expires_at:
expires = datetime.fromisoformat(api_key.expires_at)
expires = datetime.fromisoformat(api_key.expires_at)
if datetime.now() > expires:
# 更新状态为过期
conn.execute(
"UPDATE api_keys SET status = ? WHERE id = ?",
"UPDATE api_keys SET status = ? WHERE id = ?",
(ApiKeyStatus.EXPIRED.value, api_key.id),
)
conn.commit()
@@ -232,22 +232,22 @@ class ApiKeyManager:
return api_key
def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
"""撤销 API Key"""
with sqlite3.connect(self.db_path) as conn:
# 验证所有权(如果提供了 owner_id
if owner_id:
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
).fetchone()
if not row or row[0] != owner_id:
return False
cursor = conn.execute(
cursor = conn.execute(
"""
UPDATE api_keys
SET status = ?, revoked_at = ?, revoked_reason = ?
WHERE id = ? AND status = ?
SET status = ?, revoked_at = ?, revoked_reason = ?
WHERE id = ? AND status = ?
""",
(
ApiKeyStatus.REVOKED.value,
@@ -260,17 +260,17 @@ class ApiKeyManager:
conn.commit()
return cursor.rowcount > 0
def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
"""通过 ID 获取 API Key不包含敏感信息"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row
if owner_id:
row = conn.execute(
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
row = conn.execute(
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
).fetchone()
else:
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone()
if row:
return self._row_to_api_key(row)
@@ -278,54 +278,54 @@ class ApiKeyManager:
def list_keys(
self,
owner_id: str | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
owner_id: str | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ApiKey]:
"""列出 API Keys"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_keys WHERE 1=1"
params = []
query = "SELECT * FROM api_keys WHERE 1 = 1"
params = []
if owner_id:
query += " AND owner_id = ?"
query += " AND owner_id = ?"
params.append(owner_id)
if status:
query += " AND status = ?"
query += " AND status = ?"
params.append(status)
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
rows = conn.execute(query, params).fetchall()
rows = conn.execute(query, params).fetchall()
return [self._row_to_api_key(row) for row in rows]
def update_key(
self,
key_id: str,
name: str | None = None,
permissions: list[str] | None = None,
rate_limit: int | None = None,
owner_id: str | None = None,
name: str | None = None,
permissions: list[str] | None = None,
rate_limit: int | None = None,
owner_id: str | None = None,
) -> bool:
"""更新 API Key 信息"""
updates = []
params = []
updates = []
params = []
if name is not None:
updates.append("name = ?")
updates.append("name = ?")
params.append(name)
if permissions is not None:
updates.append("permissions = ?")
updates.append("permissions = ?")
params.append(json.dumps(permissions))
if rate_limit is not None:
updates.append("rate_limit = ?")
updates.append("rate_limit = ?")
params.append(rate_limit)
if not updates:
@@ -336,14 +336,14 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn:
# 验证所有权
if owner_id:
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
).fetchone()
if not row or row[0] != owner_id:
return False
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
cursor = conn.execute(query, params)
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
cursor = conn.execute(query, params)
conn.commit()
return cursor.rowcount > 0
@@ -353,8 +353,8 @@ class ApiKeyManager:
conn.execute(
"""
UPDATE api_keys
SET last_used_at = ?, total_calls = total_calls + 1
WHERE id = ?
SET last_used_at = ?, total_calls = total_calls + 1
WHERE id = ?
""",
(datetime.now().isoformat(), key_id),
)
@@ -365,12 +365,12 @@ class ApiKeyManager:
api_key_id: str,
endpoint: str,
method: str,
status_code: int = 200,
response_time_ms: int = 0,
ip_address: str = "",
user_agent: str = "",
error_message: str = "",
):
status_code: int = 200,
response_time_ms: int = 0,
ip_address: str = "",
user_agent: str = "",
error_message: str = "",
) -> None:
"""记录 API 调用日志"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
@@ -395,21 +395,21 @@ class ApiKeyManager:
def get_call_logs(
self,
api_key_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
offset: int = 0,
api_key_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[dict]:
"""获取 API 调用日志"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_call_logs WHERE 1=1"
params = []
query = "SELECT * FROM api_call_logs WHERE 1 = 1"
params = []
if api_key_id:
query += " AND api_key_id = ?"
query += " AND api_key_id = ?"
params.append(api_key_id)
if start_date:
@@ -423,16 +423,16 @@ class ApiKeyManager:
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
rows = conn.execute(query, params).fetchall()
rows = conn.execute(query, params).fetchall()
return [dict(row) for row in rows]
def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
"""获取 API 调用统计"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row
# 总体统计
query = f"""
query = f"""
SELECT
COUNT(*) as total_calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
@@ -444,15 +444,15 @@ class ApiKeyManager:
WHERE created_at >= date('now', '-{days} days')
"""
params = []
params = []
if api_key_id:
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
params.insert(0, api_key_id)
row = conn.execute(query, params).fetchone()
row = conn.execute(query, params).fetchone()
# 按端点统计
endpoint_query = f"""
endpoint_query = f"""
SELECT
endpoint,
method,
@@ -462,19 +462,19 @@ class ApiKeyManager:
WHERE created_at >= date('now', '-{days} days')
"""
endpoint_params = []
endpoint_params = []
if api_key_id:
endpoint_query = endpoint_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
endpoint_query = endpoint_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
)
endpoint_params.insert(0, api_key_id)
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
# 按天统计
daily_query = f"""
daily_query = f"""
SELECT
date(created_at) as date,
COUNT(*) as calls,
@@ -483,16 +483,16 @@ class ApiKeyManager:
WHERE created_at >= date('now', '-{days} days')
"""
daily_params = []
daily_params = []
if api_key_id:
daily_query = daily_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
daily_query = daily_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
)
daily_params.insert(0, api_key_id)
daily_query += " GROUP BY date(created_at) ORDER BY date"
daily_rows = conn.execute(daily_query, daily_params).fetchall()
daily_rows = conn.execute(daily_query, daily_params).fetchall()
return {
"summary": {
@@ -510,30 +510,30 @@ class ApiKeyManager:
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
"""将数据库行转换为 ApiKey 对象"""
return ApiKey(
id=row["id"],
key_hash=row["key_hash"],
key_preview=row["key_preview"],
name=row["name"],
owner_id=row["owner_id"],
permissions=json.loads(row["permissions"]),
rate_limit=row["rate_limit"],
status=row["status"],
created_at=row["created_at"],
expires_at=row["expires_at"],
last_used_at=row["last_used_at"],
revoked_at=row["revoked_at"],
revoked_reason=row["revoked_reason"],
total_calls=row["total_calls"],
id = row["id"],
key_hash = row["key_hash"],
key_preview = row["key_preview"],
name = row["name"],
owner_id = row["owner_id"],
permissions = json.loads(row["permissions"]),
rate_limit = row["rate_limit"],
status = row["status"],
created_at = row["created_at"],
expires_at = row["expires_at"],
last_used_at = row["last_used_at"],
revoked_at = row["revoked_at"],
revoked_reason = row["revoked_reason"],
total_calls = row["total_calls"],
)
# 全局实例
_api_key_manager: ApiKeyManager | None = None
_api_key_manager: ApiKeyManager | None = None
def get_api_key_manager() -> ApiKeyManager:
"""获取 API Key 管理器实例"""
global _api_key_manager
if _api_key_manager is None:
_api_key_manager = ApiKeyManager()
_api_key_manager = ApiKeyManager()
return _api_key_manager

View File

@@ -15,29 +15,29 @@ from typing import Any
class SharePermission(Enum):
"""分享权限级别"""
READ_ONLY = "read_only" # 只读
COMMENT = "comment" # 可评论
EDIT = "edit" # 可编辑
ADMIN = "admin" # 管理员
READ_ONLY = "read_only" # 只读
COMMENT = "comment" # 可评论
EDIT = "edit" # 可编辑
ADMIN = "admin" # 管理员
class CommentTargetType(Enum):
"""评论目标类型"""
ENTITY = "entity" # 实体评论
RELATION = "relation" # 关系评论
TRANSCRIPT = "transcript" # 转录文本评论
PROJECT = "project" # 项目级评论
ENTITY = "entity" # 实体评论
RELATION = "relation" # 关系评论
TRANSCRIPT = "transcript" # 转录文本评论
PROJECT = "project" # 项目级评论
class ChangeType(Enum):
"""变更类型"""
CREATE = "create" # 创建
UPDATE = "update" # 更新
DELETE = "delete" # 删除
MERGE = "merge" # 合并
SPLIT = "split" # 拆分
CREATE = "create" # 创建
UPDATE = "update" # 更新
DELETE = "delete" # 删除
MERGE = "merge" # 合并
SPLIT = "split" # 拆分
@dataclass
@@ -136,10 +136,10 @@ class TeamSpace:
class CollaborationManager:
"""协作管理主类"""
def __init__(self, db_manager=None):
self.db = db_manager
self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: dict[str, list[Comment]] = {}
def __init__(self, db_manager = None) -> None:
self.db = db_manager
self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: dict[str, list[Comment]] = {}
# ============ 项目分享 ============
@@ -147,57 +147,57 @@ class CollaborationManager:
self,
project_id: str,
created_by: str,
permission: str = "read_only",
expires_in_days: int | None = None,
max_uses: int | None = None,
password: str | None = None,
allow_download: bool = False,
allow_export: bool = False,
permission: str = "read_only",
expires_in_days: int | None = None,
max_uses: int | None = None,
password: str | None = None,
allow_download: bool = False,
allow_export: bool = False,
) -> ProjectShare:
"""创建项目分享链接"""
share_id = str(uuid.uuid4())
token = self._generate_share_token(project_id)
share_id = str(uuid.uuid4())
token = self._generate_share_token(project_id)
now = datetime.now().isoformat()
expires_at = None
now = datetime.now().isoformat()
expires_at = None
if expires_in_days:
expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat()
expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat()
password_hash = None
password_hash = None
if password:
password_hash = hashlib.sha256(password.encode()).hexdigest()
password_hash = hashlib.sha256(password.encode()).hexdigest()
share = ProjectShare(
id=share_id,
project_id=project_id,
token=token,
permission=permission,
created_by=created_by,
created_at=now,
expires_at=expires_at,
max_uses=max_uses,
use_count=0,
password_hash=password_hash,
is_active=True,
allow_download=allow_download,
allow_export=allow_export,
share = ProjectShare(
id = share_id,
project_id = project_id,
token = token,
permission = permission,
created_by = created_by,
created_at = now,
expires_at = expires_at,
max_uses = max_uses,
use_count = 0,
password_hash = password_hash,
is_active = True,
allow_download = allow_download,
allow_export = allow_export,
)
# 保存到数据库
if self.db:
self._save_share_to_db(share)
self._shares_cache[token] = share
self._shares_cache[token] = share
return share
def _generate_share_token(self, project_id: str) -> str:
"""生成分享令牌"""
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
return hashlib.sha256(data.encode()).hexdigest()[:32]
def _save_share_to_db(self, share: ProjectShare) -> None:
"""保存分享记录到数据库"""
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
INSERT INTO project_shares
@@ -224,12 +224,12 @@ class CollaborationManager:
)
self.db.conn.commit()
def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None:
def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None:
"""验证分享令牌"""
# 从缓存或数据库获取
share = self._shares_cache.get(token)
share = self._shares_cache.get(token)
if not share and self.db:
share = self._get_share_from_db(token)
share = self._get_share_from_db(token)
if not share:
return None
@@ -250,7 +250,7 @@ class CollaborationManager:
if share.password_hash:
if not password:
return None
password_hash = hashlib.sha256(password.encode()).hexdigest()
password_hash = hashlib.sha256(password.encode()).hexdigest()
if password_hash != share.password_hash:
return None
@@ -258,63 +258,63 @@ class CollaborationManager:
def _get_share_from_db(self, token: str) -> ProjectShare | None:
"""从数据库获取分享记录"""
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
SELECT * FROM project_shares WHERE token = ?
SELECT * FROM project_shares WHERE token = ?
""",
(token,),
(token, ),
)
row = cursor.fetchone()
row = cursor.fetchone()
if not row:
return None
return ProjectShare(
id=row[0],
project_id=row[1],
token=row[2],
permission=row[3],
created_by=row[4],
created_at=row[5],
expires_at=row[6],
max_uses=row[7],
use_count=row[8],
password_hash=row[9],
is_active=bool(row[10]),
allow_download=bool(row[11]),
allow_export=bool(row[12]),
id = row[0],
project_id = row[1],
token = row[2],
permission = row[3],
created_by = row[4],
created_at = row[5],
expires_at = row[6],
max_uses = row[7],
use_count = row[8],
password_hash = row[9],
is_active = bool(row[10]),
allow_download = bool(row[11]),
allow_export = bool(row[12]),
)
def increment_share_usage(self, token: str) -> None:
"""增加分享链接使用次数"""
share = self._shares_cache.get(token)
share = self._shares_cache.get(token)
if share:
share.use_count += 1
if self.db:
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE project_shares
SET use_count = use_count + 1
WHERE token = ?
SET use_count = use_count + 1
WHERE token = ?
""",
(token,),
(token, ),
)
self.db.conn.commit()
def revoke_share_link(self, share_id: str, revoked_by: str) -> bool:
"""撤销分享链接"""
if self.db:
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE project_shares
SET is_active = 0
WHERE id = ?
SET is_active = 0
WHERE id = ?
""",
(share_id,),
(share_id, ),
)
self.db.conn.commit()
return cursor.rowcount > 0
@@ -325,33 +325,33 @@ class CollaborationManager:
if not self.db:
return []
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
SELECT * FROM project_shares
WHERE project_id = ?
WHERE project_id = ?
ORDER BY created_at DESC
""",
(project_id,),
(project_id, ),
)
shares = []
shares = []
for row in cursor.fetchall():
shares.append(
ProjectShare(
id=row[0],
project_id=row[1],
token=row[2],
permission=row[3],
created_by=row[4],
created_at=row[5],
expires_at=row[6],
max_uses=row[7],
use_count=row[8],
password_hash=row[9],
is_active=bool(row[10]),
allow_download=bool(row[11]),
allow_export=bool(row[12]),
id = row[0],
project_id = row[1],
token = row[2],
permission = row[3],
created_by = row[4],
created_at = row[5],
expires_at = row[6],
max_uses = row[7],
use_count = row[8],
password_hash = row[9],
is_active = bool(row[10]),
allow_download = bool(row[11]),
allow_export = bool(row[12]),
)
)
return shares
@@ -366,46 +366,46 @@ class CollaborationManager:
author: str,
author_name: str,
content: str,
parent_id: str | None = None,
mentions: list[str] | None = None,
attachments: list[dict] | None = None,
parent_id: str | None = None,
mentions: list[str] | None = None,
attachments: list[dict] | None = None,
) -> Comment:
"""添加评论"""
comment_id = str(uuid.uuid4())
now = datetime.now().isoformat()
comment_id = str(uuid.uuid4())
now = datetime.now().isoformat()
comment = Comment(
id=comment_id,
project_id=project_id,
target_type=target_type,
target_id=target_id,
parent_id=parent_id,
author=author,
author_name=author_name,
content=content,
created_at=now,
updated_at=now,
resolved=False,
resolved_by=None,
resolved_at=None,
mentions=mentions or [],
attachments=attachments or [],
comment = Comment(
id = comment_id,
project_id = project_id,
target_type = target_type,
target_id = target_id,
parent_id = parent_id,
author = author,
author_name = author_name,
content = content,
created_at = now,
updated_at = now,
resolved = False,
resolved_by = None,
resolved_at = None,
mentions = mentions or [],
attachments = attachments or [],
)
if self.db:
self._save_comment_to_db(comment)
# 更新缓存
key = f"{target_type}:{target_id}"
key = f"{target_type}:{target_id}"
if key not in self._comments_cache:
self._comments_cache[key] = []
self._comments_cache[key] = []
self._comments_cache[key].append(comment)
return comment
def _save_comment_to_db(self, comment: Comment) -> None:
"""保存评论到数据库"""
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
INSERT INTO comments
@@ -435,18 +435,18 @@ class CollaborationManager:
self.db.conn.commit()
def get_comments(
self, target_type: str, target_id: str, include_resolved: bool = True
self, target_type: str, target_id: str, include_resolved: bool = True
) -> list[Comment]:
"""获取评论列表"""
if not self.db:
return []
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
if include_resolved:
cursor.execute(
"""
SELECT * FROM comments
WHERE target_type = ? AND target_id = ?
WHERE target_type = ? AND target_id = ?
ORDER BY created_at ASC
""",
(target_type, target_id),
@@ -455,13 +455,13 @@ class CollaborationManager:
cursor.execute(
"""
SELECT * FROM comments
WHERE target_type = ? AND target_id = ? AND resolved = 0
WHERE target_type = ? AND target_id = ? AND resolved = 0
ORDER BY created_at ASC
""",
(target_type, target_id),
)
comments = []
comments = []
for row in cursor.fetchall():
comments.append(self._row_to_comment(row))
return comments
@@ -469,21 +469,21 @@ class CollaborationManager:
def _row_to_comment(self, row) -> Comment:
"""将数据库行转换为Comment对象"""
return Comment(
id=row[0],
project_id=row[1],
target_type=row[2],
target_id=row[3],
parent_id=row[4],
author=row[5],
author_name=row[6],
content=row[7],
created_at=row[8],
updated_at=row[9],
resolved=bool(row[10]),
resolved_by=row[11],
resolved_at=row[12],
mentions=json.loads(row[13]) if row[13] else [],
attachments=json.loads(row[14]) if row[14] else [],
id = row[0],
project_id = row[1],
target_type = row[2],
target_id = row[3],
parent_id = row[4],
author = row[5],
author_name = row[6],
content = row[7],
created_at = row[8],
updated_at = row[9],
resolved = bool(row[10]),
resolved_by = row[11],
resolved_at = row[12],
mentions = json.loads(row[13]) if row[13] else [],
attachments = json.loads(row[14]) if row[14] else [],
)
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
@@ -491,13 +491,13 @@ class CollaborationManager:
if not self.db:
return None
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE comments
SET content = ?, updated_at = ?
WHERE id = ? AND author = ?
SET content = ?, updated_at = ?
WHERE id = ? AND author = ?
""",
(content, now, comment_id, updated_by),
)
@@ -509,9 +509,9 @@ class CollaborationManager:
def _get_comment_by_id(self, comment_id: str) -> Comment | None:
"""根据ID获取评论"""
cursor = self.db.conn.cursor()
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,))
row = cursor.fetchone()
cursor = self.db.conn.cursor()
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id, ))
row = cursor.fetchone()
if row:
return self._row_to_comment(row)
return None
@@ -521,13 +521,13 @@ class CollaborationManager:
if not self.db:
return False
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE comments
SET resolved = 1, resolved_by = ?, resolved_at = ?
WHERE id = ?
SET resolved = 1, resolved_by = ?, resolved_at = ?
WHERE id = ?
""",
(resolved_by, now, comment_id),
)
@@ -539,13 +539,13 @@ class CollaborationManager:
if not self.db:
return False
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
# 只允许作者或管理员删除
cursor.execute(
"""
DELETE FROM comments
WHERE id = ? AND (author = ? OR ? IN (
SELECT created_by FROM projects WHERE id = comments.project_id
WHERE id = ? AND (author = ? OR ? IN (
SELECT created_by FROM projects WHERE id = comments.project_id
))
""",
(comment_id, deleted_by, deleted_by),
@@ -554,24 +554,24 @@ class CollaborationManager:
return cursor.rowcount > 0
def get_project_comments(
self, project_id: str, limit: int = 50, offset: int = 0
self, project_id: str, limit: int = 50, offset: int = 0
) -> list[Comment]:
"""获取项目下的所有评论"""
if not self.db:
return []
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
SELECT * FROM comments
WHERE project_id = ?
WHERE project_id = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?
""",
(project_id, limit, offset),
)
comments = []
comments = []
for row in cursor.fetchall():
comments.append(self._row_to_comment(row))
return comments
@@ -587,32 +587,32 @@ class CollaborationManager:
entity_name: str,
changed_by: str,
changed_by_name: str,
old_value: dict | None = None,
new_value: dict | None = None,
description: str = "",
session_id: str | None = None,
old_value: dict | None = None,
new_value: dict | None = None,
description: str = "",
session_id: str | None = None,
) -> ChangeRecord:
"""记录变更"""
record_id = str(uuid.uuid4())
now = datetime.now().isoformat()
record_id = str(uuid.uuid4())
now = datetime.now().isoformat()
record = ChangeRecord(
id=record_id,
project_id=project_id,
change_type=change_type,
entity_type=entity_type,
entity_id=entity_id,
entity_name=entity_name,
changed_by=changed_by,
changed_by_name=changed_by_name,
changed_at=now,
old_value=old_value,
new_value=new_value,
description=description,
session_id=session_id,
reverted=False,
reverted_at=None,
reverted_by=None,
record = ChangeRecord(
id = record_id,
project_id = project_id,
change_type = change_type,
entity_type = entity_type,
entity_id = entity_id,
entity_name = entity_name,
changed_by = changed_by,
changed_by_name = changed_by_name,
changed_at = now,
old_value = old_value,
new_value = new_value,
description = description,
session_id = session_id,
reverted = False,
reverted_at = None,
reverted_by = None,
)
if self.db:
@@ -622,7 +622,7 @@ class CollaborationManager:
def _save_change_to_db(self, record: ChangeRecord) -> None:
"""保存变更记录到数据库"""
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
INSERT INTO change_history
@@ -655,22 +655,22 @@ class CollaborationManager:
def get_change_history(
self,
project_id: str,
entity_type: str | None = None,
entity_id: str | None = None,
limit: int = 50,
offset: int = 0,
entity_type: str | None = None,
entity_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[ChangeRecord]:
"""获取变更历史"""
if not self.db:
return []
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
if entity_type and entity_id:
cursor.execute(
"""
SELECT * FROM change_history
WHERE project_id = ? AND entity_type = ? AND entity_id = ?
WHERE project_id = ? AND entity_type = ? AND entity_id = ?
ORDER BY changed_at DESC
LIMIT ? OFFSET ?
""",
@@ -680,7 +680,7 @@ class CollaborationManager:
cursor.execute(
"""
SELECT * FROM change_history
WHERE project_id = ? AND entity_type = ?
WHERE project_id = ? AND entity_type = ?
ORDER BY changed_at DESC
LIMIT ? OFFSET ?
""",
@@ -690,14 +690,14 @@ class CollaborationManager:
cursor.execute(
"""
SELECT * FROM change_history
WHERE project_id = ?
WHERE project_id = ?
ORDER BY changed_at DESC
LIMIT ? OFFSET ?
""",
(project_id, limit, offset),
)
records = []
records = []
for row in cursor.fetchall():
records.append(self._row_to_change_record(row))
return records
@@ -705,22 +705,22 @@ class CollaborationManager:
def _row_to_change_record(self, row) -> ChangeRecord:
"""将数据库行转换为ChangeRecord对象"""
return ChangeRecord(
id=row[0],
project_id=row[1],
change_type=row[2],
entity_type=row[3],
entity_id=row[4],
entity_name=row[5],
changed_by=row[6],
changed_by_name=row[7],
changed_at=row[8],
old_value=json.loads(row[9]) if row[9] else None,
new_value=json.loads(row[10]) if row[10] else None,
description=row[11],
session_id=row[12],
reverted=bool(row[13]),
reverted_at=row[14],
reverted_by=row[15],
id = row[0],
project_id = row[1],
change_type = row[2],
entity_type = row[3],
entity_id = row[4],
entity_name = row[5],
changed_by = row[6],
changed_by_name = row[7],
changed_at = row[8],
old_value = json.loads(row[9]) if row[9] else None,
new_value = json.loads(row[10]) if row[10] else None,
description = row[11],
session_id = row[12],
reverted = bool(row[13]),
reverted_at = row[14],
reverted_by = row[15],
)
def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
@@ -728,17 +728,17 @@ class CollaborationManager:
if not self.db:
return []
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
SELECT * FROM change_history
WHERE entity_type = ? AND entity_id = ?
WHERE entity_type = ? AND entity_id = ?
ORDER BY changed_at ASC
""",
(entity_type, entity_id),
)
records = []
records = []
for row in cursor.fetchall():
records.append(self._row_to_change_record(row))
return records
@@ -748,13 +748,13 @@ class CollaborationManager:
if not self.db:
return False
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE change_history
SET reverted = 1, reverted_at = ?, reverted_by = ?
WHERE id = ? AND reverted = 0
SET reverted = 1, reverted_at = ?, reverted_by = ?
WHERE id = ? AND reverted = 0
""",
(now, reverted_by, record_id),
)
@@ -766,49 +766,49 @@ class CollaborationManager:
if not self.db:
return {}
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
# 总变更数
cursor.execute(
"""
SELECT COUNT(*) FROM change_history WHERE project_id = ?
SELECT COUNT(*) FROM change_history WHERE project_id = ?
""",
(project_id,),
(project_id, ),
)
total_changes = cursor.fetchone()[0]
total_changes = cursor.fetchone()[0]
# 按类型统计
cursor.execute(
"""
SELECT change_type, COUNT(*) FROM change_history
WHERE project_id = ? GROUP BY change_type
WHERE project_id = ? GROUP BY change_type
""",
(project_id,),
(project_id, ),
)
type_counts = {row[0]: row[1] for row in cursor.fetchall()}
type_counts = {row[0]: row[1] for row in cursor.fetchall()}
# 按实体类型统计
cursor.execute(
"""
SELECT entity_type, COUNT(*) FROM change_history
WHERE project_id = ? GROUP BY entity_type
WHERE project_id = ? GROUP BY entity_type
""",
(project_id,),
(project_id, ),
)
entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()}
entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()}
# 最近活跃的用户
cursor.execute(
"""
SELECT changed_by_name, COUNT(*) as count FROM change_history
WHERE project_id = ?
WHERE project_id = ?
GROUP BY changed_by_name
ORDER BY count DESC
LIMIT 5
""",
(project_id,),
(project_id, ),
)
top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()]
top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()]
return {
"total_changes": total_changes,
@@ -827,27 +827,27 @@ class CollaborationManager:
user_email: str,
role: str,
invited_by: str,
permissions: list[str] | None = None,
permissions: list[str] | None = None,
) -> TeamMember:
"""添加团队成员"""
member_id = str(uuid.uuid4())
now = datetime.now().isoformat()
member_id = str(uuid.uuid4())
now = datetime.now().isoformat()
# 根据角色设置默认权限
if permissions is None:
permissions = self._get_default_permissions(role)
permissions = self._get_default_permissions(role)
member = TeamMember(
id=member_id,
project_id=project_id,
user_id=user_id,
user_name=user_name,
user_email=user_email,
role=role,
joined_at=now,
invited_by=invited_by,
last_active_at=None,
permissions=permissions,
member = TeamMember(
id = member_id,
project_id = project_id,
user_id = user_id,
user_name = user_name,
user_email = user_email,
role = role,
joined_at = now,
invited_by = invited_by,
last_active_at = None,
permissions = permissions,
)
if self.db:
@@ -857,7 +857,7 @@ class CollaborationManager:
def _get_default_permissions(self, role: str) -> list[str]:
"""获取角色的默认权限"""
permissions_map = {
permissions_map = {
"owner": ["read", "write", "delete", "share", "admin", "export"],
"admin": ["read", "write", "delete", "share", "export"],
"editor": ["read", "write", "export"],
@@ -868,7 +868,7 @@ class CollaborationManager:
def _save_member_to_db(self, member: TeamMember) -> None:
"""保存成员到数据库"""
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
INSERT INTO team_members
@@ -896,16 +896,16 @@ class CollaborationManager:
if not self.db:
return []
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
SELECT * FROM team_members WHERE project_id = ?
SELECT * FROM team_members WHERE project_id = ?
ORDER BY joined_at ASC
""",
(project_id,),
(project_id, ),
)
members = []
members = []
for row in cursor.fetchall():
members.append(self._row_to_team_member(row))
return members
@@ -913,16 +913,16 @@ class CollaborationManager:
def _row_to_team_member(self, row) -> TeamMember:
"""将数据库行转换为TeamMember对象"""
return TeamMember(
id=row[0],
project_id=row[1],
user_id=row[2],
user_name=row[3],
user_email=row[4],
role=row[5],
joined_at=row[6],
invited_by=row[7],
last_active_at=row[8],
permissions=json.loads(row[9]) if row[9] else [],
id = row[0],
project_id = row[1],
user_id = row[2],
user_name = row[3],
user_email = row[4],
role = row[5],
joined_at = row[6],
invited_by = row[7],
last_active_at = row[8],
permissions = json.loads(row[9]) if row[9] else [],
)
def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool:
@@ -930,13 +930,13 @@ class CollaborationManager:
if not self.db:
return False
permissions = self._get_default_permissions(new_role)
cursor = self.db.conn.cursor()
permissions = self._get_default_permissions(new_role)
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE team_members
SET role = ?, permissions = ?
WHERE id = ?
SET role = ?, permissions = ?
WHERE id = ?
""",
(new_role, json.dumps(permissions), member_id),
)
@@ -948,8 +948,8 @@ class CollaborationManager:
if not self.db:
return False
cursor = self.db.conn.cursor()
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id,))
cursor = self.db.conn.cursor()
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id, ))
self.db.conn.commit()
return cursor.rowcount > 0
@@ -958,20 +958,20 @@ class CollaborationManager:
if not self.db:
return False
cursor = self.db.conn.cursor()
cursor = self.db.conn.cursor()
cursor.execute(
"""
SELECT permissions FROM team_members
WHERE project_id = ? AND user_id = ?
WHERE project_id = ? AND user_id = ?
""",
(project_id, user_id),
)
row = cursor.fetchone()
row = cursor.fetchone()
if not row:
return False
permissions = json.loads(row[0]) if row[0] else []
permissions = json.loads(row[0]) if row[0] else []
return permission in permissions or "admin" in permissions
def update_last_active(self, project_id: str, user_id: str) -> None:
@@ -979,13 +979,13 @@ class CollaborationManager:
if not self.db:
return
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
now = datetime.now().isoformat()
cursor = self.db.conn.cursor()
cursor.execute(
"""
UPDATE team_members
SET last_active_at = ?
WHERE project_id = ? AND user_id = ?
SET last_active_at = ?
WHERE project_id = ? AND user_id = ?
""",
(now, project_id, user_id),
)
@@ -993,12 +993,12 @@ class CollaborationManager:
# 全局协作管理器实例
_collaboration_manager = None
_collaboration_manager = None
def get_collaboration_manager(db_manager=None) -> None:
def get_collaboration_manager(db_manager = None) -> None:
"""获取协作管理器单例"""
global _collaboration_manager
if _collaboration_manager is None:
_collaboration_manager = CollaborationManager(db_manager)
_collaboration_manager = CollaborationManager(db_manager)
return _collaboration_manager

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -11,8 +11,8 @@ import os
class DocumentProcessor:
"""文档处理器 - 提取 PDF/DOCX 文本"""
def __init__(self):
self.supported_formats = {
def __init__(self) -> None:
self.supported_formats = {
".pdf": self._extract_pdf,
".docx": self._extract_docx,
".doc": self._extract_docx,
@@ -31,18 +31,18 @@ class DocumentProcessor:
Returns:
{"text": "提取的文本内容", "format": "文件格式"}
"""
ext = os.path.splitext(filename.lower())[1]
ext = os.path.splitext(filename.lower())[1]
if ext not in self.supported_formats:
raise ValueError(
f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}"
)
extractor = self.supported_formats[ext]
text = extractor(content)
extractor = self.supported_formats[ext]
text = extractor(content)
# 清理文本
text = self._clean_text(text)
text = self._clean_text(text)
return {"text": text, "format": ext, "filename": filename}
@@ -51,12 +51,12 @@ class DocumentProcessor:
try:
import PyPDF2
pdf_file = io.BytesIO(content)
reader = PyPDF2.PdfReader(pdf_file)
pdf_file = io.BytesIO(content)
reader = PyPDF2.PdfReader(pdf_file)
text_parts = []
text_parts = []
for page in reader.pages:
page_text = page.extract_text()
page_text = page.extract_text()
if page_text:
text_parts.append(page_text)
@@ -66,10 +66,10 @@ class DocumentProcessor:
try:
import pdfplumber
text_parts = []
text_parts = []
with pdfplumber.open(io.BytesIO(content)) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
page_text = page.extract_text()
if page_text:
text_parts.append(page_text)
return "\n\n".join(text_parts)
@@ -85,10 +85,10 @@ class DocumentProcessor:
try:
import docx
doc_file = io.BytesIO(content)
doc = docx.Document(doc_file)
doc_file = io.BytesIO(content)
doc = docx.Document(doc_file)
text_parts = []
text_parts = []
for para in doc.paragraphs:
if para.text.strip():
text_parts.append(para.text)
@@ -96,7 +96,7 @@ class DocumentProcessor:
# 提取表格中的文本
for table in doc.tables:
for row in table.rows:
row_text = []
row_text = []
for cell in row.cells:
if cell.text.strip():
row_text.append(cell.text.strip())
@@ -114,7 +114,7 @@ class DocumentProcessor:
def _extract_txt(self, content: bytes) -> str:
"""提取纯文本"""
# 尝试多种编码
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
for encoding in encodings:
try:
@@ -123,7 +123,7 @@ class DocumentProcessor:
continue
# 如果都失败了,使用 latin-1 并忽略错误
return content.decode("latin-1", errors="ignore")
return content.decode("latin-1", errors = "ignore")
def _clean_text(self, text: str) -> str:
"""清理提取的文本"""
@@ -131,29 +131,29 @@ class DocumentProcessor:
return ""
# 移除多余的空白字符
lines = text.split("\n")
cleaned_lines = []
lines = text.split("\n")
cleaned_lines = []
for line in lines:
line = line.strip()
line = line.strip()
# 移除空行,但保留段落分隔
if line:
cleaned_lines.append(line)
# 合并行,保留段落结构
text = "\n\n".join(cleaned_lines)
text = "\n\n".join(cleaned_lines)
# 移除多余的空格
text = " ".join(text.split())
text = " ".join(text.split())
# 移除控制字符
text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t")
text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t")
return text.strip()
def is_supported(self, filename: str) -> bool:
"""检查文件格式是否支持"""
ext = os.path.splitext(filename.lower())[1]
ext = os.path.splitext(filename.lower())[1]
return ext in self.supported_formats
@@ -165,7 +165,7 @@ class SimpleTextExtractor:
def extract(self, content: bytes, filename: str) -> str:
"""尝试提取文本"""
encodings = ["utf-8", "gbk", "latin-1"]
encodings = ["utf-8", "gbk", "latin-1"]
for encoding in encodings:
try:
@@ -173,15 +173,15 @@ class SimpleTextExtractor:
except UnicodeDecodeError:
continue
return content.decode("latin-1", errors="ignore")
return content.decode("latin-1", errors = "ignore")
if __name__ == "__main__":
# 测试
processor = DocumentProcessor()
processor = DocumentProcessor()
# 测试文本提取
test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs."
result = processor.process(test_text.encode("utf-8"), "test.txt")
test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs."
result = processor.process(test_text.encode("utf-8"), "test.txt")
print(f"Text extraction test: {len(result['text'])} chars")
print(result["text"][:100])

File diff suppressed because it is too large Load Diff

View File

@@ -12,8 +12,8 @@ import httpx
import numpy as np
# API Keys
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@dataclass
@@ -27,9 +27,9 @@ class EntityEmbedding:
class EntityAligner:
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
self.embedding_cache: dict[str, list[float]] = {}
def __init__(self, similarity_threshold: float = 0.85) -> None:
self.similarity_threshold = similarity_threshold
self.embedding_cache: dict[str, list[float]] = {}
def get_embedding(self, text: str) -> list[float] | None:
"""
@@ -45,25 +45,25 @@ class EntityAligner:
return None
# 检查缓存
cache_key = hash(text)
cache_key = hash(text)
if cache_key in self.embedding_cache:
return self.embedding_cache[cache_key]
try:
response = httpx.post(
response = httpx.post(
f"{KIMI_BASE_URL}/v1/embeddings",
headers={
headers = {
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={"model": "k2p5", "input": text[:500]}, # 限制长度
timeout=30.0,
json = {"model": "k2p5", "input": text[:500]}, # 限制长度
timeout = 30.0,
)
response.raise_for_status()
result = response.json()
result = response.json()
embedding = result["data"][0]["embedding"]
self.embedding_cache[cache_key] = embedding
embedding = result["data"][0]["embedding"]
self.embedding_cache[cache_key] = embedding
return embedding
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
@@ -81,20 +81,20 @@ class EntityAligner:
Returns:
相似度分数 (0-1)
"""
vec1 = np.array(embedding1)
vec2 = np.array(embedding2)
vec1 = np.array(embedding1)
vec2 = np.array(embedding2)
# 余弦相似度
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return float(dot_product / (norm1 * norm2))
def get_entity_text(self, name: str, definition: str = "") -> str:
def get_entity_text(self, name: str, definition: str = "") -> str:
"""
构建用于 embedding 的实体文本
@@ -113,9 +113,9 @@ class EntityAligner:
self,
project_id: str,
name: str,
definition: str = "",
exclude_id: str | None = None,
threshold: float | None = None,
definition: str = "",
exclude_id: str | None = None,
threshold: float | None = None,
) -> object | None:
"""
查找相似的实体
@@ -131,54 +131,54 @@ class EntityAligner:
相似的实体或 None
"""
if threshold is None:
threshold = self.similarity_threshold
threshold = self.similarity_threshold
try:
from db_manager import get_db_manager
db = get_db_manager()
db = get_db_manager()
except ImportError:
return None
# 获取项目的所有实体
entities = db.get_all_entities_for_embedding(project_id)
entities = db.get_all_entities_for_embedding(project_id)
if not entities:
return None
# 获取查询实体的 embedding
query_text = self.get_entity_text(name, definition)
query_embedding = self.get_embedding(query_text)
query_text = self.get_entity_text(name, definition)
query_embedding = self.get_embedding(query_text)
if query_embedding is None:
# 如果 embedding API 失败,回退到简单匹配
return self._fallback_similarity_match(entities, name, exclude_id)
best_match = None
best_score = threshold
best_match = None
best_score = threshold
for entity in entities:
if exclude_id and entity.id == exclude_id:
continue
# 获取实体的 embedding
entity_text = self.get_entity_text(entity.name, entity.definition)
entity_embedding = self.get_embedding(entity_text)
entity_text = self.get_entity_text(entity.name, entity.definition)
entity_embedding = self.get_embedding(entity_text)
if entity_embedding is None:
continue
# 计算相似度
similarity = self.compute_similarity(query_embedding, entity_embedding)
similarity = self.compute_similarity(query_embedding, entity_embedding)
if similarity > best_score:
best_score = similarity
best_match = entity
best_score = similarity
best_match = entity
return best_match
def _fallback_similarity_match(
self, entities: list[object], name: str, exclude_id: str | None = None
self, entities: list[object], name: str, exclude_id: str | None = None
) -> object | None:
"""
回退到简单的相似度匹配(不使用 embedding
@@ -191,7 +191,7 @@ class EntityAligner:
Returns:
最相似的实体或 None
"""
name_lower = name.lower()
name_lower = name.lower()
# 1. 精确匹配
for entity in entities:
@@ -212,7 +212,7 @@ class EntityAligner:
return None
def batch_align_entities(
self, project_id: str, new_entities: list[dict], threshold: float | None = None
self, project_id: str, new_entities: list[dict], threshold: float | None = None
) -> list[dict]:
"""
批量对齐实体
@@ -226,16 +226,16 @@ class EntityAligner:
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
"""
if threshold is None:
threshold = self.similarity_threshold
threshold = self.similarity_threshold
results = []
results = []
for new_ent in new_entities:
matched = self.find_similar_entity(
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
matched = self.find_similar_entity(
project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold
)
result = {
result = {
"new_entity": new_ent,
"matched_entity": None,
"similarity": 0.0,
@@ -244,28 +244,28 @@ class EntityAligner:
if matched:
# 计算相似度
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
matched_text = self.get_entity_text(matched.name, matched.definition)
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
matched_text = self.get_entity_text(matched.name, matched.definition)
query_emb = self.get_embedding(query_text)
matched_emb = self.get_embedding(matched_text)
query_emb = self.get_embedding(query_text)
matched_emb = self.get_embedding(matched_text)
if query_emb and matched_emb:
similarity = self.compute_similarity(query_emb, matched_emb)
result["matched_entity"] = {
similarity = self.compute_similarity(query_emb, matched_emb)
result["matched_entity"] = {
"id": matched.id,
"name": matched.name,
"type": matched.type,
"definition": matched.definition,
}
result["similarity"] = similarity
result["should_merge"] = similarity >= threshold
result["similarity"] = similarity
result["should_merge"] = similarity >= threshold
results.append(result)
return results
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
"""
使用 LLM 建议实体的别名
@@ -279,7 +279,7 @@ class EntityAligner:
if not KIMI_API_KEY:
return []
prompt = f"""为以下实体生成可能的别名或简称:
prompt = f"""为以下实体生成可能的别名或简称:
实体名称:{entity_name}
定义:{entity_definition}
@@ -290,28 +290,28 @@ class EntityAligner:
只返回 JSON不要其他内容。"""
try:
response = httpx.post(
response = httpx.post(
f"{KIMI_BASE_URL}/v1/chat/completions",
headers={
headers = {
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={
json = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
timeout=30.0,
timeout = 30.0,
)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
result = response.json()
content = result["choices"][0]["message"]["content"]
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
data = json.loads(json_match.group())
return data.get("aliases", [])
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
print(f"Alias suggestion failed: {e}")
@@ -340,8 +340,8 @@ def simple_similarity(str1: str, str2: str) -> float:
return 0.0
# 转换为小写
s1 = str1.lower()
s2 = str2.lower()
s1 = str1.lower()
s2 = str2.lower()
# 包含关系
if s1 in s2 or s2 in s1:
@@ -355,11 +355,11 @@ def simple_similarity(str1: str, str2: str) -> float:
if __name__ == "__main__":
# 测试
aligner = EntityAligner()
aligner = EntityAligner()
# 测试 embedding
test_text = "Kubernetes 容器编排平台"
embedding = aligner.get_embedding(test_text)
test_text = "Kubernetes 容器编排平台"
embedding = aligner.get_embedding(test_text)
if embedding:
print(f"Embedding dimension: {len(embedding)}")
print(f"First 5 values: {embedding[:5]}")
@@ -367,7 +367,7 @@ if __name__ == "__main__":
print("Embedding API not available")
# 测试相似度计算
emb1 = [1.0, 0.0, 0.0]
emb2 = [0.9, 0.1, 0.0]
sim = aligner.compute_similarity(emb1, emb2)
emb1 = [1.0, 0.0, 0.0]
emb2 = [0.9, 0.1, 0.0]
sim = aligner.compute_similarity(emb1, emb2)
print(f"Similarity: {sim:.4f}")

View File

@@ -14,9 +14,9 @@ from typing import Any
try:
import pandas as pd
PANDAS_AVAILABLE = True
PANDAS_AVAILABLE = True
except ImportError:
PANDAS_AVAILABLE = False
PANDAS_AVAILABLE = False
try:
from reportlab.lib import colors
@@ -32,9 +32,9 @@ try:
TableStyle,
)
REPORTLAB_AVAILABLE = True
REPORTLAB_AVAILABLE = True
except ImportError:
REPORTLAB_AVAILABLE = False
REPORTLAB_AVAILABLE = False
@dataclass
@@ -71,8 +71,8 @@ class ExportTranscript:
class ExportManager:
"""导出管理器 - 处理各种导出需求"""
def __init__(self, db_manager=None):
self.db = db_manager
def __init__(self, db_manager = None) -> None:
self.db = db_manager
def export_knowledge_graph_svg(
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
@@ -84,21 +84,21 @@ class ExportManager:
SVG 字符串
"""
# 计算布局参数
width = 1200
height = 800
center_x = width / 2
center_y = height / 2
radius = 300
width = 1200
height = 800
center_x = width / 2
center_y = height / 2
radius = 300
# 按类型分组实体
entities_by_type = {}
entities_by_type = {}
for e in entities:
if e.type not in entities_by_type:
entities_by_type[e.type] = []
entities_by_type[e.type] = []
entities_by_type[e.type].append(e)
# 颜色映射
type_colors = {
type_colors = {
"PERSON": "#FF6B6B",
"ORGANIZATION": "#4ECDC4",
"LOCATION": "#45B7D1",
@@ -110,110 +110,110 @@ class ExportManager:
}
# 计算实体位置
entity_positions = {}
angle_step = 2 * 3.14159 / max(len(entities), 1)
entity_positions = {}
angle_step = 2 * 3.14159 / max(len(entities), 1)
for i, entity in enumerate(entities):
i * angle_step
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
entity_positions[entity.id] = (x, y)
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
entity_positions[entity.id] = (x, y)
# 生成 SVG
svg_parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" '
f'viewBox="0 0 {width} {height}">',
svg_parts = [
f'<svg xmlns = "http://www.w3.org/2000/svg" width = "{width}" height = "{height}" '
f'viewBox = "0 0 {width} {height}">',
"<defs>",
' <marker id="arrowhead" markerWidth="10" markerHeight="7" '
'refX="9" refY="3.5" orient="auto">',
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>',
' <marker id = "arrowhead" markerWidth = "10" markerHeight = "7" '
'refX = "9" refY = "3.5" orient = "auto">',
' <polygon points = "0 0, 10 3.5, 0 7" fill = "#7f8c8d"/>',
" </marker>",
"</defs>",
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>',
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" '
f'font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>',
f'<rect width = "{width}" height = "{height}" fill = "#f8f9fa"/>',
f'<text x = "{center_x}" y = "30" text-anchor = "middle" font-size = "20" '
f'font-weight = "bold" fill = "#2c3e50">知识图谱 - {project_id}</text>',
]
# 绘制关系连线
for rel in relations:
if rel.source in entity_positions and rel.target in entity_positions:
x1, y1 = entity_positions[rel.source]
x2, y2 = entity_positions[rel.target]
x1, y1 = entity_positions[rel.source]
x2, y2 = entity_positions[rel.target]
# 计算箭头终点(避免覆盖节点)
dx = x2 - x1
dy = y2 - y1
dist = (dx**2 + dy**2) ** 0.5
dx = x2 - x1
dy = y2 - y1
dist = (dx**2 + dy**2) ** 0.5
if dist > 0:
offset = 40
x2 = x2 - dx * offset / dist
y2 = y2 - dy * offset / dist
offset = 40
x2 = x2 - dx * offset / dist
y2 = y2 - dy * offset / dist
svg_parts.append(
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" '
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>'
f'<line x1 = "{x1}" y1 = "{y1}" x2 = "{x2}" y2 = "{y2}" '
f'stroke = "#7f8c8d" stroke-width = "2" marker-end = "url(#arrowhead)" opacity = "0.6"/>'
)
# 关系标签
mid_x = (x1 + x2) / 2
mid_y = (y1 + y2) / 2
mid_x = (x1 + x2) / 2
mid_y = (y1 + y2) / 2
svg_parts.append(
f'<rect x="{mid_x - 30}" y="{mid_y - 10}" width="60" height="20" '
f'fill="white" stroke="#bdc3c7" rx="3"/>'
f'<rect x = "{mid_x - 30}" y = "{mid_y - 10}" width = "60" height = "20" '
f'fill = "white" stroke = "#bdc3c7" rx = "3"/>'
)
svg_parts.append(
f'<text x="{mid_x}" y="{mid_y + 5}" text-anchor="middle" '
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>'
f'<text x = "{mid_x}" y = "{mid_y + 5}" text-anchor = "middle" '
f'font-size = "10" fill = "#2c3e50">{rel.relation_type}</text>'
)
# 绘制实体节点
for entity in entities:
if entity.id in entity_positions:
x, y = entity_positions[entity.id]
color = type_colors.get(entity.type, type_colors["default"])
x, y = entity_positions[entity.id]
color = type_colors.get(entity.type, type_colors["default"])
# 节点圆圈
svg_parts.append(
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>'
f'<circle cx = "{x}" cy = "{y}" r = "35" fill = "{color}" stroke = "white" stroke-width = "3"/>'
)
# 实体名称
svg_parts.append(
f'<text x="{x}" y="{y + 5}" text-anchor="middle" font-size="12" '
f'font-weight="bold" fill="white">{entity.name[:8]}</text>'
f'<text x = "{x}" y = "{y + 5}" text-anchor = "middle" font-size = "12" '
f'font-weight = "bold" fill = "white">{entity.name[:8]}</text>'
)
# 实体类型
svg_parts.append(
f'<text x="{x}" y="{y + 55}" text-anchor="middle" font-size="10" '
f'fill="#7f8c8d">{entity.type}</text>'
f'<text x = "{x}" y = "{y + 55}" text-anchor = "middle" font-size = "10" '
f'fill = "#7f8c8d">{entity.type}</text>'
)
# 图例
legend_x = width - 150
legend_y = 80
rect_x = legend_x - 10
rect_y = legend_y - 20
rect_height = len(type_colors) * 25 + 10
legend_x = width - 150
legend_y = 80
rect_x = legend_x - 10
rect_y = legend_y - 20
rect_height = len(type_colors) * 25 + 10
svg_parts.append(
f'<rect x="{rect_x}" y="{rect_y}" width="140" height="{rect_height}" '
f'fill="white" stroke="#bdc3c7" rx="5"/>'
f'<rect x = "{rect_x}" y = "{rect_y}" width = "140" height = "{rect_height}" '
f'fill = "white" stroke = "#bdc3c7" rx = "5"/>'
)
svg_parts.append(
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
f'fill="#2c3e50">实体类型</text>'
f'<text x = "{legend_x}" y = "{legend_y}" font-size = "12" font-weight = "bold" '
f'fill = "#2c3e50">实体类型</text>'
)
for i, (etype, color) in enumerate(type_colors.items()):
if etype != "default":
y_pos = legend_y + 25 + i * 20
y_pos = legend_y + 25 + i * 20
svg_parts.append(
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
f'<circle cx = "{legend_x + 10}" cy = "{y_pos}" r = "8" fill = "{color}"/>'
)
text_y = y_pos + 4
text_y = y_pos + 4
svg_parts.append(
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
f'fill="#2c3e50">{etype}</text>'
f'<text x = "{legend_x + 25}" y = "{text_y}" font-size = "10" '
f'fill = "#2c3e50">{etype}</text>'
)
svg_parts.append("</svg>")
@@ -231,12 +231,12 @@ class ExportManager:
try:
import cairosvg
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8"))
return png_bytes
except ImportError:
# 如果没有 cairosvg返回 SVG 的 base64
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
return base64.b64encode(svg_content.encode("utf-8"))
def export_entities_excel(self, entities: list[ExportEntity]) -> bytes:
@@ -250,9 +250,9 @@ class ExportManager:
raise ImportError("pandas is required for Excel export")
# 准备数据
data = []
data = []
for e in entities:
row = {
row = {
"ID": e.id,
"名称": e.name,
"类型": e.type,
@@ -262,29 +262,29 @@ class ExportManager:
}
# 添加属性
for attr_name, attr_value in e.attributes.items():
row[f"属性:{attr_name}"] = attr_value
row[f"属性:{attr_name}"] = attr_value
data.append(row)
df = pd.DataFrame(data)
df = pd.DataFrame(data)
# 写入 Excel
output = io.BytesIO()
with pd.ExcelWriter(output, engine="openpyxl") as writer:
df.to_excel(writer, sheet_name="实体列表", index=False)
output = io.BytesIO()
with pd.ExcelWriter(output, engine = "openpyxl") as writer:
df.to_excel(writer, sheet_name = "实体列表", index = False)
# 调整列宽
worksheet = writer.sheets["实体列表"]
worksheet = writer.sheets["实体列表"]
for column in worksheet.columns:
max_length = 0
column_letter = column[0].column_letter
max_length = 0
column_letter = column[0].column_letter
for cell in column:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
max_length = len(str(cell.value))
except (AttributeError, TypeError, ValueError):
pass
adjusted_width = min(max_length + 2, 50)
worksheet.column_dimensions[column_letter].width = adjusted_width
adjusted_width = min(max_length + 2, 50)
worksheet.column_dimensions[column_letter].width = adjusted_width
return output.getvalue()
@@ -295,24 +295,24 @@ class ExportManager:
Returns:
CSV 字符串
"""
output = io.StringIO()
output = io.StringIO()
# 收集所有可能的属性列
all_attrs = set()
all_attrs = set()
for e in entities:
all_attrs.update(e.attributes.keys())
# 表头
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [
f"属性:{a}" for a in sorted(all_attrs)
]
writer = csv.writer(output)
writer = csv.writer(output)
writer.writerow(headers)
# 数据行
for e in entities:
row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count]
row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count]
for attr in sorted(all_attrs):
row.append(e.attributes.get(attr, ""))
writer.writerow(row)
@@ -327,8 +327,8 @@ class ExportManager:
CSV 字符串
"""
output = io.StringIO()
writer = csv.writer(output)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["ID", "源实体", "目标实体", "关系类型", "置信度", "证据"])
for r in relations:
@@ -345,7 +345,7 @@ class ExportManager:
Returns:
Markdown 字符串
"""
lines = [
lines = [
f"# {transcript.name}",
"",
f"**类型**: {transcript.type}",
@@ -369,10 +369,10 @@ class ExportManager:
]
)
for seg in transcript.segments:
speaker = seg.get("speaker", "Unknown")
start = seg.get("start", 0)
end = seg.get("end", 0)
text = seg.get("text", "")
speaker = seg.get("speaker", "Unknown")
start = seg.get("start", 0)
end = seg.get("end", 0)
text = seg.get("text", "")
lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}")
lines.append("")
@@ -387,12 +387,12 @@ class ExportManager:
]
)
for mention in transcript.entity_mentions:
entity_id = mention.get("entity_id", "")
entity = entities_map.get(entity_id)
entity_name = entity.name if entity else mention.get("entity_name", "Unknown")
entity_type = entity.type if entity else "Unknown"
position = mention.get("position", "")
context = mention.get("context", "")[:50] + "..." if mention.get("context") else ""
entity_id = mention.get("entity_id", "")
entity = entities_map.get(entity_id)
entity_name = entity.name if entity else mention.get("entity_name", "Unknown")
entity_type = entity.type if entity else "Unknown"
position = mention.get("position", "")
context = mention.get("context", "")[:50] + "..." if mention.get("context") else ""
lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |")
return "\n".join(lines)
@@ -404,7 +404,7 @@ class ExportManager:
entities: list[ExportEntity],
relations: list[ExportRelation],
transcripts: list[ExportTranscript],
summary: str = "",
summary: str = "",
) -> bytes:
"""
导出项目报告为 PDF 格式
@@ -415,29 +415,29 @@ class ExportManager:
if not REPORTLAB_AVAILABLE:
raise ImportError("reportlab is required for PDF export")
output = io.BytesIO()
doc = SimpleDocTemplate(
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18
output = io.BytesIO()
doc = SimpleDocTemplate(
output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18
)
# 样式
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
"CustomTitle",
parent=styles["Heading1"],
fontSize=24,
spaceAfter=30,
textColor=colors.HexColor("#2c3e50"),
parent = styles["Heading1"],
fontSize = 24,
spaceAfter = 30,
textColor = colors.HexColor("#2c3e50"),
)
heading_style = ParagraphStyle(
heading_style = ParagraphStyle(
"CustomHeading",
parent=styles["Heading2"],
fontSize=16,
spaceAfter=12,
textColor=colors.HexColor("#34495e"),
parent = styles["Heading2"],
fontSize = 16,
spaceAfter = 12,
textColor = colors.HexColor("#34495e"),
)
story = []
story = []
# 标题页
story.append(Paragraph("InsightFlow 项目报告", title_style))
@@ -452,7 +452,7 @@ class ExportManager:
# 统计概览
story.append(Paragraph("项目概览", heading_style))
stats_data = [
stats_data = [
["指标", "数值"],
["实体数量", str(len(entities))],
["关系数量", str(len(relations))],
@@ -460,14 +460,14 @@ class ExportManager:
]
# 按类型统计实体
type_counts = {}
type_counts = {}
for e in entities:
type_counts[e.type] = type_counts.get(e.type, 0) + 1
type_counts[e.type] = type_counts.get(e.type, 0) + 1
for etype, count in sorted(type_counts.items()):
stats_data.append([f"{etype} 实体", str(count)])
stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch])
stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch])
stats_table.setStyle(
TableStyle(
[
@@ -496,8 +496,8 @@ class ExportManager:
story.append(PageBreak())
story.append(Paragraph("实体列表", heading_style))
entity_data = [["名称", "类型", "提及次数", "定义"]]
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[
entity_data = [["名称", "类型", "提及次数", "定义"]]
for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[
:50
]: # 限制前50个
entity_data.append(
@@ -509,8 +509,8 @@ class ExportManager:
]
)
entity_table = Table(
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
entity_table = Table(
entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
)
entity_table.setStyle(
TableStyle(
@@ -534,12 +534,12 @@ class ExportManager:
story.append(PageBreak())
story.append(Paragraph("关系列表", heading_style))
relation_data = [["源实体", "关系", "目标实体", "置信度"]]
relation_data = [["源实体", "关系", "目标实体", "置信度"]]
for r in relations[:100]: # 限制前100个
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
relation_table = Table(
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
relation_table = Table(
relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
)
relation_table.setStyle(
TableStyle(
@@ -574,7 +574,7 @@ class ExportManager:
Returns:
JSON 字符串
"""
data = {
data = {
"project_id": project_id,
"project_name": project_name,
"export_time": datetime.now().isoformat(),
@@ -613,16 +613,16 @@ class ExportManager:
],
}
return json.dumps(data, ensure_ascii=False, indent=2)
return json.dumps(data, ensure_ascii = False, indent = 2)
# 全局导出管理器实例
_export_manager = None
_export_manager = None
def get_export_manager(db_manager=None) -> None:
def get_export_manager(db_manager = None) -> None:
"""获取导出管理器实例"""
global _export_manager
if _export_manager is None:
_export_manager = ExportManager(db_manager)
_export_manager = ExportManager(db_manager)
return _export_manager

File diff suppressed because it is too large Load Diff

View File

@@ -11,30 +11,30 @@ import uuid
from dataclasses import dataclass
# Constants
UUID_LENGTH = 8 # UUID 截断长度
UUID_LENGTH = 8 # UUID 截断长度
# 尝试导入图像处理库
try:
from PIL import Image, ImageEnhance, ImageFilter
PIL_AVAILABLE = True
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
PIL_AVAILABLE = False
try:
import cv2
import numpy as np
CV2_AVAILABLE = True
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
CV2_AVAILABLE = False
try:
import pytesseract
PYTESSERACT_AVAILABLE = True
PYTESSERACT_AVAILABLE = True
except ImportError:
PYTESSERACT_AVAILABLE = False
PYTESSERACT_AVAILABLE = False
@dataclass
@@ -44,7 +44,7 @@ class ImageEntity:
name: str
type: str
confidence: float
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
@dataclass
@@ -70,7 +70,7 @@ class ImageProcessingResult:
width: int
height: int
success: bool
error_message: str = ""
error_message: str = ""
@dataclass
@@ -87,7 +87,7 @@ class ImageProcessor:
"""图片处理器 - 处理各种类型图片"""
# 图片类型定义
IMAGE_TYPES = {
IMAGE_TYPES = {
"whiteboard": "白板",
"ppt": "PPT/演示文稿",
"handwritten": "手写笔记",
@@ -96,17 +96,17 @@ class ImageProcessor:
"other": "其他",
}
def __init__(self, temp_dir: str = None) -> None:
def __init__(self, temp_dir: str = None) -> None:
"""
初始化图片处理器
Args:
temp_dir: 临时文件目录
"""
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok=True)
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok = True)
def preprocess_image(self, image, image_type: str = None) -> None:
def preprocess_image(self, image, image_type: str = None) -> None:
"""
预处理图片以提高OCR质量
@@ -123,25 +123,25 @@ class ImageProcessor:
try:
# 转换为RGB如果是RGBA
if image.mode == "RGBA":
image = image.convert("RGB")
image = image.convert("RGB")
# 根据图片类型进行针对性处理
if image_type == "whiteboard":
# 白板:增强对比度,去除背景
image = self._enhance_whiteboard(image)
image = self._enhance_whiteboard(image)
elif image_type == "handwritten":
# 手写笔记:降噪,增强对比度
image = self._enhance_handwritten(image)
image = self._enhance_handwritten(image)
elif image_type == "screenshot":
# 截图:轻微锐化
image = image.filter(ImageFilter.SHARPEN)
image = image.filter(ImageFilter.SHARPEN)
# 通用处理:调整大小(如果太大)
max_size = 4096
max_size = 4096
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
except Exception as e:
@@ -151,33 +151,33 @@ class ImageProcessor:
def _enhance_whiteboard(self, image) -> None:
"""增强白板图片"""
# 转换为灰度
gray = image.convert("L")
gray = image.convert("L")
# 增强对比度
enhancer = ImageEnhance.Contrast(gray)
enhanced = enhancer.enhance(2.0)
enhancer = ImageEnhance.Contrast(gray)
enhanced = enhancer.enhance(2.0)
# 二值化
threshold = 128
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
threshold = 128
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
return binary.convert("L")
def _enhance_handwritten(self, image) -> None:
"""增强手写笔记图片"""
# 转换为灰度
gray = image.convert("L")
gray = image.convert("L")
# 轻微降噪
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1))
# 增强对比度
enhancer = ImageEnhance.Contrast(blurred)
enhanced = enhancer.enhance(1.5)
enhancer = ImageEnhance.Contrast(blurred)
enhanced = enhancer.enhance(1.5)
return enhanced
def detect_image_type(self, image, ocr_text: str = "") -> str:
def detect_image_type(self, image, ocr_text: str = "") -> str:
"""
自动检测图片类型
@@ -193,8 +193,8 @@ class ImageProcessor:
try:
# 基于图片特征和OCR内容判断类型
width, height = image.size
aspect_ratio = width / height
width, height = image.size
aspect_ratio = width / height
# 检测是否为PPT通常是16:9或4:3
if 1.3 <= aspect_ratio <= 1.8:
@@ -204,12 +204,12 @@ class ImageProcessor:
# 检测是否为白板(大量手写文字,可能有箭头、框等)
if CV2_AVAILABLE:
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# 检测边缘(白板通常有很多线条)
edges = cv2.Canny(gray, 50, 150)
edge_ratio = np.sum(edges > 0) / edges.size
edges = cv2.Canny(gray, 50, 150)
edge_ratio = np.sum(edges > 0) / edges.size
# 如果边缘比例高,可能是白板
if edge_ratio > 0.05 and len(ocr_text) > 50:
@@ -236,7 +236,7 @@ class ImageProcessor:
print(f"Image type detection error: {e}")
return "other"
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
"""
对图片进行OCR识别
@@ -252,15 +252,15 @@ class ImageProcessor:
try:
# 预处理图片
processed_image = self.preprocess_image(image)
processed_image = self.preprocess_image(image)
# 执行OCR
text = pytesseract.image_to_string(processed_image, lang=lang)
text = pytesseract.image_to_string(processed_image, lang = lang)
# 获取置信度
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
return text.strip(), avg_confidence / 100.0
except Exception as e:
@@ -277,26 +277,26 @@ class ImageProcessor:
Returns:
实体列表
"""
entities = []
entities = []
# 简单的实体提取规则可以替换为LLM调用
# 提取大写字母开头的词组(可能是专有名词)
import re
# 项目名称(通常是大写或带引号)
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
for match in re.finditer(project_pattern, text):
name = match.group(1) or match.group(2)
name = match.group(1) or match.group(2)
if name and len(name) > 2:
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7))
# 人名(中文)
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)"
name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)"
for match in re.finditer(name_pattern, text):
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8))
# 技术术语
tech_keywords = [
tech_keywords = [
"K8s",
"Kubernetes",
"Docker",
@@ -314,13 +314,13 @@ class ImageProcessor:
]
for keyword in tech_keywords:
if keyword in text:
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9))
# 去重
seen = set()
unique_entities = []
seen = set()
unique_entities = []
for e in entities:
key = (e.name.lower(), e.type)
key = (e.name.lower(), e.type)
if key not in seen:
seen.add(key)
unique_entities.append(e)
@@ -341,19 +341,19 @@ class ImageProcessor:
Returns:
图片描述
"""
type_name = self.IMAGE_TYPES.get(image_type, "图片")
type_name = self.IMAGE_TYPES.get(image_type, "图片")
description_parts = [f"这是一张{type_name}图片。"]
description_parts = [f"这是一张{type_name}图片。"]
if ocr_text:
# 提取前200字符作为摘要
text_preview = ocr_text[:200].replace("\n", " ")
text_preview = ocr_text[:200].replace("\n", " ")
if len(ocr_text) > 200:
text_preview += "..."
description_parts.append(f"内容摘要:{text_preview}")
if entities:
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}")
return " ".join(description_parts)
@@ -361,9 +361,9 @@ class ImageProcessor:
def process_image(
self,
image_data: bytes,
filename: str = None,
image_id: str = None,
detect_type: bool = True,
filename: str = None,
image_id: str = None,
detect_type: bool = True,
) -> ImageProcessingResult:
"""
处理单张图片
@@ -377,73 +377,73 @@ class ImageProcessor:
Returns:
图片处理结果
"""
image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH]
image_id = image_id or str(uuid.uuid4())[:UUID_LENGTH]
if not PIL_AVAILABLE:
return ImageProcessingResult(
image_id=image_id,
image_type="other",
ocr_text="",
description="PIL not available",
entities=[],
relations=[],
width=0,
height=0,
success=False,
error_message="PIL library not available",
image_id = image_id,
image_type = "other",
ocr_text = "",
description = "PIL not available",
entities = [],
relations = [],
width = 0,
height = 0,
success = False,
error_message = "PIL library not available",
)
try:
# 加载图片
image = Image.open(io.BytesIO(image_data))
width, height = image.size
image = Image.open(io.BytesIO(image_data))
width, height = image.size
# 执行OCR
ocr_text, ocr_confidence = self.perform_ocr(image)
ocr_text, ocr_confidence = self.perform_ocr(image)
# 检测图片类型
image_type = "other"
image_type = "other"
if detect_type:
image_type = self.detect_image_type(image, ocr_text)
image_type = self.detect_image_type(image, ocr_text)
# 提取实体
entities = self.extract_entities_from_text(ocr_text)
entities = self.extract_entities_from_text(ocr_text)
# 生成描述
description = self.generate_description(image_type, ocr_text, entities)
description = self.generate_description(image_type, ocr_text, entities)
# 提取关系(基于实体共现)
relations = self._extract_relations(entities, ocr_text)
relations = self._extract_relations(entities, ocr_text)
# 保存图片文件(可选)
if filename:
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
image.save(save_path)
return ImageProcessingResult(
image_id=image_id,
image_type=image_type,
ocr_text=ocr_text,
description=description,
entities=entities,
relations=relations,
width=width,
height=height,
success=True,
image_id = image_id,
image_type = image_type,
ocr_text = ocr_text,
description = description,
entities = entities,
relations = relations,
width = width,
height = height,
success = True,
)
except Exception as e:
return ImageProcessingResult(
image_id=image_id,
image_type="other",
ocr_text="",
description="",
entities=[],
relations=[],
width=0,
height=0,
success=False,
error_message=str(e),
image_id = image_id,
image_type = "other",
ocr_text = "",
description = "",
entities = [],
relations = [],
width = 0,
height = 0,
success = False,
error_message = str(e),
)
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
@@ -457,16 +457,16 @@ class ImageProcessor:
Returns:
关系列表
"""
relations = []
relations = []
if len(entities) < 2:
return relations
# 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关
sentences = text.replace("", ".").replace("", "!").replace("", "?").split(".")
sentences = text.replace("", ".").replace("", "!").replace("", "?").split(".")
for sentence in sentences:
sentence_entities = []
sentence_entities = []
for entity in entities:
if entity.name in sentence:
sentence_entities.append(entity)
@@ -477,17 +477,17 @@ class ImageProcessor:
for j in range(i + 1, len(sentence_entities)):
relations.append(
ImageRelation(
source=sentence_entities[i].name,
target=sentence_entities[j].name,
relation_type="related",
confidence=0.5,
source = sentence_entities[i].name,
target = sentence_entities[j].name,
relation_type = "related",
confidence = 0.5,
)
)
return relations
def process_batch(
self, images_data: list[tuple[bytes, str]], project_id: str = None
self, images_data: list[tuple[bytes, str]], project_id: str = None
) -> BatchProcessingResult:
"""
批量处理图片
@@ -499,12 +499,12 @@ class ImageProcessor:
Returns:
批量处理结果
"""
results = []
success_count = 0
failed_count = 0
results = []
success_count = 0
failed_count = 0
for image_data, filename in images_data:
result = self.process_image(image_data, filename)
result = self.process_image(image_data, filename)
results.append(result)
if result.success:
@@ -513,10 +513,10 @@ class ImageProcessor:
failed_count += 1
return BatchProcessingResult(
results=results,
total_count=len(results),
success_count=success_count,
failed_count=failed_count,
results = results,
total_count = len(results),
success_count = success_count,
failed_count = failed_count,
)
def image_to_base64(self, image_data: bytes) -> str:
@@ -531,7 +531,7 @@ class ImageProcessor:
"""
return base64.b64encode(image_data).decode("utf-8")
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
"""
生成图片缩略图
@@ -546,11 +546,11 @@ class ImageProcessor:
return image_data
try:
image = Image.open(io.BytesIO(image_data))
image = Image.open(io.BytesIO(image_data))
image.thumbnail(size, Image.Resampling.LANCZOS)
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
buffer = io.BytesIO()
image.save(buffer, format = "JPEG")
return buffer.getvalue()
except Exception as e:
print(f"Thumbnail generation error: {e}")
@@ -558,12 +558,12 @@ class ImageProcessor:
# Singleton instance
_image_processor = None
_image_processor = None
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
"""获取图片处理器单例"""
global _image_processor
if _image_processor is None:
_image_processor = ImageProcessor(temp_dir)
_image_processor = ImageProcessor(temp_dir)
return _image_processor

View File

@@ -4,27 +4,27 @@
import os
import sqlite3
db_path = os.path.join(os.path.dirname(__file__), "insightflow.db")
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
db_path = os.path.join(os.path.dirname(__file__), "insightflow.db")
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
print(f"Database path: {db_path}")
print(f"Schema path: {schema_path}")
# Read schema
with open(schema_path) as f:
schema = f.read()
schema = f.read()
# Execute schema
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Split schema by semicolons and execute each statement
statements = schema.split(";")
success_count = 0
error_count = 0
statements = schema.split(";")
success_count = 0
error_count = 0
for stmt in statements:
stmt = stmt.strip()
stmt = stmt.strip()
if stmt:
try:
cursor.execute(stmt)

View File

@@ -12,18 +12,18 @@ from enum import Enum
import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
class ReasoningType(Enum):
"""推理类型"""
CAUSAL = "causal" # 因果推理
ASSOCIATIVE = "associative" # 关联推理
TEMPORAL = "temporal" # 时序推理
COMPARATIVE = "comparative" # 对比推理
SUMMARY = "summary" # 总结推理
CAUSAL = "causal" # 因果推理
ASSOCIATIVE = "associative" # 关联推理
TEMPORAL = "temporal" # 时序推理
COMPARATIVE = "comparative" # 对比推理
SUMMARY = "summary" # 总结推理
@dataclass
@@ -51,38 +51,38 @@ class InferencePath:
class KnowledgeReasoner:
"""知识推理引擎"""
def __init__(self, api_key: str = None, base_url: str = None):
self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL
self.headers = {
def __init__(self, api_key: str = None, base_url: str = None) -> None:
self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
"""调用 LLM"""
if not self.api_key:
raise ValueError("KIMI_API_KEY not set")
payload = {
payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature,
}
async with httpx.AsyncClient() as client:
response = await client.post(
response = await client.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
headers = self.headers,
json = payload,
timeout = 120.0,
)
response.raise_for_status()
result = response.json()
result = response.json()
return result["choices"][0]["message"]["content"]
async def enhanced_qa(
self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium"
self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium"
) -> ReasoningResult:
"""
增强问答 - 结合图谱推理的问答
@@ -94,7 +94,7 @@ class KnowledgeReasoner:
reasoning_depth: 推理深度 (shallow/medium/deep)
"""
# 1. 分析问题类型
analysis = await self._analyze_question(query)
analysis = await self._analyze_question(query)
# 2. 根据问题类型选择推理策略
if analysis["type"] == "causal":
@@ -108,7 +108,7 @@ class KnowledgeReasoner:
async def _analyze_question(self, query: str) -> dict:
"""分析问题类型和意图"""
prompt = f"""分析以下问题的类型和意图:
prompt = f"""分析以下问题的类型和意图:
问题:{query}
@@ -127,9 +127,9 @@ class KnowledgeReasoner:
- factual: 事实类问题(是什么、有哪些)
- opinion: 观点类问题(怎么看、态度、评价)"""
content = await self._call_llm(prompt, temperature=0.1)
content = await self._call_llm(prompt, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group())
@@ -144,10 +144,10 @@ class KnowledgeReasoner:
"""因果推理 - 分析原因和影响"""
# 构建因果分析提示
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)
prompt = f"""基于以下知识图谱进行因果推理分析:
prompt = f"""基于以下知识图谱进行因果推理分析:
## 问题
{query}
@@ -159,7 +159,7 @@ class KnowledgeReasoner:
{relations_str[:2000]}
## 项目上下文
{json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]}
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]}
请进行因果分析,返回 JSON 格式:
{{
@@ -172,31 +172,31 @@ class KnowledgeReasoner:
"knowledge_gaps": ["缺失信息1"]
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.CAUSAL,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.CAUSAL,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.CAUSAL,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=["无法完成因果推理"],
answer = content,
reasoning_type = ReasoningType.CAUSAL,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = ["无法完成因果推理"],
)
async def _comparative_reasoning(
@@ -204,16 +204,16 @@ class KnowledgeReasoner:
) -> ReasoningResult:
"""对比推理 - 比较实体间的异同"""
prompt = f"""基于以下知识图谱进行对比分析:
prompt = f"""基于以下知识图谱进行对比分析:
## 问题
{query}
## 实体
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]}
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]}
## 关系
{json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]}
{json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]}
请进行对比分析,返回 JSON 格式:
{{
@@ -226,31 +226,31 @@ class KnowledgeReasoner:
"knowledge_gaps": []
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.COMPARATIVE,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.COMPARATIVE,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.COMPARATIVE,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=[],
answer = content,
reasoning_type = ReasoningType.COMPARATIVE,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = [],
)
async def _temporal_reasoning(
@@ -258,16 +258,16 @@ class KnowledgeReasoner:
) -> ReasoningResult:
"""时序推理 - 分析时间线和演变"""
prompt = f"""基于以下知识图谱进行时序分析:
prompt = f"""基于以下知识图谱进行时序分析:
## 问题
{query}
## 项目时间线
{json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]}
{json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]}
## 实体提及历史
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]}
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]}
请进行时序分析,返回 JSON 格式:
{{
@@ -280,31 +280,31 @@ class KnowledgeReasoner:
"knowledge_gaps": []
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.TEMPORAL,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.TEMPORAL,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.TEMPORAL,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=[],
answer = content,
reasoning_type = ReasoningType.TEMPORAL,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = [],
)
async def _associative_reasoning(
@@ -312,16 +312,16 @@ class KnowledgeReasoner:
) -> ReasoningResult:
"""关联推理 - 发现实体间的隐含关联"""
prompt = f"""基于以下知识图谱进行关联分析:
prompt = f"""基于以下知识图谱进行关联分析:
## 问题
{query}
## 实体
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)}
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)}
## 关系
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)}
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)}
请进行关联推理,发现隐含联系,返回 JSON 格式:
{{
@@ -334,52 +334,52 @@ class KnowledgeReasoner:
"knowledge_gaps": []
}}"""
content = await self._call_llm(prompt, temperature=0.4)
content = await self._call_llm(prompt, temperature = 0.4)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.ASSOCIATIVE,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.ASSOCIATIVE,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.ASSOCIATIVE,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=[],
answer = content,
reasoning_type = ReasoningType.ASSOCIATIVE,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = [],
)
def find_inference_paths(
self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3
self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3
) -> list[InferencePath]:
"""
发现两个实体之间的推理路径
使用 BFS 在关系图中搜索路径
"""
relations = graph_data.get("relations", [])
relations = graph_data.get("relations", [])
# 构建邻接表
adj = {}
adj = {}
for r in relations:
src = r.get("source_id") or r.get("source")
tgt = r.get("target_id") or r.get("target")
src = r.get("source_id") or r.get("source")
tgt = r.get("target_id") or r.get("target")
if src not in adj:
adj[src] = []
adj[src] = []
if tgt not in adj:
adj[tgt] = []
adj[tgt] = []
adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r})
# 无向图也添加反向
adj[tgt].append(
@@ -389,21 +389,21 @@ class KnowledgeReasoner:
# BFS 搜索路径
from collections import deque
paths = []
queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])])
paths = []
queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])])
{start_entity}
while queue and len(paths) < 5:
current, path = queue.popleft()
current, path = queue.popleft()
if current == end_entity and len(path) > 1:
# 找到一条路径
paths.append(
InferencePath(
start_entity=start_entity,
end_entity=end_entity,
path=path,
strength=self._calculate_path_strength(path),
start_entity = start_entity,
end_entity = end_entity,
path = path,
strength = self._calculate_path_strength(path),
)
)
continue
@@ -412,9 +412,9 @@ class KnowledgeReasoner:
continue
for neighbor in adj.get(current, []):
next_entity = neighbor["target"]
next_entity = neighbor["target"]
if next_entity not in [p["entity"] for p in path]: # 避免循环
new_path = path + [
new_path = path + [
{
"entity": next_entity,
"relation": neighbor["relation"],
@@ -424,7 +424,7 @@ class KnowledgeReasoner:
queue.append((next_entity, new_path))
# 按强度排序
paths.sort(key=lambda p: p.strength, reverse=True)
paths.sort(key = lambda p: p.strength, reverse = True)
return paths
def _calculate_path_strength(self, path: list[dict]) -> float:
@@ -433,23 +433,23 @@ class KnowledgeReasoner:
return 0.0
# 路径越短越强
length_factor = 1.0 / len(path)
length_factor = 1.0 / len(path)
# 关系置信度
confidence_sum = 0
confidence_count = 0
confidence_sum = 0
confidence_count = 0
for node in path[1:]: # 跳过第一个节点
rel_data = node.get("relation_data", {})
rel_data = node.get("relation_data", {})
if "confidence" in rel_data:
confidence_sum += rel_data["confidence"]
confidence_count += 1
confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5
confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5
return length_factor * confidence_factor
async def summarize_project(
self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive"
self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive"
) -> dict:
"""
项目智能总结
@@ -457,17 +457,17 @@ class KnowledgeReasoner:
Args:
summary_type: comprehensive/executive/technical/risk
"""
type_prompts = {
type_prompts = {
"comprehensive": "全面总结项目的所有方面",
"executive": "高管摘要,关注关键决策和风险",
"technical": "技术总结,关注架构和技术栈",
"risk": "风险分析,关注潜在问题和依赖",
}
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]}
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]}
## 知识图谱
实体数: {len(graph_data.get("entities", []))}
@@ -483,9 +483,9 @@ class KnowledgeReasoner:
"confidence": 0.85
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
@@ -504,11 +504,11 @@ class KnowledgeReasoner:
# Singleton instance
_reasoner = None
_reasoner = None
def get_knowledge_reasoner() -> KnowledgeReasoner:
global _reasoner
if _reasoner is None:
_reasoner = KnowledgeReasoner()
_reasoner = KnowledgeReasoner()
return _reasoner

View File

@@ -12,8 +12,8 @@ from dataclasses import dataclass
import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@dataclass
@@ -41,22 +41,22 @@ class RelationExtractionResult:
class LLMClient:
"""Kimi API 客户端"""
def __init__(self, api_key: str = None, base_url: str = None):
self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL
self.headers = {
def __init__(self, api_key: str = None, base_url: str = None) -> None:
self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
async def chat(
self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False
self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False
) -> str:
"""发送聊天请求"""
if not self.api_key:
raise ValueError("KIMI_API_KEY not set")
payload = {
payload = {
"model": "k2p5",
"messages": [{"role": m.role, "content": m.content} for m in messages],
"temperature": temperature,
@@ -64,24 +64,24 @@ class LLMClient:
}
async with httpx.AsyncClient() as client:
response = await client.post(
response = await client.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
headers = self.headers,
json = payload,
timeout = 120.0,
)
response.raise_for_status()
result = response.json()
result = response.json()
return result["choices"][0]["message"]["content"]
async def chat_stream(
self, messages: list[ChatMessage], temperature: float = 0.3
self, messages: list[ChatMessage], temperature: float = 0.3
) -> AsyncGenerator[str, None]:
"""流式聊天请求"""
if not self.api_key:
raise ValueError("KIMI_API_KEY not set")
payload = {
payload = {
"model": "k2p5",
"messages": [{"role": m.role, "content": m.content} for m in messages],
"temperature": temperature,
@@ -92,19 +92,19 @@ class LLMClient:
async with client.stream(
"POST",
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
headers = self.headers,
json = payload,
timeout = 120.0,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
data = line[6:]
if data == "[DONE]":
break
try:
chunk = json.loads(data)
delta = chunk["choices"][0]["delta"]
chunk = json.loads(data)
delta = chunk["choices"][0]["delta"]
if "content" in delta:
yield delta["content"]
except (json.JSONDecodeError, KeyError, IndexError):
@@ -114,7 +114,7 @@ class LLMClient:
self, text: str
) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]:
"""提取实体和关系,带置信度分数"""
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
文本:{text[:3000]}
@@ -139,30 +139,30 @@ class LLMClient:
]
}}"""
messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature=0.1)
messages = [ChatMessage(role = "user", content = prompt)]
content = await self.chat(messages, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match:
return [], []
try:
data = json.loads(json_match.group())
entities = [
data = json.loads(json_match.group())
entities = [
EntityExtractionResult(
name=e["name"],
type=e.get("type", "OTHER"),
definition=e.get("definition", ""),
confidence=e.get("confidence", 0.8),
name = e["name"],
type = e.get("type", "OTHER"),
definition = e.get("definition", ""),
confidence = e.get("confidence", 0.8),
)
for e in data.get("entities", [])
]
relations = [
relations = [
RelationExtractionResult(
source=r["source"],
target=r["target"],
type=r.get("type", "related"),
confidence=r.get("confidence", 0.8),
source = r["source"],
target = r["target"],
type = r.get("type", "related"),
confidence = r.get("confidence", 0.8),
)
for r in data.get("relations", [])
]
@@ -173,10 +173,10 @@ class LLMClient:
async def rag_query(self, query: str, context: str, project_context: dict) -> str:
"""RAG 问答 - 基于项目上下文回答问题"""
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)}
{json.dumps(project_context, ensure_ascii = False, indent = 2)}
## 相关上下文
{context[:4000]}
@@ -186,21 +186,21 @@ class LLMClient:
请用中文回答,保持简洁专业。如果信息不足,请明确说明。"""
messages = [
messages = [
ChatMessage(
role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
),
ChatMessage(role="user", content=prompt),
ChatMessage(role = "user", content = prompt),
]
return await self.chat(messages, temperature=0.3)
return await self.chat(messages, temperature = 0.3)
async def agent_command(self, command: str, project_context: dict) -> dict:
"""Agent 指令解析 - 将自然语言指令转换为结构化操作"""
prompt = f"""解析以下用户指令,转换为结构化操作:
prompt = f"""解析以下用户指令,转换为结构化操作:
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)}
{json.dumps(project_context, ensure_ascii = False, indent = 2)}
## 用户指令
{command}
@@ -221,10 +221,10 @@ class LLMClient:
- create_relation: 创建关系params 包含 source(源实体), target(目标实体), relation_type(关系类型)
"""
messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature=0.1)
messages = [ChatMessage(role = "user", content = prompt)]
content = await self.chat(messages, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match:
return {"intent": "unknown", "explanation": "无法解析指令"}
@@ -235,14 +235,14 @@ class LLMClient:
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
"""分析实体在项目中的演变/态度变化"""
mentions_text = "\n".join(
mentions_text = "\n".join(
[
f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}"
for m in mentions[:20]
] # 限制数量
)
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
## 提及记录
{mentions_text}
@@ -255,16 +255,16 @@ class LLMClient:
用中文回答,结构清晰。"""
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(messages, temperature=0.3)
messages = [ChatMessage(role = "user", content = prompt)]
return await self.chat(messages, temperature = 0.3)
# Singleton instance
_llm_client = None
_llm_client = None
def get_llm_client() -> LLMClient:
global _llm_client
if _llm_client is None:
_llm_client = LLMClient()
_llm_client = LLMClient()
return _llm_client

File diff suppressed because it is too large Load Diff

View File

@@ -9,13 +9,13 @@ from dataclasses import dataclass
from difflib import SequenceMatcher
# Constants
UUID_LENGTH = 8 # UUID 截断长度
UUID_LENGTH = 8 # UUID 截断长度
# 尝试导入embedding库
try:
NUMPY_AVAILABLE = True
NUMPY_AVAILABLE = True
except ImportError:
NUMPY_AVAILABLE = False
NUMPY_AVAILABLE = False
@dataclass
@@ -30,11 +30,11 @@ class MultimodalEntity:
source_id: str
mention_context: str
confidence: float
modality_features: dict = None # 模态特定特征
modality_features: dict = None # 模态特定特征
def __post_init__(self):
def __post_init__(self) -> None:
if self.modality_features is None:
self.modality_features = {}
self.modality_features = {}
@dataclass
@@ -78,7 +78,7 @@ class MultimodalEntityLinker:
"""多模态实体关联器 - 跨模态实体对齐和知识融合"""
# 关联类型
LINK_TYPES = {
LINK_TYPES = {
"same_as": "同一实体",
"related_to": "相关实体",
"part_of": "组成部分",
@@ -86,16 +86,16 @@ class MultimodalEntityLinker:
}
# 模态类型
MODALITIES = ["audio", "video", "image", "document"]
MODALITIES = ["audio", "video", "image", "document"]
def __init__(self, similarity_threshold: float = 0.85) -> None:
def __init__(self, similarity_threshold: float = 0.85) -> None:
"""
初始化多模态实体关联器
Args:
similarity_threshold: 相似度阈值
"""
self.similarity_threshold = similarity_threshold
self.similarity_threshold = similarity_threshold
def calculate_string_similarity(self, s1: str, s2: str) -> float:
"""
@@ -111,7 +111,7 @@ class MultimodalEntityLinker:
if not s1 or not s2:
return 0.0
s1, s2 = s1.lower().strip(), s2.lower().strip()
s1, s2 = s1.lower().strip(), s2.lower().strip()
# 完全匹配
if s1 == s2:
@@ -136,7 +136,7 @@ class MultimodalEntityLinker:
(相似度, 匹配类型)
"""
# 名称相似度
name_sim = self.calculate_string_similarity(
name_sim = self.calculate_string_similarity(
entity1.get("name", ""), entity2.get("name", "")
)
@@ -145,8 +145,8 @@ class MultimodalEntityLinker:
return 1.0, "exact"
# 检查别名
aliases1 = set(a.lower() for a in entity1.get("aliases", []))
aliases2 = set(a.lower() for a in entity2.get("aliases", []))
aliases1 = set(a.lower() for a in entity1.get("aliases", []))
aliases2 = set(a.lower() for a in entity2.get("aliases", []))
if aliases1 & aliases2: # 有共同别名
return 0.95, "alias_match"
@@ -157,12 +157,12 @@ class MultimodalEntityLinker:
return 0.95, "alias_match"
# 定义相似度
def_sim = self.calculate_string_similarity(
def_sim = self.calculate_string_similarity(
entity1.get("definition", ""), entity2.get("definition", "")
)
# 综合相似度
combined_sim = name_sim * 0.7 + def_sim * 0.3
combined_sim = name_sim * 0.7 + def_sim * 0.3
if combined_sim >= self.similarity_threshold:
return combined_sim, "fuzzy"
@@ -170,7 +170,7 @@ class MultimodalEntityLinker:
return combined_sim, "none"
def find_matching_entity(
self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None
self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None
) -> AlignmentResult | None:
"""
在候选实体中查找匹配的实体
@@ -183,28 +183,28 @@ class MultimodalEntityLinker:
Returns:
对齐结果
"""
exclude_ids = exclude_ids or set()
best_match = None
best_similarity = 0.0
exclude_ids = exclude_ids or set()
best_match = None
best_similarity = 0.0
for candidate in candidate_entities:
if candidate.get("id") in exclude_ids:
continue
similarity, match_type = self.calculate_entity_similarity(query_entity, candidate)
similarity, match_type = self.calculate_entity_similarity(query_entity, candidate)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = candidate
best_match_type = match_type
best_similarity = similarity
best_match = candidate
best_match_type = match_type
if best_match:
return AlignmentResult(
entity_id=query_entity.get("id"),
matched_entity_id=best_match.get("id"),
similarity=best_similarity,
match_type=best_match_type,
confidence=best_similarity,
entity_id = query_entity.get("id"),
matched_entity_id = best_match.get("id"),
similarity = best_similarity,
match_type = best_match_type,
confidence = best_similarity,
)
return None
@@ -230,10 +230,10 @@ class MultimodalEntityLinker:
Returns:
实体关联列表
"""
links = []
links = []
# 合并所有实体
all_entities = {
all_entities = {
"audio": audio_entities,
"video": video_entities,
"image": image_entities,
@@ -246,24 +246,24 @@ class MultimodalEntityLinker:
if mod1 >= mod2: # 避免重复比较
continue
entities1 = all_entities.get(mod1, [])
entities2 = all_entities.get(mod2, [])
entities1 = all_entities.get(mod1, [])
entities2 = all_entities.get(mod2, [])
for ent1 in entities1:
# 在另一个模态中查找匹配
result = self.find_matching_entity(ent1, entities2)
result = self.find_matching_entity(ent1, entities2)
if result and result.matched_entity_id:
link = EntityLink(
id=str(uuid.uuid4())[:UUID_LENGTH],
project_id=project_id,
source_entity_id=ent1.get("id"),
target_entity_id=result.matched_entity_id,
link_type="same_as" if result.similarity > 0.95 else "related_to",
source_modality=mod1,
target_modality=mod2,
confidence=result.confidence,
evidence=f"Cross-modal alignment: {result.match_type}",
link = EntityLink(
id = str(uuid.uuid4())[:UUID_LENGTH],
project_id = project_id,
source_entity_id = ent1.get("id"),
target_entity_id = result.matched_entity_id,
link_type = "same_as" if result.similarity > 0.95 else "related_to",
source_modality = mod1,
target_modality = mod2,
confidence = result.confidence,
evidence = f"Cross-modal alignment: {result.match_type}",
)
links.append(link)
@@ -284,7 +284,7 @@ class MultimodalEntityLinker:
融合结果
"""
# 收集所有属性
fused_properties = {
fused_properties = {
"names": set(),
"definitions": [],
"aliases": set(),
@@ -293,7 +293,7 @@ class MultimodalEntityLinker:
"contexts": [],
}
merged_ids = []
merged_ids = []
for entity in linked_entities:
merged_ids.append(entity.get("id"))
@@ -318,21 +318,21 @@ class MultimodalEntityLinker:
fused_properties["contexts"].append(mention.get("mention_context"))
# 选择最佳定义(最长的那个)
best_definition = (
max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else ""
best_definition = (
max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else ""
)
# 选择最佳名称(最常见的那个)
from collections import Counter
name_counts = Counter(fused_properties["names"])
best_name = name_counts.most_common(1)[0][0] if name_counts else ""
name_counts = Counter(fused_properties["names"])
best_name = name_counts.most_common(1)[0][0] if name_counts else ""
# 构建融合结果
return FusionResult(
canonical_entity_id=entity_id,
merged_entity_ids=merged_ids,
fused_properties={
canonical_entity_id = entity_id,
merged_entity_ids = merged_ids,
fused_properties = {
"name": best_name,
"definition": best_definition,
"aliases": list(fused_properties["aliases"]),
@@ -340,8 +340,8 @@ class MultimodalEntityLinker:
"modalities": list(fused_properties["modalities"]),
"contexts": fused_properties["contexts"][:10], # 最多10个上下文
},
source_modalities=list(fused_properties["modalities"]),
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
source_modalities = list(fused_properties["modalities"]),
confidence = min(1.0, len(linked_entities) * 0.2 + 0.5),
)
def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
@@ -354,30 +354,30 @@ class MultimodalEntityLinker:
Returns:
冲突列表
"""
conflicts = []
conflicts = []
# 按名称分组
name_groups = {}
name_groups = {}
for entity in entities:
name = entity.get("name", "").lower()
name = entity.get("name", "").lower()
if name:
if name not in name_groups:
name_groups[name] = []
name_groups[name] = []
name_groups[name].append(entity)
# 检测同名但定义不同的实体
for name, group in name_groups.items():
if len(group) > 1:
# 检查定义是否相似
definitions = [e.get("definition", "") for e in group if e.get("definition")]
definitions = [e.get("definition", "") for e in group if e.get("definition")]
if len(definitions) > 1:
# 计算定义之间的相似度
sim_matrix = []
sim_matrix = []
for i, d1 in enumerate(definitions):
for j, d2 in enumerate(definitions):
if i < j:
sim = self.calculate_string_similarity(d1, d2)
sim = self.calculate_string_similarity(d1, d2)
sim_matrix.append(sim)
# 如果定义相似度都很低,可能是冲突
@@ -394,7 +394,7 @@ class MultimodalEntityLinker:
return conflicts
def suggest_entity_merges(
self, entities: list[dict], existing_links: list[EntityLink] = None
self, entities: list[dict], existing_links: list[EntityLink] = None
) -> list[dict]:
"""
建议实体合并
@@ -406,13 +406,13 @@ class MultimodalEntityLinker:
Returns:
合并建议列表
"""
suggestions = []
existing_pairs = set()
suggestions = []
existing_pairs = set()
# 记录已有的关联
if existing_links:
for link in existing_links:
pair = tuple(sorted([link.source_entity_id, link.target_entity_id]))
pair = tuple(sorted([link.source_entity_id, link.target_entity_id]))
existing_pairs.add(pair)
# 检查所有实体对
@@ -422,12 +422,12 @@ class MultimodalEntityLinker:
continue
# 检查是否已有关联
pair = tuple(sorted([ent1.get("id"), ent2.get("id")]))
pair = tuple(sorted([ent1.get("id"), ent2.get("id")]))
if pair in existing_pairs:
continue
# 计算相似度
similarity, match_type = self.calculate_entity_similarity(ent1, ent2)
similarity, match_type = self.calculate_entity_similarity(ent1, ent2)
if similarity >= self.similarity_threshold:
suggestions.append(
@@ -441,7 +441,7 @@ class MultimodalEntityLinker:
)
# 按相似度排序
suggestions.sort(key=lambda x: x["similarity"], reverse=True)
suggestions.sort(key = lambda x: x["similarity"], reverse = True)
return suggestions
@@ -451,8 +451,8 @@ class MultimodalEntityLinker:
entity_id: str,
source_type: str,
source_id: str,
mention_context: str = "",
confidence: float = 1.0,
mention_context: str = "",
confidence: float = 1.0,
) -> MultimodalEntity:
"""
创建多模态实体记录
@@ -469,14 +469,14 @@ class MultimodalEntityLinker:
多模态实体记录
"""
return MultimodalEntity(
id=str(uuid.uuid4())[:UUID_LENGTH],
entity_id=entity_id,
project_id=project_id,
name="", # 将在后续填充
source_type=source_type,
source_id=source_id,
mention_context=mention_context,
confidence=confidence,
id = str(uuid.uuid4())[:UUID_LENGTH],
entity_id = entity_id,
project_id = project_id,
name = "", # 将在后续填充
source_type = source_type,
source_id = source_id,
mention_context = mention_context,
confidence = confidence,
)
def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:
@@ -489,7 +489,7 @@ class MultimodalEntityLinker:
Returns:
模态分布统计
"""
distribution = {mod: 0 for mod in self.MODALITIES}
distribution = {mod: 0 for mod in self.MODALITIES}
# 统计每个模态的实体数
for me in multimodal_entities:
@@ -497,13 +497,13 @@ class MultimodalEntityLinker:
distribution[me.source_type] += 1
# 统计跨模态实体
entity_modalities = {}
entity_modalities = {}
for me in multimodal_entities:
if me.entity_id not in entity_modalities:
entity_modalities[me.entity_id] = set()
entity_modalities[me.entity_id] = set()
entity_modalities[me.entity_id].add(me.source_type)
cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1)
cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1)
return {
"modality_distribution": distribution,
@@ -517,12 +517,12 @@ class MultimodalEntityLinker:
# Singleton instance
_multimodal_entity_linker = None
_multimodal_entity_linker = None
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
"""获取多模态实体关联器单例"""
global _multimodal_entity_linker
if _multimodal_entity_linker is None:
_multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold)
_multimodal_entity_linker = MultimodalEntityLinker(similarity_threshold)
return _multimodal_entity_linker

View File

@@ -13,30 +13,30 @@ from dataclasses import dataclass
from pathlib import Path
# Constants
UUID_LENGTH = 8 # UUID 截断长度
UUID_LENGTH = 8 # UUID 截断长度
# 尝试导入OCR库
try:
import pytesseract
from PIL import Image
PYTESSERACT_AVAILABLE = True
PYTESSERACT_AVAILABLE = True
except ImportError:
PYTESSERACT_AVAILABLE = False
PYTESSERACT_AVAILABLE = False
try:
import cv2
CV2_AVAILABLE = True
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
CV2_AVAILABLE = False
try:
import ffmpeg
FFMPEG_AVAILABLE = True
FFMPEG_AVAILABLE = True
except ImportError:
FFMPEG_AVAILABLE = False
FFMPEG_AVAILABLE = False
@dataclass
@@ -48,13 +48,13 @@ class VideoFrame:
frame_number: int
timestamp: float
frame_path: str
ocr_text: str = ""
ocr_confidence: float = 0.0
entities_detected: list[dict] = None
ocr_text: str = ""
ocr_confidence: float = 0.0
entities_detected: list[dict] = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.entities_detected is None:
self.entities_detected = []
self.entities_detected = []
@dataclass
@@ -65,20 +65,20 @@ class VideoInfo:
project_id: str
filename: str
file_path: str
duration: float = 0.0
width: int = 0
height: int = 0
fps: float = 0.0
audio_extracted: bool = False
audio_path: str = ""
transcript_id: str = ""
status: str = "pending"
error_message: str = ""
metadata: dict = None
duration: float = 0.0
width: int = 0
height: int = 0
fps: float = 0.0
audio_extracted: bool = False
audio_path: str = ""
transcript_id: str = ""
status: str = "pending"
error_message: str = ""
metadata: dict = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.metadata is None:
self.metadata = {}
self.metadata = {}
@dataclass
@@ -91,13 +91,13 @@ class VideoProcessingResult:
ocr_results: list[dict]
full_text: str # 整合的文本(音频转录 + OCR文本
success: bool
error_message: str = ""
error_message: str = ""
class MultimodalProcessor:
"""多模态处理器 - 处理视频文件"""
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
"""
初始化多模态处理器
@@ -105,16 +105,16 @@ class MultimodalProcessor:
temp_dir: 临时文件目录
frame_interval: 关键帧提取间隔(秒)
"""
self.temp_dir = temp_dir or tempfile.gettempdir()
self.frame_interval = frame_interval
self.video_dir = os.path.join(self.temp_dir, "videos")
self.frames_dir = os.path.join(self.temp_dir, "frames")
self.audio_dir = os.path.join(self.temp_dir, "audio")
self.temp_dir = temp_dir or tempfile.gettempdir()
self.frame_interval = frame_interval
self.video_dir = os.path.join(self.temp_dir, "videos")
self.frames_dir = os.path.join(self.temp_dir, "frames")
self.audio_dir = os.path.join(self.temp_dir, "audio")
# 创建目录
os.makedirs(self.video_dir, exist_ok=True)
os.makedirs(self.frames_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok=True)
os.makedirs(self.video_dir, exist_ok = True)
os.makedirs(self.frames_dir, exist_ok = True)
os.makedirs(self.audio_dir, exist_ok = True)
def extract_video_info(self, video_path: str) -> dict:
"""
@@ -128,11 +128,11 @@ class MultimodalProcessor:
"""
try:
if FFMPEG_AVAILABLE:
probe = ffmpeg.probe(video_path)
video_stream = next(
probe = ffmpeg.probe(video_path)
video_stream = next(
(s for s in probe["streams"] if s["codec_type"] == "video"), None
)
audio_stream = next(
audio_stream = next(
(s for s in probe["streams"] if s["codec_type"] == "audio"), None
)
@@ -147,21 +147,21 @@ class MultimodalProcessor:
}
else:
# 使用 ffprobe 命令行
cmd = [
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration,bit_rate",
"format = duration, bit_rate",
"-show_entries",
"stream=width,height,r_frame_rate",
"stream = width, height, r_frame_rate",
"-of",
"json",
video_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
result = subprocess.run(cmd, capture_output = True, text = True)
if result.returncode == 0:
data = json.loads(result.stdout)
data = json.loads(result.stdout)
return {
"duration": float(data["format"].get("duration", 0)),
"width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0,
@@ -177,7 +177,7 @@ class MultimodalProcessor:
return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0}
def extract_audio(self, video_path: str, output_path: str = None) -> str:
def extract_audio(self, video_path: str, output_path: str = None) -> str:
"""
从视频中提取音频
@@ -189,20 +189,20 @@ class MultimodalProcessor:
提取的音频文件路径
"""
if output_path is None:
video_name = Path(video_path).stem
output_path = os.path.join(self.audio_dir, f"{video_name}.wav")
video_name = Path(video_path).stem
output_path = os.path.join(self.audio_dir, f"{video_name}.wav")
try:
if FFMPEG_AVAILABLE:
(
ffmpeg.input(video_path)
.output(output_path, ac=1, ar=16000, vn=None)
.output(output_path, ac = 1, ar = 16000, vn = None)
.overwrite_output()
.run(quiet=True)
.run(quiet = True)
)
else:
# 使用命令行 ffmpeg
cmd = [
cmd = [
"ffmpeg",
"-i",
video_path,
@@ -216,14 +216,14 @@ class MultimodalProcessor:
"-y",
output_path,
]
subprocess.run(cmd, check=True, capture_output=True)
subprocess.run(cmd, check = True, capture_output = True)
return output_path
except Exception as e:
print(f"Error extracting audio: {e}")
raise
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
"""
从视频中提取关键帧
@@ -235,31 +235,31 @@ class MultimodalProcessor:
Returns:
提取的帧文件路径列表
"""
interval = interval or self.frame_interval
frame_paths = []
interval = interval or self.frame_interval
frame_paths = []
# 创建帧存储目录
video_frames_dir = os.path.join(self.frames_dir, video_id)
os.makedirs(video_frames_dir, exist_ok=True)
video_frames_dir = os.path.join(self.frames_dir, video_id)
os.makedirs(video_frames_dir, exist_ok = True)
try:
if CV2_AVAILABLE:
# 使用 OpenCV 提取帧
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval_frames = int(fps * interval)
frame_number = 0
frame_interval_frames = int(fps * interval)
frame_number = 0
while True:
ret, frame = cap.read()
ret, frame = cap.read()
if not ret:
break
if frame_number % frame_interval_frames == 0:
timestamp = frame_number / fps
frame_path = os.path.join(
timestamp = frame_number / fps
frame_path = os.path.join(
video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg"
)
cv2.imwrite(frame_path, frame)
@@ -271,23 +271,23 @@ class MultimodalProcessor:
else:
# 使用 ffmpeg 命令行提取帧
Path(video_path).stem
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
cmd = [
cmd = [
"ffmpeg",
"-i",
video_path,
"-vf",
f"fps=1/{interval}",
f"fps = 1/{interval}",
"-frame_pts",
"1",
"-y",
output_pattern,
]
subprocess.run(cmd, check=True, capture_output=True)
subprocess.run(cmd, check = True, capture_output = True)
# 获取生成的帧文件列表
frame_paths = sorted(
frame_paths = sorted(
[
os.path.join(video_frames_dir, f)
for f in os.listdir(video_frames_dir)
@@ -313,19 +313,19 @@ class MultimodalProcessor:
return "", 0.0
try:
image = Image.open(image_path)
image = Image.open(image_path)
# 预处理:转换为灰度图
if image.mode != "L":
image = image.convert("L")
image = image.convert("L")
# 使用 pytesseract 进行 OCR
text = pytesseract.image_to_string(image, lang="chi_sim+eng")
text = pytesseract.image_to_string(image, lang = "chi_sim+eng")
# 获取置信度数据
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
return text.strip(), avg_confidence / 100.0
except Exception as e:
@@ -333,7 +333,7 @@ class MultimodalProcessor:
return "", 0.0
def process_video(
self, video_data: bytes, filename: str, project_id: str, video_id: str = None
self, video_data: bytes, filename: str, project_id: str, video_id: str = None
) -> VideoProcessingResult:
"""
处理视频文件提取音频、关键帧、OCR
@@ -347,48 +347,48 @@ class MultimodalProcessor:
Returns:
视频处理结果
"""
video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH]
video_id = video_id or str(uuid.uuid4())[:UUID_LENGTH]
try:
# 保存视频文件
video_path = os.path.join(self.video_dir, f"{video_id}_{filename}")
video_path = os.path.join(self.video_dir, f"{video_id}_{filename}")
with open(video_path, "wb") as f:
f.write(video_data)
# 提取视频信息
video_info = self.extract_video_info(video_path)
video_info = self.extract_video_info(video_path)
# 提取音频
audio_path = ""
audio_path = ""
if video_info["has_audio"]:
audio_path = self.extract_audio(video_path)
audio_path = self.extract_audio(video_path)
# 提取关键帧
frame_paths = self.extract_keyframes(video_path, video_id)
frame_paths = self.extract_keyframes(video_path, video_id)
# 对关键帧进行 OCR
frames = []
ocr_results = []
all_ocr_text = []
frames = []
ocr_results = []
all_ocr_text = []
for i, frame_path in enumerate(frame_paths):
# 解析帧信息
frame_name = os.path.basename(frame_path)
parts = frame_name.replace(".jpg", "").split("_")
frame_number = int(parts[1]) if len(parts) > 1 else i
timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval
frame_name = os.path.basename(frame_path)
parts = frame_name.replace(".jpg", "").split("_")
frame_number = int(parts[1]) if len(parts) > 1 else i
timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval
# OCR 识别
ocr_text, confidence = self.perform_ocr(frame_path)
ocr_text, confidence = self.perform_ocr(frame_path)
frame = VideoFrame(
id=str(uuid.uuid4())[:UUID_LENGTH],
video_id=video_id,
frame_number=frame_number,
timestamp=timestamp,
frame_path=frame_path,
ocr_text=ocr_text,
ocr_confidence=confidence,
frame = VideoFrame(
id = str(uuid.uuid4())[:UUID_LENGTH],
video_id = video_id,
frame_number = frame_number,
timestamp = timestamp,
frame_path = frame_path,
ocr_text = ocr_text,
ocr_confidence = confidence,
)
frames.append(frame)
@@ -404,29 +404,29 @@ class MultimodalProcessor:
all_ocr_text.append(ocr_text)
# 整合所有 OCR 文本
full_ocr_text = "\n\n".join(all_ocr_text)
full_ocr_text = "\n\n".join(all_ocr_text)
return VideoProcessingResult(
video_id=video_id,
audio_path=audio_path,
frames=frames,
ocr_results=ocr_results,
full_text=full_ocr_text,
success=True,
video_id = video_id,
audio_path = audio_path,
frames = frames,
ocr_results = ocr_results,
full_text = full_ocr_text,
success = True,
)
except Exception as e:
return VideoProcessingResult(
video_id=video_id,
audio_path="",
frames=[],
ocr_results=[],
full_text="",
success=False,
error_message=str(e),
video_id = video_id,
audio_path = "",
frames = [],
ocr_results = [],
full_text = "",
success = False,
error_message = str(e),
)
def cleanup(self, video_id: str = None) -> None:
def cleanup(self, video_id: str = None) -> None:
"""
清理临时文件
@@ -438,7 +438,7 @@ class MultimodalProcessor:
if video_id:
# 清理特定视频的文件
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
target_dir = (
target_dir = (
os.path.join(dir_path, video_id) if dir_path == self.frames_dir else dir_path
)
if os.path.exists(target_dir):
@@ -450,16 +450,16 @@ class MultimodalProcessor:
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.makedirs(dir_path, exist_ok=True)
os.makedirs(dir_path, exist_ok = True)
# Singleton instance
_multimodal_processor = None
_multimodal_processor = None
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
"""获取多模态处理器单例"""
global _multimodal_processor
if _multimodal_processor is None:
_multimodal_processor = MultimodalProcessor(temp_dir, frame_interval)
_multimodal_processor = MultimodalProcessor(temp_dir, frame_interval)
return _multimodal_processor

View File

@@ -10,20 +10,20 @@ import logging
import os
from dataclasses import dataclass
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Neo4j 连接配置
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
# 延迟导入,避免未安装时出错
try:
from neo4j import Driver, GraphDatabase
NEO4J_AVAILABLE = True
NEO4J_AVAILABLE = True
except ImportError:
NEO4J_AVAILABLE = False
NEO4J_AVAILABLE = False
logger.warning("Neo4j driver not installed. Neo4j features will be disabled.")
@@ -35,15 +35,15 @@ class GraphEntity:
project_id: str
name: str
type: str
definition: str = ""
aliases: list[str] = None
properties: dict = None
definition: str = ""
aliases: list[str] = None
properties: dict = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.aliases is None:
self.aliases = []
self.aliases = []
if self.properties is None:
self.properties = {}
self.properties = {}
@dataclass
@@ -54,12 +54,12 @@ class GraphRelation:
source_id: str
target_id: str
relation_type: str
evidence: str = ""
properties: dict = None
evidence: str = ""
properties: dict = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.properties is None:
self.properties = {}
self.properties = {}
@dataclass
@@ -69,7 +69,7 @@ class PathResult:
nodes: list[dict]
relationships: list[dict]
length: int
total_weight: float = 0.0
total_weight: float = 0.0
@dataclass
@@ -79,7 +79,7 @@ class CommunityResult:
community_id: int
nodes: list[dict]
size: int
density: float = 0.0
density: float = 0.0
@dataclass
@@ -89,17 +89,17 @@ class CentralityResult:
entity_id: str
entity_name: str
score: float
rank: int = 0
rank: int = 0
class Neo4jManager:
"""Neo4j 图数据库管理器"""
def __init__(self, uri: str = None, user: str = None, password: str = None):
self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD
self._driver: Driver | None = None
def __init__(self, uri: str = None, user: str = None, password: str = None) -> None:
self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD
self._driver: Driver | None = None
if not NEO4J_AVAILABLE:
logger.error("Neo4j driver not available. Please install: pip install neo4j")
@@ -113,13 +113,13 @@ class Neo4jManager:
return
try:
self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password))
# 验证连接
self._driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}")
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"Failed to connect to Neo4j: {e}")
self._driver = None
self._driver = None
def close(self) -> None:
"""关闭连接"""
@@ -179,7 +179,7 @@ class Neo4jManager:
# ==================== 数据同步 ====================
def sync_project(
self, project_id: str, project_name: str, project_description: str = ""
self, project_id: str, project_name: str, project_description: str = ""
) -> None:
"""同步项目节点到 Neo4j"""
if not self._driver:
@@ -189,13 +189,13 @@ class Neo4jManager:
session.run(
"""
MERGE (p:Project {id: $project_id})
SET p.name = $name,
p.description = $description,
p.updated_at = datetime()
SET p.name = $name,
p.description = $description,
p.updated_at = datetime()
""",
project_id=project_id,
name=project_name,
description=project_description,
project_id = project_id,
name = project_name,
description = project_description,
)
def sync_entity(self, entity: GraphEntity) -> None:
@@ -208,23 +208,23 @@ class Neo4jManager:
session.run(
"""
MERGE (e:Entity {id: $id})
SET e.name = $name,
e.type = $type,
e.definition = $definition,
e.aliases = $aliases,
e.properties = $properties,
e.updated_at = datetime()
SET e.name = $name,
e.type = $type,
e.definition = $definition,
e.aliases = $aliases,
e.properties = $properties,
e.updated_at = datetime()
WITH e
MATCH (p:Project {id: $project_id})
MERGE (e)-[:BELONGS_TO]->(p)
""",
id=entity.id,
project_id=entity.project_id,
name=entity.name,
type=entity.type,
definition=entity.definition,
aliases=json.dumps(entity.aliases),
properties=json.dumps(entity.properties),
id = entity.id,
project_id = entity.project_id,
name = entity.name,
type = entity.type,
definition = entity.definition,
aliases = json.dumps(entity.aliases),
properties = json.dumps(entity.properties),
)
def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
@@ -234,7 +234,7 @@ class Neo4jManager:
with self._driver.session() as session:
# 使用 UNWIND 批量处理
entities_data = [
entities_data = [
{
"id": e.id,
"project_id": e.project_id,
@@ -251,17 +251,17 @@ class Neo4jManager:
"""
UNWIND $entities AS entity
MERGE (e:Entity {id: entity.id})
SET e.name = entity.name,
e.type = entity.type,
e.definition = entity.definition,
e.aliases = entity.aliases,
e.properties = entity.properties,
e.updated_at = datetime()
SET e.name = entity.name,
e.type = entity.type,
e.definition = entity.definition,
e.aliases = entity.aliases,
e.properties = entity.properties,
e.updated_at = datetime()
WITH e, entity
MATCH (p:Project {id: entity.project_id})
MERGE (e)-[:BELONGS_TO]->(p)
""",
entities=entities_data,
entities = entities_data,
)
def sync_relation(self, relation: GraphRelation) -> None:
@@ -275,17 +275,17 @@ class Neo4jManager:
MATCH (source:Entity {id: $source_id})
MATCH (target:Entity {id: $target_id})
MERGE (source)-[r:RELATES_TO {id: $id}]->(target)
SET r.relation_type = $relation_type,
r.evidence = $evidence,
r.properties = $properties,
r.updated_at = datetime()
SET r.relation_type = $relation_type,
r.evidence = $evidence,
r.properties = $properties,
r.updated_at = datetime()
""",
id=relation.id,
source_id=relation.source_id,
target_id=relation.target_id,
relation_type=relation.relation_type,
evidence=relation.evidence,
properties=json.dumps(relation.properties),
id = relation.id,
source_id = relation.source_id,
target_id = relation.target_id,
relation_type = relation.relation_type,
evidence = relation.evidence,
properties = json.dumps(relation.properties),
)
def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
@@ -294,7 +294,7 @@ class Neo4jManager:
return
with self._driver.session() as session:
relations_data = [
relations_data = [
{
"id": r.id,
"source_id": r.source_id,
@@ -312,12 +312,12 @@ class Neo4jManager:
MATCH (source:Entity {id: rel.source_id})
MATCH (target:Entity {id: rel.target_id})
MERGE (source)-[r:RELATES_TO {id: rel.id}]->(target)
SET r.relation_type = rel.relation_type,
r.evidence = rel.evidence,
r.properties = rel.properties,
r.updated_at = datetime()
SET r.relation_type = rel.relation_type,
r.evidence = rel.evidence,
r.properties = rel.properties,
r.updated_at = datetime()
""",
relations=relations_data,
relations = relations_data,
)
def delete_entity(self, entity_id: str) -> None:
@@ -331,7 +331,7 @@ class Neo4jManager:
MATCH (e:Entity {id: $id})
DETACH DELETE e
""",
id=entity_id,
id = entity_id,
)
def delete_project(self, project_id: str) -> None:
@@ -346,13 +346,13 @@ class Neo4jManager:
OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p)
DETACH DELETE e, p
""",
id=project_id,
id = project_id,
)
# ==================== 复杂图查询 ====================
def find_shortest_path(
self, source_id: str, target_id: str, max_depth: int = 10
self, source_id: str, target_id: str, max_depth: int = 10
) -> PathResult | None:
"""
查找两个实体之间的最短路径
@@ -369,31 +369,31 @@ class Neo4jManager:
return None
with self._driver.session() as session:
result = session.run(
result = session.run(
"""
MATCH path = shortestPath(
MATCH path = shortestPath(
(source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
)
RETURN path
""",
source_id=source_id,
target_id=target_id,
max_depth=max_depth,
source_id = source_id,
target_id = target_id,
max_depth = max_depth,
)
record = result.single()
record = result.single()
if not record:
return None
path = record["path"]
path = record["path"]
# 提取节点和关系
nodes = [
nodes = [
{"id": node["id"], "name": node["name"], "type": node["type"]}
for node in path.nodes
]
relationships = [
relationships = [
{
"source": rel.start_node["id"],
"target": rel.end_node["id"],
@@ -404,11 +404,11 @@ class Neo4jManager:
]
return PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships)
nodes = nodes, relationships = relationships, length = len(path.relationships)
)
def find_all_paths(
self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10
self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10
) -> list[PathResult]:
"""
查找两个实体之间的所有路径
@@ -426,29 +426,29 @@ class Neo4jManager:
return []
with self._driver.session() as session:
result = session.run(
result = session.run(
"""
MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
WHERE source <> target
RETURN path
LIMIT $limit
""",
source_id=source_id,
target_id=target_id,
max_depth=max_depth,
limit=limit,
source_id = source_id,
target_id = target_id,
max_depth = max_depth,
limit = limit,
)
paths = []
paths = []
for record in result:
path = record["path"]
path = record["path"]
nodes = [
nodes = [
{"id": node["id"], "name": node["name"], "type": node["type"]}
for node in path.nodes
]
relationships = [
relationships = [
{
"source": rel.start_node["id"],
"target": rel.end_node["id"],
@@ -460,14 +460,14 @@ class Neo4jManager:
paths.append(
PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships)
nodes = nodes, relationships = relationships, length = len(path.relationships)
)
)
return paths
def find_neighbors(
self, entity_id: str, relation_type: str = None, limit: int = 50
self, entity_id: str, relation_type: str = None, limit: int = 50
) -> list[dict]:
"""
查找实体的邻居节点
@@ -485,30 +485,30 @@ class Neo4jManager:
with self._driver.session() as session:
if relation_type:
result = session.run(
result = session.run(
"""
MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]-(neighbor:Entity)
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit
""",
entity_id=entity_id,
relation_type=relation_type,
limit=limit,
entity_id = entity_id,
relation_type = relation_type,
limit = limit,
)
else:
result = session.run(
result = session.run(
"""
MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO]-(neighbor:Entity)
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit
""",
entity_id=entity_id,
limit=limit,
entity_id = entity_id,
limit = limit,
)
neighbors = []
neighbors = []
for record in result:
node = record["neighbor"]
node = record["neighbor"]
neighbors.append(
{
"id": node["id"],
@@ -536,13 +536,13 @@ class Neo4jManager:
return []
with self._driver.session() as session:
result = session.run(
result = session.run(
"""
MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2})
RETURN DISTINCT common
""",
id1=entity_id1,
id2=entity_id2,
id1 = entity_id1,
id2 = entity_id2,
)
return [
@@ -556,7 +556,7 @@ class Neo4jManager:
# ==================== 图算法分析 ====================
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
"""
计算 PageRank 中心性
@@ -571,7 +571,7 @@ class Neo4jManager:
return []
with self._driver.session() as session:
result = session.run(
result = session.run(
"""
CALL gds.graph.exists('project-graph-$project_id') YIELD exists
WITH exists
@@ -581,7 +581,7 @@ class Neo4jManager:
{}
) YIELD value RETURN value
""",
project_id=project_id,
project_id = project_id,
)
# 创建临时图
@@ -601,11 +601,11 @@ class Neo4jManager:
}
)
""",
project_id=project_id,
project_id = project_id,
)
# 运行 PageRank
result = session.run(
result = session.run(
"""
CALL gds.pageRank.stream('project-graph-$project_id')
YIELD nodeId, score
@@ -615,19 +615,19 @@ class Neo4jManager:
ORDER BY score DESC
LIMIT $top_n
""",
project_id=project_id,
top_n=top_n,
project_id = project_id,
top_n = top_n,
)
rankings = []
rank = 1
rankings = []
rank = 1
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=record["score"],
rank=rank,
entity_id = record["entity_id"],
entity_name = record["entity_name"],
score = record["score"],
rank = rank,
)
)
rank += 1
@@ -637,12 +637,12 @@ class Neo4jManager:
"""
CALL gds.graph.drop('project-graph-$project_id')
""",
project_id=project_id,
project_id = project_id,
)
return rankings
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
"""
计算 Betweenness 中心性(桥梁作用)
@@ -658,7 +658,7 @@ class Neo4jManager:
with self._driver.session() as session:
# 使用 APOC 的 betweenness 计算(如果没有 GDS
result = session.run(
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
@@ -667,19 +667,19 @@ class Neo4jManager:
LIMIT $top_n
RETURN e.id as entity_id, e.name as entity_name, degree as score
""",
project_id=project_id,
top_n=top_n,
project_id = project_id,
top_n = top_n,
)
rankings = []
rank = 1
rankings = []
rank = 1
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=float(record["score"]),
rank=rank,
entity_id = record["entity_id"],
entity_name = record["entity_name"],
score = float(record["score"]),
rank = rank,
)
)
rank += 1
@@ -701,7 +701,7 @@ class Neo4jManager:
with self._driver.session() as session:
# 简单的社区检测:基于连通分量
result = session.run(
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)-[:BELONGS_TO]->(p)
@@ -710,25 +710,25 @@ class Neo4jManager:
connections, size(connections) as connection_count
ORDER BY connection_count DESC
""",
project_id=project_id,
project_id = project_id,
)
# 手动分组(基于连通性)
communities = {}
communities = {}
for record in result:
entity_id = record["entity_id"]
connections = record["connections"]
entity_id = record["entity_id"]
connections = record["connections"]
# 找到所属的社区
found_community = None
found_community = None
for comm_id, comm_data in communities.items():
if any(conn in comm_data["member_ids"] for conn in connections):
found_community = comm_id
found_community = comm_id
break
if found_community is None:
found_community = len(communities)
communities[found_community] = {"member_ids": set(), "nodes": []}
found_community = len(communities)
communities[found_community] = {"member_ids": set(), "nodes": []}
communities[found_community]["member_ids"].add(entity_id)
communities[found_community]["nodes"].append(
@@ -741,27 +741,27 @@ class Neo4jManager:
)
# 构建结果
results = []
results = []
for comm_id, comm_data in communities.items():
nodes = comm_data["nodes"]
size = len(nodes)
nodes = comm_data["nodes"]
size = len(nodes)
# 计算密度(简化版)
max_edges = size * (size - 1) / 2 if size > 1 else 1
actual_edges = sum(n["connections"] for n in nodes) / 2
density = actual_edges / max_edges if max_edges > 0 else 0
max_edges = size * (size - 1) / 2 if size > 1 else 1
actual_edges = sum(n["connections"] for n in nodes) / 2
density = actual_edges / max_edges if max_edges > 0 else 0
results.append(
CommunityResult(
community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0)
community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0)
)
)
# 按大小排序
results.sort(key=lambda x: x.size, reverse=True)
results.sort(key = lambda x: x.size, reverse = True)
return results
def find_central_entities(
self, project_id: str, metric: str = "degree"
self, project_id: str, metric: str = "degree"
) -> list[CentralityResult]:
"""
查找中心实体
@@ -778,7 +778,7 @@ class Neo4jManager:
with self._driver.session() as session:
if metric == "degree":
result = session.run(
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
@@ -787,11 +787,11 @@ class Neo4jManager:
ORDER BY degree DESC
LIMIT 20
""",
project_id=project_id,
project_id = project_id,
)
else:
# 默认使用度中心性
result = session.run(
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
@@ -800,18 +800,18 @@ class Neo4jManager:
ORDER BY degree DESC
LIMIT 20
""",
project_id=project_id,
project_id = project_id,
)
rankings = []
rank = 1
rankings = []
rank = 1
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=float(record["score"]),
rank=rank,
entity_id = record["entity_id"],
entity_name = record["entity_name"],
score = float(record["score"]),
rank = rank,
)
)
rank += 1
@@ -835,49 +835,49 @@ class Neo4jManager:
with self._driver.session() as session:
# 实体数量
entity_count = session.run(
entity_count = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN count(e) as count
""",
project_id=project_id,
project_id = project_id,
).single()["count"]
# 关系数量
relation_count = session.run(
relation_count = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
MATCH (e)-[r:RELATES_TO]-()
RETURN count(r) as count
""",
project_id=project_id,
project_id = project_id,
).single()["count"]
# 实体类型分布
type_distribution = session.run(
type_distribution = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN e.type as type, count(e) as count
ORDER BY count DESC
""",
project_id=project_id,
project_id = project_id,
)
types = {record["type"]: record["count"] for record in type_distribution}
types = {record["type"]: record["count"] for record in type_distribution}
# 平均度
avg_degree = session.run(
avg_degree = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other)
WITH e, count(other) as degree
RETURN avg(degree) as avg_degree
""",
project_id=project_id,
project_id = project_id,
).single()["avg_degree"]
# 关系类型分布
rel_types = session.run(
rel_types = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
MATCH (e)-[r:RELATES_TO]-()
@@ -885,10 +885,10 @@ class Neo4jManager:
ORDER BY count DESC
LIMIT 10
""",
project_id=project_id,
project_id = project_id,
)
relation_types = {record["type"]: record["count"] for record in rel_types}
relation_types = {record["type"]: record["count"] for record in rel_types}
return {
"entity_count": entity_count,
@@ -901,7 +901,7 @@ class Neo4jManager:
else 0,
}
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
"""
获取指定实体的子图
@@ -916,7 +916,7 @@ class Neo4jManager:
return {"nodes": [], "relationships": []}
with self._driver.session() as session:
result = session.run(
result = session.run(
"""
MATCH (e:Entity)
WHERE e.id IN $entity_ids
@@ -927,14 +927,14 @@ class Neo4jManager:
}) YIELD node
RETURN DISTINCT node
""",
entity_ids=entity_ids,
depth=depth,
entity_ids = entity_ids,
depth = depth,
)
nodes = []
node_ids = set()
nodes = []
node_ids = set()
for record in result:
node = record["node"]
node = record["node"]
node_ids.add(node["id"])
nodes.append(
{
@@ -946,17 +946,17 @@ class Neo4jManager:
)
# 获取这些节点之间的关系
result = session.run(
result = session.run(
"""
MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity)
WHERE source.id IN $node_ids AND target.id IN $node_ids
RETURN source.id as source_id, target.id as target_id,
r.relation_type as type, r.evidence as evidence
""",
node_ids=list(node_ids),
node_ids = list(node_ids),
)
relationships = [
relationships = [
{
"source": record["source_id"],
"target": record["target_id"],
@@ -970,14 +970,14 @@ class Neo4jManager:
# 全局单例
_neo4j_manager = None
_neo4j_manager = None
def get_neo4j_manager() -> Neo4jManager:
"""获取 Neo4j 管理器单例"""
global _neo4j_manager
if _neo4j_manager is None:
_neo4j_manager = Neo4jManager()
_neo4j_manager = Neo4jManager()
return _neo4j_manager
@@ -986,7 +986,7 @@ def close_neo4j_manager() -> None:
global _neo4j_manager
if _neo4j_manager:
_neo4j_manager.close()
_neo4j_manager = None
_neo4j_manager = None
# 便捷函数
@@ -1004,7 +1004,7 @@ def sync_project_to_neo4j(
entities: 实体列表(字典格式)
relations: 关系列表(字典格式)
"""
manager = get_neo4j_manager()
manager = get_neo4j_manager()
if not manager.is_connected():
logger.warning("Neo4j not connected, skipping sync")
return
@@ -1013,29 +1013,29 @@ def sync_project_to_neo4j(
manager.sync_project(project_id, project_name)
# 同步实体
graph_entities = [
graph_entities = [
GraphEntity(
id=e["id"],
project_id=project_id,
name=e["name"],
type=e.get("type", "unknown"),
definition=e.get("definition", ""),
aliases=e.get("aliases", []),
properties=e.get("properties", {}),
id = e["id"],
project_id = project_id,
name = e["name"],
type = e.get("type", "unknown"),
definition = e.get("definition", ""),
aliases = e.get("aliases", []),
properties = e.get("properties", {}),
)
for e in entities
]
manager.sync_entities_batch(graph_entities)
# 同步关系
graph_relations = [
graph_relations = [
GraphRelation(
id=r["id"],
source_id=r["source_entity_id"],
target_id=r["target_entity_id"],
relation_type=r["relation_type"],
evidence=r.get("evidence", ""),
properties=r.get("properties", {}),
id = r["id"],
source_id = r["source_entity_id"],
target_id = r["target_entity_id"],
relation_type = r["relation_type"],
evidence = r.get("evidence", ""),
properties = r.get("properties", {}),
)
for r in relations
]
@@ -1048,9 +1048,9 @@ def sync_project_to_neo4j(
if __name__ == "__main__":
# 测试代码
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level = logging.INFO)
manager = Neo4jManager()
manager = Neo4jManager()
if manager.is_connected():
print("✅ Connected to Neo4j")
@@ -1064,18 +1064,18 @@ if __name__ == "__main__":
print("✅ Project synced")
# 测试实体
test_entity = GraphEntity(
id="test-entity-1",
project_id="test-project",
name="Test Entity",
type="Person",
definition="A test entity",
test_entity = GraphEntity(
id = "test-entity-1",
project_id = "test-project",
name = "Test Entity",
type = "Person",
definition = "A test entity",
)
manager.sync_entity(test_entity)
print("✅ Entity synced")
# 获取统计
stats = manager.get_graph_stats("test-project")
stats = manager.get_graph_stats("test-project")
print(f"📊 Graph stats: {stats}")
else:

File diff suppressed because it is too large Load Diff

View File

@@ -11,30 +11,30 @@ import oss2
class OSSUploader:
def __init__(self):
self.access_key = os.getenv("ALI_ACCESS_KEY")
self.secret_key = os.getenv("ALI_SECRET_KEY")
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")
self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com")
self.endpoint = f"https://{self.region}"
def __init__(self) -> None:
self.access_key = os.getenv("ALI_ACCESS_KEY")
self.secret_key = os.getenv("ALI_SECRET_KEY")
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")
self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com")
self.endpoint = f"https://{self.region}"
if not self.access_key or not self.secret_key:
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY must be set")
self.auth = oss2.Auth(self.access_key, self.secret_key)
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
self.auth = oss2.Auth(self.access_key, self.secret_key)
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
def upload_audio(self, audio_data: bytes, filename: str) -> tuple:
"""上传音频到 OSS返回 (URL, object_name)"""
# 生成唯一文件名
ext = os.path.splitext(filename)[1] or ".wav"
object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}"
ext = os.path.splitext(filename)[1] or ".wav"
object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}"
# 上传文件
self.bucket.put_object(object_name, audio_data)
# 生成临时访问 URL (1小时有效)
url = self.bucket.sign_url("GET", object_name, 3600)
url = self.bucket.sign_url("GET", object_name, 3600)
return url, object_name
def delete_object(self, object_name: str) -> None:
@@ -43,11 +43,11 @@ class OSSUploader:
# 单例
_oss_uploader = None
_oss_uploader = None
def get_oss_uploader() -> OSSUploader:
global _oss_uploader
if _oss_uploader is None:
_oss_uploader = OSSUploader()
_oss_uploader = OSSUploader()
return _oss_uploader

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -17,9 +17,9 @@ from functools import wraps
class RateLimitConfig:
"""限流配置"""
requests_per_minute: int = 60
burst_size: int = 10 # 突发请求数
window_size: int = 60 # 窗口大小(秒)
requests_per_minute: int = 60
burst_size: int = 10 # 突发请求数
window_size: int = 60 # 窗口大小(秒)
@dataclass
@@ -35,16 +35,16 @@ class RateLimitInfo:
class SlidingWindowCounter:
"""滑动窗口计数器"""
def __init__(self, window_size: int = 60):
self.window_size = window_size
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
def __init__(self, window_size: int = 60) -> None:
self.window_size = window_size
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
async def add_request(self) -> int:
"""添加请求,返回当前窗口内的请求数"""
async with self._lock:
now = int(time.time())
now = int(time.time())
self.requests[now] += 1
self._cleanup_old(now)
return sum(self.requests.values())
@@ -52,14 +52,14 @@ class SlidingWindowCounter:
async def get_count(self) -> int:
"""获取当前窗口内的请求数"""
async with self._lock:
now = int(time.time())
now = int(time.time())
self._cleanup_old(now)
return sum(self.requests.values())
def _cleanup_old(self, now: int) -> None:
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
cutoff = now - self.window_size
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
cutoff = now - self.window_size
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
for k in old_keys:
self.requests.pop(k, None)
@@ -69,13 +69,13 @@ class RateLimiter:
def __init__(self) -> None:
# key -> SlidingWindowCounter
self.counters: dict[str, SlidingWindowCounter] = {}
self.counters: dict[str, SlidingWindowCounter] = {}
# key -> RateLimitConfig
self.configs: dict[str, RateLimitConfig] = {}
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
self.configs: dict[str, RateLimitConfig] = {}
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
"""
检查是否允许请求
@@ -87,70 +87,70 @@ class RateLimiter:
RateLimitInfo
"""
if config is None:
config = RateLimitConfig()
config = RateLimitConfig()
async with self._lock:
if key not in self.counters:
self.counters[key] = SlidingWindowCounter(config.window_size)
self.configs[key] = config
self.counters[key] = SlidingWindowCounter(config.window_size)
self.configs[key] = config
counter = self.counters[key]
stored_config = self.configs.get(key, config)
counter = self.counters[key]
stored_config = self.configs.get(key, config)
# 获取当前计数
current_count = await counter.get_count()
current_count = await counter.get_count()
# 计算剩余配额
remaining = max(0, stored_config.requests_per_minute - current_count)
remaining = max(0, stored_config.requests_per_minute - current_count)
# 计算重置时间
now = int(time.time())
reset_time = now + stored_config.window_size
now = int(time.time())
reset_time = now + stored_config.window_size
# 检查是否超过限制
if current_count >= stored_config.requests_per_minute:
return RateLimitInfo(
allowed=False,
remaining=0,
reset_time=reset_time,
retry_after=stored_config.window_size,
allowed = False,
remaining = 0,
reset_time = reset_time,
retry_after = stored_config.window_size,
)
# 允许请求,增加计数
await counter.add_request()
return RateLimitInfo(
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0
)
async def get_limit_info(self, key: str) -> RateLimitInfo:
"""获取限流信息(不增加计数)"""
if key not in self.counters:
config = RateLimitConfig()
config = RateLimitConfig()
return RateLimitInfo(
allowed=True,
remaining=config.requests_per_minute,
reset_time=int(time.time()) + config.window_size,
retry_after=0,
allowed = True,
remaining = config.requests_per_minute,
reset_time = int(time.time()) + config.window_size,
retry_after = 0,
)
counter = self.counters[key]
config = self.configs.get(key, RateLimitConfig())
counter = self.counters[key]
config = self.configs.get(key, RateLimitConfig())
current_count = await counter.get_count()
remaining = max(0, config.requests_per_minute - current_count)
reset_time = int(time.time()) + config.window_size
current_count = await counter.get_count()
remaining = max(0, config.requests_per_minute - current_count)
reset_time = int(time.time()) + config.window_size
return RateLimitInfo(
allowed=current_count < config.requests_per_minute,
remaining=remaining,
reset_time=reset_time,
retry_after=max(0, config.window_size)
allowed = current_count < config.requests_per_minute,
remaining = remaining,
reset_time = reset_time,
retry_after = max(0, config.window_size)
if current_count >= config.requests_per_minute
else 0,
)
def reset(self, key: str | None = None) -> None:
def reset(self, key: str | None = None) -> None:
"""重置限流计数器"""
if key:
self.counters.pop(key, None)
@@ -161,21 +161,21 @@ class RateLimiter:
# 全局限流器实例
_rate_limiter: RateLimiter | None = None
_rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter:
"""获取限流器实例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
_rate_limiter = RateLimiter()
return _rate_limiter
# 限流装饰器(用于函数级别限流)
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
"""
限流装饰器
@@ -184,14 +184,14 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
key_func: 生成限流键的函数,默认为 None使用函数名
"""
def decorator(func):
limiter = get_rate_limiter()
config = RateLimitConfig(requests_per_minute=requests_per_minute)
def decorator(func) -> None:
limiter = get_rate_limiter()
config = RateLimitConfig(requests_per_minute = requests_per_minute)
@wraps(func)
async def async_wrapper(*args, **kwargs):
key = key_func(*args, **kwargs) if key_func else func.__name__
info = await limiter.is_allowed(key, config)
async def async_wrapper(*args, **kwargs) -> None:
key = key_func(*args, **kwargs) if key_func else func.__name__
info = await limiter.is_allowed(key, config)
if not info.allowed:
raise RateLimitExceeded(
@@ -201,10 +201,10 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
key = key_func(*args, **kwargs) if key_func else func.__name__
def sync_wrapper(*args, **kwargs) -> None:
key = key_func(*args, **kwargs) if key_func else func.__name__
# 同步版本使用 asyncio.run
info = asyncio.run(limiter.is_allowed(key, config))
info = asyncio.run(limiter.is_allowed(key, config))
if not info.allowed:
raise RateLimitExceeded(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -10,9 +10,9 @@ import sys
# 添加 backend 目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
print("=" * 60)
print(" = " * 60)
print("InsightFlow 多模态模块测试")
print("=" * 60)
print(" = " * 60)
# 测试导入
print("\n1. 测试模块导入...")
@@ -42,7 +42,7 @@ except ImportError as e:
print("\n2. 测试模块初始化...")
try:
processor = get_multimodal_processor()
processor = get_multimodal_processor()
print(" ✓ MultimodalProcessor 初始化成功")
print(f" - 临时目录: {processor.temp_dir}")
print(f" - 帧提取间隔: {processor.frame_interval}")
@@ -50,14 +50,14 @@ except Exception as e:
print(f" ✗ MultimodalProcessor 初始化失败: {e}")
try:
img_processor = get_image_processor()
img_processor = get_image_processor()
print(" ✓ ImageProcessor 初始化成功")
print(f" - 临时目录: {img_processor.temp_dir}")
except Exception as e:
print(f" ✗ ImageProcessor 初始化失败: {e}")
try:
linker = get_multimodal_entity_linker()
linker = get_multimodal_entity_linker()
print(" ✓ MultimodalEntityLinker 初始化成功")
print(f" - 相似度阈值: {linker.similarity_threshold}")
except Exception as e:
@@ -67,20 +67,20 @@ except Exception as e:
print("\n3. 测试实体关联功能...")
try:
linker = get_multimodal_entity_linker()
linker = get_multimodal_entity_linker()
# 测试字符串相似度
sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha")
sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha")
assert sim == 1.0, "完全匹配应该返回1.0"
print(f" ✓ 字符串相似度计算正常 (完全匹配: {sim})")
sim = linker.calculate_string_similarity("K8s", "Kubernetes")
sim = linker.calculate_string_similarity("K8s", "Kubernetes")
print(f" ✓ 字符串相似度计算正常 (不同字符串: {sim:.2f})")
# 测试实体相似度
entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"}
entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"}
sim, match_type = linker.calculate_entity_similarity(entity1, entity2)
entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"}
entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"}
sim, match_type = linker.calculate_entity_similarity(entity1, entity2)
print(f" ✓ 实体相似度计算正常 (相似度: {sim:.2f}, 类型: {match_type})")
except Exception as e:
@@ -90,7 +90,7 @@ except Exception as e:
print("\n4. 测试图片处理器功能...")
try:
processor = get_image_processor()
processor = get_image_processor()
# 测试图片类型检测(使用模拟数据)
print(f" ✓ 支持的图片类型: {list(processor.IMAGE_TYPES.keys())}")
@@ -103,7 +103,7 @@ except Exception as e:
print("\n5. 测试视频处理器配置...")
try:
processor = get_multimodal_processor()
processor = get_multimodal_processor()
print(f" ✓ 视频目录: {processor.video_dir}")
print(f" ✓ 帧目录: {processor.frames_dir}")
@@ -129,11 +129,11 @@ print("\n6. 测试数据库多模态方法...")
try:
from db_manager import get_db_manager
db = get_db_manager()
db = get_db_manager()
# 检查多模态表是否存在
conn = db.get_conn()
tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"]
conn = db.get_conn()
tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"]
for table in tables:
try:
@@ -147,6 +147,6 @@ try:
except Exception as e:
print(f" ✗ 数据库多模态方法测试失败: {e}")
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试完成")
print("=" * 60)
print(" = " * 60)

View File

@@ -21,27 +21,27 @@ from search_manager import (
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_fulltext_search():
def test_fulltext_search() -> None:
"""测试全文搜索"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试全文搜索 (FullTextSearch)")
print("=" * 60)
print(" = " * 60)
search = FullTextSearch()
search = FullTextSearch()
# 测试索引创建
print("\n1. 测试索引创建...")
success = search.index_content(
content_id="test_entity_1",
content_type="entity",
project_id="test_project",
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
success = search.index_content(
content_id = "test_entity_1",
content_type = "entity",
project_id = "test_project",
text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
)
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
# 测试搜索
print("\n2. 测试关键词搜索...")
results = search.search("测试", project_id="test_project")
results = search.search("测试", project_id = "test_project")
print(f" 搜索结果数量: {len(results)}")
if results:
print(f" 第一个结果: {results[0].content[:50]}...")
@@ -49,28 +49,28 @@ def test_fulltext_search():
# 测试布尔搜索
print("\n3. 测试布尔搜索...")
results = search.search("测试 AND 全文", project_id="test_project")
results = search.search("测试 AND 全文", project_id = "test_project")
print(f" AND 搜索结果: {len(results)}")
results = search.search("测试 OR 关键词", project_id="test_project")
results = search.search("测试 OR 关键词", project_id = "test_project")
print(f" OR 搜索结果: {len(results)}")
# 测试高亮
print("\n4. 测试文本高亮...")
highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文")
highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文")
print(f" 高亮结果: {highlighted}")
print("\n✓ 全文搜索测试完成")
return True
def test_semantic_search():
def test_semantic_search() -> None:
"""测试语义搜索"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试语义搜索 (SemanticSearch)")
print("=" * 60)
print(" = " * 60)
semantic = SemanticSearch()
semantic = SemanticSearch()
# 检查可用性
print(f"\n1. 语义搜索可用性: {'✓ 可用' if semantic.is_available() else '✗ 不可用'}")
@@ -81,18 +81,18 @@ def test_semantic_search():
# 测试 embedding 生成
print("\n2. 测试 embedding 生成...")
embedding = semantic.generate_embedding("这是一个测试句子")
embedding = semantic.generate_embedding("这是一个测试句子")
if embedding:
print(f" Embedding 维度: {len(embedding)}")
print(f" 前5个值: {embedding[:5]}")
# 测试索引
print("\n3. 测试语义索引...")
success = semantic.index_embedding(
content_id="test_content_1",
content_type="transcript",
project_id="test_project",
text="这是用于语义搜索测试的文本内容。",
success = semantic.index_embedding(
content_id = "test_content_1",
content_type = "transcript",
project_id = "test_project",
text = "这是用于语义搜索测试的文本内容。",
)
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
@@ -100,13 +100,13 @@ def test_semantic_search():
return True
def test_entity_path_discovery():
def test_entity_path_discovery() -> None:
"""测试实体路径发现"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试实体路径发现 (EntityPathDiscovery)")
print("=" * 60)
print(" = " * 60)
discovery = EntityPathDiscovery()
discovery = EntityPathDiscovery()
print("\n1. 测试路径发现初始化...")
print(f" 数据库路径: {discovery.db_path}")
@@ -119,13 +119,13 @@ def test_entity_path_discovery():
return True
def test_knowledge_gap_detection():
def test_knowledge_gap_detection() -> None:
"""测试知识缺口识别"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试知识缺口识别 (KnowledgeGapDetection)")
print("=" * 60)
print(" = " * 60)
detection = KnowledgeGapDetection()
detection = KnowledgeGapDetection()
print("\n1. 测试缺口检测初始化...")
print(f" 数据库路径: {detection.db_path}")
@@ -138,32 +138,32 @@ def test_knowledge_gap_detection():
return True
def test_cache_manager():
def test_cache_manager() -> None:
"""测试缓存管理器"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试缓存管理器 (CacheManager)")
print("=" * 60)
print(" = " * 60)
cache = CacheManager()
cache = CacheManager()
print(f"\n1. 缓存后端: {'Redis' if cache.use_redis else '内存 LRU'}")
print("\n2. 测试缓存操作...")
# 设置缓存
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60)
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60)
print(" ✓ 设置缓存 test_key_1")
# 获取缓存
_ = cache.get("test_key_1")
_ = cache.get("test_key_1")
print(" ✓ 获取缓存: {value}")
# 批量操作
cache.set_many(
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60
)
print(" ✓ 批量设置缓存")
_ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
_ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
print(" ✓ 批量获取缓存: {len(values)} 个")
# 删除缓存
@@ -171,7 +171,7 @@ def test_cache_manager():
print(" ✓ 删除缓存 test_key_1")
# 获取统计
stats = cache.get_stats()
stats = cache.get_stats()
print("\n3. 缓存统计:")
print(f" 总请求数: {stats['total_requests']}")
print(f" 命中数: {stats['hits']}")
@@ -186,13 +186,13 @@ def test_cache_manager():
return True
def test_task_queue():
def test_task_queue() -> None:
"""测试任务队列"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试任务队列 (TaskQueue)")
print("=" * 60)
print(" = " * 60)
queue = TaskQueue()
queue = TaskQueue()
print(f"\n1. 任务队列可用性: {'✓ 可用' if queue.is_available() else '✗ 不可用'}")
print(f" 后端: {'Celery' if queue.use_celery else '内存'}")
@@ -200,25 +200,25 @@ def test_task_queue():
print("\n2. 测试任务提交...")
# 定义测试任务处理器
def test_task_handler(payload):
def test_task_handler(payload) -> None:
print(f" 执行任务: {payload}")
return {"status": "success", "processed": True}
queue.register_handler("test_task", test_task_handler)
# 提交任务
task_id = queue.submit(
task_type="test_task", payload={"test": "data", "timestamp": time.time()}
task_id = queue.submit(
task_type = "test_task", payload = {"test": "data", "timestamp": time.time()}
)
print(" ✓ 提交任务: {task_id}")
# 获取任务状态
task_info = queue.get_status(task_id)
task_info = queue.get_status(task_id)
if task_info:
print(" ✓ 任务状态: {task_info.status}")
# 获取统计
stats = queue.get_stats()
stats = queue.get_stats()
print("\n3. 任务队列统计:")
print(f" 后端: {stats['backend']}")
print(f" 按状态统计: {stats.get('by_status', {})}")
@@ -227,38 +227,38 @@ def test_task_queue():
return True
def test_performance_monitor():
def test_performance_monitor() -> None:
"""测试性能监控"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试性能监控 (PerformanceMonitor)")
print("=" * 60)
print(" = " * 60)
monitor = PerformanceMonitor()
monitor = PerformanceMonitor()
print("\n1. 测试指标记录...")
# 记录一些测试指标
for i in range(5):
monitor.record_metric(
metric_type="api_response",
duration_ms=50 + i * 10,
endpoint="/api/v1/test",
metadata={"test": True},
metric_type = "api_response",
duration_ms = 50 + i * 10,
endpoint = "/api/v1/test",
metadata = {"test": True},
)
for i in range(3):
monitor.record_metric(
metric_type="db_query",
duration_ms=20 + i * 5,
endpoint="SELECT test",
metadata={"test": True},
metric_type = "db_query",
duration_ms = 20 + i * 5,
endpoint = "SELECT test",
metadata = {"test": True},
)
print(" ✓ 记录了 8 个测试指标")
# 获取统计
print("\n2. 获取性能统计...")
stats = monitor.get_stats(hours=1)
stats = monitor.get_stats(hours = 1)
print(f" 总请求数: {stats['overall']['total_requests']}")
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
@@ -274,19 +274,19 @@ def test_performance_monitor():
return True
def test_search_manager():
def test_search_manager() -> None:
"""测试搜索管理器"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试搜索管理器 (SearchManager)")
print("=" * 60)
print(" = " * 60)
manager = get_search_manager()
manager = get_search_manager()
print("\n1. 搜索管理器初始化...")
print(" ✓ 搜索管理器已初始化")
print("\n2. 获取搜索统计...")
stats = manager.get_search_stats()
stats = manager.get_search_stats()
print(f" 全文索引数: {stats['fulltext_indexed']}")
print(f" 语义索引数: {stats['semantic_indexed']}")
print(f" 语义搜索可用: {stats['semantic_search_available']}")
@@ -295,24 +295,24 @@ def test_search_manager():
return True
def test_performance_manager():
def test_performance_manager() -> None:
"""测试性能管理器"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试性能管理器 (PerformanceManager)")
print("=" * 60)
print(" = " * 60)
manager = get_performance_manager()
manager = get_performance_manager()
print("\n1. 性能管理器初始化...")
print(" ✓ 性能管理器已初始化")
print("\n2. 获取系统健康状态...")
health = manager.get_health_status()
health = manager.get_health_status()
print(f" 缓存后端: {health['cache']['backend']}")
print(f" 任务队列后端: {health['task_queue']['backend']}")
print("\n3. 获取完整统计...")
stats = manager.get_full_stats()
stats = manager.get_full_stats()
print(f" 缓存统计: {stats['cache']['total_requests']} 请求")
print(f" 任务队列统计: {stats['task_queue']}")
@@ -320,14 +320,14 @@ def test_performance_manager():
return True
def run_all_tests():
def run_all_tests() -> None:
"""运行所有测试"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("InsightFlow Phase 7 Task 6 & 8 测试")
print("高级搜索与发现 + 性能优化与扩展")
print("=" * 60)
print(" = " * 60)
results = []
results = []
# 搜索模块测试
try:
@@ -386,15 +386,15 @@ def run_all_tests():
results.append(("性能管理器", False))
# 打印测试汇总
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试汇总")
print("=" * 60)
print(" = " * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
passed = sum(1 for _, result in results if result)
total = len(results)
for name, result in results:
status = "✓ 通过" if result else "✗ 失败"
status = "✓ 通过" if result else "✗ 失败"
print(f" {status} - {name}")
print(f"\n总计: {passed}/{total} 测试通过")
@@ -408,5 +408,5 @@ def run_all_tests():
if __name__ == "__main__":
success = run_all_tests()
success = run_all_tests()
sys.exit(0 if success else 1)

View File

@@ -18,18 +18,18 @@ from tenant_manager import get_tenant_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_tenant_management():
def test_tenant_management() -> None:
"""测试租户管理功能"""
print("=" * 60)
print(" = " * 60)
print("测试 1: 租户管理")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
manager = get_tenant_manager()
# 1. 创建租户
print("\n1.1 创建租户...")
tenant = manager.create_tenant(
name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant"
tenant = manager.create_tenant(
name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant"
)
print(f"✅ 租户创建成功: {tenant.id}")
print(f" - 名称: {tenant.name}")
@@ -40,43 +40,43 @@ def test_tenant_management():
# 2. 获取租户
print("\n1.2 获取租户信息...")
fetched = manager.get_tenant(tenant.id)
fetched = manager.get_tenant(tenant.id)
assert fetched is not None, "获取租户失败"
print(f"✅ 获取租户成功: {fetched.name}")
# 3. 通过 slug 获取
print("\n1.3 通过 slug 获取租户...")
by_slug = manager.get_tenant_by_slug(tenant.slug)
by_slug = manager.get_tenant_by_slug(tenant.slug)
assert by_slug is not None, "通过 slug 获取失败"
print(f"✅ 通过 slug 获取成功: {by_slug.name}")
# 4. 更新租户
print("\n1.4 更新租户信息...")
updated = manager.update_tenant(
tenant_id=tenant.id, name="Test Company Updated", tier="enterprise"
updated = manager.update_tenant(
tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise"
)
assert updated is not None, "更新租户失败"
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
# 5. 列出租户
print("\n1.5 列出租户...")
tenants = manager.list_tenants(limit=10)
tenants = manager.list_tenants(limit = 10)
print(f"✅ 找到 {len(tenants)} 个租户")
return tenant.id
def test_domain_management(tenant_id: str):
def test_domain_management(tenant_id: str) -> None:
"""测试域名管理功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 2: 域名管理")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
manager = get_tenant_manager()
# 1. 添加域名
print("\n2.1 添加自定义域名...")
domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True)
domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True)
print(f"✅ 域名添加成功: {domain.domain}")
print(f" - ID: {domain.id}")
print(f" - 状态: {domain.status}")
@@ -84,19 +84,19 @@ def test_domain_management(tenant_id: str):
# 2. 获取验证指导
print("\n2.2 获取域名验证指导...")
instructions = manager.get_domain_verification_instructions(domain.id)
instructions = manager.get_domain_verification_instructions(domain.id)
print("✅ 验证指导:")
print(f" - DNS 记录: {instructions['dns_record']}")
print(f" - 文件验证: {instructions['file_verification']}")
# 3. 验证域名
print("\n2.3 验证域名...")
verified = manager.verify_domain(tenant_id, domain.id)
verified = manager.verify_domain(tenant_id, domain.id)
print(f"✅ 域名验证结果: {verified}")
# 4. 通过域名获取租户
print("\n2.4 通过域名获取租户...")
by_domain = manager.get_tenant_by_domain("test.example.com")
by_domain = manager.get_tenant_by_domain("test.example.com")
if by_domain:
print(f"✅ 通过域名获取租户成功: {by_domain.name}")
else:
@@ -104,7 +104,7 @@ def test_domain_management(tenant_id: str):
# 5. 列出域名
print("\n2.5 列出所有域名...")
domains = manager.list_domains(tenant_id)
domains = manager.list_domains(tenant_id)
print(f"✅ 找到 {len(domains)} 个域名")
for d in domains:
print(f" - {d.domain} ({d.status})")
@@ -112,25 +112,25 @@ def test_domain_management(tenant_id: str):
return domain.id
def test_branding_management(tenant_id: str):
def test_branding_management(tenant_id: str) -> None:
"""测试品牌白标功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 3: 品牌白标")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
manager = get_tenant_manager()
# 1. 更新品牌配置
print("\n3.1 更新品牌配置...")
branding = manager.update_branding(
tenant_id=tenant_id,
logo_url="https://example.com/logo.png",
favicon_url="https://example.com/favicon.ico",
primary_color="#1890ff",
secondary_color="#52c41a",
custom_css=".header { background: #1890ff; }",
custom_js="console.log('Custom JS loaded');",
login_page_bg="https://example.com/bg.jpg",
branding = manager.update_branding(
tenant_id = tenant_id,
logo_url = "https://example.com/logo.png",
favicon_url = "https://example.com/favicon.ico",
primary_color = "#1890ff",
secondary_color = "#52c41a",
custom_css = ".header { background: #1890ff; }",
custom_js = "console.log('Custom JS loaded');",
login_page_bg = "https://example.com/bg.jpg",
)
print("✅ 品牌配置更新成功")
print(f" - Logo: {branding.logo_url}")
@@ -139,67 +139,67 @@ def test_branding_management(tenant_id: str):
# 2. 获取品牌配置
print("\n3.2 获取品牌配置...")
fetched = manager.get_branding(tenant_id)
fetched = manager.get_branding(tenant_id)
assert fetched is not None, "获取品牌配置失败"
print("✅ 获取品牌配置成功")
# 3. 生成品牌 CSS
print("\n3.3 生成品牌 CSS...")
css = manager.get_branding_css(tenant_id)
css = manager.get_branding_css(tenant_id)
print(f"✅ 生成 CSS 成功 ({len(css)} 字符)")
print(f" CSS 预览:\n{css[:200]}...")
return branding.id
def test_member_management(tenant_id: str):
def test_member_management(tenant_id: str) -> None:
"""测试成员管理功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 4: 成员管理")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
manager = get_tenant_manager()
# 1. 邀请成员
print("\n4.1 邀请成员...")
member1 = manager.invite_member(
tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001"
member1 = manager.invite_member(
tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001"
)
print(f"✅ 成员邀请成功: {member1.email}")
print(f" - ID: {member1.id}")
print(f" - 角色: {member1.role}")
print(f" - 权限: {member1.permissions}")
member2 = manager.invite_member(
tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001"
member2 = manager.invite_member(
tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001"
)
print(f"✅ 成员邀请成功: {member2.email}")
# 2. 接受邀请
print("\n4.2 接受邀请...")
accepted = manager.accept_invitation(member1.id, "user_002")
accepted = manager.accept_invitation(member1.id, "user_002")
print(f"✅ 邀请接受结果: {accepted}")
# 3. 列出成员
print("\n4.3 列出所有成员...")
members = manager.list_members(tenant_id)
members = manager.list_members(tenant_id)
print(f"✅ 找到 {len(members)} 个成员")
for m in members:
print(f" - {m.email} ({m.role}) - {m.status}")
# 4. 检查权限
print("\n4.4 检查权限...")
can_manage = manager.check_permission(tenant_id, "user_002", "project", "create")
can_manage = manager.check_permission(tenant_id, "user_002", "project", "create")
print(f"✅ user_002 可以创建项目: {can_manage}")
# 5. 更新成员角色
print("\n4.5 更新成员角色...")
updated = manager.update_member_role(tenant_id, member2.id, "viewer")
updated = manager.update_member_role(tenant_id, member2.id, "viewer")
print(f"✅ 角色更新结果: {updated}")
# 6. 获取用户所属租户
print("\n4.6 获取用户所属租户...")
user_tenants = manager.get_user_tenants("user_002")
user_tenants = manager.get_user_tenants("user_002")
print(f"✅ user_002 属于 {len(user_tenants)} 个租户")
for t in user_tenants:
print(f" - {t['name']} ({t['member_role']})")
@@ -207,30 +207,30 @@ def test_member_management(tenant_id: str):
return member1.id, member2.id
def test_usage_tracking(tenant_id: str):
def test_usage_tracking(tenant_id: str) -> None:
"""测试资源使用统计功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 5: 资源使用统计")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
manager = get_tenant_manager()
# 1. 记录使用
print("\n5.1 记录资源使用...")
manager.record_usage(
tenant_id=tenant_id,
storage_bytes=1024 * 1024 * 50, # 50MB
transcription_seconds=600, # 10分钟
api_calls=100,
projects_count=5,
entities_count=50,
members_count=3,
tenant_id = tenant_id,
storage_bytes = 1024 * 1024 * 50, # 50MB
transcription_seconds = 600, # 10分钟
api_calls = 100,
projects_count = 5,
entities_count = 50,
members_count = 3,
)
print("✅ 资源使用记录成功")
# 2. 获取使用统计
print("\n5.2 获取使用统计...")
stats = manager.get_usage_stats(tenant_id)
stats = manager.get_usage_stats(tenant_id)
print("✅ 使用统计:")
print(f" - 存储: {stats['storage_mb']:.2f} MB")
print(f" - 转录: {stats['transcription_minutes']:.2f} 分钟")
@@ -243,19 +243,19 @@ def test_usage_tracking(tenant_id: str):
# 3. 检查资源限制
print("\n5.3 检查资源限制...")
for resource in ["storage", "transcription", "api_calls", "projects", "entities", "members"]:
allowed, current, limit = manager.check_resource_limit(tenant_id, resource)
allowed, current, limit = manager.check_resource_limit(tenant_id, resource)
print(f" - {resource}: {current}/{limit} ({'' if allowed else ''})")
return stats
def cleanup(tenant_id: str, domain_id: str, member_ids: list):
def cleanup(tenant_id: str, domain_id: str, member_ids: list) -> None:
"""清理测试数据"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("清理测试数据")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
manager = get_tenant_manager()
# 移除成员
for member_id in member_ids:
@@ -273,28 +273,28 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list):
print(f"✅ 租户已删除: {tenant_id}")
def main():
def main() -> None:
"""主测试函数"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试")
print("=" * 60)
print(" = " * 60)
tenant_id = None
domain_id = None
member_ids = []
tenant_id = None
domain_id = None
member_ids = []
try:
# 运行所有测试
tenant_id = test_tenant_management()
domain_id = test_domain_management(tenant_id)
tenant_id = test_tenant_management()
domain_id = test_domain_management(tenant_id)
test_branding_management(tenant_id)
m1, m2 = test_member_management(tenant_id)
member_ids = [m1, m2]
m1, m2 = test_member_management(tenant_id)
member_ids = [m1, m2]
test_usage_tracking(tenant_id)
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("✅ 所有测试通过!")
print("=" * 60)
print(" = " * 60)
except Exception as e:
print(f"\n❌ 测试失败: {e}")

View File

@@ -12,31 +12,31 @@ from subscription_manager import PaymentProvider, SubscriptionManager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_subscription_manager():
def test_subscription_manager() -> None:
"""测试订阅管理器"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试")
print("=" * 60)
print(" = " * 60)
# 使用临时文件数据库进行测试
db_path = tempfile.mktemp(suffix=".db")
db_path = tempfile.mktemp(suffix = ".db")
try:
manager = SubscriptionManager(db_path=db_path)
manager = SubscriptionManager(db_path = db_path)
print("\n1. 测试订阅计划管理")
print("-" * 40)
# 获取默认计划
plans = manager.list_plans()
plans = manager.list_plans()
print(f"✓ 默认计划数量: {len(plans)}")
for plan in plans:
print(f" - {plan.name} ({plan.tier}): ¥{plan.price_monthly}/月")
# 通过 tier 获取计划
free_plan = manager.get_plan_by_tier("free")
pro_plan = manager.get_plan_by_tier("pro")
enterprise_plan = manager.get_plan_by_tier("enterprise")
free_plan = manager.get_plan_by_tier("free")
pro_plan = manager.get_plan_by_tier("pro")
enterprise_plan = manager.get_plan_by_tier("enterprise")
assert free_plan is not None, "Free 计划应该存在"
assert pro_plan is not None, "Pro 计划应该存在"
@@ -49,14 +49,14 @@ def test_subscription_manager():
print("\n2. 测试订阅管理")
print("-" * 40)
tenant_id = "test-tenant-001"
tenant_id = "test-tenant-001"
# 创建订阅
subscription = manager.create_subscription(
tenant_id=tenant_id,
plan_id=pro_plan.id,
payment_provider=PaymentProvider.STRIPE.value,
trial_days=14,
subscription = manager.create_subscription(
tenant_id = tenant_id,
plan_id = pro_plan.id,
payment_provider = PaymentProvider.STRIPE.value,
trial_days = 14,
)
print(f"✓ 创建订阅: {subscription.id}")
@@ -66,7 +66,7 @@ def test_subscription_manager():
print(f" - 试用结束: {subscription.trial_end}")
# 获取租户订阅
tenant_sub = manager.get_tenant_subscription(tenant_id)
tenant_sub = manager.get_tenant_subscription(tenant_id)
assert tenant_sub is not None, "应该能获取到租户订阅"
print(f"✓ 获取租户订阅: {tenant_sub.id}")
@@ -74,27 +74,27 @@ def test_subscription_manager():
print("-" * 40)
# 记录转录用量
usage1 = manager.record_usage(
tenant_id=tenant_id,
resource_type="transcription",
quantity=120,
unit="minute",
description="会议转录",
usage1 = manager.record_usage(
tenant_id = tenant_id,
resource_type = "transcription",
quantity = 120,
unit = "minute",
description = "会议转录",
)
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
# 记录存储用量
usage2 = manager.record_usage(
tenant_id=tenant_id,
resource_type="storage",
quantity=2.5,
unit="gb",
description="文件存储",
usage2 = manager.record_usage(
tenant_id = tenant_id,
resource_type = "storage",
quantity = 2.5,
unit = "gb",
description = "文件存储",
)
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
# 获取用量汇总
summary = manager.get_usage_summary(tenant_id)
summary = manager.get_usage_summary(tenant_id)
print("✓ 用量汇总:")
print(f" - 总费用: ¥{summary['total_cost']:.2f}")
for resource, data in summary["breakdown"].items():
@@ -104,12 +104,12 @@ def test_subscription_manager():
print("-" * 40)
# 创建支付
payment = manager.create_payment(
tenant_id=tenant_id,
amount=99.0,
currency="CNY",
provider=PaymentProvider.ALIPAY.value,
payment_method="qrcode",
payment = manager.create_payment(
tenant_id = tenant_id,
amount = 99.0,
currency = "CNY",
provider = PaymentProvider.ALIPAY.value,
payment_method = "qrcode",
)
print(f"✓ 创建支付: {payment.id}")
print(f" - 金额: ¥{payment.amount}")
@@ -117,22 +117,22 @@ def test_subscription_manager():
print(f" - 状态: {payment.status}")
# 确认支付
confirmed = manager.confirm_payment(payment.id, "alipay_123456")
confirmed = manager.confirm_payment(payment.id, "alipay_123456")
print(f"✓ 确认支付完成: {confirmed.status}")
# 列出支付记录
payments = manager.list_payments(tenant_id)
payments = manager.list_payments(tenant_id)
print(f"✓ 支付记录数量: {len(payments)}")
print("\n5. 测试发票管理")
print("-" * 40)
# 列出发票
invoices = manager.list_invoices(tenant_id)
invoices = manager.list_invoices(tenant_id)
print(f"✓ 发票数量: {len(invoices)}")
if invoices:
invoice = invoices[0]
invoice = invoices[0]
print(f" - 发票号: {invoice.invoice_number}")
print(f" - 金额: ¥{invoice.amount_due}")
print(f" - 状态: {invoice.status}")
@@ -141,12 +141,12 @@ def test_subscription_manager():
print("-" * 40)
# 申请退款
refund = manager.request_refund(
tenant_id=tenant_id,
payment_id=payment.id,
amount=50.0,
reason="服务不满意",
requested_by="user_001",
refund = manager.request_refund(
tenant_id = tenant_id,
payment_id = payment.id,
amount = 50.0,
reason = "服务不满意",
requested_by = "user_001",
)
print(f"✓ 申请退款: {refund.id}")
print(f" - 金额: ¥{refund.amount}")
@@ -154,21 +154,21 @@ def test_subscription_manager():
print(f" - 状态: {refund.status}")
# 批准退款
approved = manager.approve_refund(refund.id, "admin_001")
approved = manager.approve_refund(refund.id, "admin_001")
print(f"✓ 批准退款: {approved.status}")
# 完成退款
completed = manager.complete_refund(refund.id, "refund_123456")
completed = manager.complete_refund(refund.id, "refund_123456")
print(f"✓ 完成退款: {completed.status}")
# 列出退款记录
refunds = manager.list_refunds(tenant_id)
refunds = manager.list_refunds(tenant_id)
print(f"✓ 退款记录数量: {len(refunds)}")
print("\n7. 测试账单历史")
print("-" * 40)
history = manager.get_billing_history(tenant_id)
history = manager.get_billing_history(tenant_id)
print(f"✓ 账单历史记录数量: {len(history)}")
for h in history:
print(f" - [{h.type}] {h.description}: ¥{h.amount}")
@@ -177,24 +177,24 @@ def test_subscription_manager():
print("-" * 40)
# Stripe Checkout
stripe_session = manager.create_stripe_checkout_session(
tenant_id=tenant_id,
plan_id=enterprise_plan.id,
success_url="https://example.com/success",
cancel_url="https://example.com/cancel",
stripe_session = manager.create_stripe_checkout_session(
tenant_id = tenant_id,
plan_id = enterprise_plan.id,
success_url = "https://example.com/success",
cancel_url = "https://example.com/cancel",
)
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
# 支付宝订单
alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id)
alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id)
print(f"✓ 支付宝订单: {alipay_order['order_id']}")
# 微信支付订单
wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id)
wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id)
print(f"✓ 微信支付订单: {wechat_order['order_id']}")
# Webhook 处理
webhook_result = manager.handle_webhook(
webhook_result = manager.handle_webhook(
"stripe",
{"event_type": "checkout.session.completed", "data": {"object": {"id": "cs_test"}}},
)
@@ -204,19 +204,19 @@ def test_subscription_manager():
print("-" * 40)
# 更改计划
changed = manager.change_plan(
subscription_id=subscription.id, new_plan_id=enterprise_plan.id
changed = manager.change_plan(
subscription_id = subscription.id, new_plan_id = enterprise_plan.id
)
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
# 取消订阅
cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True)
cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True)
print(f"✓ 取消订阅: {cancelled.status}")
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("所有测试通过! ✓")
print("=" * 60)
print(" = " * 60)
finally:
# 清理临时数据库

View File

@@ -14,31 +14,31 @@ from ai_manager import ModelType, PredictionType, get_ai_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_custom_model():
def test_custom_model() -> None:
"""测试自定义模型功能"""
print("\n=== 测试自定义模型 ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 1. 创建自定义模型
print("1. 创建自定义模型...")
model = manager.create_custom_model(
tenant_id="tenant_001",
name="领域实体识别模型",
description="用于识别医疗领域实体的自定义模型",
model_type=ModelType.CUSTOM_NER,
training_data={
model = manager.create_custom_model(
tenant_id = "tenant_001",
name = "领域实体识别模型",
description = "用于识别医疗领域实体的自定义模型",
model_type = ModelType.CUSTOM_NER,
training_data = {
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
"domain": "medical",
},
hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
created_by="user_001",
hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
created_by = "user_001",
)
print(f" 创建成功: {model.id}, 状态: {model.status.value}")
# 2. 添加训练样本
print("2. 添加训练样本...")
samples = [
samples = [
{
"text": "患者张三患有高血压,正在服用降压药治疗。",
"entities": [
@@ -66,22 +66,22 @@ def test_custom_model():
]
for sample_data in samples:
sample = manager.add_training_sample(
model_id=model.id,
text=sample_data["text"],
entities=sample_data["entities"],
metadata={"source": "manual"},
sample = manager.add_training_sample(
model_id = model.id,
text = sample_data["text"],
entities = sample_data["entities"],
metadata = {"source": "manual"},
)
print(f" 添加样本: {sample.id}")
# 3. 获取训练样本
print("3. 获取训练样本...")
all_samples = manager.get_training_samples(model.id)
all_samples = manager.get_training_samples(model.id)
print(f" 共有 {len(all_samples)} 个训练样本")
# 4. 列出自定义模型
print("4. 列出自定义模型...")
models = manager.list_custom_models(tenant_id="tenant_001")
models = manager.list_custom_models(tenant_id = "tenant_001")
print(f" 找到 {len(models)} 个模型")
for m in models:
print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
@@ -89,16 +89,16 @@ def test_custom_model():
return model.id
async def test_train_and_predict(model_id: str):
async def test_train_and_predict(model_id: str) -> None:
"""测试训练和预测"""
print("\n=== 测试模型训练和预测 ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 1. 训练模型
print("1. 训练模型...")
try:
trained_model = await manager.train_custom_model(model_id)
trained_model = await manager.train_custom_model(model_id)
print(f" 训练完成: {trained_model.status.value}")
print(f" 指标: {trained_model.metrics}")
except Exception as e:
@@ -107,50 +107,50 @@ async def test_train_and_predict(model_id: str):
# 2. 使用模型预测
print("2. 使用模型预测...")
test_text = "赵六患有糖尿病,正在使用胰岛素治疗。"
test_text = "赵六患有糖尿病,正在使用胰岛素治疗。"
try:
entities = await manager.predict_with_custom_model(model_id, test_text)
entities = await manager.predict_with_custom_model(model_id, test_text)
print(f" 输入: {test_text}")
print(f" 预测实体: {entities}")
except Exception as e:
print(f" 预测失败: {e}")
def test_prediction_models():
def test_prediction_models() -> None:
"""测试预测模型"""
print("\n=== 测试预测模型 ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 1. 创建趋势预测模型
print("1. 创建趋势预测模型...")
trend_model = manager.create_prediction_model(
tenant_id="tenant_001",
project_id="project_001",
name="实体数量趋势预测",
prediction_type=PredictionType.TREND,
target_entity_type="PERSON",
features=["entity_count", "time_period", "document_count"],
model_config={"algorithm": "linear_regression", "window_size": 7},
trend_model = manager.create_prediction_model(
tenant_id = "tenant_001",
project_id = "project_001",
name = "实体数量趋势预测",
prediction_type = PredictionType.TREND,
target_entity_type = "PERSON",
features = ["entity_count", "time_period", "document_count"],
model_config = {"algorithm": "linear_regression", "window_size": 7},
)
print(f" 创建成功: {trend_model.id}")
# 2. 创建异常检测模型
print("2. 创建异常检测模型...")
anomaly_model = manager.create_prediction_model(
tenant_id="tenant_001",
project_id="project_001",
name="实体增长异常检测",
prediction_type=PredictionType.ANOMALY,
target_entity_type=None,
features=["daily_growth", "weekly_growth"],
model_config={"threshold": 2.5, "sensitivity": "medium"},
anomaly_model = manager.create_prediction_model(
tenant_id = "tenant_001",
project_id = "project_001",
name = "实体增长异常检测",
prediction_type = PredictionType.ANOMALY,
target_entity_type = None,
features = ["daily_growth", "weekly_growth"],
model_config = {"threshold": 2.5, "sensitivity": "medium"},
)
print(f" 创建成功: {anomaly_model.id}")
# 3. 列出预测模型
print("3. 列出预测模型...")
models = manager.list_prediction_models(tenant_id="tenant_001")
models = manager.list_prediction_models(tenant_id = "tenant_001")
print(f" 找到 {len(models)} 个预测模型")
for m in models:
print(f" - {m.name} ({m.prediction_type.value})")
@@ -158,15 +158,15 @@ def test_prediction_models():
return trend_model.id, anomaly_model.id
async def test_predictions(trend_model_id: str, anomaly_model_id: str):
async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None:
"""测试预测功能"""
print("\n=== 测试预测功能 ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 1. 训练趋势预测模型
print("1. 训练趋势预测模型...")
historical_data = [
historical_data = [
{"date": "2024-01-01", "value": 10},
{"date": "2024-01-02", "value": 12},
{"date": "2024-01-03", "value": 15},
@@ -175,62 +175,62 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
{"date": "2024-01-06", "value": 20},
{"date": "2024-01-07", "value": 22},
]
trained = await manager.train_prediction_model(trend_model_id, historical_data)
trained = await manager.train_prediction_model(trend_model_id, historical_data)
print(f" 训练完成,准确率: {trained.accuracy}")
# 2. 趋势预测
print("2. 趋势预测...")
trend_result = await manager.predict(
trend_result = await manager.predict(
trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]}
)
print(f" 预测结果: {trend_result.prediction_data}")
# 3. 异常检测
print("3. 异常检测...")
anomaly_result = await manager.predict(
anomaly_result = await manager.predict(
anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]}
)
print(f" 检测结果: {anomaly_result.prediction_data}")
def test_kg_rag():
def test_kg_rag() -> None:
"""测试知识图谱 RAG"""
print("\n=== 测试知识图谱 RAG ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 创建 RAG 配置
print("1. 创建知识图谱 RAG 配置...")
rag = manager.create_kg_rag(
tenant_id="tenant_001",
project_id="project_001",
name="项目知识问答",
description="基于项目知识图谱的智能问答",
kg_config={
rag = manager.create_kg_rag(
tenant_id = "tenant_001",
project_id = "project_001",
name = "项目知识问答",
description = "基于项目知识图谱的智能问答",
kg_config = {
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
"relation_types": ["works_with", "belongs_to", "depends_on"],
},
retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
)
print(f" 创建成功: {rag.id}")
# 列出 RAG 配置
print("2. 列出 RAG 配置...")
rags = manager.list_kg_rags(tenant_id="tenant_001")
rags = manager.list_kg_rags(tenant_id = "tenant_001")
print(f" 找到 {len(rags)} 个配置")
return rag.id
async def test_kg_rag_query(rag_id: str):
async def test_kg_rag_query(rag_id: str) -> None:
"""测试 RAG 查询"""
print("\n=== 测试知识图谱 RAG 查询 ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 模拟项目实体和关系
project_entities = [
project_entities = [
{"id": "e1", "name": "张三", "type": "PERSON", "definition": "项目经理"},
{"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"},
{"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"},
@@ -238,7 +238,7 @@ async def test_kg_rag_query(rag_id: str):
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"},
]
project_relations = [
project_relations = [
{
"source_entity_id": "e1",
"target_entity_id": "e3",
@@ -275,14 +275,14 @@ async def test_kg_rag_query(rag_id: str):
# 执行查询
print("1. 执行 RAG 查询...")
query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?"
query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?"
try:
result = await manager.query_kg_rag(
rag_id=rag_id,
query=query_text,
project_entities=project_entities,
project_relations=project_relations,
result = await manager.query_kg_rag(
rag_id = rag_id,
query = query_text,
project_entities = project_entities,
project_relations = project_relations,
)
print(f" 查询: {result.query}")
@@ -294,14 +294,14 @@ async def test_kg_rag_query(rag_id: str):
print(f" 查询失败: {e}")
async def test_smart_summary():
async def test_smart_summary() -> None:
"""测试智能摘要"""
print("\n=== 测试智能摘要 ===")
manager = get_ai_manager()
manager = get_ai_manager()
# 模拟转录文本
transcript_text = """
transcript_text = """
今天的会议主要讨论了 Project Alpha 的进展情况。张三作为项目经理,
汇报了当前的项目进度,表示已经完成了 80% 的开发工作。李四提出了
一些关于 Kubernetes 部署的问题,建议我们采用新的部署策略。
@@ -309,7 +309,7 @@ async def test_smart_summary():
大家一致认为项目进展顺利,预计可以按时交付。
"""
content_data = {
content_data = {
"text": transcript_text,
"entities": [
{"name": "张三", "type": "PERSON"},
@@ -320,18 +320,18 @@ async def test_smart_summary():
}
# 生成不同类型的摘要
summary_types = ["extractive", "abstractive", "key_points"]
summary_types = ["extractive", "abstractive", "key_points"]
for summary_type in summary_types:
print(f"1. 生成 {summary_type} 类型摘要...")
try:
summary = await manager.generate_smart_summary(
tenant_id="tenant_001",
project_id="project_001",
source_type="transcript",
source_id="transcript_001",
summary_type=summary_type,
content_data=content_data,
summary = await manager.generate_smart_summary(
tenant_id = "tenant_001",
project_id = "project_001",
source_type = "transcript",
source_id = "transcript_001",
summary_type = summary_type,
content_data = content_data,
)
print(f" 摘要类型: {summary.summary_type}")
@@ -342,27 +342,27 @@ async def test_smart_summary():
print(f" 生成失败: {e}")
async def main():
async def main() -> None:
"""主测试函数"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 4 - AI 能力增强测试")
print("=" * 60)
print(" = " * 60)
try:
# 测试自定义模型
model_id = test_custom_model()
model_id = test_custom_model()
# 测试训练和预测
await test_train_and_predict(model_id)
# 测试预测模型
trend_model_id, anomaly_model_id = test_prediction_models()
trend_model_id, anomaly_model_id = test_prediction_models()
# 测试预测功能
await test_predictions(trend_model_id, anomaly_model_id)
# 测试知识图谱 RAG
rag_id = test_kg_rag()
rag_id = test_kg_rag()
# 测试 RAG 查询
await test_kg_rag_query(rag_id)
@@ -370,9 +370,9 @@ async def main():
# 测试智能摘要
await test_smart_summary()
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("所有测试完成!")
print("=" * 60)
print(" = " * 60)
except Exception as e:
print(f"\n测试失败: {e}")

View File

@@ -28,7 +28,7 @@ from growth_manager import (
)
# 添加 backend 目录到路径
backend_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
@@ -36,35 +36,35 @@ if backend_dir not in sys.path:
class TestGrowthManager:
"""测试 Growth Manager 功能"""
def __init__(self):
self.manager = GrowthManager()
self.test_tenant_id = "test_tenant_001"
self.test_user_id = "test_user_001"
self.test_results = []
def __init__(self) -> None:
self.manager = GrowthManager()
self.test_tenant_id = "test_tenant_001"
self.test_user_id = "test_user_001"
self.test_results = []
def log(self, message: str, success: bool = True):
def log(self, message: str, success: bool = True) -> None:
"""记录测试结果"""
status = "" if success else ""
status = "" if success else ""
print(f"{status} {message}")
self.test_results.append((message, success))
# ==================== 测试用户行为分析 ====================
async def test_track_event(self):
async def test_track_event(self) -> None:
"""测试事件追踪"""
print("\n📊 测试事件追踪...")
try:
event = await self.manager.track_event(
tenant_id=self.test_tenant_id,
user_id=self.test_user_id,
event_type=EventType.PAGE_VIEW,
event_name="dashboard_view",
properties={"page": "/dashboard", "duration": 120},
session_id="session_001",
device_info={"browser": "Chrome", "os": "MacOS"},
referrer="https://google.com",
utm_params={"source": "google", "medium": "organic", "campaign": "summer"},
event = await self.manager.track_event(
tenant_id = self.test_tenant_id,
user_id = self.test_user_id,
event_type = EventType.PAGE_VIEW,
event_name = "dashboard_view",
properties = {"page": "/dashboard", "duration": 120},
session_id = "session_001",
device_info = {"browser": "Chrome", "os": "MacOS"},
referrer = "https://google.com",
utm_params = {"source": "google", "medium": "organic", "campaign": "summer"},
)
assert event.id is not None
@@ -74,15 +74,15 @@ class TestGrowthManager:
self.log(f"事件追踪成功: {event.id}")
return True
except Exception as e:
self.log(f"事件追踪失败: {e}", success=False)
self.log(f"事件追踪失败: {e}", success = False)
return False
async def test_track_multiple_events(self):
async def test_track_multiple_events(self) -> None:
"""测试追踪多个事件"""
print("\n📊 测试追踪多个事件...")
try:
events = [
events = [
(EventType.FEATURE_USE, "entity_extraction", {"entity_count": 5}),
(EventType.FEATURE_USE, "relation_discovery", {"relation_count": 3}),
(EventType.CONVERSION, "upgrade_click", {"plan": "pro"}),
@@ -91,25 +91,25 @@ class TestGrowthManager:
for event_type, event_name, props in events:
await self.manager.track_event(
tenant_id=self.test_tenant_id,
user_id=self.test_user_id,
event_type=event_type,
event_name=event_name,
properties=props,
tenant_id = self.test_tenant_id,
user_id = self.test_user_id,
event_type = event_type,
event_name = event_name,
properties = props,
)
self.log(f"成功追踪 {len(events)} 个事件")
return True
except Exception as e:
self.log(f"批量事件追踪失败: {e}", success=False)
self.log(f"批量事件追踪失败: {e}", success = False)
return False
def test_get_user_profile(self):
def test_get_user_profile(self) -> None:
"""测试获取用户画像"""
print("\n👤 测试用户画像...")
try:
profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id)
profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id)
if profile:
assert profile.user_id == self.test_user_id
@@ -120,18 +120,18 @@ class TestGrowthManager:
return True
except Exception as e:
self.log(f"获取用户画像失败: {e}", success=False)
self.log(f"获取用户画像失败: {e}", success = False)
return False
def test_get_analytics_summary(self):
def test_get_analytics_summary(self) -> None:
"""测试获取分析汇总"""
print("\n📈 测试分析汇总...")
try:
summary = self.manager.get_user_analytics_summary(
tenant_id=self.test_tenant_id,
start_date=datetime.now() - timedelta(days=7),
end_date=datetime.now(),
summary = self.manager.get_user_analytics_summary(
tenant_id = self.test_tenant_id,
start_date = datetime.now() - timedelta(days = 7),
end_date = datetime.now(),
)
assert "unique_users" in summary
@@ -141,25 +141,25 @@ class TestGrowthManager:
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
return True
except Exception as e:
self.log(f"获取分析汇总失败: {e}", success=False)
self.log(f"获取分析汇总失败: {e}", success = False)
return False
def test_create_funnel(self):
def test_create_funnel(self) -> None:
"""测试创建转化漏斗"""
print("\n🎯 测试创建转化漏斗...")
try:
funnel = self.manager.create_funnel(
tenant_id=self.test_tenant_id,
name="用户注册转化漏斗",
description="从访问到完成注册的转化流程",
steps=[
funnel = self.manager.create_funnel(
tenant_id = self.test_tenant_id,
name = "用户注册转化漏斗",
description = "从访问到完成注册的转化流程",
steps = [
{"name": "访问首页", "event_name": "page_view_home"},
{"name": "点击注册", "event_name": "signup_click"},
{"name": "填写信息", "event_name": "signup_form_fill"},
{"name": "完成注册", "event_name": "signup_complete"},
],
created_by="test",
created_by = "test",
)
assert funnel.id is not None
@@ -168,10 +168,10 @@ class TestGrowthManager:
self.log(f"漏斗创建成功: {funnel.id}")
return funnel.id
except Exception as e:
self.log(f"创建漏斗失败: {e}", success=False)
self.log(f"创建漏斗失败: {e}", success = False)
return None
def test_analyze_funnel(self, funnel_id: str):
def test_analyze_funnel(self, funnel_id: str) -> None:
"""测试分析漏斗"""
print("\n📉 测试漏斗分析...")
@@ -180,10 +180,10 @@ class TestGrowthManager:
return False
try:
analysis = self.manager.analyze_funnel(
funnel_id=funnel_id,
period_start=datetime.now() - timedelta(days=30),
period_end=datetime.now(),
analysis = self.manager.analyze_funnel(
funnel_id = funnel_id,
period_start = datetime.now() - timedelta(days = 30),
period_end = datetime.now(),
)
if analysis:
@@ -194,18 +194,18 @@ class TestGrowthManager:
self.log("漏斗分析返回空结果")
return False
except Exception as e:
self.log(f"漏斗分析失败: {e}", success=False)
self.log(f"漏斗分析失败: {e}", success = False)
return False
def test_calculate_retention(self):
def test_calculate_retention(self) -> None:
"""测试留存率计算"""
print("\n🔄 测试留存率计算...")
try:
retention = self.manager.calculate_retention(
tenant_id=self.test_tenant_id,
cohort_date=datetime.now() - timedelta(days=7),
periods=[1, 3, 7],
retention = self.manager.calculate_retention(
tenant_id = self.test_tenant_id,
cohort_date = datetime.now() - timedelta(days = 7),
periods = [1, 3, 7],
)
assert "cohort_date" in retention
@@ -214,34 +214,34 @@ class TestGrowthManager:
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
return True
except Exception as e:
self.log(f"留存率计算失败: {e}", success=False)
self.log(f"留存率计算失败: {e}", success = False)
return False
# ==================== 测试 A/B 测试框架 ====================
def test_create_experiment(self):
def test_create_experiment(self) -> None:
"""测试创建实验"""
print("\n🧪 测试创建 A/B 测试实验...")
try:
experiment = self.manager.create_experiment(
tenant_id=self.test_tenant_id,
name="首页按钮颜色测试",
description="测试不同按钮颜色对转化率的影响",
hypothesis="蓝色按钮比红色按钮有更高的点击率",
variants=[
experiment = self.manager.create_experiment(
tenant_id = self.test_tenant_id,
name = "首页按钮颜色测试",
description = "测试不同按钮颜色对转化率的影响",
hypothesis = "蓝色按钮比红色按钮有更高的点击率",
variants = [
{"id": "control", "name": "红色按钮", "is_control": True},
{"id": "variant_a", "name": "蓝色按钮", "is_control": False},
{"id": "variant_b", "name": "绿色按钮", "is_control": False},
],
traffic_allocation=TrafficAllocationType.RANDOM,
traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
target_audience={"conditions": []},
primary_metric="button_click_rate",
secondary_metrics=["conversion_rate", "bounce_rate"],
min_sample_size=100,
confidence_level=0.95,
created_by="test",
traffic_allocation = TrafficAllocationType.RANDOM,
traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
target_audience = {"conditions": []},
primary_metric = "button_click_rate",
secondary_metrics = ["conversion_rate", "bounce_rate"],
min_sample_size = 100,
confidence_level = 0.95,
created_by = "test",
)
assert experiment.id is not None
@@ -250,23 +250,23 @@ class TestGrowthManager:
self.log(f"实验创建成功: {experiment.id}")
return experiment.id
except Exception as e:
self.log(f"创建实验失败: {e}", success=False)
self.log(f"创建实验失败: {e}", success = False)
return None
def test_list_experiments(self):
def test_list_experiments(self) -> None:
"""测试列出实验"""
print("\n📋 测试列出实验...")
try:
experiments = self.manager.list_experiments(self.test_tenant_id)
experiments = self.manager.list_experiments(self.test_tenant_id)
self.log(f"列出 {len(experiments)} 个实验")
return True
except Exception as e:
self.log(f"列出实验失败: {e}", success=False)
self.log(f"列出实验失败: {e}", success = False)
return False
def test_assign_variant(self, experiment_id: str):
def test_assign_variant(self, experiment_id: str) -> None:
"""测试分配变体"""
print("\n🎲 测试分配实验变体...")
@@ -279,26 +279,26 @@ class TestGrowthManager:
self.manager.start_experiment(experiment_id)
# 测试多个用户的变体分配
test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"]
assignments = {}
test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"]
assignments = {}
for user_id in test_users:
variant_id = self.manager.assign_variant(
experiment_id=experiment_id,
user_id=user_id,
user_attributes={"user_id": user_id, "segment": "new"},
variant_id = self.manager.assign_variant(
experiment_id = experiment_id,
user_id = user_id,
user_attributes = {"user_id": user_id, "segment": "new"},
)
if variant_id:
assignments[user_id] = variant_id
assignments[user_id] = variant_id
self.log(f"变体分配完成: {len(assignments)} 个用户")
return True
except Exception as e:
self.log(f"变体分配失败: {e}", success=False)
self.log(f"变体分配失败: {e}", success = False)
return False
def test_record_experiment_metric(self, experiment_id: str):
def test_record_experiment_metric(self, experiment_id: str) -> None:
"""测试记录实验指标"""
print("\n📊 测试记录实验指标...")
@@ -308,7 +308,7 @@ class TestGrowthManager:
try:
# 模拟记录一些指标
test_data = [
test_data = [
("user_001", "control", 1),
("user_002", "variant_a", 1),
("user_003", "variant_b", 0),
@@ -318,20 +318,20 @@ class TestGrowthManager:
for user_id, variant_id, value in test_data:
self.manager.record_experiment_metric(
experiment_id=experiment_id,
variant_id=variant_id,
user_id=user_id,
metric_name="button_click_rate",
metric_value=value,
experiment_id = experiment_id,
variant_id = variant_id,
user_id = user_id,
metric_name = "button_click_rate",
metric_value = value,
)
self.log(f"成功记录 {len(test_data)} 条指标")
return True
except Exception as e:
self.log(f"记录指标失败: {e}", success=False)
self.log(f"记录指标失败: {e}", success = False)
return False
def test_analyze_experiment(self, experiment_id: str):
def test_analyze_experiment(self, experiment_id: str) -> None:
"""测试分析实验结果"""
print("\n📈 测试分析实验结果...")
@@ -340,31 +340,31 @@ class TestGrowthManager:
return False
try:
result = self.manager.analyze_experiment(experiment_id)
result = self.manager.analyze_experiment(experiment_id)
if "error" not in result:
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
return True
else:
self.log(f"实验分析返回错误: {result['error']}", success=False)
self.log(f"实验分析返回错误: {result['error']}", success = False)
return False
except Exception as e:
self.log(f"实验分析失败: {e}", success=False)
self.log(f"实验分析失败: {e}", success = False)
return False
# ==================== 测试邮件营销 ====================
def test_create_email_template(self):
def test_create_email_template(self) -> None:
"""测试创建邮件模板"""
print("\n📧 测试创建邮件模板...")
try:
template = self.manager.create_email_template(
tenant_id=self.test_tenant_id,
name="欢迎邮件",
template_type=EmailTemplateType.WELCOME,
subject="欢迎加入 InsightFlow",
html_content="""
template = self.manager.create_email_template(
tenant_id = self.test_tenant_id,
name = "欢迎邮件",
template_type = EmailTemplateType.WELCOME,
subject = "欢迎加入 InsightFlow",
html_content = """
<h1>欢迎,{{user_name}}</h1>
<p>感谢您注册 InsightFlow。我们很高兴您能加入我们</p>
<p>您的账户已创建,可以开始使用以下功能:</p>
@@ -373,10 +373,10 @@ class TestGrowthManager:
<li>智能实体提取</li>
<li>团队协作</li>
</ul>
<p><a href="{{dashboard_url}}">立即开始使用</a></p>
<p><a href = "{{dashboard_url}}">立即开始使用</a></p>
""",
from_name="InsightFlow 团队",
from_email="welcome@insightflow.io",
from_name = "InsightFlow 团队",
from_email = "welcome@insightflow.io",
)
assert template.id is not None
@@ -385,23 +385,23 @@ class TestGrowthManager:
self.log(f"邮件模板创建成功: {template.id}")
return template.id
except Exception as e:
self.log(f"创建邮件模板失败: {e}", success=False)
self.log(f"创建邮件模板失败: {e}", success = False)
return None
def test_list_email_templates(self):
def test_list_email_templates(self) -> None:
"""测试列出邮件模板"""
print("\n📧 测试列出邮件模板...")
try:
templates = self.manager.list_email_templates(self.test_tenant_id)
templates = self.manager.list_email_templates(self.test_tenant_id)
self.log(f"列出 {len(templates)} 个邮件模板")
return True
except Exception as e:
self.log(f"列出邮件模板失败: {e}", success=False)
self.log(f"列出邮件模板失败: {e}", success = False)
return False
def test_render_template(self, template_id: str):
def test_render_template(self, template_id: str) -> None:
"""测试渲染邮件模板"""
print("\n🎨 测试渲染邮件模板...")
@@ -410,9 +410,9 @@ class TestGrowthManager:
return False
try:
rendered = self.manager.render_template(
template_id=template_id,
variables={
rendered = self.manager.render_template(
template_id = template_id,
variables = {
"user_name": "张三",
"dashboard_url": "https://app.insightflow.io/dashboard",
},
@@ -424,13 +424,13 @@ class TestGrowthManager:
self.log(f"模板渲染成功: {rendered['subject']}")
return True
else:
self.log("模板渲染返回空结果", success=False)
self.log("模板渲染返回空结果", success = False)
return False
except Exception as e:
self.log(f"模板渲染失败: {e}", success=False)
self.log(f"模板渲染失败: {e}", success = False)
return False
def test_create_email_campaign(self, template_id: str):
def test_create_email_campaign(self, template_id: str) -> None:
"""测试创建邮件营销活动"""
print("\n📮 测试创建邮件营销活动...")
@@ -439,11 +439,11 @@ class TestGrowthManager:
return None
try:
campaign = self.manager.create_email_campaign(
tenant_id=self.test_tenant_id,
name="新用户欢迎活动",
template_id=template_id,
recipient_list=[
campaign = self.manager.create_email_campaign(
tenant_id = self.test_tenant_id,
name = "新用户欢迎活动",
template_id = template_id,
recipient_list = [
{"user_id": "user_001", "email": "user1@example.com"},
{"user_id": "user_002", "email": "user2@example.com"},
{"user_id": "user_003", "email": "user3@example.com"},
@@ -456,21 +456,21 @@ class TestGrowthManager:
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
return campaign.id
except Exception as e:
self.log(f"创建营销活动失败: {e}", success=False)
self.log(f"创建营销活动失败: {e}", success = False)
return None
def test_create_automation_workflow(self):
def test_create_automation_workflow(self) -> None:
"""测试创建自动化工作流"""
print("\n🤖 测试创建自动化工作流...")
try:
workflow = self.manager.create_automation_workflow(
tenant_id=self.test_tenant_id,
name="新用户欢迎序列",
description="用户注册后自动发送欢迎邮件序列",
trigger_type=WorkflowTriggerType.USER_SIGNUP,
trigger_conditions={"event": "user_signup"},
actions=[
workflow = self.manager.create_automation_workflow(
tenant_id = self.test_tenant_id,
name = "新用户欢迎序列",
description = "用户注册后自动发送欢迎邮件序列",
trigger_type = WorkflowTriggerType.USER_SIGNUP,
trigger_conditions = {"event": "user_signup"},
actions = [
{"type": "send_email", "template_type": "welcome", "delay_hours": 0},
{"type": "send_email", "template_type": "onboarding", "delay_hours": 24},
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72},
@@ -483,27 +483,27 @@ class TestGrowthManager:
self.log(f"自动化工作流创建成功: {workflow.id}")
return True
except Exception as e:
self.log(f"创建工作流失败: {e}", success=False)
self.log(f"创建工作流失败: {e}", success = False)
return False
# ==================== 测试推荐系统 ====================
def test_create_referral_program(self):
def test_create_referral_program(self) -> None:
"""测试创建推荐计划"""
print("\n🎁 测试创建推荐计划...")
try:
program = self.manager.create_referral_program(
tenant_id=self.test_tenant_id,
name="邀请好友奖励计划",
description="邀请好友注册,双方获得积分奖励",
referrer_reward_type="credit",
referrer_reward_value=100.0,
referee_reward_type="credit",
referee_reward_value=50.0,
max_referrals_per_user=10,
referral_code_length=8,
expiry_days=30,
program = self.manager.create_referral_program(
tenant_id = self.test_tenant_id,
name = "邀请好友奖励计划",
description = "邀请好友注册,双方获得积分奖励",
referrer_reward_type = "credit",
referrer_reward_value = 100.0,
referee_reward_type = "credit",
referee_reward_value = 50.0,
max_referrals_per_user = 10,
referral_code_length = 8,
expiry_days = 30,
)
assert program.id is not None
@@ -512,10 +512,10 @@ class TestGrowthManager:
self.log(f"推荐计划创建成功: {program.id}")
return program.id
except Exception as e:
self.log(f"创建推荐计划失败: {e}", success=False)
self.log(f"创建推荐计划失败: {e}", success = False)
return None
def test_generate_referral_code(self, program_id: str):
def test_generate_referral_code(self, program_id: str) -> None:
"""测试生成推荐码"""
print("\n🔑 测试生成推荐码...")
@@ -524,8 +524,8 @@ class TestGrowthManager:
return None
try:
referral = self.manager.generate_referral_code(
program_id=program_id, referrer_id="referrer_user_001"
referral = self.manager.generate_referral_code(
program_id = program_id, referrer_id = "referrer_user_001"
)
if referral:
@@ -535,13 +535,13 @@ class TestGrowthManager:
self.log(f"推荐码生成成功: {referral.referral_code}")
return referral.referral_code
else:
self.log("生成推荐码返回空结果", success=False)
self.log("生成推荐码返回空结果", success = False)
return None
except Exception as e:
self.log(f"生成推荐码失败: {e}", success=False)
self.log(f"生成推荐码失败: {e}", success = False)
return None
def test_apply_referral_code(self, referral_code: str):
def test_apply_referral_code(self, referral_code: str) -> None:
"""测试应用推荐码"""
print("\n✅ 测试应用推荐码...")
@@ -550,21 +550,21 @@ class TestGrowthManager:
return False
try:
success = self.manager.apply_referral_code(
referral_code=referral_code, referee_id="new_user_001"
success = self.manager.apply_referral_code(
referral_code = referral_code, referee_id = "new_user_001"
)
if success:
self.log(f"推荐码应用成功: {referral_code}")
return True
else:
self.log("推荐码应用失败", success=False)
self.log("推荐码应用失败", success = False)
return False
except Exception as e:
self.log(f"应用推荐码失败: {e}", success=False)
self.log(f"应用推荐码失败: {e}", success = False)
return False
def test_get_referral_stats(self, program_id: str):
def test_get_referral_stats(self, program_id: str) -> None:
"""测试获取推荐统计"""
print("\n📊 测试获取推荐统计...")
@@ -573,7 +573,7 @@ class TestGrowthManager:
return False
try:
stats = self.manager.get_referral_stats(program_id)
stats = self.manager.get_referral_stats(program_id)
assert "total_referrals" in stats
assert "conversion_rate" in stats
@@ -583,24 +583,24 @@ class TestGrowthManager:
)
return True
except Exception as e:
self.log(f"获取推荐统计失败: {e}", success=False)
self.log(f"获取推荐统计失败: {e}", success = False)
return False
def test_create_team_incentive(self):
def test_create_team_incentive(self) -> None:
"""测试创建团队激励"""
print("\n🏆 测试创建团队升级激励...")
try:
incentive = self.manager.create_team_incentive(
tenant_id=self.test_tenant_id,
name="团队升级奖励",
description="团队规模达到5人升级到 Pro 计划可获得折扣",
target_tier="pro",
min_team_size=5,
incentive_type="discount",
incentive_value=20.0, # 20% 折扣
valid_from=datetime.now(),
valid_until=datetime.now() + timedelta(days=90),
incentive = self.manager.create_team_incentive(
tenant_id = self.test_tenant_id,
name = "团队升级奖励",
description = "团队规模达到5人升级到 Pro 计划可获得折扣",
target_tier = "pro",
min_team_size = 5,
incentive_type = "discount",
incentive_value = 20.0, # 20% 折扣
valid_from = datetime.now(),
valid_until = datetime.now() + timedelta(days = 90),
)
assert incentive.id is not None
@@ -609,116 +609,116 @@ class TestGrowthManager:
self.log(f"团队激励创建成功: {incentive.id}")
return True
except Exception as e:
self.log(f"创建团队激励失败: {e}", success=False)
self.log(f"创建团队激励失败: {e}", success = False)
return False
def test_check_team_incentive_eligibility(self):
def test_check_team_incentive_eligibility(self) -> None:
"""测试检查团队激励资格"""
print("\n🔍 测试检查团队激励资格...")
try:
incentives = self.manager.check_team_incentive_eligibility(
tenant_id=self.test_tenant_id, current_tier="free", team_size=5
incentives = self.manager.check_team_incentive_eligibility(
tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5
)
self.log(f"找到 {len(incentives)} 个符合条件的激励")
return True
except Exception as e:
self.log(f"检查激励资格失败: {e}", success=False)
self.log(f"检查激励资格失败: {e}", success = False)
return False
# ==================== 测试实时仪表板 ====================
def test_get_realtime_dashboard(self):
def test_get_realtime_dashboard(self) -> None:
"""测试获取实时仪表板"""
print("\n📺 测试实时分析仪表板...")
try:
dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id)
dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id)
assert "today" in dashboard
assert "recent_events" in dashboard
assert "top_features" in dashboard
today = dashboard["today"]
today = dashboard["today"]
self.log(
f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件"
)
return True
except Exception as e:
self.log(f"获取实时仪表板失败: {e}", success=False)
self.log(f"获取实时仪表板失败: {e}", success = False)
return False
# ==================== 运行所有测试 ====================
async def run_all_tests(self):
async def run_all_tests(self) -> None:
"""运行所有测试"""
print("=" * 60)
print(" = " * 60)
print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试")
print("=" * 60)
print(" = " * 60)
# 用户行为分析测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📊 模块 1: 用户行为分析")
print("=" * 60)
print(" = " * 60)
await self.test_track_event()
await self.test_track_multiple_events()
self.test_get_user_profile()
self.test_get_analytics_summary()
funnel_id = self.test_create_funnel()
funnel_id = self.test_create_funnel()
self.test_analyze_funnel(funnel_id)
self.test_calculate_retention()
# A/B 测试框架测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("🧪 模块 2: A/B 测试框架")
print("=" * 60)
print(" = " * 60)
experiment_id = self.test_create_experiment()
experiment_id = self.test_create_experiment()
self.test_list_experiments()
self.test_assign_variant(experiment_id)
self.test_record_experiment_metric(experiment_id)
self.test_analyze_experiment(experiment_id)
# 邮件营销测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📧 模块 3: 邮件营销自动化")
print("=" * 60)
print(" = " * 60)
template_id = self.test_create_email_template()
template_id = self.test_create_email_template()
self.test_list_email_templates()
self.test_render_template(template_id)
self.test_create_email_campaign(template_id)
self.test_create_automation_workflow()
# 推荐系统测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("🎁 模块 4: 推荐系统")
print("=" * 60)
print(" = " * 60)
program_id = self.test_create_referral_program()
referral_code = self.test_generate_referral_code(program_id)
program_id = self.test_create_referral_program()
referral_code = self.test_generate_referral_code(program_id)
self.test_apply_referral_code(referral_code)
self.test_get_referral_stats(program_id)
self.test_create_team_incentive()
self.test_check_team_incentive_eligibility()
# 实时仪表板测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📺 模块 5: 实时分析仪表板")
print("=" * 60)
print(" = " * 60)
self.test_get_realtime_dashboard()
# 测试总结
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📋 测试总结")
print("=" * 60)
print(" = " * 60)
total_tests = len(self.test_results)
passed_tests = sum(1 for _, success in self.test_results if success)
failed_tests = total_tests - passed_tests
total_tests = len(self.test_results)
passed_tests = sum(1 for _, success in self.test_results if success)
failed_tests = total_tests - passed_tests
print(f"总测试数: {total_tests}")
print(f"通过: {passed_tests}")
@@ -731,14 +731,14 @@ class TestGrowthManager:
if not success:
print(f" - {message}")
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("✨ 测试完成!")
print("=" * 60)
print(" = " * 60)
async def main():
async def main() -> None:
"""主函数"""
tester = TestGrowthManager()
tester = TestGrowthManager()
await tester.run_all_tests()

View File

@@ -25,7 +25,7 @@ from developer_ecosystem_manager import (
)
# Add backend directory to path
backend_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
@@ -33,10 +33,10 @@ if backend_dir not in sys.path:
class TestDeveloperEcosystem:
"""开发者生态系统测试类"""
def __init__(self):
self.manager = DeveloperEcosystemManager()
self.test_results = []
self.created_ids = {
def __init__(self) -> None:
self.manager = DeveloperEcosystemManager()
self.test_results = []
self.created_ids = {
"sdk": [],
"template": [],
"plugin": [],
@@ -45,19 +45,19 @@ class TestDeveloperEcosystem:
"portal_config": [],
}
def log(self, message: str, success: bool = True):
def log(self, message: str, success: bool = True) -> None:
"""记录测试结果"""
status = "" if success else ""
status = "" if success else ""
print(f"{status} {message}")
self.test_results.append(
{"message": message, "success": success, "timestamp": datetime.now().isoformat()}
)
def run_all_tests(self):
def run_all_tests(self) -> None:
"""运行所有测试"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests")
print("=" * 60)
print(" = " * 60)
# SDK Tests
print("\n📦 SDK Release & Management Tests")
@@ -119,538 +119,538 @@ class TestDeveloperEcosystem:
# Print Summary
self.print_summary()
def test_sdk_create(self):
def test_sdk_create(self) -> None:
"""测试创建 SDK"""
try:
sdk = self.manager.create_sdk_release(
name="InsightFlow Python SDK",
language=SDKLanguage.PYTHON,
version="1.0.0",
description="Python SDK for InsightFlow API",
changelog="Initial release",
download_url="https://pypi.org/insightflow/1.0.0",
documentation_url="https://docs.insightflow.io/python",
repository_url="https://github.com/insightflow/python-sdk",
package_name="insightflow",
min_platform_version="1.0.0",
dependencies=[{"name": "requests", "version": ">=2.0"}],
file_size=1024000,
checksum="abc123",
created_by="test_user",
sdk = self.manager.create_sdk_release(
name = "InsightFlow Python SDK",
language = SDKLanguage.PYTHON,
version = "1.0.0",
description = "Python SDK for InsightFlow API",
changelog = "Initial release",
download_url = "https://pypi.org/insightflow/1.0.0",
documentation_url = "https://docs.insightflow.io/python",
repository_url = "https://github.com/insightflow/python-sdk",
package_name = "insightflow",
min_platform_version = "1.0.0",
dependencies = [{"name": "requests", "version": ">= 2.0"}],
file_size = 1024000,
checksum = "abc123",
created_by = "test_user",
)
self.created_ids["sdk"].append(sdk.id)
self.log(f"Created SDK: {sdk.name} ({sdk.id})")
# Create JavaScript SDK
sdk_js = self.manager.create_sdk_release(
name="InsightFlow JavaScript SDK",
language=SDKLanguage.JAVASCRIPT,
version="1.0.0",
description="JavaScript SDK for InsightFlow API",
changelog="Initial release",
download_url="https://npmjs.com/insightflow/1.0.0",
documentation_url="https://docs.insightflow.io/js",
repository_url="https://github.com/insightflow/js-sdk",
package_name="@insightflow/sdk",
min_platform_version="1.0.0",
dependencies=[{"name": "axios", "version": ">=0.21"}],
file_size=512000,
checksum="def456",
created_by="test_user",
sdk_js = self.manager.create_sdk_release(
name = "InsightFlow JavaScript SDK",
language = SDKLanguage.JAVASCRIPT,
version = "1.0.0",
description = "JavaScript SDK for InsightFlow API",
changelog = "Initial release",
download_url = "https://npmjs.com/insightflow/1.0.0",
documentation_url = "https://docs.insightflow.io/js",
repository_url = "https://github.com/insightflow/js-sdk",
package_name = "@insightflow/sdk",
min_platform_version = "1.0.0",
dependencies = [{"name": "axios", "version": ">= 0.21"}],
file_size = 512000,
checksum = "def456",
created_by = "test_user",
)
self.created_ids["sdk"].append(sdk_js.id)
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
except Exception as e:
self.log(f"Failed to create SDK: {str(e)}", success=False)
self.log(f"Failed to create SDK: {str(e)}", success = False)
def test_sdk_list(self):
def test_sdk_list(self) -> None:
"""测试列出 SDK"""
try:
sdks = self.manager.list_sdk_releases()
sdks = self.manager.list_sdk_releases()
self.log(f"Listed {len(sdks)} SDKs")
# Test filter by language
python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON)
python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON)
self.log(f"Found {len(python_sdks)} Python SDKs")
# Test search
search_results = self.manager.list_sdk_releases(search="Python")
search_results = self.manager.list_sdk_releases(search = "Python")
self.log(f"Search found {len(search_results)} SDKs")
except Exception as e:
self.log(f"Failed to list SDKs: {str(e)}", success=False)
self.log(f"Failed to list SDKs: {str(e)}", success = False)
def test_sdk_get(self):
def test_sdk_get(self) -> None:
"""测试获取 SDK 详情"""
try:
if self.created_ids["sdk"]:
sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0])
sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0])
if sdk:
self.log(f"Retrieved SDK: {sdk.name}")
else:
self.log("SDK not found", success=False)
self.log("SDK not found", success = False)
except Exception as e:
self.log(f"Failed to get SDK: {str(e)}", success=False)
self.log(f"Failed to get SDK: {str(e)}", success = False)
def test_sdk_update(self):
def test_sdk_update(self) -> None:
"""测试更新 SDK"""
try:
if self.created_ids["sdk"]:
sdk = self.manager.update_sdk_release(
self.created_ids["sdk"][0], description="Updated description"
sdk = self.manager.update_sdk_release(
self.created_ids["sdk"][0], description = "Updated description"
)
if sdk:
self.log(f"Updated SDK: {sdk.name}")
except Exception as e:
self.log(f"Failed to update SDK: {str(e)}", success=False)
self.log(f"Failed to update SDK: {str(e)}", success = False)
def test_sdk_publish(self):
def test_sdk_publish(self) -> None:
"""测试发布 SDK"""
try:
if self.created_ids["sdk"]:
sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0])
sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0])
if sdk:
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
except Exception as e:
self.log(f"Failed to publish SDK: {str(e)}", success=False)
self.log(f"Failed to publish SDK: {str(e)}", success = False)
def test_sdk_version_add(self):
def test_sdk_version_add(self) -> None:
"""测试添加 SDK 版本"""
try:
if self.created_ids["sdk"]:
version = self.manager.add_sdk_version(
sdk_id=self.created_ids["sdk"][0],
version="1.1.0",
is_lts=True,
release_notes="Bug fixes and improvements",
download_url="https://pypi.org/insightflow/1.1.0",
checksum="xyz789",
file_size=1100000,
version = self.manager.add_sdk_version(
sdk_id = self.created_ids["sdk"][0],
version = "1.1.0",
is_lts = True,
release_notes = "Bug fixes and improvements",
download_url = "https://pypi.org/insightflow/1.1.0",
checksum = "xyz789",
file_size = 1100000,
)
self.log(f"Added SDK version: {version.version}")
except Exception as e:
self.log(f"Failed to add SDK version: {str(e)}", success=False)
self.log(f"Failed to add SDK version: {str(e)}", success = False)
def test_template_create(self):
def test_template_create(self) -> None:
"""测试创建模板"""
try:
template = self.manager.create_template(
name="医疗行业实体识别模板",
description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
category=TemplateCategory.MEDICAL,
subcategory="entity_recognition",
tags=["medical", "healthcare", "ner"],
author_id="dev_001",
author_name="Medical AI Lab",
price=99.0,
currency="CNY",
preview_image_url="https://cdn.insightflow.io/templates/medical.png",
demo_url="https://demo.insightflow.io/medical",
documentation_url="https://docs.insightflow.io/templates/medical",
download_url="https://cdn.insightflow.io/templates/medical.zip",
version="1.0.0",
min_platform_version="2.0.0",
file_size=5242880,
checksum="tpl123",
template = self.manager.create_template(
name = "医疗行业实体识别模板",
description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
category = TemplateCategory.MEDICAL,
subcategory = "entity_recognition",
tags = ["medical", "healthcare", "ner"],
author_id = "dev_001",
author_name = "Medical AI Lab",
price = 99.0,
currency = "CNY",
preview_image_url = "https://cdn.insightflow.io/templates/medical.png",
demo_url = "https://demo.insightflow.io/medical",
documentation_url = "https://docs.insightflow.io/templates/medical",
download_url = "https://cdn.insightflow.io/templates/medical.zip",
version = "1.0.0",
min_platform_version = "2.0.0",
file_size = 5242880,
checksum = "tpl123",
)
self.created_ids["template"].append(template.id)
self.log(f"Created template: {template.name} ({template.id})")
# Create free template
template_free = self.manager.create_template(
name="通用实体识别模板",
description="适用于一般场景的实体识别模板",
category=TemplateCategory.GENERAL,
subcategory=None,
tags=["general", "ner", "basic"],
author_id="dev_002",
author_name="InsightFlow Team",
price=0.0,
currency="CNY",
template_free = self.manager.create_template(
name = "通用实体识别模板",
description = "适用于一般场景的实体识别模板",
category = TemplateCategory.GENERAL,
subcategory = None,
tags = ["general", "ner", "basic"],
author_id = "dev_002",
author_name = "InsightFlow Team",
price = 0.0,
currency = "CNY",
)
self.created_ids["template"].append(template_free.id)
self.log(f"Created free template: {template_free.name}")
except Exception as e:
self.log(f"Failed to create template: {str(e)}", success=False)
self.log(f"Failed to create template: {str(e)}", success = False)
def test_template_list(self):
def test_template_list(self) -> None:
"""测试列出模板"""
try:
templates = self.manager.list_templates()
templates = self.manager.list_templates()
self.log(f"Listed {len(templates)} templates")
# Filter by category
medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL)
medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL)
self.log(f"Found {len(medical_templates)} medical templates")
# Filter by price
free_templates = self.manager.list_templates(max_price=0)
free_templates = self.manager.list_templates(max_price = 0)
self.log(f"Found {len(free_templates)} free templates")
except Exception as e:
self.log(f"Failed to list templates: {str(e)}", success=False)
self.log(f"Failed to list templates: {str(e)}", success = False)
def test_template_get(self):
def test_template_get(self) -> None:
"""测试获取模板详情"""
try:
if self.created_ids["template"]:
template = self.manager.get_template(self.created_ids["template"][0])
template = self.manager.get_template(self.created_ids["template"][0])
if template:
self.log(f"Retrieved template: {template.name}")
except Exception as e:
self.log(f"Failed to get template: {str(e)}", success=False)
self.log(f"Failed to get template: {str(e)}", success = False)
def test_template_approve(self):
def test_template_approve(self) -> None:
"""测试审核通过模板"""
try:
if self.created_ids["template"]:
template = self.manager.approve_template(
self.created_ids["template"][0], reviewed_by="admin_001"
template = self.manager.approve_template(
self.created_ids["template"][0], reviewed_by = "admin_001"
)
if template:
self.log(f"Approved template: {template.name}")
except Exception as e:
self.log(f"Failed to approve template: {str(e)}", success=False)
self.log(f"Failed to approve template: {str(e)}", success = False)
def test_template_publish(self):
def test_template_publish(self) -> None:
"""测试发布模板"""
try:
if self.created_ids["template"]:
template = self.manager.publish_template(self.created_ids["template"][0])
template = self.manager.publish_template(self.created_ids["template"][0])
if template:
self.log(f"Published template: {template.name}")
except Exception as e:
self.log(f"Failed to publish template: {str(e)}", success=False)
self.log(f"Failed to publish template: {str(e)}", success = False)
def test_template_review(self):
def test_template_review(self) -> None:
"""测试添加模板评价"""
try:
if self.created_ids["template"]:
review = self.manager.add_template_review(
template_id=self.created_ids["template"][0],
user_id="user_001",
user_name="Test User",
rating=5,
comment="Great template! Very accurate for medical entities.",
is_verified_purchase=True,
review = self.manager.add_template_review(
template_id = self.created_ids["template"][0],
user_id = "user_001",
user_name = "Test User",
rating = 5,
comment = "Great template! Very accurate for medical entities.",
is_verified_purchase = True,
)
self.log(f"Added template review: {review.rating} stars")
except Exception as e:
self.log(f"Failed to add template review: {str(e)}", success=False)
self.log(f"Failed to add template review: {str(e)}", success = False)
def test_plugin_create(self):
def test_plugin_create(self) -> None:
"""测试创建插件"""
try:
plugin = self.manager.create_plugin(
name="飞书机器人集成插件",
description="将 InsightFlow 与飞书机器人集成,实现自动通知",
category=PluginCategory.INTEGRATION,
tags=["feishu", "bot", "integration", "notification"],
author_id="dev_003",
author_name="Integration Team",
price=49.0,
currency="CNY",
pricing_model="paid",
preview_image_url="https://cdn.insightflow.io/plugins/feishu.png",
demo_url="https://demo.insightflow.io/feishu",
documentation_url="https://docs.insightflow.io/plugins/feishu",
repository_url="https://github.com/insightflow/feishu-plugin",
download_url="https://cdn.insightflow.io/plugins/feishu.zip",
webhook_url="https://api.insightflow.io/webhooks/feishu",
permissions=["read:projects", "write:notifications"],
version="1.0.0",
min_platform_version="2.0.0",
file_size=1048576,
checksum="plg123",
plugin = self.manager.create_plugin(
name = "飞书机器人集成插件",
description = "将 InsightFlow 与飞书机器人集成,实现自动通知",
category = PluginCategory.INTEGRATION,
tags = ["feishu", "bot", "integration", "notification"],
author_id = "dev_003",
author_name = "Integration Team",
price = 49.0,
currency = "CNY",
pricing_model = "paid",
preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png",
demo_url = "https://demo.insightflow.io/feishu",
documentation_url = "https://docs.insightflow.io/plugins/feishu",
repository_url = "https://github.com/insightflow/feishu-plugin",
download_url = "https://cdn.insightflow.io/plugins/feishu.zip",
webhook_url = "https://api.insightflow.io/webhooks/feishu",
permissions = ["read:projects", "write:notifications"],
version = "1.0.0",
min_platform_version = "2.0.0",
file_size = 1048576,
checksum = "plg123",
)
self.created_ids["plugin"].append(plugin.id)
self.log(f"Created plugin: {plugin.name} ({plugin.id})")
# Create free plugin
plugin_free = self.manager.create_plugin(
name="数据导出插件",
description="支持多种格式的数据导出",
category=PluginCategory.ANALYSIS,
tags=["export", "data", "csv", "json"],
author_id="dev_004",
author_name="Data Team",
price=0.0,
currency="CNY",
pricing_model="free",
plugin_free = self.manager.create_plugin(
name = "数据导出插件",
description = "支持多种格式的数据导出",
category = PluginCategory.ANALYSIS,
tags = ["export", "data", "csv", "json"],
author_id = "dev_004",
author_name = "Data Team",
price = 0.0,
currency = "CNY",
pricing_model = "free",
)
self.created_ids["plugin"].append(plugin_free.id)
self.log(f"Created free plugin: {plugin_free.name}")
except Exception as e:
self.log(f"Failed to create plugin: {str(e)}", success=False)
self.log(f"Failed to create plugin: {str(e)}", success = False)
def test_plugin_list(self):
def test_plugin_list(self) -> None:
"""测试列出插件"""
try:
plugins = self.manager.list_plugins()
plugins = self.manager.list_plugins()
self.log(f"Listed {len(plugins)} plugins")
# Filter by category
integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION)
integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION)
self.log(f"Found {len(integration_plugins)} integration plugins")
except Exception as e:
self.log(f"Failed to list plugins: {str(e)}", success=False)
self.log(f"Failed to list plugins: {str(e)}", success = False)
def test_plugin_get(self):
def test_plugin_get(self) -> None:
"""测试获取插件详情"""
try:
if self.created_ids["plugin"]:
plugin = self.manager.get_plugin(self.created_ids["plugin"][0])
plugin = self.manager.get_plugin(self.created_ids["plugin"][0])
if plugin:
self.log(f"Retrieved plugin: {plugin.name}")
except Exception as e:
self.log(f"Failed to get plugin: {str(e)}", success=False)
self.log(f"Failed to get plugin: {str(e)}", success = False)
def test_plugin_review(self):
def test_plugin_review(self) -> None:
"""测试审核插件"""
try:
if self.created_ids["plugin"]:
plugin = self.manager.review_plugin(
plugin = self.manager.review_plugin(
self.created_ids["plugin"][0],
reviewed_by="admin_001",
status=PluginStatus.APPROVED,
notes="Code review passed",
reviewed_by = "admin_001",
status = PluginStatus.APPROVED,
notes = "Code review passed",
)
if plugin:
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
except Exception as e:
self.log(f"Failed to review plugin: {str(e)}", success=False)
self.log(f"Failed to review plugin: {str(e)}", success = False)
def test_plugin_publish(self):
def test_plugin_publish(self) -> None:
"""测试发布插件"""
try:
if self.created_ids["plugin"]:
plugin = self.manager.publish_plugin(self.created_ids["plugin"][0])
plugin = self.manager.publish_plugin(self.created_ids["plugin"][0])
if plugin:
self.log(f"Published plugin: {plugin.name}")
except Exception as e:
self.log(f"Failed to publish plugin: {str(e)}", success=False)
self.log(f"Failed to publish plugin: {str(e)}", success = False)
def test_plugin_review_add(self):
def test_plugin_review_add(self) -> None:
"""测试添加插件评价"""
try:
if self.created_ids["plugin"]:
review = self.manager.add_plugin_review(
plugin_id=self.created_ids["plugin"][0],
user_id="user_002",
user_name="Plugin User",
rating=4,
comment="Works great with Feishu!",
is_verified_purchase=True,
review = self.manager.add_plugin_review(
plugin_id = self.created_ids["plugin"][0],
user_id = "user_002",
user_name = "Plugin User",
rating = 4,
comment = "Works great with Feishu!",
is_verified_purchase = True,
)
self.log(f"Added plugin review: {review.rating} stars")
except Exception as e:
self.log(f"Failed to add plugin review: {str(e)}", success=False)
self.log(f"Failed to add plugin review: {str(e)}", success = False)
def test_developer_profile_create(self):
def test_developer_profile_create(self) -> None:
"""测试创建开发者档案"""
try:
# Generate unique user IDs
unique_id = uuid.uuid4().hex[:8]
unique_id = uuid.uuid4().hex[:8]
profile = self.manager.create_developer_profile(
user_id=f"user_dev_{unique_id}_001",
display_name="张三",
email=f"zhangsan_{unique_id}@example.com",
bio="专注于医疗AI和自然语言处理",
website="https://zhangsan.dev",
github_url="https://github.com/zhangsan",
avatar_url="https://cdn.example.com/avatars/zhangsan.png",
profile = self.manager.create_developer_profile(
user_id = f"user_dev_{unique_id}_001",
display_name = "张三",
email = f"zhangsan_{unique_id}@example.com",
bio = "专注于医疗AI和自然语言处理",
website = "https://zhangsan.dev",
github_url = "https://github.com/zhangsan",
avatar_url = "https://cdn.example.com/avatars/zhangsan.png",
)
self.created_ids["developer"].append(profile.id)
self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
# Create another developer
profile2 = self.manager.create_developer_profile(
user_id=f"user_dev_{unique_id}_002",
display_name="李四",
email=f"lisi_{unique_id}@example.com",
bio="全栈开发者,热爱开源",
profile2 = self.manager.create_developer_profile(
user_id = f"user_dev_{unique_id}_002",
display_name = "李四",
email = f"lisi_{unique_id}@example.com",
bio = "全栈开发者,热爱开源",
)
self.created_ids["developer"].append(profile2.id)
self.log(f"Created developer profile: {profile2.display_name}")
except Exception as e:
self.log(f"Failed to create developer profile: {str(e)}", success=False)
self.log(f"Failed to create developer profile: {str(e)}", success = False)
def test_developer_profile_get(self):
def test_developer_profile_get(self) -> None:
"""测试获取开发者档案"""
try:
if self.created_ids["developer"]:
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
if profile:
self.log(f"Retrieved developer profile: {profile.display_name}")
except Exception as e:
self.log(f"Failed to get developer profile: {str(e)}", success=False)
self.log(f"Failed to get developer profile: {str(e)}", success = False)
def test_developer_verify(self):
def test_developer_verify(self) -> None:
"""测试验证开发者"""
try:
if self.created_ids["developer"]:
profile = self.manager.verify_developer(
profile = self.manager.verify_developer(
self.created_ids["developer"][0], DeveloperStatus.VERIFIED
)
if profile:
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
except Exception as e:
self.log(f"Failed to verify developer: {str(e)}", success=False)
self.log(f"Failed to verify developer: {str(e)}", success = False)
def test_developer_stats_update(self):
def test_developer_stats_update(self) -> None:
"""测试更新开发者统计"""
try:
if self.created_ids["developer"]:
self.manager.update_developer_stats(self.created_ids["developer"][0])
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
self.log(
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates"
)
except Exception as e:
self.log(f"Failed to update developer stats: {str(e)}", success=False)
self.log(f"Failed to update developer stats: {str(e)}", success = False)
def test_code_example_create(self):
def test_code_example_create(self) -> None:
"""测试创建代码示例"""
try:
example = self.manager.create_code_example(
title="使用 Python SDK 创建项目",
description="演示如何使用 Python SDK 创建新项目",
language="python",
category="quickstart",
code="""from insightflow import Client
example = self.manager.create_code_example(
title = "使用 Python SDK 创建项目",
description = "演示如何使用 Python SDK 创建新项目",
language = "python",
category = "quickstart",
code = """from insightflow import Client
client = Client(api_key="your_api_key")
project = client.projects.create(name="My Project")
client = Client(api_key = "your_api_key")
project = client.projects.create(name = "My Project")
print(f"Created project: {project.id}")
""",
explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
tags=["python", "quickstart", "projects"],
author_id="dev_001",
author_name="InsightFlow Team",
api_endpoints=["/api/v1/projects"],
explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
tags = ["python", "quickstart", "projects"],
author_id = "dev_001",
author_name = "InsightFlow Team",
api_endpoints = ["/api/v1/projects"],
)
self.created_ids["code_example"].append(example.id)
self.log(f"Created code example: {example.title}")
# Create JavaScript example
example_js = self.manager.create_code_example(
title="使用 JavaScript SDK 上传文件",
description="演示如何使用 JavaScript SDK 上传音频文件",
language="javascript",
category="upload",
code="""const { Client } = require('insightflow');
example_js = self.manager.create_code_example(
title = "使用 JavaScript SDK 上传文件",
description = "演示如何使用 JavaScript SDK 上传音频文件",
language = "javascript",
category = "upload",
code = """const { Client } = require('insightflow');
const client = new Client({ apiKey: 'your_api_key' });
const result = await client.uploads.create({
const client = new Client({ apiKey: 'your_api_key' });
const result = await client.uploads.create({
projectId: 'proj_123',
file: './meeting.mp3'
});
console.log('Upload complete:', result.id);
""",
explanation="使用 JavaScript SDK 上传文件到 InsightFlow",
tags=["javascript", "upload", "audio"],
author_id="dev_002",
author_name="JS Team",
explanation = "使用 JavaScript SDK 上传文件到 InsightFlow",
tags = ["javascript", "upload", "audio"],
author_id = "dev_002",
author_name = "JS Team",
)
self.created_ids["code_example"].append(example_js.id)
self.log(f"Created code example: {example_js.title}")
except Exception as e:
self.log(f"Failed to create code example: {str(e)}", success=False)
self.log(f"Failed to create code example: {str(e)}", success = False)
def test_code_example_list(self):
def test_code_example_list(self) -> None:
"""测试列出代码示例"""
try:
examples = self.manager.list_code_examples()
examples = self.manager.list_code_examples()
self.log(f"Listed {len(examples)} code examples")
# Filter by language
python_examples = self.manager.list_code_examples(language="python")
python_examples = self.manager.list_code_examples(language = "python")
self.log(f"Found {len(python_examples)} Python examples")
except Exception as e:
self.log(f"Failed to list code examples: {str(e)}", success=False)
self.log(f"Failed to list code examples: {str(e)}", success = False)
def test_code_example_get(self):
def test_code_example_get(self) -> None:
"""测试获取代码示例详情"""
try:
if self.created_ids["code_example"]:
example = self.manager.get_code_example(self.created_ids["code_example"][0])
example = self.manager.get_code_example(self.created_ids["code_example"][0])
if example:
self.log(
f"Retrieved code example: {example.title} (views: {example.view_count})"
)
except Exception as e:
self.log(f"Failed to get code example: {str(e)}", success=False)
self.log(f"Failed to get code example: {str(e)}", success = False)
def test_portal_config_create(self):
def test_portal_config_create(self) -> None:
"""测试创建开发者门户配置"""
try:
config = self.manager.create_portal_config(
name="InsightFlow Developer Portal",
description="开发者门户 - SDK、API 文档和示例代码",
theme="default",
primary_color="#1890ff",
secondary_color="#52c41a",
support_email="developers@insightflow.io",
support_url="https://support.insightflow.io",
github_url="https://github.com/insightflow",
discord_url="https://discord.gg/insightflow",
api_base_url="https://api.insightflow.io/v1",
config = self.manager.create_portal_config(
name = "InsightFlow Developer Portal",
description = "开发者门户 - SDK、API 文档和示例代码",
theme = "default",
primary_color = "#1890ff",
secondary_color = "#52c41a",
support_email = "developers@insightflow.io",
support_url = "https://support.insightflow.io",
github_url = "https://github.com/insightflow",
discord_url = "https://discord.gg/insightflow",
api_base_url = "https://api.insightflow.io/v1",
)
self.created_ids["portal_config"].append(config.id)
self.log(f"Created portal config: {config.name}")
except Exception as e:
self.log(f"Failed to create portal config: {str(e)}", success=False)
self.log(f"Failed to create portal config: {str(e)}", success = False)
def test_portal_config_get(self):
def test_portal_config_get(self) -> None:
"""测试获取开发者门户配置"""
try:
if self.created_ids["portal_config"]:
config = self.manager.get_portal_config(self.created_ids["portal_config"][0])
config = self.manager.get_portal_config(self.created_ids["portal_config"][0])
if config:
self.log(f"Retrieved portal config: {config.name}")
# Test active config
active_config = self.manager.get_active_portal_config()
active_config = self.manager.get_active_portal_config()
if active_config:
self.log(f"Active portal config: {active_config.name}")
except Exception as e:
self.log(f"Failed to get portal config: {str(e)}", success=False)
self.log(f"Failed to get portal config: {str(e)}", success = False)
def test_revenue_record(self):
def test_revenue_record(self) -> None:
"""测试记录开发者收益"""
try:
if self.created_ids["developer"] and self.created_ids["plugin"]:
revenue = self.manager.record_revenue(
developer_id=self.created_ids["developer"][0],
item_type="plugin",
item_id=self.created_ids["plugin"][0],
item_name="飞书机器人集成插件",
sale_amount=49.0,
currency="CNY",
buyer_id="user_buyer_001",
transaction_id="txn_123456",
revenue = self.manager.record_revenue(
developer_id = self.created_ids["developer"][0],
item_type = "plugin",
item_id = self.created_ids["plugin"][0],
item_name = "飞书机器人集成插件",
sale_amount = 49.0,
currency = "CNY",
buyer_id = "user_buyer_001",
transaction_id = "txn_123456",
)
self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}")
self.log(f" - Platform fee: {revenue.platform_fee}")
self.log(f" - Developer earnings: {revenue.developer_earnings}")
except Exception as e:
self.log(f"Failed to record revenue: {str(e)}", success=False)
self.log(f"Failed to record revenue: {str(e)}", success = False)
def test_revenue_summary(self):
def test_revenue_summary(self) -> None:
"""测试获取开发者收益汇总"""
try:
if self.created_ids["developer"]:
summary = self.manager.get_developer_revenue_summary(
summary = self.manager.get_developer_revenue_summary(
self.created_ids["developer"][0]
)
self.log("Revenue summary for developer:")
@@ -659,17 +659,17 @@ console.log('Upload complete:', result.id);
self.log(f" - Total earnings: {summary['total_earnings']}")
self.log(f" - Transaction count: {summary['transaction_count']}")
except Exception as e:
self.log(f"Failed to get revenue summary: {str(e)}", success=False)
self.log(f"Failed to get revenue summary: {str(e)}", success = False)
def print_summary(self):
def print_summary(self) -> None:
"""打印测试摘要"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("Test Summary")
print("=" * 60)
print(" = " * 60)
total = len(self.test_results)
passed = sum(1 for r in self.test_results if r["success"])
failed = total - passed
total = len(self.test_results)
passed = sum(1 for r in self.test_results if r["success"])
failed = total - passed
print(f"Total tests: {total}")
print(f"Passed: {passed}")
@@ -686,12 +686,12 @@ console.log('Upload complete:', result.id);
if ids:
print(f" {resource_type}: {len(ids)}")
print("=" * 60)
print(" = " * 60)
def main():
def main() -> None:
"""主函数"""
test = TestDeveloperEcosystem()
test = TestDeveloperEcosystem()
test.run_all_tests()

View File

@@ -26,7 +26,7 @@ from ops_manager import (
)
# Add backend directory to path
backend_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
@@ -34,22 +34,22 @@ if backend_dir not in sys.path:
class TestOpsManager:
"""测试运维与监控管理器"""
def __init__(self):
self.manager = get_ops_manager()
self.tenant_id = "test_tenant_001"
self.test_results = []
def __init__(self) -> None:
self.manager = get_ops_manager()
self.tenant_id = "test_tenant_001"
self.test_results = []
def log(self, message: str, success: bool = True):
def log(self, message: str, success: bool = True) -> None:
"""记录测试结果"""
status = "" if success else ""
status = "" if success else ""
print(f"{status} {message}")
self.test_results.append((message, success))
def run_all_tests(self):
def run_all_tests(self) -> None:
"""运行所有测试"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests")
print("=" * 60)
print(" = " * 60)
# 1. 告警系统测试
self.test_alert_rules()
@@ -73,63 +73,63 @@ class TestOpsManager:
# 打印测试总结
self.print_summary()
def test_alert_rules(self):
def test_alert_rules(self) -> None:
"""测试告警规则管理"""
print("\n📋 Testing Alert Rules...")
try:
# 创建阈值告警规则
rule1 = self.manager.create_alert_rule(
tenant_id=self.tenant_id,
name="CPU 使用率告警",
description="当 CPU 使用率超过 80% 时触发告警",
rule_type=AlertRuleType.THRESHOLD,
severity=AlertSeverity.P1,
metric="cpu_usage_percent",
condition=">",
threshold=80.0,
duration=300,
evaluation_interval=60,
channels=[],
labels={"service": "api", "team": "platform"},
annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by="test_user",
rule1 = self.manager.create_alert_rule(
tenant_id = self.tenant_id,
name = "CPU 使用率告警",
description = "当 CPU 使用率超过 80% 时触发告警",
rule_type = AlertRuleType.THRESHOLD,
severity = AlertSeverity.P1,
metric = "cpu_usage_percent",
condition = ">",
threshold = 80.0,
duration = 300,
evaluation_interval = 60,
channels = [],
labels = {"service": "api", "team": "platform"},
annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by = "test_user",
)
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
# 创建异常检测告警规则
rule2 = self.manager.create_alert_rule(
tenant_id=self.tenant_id,
name="内存异常检测",
description="检测内存使用异常",
rule_type=AlertRuleType.ANOMALY,
severity=AlertSeverity.P2,
metric="memory_usage_percent",
condition=">",
threshold=0.0,
duration=600,
evaluation_interval=300,
channels=[],
labels={"service": "database"},
annotations={},
created_by="test_user",
rule2 = self.manager.create_alert_rule(
tenant_id = self.tenant_id,
name = "内存异常检测",
description = "检测内存使用异常",
rule_type = AlertRuleType.ANOMALY,
severity = AlertSeverity.P2,
metric = "memory_usage_percent",
condition = ">",
threshold = 0.0,
duration = 600,
evaluation_interval = 300,
channels = [],
labels = {"service": "database"},
annotations = {},
created_by = "test_user",
)
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
# 获取告警规则
fetched_rule = self.manager.get_alert_rule(rule1.id)
fetched_rule = self.manager.get_alert_rule(rule1.id)
assert fetched_rule is not None
assert fetched_rule.name == rule1.name
self.log(f"Fetched alert rule: {fetched_rule.name}")
# 列出租户的所有告警规则
rules = self.manager.list_alert_rules(self.tenant_id)
rules = self.manager.list_alert_rules(self.tenant_id)
assert len(rules) >= 2
self.log(f"Listed {len(rules)} alert rules for tenant")
# 更新告警规则
updated_rule = self.manager.update_alert_rule(
rule1.id, threshold=85.0, description="更新后的描述"
updated_rule = self.manager.update_alert_rule(
rule1.id, threshold = 85.0, description = "更新后的描述"
)
assert updated_rule.threshold == 85.0
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
@@ -140,57 +140,57 @@ class TestOpsManager:
self.log("Deleted test alert rules")
except Exception as e:
self.log(f"Alert rules test failed: {e}", success=False)
self.log(f"Alert rules test failed: {e}", success = False)
def test_alert_channels(self):
def test_alert_channels(self) -> None:
"""测试告警渠道管理"""
print("\n📢 Testing Alert Channels...")
try:
# 创建飞书告警渠道
channel1 = self.manager.create_alert_channel(
tenant_id=self.tenant_id,
name="飞书告警",
channel_type=AlertChannelType.FEISHU,
config={
channel1 = self.manager.create_alert_channel(
tenant_id = self.tenant_id,
name = "飞书告警",
channel_type = AlertChannelType.FEISHU,
config = {
"webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test",
"secret": "test_secret",
},
severity_filter=["p0", "p1"],
severity_filter = ["p0", "p1"],
)
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
# 创建钉钉告警渠道
channel2 = self.manager.create_alert_channel(
tenant_id=self.tenant_id,
name="钉钉告警",
channel_type=AlertChannelType.DINGTALK,
config={
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test",
channel2 = self.manager.create_alert_channel(
tenant_id = self.tenant_id,
name = "钉钉告警",
channel_type = AlertChannelType.DINGTALK,
config = {
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test",
"secret": "test_secret",
},
severity_filter=["p0", "p1", "p2"],
severity_filter = ["p0", "p1", "p2"],
)
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
# 创建 Slack 告警渠道
channel3 = self.manager.create_alert_channel(
tenant_id=self.tenant_id,
name="Slack 告警",
channel_type=AlertChannelType.SLACK,
config={"webhook_url": "https://hooks.slack.com/services/test"},
severity_filter=["p0", "p1", "p2", "p3"],
channel3 = self.manager.create_alert_channel(
tenant_id = self.tenant_id,
name = "Slack 告警",
channel_type = AlertChannelType.SLACK,
config = {"webhook_url": "https://hooks.slack.com/services/test"},
severity_filter = ["p0", "p1", "p2", "p3"],
)
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
# 获取告警渠道
fetched_channel = self.manager.get_alert_channel(channel1.id)
fetched_channel = self.manager.get_alert_channel(channel1.id)
assert fetched_channel is not None
assert fetched_channel.name == channel1.name
self.log(f"Fetched alert channel: {fetched_channel.name}")
# 列出租户的所有告警渠道
channels = self.manager.list_alert_channels(self.tenant_id)
channels = self.manager.list_alert_channels(self.tenant_id)
assert len(channels) >= 3
self.log(f"Listed {len(channels)} alert channels for tenant")
@@ -198,74 +198,74 @@ class TestOpsManager:
for channel in channels:
if channel.tenant_id == self.tenant_id:
with self.manager._get_db() as conn:
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,))
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id, ))
conn.commit()
self.log("Deleted test alert channels")
except Exception as e:
self.log(f"Alert channels test failed: {e}", success=False)
self.log(f"Alert channels test failed: {e}", success = False)
def test_alerts(self):
def test_alerts(self) -> None:
"""测试告警管理"""
print("\n🚨 Testing Alerts...")
try:
# 创建告警规则
rule = self.manager.create_alert_rule(
tenant_id=self.tenant_id,
name="测试告警规则",
description="用于测试的告警规则",
rule_type=AlertRuleType.THRESHOLD,
severity=AlertSeverity.P1,
metric="test_metric",
condition=">",
threshold=100.0,
duration=60,
evaluation_interval=60,
channels=[],
labels={},
annotations={},
created_by="test_user",
rule = self.manager.create_alert_rule(
tenant_id = self.tenant_id,
name = "测试告警规则",
description = "用于测试的告警规则",
rule_type = AlertRuleType.THRESHOLD,
severity = AlertSeverity.P1,
metric = "test_metric",
condition = ">",
threshold = 100.0,
duration = 60,
evaluation_interval = 60,
channels = [],
labels = {},
annotations = {},
created_by = "test_user",
)
# 记录资源指标
for i in range(10):
self.manager.record_resource_metric(
tenant_id=self.tenant_id,
resource_type=ResourceType.CPU,
resource_id="server-001",
metric_name="test_metric",
metric_value=110.0 + i,
unit="percent",
metadata={"region": "cn-north-1"},
tenant_id = self.tenant_id,
resource_type = ResourceType.CPU,
resource_id = "server-001",
metric_name = "test_metric",
metric_value = 110.0 + i,
unit = "percent",
metadata = {"region": "cn-north-1"},
)
self.log("Recorded 10 resource metrics")
# 手动创建告警
from ops_manager import Alert
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
now = datetime.now().isoformat()
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
now = datetime.now().isoformat()
alert = Alert(
id=alert_id,
rule_id=rule.id,
tenant_id=self.tenant_id,
severity=AlertSeverity.P1,
status=AlertStatus.FIRING,
title="测试告警",
description="这是一条测试告警",
metric="test_metric",
value=120.0,
threshold=100.0,
labels={"test": "true"},
annotations={},
started_at=now,
resolved_at=None,
acknowledged_by=None,
acknowledged_at=None,
notification_sent={},
suppression_count=0,
alert = Alert(
id = alert_id,
rule_id = rule.id,
tenant_id = self.tenant_id,
severity = AlertSeverity.P1,
status = AlertStatus.FIRING,
title = "测试告警",
description = "这是一条测试告警",
metric = "test_metric",
value = 120.0,
threshold = 100.0,
labels = {"test": "true"},
annotations = {},
started_at = now,
resolved_at = None,
acknowledged_by = None,
acknowledged_at = None,
notification_sent = {},
suppression_count = 0,
)
with self.manager._get_db() as conn:
@@ -299,20 +299,20 @@ class TestOpsManager:
self.log(f"Created test alert: {alert.id}")
# 列出租户的告警
alerts = self.manager.list_alerts(self.tenant_id)
alerts = self.manager.list_alerts(self.tenant_id)
assert len(alerts) >= 1
self.log(f"Listed {len(alerts)} alerts for tenant")
# 确认告警
self.manager.acknowledge_alert(alert_id, "test_user")
fetched_alert = self.manager.get_alert(alert_id)
fetched_alert = self.manager.get_alert(alert_id)
assert fetched_alert.status == AlertStatus.ACKNOWLEDGED
assert fetched_alert.acknowledged_by == "test_user"
self.log(f"Acknowledged alert: {alert_id}")
# 解决告警
self.manager.resolve_alert(alert_id)
fetched_alert = self.manager.get_alert(alert_id)
fetched_alert = self.manager.get_alert(alert_id)
assert fetched_alert.status == AlertStatus.RESOLVED
assert fetched_alert.resolved_at is not None
self.log(f"Resolved alert: {alert_id}")
@@ -320,23 +320,23 @@ class TestOpsManager:
# 清理
self.manager.delete_alert_rule(rule.id)
with self.manager._get_db() as conn:
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id,))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id, ))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up test data")
except Exception as e:
self.log(f"Alerts test failed: {e}", success=False)
self.log(f"Alerts test failed: {e}", success = False)
def test_capacity_planning(self):
def test_capacity_planning(self) -> None:
"""测试容量规划"""
print("\n📊 Testing Capacity Planning...")
try:
# 记录历史指标数据
base_time = datetime.now() - timedelta(days=30)
base_time = datetime.now() - timedelta(days = 30)
for i in range(30):
timestamp = (base_time + timedelta(days=i)).isoformat()
timestamp = (base_time + timedelta(days = i)).isoformat()
with self.manager._get_db() as conn:
conn.execute(
"""
@@ -360,13 +360,13 @@ class TestOpsManager:
self.log("Recorded 30 days of historical metrics")
# 创建容量规划
prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
plan = self.manager.create_capacity_plan(
tenant_id=self.tenant_id,
resource_type=ResourceType.CPU,
current_capacity=100.0,
prediction_date=prediction_date,
confidence=0.85,
prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d")
plan = self.manager.create_capacity_plan(
tenant_id = self.tenant_id,
resource_type = ResourceType.CPU,
current_capacity = 100.0,
prediction_date = prediction_date,
confidence = 0.85,
)
self.log(f"Created capacity plan: {plan.id}")
@@ -375,38 +375,38 @@ class TestOpsManager:
self.log(f" Recommended action: {plan.recommended_action}")
# 获取容量规划列表
plans = self.manager.get_capacity_plans(self.tenant_id)
plans = self.manager.get_capacity_plans(self.tenant_id)
assert len(plans) >= 1
self.log(f"Listed {len(plans)} capacity plans")
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up capacity planning test data")
except Exception as e:
self.log(f"Capacity planning test failed: {e}", success=False)
self.log(f"Capacity planning test failed: {e}", success = False)
def test_auto_scaling(self):
def test_auto_scaling(self) -> None:
"""测试自动扩缩容"""
print("\n⚖️ Testing Auto Scaling...")
try:
# 创建自动扩缩容策略
policy = self.manager.create_auto_scaling_policy(
tenant_id=self.tenant_id,
name="API 服务自动扩缩容",
resource_type=ResourceType.CPU,
min_instances=2,
max_instances=10,
target_utilization=0.7,
scale_up_threshold=0.8,
scale_down_threshold=0.3,
scale_up_step=2,
scale_down_step=1,
cooldown_period=300,
policy = self.manager.create_auto_scaling_policy(
tenant_id = self.tenant_id,
name = "API 服务自动扩缩容",
resource_type = ResourceType.CPU,
min_instances = 2,
max_instances = 10,
target_utilization = 0.7,
scale_up_threshold = 0.8,
scale_down_threshold = 0.3,
scale_up_step = 2,
scale_down_step = 1,
cooldown_period = 300,
)
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
@@ -415,13 +415,13 @@ class TestOpsManager:
self.log(f" Target utilization: {policy.target_utilization}")
# 获取策略列表
policies = self.manager.list_auto_scaling_policies(self.tenant_id)
policies = self.manager.list_auto_scaling_policies(self.tenant_id)
assert len(policies) >= 1
self.log(f"Listed {len(policies)} auto scaling policies")
# 模拟扩缩容评估
event = self.manager.evaluate_scaling_policy(
policy_id=policy.id, current_instances=3, current_utilization=0.85
event = self.manager.evaluate_scaling_policy(
policy_id = policy.id, current_instances = 3, current_utilization = 0.85
)
if event:
@@ -432,62 +432,62 @@ class TestOpsManager:
self.log("No scaling action needed")
# 获取扩缩容事件列表
events = self.manager.list_scaling_events(self.tenant_id)
events = self.manager.list_scaling_events(self.tenant_id)
self.log(f"Listed {len(events)} scaling events")
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute(
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,)
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id, )
)
conn.commit()
self.log("Cleaned up auto scaling test data")
except Exception as e:
self.log(f"Auto scaling test failed: {e}", success=False)
self.log(f"Auto scaling test failed: {e}", success = False)
def test_health_checks(self):
def test_health_checks(self) -> None:
"""测试健康检查"""
print("\n💓 Testing Health Checks...")
try:
# 创建 HTTP 健康检查
check1 = self.manager.create_health_check(
tenant_id=self.tenant_id,
name="API 服务健康检查",
target_type="service",
target_id="api-service",
check_type="http",
check_config={"url": "https://api.insightflow.io/health", "expected_status": 200},
interval=60,
timeout=10,
retry_count=3,
check1 = self.manager.create_health_check(
tenant_id = self.tenant_id,
name = "API 服务健康检查",
target_type = "service",
target_id = "api-service",
check_type = "http",
check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200},
interval = 60,
timeout = 10,
retry_count = 3,
)
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
# 创建 TCP 健康检查
check2 = self.manager.create_health_check(
tenant_id=self.tenant_id,
name="数据库健康检查",
target_type="database",
target_id="postgres-001",
check_type="tcp",
check_config={"host": "db.insightflow.io", "port": 5432},
interval=30,
timeout=5,
retry_count=2,
check2 = self.manager.create_health_check(
tenant_id = self.tenant_id,
name = "数据库健康检查",
target_type = "database",
target_id = "postgres-001",
check_type = "tcp",
check_config = {"host": "db.insightflow.io", "port": 5432},
interval = 30,
timeout = 5,
retry_count = 2,
)
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
# 获取健康检查列表
checks = self.manager.list_health_checks(self.tenant_id)
checks = self.manager.list_health_checks(self.tenant_id)
assert len(checks) >= 2
self.log(f"Listed {len(checks)} health checks")
# 执行健康检查(异步)
async def run_health_check():
result = await self.manager.execute_health_check(check1.id)
async def run_health_check() -> None:
result = await self.manager.execute_health_check(check1.id)
return result
# 由于健康检查需要网络,这里只验证方法存在
@@ -495,28 +495,28 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up health check test data")
except Exception as e:
self.log(f"Health checks test failed: {e}", success=False)
self.log(f"Health checks test failed: {e}", success = False)
def test_failover(self):
def test_failover(self) -> None:
"""测试故障转移"""
print("\n🔄 Testing Failover...")
try:
# 创建故障转移配置
config = self.manager.create_failover_config(
tenant_id=self.tenant_id,
name="主备数据中心故障转移",
primary_region="cn-north-1",
secondary_regions=["cn-south-1", "cn-east-1"],
failover_trigger="health_check_failed",
auto_failover=False,
failover_timeout=300,
health_check_id=None,
config = self.manager.create_failover_config(
tenant_id = self.tenant_id,
name = "主备数据中心故障转移",
primary_region = "cn-north-1",
secondary_regions = ["cn-south-1", "cn-east-1"],
failover_trigger = "health_check_failed",
auto_failover = False,
failover_timeout = 300,
health_check_id = None,
)
self.log(f"Created failover config: {config.name} (ID: {config.id})")
@@ -524,13 +524,13 @@ class TestOpsManager:
self.log(f" Secondary regions: {config.secondary_regions}")
# 获取故障转移配置列表
configs = self.manager.list_failover_configs(self.tenant_id)
configs = self.manager.list_failover_configs(self.tenant_id)
assert len(configs) >= 1
self.log(f"Listed {len(configs)} failover configs")
# 发起故障转移
event = self.manager.initiate_failover(
config_id=config.id, reason="Primary region health check failed"
event = self.manager.initiate_failover(
config_id = config.id, reason = "Primary region health check failed"
)
if event:
@@ -540,41 +540,41 @@ class TestOpsManager:
# 更新故障转移状态
self.manager.update_failover_status(event.id, "completed")
updated_event = self.manager.get_failover_event(event.id)
updated_event = self.manager.get_failover_event(event.id)
assert updated_event.status == "completed"
self.log("Failover completed")
# 获取故障转移事件列表
events = self.manager.list_failover_events(self.tenant_id)
events = self.manager.list_failover_events(self.tenant_id)
self.log(f"Listed {len(events)} failover events")
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up failover test data")
except Exception as e:
self.log(f"Failover test failed: {e}", success=False)
self.log(f"Failover test failed: {e}", success = False)
def test_backup(self):
def test_backup(self) -> None:
"""测试备份与恢复"""
print("\n💾 Testing Backup & Recovery...")
try:
# 创建备份任务
job = self.manager.create_backup_job(
tenant_id=self.tenant_id,
name="每日数据库备份",
backup_type="full",
target_type="database",
target_id="postgres-main",
schedule="0 2 * * *", # 每天凌晨2点
retention_days=30,
encryption_enabled=True,
compression_enabled=True,
storage_location="s3://insightflow-backups/",
job = self.manager.create_backup_job(
tenant_id = self.tenant_id,
name = "每日数据库备份",
backup_type = "full",
target_type = "database",
target_id = "postgres-main",
schedule = "0 2 * * *", # 每天凌晨2点
retention_days = 30,
encryption_enabled = True,
compression_enabled = True,
storage_location = "s3://insightflow-backups/",
)
self.log(f"Created backup job: {job.name} (ID: {job.id})")
@@ -582,12 +582,12 @@ class TestOpsManager:
self.log(f" Retention: {job.retention_days} days")
# 获取备份任务列表
jobs = self.manager.list_backup_jobs(self.tenant_id)
jobs = self.manager.list_backup_jobs(self.tenant_id)
assert len(jobs) >= 1
self.log(f"Listed {len(jobs)} backup jobs")
# 执行备份
record = self.manager.execute_backup(job.id)
record = self.manager.execute_backup(job.id)
if record:
self.log(f"Executed backup: {record.id}")
@@ -595,50 +595,50 @@ class TestOpsManager:
self.log(f" Storage: {record.storage_path}")
# 获取备份记录列表
records = self.manager.list_backup_records(self.tenant_id)
records = self.manager.list_backup_records(self.tenant_id)
self.log(f"Listed {len(records)} backup records")
# 测试恢复(模拟)
restore_result = self.manager.restore_from_backup(record.id)
restore_result = self.manager.restore_from_backup(record.id)
self.log(f"Restore test result: {restore_result}")
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up backup test data")
except Exception as e:
self.log(f"Backup test failed: {e}", success=False)
self.log(f"Backup test failed: {e}", success = False)
def test_cost_optimization(self):
def test_cost_optimization(self) -> None:
"""测试成本优化"""
print("\n💰 Testing Cost Optimization...")
try:
# 记录资源利用率数据
report_date = datetime.now().strftime("%Y-%m-%d")
report_date = datetime.now().strftime("%Y-%m-%d")
for i in range(5):
self.manager.record_resource_utilization(
tenant_id=self.tenant_id,
resource_type=ResourceType.CPU,
resource_id=f"server-{i:03d}",
utilization_rate=0.05 + random.random() * 0.1, # 低利用率
peak_utilization=0.15,
avg_utilization=0.08,
idle_time_percent=0.85,
report_date=report_date,
recommendations=["Consider downsizing this resource"],
tenant_id = self.tenant_id,
resource_type = ResourceType.CPU,
resource_id = f"server-{i:03d}",
utilization_rate = 0.05 + random.random() * 0.1, # 低利用率
peak_utilization = 0.15,
avg_utilization = 0.08,
idle_time_percent = 0.85,
report_date = report_date,
recommendations = ["Consider downsizing this resource"],
)
self.log("Recorded 5 resource utilization records")
# 生成成本报告
now = datetime.now()
report = self.manager.generate_cost_report(
tenant_id=self.tenant_id, year=now.year, month=now.month
now = datetime.now()
report = self.manager.generate_cost_report(
tenant_id = self.tenant_id, year = now.year, month = now.month
)
self.log(f"Generated cost report: {report.id}")
@@ -647,11 +647,11 @@ class TestOpsManager:
self.log(f" Anomalies detected: {len(report.anomalies)}")
# 检测闲置资源
idle_resources = self.manager.detect_idle_resources(self.tenant_id)
idle_resources = self.manager.detect_idle_resources(self.tenant_id)
self.log(f"Detected {len(idle_resources)} idle resources")
# 获取闲置资源列表
idle_list = self.manager.get_idle_resources(self.tenant_id)
idle_list = self.manager.get_idle_resources(self.tenant_id)
for resource in idle_list:
self.log(
f" Idle resource: {resource.resource_name} (est. cost: {
@@ -660,7 +660,7 @@ class TestOpsManager:
)
# 生成成本优化建议
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
self.log(f"Generated {len(suggestions)} cost optimization suggestions")
for suggestion in suggestions:
@@ -672,12 +672,12 @@ class TestOpsManager:
self.log(f" Difficulty: {suggestion.difficulty}")
# 获取优化建议列表
all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id)
all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id)
self.log(f"Listed {len(all_suggestions)} optimization suggestions")
# 应用优化建议
if all_suggestions:
applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id)
applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id)
if applied:
self.log(f"Applied optimization suggestion: {applied.title}")
assert applied.is_applied
@@ -686,29 +686,29 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute(
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
(self.tenant_id,),
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
(self.tenant_id, ),
)
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute(
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,)
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id, )
)
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up cost optimization test data")
except Exception as e:
self.log(f"Cost optimization test failed: {e}", success=False)
self.log(f"Cost optimization test failed: {e}", success = False)
def print_summary(self):
def print_summary(self) -> None:
"""打印测试总结"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("Test Summary")
print("=" * 60)
print(" = " * 60)
total = len(self.test_results)
passed = sum(1 for _, success in self.test_results if success)
failed = total - passed
total = len(self.test_results)
passed = sum(1 for _, success in self.test_results if success)
failed = total - passed
print(f"Total tests: {total}")
print(f"Passed: {passed}")
@@ -720,12 +720,12 @@ class TestOpsManager:
if not success:
print(f"{message}")
print("=" * 60)
print(" = " * 60)
def main():
def main() -> None:
"""主函数"""
test = TestOpsManager()
test = TestOpsManager()
test.run_all_tests()

View File

@@ -10,19 +10,19 @@ from typing import Any
class TingwuClient:
def __init__(self):
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
def __init__(self) -> None:
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
if not self.access_key or not self.secret_key:
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
def _sign_request(
self, method: str, uri: str, query: str = "", body: str = ""
self, method: str, uri: str, query: str = "", body: str = ""
) -> dict[str, str]:
"""阿里云签名 V3"""
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
# 简化签名,实际生产需要完整实现
# 这里使用基础认证头
@@ -31,10 +31,10 @@ class TingwuClient:
"x-acs-action": "CreateTask",
"x-acs-version": "2023-09-30",
"x-acs-date": timestamp,
"Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing",
"Authorization": f"ACS3-HMAC-SHA256 Credential = {self.access_key}/acs/tingwu/cn-beijing",
}
def create_task(self, audio_url: str, language: str = "zh") -> str:
def create_task(self, audio_url: str, language: str = "zh") -> str:
"""创建听悟任务"""
try:
# 导入移到文件顶部会导致循环导入,保持在这里
@@ -42,23 +42,23 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key
config = open_api_models.Config(
access_key_id = self.access_key, access_key_secret = self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
request = tingwu_models.CreateTaskRequest(
type="offline",
input=tingwu_models.Input(source="OSS", file_url=audio_url),
parameters=tingwu_models.Parameters(
transcription=tingwu_models.Transcription(
diarization_enabled=True, sentence_max_length=20
request = tingwu_models.CreateTaskRequest(
type = "offline",
input = tingwu_models.Input(source = "OSS", file_url = audio_url),
parameters = tingwu_models.Parameters(
transcription = tingwu_models.Transcription(
diarization_enabled = True, sentence_max_length = 20
)
),
)
response = client.create_task(request)
response = client.create_task(request)
if response.body.code == "0":
return response.body.data.task_id
else:
@@ -73,29 +73,26 @@ class TingwuClient:
return f"mock_task_{int(time.time())}"
def get_task_result(
self, task_id: str, max_retries: int = 60, interval: int = 5
self, task_id: str, max_retries: int = 60, interval: int = 5
) -> dict[str, Any]:
"""获取任务结果"""
try:
# 导入移到文件顶部会导致循环导入,保持在这里
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key
config = open_api_models.Config(
access_key_id = self.access_key, access_key_secret = self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
for i in range(max_retries):
request = tingwu_models.GetTaskInfoRequest()
response = client.get_task_info(task_id, request)
request = tingwu_models.GetTaskInfoRequest()
response = client.get_task_info(task_id, request)
if response.body.code != "0":
raise RuntimeError(f"Query failed: {response.body.message}")
status = response.body.data.task_status
status = response.body.data.task_status
if status == "SUCCESS":
return self._parse_result(response.body.data)
@@ -116,11 +113,11 @@ class TingwuClient:
def _parse_result(self, data) -> dict[str, Any]:
"""解析结果"""
result = data.result
transcription = result.transcription
result = data.result
transcription = result.transcription
full_text = ""
segments = []
full_text = ""
segments = []
if transcription.paragraphs:
for para in transcription.paragraphs:
@@ -153,8 +150,8 @@ class TingwuClient:
],
}
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
"""一键转录"""
task_id = self.create_task(audio_url, language)
task_id = self.create_task(audio_url, language)
print(f"Tingwu task: {task_id}")
return self.get_task_result(task_id)

File diff suppressed because it is too large Load Diff

View File

@@ -11,10 +11,10 @@ from pathlib import Path
from typing import Any
# 项目路径
PROJECT_PATH = Path("/root/.openclaw/workspace/projects/insightflow")
PROJECT_PATH = Path("/root/.openclaw/workspace/projects/insightflow")
# 修复报告
report = {
report = {
"fixed": [],
"manual_review": [],
"errors": []
@@ -22,7 +22,7 @@ report = {
def find_python_files() -> list[Path]:
"""查找所有 Python 文件"""
py_files = []
py_files = []
for py_file in PROJECT_PATH.rglob("*.py"):
if "__pycache__" not in str(py_file):
py_files.append(py_file)
@@ -30,12 +30,12 @@ def find_python_files() -> list[Path]:
def check_duplicate_imports(content: str, file_path: Path) -> list[dict]:
"""检查重复导入"""
issues = []
lines = content.split('\n')
imports = {}
issues = []
lines = content.split('\n')
imports = {}
for i, line in enumerate(lines, 1):
line_stripped = line.strip()
line_stripped = line.strip()
if line_stripped.startswith('import ') or line_stripped.startswith('from '):
if line_stripped in imports:
issues.append({
@@ -45,17 +45,17 @@ def check_duplicate_imports(content: str, file_path: Path) -> list[dict]:
"original_line": imports[line_stripped]
})
else:
imports[line_stripped] = i
imports[line_stripped] = i
return issues
def check_bare_excepts(content: str, file_path: Path) -> list[dict]:
"""检查裸异常捕获"""
issues = []
lines = content.split('\n')
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
stripped = line.strip()
# 检查 except Exception: 或 except :
stripped = line.strip()
# 检查 except Exception: 或 except Exception:
if re.match(r'^except\s*:', stripped):
issues.append({
"line": i,
@@ -66,9 +66,9 @@ def check_bare_excepts(content: str, file_path: Path) -> list[dict]:
def check_line_length(content: str, file_path: Path) -> list[dict]:
"""检查行长度PEP8: 79字符这里放宽到 100"""
issues = []
lines = content.split('\n')
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
if len(line) > 100:
issues.append({
@@ -81,24 +81,24 @@ def check_line_length(content: str, file_path: Path) -> list[dict]:
def check_unused_imports(content: str, file_path: Path) -> list[dict]:
"""检查未使用的导入"""
issues = []
issues = []
try:
tree = ast.parse(content)
imports = {}
used_names = set()
tree = ast.parse(content)
imports = {}
used_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports[alias.asname or alias.name] = node
imports[alias.asname or alias.name] = node
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname or alias.name
name = alias.asname or alias.name
if name != '*':
imports[name] = node
imports[name] = node
elif isinstance(node, ast.Name):
used_names.add(node.id)
for name, node in imports.items():
if name not in used_names and not name.startswith('_'):
issues.append({
@@ -112,9 +112,9 @@ def check_unused_imports(content: str, file_path: Path) -> list[dict]:
def check_string_formatting(content: str, file_path: Path) -> list[dict]:
"""检查混合字符串格式化(建议使用 f-string"""
issues = []
lines = content.split('\n')
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
# 检查 % 格式化
if re.search(r'["\'].*%\s*\w+', line) and '%' in line:
@@ -136,18 +136,18 @@ def check_string_formatting(content: str, file_path: Path) -> list[dict]:
def check_magic_numbers(content: str, file_path: Path) -> list[dict]:
"""检查魔法数字"""
issues = []
lines = content.split('\n')
issues = []
lines = content.split('\n')
# 常见魔法数字模式(排除常见索引和简单值)
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3,})(?![\w\d_])')
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3, })(?![\w\d_])')
for i, line in enumerate(lines, 1):
if line.strip().startswith('#'):
continue
matches = magic_pattern.findall(line)
matches = magic_pattern.findall(line)
for match in matches:
num = int(match)
num = int(match)
# 排除常见值
if num not in [200, 201, 204, 301, 302, 400, 401, 403, 404, 429, 500, 502, 503, 3600, 86400]:
issues.append({
@@ -160,9 +160,9 @@ def check_magic_numbers(content: str, file_path: Path) -> list[dict]:
def check_sql_injection(content: str, file_path: Path) -> list[dict]:
"""检查 SQL 注入风险"""
issues = []
lines = content.split('\n')
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
# 检查字符串拼接的 SQL
if 'execute(' in line or 'executescript(' in line or 'executemany(' in line:
@@ -179,9 +179,9 @@ def check_sql_injection(content: str, file_path: Path) -> list[dict]:
def check_cors_config(content: str, file_path: Path) -> list[dict]:
"""检查 CORS 配置"""
issues = []
lines = content.split('\n')
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
if 'allow_origins' in line and '["*"]' in line:
issues.append({
@@ -194,48 +194,48 @@ def check_cors_config(content: str, file_path: Path) -> list[dict]:
def fix_bare_excepts(content: str) -> str:
"""修复裸异常捕获"""
lines = content.split('\n')
new_lines = []
lines = content.split('\n')
new_lines = []
for line in lines:
stripped = line.strip()
stripped = line.strip()
if re.match(r'^except\s*:', stripped):
# 替换为具体异常
indent = len(line) - len(line.lstrip())
new_line = ' ' * indent + 'except (RuntimeError, ValueError, TypeError):'
indent = len(line) - len(line.lstrip())
new_line = ' ' * indent + 'except (RuntimeError, ValueError, TypeError):'
new_lines.append(new_line)
else:
new_lines.append(line)
return '\n'.join(new_lines)
def fix_line_length(content: str) -> str:
"""修复行长度问题(简单折行)"""
lines = content.split('\n')
new_lines = []
lines = content.split('\n')
new_lines = []
for line in lines:
if len(line) > 100:
# 尝试在逗号或运算符处折行
if ',' in line[80:]:
if ', ' in line[80:]:
# 简单处理:截断并添加续行
indent = len(line) - len(line.lstrip())
indent = len(line) - len(line.lstrip())
new_lines.append(line)
else:
new_lines.append(line)
else:
new_lines.append(line)
return '\n'.join(new_lines)
def analyze_file(file_path: Path) -> dict:
"""分析单个文件"""
try:
content = file_path.read_text(encoding='utf-8')
content = file_path.read_text(encoding = 'utf-8')
except Exception as e:
return {"error": str(e)}
issues = {
issues = {
"duplicate_imports": check_duplicate_imports(content, file_path),
"bare_excepts": check_bare_excepts(content, file_path),
"line_length": check_line_length(content, file_path),
@@ -245,22 +245,22 @@ def analyze_file(file_path: Path) -> dict:
"sql_injection": check_sql_injection(content, file_path),
"cors_config": check_cors_config(content, file_path),
}
return issues
def fix_file(file_path: Path, issues: dict) -> bool:
"""自动修复文件问题"""
try:
content = file_path.read_text(encoding='utf-8')
original_content = content
content = file_path.read_text(encoding = 'utf-8')
original_content = content
# 修复裸异常
if issues.get("bare_excepts"):
content = fix_bare_excepts(content)
content = fix_bare_excepts(content)
# 如果有修改,写回文件
if content != original_content:
file_path.write_text(content, encoding='utf-8')
file_path.write_text(content, encoding = 'utf-8')
return True
return False
except Exception as e:
@@ -269,88 +269,88 @@ def fix_file(file_path: Path, issues: dict) -> bool:
def generate_report(all_issues: dict) -> str:
"""生成修复报告"""
lines = []
lines = []
lines.append("# InsightFlow 代码审查报告")
lines.append(f"\n生成时间: {__import__('datetime').datetime.now().isoformat()}")
lines.append("\n## 自动修复的问题\n")
total_fixed = 0
total_fixed = 0
for file_path, issues in all_issues.items():
fixed_count = 0
fixed_count = 0
for issue_type, issue_list in issues.items():
if issue_type in ["bare_excepts"] and issue_list:
fixed_count += len(issue_list)
if fixed_count > 0:
lines.append(f"### {file_path}")
lines.append(f"- 修复裸异常捕获: {fixed_count}")
total_fixed += fixed_count
if total_fixed == 0:
lines.append("未发现需要自动修复的问题。")
lines.append(f"\n**总计自动修复: {total_fixed} 处**")
lines.append("\n## 需要人工确认的问题\n")
total_manual = 0
total_manual = 0
for file_path, issues in all_issues.items():
manual_issues = []
manual_issues = []
if issues.get("sql_injection"):
manual_issues.extend(issues["sql_injection"])
if issues.get("cors_config"):
manual_issues.extend(issues["cors_config"])
if manual_issues:
lines.append(f"### {file_path}")
for issue in manual_issues:
lines.append(f"- **{issue['type']}** (第 {issue['line']} 行): {issue.get('content', '')}")
total_manual += len(manual_issues)
if total_manual == 0:
lines.append("未发现需要人工确认的问题。")
lines.append(f"\n**总计待确认: {total_manual} 处**")
lines.append("\n## 代码风格建议\n")
for file_path, issues in all_issues.items():
style_issues = []
style_issues = []
if issues.get("line_length"):
style_issues.extend(issues["line_length"])
if issues.get("string_formatting"):
style_issues.extend(issues["string_formatting"])
if issues.get("magic_numbers"):
style_issues.extend(issues["magic_numbers"])
if style_issues:
lines.append(f"### {file_path}")
for issue in style_issues[:5]: # 只显示前5个
lines.append(f"- 第 {issue['line']} 行: {issue['type']}")
if len(style_issues) > 5:
lines.append(f"- ... 还有 {len(style_issues) - 5} 个类似问题")
return '\n'.join(lines)
def git_commit_and_push():
def git_commit_and_push() -> None:
"""提交并推送代码"""
try:
os.chdir(PROJECT_PATH)
# 检查是否有修改
result = subprocess.run(
result = subprocess.run(
["git", "status", "--porcelain"],
capture_output=True,
text=True
capture_output = True,
text = True
)
if not result.stdout.strip():
return "没有需要提交的更改"
# 添加所有修改
subprocess.run(["git", "add", "-A"], check=True)
subprocess.run(["git", "add", "-A"], check = True)
# 提交
subprocess.run(
["git", "commit", "-m", """fix: auto-fix code issues (cron)
@@ -359,52 +359,52 @@ def git_commit_and_push():
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解"""],
check=True
check = True
)
# 推送
subprocess.run(["git", "push"], check=True)
subprocess.run(["git", "push"], check = True)
return "✅ 提交并推送成功"
except subprocess.CalledProcessError as e:
return f"❌ Git 操作失败: {e}"
except Exception as e:
return f"❌ 错误: {e}"
def main():
def main() -> None:
"""主函数"""
print("🔍 开始代码审查...")
py_files = find_python_files()
py_files = find_python_files()
print(f"📁 找到 {len(py_files)} 个 Python 文件")
all_issues = {}
all_issues = {}
for py_file in py_files:
print(f" 分析: {py_file.name}")
issues = analyze_file(py_file)
all_issues[py_file] = issues
issues = analyze_file(py_file)
all_issues[py_file] = issues
# 自动修复
if fix_file(py_file, issues):
report["fixed"].append(str(py_file))
# 生成报告
report_content = generate_report(all_issues)
report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md"
report_path.write_text(report_content, encoding='utf-8')
report_content = generate_report(all_issues)
report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md"
report_path.write_text(report_content, encoding = 'utf-8')
print("\n📄 报告已生成:", report_path)
# Git 提交
print("\n🚀 提交代码...")
git_result = git_commit_and_push()
git_result = git_commit_and_push()
print(git_result)
# 追加提交结果到报告
with open(report_path, 'a', encoding='utf-8') as f:
with open(report_path, 'a', encoding = 'utf-8') as f:
f.write(f"\n\n## Git 提交结果\n\n{git_result}\n")
print("\n✅ 代码审查完成!")
return report_content

View File

@@ -15,25 +15,25 @@ class CodeIssue:
line_no: int,
issue_type: str,
message: str,
severity: str = "info",
):
self.file_path = file_path
self.line_no = line_no
self.issue_type = issue_type
self.message = message
self.severity = severity # info, warning, error
self.fixed = False
severity: str = "info",
) -> None:
self.file_path = file_path
self.line_no = line_no
self.issue_type = issue_type
self.message = message
self.severity = severity # info, warning, error
self.fixed = False
def __repr__(self):
def __repr__(self) -> None:
return f"{self.severity.upper()}: {self.file_path}:{self.line_no} - {self.issue_type}: {self.message}"
class CodeReviewer:
def __init__(self, base_path: str):
self.base_path = Path(base_path)
self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = []
self.manual_review_issues: list[CodeIssue] = []
def __init__(self, base_path: str) -> None:
self.base_path = Path(base_path)
self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = []
self.manual_review_issues: list[CodeIssue] = []
def scan_all(self) -> None:
"""扫描所有 Python 文件"""
@@ -45,14 +45,14 @@ class CodeReviewer:
def scan_file(self, file_path: Path) -> None:
"""扫描单个文件"""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
with open(file_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception as e:
print(f"Error reading {file_path}: {e}")
return
rel_path = str(file_path.relative_to(self.base_path))
rel_path = str(file_path.relative_to(self.base_path))
# 1. 检查裸异常捕获
self._check_bare_exceptions(content, lines, rel_path)
@@ -92,7 +92,7 @@ class CodeReviewer:
# 跳过有注释说明的情况
if "# noqa" in line or "# intentional" in line.lower():
continue
issue = CodeIssue(
issue = CodeIssue(
file_path,
i,
"bare_exception",
@@ -105,17 +105,17 @@ class CodeReviewer:
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查重复导入"""
imports = {}
imports = {}
for i, line in enumerate(lines, 1):
match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip())
match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip())
if match:
module = match.group(1) or ""
names = match.group(2).split(",")
module = match.group(1) or ""
names = match.group(2).split(", ")
for name in names:
name = name.strip().split()[0] # 处理 'as' 别名
key = f"{module}.{name}" if module else name
name = name.strip().split()[0] # 处理 'as' 别名
key = f"{module}.{name}" if module else name
if key in imports:
issue = CodeIssue(
issue = CodeIssue(
file_path,
i,
"duplicate_import",
@@ -123,7 +123,7 @@ class CodeReviewer:
"warning",
)
self.issues.append(issue)
imports[key] = i
imports[key] = i
def _check_pep8_issues(
self, content: str, lines: list[str], file_path: str
@@ -132,7 +132,7 @@ class CodeReviewer:
for i, line in enumerate(lines, 1):
# 行长度超过 120
if len(line) > 120:
issue = CodeIssue(
issue = CodeIssue(
file_path,
i,
"line_too_long",
@@ -143,7 +143,7 @@ class CodeReviewer:
# 行尾空格
if line.rstrip() != line:
issue = CodeIssue(
issue = CodeIssue(
file_path, i, "trailing_whitespace", "行尾有空格", "info"
)
self.issues.append(issue)
@@ -151,7 +151,7 @@ class CodeReviewer:
# 多余的空行
if i > 1 and line.strip() == "" and lines[i - 2].strip() == "":
if i < len(lines) and lines[i].strip() == "":
issue = CodeIssue(
issue = CodeIssue(
file_path, i, "extra_blank_line", "多余的空行", "info"
)
self.issues.append(issue)
@@ -161,23 +161,23 @@ class CodeReviewer:
) -> None:
"""检查未使用的导入"""
try:
tree = ast.parse(content)
tree = ast.parse(content)
except SyntaxError:
return
imported_names = {}
used_names = set()
imported_names = {}
used_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
imported_names[name] = node.lineno
name = alias.asname if alias.asname else alias.name
imported_names[name] = node.lineno
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
name = alias.asname if alias.asname else alias.name
if name != "*":
imported_names[name] = node.lineno
imported_names[name] = node.lineno
elif isinstance(node, ast.Name):
used_names.add(node.id)
@@ -186,7 +186,7 @@ class CodeReviewer:
# 排除一些常见例外
if name in ["annotations", "TYPE_CHECKING"]:
continue
issue = CodeIssue(
issue = CodeIssue(
file_path, lineno, "unused_import", f"未使用的导入: {name}", "info"
)
self.issues.append(issue)
@@ -195,20 +195,20 @@ class CodeReviewer:
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查混合字符串格式化"""
has_fstring = False
has_percent = False
has_format = False
has_fstring = False
has_percent = False
has_format = False
for i, line in enumerate(lines, 1):
if re.search(r'f["\']', line):
has_fstring = True
has_fstring = True
if re.search(r"%[sdfr]", line) and not re.search(r"\d+%", line):
has_percent = True
has_percent = True
if ".format(" in line:
has_format = True
has_format = True
if has_fstring and (has_percent or has_format):
issue = CodeIssue(
issue = CodeIssue(
file_path,
0,
"mixed_formatting",
@@ -222,25 +222,25 @@ class CodeReviewer:
) -> None:
"""检查魔法数字"""
# 常见的魔法数字模式
magic_patterns = [
(r"=\s*(\d{3,})\s*[^:]", "可能的魔法数字"),
(r"timeout\s*=\s*(\d+)", "timeout 魔法数字"),
(r"limit\s*=\s*(\d+)", "limit 魔法数字"),
(r"port\s*=\s*(\d+)", "port 魔法数字"),
magic_patterns = [
(r" = \s*(\d{3, })\s*[^:]", "可能的魔法数字"),
(r"timeout\s* = \s*(\d+)", "timeout 魔法数字"),
(r"limit\s* = \s*(\d+)", "limit 魔法数字"),
(r"port\s* = \s*(\d+)", "port 魔法数字"),
]
for i, line in enumerate(lines, 1):
# 跳过注释和字符串
code_part = line.split("#")[0]
code_part = line.split("#")[0]
if not code_part.strip():
continue
for pattern, msg in magic_patterns:
if re.search(pattern, code_part, re.IGNORECASE):
# 排除常见的合理数字
match = re.search(r"(\d{3,})", code_part)
match = re.search(r"(\d{3, })", code_part)
if match:
num = int(match.group(1))
num = int(match.group(1))
if num in [
200,
404,
@@ -257,7 +257,7 @@ class CodeReviewer:
8000,
]:
continue
issue = CodeIssue(
issue = CodeIssue(
file_path, i, "magic_number", f"{msg}: {num}", "info"
)
self.issues.append(issue)
@@ -272,7 +272,7 @@ class CodeReviewer:
r'execute\s*\(\s*f["\']', line
):
if "?" not in line and "%s" in line:
issue = CodeIssue(
issue = CodeIssue(
file_path,
i,
"sql_injection_risk",
@@ -287,7 +287,7 @@ class CodeReviewer:
"""检查 CORS 配置"""
for i, line in enumerate(lines, 1):
if "allow_origins" in line and '["*"]' in line:
issue = CodeIssue(
issue = CodeIssue(
file_path,
i,
"cors_wildcard",
@@ -303,7 +303,7 @@ class CodeReviewer:
for i, line in enumerate(lines, 1):
# 检查硬编码密钥
if re.search(
r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']',
r'(password|secret|key|token)\s* = \s*["\'][^"\']+["\']',
line,
re.IGNORECASE,
):
@@ -316,7 +316,7 @@ class CodeReviewer:
if not re.search(r'["\']\*+["\']', line) and not re.search(
r'["\']<[^"\']*>["\']', line
):
issue = CodeIssue(
issue = CodeIssue(
file_path,
i,
"hardcoded_secret",
@@ -328,62 +328,62 @@ class CodeReviewer:
def auto_fix(self) -> None:
"""自动修复问题"""
# 按文件分组问题
issues_by_file: dict[str, list[CodeIssue]] = {}
issues_by_file: dict[str, list[CodeIssue]] = {}
for issue in self.issues:
if issue.file_path not in issues_by_file:
issues_by_file[issue.file_path] = []
issues_by_file[issue.file_path] = []
issues_by_file[issue.file_path].append(issue)
for file_path, issues in issues_by_file.items():
full_path = self.base_path / file_path
full_path = self.base_path / file_path
if not full_path.exists():
continue
try:
with open(full_path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
with open(full_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception as e:
print(f"Error reading {full_path}: {e}")
continue
original_lines = lines.copy()
original_lines = lines.copy()
# 修复行尾空格
for issue in issues:
if issue.issue_type == "trailing_whitespace":
idx = issue.line_no - 1
idx = issue.line_no - 1
if 0 <= idx < len(lines):
lines[idx] = lines[idx].rstrip()
issue.fixed = True
lines[idx] = lines[idx].rstrip()
issue.fixed = True
# 修复裸异常
for issue in issues:
if issue.issue_type == "bare_exception":
idx = issue.line_no - 1
idx = issue.line_no - 1
if 0 <= idx < len(lines):
line = lines[idx]
# 将 except: 改为 except Exception:
line = lines[idx]
# 将 except Exception: 改为 except Exception:
if re.search(r"except\s*:\s*$", line.strip()):
lines[idx] = line.replace("except:", "except Exception:")
issue.fixed = True
lines[idx] = line.replace("except Exception:", "except Exception:")
issue.fixed = True
elif re.search(r"except\s+Exception\s*:\s*$", line.strip()):
# 已经是 Exception但可能需要更具体
pass
# 如果文件有修改,写回
if lines != original_lines:
with open(full_path, "w", encoding="utf-8") as f:
with open(full_path, "w", encoding = "utf-8") as f:
f.write("\n".join(lines))
print(f"Fixed issues in {file_path}")
# 移动到已修复列表
self.fixed_issues = [i for i in self.issues if i.fixed]
self.issues = [i for i in self.issues if not i.fixed]
self.fixed_issues = [i for i in self.issues if i.fixed]
self.issues = [i for i in self.issues if not i.fixed]
def generate_report(self) -> str:
"""生成审查报告"""
report = []
report = []
report.append("# InsightFlow 代码审查报告")
report.append(f"\n扫描路径: {self.base_path}")
report.append(f"扫描时间: {__import__('datetime').datetime.now().isoformat()}")
@@ -421,9 +421,9 @@ class CodeReviewer:
return "\n".join(report)
def main():
base_path = "/root/.openclaw/workspace/projects/insightflow/backend"
reviewer = CodeReviewer(base_path)
def main() -> None:
base_path = "/root/.openclaw/workspace/projects/insightflow/backend"
reviewer = CodeReviewer(base_path)
print("开始扫描代码...")
reviewer.scan_all()
@@ -437,9 +437,9 @@ def main():
print(f"\n已修复 {len(reviewer.fixed_issues)} 个问题")
# 生成报告
report = reviewer.generate_report()
report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding="utf-8") as f:
report = reviewer.generate_report()
report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding = "utf-8") as f:
f.write(report)
print(f"\n报告已保存到: {report_path}")