fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 修复语法错误(运算符空格问题)
- 修复类型注解格式
This commit is contained in:
AutoFix Bot
2026-03-02 06:09:49 +08:00
parent b83265e5fd
commit e23f1fec08
84 changed files with 9492 additions and 9491 deletions

View File

@@ -225,3 +225,7 @@
- 第 541 行: line_too_long - 第 541 行: line_too_long
- 第 579 行: line_too_long - 第 579 行: line_too_long
- ... 还有 2 个类似问题 - ... 还有 2 个类似问题
## Git 提交结果
✅ 提交并推送成功

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -21,7 +21,7 @@ class CodeIssue:
message: str, message: str,
severity: str = "warning", severity: str = "warning",
original_line: str = "", original_line: str = "",
): ) -> None:
self.file_path = file_path self.file_path = file_path
self.line_no = line_no self.line_no = line_no
self.issue_type = issue_type self.issue_type = issue_type
@@ -30,14 +30,14 @@ class CodeIssue:
self.original_line = original_line self.original_line = original_line
self.fixed = False self.fixed = False
def __repr__(self): def __repr__(self) -> None:
return f"{self.file_path}:{self.line_no} [{self.severity}] {self.issue_type}: {self.message}" return f"{self.file_path}:{self.line_no} [{self.severity}] {self.issue_type}: {self.message}"
class CodeFixer: class CodeFixer:
"""代码自动修复器""" """代码自动修复器"""
def __init__(self, project_path: str): def __init__(self, project_path: str) -> None:
self.project_path = Path(project_path) self.project_path = Path(project_path)
self.issues: list[CodeIssue] = [] self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = [] self.fixed_issues: list[CodeIssue] = []
@@ -55,7 +55,7 @@ class CodeFixer:
def _scan_file(self, file_path: Path) -> None: def _scan_file(self, file_path: Path) -> None:
"""扫描单个文件""" """扫描单个文件"""
try: try:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding = "utf-8") as f:
content = f.read() content = f.read()
lines = content.split("\n") lines = content.split("\n")
except Exception as e: except Exception as e:
@@ -85,7 +85,7 @@ class CodeFixer:
) -> None: ) -> None:
"""检查裸异常捕获""" """检查裸异常捕获"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 匹配 except: 但不匹配 except Exception: 或 except SpecificError: # 匹配 except Exception: 但不匹配 except Exception: 或 except SpecificError:
if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line): if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line):
# 跳过注释说明的情况 # 跳过注释说明的情况
if "# noqa" in line or "# intentional" in line.lower(): if "# noqa" in line or "# intentional" in line.lower():
@@ -221,10 +221,10 @@ class CodeFixer:
return return
patterns = [ patterns = [
(r'password\s*=\s*["\'][^"\']{8,}["\']', "硬编码密码"), (r'password\s* = \s*["\'][^"\']{8, }["\']', "硬编码密码"),
(r'secret_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"), (r'secret_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码密钥"),
(r'api_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码 API Key"), (r'api_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码 API Key"),
(r'token\s*=\s*["\'][^"\']{8,}["\']', "硬编码 Token"), (r'token\s* = \s*["\'][^"\']{8, }["\']', "硬编码 Token"),
] ]
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
@@ -241,7 +241,7 @@ class CodeFixer:
if any(x in line.lower() for x in ["your_", "example", "placeholder", "test", "demo"]): if any(x in line.lower() for x in ["your_", "example", "placeholder", "test", "demo"]):
continue continue
# 排除 Enum 定义 # 排除 Enum 定义
if re.search(r'^\s*[A-Z_]+\s*=', line.strip()): if re.search(r'^\s*[A-Z_]+\s* = ', line.strip()):
continue continue
self.manual_issues.append( self.manual_issues.append(
CodeIssue( CodeIssue(
@@ -275,7 +275,7 @@ class CodeFixer:
continue continue
try: try:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding = "utf-8") as f:
content = f.read() content = f.read()
lines = content.split("\n") lines = content.split("\n")
except Exception: except Exception:
@@ -301,9 +301,9 @@ class CodeFixer:
line_idx = issue.line_no - 1 line_idx = issue.line_no - 1
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines: if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
line = lines[line_idx] line = lines[line_idx]
# 将 except: 改为 except Exception: # 将 except Exception: 改为 except Exception:
if re.search(r"except\s*:\s*$", line.strip()): if re.search(r"except\s*:\s*$", line.strip()):
lines[line_idx] = line.replace("except:", "except Exception:") lines[line_idx] = line.replace("except Exception:", "except Exception:")
fixed_lines.add(line_idx) fixed_lines.add(line_idx)
issue.fixed = True issue.fixed = True
self.fixed_issues.append(issue) self.fixed_issues.append(issue)
@@ -311,7 +311,7 @@ class CodeFixer:
# 如果文件有修改,写回 # 如果文件有修改,写回
if lines != original_lines: if lines != original_lines:
try: try:
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding = "utf-8") as f:
f.write("\n".join(lines)) f.write("\n".join(lines))
print(f"Fixed issues in {file_path}") print(f"Fixed issues in {file_path}")
except Exception as e: except Exception as e:
@@ -426,16 +426,16 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
# 检查是否有变更 # 检查是否有变更
result = subprocess.run( result = subprocess.run(
["git", "status", "--porcelain"], ["git", "status", "--porcelain"],
cwd=project_path, cwd = project_path,
capture_output=True, capture_output = True,
text=True, text = True,
) )
if not result.stdout.strip(): if not result.stdout.strip():
return True, "没有需要提交的变更" return True, "没有需要提交的变更"
# 添加所有变更 # 添加所有变更
subprocess.run(["git", "add", "-A"], cwd=project_path, check=True) subprocess.run(["git", "add", "-A"], cwd = project_path, check = True)
# 提交 # 提交
commit_msg = """fix: auto-fix code issues (cron) commit_msg = """fix: auto-fix code issues (cron)
@@ -446,11 +446,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
- 添加类型注解""" - 添加类型注解"""
subprocess.run( subprocess.run(
["git", "commit", "-m", commit_msg], cwd=project_path, check=True ["git", "commit", "-m", commit_msg], cwd = project_path, check = True
) )
# 推送 # 推送
subprocess.run(["git", "push"], cwd=project_path, check=True) subprocess.run(["git", "push"], cwd = project_path, check = True)
return True, "提交并推送成功" return True, "提交并推送成功"
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@@ -459,7 +459,7 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
return False, f"Git 操作异常: {e}" return False, f"Git 操作异常: {e}"
def main(): def main() -> None:
project_path = "/root/.openclaw/workspace/projects/insightflow" project_path = "/root/.openclaw/workspace/projects/insightflow"
print("🔍 开始扫描代码...") print("🔍 开始扫描代码...")
@@ -479,7 +479,7 @@ def main():
# 保存报告 # 保存报告
report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md" report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding="utf-8") as f: with open(report_path, "w", encoding = "utf-8") as f:
f.write(report) f.write(report)
print(f"📝 报告已保存到: {report_path}") print(f"📝 报告已保存到: {report_path}")
@@ -493,12 +493,12 @@ def main():
report += f"\n\n## Git 提交结果\n\n{'' if success else ''} {msg}\n" report += f"\n\n## Git 提交结果\n\n{'' if success else ''} {msg}\n"
# 重新保存完整报告 # 重新保存完整报告
with open(report_path, "w", encoding="utf-8") as f: with open(report_path, "w", encoding = "utf-8") as f:
f.write(report) f.write(report)
print("\n" + "=" * 60) print("\n" + " = " * 60)
print(report) print(report)
print("=" * 60) print(" = " * 60)
return report return report

View File

@@ -15,10 +15,10 @@ def run_ruff_check(directory: str) -> list[dict]:
"""运行 ruff 检查并返回问题列表""" """运行 ruff 检查并返回问题列表"""
try: try:
result = subprocess.run( result = subprocess.run(
["ruff", "check", "--select=E,W,F,I", "--output-format=json", directory], ["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory],
capture_output=True, capture_output = True,
text=True, text = True,
check=False, check = False,
) )
if result.stdout: if result.stdout:
return json.loads(result.stdout) return json.loads(result.stdout)
@@ -29,7 +29,7 @@ def run_ruff_check(directory: str) -> list[dict]:
def fix_bare_except(content: str) -> str: def fix_bare_except(content: str) -> str:
"""修复裸异常捕获 - 将 bare except: 改为 except Exception:""" """修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:"""
pattern = r'except\s*:\s*\n' pattern = r'except\s*:\s*\n'
replacement = 'except Exception:\n' replacement = 'except Exception:\n'
return re.sub(pattern, replacement, content) return re.sub(pattern, replacement, content)
@@ -73,7 +73,7 @@ def fix_undefined_names(content: str, filepath: str) -> str:
def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]: def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]:
"""修复单个文件的问题""" """修复单个文件的问题"""
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, 'r', encoding = 'utf-8') as f:
original_content = f.read() original_content = f.read()
content = original_content content = original_content
@@ -97,20 +97,20 @@ def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[s
content = fix_bare_except(content) content = fix_bare_except(content)
if content != original_content: if content != original_content:
with open(filepath, 'w', encoding='utf-8') as f: with open(filepath, 'w', encoding = 'utf-8') as f:
f.write(content) f.write(content)
return True, fixed_issues, manual_fix_needed return True, fixed_issues, manual_fix_needed
return False, fixed_issues, manual_fix_needed return False, fixed_issues, manual_fix_needed
def main(): def main() -> None:
base_dir = Path("/root/.openclaw/workspace/projects/insightflow") base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
backend_dir = base_dir / "backend" backend_dir = base_dir / "backend"
print("=" * 60) print(" = " * 60)
print("InsightFlow 代码自动修复") print("InsightFlow 代码自动修复")
print("=" * 60) print(" = " * 60)
print("\n1. 扫描代码问题...") print("\n1. 扫描代码问题...")
issues = run_ruff_check(str(backend_dir)) issues = run_ruff_check(str(backend_dir))
@@ -130,7 +130,7 @@ def main():
issue_types[code] = issue_types.get(code, 0) + 1 issue_types[code] = issue_types.get(code, 0) + 1
print("\n2. 问题类型统计:") print("\n2. 问题类型统计:")
for code, count in sorted(issue_types.items(), key=lambda x: -x[1]): for code, count in sorted(issue_types.items(), key = lambda x: -x[1]):
print(f" - {code}: {count}") print(f" - {code}: {count}")
print("\n3. 尝试自动修复...") print("\n3. 尝试自动修复...")
@@ -155,8 +155,8 @@ def main():
try: try:
subprocess.run( subprocess.run(
["ruff", "format", str(backend_dir)], ["ruff", "format", str(backend_dir)],
capture_output=True, capture_output = True,
check=False, check = False,
) )
print(" 格式化完成") print(" 格式化完成")
except Exception as e: except Exception as e:
@@ -180,9 +180,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
report = main() report = main()
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("修复报告") print("修复报告")
print("=" * 60) print(" = " * 60)
print(f"总问题数: {report['total_issues']}") print(f"总问题数: {report['total_issues']}")
print(f"修复文件数: {report['fixed_files']}") print(f"修复文件数: {report['fixed_files']}")
print(f"自动修复问题数: {report['fixed_issues']}") print(f"自动修复问题数: {report['fixed_issues']}")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -234,20 +234,20 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
model = CustomModel( model = CustomModel(
id=model_id, id = model_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
model_type=model_type, model_type = model_type,
status=ModelStatus.PENDING, status = ModelStatus.PENDING,
training_data=training_data, training_data = training_data,
hyperparameters=hyperparameters, hyperparameters = hyperparameters,
metrics={}, metrics = {},
model_path=None, model_path = None,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
trained_at=None, trained_at = None,
created_by=created_by, created_by = created_by,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -283,7 +283,7 @@ class AIManager:
def get_custom_model(self, model_id: str) -> CustomModel | None: def get_custom_model(self, model_id: str) -> CustomModel | None:
"""获取自定义模型""" """获取自定义模型"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone() row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id, )).fetchone()
if not row: if not row:
return None return None
@@ -318,12 +318,12 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
sample = TrainingSample( sample = TrainingSample(
id=sample_id, id = sample_id,
model_id=model_id, model_id = model_id,
text=text, text = text,
entities=entities, entities = entities,
metadata=metadata or {}, metadata = metadata or {},
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -350,7 +350,7 @@ class AIManager:
"""获取训练样本""" """获取训练样本"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id,) "SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id, )
).fetchall() ).fetchall()
return [self._row_to_training_sample(row) for row in rows] return [self._row_to_training_sample(row) for row in rows]
@@ -392,7 +392,7 @@ class AIManager:
# 保存模型(模拟) # 保存模型(模拟)
model_path = f"models/{model_id}.bin" model_path = f"models/{model_id}.bin"
os.makedirs("models", exist_ok=True) os.makedirs("models", exist_ok = True)
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -450,9 +450,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=60.0, timeout = 60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -494,17 +494,17 @@ class AIManager:
result = await self._call_kimi_multimodal(input_urls, prompt) result = await self._call_kimi_multimodal(input_urls, prompt)
analysis = MultimodalAnalysis( analysis = MultimodalAnalysis(
id=analysis_id, id = analysis_id,
tenant_id=tenant_id, tenant_id = tenant_id,
project_id=project_id, project_id = project_id,
provider=provider, provider = provider,
input_type=input_type, input_type = input_type,
input_urls=input_urls, input_urls = input_urls,
prompt=prompt, prompt = prompt,
result=result, result = result,
tokens_used=result.get("tokens_used", 0), tokens_used = result.get("tokens_used", 0),
cost=result.get("cost", 0.0), cost = result.get("cost", 0.0),
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -553,9 +553,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=120.0, timeout = 120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -588,9 +588,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://api.anthropic.com/v1/messages", "https://api.anthropic.com/v1/messages",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=120.0, timeout = 120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -623,9 +623,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=60.0, timeout = 60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -670,17 +670,17 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
rag = KnowledgeGraphRAG( rag = KnowledgeGraphRAG(
id=rag_id, id = rag_id,
tenant_id=tenant_id, tenant_id = tenant_id,
project_id=project_id, project_id = project_id,
name=name, name = name,
description=description, description = description,
kg_config=kg_config, kg_config = kg_config,
retrieval_config=retrieval_config, retrieval_config = retrieval_config,
generation_config=generation_config, generation_config = generation_config,
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -712,7 +712,7 @@ class AIManager:
def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None: def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None:
"""获取知识图谱 RAG 配置""" """获取知识图谱 RAG 配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone() row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id, )).fetchone()
if not row: if not row:
return None return None
@@ -766,7 +766,7 @@ class AIManager:
if score > 0: if score > 0:
relevant_entities.append({**entity, "relevance_score": score}) relevant_entities.append({**entity, "relevance_score": score})
relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True) relevant_entities.sort(key = lambda x: x["relevance_score"], reverse = True)
relevant_entities = relevant_entities[:top_k] relevant_entities = relevant_entities[:top_k]
# 检索相关关系 # 检索相关关系
@@ -818,9 +818,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=60.0, timeout = 60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -840,20 +840,20 @@ class AIManager:
] ]
rag_query = RAGQuery( rag_query = RAGQuery(
id=query_id, id = query_id,
rag_id=rag_id, rag_id = rag_id,
query=query, query = query,
context=context, context = context,
answer=answer, answer = answer,
sources=sources, sources = sources,
confidence=( confidence = (
sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities) sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities)
if relevant_entities if relevant_entities
else 0 else 0
), ),
tokens_used=tokens_used, tokens_used = tokens_used,
latency_ms=latency_ms, latency_ms = latency_ms,
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -974,9 +974,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=60.0, timeout = 60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -1014,18 +1014,18 @@ class AIManager:
entity_names = [e.get("name", "") for e in entities_mentioned[:10]] entity_names = [e.get("name", "") for e in entities_mentioned[:10]]
summary = SmartSummary( summary = SmartSummary(
id=summary_id, id = summary_id,
tenant_id=tenant_id, tenant_id = tenant_id,
project_id=project_id, project_id = project_id,
source_type=source_type, source_type = source_type,
source_id=source_id, source_id = source_id,
summary_type=summary_type, summary_type = summary_type,
content=content, content = content,
key_points=key_points[:8], key_points = key_points[:8],
entities_mentioned=entity_names, entities_mentioned = entity_names,
confidence=0.85, confidence = 0.85,
tokens_used=tokens_used, tokens_used = tokens_used,
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1072,20 +1072,20 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
model = PredictionModel( model = PredictionModel(
id=model_id, id = model_id,
tenant_id=tenant_id, tenant_id = tenant_id,
project_id=project_id, project_id = project_id,
name=name, name = name,
prediction_type=prediction_type, prediction_type = prediction_type,
target_entity_type=target_entity_type, target_entity_type = target_entity_type,
features=features, features = features,
model_config=model_config, model_config = model_config,
accuracy=None, accuracy = None,
last_trained_at=None, last_trained_at = None,
prediction_count=0, prediction_count = 0,
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1122,7 +1122,7 @@ class AIManager:
"""获取预测模型""" """获取预测模型"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT * FROM prediction_models WHERE id = ?", (model_id,) "SELECT * FROM prediction_models WHERE id = ?", (model_id, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -1201,16 +1201,16 @@ class AIManager:
explanation = prediction_data.get("explanation", "基于历史数据模式预测") explanation = prediction_data.get("explanation", "基于历史数据模式预测")
result = PredictionResult( result = PredictionResult(
id=prediction_id, id = prediction_id,
model_id=model_id, model_id = model_id,
prediction_type=model.prediction_type, prediction_type = model.prediction_type,
target_id=input_data.get("target_id"), target_id = input_data.get("target_id"),
prediction_data=prediction_data, prediction_data = prediction_data,
confidence=confidence, confidence = confidence,
explanation=explanation, explanation = explanation,
actual_value=None, actual_value = None,
is_correct=None, is_correct = None,
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1238,7 +1238,7 @@ class AIManager:
# 更新预测计数 # 更新预测计数
conn.execute( conn.execute(
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?", "UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?",
(model_id,), (model_id, ),
) )
conn.commit() conn.commit()
@@ -1368,7 +1368,7 @@ class AIManager:
predicted_relations = [ predicted_relations = [
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
for rel_type, count in sorted( for rel_type, count in sorted(
relation_counts.items(), key=lambda x: x[1], reverse=True relation_counts.items(), key = lambda x: x[1], reverse = True
)[:5] )[:5]
] ]
@@ -1410,97 +1410,97 @@ class AIManager:
def _row_to_custom_model(self, row) -> CustomModel: def _row_to_custom_model(self, row) -> CustomModel:
"""将数据库行转换为 CustomModel""" """将数据库行转换为 CustomModel"""
return CustomModel( return CustomModel(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
model_type=ModelType(row["model_type"]), model_type = ModelType(row["model_type"]),
status=ModelStatus(row["status"]), status = ModelStatus(row["status"]),
training_data=json.loads(row["training_data"]), training_data = json.loads(row["training_data"]),
hyperparameters=json.loads(row["hyperparameters"]), hyperparameters = json.loads(row["hyperparameters"]),
metrics=json.loads(row["metrics"]), metrics = json.loads(row["metrics"]),
model_path=row["model_path"], model_path = row["model_path"],
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
trained_at=row["trained_at"], trained_at = row["trained_at"],
created_by=row["created_by"], created_by = row["created_by"],
) )
def _row_to_training_sample(self, row) -> TrainingSample: def _row_to_training_sample(self, row) -> TrainingSample:
"""将数据库行转换为 TrainingSample""" """将数据库行转换为 TrainingSample"""
return TrainingSample( return TrainingSample(
id=row["id"], id = row["id"],
model_id=row["model_id"], model_id = row["model_id"],
text=row["text"], text = row["text"],
entities=json.loads(row["entities"]), entities = json.loads(row["entities"]),
metadata=json.loads(row["metadata"]), metadata = json.loads(row["metadata"]),
created_at=row["created_at"], created_at = row["created_at"],
) )
def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis: def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis:
"""将数据库行转换为 MultimodalAnalysis""" """将数据库行转换为 MultimodalAnalysis"""
return MultimodalAnalysis( return MultimodalAnalysis(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
project_id=row["project_id"], project_id = row["project_id"],
provider=MultimodalProvider(row["provider"]), provider = MultimodalProvider(row["provider"]),
input_type=row["input_type"], input_type = row["input_type"],
input_urls=json.loads(row["input_urls"]), input_urls = json.loads(row["input_urls"]),
prompt=row["prompt"], prompt = row["prompt"],
result=json.loads(row["result"]), result = json.loads(row["result"]),
tokens_used=row["tokens_used"], tokens_used = row["tokens_used"],
cost=row["cost"], cost = row["cost"],
created_at=row["created_at"], created_at = row["created_at"],
) )
def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG: def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG:
"""将数据库行转换为 KnowledgeGraphRAG""" """将数据库行转换为 KnowledgeGraphRAG"""
return KnowledgeGraphRAG( return KnowledgeGraphRAG(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
project_id=row["project_id"], project_id = row["project_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
kg_config=json.loads(row["kg_config"]), kg_config = json.loads(row["kg_config"]),
retrieval_config=json.loads(row["retrieval_config"]), retrieval_config = json.loads(row["retrieval_config"]),
generation_config=json.loads(row["generation_config"]), generation_config = json.loads(row["generation_config"]),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
) )
def _row_to_prediction_model(self, row) -> PredictionModel: def _row_to_prediction_model(self, row) -> PredictionModel:
"""将数据库行转换为 PredictionModel""" """将数据库行转换为 PredictionModel"""
return PredictionModel( return PredictionModel(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
project_id=row["project_id"], project_id = row["project_id"],
name=row["name"], name = row["name"],
prediction_type=PredictionType(row["prediction_type"]), prediction_type = PredictionType(row["prediction_type"]),
target_entity_type=row["target_entity_type"], target_entity_type = row["target_entity_type"],
features=json.loads(row["features"]), features = json.loads(row["features"]),
model_config=json.loads(row["model_config"]), model_config = json.loads(row["model_config"]),
accuracy=row["accuracy"], accuracy = row["accuracy"],
last_trained_at=row["last_trained_at"], last_trained_at = row["last_trained_at"],
prediction_count=row["prediction_count"], prediction_count = row["prediction_count"],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
) )
def _row_to_prediction_result(self, row) -> PredictionResult: def _row_to_prediction_result(self, row) -> PredictionResult:
"""将数据库行转换为 PredictionResult""" """将数据库行转换为 PredictionResult"""
return PredictionResult( return PredictionResult(
id=row["id"], id = row["id"],
model_id=row["model_id"], model_id = row["model_id"],
prediction_type=PredictionType(row["prediction_type"]), prediction_type = PredictionType(row["prediction_type"]),
target_id=row["target_id"], target_id = row["target_id"],
prediction_data=json.loads(row["prediction_data"]), prediction_data = json.loads(row["prediction_data"]),
confidence=row["confidence"], confidence = row["confidence"],
explanation=row["explanation"], explanation = row["explanation"],
actual_value=row["actual_value"], actual_value = row["actual_value"],
is_correct=row["is_correct"], is_correct = row["is_correct"],
created_at=row["created_at"], created_at = row["created_at"],
) )

View File

@@ -152,23 +152,23 @@ class ApiKeyManager:
expires_at = None expires_at = None
if expires_days: if expires_days:
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat()
api_key = ApiKey( api_key = ApiKey(
id=key_id, id = key_id,
key_hash=key_hash, key_hash = key_hash,
key_preview=key_preview, key_preview = key_preview,
name=name, name = name,
owner_id=owner_id, owner_id = owner_id,
permissions=permissions, permissions = permissions,
rate_limit=rate_limit, rate_limit = rate_limit,
status=ApiKeyStatus.ACTIVE.value, status = ApiKeyStatus.ACTIVE.value,
created_at=datetime.now().isoformat(), created_at = datetime.now().isoformat(),
expires_at=expires_at, expires_at = expires_at,
last_used_at=None, last_used_at = None,
revoked_at=None, revoked_at = None,
revoked_reason=None, revoked_reason = None,
total_calls=0, total_calls = 0,
) )
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -207,7 +207,7 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone() row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone()
if not row: if not row:
return None return None
@@ -238,7 +238,7 @@ class ApiKeyManager:
# 验证所有权(如果提供了 owner_id # 验证所有权(如果提供了 owner_id
if owner_id: if owner_id:
row = conn.execute( row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,) "SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
).fetchone() ).fetchone()
if not row or row[0] != owner_id: if not row or row[0] != owner_id:
return False return False
@@ -270,7 +270,7 @@ class ApiKeyManager:
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id) "SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
).fetchone() ).fetchone()
else: else:
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone() row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone()
if row: if row:
return self._row_to_api_key(row) return self._row_to_api_key(row)
@@ -287,7 +287,7 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_keys WHERE 1=1" query = "SELECT * FROM api_keys WHERE 1 = 1"
params = [] params = []
if owner_id: if owner_id:
@@ -337,7 +337,7 @@ class ApiKeyManager:
# 验证所有权 # 验证所有权
if owner_id: if owner_id:
row = conn.execute( row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,) "SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
).fetchone() ).fetchone()
if not row or row[0] != owner_id: if not row or row[0] != owner_id:
return False return False
@@ -370,7 +370,7 @@ class ApiKeyManager:
ip_address: str = "", ip_address: str = "",
user_agent: str = "", user_agent: str = "",
error_message: str = "", error_message: str = "",
): ) -> None:
"""记录 API 调用日志""" """记录 API 调用日志"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute( conn.execute(
@@ -405,7 +405,7 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_call_logs WHERE 1=1" query = "SELECT * FROM api_call_logs WHERE 1 = 1"
params = [] params = []
if api_key_id: if api_key_id:
@@ -510,20 +510,20 @@ class ApiKeyManager:
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey: def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
"""将数据库行转换为 ApiKey 对象""" """将数据库行转换为 ApiKey 对象"""
return ApiKey( return ApiKey(
id=row["id"], id = row["id"],
key_hash=row["key_hash"], key_hash = row["key_hash"],
key_preview=row["key_preview"], key_preview = row["key_preview"],
name=row["name"], name = row["name"],
owner_id=row["owner_id"], owner_id = row["owner_id"],
permissions=json.loads(row["permissions"]), permissions = json.loads(row["permissions"]),
rate_limit=row["rate_limit"], rate_limit = row["rate_limit"],
status=row["status"], status = row["status"],
created_at=row["created_at"], created_at = row["created_at"],
expires_at=row["expires_at"], expires_at = row["expires_at"],
last_used_at=row["last_used_at"], last_used_at = row["last_used_at"],
revoked_at=row["revoked_at"], revoked_at = row["revoked_at"],
revoked_reason=row["revoked_reason"], revoked_reason = row["revoked_reason"],
total_calls=row["total_calls"], total_calls = row["total_calls"],
) )

View File

@@ -136,7 +136,7 @@ class TeamSpace:
class CollaborationManager: class CollaborationManager:
"""协作管理主类""" """协作管理主类"""
def __init__(self, db_manager=None): def __init__(self, db_manager = None) -> None:
self.db = db_manager self.db = db_manager
self._shares_cache: dict[str, ProjectShare] = {} self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: dict[str, list[Comment]] = {} self._comments_cache: dict[str, list[Comment]] = {}
@@ -161,26 +161,26 @@ class CollaborationManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
expires_at = None expires_at = None
if expires_in_days: if expires_in_days:
expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat() expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat()
password_hash = None password_hash = None
if password: if password:
password_hash = hashlib.sha256(password.encode()).hexdigest() password_hash = hashlib.sha256(password.encode()).hexdigest()
share = ProjectShare( share = ProjectShare(
id=share_id, id = share_id,
project_id=project_id, project_id = project_id,
token=token, token = token,
permission=permission, permission = permission,
created_by=created_by, created_by = created_by,
created_at=now, created_at = now,
expires_at=expires_at, expires_at = expires_at,
max_uses=max_uses, max_uses = max_uses,
use_count=0, use_count = 0,
password_hash=password_hash, password_hash = password_hash,
is_active=True, is_active = True,
allow_download=allow_download, allow_download = allow_download,
allow_export=allow_export, allow_export = allow_export,
) )
# 保存到数据库 # 保存到数据库
@@ -263,7 +263,7 @@ class CollaborationManager:
""" """
SELECT * FROM project_shares WHERE token = ? SELECT * FROM project_shares WHERE token = ?
""", """,
(token,), (token, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -271,19 +271,19 @@ class CollaborationManager:
return None return None
return ProjectShare( return ProjectShare(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
token=row[2], token = row[2],
permission=row[3], permission = row[3],
created_by=row[4], created_by = row[4],
created_at=row[5], created_at = row[5],
expires_at=row[6], expires_at = row[6],
max_uses=row[7], max_uses = row[7],
use_count=row[8], use_count = row[8],
password_hash=row[9], password_hash = row[9],
is_active=bool(row[10]), is_active = bool(row[10]),
allow_download=bool(row[11]), allow_download = bool(row[11]),
allow_export=bool(row[12]), allow_export = bool(row[12]),
) )
def increment_share_usage(self, token: str) -> None: def increment_share_usage(self, token: str) -> None:
@@ -300,7 +300,7 @@ class CollaborationManager:
SET use_count = use_count + 1 SET use_count = use_count + 1
WHERE token = ? WHERE token = ?
""", """,
(token,), (token, ),
) )
self.db.conn.commit() self.db.conn.commit()
@@ -314,7 +314,7 @@ class CollaborationManager:
SET is_active = 0 SET is_active = 0
WHERE id = ? WHERE id = ?
""", """,
(share_id,), (share_id, ),
) )
self.db.conn.commit() self.db.conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -332,26 +332,26 @@ class CollaborationManager:
WHERE project_id = ? WHERE project_id = ?
ORDER BY created_at DESC ORDER BY created_at DESC
""", """,
(project_id,), (project_id, ),
) )
shares = [] shares = []
for row in cursor.fetchall(): for row in cursor.fetchall():
shares.append( shares.append(
ProjectShare( ProjectShare(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
token=row[2], token = row[2],
permission=row[3], permission = row[3],
created_by=row[4], created_by = row[4],
created_at=row[5], created_at = row[5],
expires_at=row[6], expires_at = row[6],
max_uses=row[7], max_uses = row[7],
use_count=row[8], use_count = row[8],
password_hash=row[9], password_hash = row[9],
is_active=bool(row[10]), is_active = bool(row[10]),
allow_download=bool(row[11]), allow_download = bool(row[11]),
allow_export=bool(row[12]), allow_export = bool(row[12]),
) )
) )
return shares return shares
@@ -375,21 +375,21 @@ class CollaborationManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
comment = Comment( comment = Comment(
id=comment_id, id = comment_id,
project_id=project_id, project_id = project_id,
target_type=target_type, target_type = target_type,
target_id=target_id, target_id = target_id,
parent_id=parent_id, parent_id = parent_id,
author=author, author = author,
author_name=author_name, author_name = author_name,
content=content, content = content,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
resolved=False, resolved = False,
resolved_by=None, resolved_by = None,
resolved_at=None, resolved_at = None,
mentions=mentions or [], mentions = mentions or [],
attachments=attachments or [], attachments = attachments or [],
) )
if self.db: if self.db:
@@ -469,21 +469,21 @@ class CollaborationManager:
def _row_to_comment(self, row) -> Comment: def _row_to_comment(self, row) -> Comment:
"""将数据库行转换为Comment对象""" """将数据库行转换为Comment对象"""
return Comment( return Comment(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
target_type=row[2], target_type = row[2],
target_id=row[3], target_id = row[3],
parent_id=row[4], parent_id = row[4],
author=row[5], author = row[5],
author_name=row[6], author_name = row[6],
content=row[7], content = row[7],
created_at=row[8], created_at = row[8],
updated_at=row[9], updated_at = row[9],
resolved=bool(row[10]), resolved = bool(row[10]),
resolved_by=row[11], resolved_by = row[11],
resolved_at=row[12], resolved_at = row[12],
mentions=json.loads(row[13]) if row[13] else [], mentions = json.loads(row[13]) if row[13] else [],
attachments=json.loads(row[14]) if row[14] else [], attachments = json.loads(row[14]) if row[14] else [],
) )
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None: def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
@@ -510,7 +510,7 @@ class CollaborationManager:
def _get_comment_by_id(self, comment_id: str) -> Comment | None: def _get_comment_by_id(self, comment_id: str) -> Comment | None:
"""根据ID获取评论""" """根据ID获取评论"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,)) cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_comment(row) return self._row_to_comment(row)
@@ -597,22 +597,22 @@ class CollaborationManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
record = ChangeRecord( record = ChangeRecord(
id=record_id, id = record_id,
project_id=project_id, project_id = project_id,
change_type=change_type, change_type = change_type,
entity_type=entity_type, entity_type = entity_type,
entity_id=entity_id, entity_id = entity_id,
entity_name=entity_name, entity_name = entity_name,
changed_by=changed_by, changed_by = changed_by,
changed_by_name=changed_by_name, changed_by_name = changed_by_name,
changed_at=now, changed_at = now,
old_value=old_value, old_value = old_value,
new_value=new_value, new_value = new_value,
description=description, description = description,
session_id=session_id, session_id = session_id,
reverted=False, reverted = False,
reverted_at=None, reverted_at = None,
reverted_by=None, reverted_by = None,
) )
if self.db: if self.db:
@@ -705,22 +705,22 @@ class CollaborationManager:
def _row_to_change_record(self, row) -> ChangeRecord: def _row_to_change_record(self, row) -> ChangeRecord:
"""将数据库行转换为ChangeRecord对象""" """将数据库行转换为ChangeRecord对象"""
return ChangeRecord( return ChangeRecord(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
change_type=row[2], change_type = row[2],
entity_type=row[3], entity_type = row[3],
entity_id=row[4], entity_id = row[4],
entity_name=row[5], entity_name = row[5],
changed_by=row[6], changed_by = row[6],
changed_by_name=row[7], changed_by_name = row[7],
changed_at=row[8], changed_at = row[8],
old_value=json.loads(row[9]) if row[9] else None, old_value = json.loads(row[9]) if row[9] else None,
new_value=json.loads(row[10]) if row[10] else None, new_value = json.loads(row[10]) if row[10] else None,
description=row[11], description = row[11],
session_id=row[12], session_id = row[12],
reverted=bool(row[13]), reverted = bool(row[13]),
reverted_at=row[14], reverted_at = row[14],
reverted_by=row[15], reverted_by = row[15],
) )
def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]: def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
@@ -773,7 +773,7 @@ class CollaborationManager:
""" """
SELECT COUNT(*) FROM change_history WHERE project_id = ? SELECT COUNT(*) FROM change_history WHERE project_id = ?
""", """,
(project_id,), (project_id, ),
) )
total_changes = cursor.fetchone()[0] total_changes = cursor.fetchone()[0]
@@ -783,7 +783,7 @@ class CollaborationManager:
SELECT change_type, COUNT(*) FROM change_history SELECT change_type, COUNT(*) FROM change_history
WHERE project_id = ? GROUP BY change_type WHERE project_id = ? GROUP BY change_type
""", """,
(project_id,), (project_id, ),
) )
type_counts = {row[0]: row[1] for row in cursor.fetchall()} type_counts = {row[0]: row[1] for row in cursor.fetchall()}
@@ -793,7 +793,7 @@ class CollaborationManager:
SELECT entity_type, COUNT(*) FROM change_history SELECT entity_type, COUNT(*) FROM change_history
WHERE project_id = ? GROUP BY entity_type WHERE project_id = ? GROUP BY entity_type
""", """,
(project_id,), (project_id, ),
) )
entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()} entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()}
@@ -806,7 +806,7 @@ class CollaborationManager:
ORDER BY count DESC ORDER BY count DESC
LIMIT 5 LIMIT 5
""", """,
(project_id,), (project_id, ),
) )
top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()] top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()]
@@ -838,16 +838,16 @@ class CollaborationManager:
permissions = self._get_default_permissions(role) permissions = self._get_default_permissions(role)
member = TeamMember( member = TeamMember(
id=member_id, id = member_id,
project_id=project_id, project_id = project_id,
user_id=user_id, user_id = user_id,
user_name=user_name, user_name = user_name,
user_email=user_email, user_email = user_email,
role=role, role = role,
joined_at=now, joined_at = now,
invited_by=invited_by, invited_by = invited_by,
last_active_at=None, last_active_at = None,
permissions=permissions, permissions = permissions,
) )
if self.db: if self.db:
@@ -902,7 +902,7 @@ class CollaborationManager:
SELECT * FROM team_members WHERE project_id = ? SELECT * FROM team_members WHERE project_id = ?
ORDER BY joined_at ASC ORDER BY joined_at ASC
""", """,
(project_id,), (project_id, ),
) )
members = [] members = []
@@ -913,16 +913,16 @@ class CollaborationManager:
def _row_to_team_member(self, row) -> TeamMember: def _row_to_team_member(self, row) -> TeamMember:
"""将数据库行转换为TeamMember对象""" """将数据库行转换为TeamMember对象"""
return TeamMember( return TeamMember(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
user_id=row[2], user_id = row[2],
user_name=row[3], user_name = row[3],
user_email=row[4], user_email = row[4],
role=row[5], role = row[5],
joined_at=row[6], joined_at = row[6],
invited_by=row[7], invited_by = row[7],
last_active_at=row[8], last_active_at = row[8],
permissions=json.loads(row[9]) if row[9] else [], permissions = json.loads(row[9]) if row[9] else [],
) )
def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool: def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool:
@@ -949,7 +949,7 @@ class CollaborationManager:
return False return False
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id,)) cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id, ))
self.db.conn.commit() self.db.conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -996,7 +996,7 @@ class CollaborationManager:
_collaboration_manager = None _collaboration_manager = None
def get_collaboration_manager(db_manager=None) -> None: def get_collaboration_manager(db_manager = None) -> None:
"""获取协作管理器单例""" """获取协作管理器单例"""
global _collaboration_manager global _collaboration_manager
if _collaboration_manager is None: if _collaboration_manager is None:

View File

@@ -41,7 +41,7 @@ class Entity:
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
def __post_init__(self): def __post_init__(self) -> None:
if self.aliases is None: if self.aliases is None:
self.aliases = [] self.aliases = []
if self.attributes is None: if self.attributes is None:
@@ -64,7 +64,7 @@ class AttributeTemplate:
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
def __post_init__(self): def __post_init__(self) -> None:
if self.options is None: if self.options is None:
self.options = [] self.options = []
@@ -85,7 +85,7 @@ class EntityAttribute:
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
def __post_init__(self): def __post_init__(self) -> None:
if self.options is None: if self.options is None:
self.options = [] self.options = []
@@ -116,12 +116,12 @@ class EntityMention:
class DatabaseManager: class DatabaseManager:
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path), exist_ok = True)
self.init_db() self.init_db()
def get_conn(self): def get_conn(self) -> None:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
@@ -149,12 +149,12 @@ class DatabaseManager:
conn.commit() conn.commit()
conn.close() conn.close()
return Project( return Project(
id=project_id, name=name, description=description, created_at=now, updated_at=now id = project_id, name = name, description = description, created_at = now, updated_at = now
) )
def get_project(self, project_id: str) -> Project | None: def get_project(self, project_id: str) -> Project | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone()
conn.close() conn.close()
if row: if row:
return Project(**dict(row)) return Project(**dict(row))
@@ -226,8 +226,8 @@ class DatabaseManager:
"""合并两个实体""" """合并两个实体"""
conn = self.get_conn() conn = self.get_conn()
target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id,)).fetchone() target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id, )).fetchone()
source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id,)).fetchone() source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id, )).fetchone()
if not target or not source: if not target or not source:
conn.close() conn.close()
@@ -252,7 +252,7 @@ class DatabaseManager:
"UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?", "UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?",
(target_id, source_id), (target_id, source_id),
) )
conn.execute("DELETE FROM entities WHERE id = ?", (source_id,)) conn.execute("DELETE FROM entities WHERE id = ?", (source_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -260,7 +260,7 @@ class DatabaseManager:
def get_entity(self, entity_id: str) -> Entity | None: def get_entity(self, entity_id: str) -> Entity | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone()
conn.close() conn.close()
if row: if row:
data = dict(row) data = dict(row)
@@ -271,7 +271,7 @@ class DatabaseManager:
def list_project_entities(self, project_id: str) -> list[Entity]: def list_project_entities(self, project_id: str) -> list[Entity]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,) "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -316,13 +316,13 @@ class DatabaseManager:
def delete_entity(self, entity_id: str) -> None: def delete_entity(self, entity_id: str) -> None:
"""删除实体及其关联数据""" """删除实体及其关联数据"""
conn = self.get_conn() conn = self.get_conn()
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,)) conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id, ))
conn.execute( conn.execute(
"DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?", "DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?",
(entity_id, entity_id), (entity_id, entity_id),
) )
conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id,)) conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id, ))
conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,)) conn.execute("DELETE FROM entities WHERE id = ?", (entity_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -352,7 +352,7 @@ class DatabaseManager:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos",
(entity_id,), (entity_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
return [EntityMention(**dict(r)) for r in rows] return [EntityMention(**dict(r)) for r in rows]
@@ -366,7 +366,7 @@ class DatabaseManager:
filename: str, filename: str,
full_text: str, full_text: str,
transcript_type: str = "audio", transcript_type: str = "audio",
): ) -> None:
conn = self.get_conn() conn = self.get_conn()
now = datetime.now().isoformat() now = datetime.now().isoformat()
conn.execute( conn.execute(
@@ -380,14 +380,14 @@ class DatabaseManager:
def get_transcript(self, transcript_id: str) -> dict | None: def get_transcript(self, transcript_id: str) -> dict | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone()
conn.close() conn.close()
return dict(row) if row else None return dict(row) if row else None
def list_project_transcripts(self, project_id: str) -> list[dict]: def list_project_transcripts(self, project_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
@@ -400,7 +400,7 @@ class DatabaseManager:
(full_text, now, transcript_id), (full_text, now, transcript_id),
) )
conn.commit() conn.commit()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone()
conn.close() conn.close()
return dict(row) if row else None return dict(row) if row else None
@@ -414,7 +414,7 @@ class DatabaseManager:
relation_type: str = "related", relation_type: str = "related",
evidence: str = "", evidence: str = "",
transcript_id: str = "", transcript_id: str = "",
): ) -> None:
conn = self.get_conn() conn = self.get_conn()
relation_id = str(uuid.uuid4())[:UUID_LENGTH] relation_id = str(uuid.uuid4())[:UUID_LENGTH]
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -453,7 +453,7 @@ class DatabaseManager:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
@@ -475,13 +475,13 @@ class DatabaseManager:
conn.execute(query, values) conn.execute(query, values)
conn.commit() conn.commit()
row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id,)).fetchone() row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id, )).fetchone()
conn.close() conn.close()
return dict(row) if row else None return dict(row) if row else None
def delete_relation(self, relation_id: str): def delete_relation(self, relation_id: str) -> None:
conn = self.get_conn() conn = self.get_conn()
conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id,)) conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -495,7 +495,7 @@ class DatabaseManager:
if existing: if existing:
conn.execute( conn.execute(
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],) "UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"], )
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -515,14 +515,14 @@ class DatabaseManager:
def list_glossary(self, project_id: str) -> list[dict]: def list_glossary(self, project_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,) "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
def delete_glossary_term(self, term_id: str): def delete_glossary_term(self, term_id: str) -> None:
conn = self.get_conn() conn = self.get_conn()
conn.execute("DELETE FROM glossary WHERE id = ?", (term_id,)) conn.execute("DELETE FROM glossary WHERE id = ?", (term_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -539,14 +539,14 @@ class DatabaseManager:
JOIN entities t ON r.target_entity_id = t.id JOIN entities t ON r.target_entity_id = t.id
LEFT JOIN transcripts tr ON r.transcript_id = tr.id LEFT JOIN transcripts tr ON r.transcript_id = tr.id
WHERE r.id = ?""", WHERE r.id = ?""",
(relation_id,), (relation_id, ),
).fetchone() ).fetchone()
conn.close() conn.close()
return dict(row) if row else None return dict(row) if row else None
def get_entity_with_mentions(self, entity_id: str) -> dict | None: def get_entity_with_mentions(self, entity_id: str) -> dict | None:
conn = self.get_conn() conn = self.get_conn()
entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone()
if not entity_row: if not entity_row:
conn.close() conn.close()
return None return None
@@ -559,7 +559,7 @@ class DatabaseManager:
FROM entity_mentions m FROM entity_mentions m
JOIN transcripts t ON m.transcript_id = t.id JOIN transcripts t ON m.transcript_id = t.id
WHERE m.entity_id = ? ORDER BY t.created_at, m.start_pos""", WHERE m.entity_id = ? ORDER BY t.created_at, m.start_pos""",
(entity_id,), (entity_id, ),
).fetchall() ).fetchall()
entity["mentions"] = [dict(m) for m in mentions] entity["mentions"] = [dict(m) for m in mentions]
entity["mention_count"] = len(mentions) entity["mention_count"] = len(mentions)
@@ -598,24 +598,24 @@ class DatabaseManager:
def get_project_summary(self, project_id: str) -> dict: def get_project_summary(self, project_id: str) -> dict:
conn = self.get_conn() conn = self.get_conn()
project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone()
entity_count = conn.execute( entity_count = conn.execute(
"SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id,) "SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id, )
).fetchone()["count"] ).fetchone()["count"]
transcript_count = conn.execute( transcript_count = conn.execute(
"SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id,) "SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id, )
).fetchone()["count"] ).fetchone()["count"]
relation_count = conn.execute( relation_count = conn.execute(
"SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id,) "SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id, )
).fetchone()["count"] ).fetchone()["count"]
recent_transcripts = conn.execute( recent_transcripts = conn.execute(
"""SELECT filename, full_text, created_at FROM transcripts """SELECT filename, full_text, created_at FROM transcripts
WHERE project_id = ? ORDER BY created_at DESC LIMIT 5""", WHERE project_id = ? ORDER BY created_at DESC LIMIT 5""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
top_entities = conn.execute( top_entities = conn.execute(
@@ -624,7 +624,7 @@ class DatabaseManager:
LEFT JOIN entity_mentions m ON e.id = m.entity_id LEFT JOIN entity_mentions m ON e.id = m.entity_id
WHERE e.project_id = ? WHERE e.project_id = ?
GROUP BY e.id ORDER BY mention_count DESC LIMIT 10""", GROUP BY e.id ORDER BY mention_count DESC LIMIT 10""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -645,7 +645,7 @@ class DatabaseManager:
) -> str: ) -> str:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute( row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,) "SELECT full_text FROM transcripts WHERE id = ?", (transcript_id, )
).fetchone() ).fetchone()
conn.close() conn.close()
if not row: if not row:
@@ -708,7 +708,7 @@ class DatabaseManager:
) )
conn.close() conn.close()
timeline_events.sort(key=lambda x: x["event_date"]) timeline_events.sort(key = lambda x: x["event_date"])
return timeline_events return timeline_events
def get_entity_timeline_summary(self, project_id: str) -> dict: def get_entity_timeline_summary(self, project_id: str) -> dict:
@@ -719,7 +719,7 @@ class DatabaseManager:
FROM entity_mentions m FROM entity_mentions m
JOIN transcripts t ON m.transcript_id = t.id JOIN transcripts t ON m.transcript_id = t.id
WHERE t.project_id = ? GROUP BY DATE(t.created_at) ORDER BY date""", WHERE t.project_id = ? GROUP BY DATE(t.created_at) ORDER BY date""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
entity_stats = conn.execute( entity_stats = conn.execute(
@@ -731,7 +731,7 @@ class DatabaseManager:
LEFT JOIN transcripts t ON m.transcript_id = t.id LEFT JOIN transcripts t ON m.transcript_id = t.id
WHERE e.project_id = ? WHERE e.project_id = ?
GROUP BY e.id ORDER BY mention_count DESC LIMIT 20""", GROUP BY e.id ORDER BY mention_count DESC LIMIT 20""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -772,7 +772,7 @@ class DatabaseManager:
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None: def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute( row = conn.execute(
"SELECT * FROM attribute_templates WHERE id = ?", (template_id,) "SELECT * FROM attribute_templates WHERE id = ?", (template_id, )
).fetchone() ).fetchone()
conn.close() conn.close()
if row: if row:
@@ -786,7 +786,7 @@ class DatabaseManager:
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM attribute_templates WHERE project_id = ? """SELECT * FROM attribute_templates WHERE project_id = ?
ORDER BY sort_order, created_at""", ORDER BY sort_order, created_at""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -830,9 +830,9 @@ class DatabaseManager:
conn.close() conn.close()
return self.get_attribute_template(template_id) return self.get_attribute_template(template_id)
def delete_attribute_template(self, template_id: str): def delete_attribute_template(self, template_id: str) -> None:
conn = self.get_conn() conn = self.get_conn()
conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id,)) conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -905,7 +905,7 @@ class DatabaseManager:
FROM entity_attributes ea FROM entity_attributes ea
LEFT JOIN attribute_templates at ON ea.template_id = at.id LEFT JOIN attribute_templates at ON ea.template_id = at.id
WHERE ea.entity_id = ? ORDER BY ea.created_at""", WHERE ea.entity_id = ? ORDER BY ea.created_at""",
(entity_id,), (entity_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
return [EntityAttribute(**dict(r)) for r in rows] return [EntityAttribute(**dict(r)) for r in rows]
@@ -927,7 +927,7 @@ class DatabaseManager:
def delete_entity_attribute( def delete_entity_attribute(
self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "" self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = ""
): ) -> None:
conn = self.get_conn() conn = self.get_conn()
old_row = conn.execute( old_row = conn.execute(
"""SELECT value FROM entity_attributes """SELECT value FROM entity_attributes
@@ -973,7 +973,7 @@ class DatabaseManager:
conditions.append("ah.template_id = ?") conditions.append("ah.template_id = ?")
params.append(template_id) params.append(template_id)
where_clause = " AND ".join(conditions) if conditions else "1=1" where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute( rows = conn.execute(
f"""SELECT ah.* f"""SELECT ah.*
@@ -997,7 +997,7 @@ class DatabaseManager:
return [] return []
conn = self.get_conn() conn = self.get_conn()
placeholders = ",".join(["?" for _ in entity_ids]) placeholders = ", ".join(["?" for _ in entity_ids])
rows = conn.execute( rows = conn.execute(
f"""SELECT ea.*, at.name as template_name f"""SELECT ea.*, at.name as template_name
FROM entity_attributes ea FROM entity_attributes ea
@@ -1075,7 +1075,7 @@ class DatabaseManager:
def get_video(self, video_id: str) -> dict | None: def get_video(self, video_id: str) -> dict | None:
"""获取视频信息""" """获取视频信息"""
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone() row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id, )).fetchone()
conn.close() conn.close()
if row: if row:
@@ -1094,7 +1094,7 @@ class DatabaseManager:
"""获取项目的所有视频""" """获取项目的所有视频"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -1149,7 +1149,7 @@ class DatabaseManager:
"""获取视频的所有帧""" """获取视频的所有帧"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id,) """SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -1201,7 +1201,7 @@ class DatabaseManager:
def get_image(self, image_id: str) -> dict | None: def get_image(self, image_id: str) -> dict | None:
"""获取图片信息""" """获取图片信息"""
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone() row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id, )).fetchone()
conn.close() conn.close()
if row: if row:
@@ -1219,7 +1219,7 @@ class DatabaseManager:
"""获取项目的所有图片""" """获取项目的所有图片"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -1279,7 +1279,7 @@ class DatabaseManager:
FROM multimodal_mentions m FROM multimodal_mentions m
JOIN entities e ON m.entity_id = e.id JOIN entities e ON m.entity_id = e.id
WHERE m.entity_id = ? ORDER BY m.created_at DESC""", WHERE m.entity_id = ? ORDER BY m.created_at DESC""",
(entity_id,), (entity_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
@@ -1303,7 +1303,7 @@ class DatabaseManager:
FROM multimodal_mentions m FROM multimodal_mentions m
JOIN entities e ON m.entity_id = e.id JOIN entities e ON m.entity_id = e.id
WHERE m.project_id = ? ORDER BY m.created_at DESC""", WHERE m.project_id = ? ORDER BY m.created_at DESC""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -1377,13 +1377,13 @@ class DatabaseManager:
# 视频数量 # 视频数量
row = conn.execute( row = conn.execute(
"SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,) "SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id, )
).fetchone() ).fetchone()
stats["video_count"] = row["count"] stats["video_count"] = row["count"]
# 图片数量 # 图片数量
row = conn.execute( row = conn.execute(
"SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,) "SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id, )
).fetchone() ).fetchone()
stats["image_count"] = row["count"] stats["image_count"] = row["count"]
@@ -1391,7 +1391,7 @@ class DatabaseManager:
row = conn.execute( row = conn.execute(
"""SELECT COUNT(DISTINCT entity_id) as count """SELECT COUNT(DISTINCT entity_id) as count
FROM multimodal_mentions WHERE project_id = ?""", FROM multimodal_mentions WHERE project_id = ?""",
(project_id,), (project_id, ),
).fetchone() ).fetchone()
stats["multimodal_entity_count"] = row["count"] stats["multimodal_entity_count"] = row["count"]
@@ -1399,7 +1399,7 @@ class DatabaseManager:
row = conn.execute( row = conn.execute(
"""SELECT COUNT(*) as count FROM multimodal_entity_links """SELECT COUNT(*) as count FROM multimodal_entity_links
WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""", WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""",
(project_id,), (project_id, ),
).fetchone() ).fetchone()
stats["cross_modal_links"] = row["count"] stats["cross_modal_links"] = row["count"]

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ import os
class DocumentProcessor: class DocumentProcessor:
"""文档处理器 - 提取 PDF/DOCX 文本""" """文档处理器 - 提取 PDF/DOCX 文本"""
def __init__(self): def __init__(self) -> None:
self.supported_formats = { self.supported_formats = {
".pdf": self._extract_pdf, ".pdf": self._extract_pdf,
".docx": self._extract_docx, ".docx": self._extract_docx,
@@ -123,7 +123,7 @@ class DocumentProcessor:
continue continue
# 如果都失败了,使用 latin-1 并忽略错误 # 如果都失败了,使用 latin-1 并忽略错误
return content.decode("latin-1", errors="ignore") return content.decode("latin-1", errors = "ignore")
def _clean_text(self, text: str) -> str: def _clean_text(self, text: str) -> str:
"""清理提取的文本""" """清理提取的文本"""
@@ -173,7 +173,7 @@ class SimpleTextExtractor:
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
return content.decode("latin-1", errors="ignore") return content.decode("latin-1", errors = "ignore")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -329,7 +329,7 @@ class EnterpriseManager:
], ],
} }
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self._init_db() self._init_db()
@@ -610,30 +610,30 @@ class EnterpriseManager:
attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)] attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)]
config = SSOConfig( config = SSOConfig(
id=config_id, id = config_id,
tenant_id=tenant_id, tenant_id = tenant_id,
provider=provider, provider = provider,
status=SSOStatus.PENDING.value, status = SSOStatus.PENDING.value,
entity_id=entity_id, entity_id = entity_id,
sso_url=sso_url, sso_url = sso_url,
slo_url=slo_url, slo_url = slo_url,
certificate=certificate, certificate = certificate,
metadata_url=metadata_url, metadata_url = metadata_url,
metadata_xml=metadata_xml, metadata_xml = metadata_xml,
client_id=client_id, client_id = client_id,
client_secret=client_secret, client_secret = client_secret,
authorization_url=authorization_url, authorization_url = authorization_url,
token_url=token_url, token_url = token_url,
userinfo_url=userinfo_url, userinfo_url = userinfo_url,
scopes=scopes or ["openid", "email", "profile"], scopes = scopes or ["openid", "email", "profile"],
attribute_mapping=attribute_mapping or {}, attribute_mapping = attribute_mapping or {},
auto_provision=auto_provision, auto_provision = auto_provision,
default_role=default_role, default_role = default_role,
domain_restriction=domain_restriction or [], domain_restriction = domain_restriction or [],
created_at=now, created_at = now,
updated_at=now, updated_at = now,
last_tested_at=None, last_tested_at = None,
last_error=None, last_error = None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -688,7 +688,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id,)) cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -722,7 +722,7 @@ class EnterpriseManager:
WHERE tenant_id = ? AND status = 'active' WHERE tenant_id = ? AND status = 'active'
ORDER BY created_at DESC LIMIT 1 ORDER BY created_at DESC LIMIT 1
""", """,
(tenant_id,), (tenant_id, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -802,7 +802,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id,)) cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id, ))
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
finally: finally:
@@ -818,7 +818,7 @@ class EnterpriseManager:
SELECT * FROM sso_configs WHERE tenant_id = ? SELECT * FROM sso_configs WHERE tenant_id = ?
ORDER BY created_at DESC ORDER BY created_at DESC
""", """,
(tenant_id,), (tenant_id, ),
) )
rows = cursor.fetchall() rows = cursor.fetchall()
@@ -841,30 +841,30 @@ class EnterpriseManager:
# 生成 X.509 证书(简化实现,实际应该生成真实的密钥对) # 生成 X.509 证书(简化实现,实际应该生成真实的密钥对)
cert = config.certificate or self._generate_self_signed_cert() cert = config.certificate or self._generate_self_signed_cert()
metadata = f"""<?xml version="1.0" encoding="UTF-8"?> metadata = f"""<?xml version = "1.0" encoding = "UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" <md:EntityDescriptor xmlns:md = "urn:oasis:names:tc:SAML:2.0:metadata"
entityID="{sp_entity_id}"> entityID = "{sp_entity_id}">
<md:SPSSODescriptor AuthnRequestsSigned="true" <md:SPSSODescriptor AuthnRequestsSigned = "true"
WantAssertionsSigned="true" WantAssertionsSigned = "true"
protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"> protocolSupportEnumeration = "urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use="signing"> <md:KeyDescriptor use = "signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#"> <ds:KeyInfo xmlns:ds = "http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data> <ds:X509Data>
<ds:X509Certificate>{cert}</ds:X509Certificate> <ds:X509Certificate>{cert}</ds:X509Certificate>
</ds:X509Data> </ds:X509Data>
</ds:KeyInfo> </ds:KeyInfo>
</md:KeyDescriptor> </md:KeyDescriptor>
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" <md:SingleLogoutService Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
Location="{slo_url}"/> Location = "{slo_url}"/>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" <md:AssertionConsumerService Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location="{acs_url}" Location = "{acs_url}"
index="0" index = "0"
isDefault="true"/> isDefault = "true"/>
</md:SPSSODescriptor> </md:SPSSODescriptor>
<md:Organization> <md:Organization>
<md:OrganizationName xml:lang="en">InsightFlow</md:OrganizationName> <md:OrganizationName xml:lang = "en">InsightFlow</md:OrganizationName>
<md:OrganizationDisplayName xml:lang="en">InsightFlow</md:OrganizationDisplayName> <md:OrganizationDisplayName xml:lang = "en">InsightFlow</md:OrganizationDisplayName>
<md:OrganizationURL xml:lang="en">{base_url}</md:OrganizationURL> <md:OrganizationURL xml:lang = "en">{base_url}</md:OrganizationURL>
</md:Organization> </md:Organization>
</md:EntityDescriptor>""" </md:EntityDescriptor>"""
@@ -878,18 +878,18 @@ class EnterpriseManager:
try: try:
request_id = f"_{uuid.uuid4().hex}" request_id = f"_{uuid.uuid4().hex}"
now = datetime.now() now = datetime.now()
expires = now + timedelta(minutes=10) expires = now + timedelta(minutes = 10)
auth_request = SAMLAuthRequest( auth_request = SAMLAuthRequest(
id=str(uuid.uuid4()), id = str(uuid.uuid4()),
tenant_id=tenant_id, tenant_id = tenant_id,
sso_config_id=config_id, sso_config_id = config_id,
request_id=request_id, request_id = request_id,
relay_state=relay_state, relay_state = relay_state,
created_at=now, created_at = now,
expires_at=expires, expires_at = expires,
used=False, used = False,
used_at=None, used_at = None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -926,7 +926,7 @@ class EnterpriseManager:
""" """
SELECT * FROM saml_auth_requests WHERE request_id = ? SELECT * FROM saml_auth_requests WHERE request_id = ?
""", """,
(request_id,), (request_id, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -949,17 +949,17 @@ class EnterpriseManager:
attributes = self._parse_saml_response(saml_response) attributes = self._parse_saml_response(saml_response)
auth_response = SAMLAuthResponse( auth_response = SAMLAuthResponse(
id=str(uuid.uuid4()), id = str(uuid.uuid4()),
request_id=request_id, request_id = request_id,
tenant_id="", # 从 request 获取 tenant_id = "", # 从 request 获取
user_id=None, user_id = None,
email=attributes.get("email"), email = attributes.get("email"),
name=attributes.get("name"), name = attributes.get("name"),
attributes=attributes, attributes = attributes,
session_index=attributes.get("session_index"), session_index = attributes.get("session_index"),
processed=False, processed = False,
processed_at=None, processed_at = None,
created_at=datetime.now(), created_at = datetime.now(),
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1028,21 +1028,21 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
config = SCIMConfig( config = SCIMConfig(
id=config_id, id = config_id,
tenant_id=tenant_id, tenant_id = tenant_id,
provider=provider, provider = provider,
status="disabled", status = "disabled",
scim_base_url=scim_base_url, scim_base_url = scim_base_url,
scim_token=scim_token, scim_token = scim_token,
sync_interval_minutes=sync_interval_minutes, sync_interval_minutes = sync_interval_minutes,
last_sync_at=None, last_sync_at = None,
last_sync_status=None, last_sync_status = None,
last_sync_error=None, last_sync_error = None,
last_sync_users_count=0, last_sync_users_count = 0,
attribute_mapping=attribute_mapping or {}, attribute_mapping = attribute_mapping or {},
sync_rules=sync_rules or {}, sync_rules = sync_rules or {},
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1084,7 +1084,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id,)) cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1104,7 +1104,7 @@ class EnterpriseManager:
SELECT * FROM scim_configs WHERE tenant_id = ? SELECT * FROM scim_configs WHERE tenant_id = ?
ORDER BY created_at DESC LIMIT 1 ORDER BY created_at DESC LIMIT 1
""", """,
(tenant_id,), (tenant_id, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -1325,28 +1325,28 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
# 默认7天后过期 # 默认7天后过期
expires_at = now + timedelta(days=7) expires_at = now + timedelta(days = 7)
export = AuditLogExport( export = AuditLogExport(
id=export_id, id = export_id,
tenant_id=tenant_id, tenant_id = tenant_id,
export_format=export_format, export_format = export_format,
start_date=start_date, start_date = start_date,
end_date=end_date, end_date = end_date,
filters=filters or {}, filters = filters or {},
compliance_standard=compliance_standard, compliance_standard = compliance_standard,
status="pending", status = "pending",
file_path=None, file_path = None,
file_size=None, file_size = None,
record_count=None, record_count = None,
checksum=None, checksum = None,
downloaded_by=None, downloaded_by = None,
downloaded_at=None, downloaded_at = None,
expires_at=expires_at, expires_at = expires_at,
created_by=created_by, created_by = created_by,
created_at=now, created_at = now,
completed_at=None, completed_at = None,
error_message=None, error_message = None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1383,7 +1383,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None: def process_audit_export(self, export_id: str, db_manager = None) -> AuditLogExport | None:
"""处理审计日志导出任务""" """处理审计日志导出任务"""
export = self.get_audit_export(export_id) export = self.get_audit_export(export_id)
if not export: if not export:
@@ -1398,7 +1398,7 @@ class EnterpriseManager:
UPDATE audit_log_exports SET status = 'processing' UPDATE audit_log_exports SET status = 'processing'
WHERE id = ? WHERE id = ?
""", """,
(export_id,), (export_id, ),
) )
conn.commit() conn.commit()
@@ -1454,7 +1454,7 @@ class EnterpriseManager:
start_date: datetime, start_date: datetime,
end_date: datetime, end_date: datetime,
filters: dict[str, Any], filters: dict[str, Any],
db_manager=None, db_manager = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""获取审计日志数据""" """获取审计日志数据"""
if db_manager is None: if db_manager is None:
@@ -1488,26 +1488,26 @@ class EnterpriseManager:
import os import os
export_dir = "/tmp/insightflow/exports" export_dir = "/tmp/insightflow/exports"
os.makedirs(export_dir, exist_ok=True) os.makedirs(export_dir, exist_ok = True)
file_path = f"{export_dir}/audit_export_{export_id}.{format}" file_path = f"{export_dir}/audit_export_{export_id}.{format}"
if format == "json": if format == "json":
content = json.dumps(logs, ensure_ascii=False, indent=2) content = json.dumps(logs, ensure_ascii = False, indent = 2)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding = "utf-8") as f:
f.write(content) f.write(content)
elif format == "csv": elif format == "csv":
import csv import csv
if logs: if logs:
with open(file_path, "w", newline="", encoding="utf-8") as f: with open(file_path, "w", newline = "", encoding = "utf-8") as f:
writer = csv.DictWriter(f, fieldnames=logs[0].keys()) writer = csv.DictWriter(f, fieldnames = logs[0].keys())
writer.writeheader() writer.writeheader()
writer.writerows(logs) writer.writerows(logs)
else: else:
# 其他格式暂不支持 # 其他格式暂不支持
content = json.dumps(logs, ensure_ascii=False) content = json.dumps(logs, ensure_ascii = False)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding = "utf-8") as f:
f.write(content) f.write(content)
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
@@ -1523,7 +1523,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id,)) cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1596,24 +1596,24 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
policy = DataRetentionPolicy( policy = DataRetentionPolicy(
id=policy_id, id = policy_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
resource_type=resource_type, resource_type = resource_type,
retention_days=retention_days, retention_days = retention_days,
action=action, action = action,
conditions=conditions or {}, conditions = conditions or {},
auto_execute=auto_execute, auto_execute = auto_execute,
execute_at=execute_at, execute_at = execute_at,
notify_before_days=notify_before_days, notify_before_days = notify_before_days,
archive_location=archive_location, archive_location = archive_location,
archive_encryption=archive_encryption, archive_encryption = archive_encryption,
is_active=True, is_active = True,
last_executed_at=None, last_executed_at = None,
last_execution_result=None, last_execution_result = None,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1661,7 +1661,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id,)) cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1758,7 +1758,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id,)) cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id, ))
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
finally: finally:
@@ -1776,18 +1776,18 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
job = DataRetentionJob( job = DataRetentionJob(
id=job_id, id = job_id,
policy_id=policy_id, policy_id = policy_id,
tenant_id=policy.tenant_id, tenant_id = policy.tenant_id,
status="running", status = "running",
started_at=now, started_at = now,
completed_at=None, completed_at = None,
affected_records=0, affected_records = 0,
archived_records=0, archived_records = 0,
deleted_records=0, deleted_records = 0,
error_count=0, error_count = 0,
details={}, details = {},
created_at=now, created_at = now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1804,7 +1804,7 @@ class EnterpriseManager:
try: try:
# 计算截止日期 # 计算截止日期
cutoff_date = now - timedelta(days=policy.retention_days) cutoff_date = now - timedelta(days = policy.retention_days)
# 根据资源类型执行不同的处理 # 根据资源类型执行不同的处理
if policy.resource_type == "audit_log": if policy.resource_type == "audit_log":
@@ -1887,7 +1887,7 @@ class EnterpriseManager:
SELECT COUNT(*) as count FROM audit_logs SELECT COUNT(*) as count FROM audit_logs
WHERE created_at < ? WHERE created_at < ?
""", """,
(cutoff_date,), (cutoff_date, ),
) )
count = cursor.fetchone()["count"] count = cursor.fetchone()["count"]
@@ -1896,7 +1896,7 @@ class EnterpriseManager:
""" """
DELETE FROM audit_logs WHERE created_at < ? DELETE FROM audit_logs WHERE created_at < ?
""", """,
(cutoff_date,), (cutoff_date, ),
) )
deleted = cursor.rowcount deleted = cursor.rowcount
return {"affected": count, "archived": 0, "deleted": deleted, "errors": 0} return {"affected": count, "archived": 0, "deleted": deleted, "errors": 0}
@@ -1927,7 +1927,7 @@ class EnterpriseManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id,)) cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1963,64 +1963,64 @@ class EnterpriseManager:
def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig: def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig:
"""数据库行转换为 SSOConfig 对象""" """数据库行转换为 SSOConfig 对象"""
return SSOConfig( return SSOConfig(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
provider=row["provider"], provider = row["provider"],
status=row["status"], status = row["status"],
entity_id=row["entity_id"], entity_id = row["entity_id"],
sso_url=row["sso_url"], sso_url = row["sso_url"],
slo_url=row["slo_url"], slo_url = row["slo_url"],
certificate=row["certificate"], certificate = row["certificate"],
metadata_url=row["metadata_url"], metadata_url = row["metadata_url"],
metadata_xml=row["metadata_xml"], metadata_xml = row["metadata_xml"],
client_id=row["client_id"], client_id = row["client_id"],
client_secret=row["client_secret"], client_secret = row["client_secret"],
authorization_url=row["authorization_url"], authorization_url = row["authorization_url"],
token_url=row["token_url"], token_url = row["token_url"],
userinfo_url=row["userinfo_url"], userinfo_url = row["userinfo_url"],
scopes=json.loads(row["scopes"] or '["openid", "email", "profile"]'), scopes = json.loads(row["scopes"] or '["openid", "email", "profile"]'),
attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), attribute_mapping = json.loads(row["attribute_mapping"] or "{}"),
auto_provision=bool(row["auto_provision"]), auto_provision = bool(row["auto_provision"]),
default_role=row["default_role"], default_role = row["default_role"],
domain_restriction=json.loads(row["domain_restriction"] or "[]"), domain_restriction = json.loads(row["domain_restriction"] or "[]"),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
last_tested_at=( last_tested_at = (
datetime.fromisoformat(row["last_tested_at"]) datetime.fromisoformat(row["last_tested_at"])
if row["last_tested_at"] and isinstance(row["last_tested_at"], str) if row["last_tested_at"] and isinstance(row["last_tested_at"], str)
else row["last_tested_at"] else row["last_tested_at"]
), ),
last_error=row["last_error"], last_error = row["last_error"],
) )
def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest: def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest:
"""数据库行转换为 SAMLAuthRequest 对象""" """数据库行转换为 SAMLAuthRequest 对象"""
return SAMLAuthRequest( return SAMLAuthRequest(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
sso_config_id=row["sso_config_id"], sso_config_id = row["sso_config_id"],
request_id=row["request_id"], request_id = row["request_id"],
relay_state=row["relay_state"], relay_state = row["relay_state"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
expires_at=( expires_at = (
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str) if isinstance(row["expires_at"], str)
else row["expires_at"] else row["expires_at"]
), ),
used=bool(row["used"]), used = bool(row["used"]),
used_at=( used_at = (
datetime.fromisoformat(row["used_at"]) datetime.fromisoformat(row["used_at"])
if row["used_at"] and isinstance(row["used_at"], str) if row["used_at"] and isinstance(row["used_at"], str)
else row["used_at"] else row["used_at"]
@@ -2030,29 +2030,29 @@ class EnterpriseManager:
def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig: def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig:
"""数据库行转换为 SCIMConfig 对象""" """数据库行转换为 SCIMConfig 对象"""
return SCIMConfig( return SCIMConfig(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
provider=row["provider"], provider = row["provider"],
status=row["status"], status = row["status"],
scim_base_url=row["scim_base_url"], scim_base_url = row["scim_base_url"],
scim_token=row["scim_token"], scim_token = row["scim_token"],
sync_interval_minutes=row["sync_interval_minutes"], sync_interval_minutes = row["sync_interval_minutes"],
last_sync_at=( last_sync_at = (
datetime.fromisoformat(row["last_sync_at"]) datetime.fromisoformat(row["last_sync_at"])
if row["last_sync_at"] and isinstance(row["last_sync_at"], str) if row["last_sync_at"] and isinstance(row["last_sync_at"], str)
else row["last_sync_at"] else row["last_sync_at"]
), ),
last_sync_status=row["last_sync_status"], last_sync_status = row["last_sync_status"],
last_sync_error=row["last_sync_error"], last_sync_error = row["last_sync_error"],
last_sync_users_count=row["last_sync_users_count"], last_sync_users_count = row["last_sync_users_count"],
attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), attribute_mapping = json.loads(row["attribute_mapping"] or "{}"),
sync_rules=json.loads(row["sync_rules"] or "{}"), sync_rules = json.loads(row["sync_rules"] or "{}"),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2062,28 +2062,28 @@ class EnterpriseManager:
def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser: def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser:
"""数据库行转换为 SCIMUser 对象""" """数据库行转换为 SCIMUser 对象"""
return SCIMUser( return SCIMUser(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
external_id=row["external_id"], external_id = row["external_id"],
user_name=row["user_name"], user_name = row["user_name"],
email=row["email"], email = row["email"],
display_name=row["display_name"], display_name = row["display_name"],
given_name=row["given_name"], given_name = row["given_name"],
family_name=row["family_name"], family_name = row["family_name"],
active=bool(row["active"]), active = bool(row["active"]),
groups=json.loads(row["groups"] or "[]"), groups = json.loads(row["groups"] or "[]"),
raw_data=json.loads(row["raw_data"] or "{}"), raw_data = json.loads(row["raw_data"] or "{}"),
synced_at=( synced_at = (
datetime.fromisoformat(row["synced_at"]) datetime.fromisoformat(row["synced_at"])
if isinstance(row["synced_at"], str) if isinstance(row["synced_at"], str)
else row["synced_at"] else row["synced_at"]
), ),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2093,78 +2093,78 @@ class EnterpriseManager:
def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport: def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport:
"""数据库行转换为 AuditLogExport 对象""" """数据库行转换为 AuditLogExport 对象"""
return AuditLogExport( return AuditLogExport(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
export_format=row["export_format"], export_format = row["export_format"],
start_date=( start_date = (
datetime.fromisoformat(row["start_date"]) datetime.fromisoformat(row["start_date"])
if isinstance(row["start_date"], str) if isinstance(row["start_date"], str)
else row["start_date"] else row["start_date"]
), ),
end_date=datetime.fromisoformat(row["end_date"]) end_date = datetime.fromisoformat(row["end_date"])
if isinstance(row["end_date"], str) if isinstance(row["end_date"], str)
else row["end_date"], else row["end_date"],
filters=json.loads(row["filters"] or "{}"), filters = json.loads(row["filters"] or "{}"),
compliance_standard=row["compliance_standard"], compliance_standard = row["compliance_standard"],
status=row["status"], status = row["status"],
file_path=row["file_path"], file_path = row["file_path"],
file_size=row["file_size"], file_size = row["file_size"],
record_count=row["record_count"], record_count = row["record_count"],
checksum=row["checksum"], checksum = row["checksum"],
downloaded_by=row["downloaded_by"], downloaded_by = row["downloaded_by"],
downloaded_at=( downloaded_at = (
datetime.fromisoformat(row["downloaded_at"]) datetime.fromisoformat(row["downloaded_at"])
if row["downloaded_at"] and isinstance(row["downloaded_at"], str) if row["downloaded_at"] and isinstance(row["downloaded_at"], str)
else row["downloaded_at"] else row["downloaded_at"]
), ),
expires_at=( expires_at = (
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str) if isinstance(row["expires_at"], str)
else row["expires_at"] else row["expires_at"]
), ),
created_by=row["created_by"], created_by = row["created_by"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
completed_at=( completed_at = (
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str) if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"] else row["completed_at"]
), ),
error_message=row["error_message"], error_message = row["error_message"],
) )
def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy: def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy:
"""数据库行转换为 DataRetentionPolicy 对象""" """数据库行转换为 DataRetentionPolicy 对象"""
return DataRetentionPolicy( return DataRetentionPolicy(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
resource_type=row["resource_type"], resource_type = row["resource_type"],
retention_days=row["retention_days"], retention_days = row["retention_days"],
action=row["action"], action = row["action"],
conditions=json.loads(row["conditions"] or "{}"), conditions = json.loads(row["conditions"] or "{}"),
auto_execute=bool(row["auto_execute"]), auto_execute = bool(row["auto_execute"]),
execute_at=row["execute_at"], execute_at = row["execute_at"],
notify_before_days=row["notify_before_days"], notify_before_days = row["notify_before_days"],
archive_location=row["archive_location"], archive_location = row["archive_location"],
archive_encryption=bool(row["archive_encryption"]), archive_encryption = bool(row["archive_encryption"]),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
last_executed_at=( last_executed_at = (
datetime.fromisoformat(row["last_executed_at"]) datetime.fromisoformat(row["last_executed_at"])
if row["last_executed_at"] and isinstance(row["last_executed_at"], str) if row["last_executed_at"] and isinstance(row["last_executed_at"], str)
else row["last_executed_at"] else row["last_executed_at"]
), ),
last_execution_result=row["last_execution_result"], last_execution_result = row["last_execution_result"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2174,26 +2174,26 @@ class EnterpriseManager:
def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob: def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob:
"""数据库行转换为 DataRetentionJob 对象""" """数据库行转换为 DataRetentionJob 对象"""
return DataRetentionJob( return DataRetentionJob(
id=row["id"], id = row["id"],
policy_id=row["policy_id"], policy_id = row["policy_id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
status=row["status"], status = row["status"],
started_at=( started_at = (
datetime.fromisoformat(row["started_at"]) datetime.fromisoformat(row["started_at"])
if row["started_at"] and isinstance(row["started_at"], str) if row["started_at"] and isinstance(row["started_at"], str)
else row["started_at"] else row["started_at"]
), ),
completed_at=( completed_at = (
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str) if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"] else row["completed_at"]
), ),
affected_records=row["affected_records"], affected_records = row["affected_records"],
archived_records=row["archived_records"], archived_records = row["archived_records"],
deleted_records=row["deleted_records"], deleted_records = row["deleted_records"],
error_count=row["error_count"], error_count = row["error_count"],
details=json.loads(row["details"] or "{}"), details = json.loads(row["details"] or "{}"),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]

View File

@@ -27,7 +27,7 @@ class EntityEmbedding:
class EntityAligner: class EntityAligner:
"""实体对齐器 - 使用 embedding 进行相似度匹配""" """实体对齐器 - 使用 embedding 进行相似度匹配"""
def __init__(self, similarity_threshold: float = 0.85): def __init__(self, similarity_threshold: float = 0.85) -> None:
self.similarity_threshold = similarity_threshold self.similarity_threshold = similarity_threshold
self.embedding_cache: dict[str, list[float]] = {} self.embedding_cache: dict[str, list[float]] = {}
@@ -52,12 +52,12 @@ class EntityAligner:
try: try:
response = httpx.post( response = httpx.post(
f"{KIMI_BASE_URL}/v1/embeddings", f"{KIMI_BASE_URL}/v1/embeddings",
headers={ headers = {
"Authorization": f"Bearer {KIMI_API_KEY}", "Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json={"model": "k2p5", "input": text[:500]}, # 限制长度 json = {"model": "k2p5", "input": text[:500]}, # 限制长度
timeout=30.0, timeout = 30.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -232,7 +232,7 @@ class EntityAligner:
for new_ent in new_entities: for new_ent in new_entities:
matched = self.find_similar_entity( matched = self.find_similar_entity(
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold
) )
result = { result = {
@@ -292,16 +292,16 @@ class EntityAligner:
try: try:
response = httpx.post( response = httpx.post(
f"{KIMI_BASE_URL}/v1/chat/completions", f"{KIMI_BASE_URL}/v1/chat/completions",
headers={ headers = {
"Authorization": f"Bearer {KIMI_API_KEY}", "Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json={ json = {
"model": "k2p5", "model": "k2p5",
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.3, "temperature": 0.3,
}, },
timeout=30.0, timeout = 30.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()

View File

@@ -71,7 +71,7 @@ class ExportTranscript:
class ExportManager: class ExportManager:
"""导出管理器 - 处理各种导出需求""" """导出管理器 - 处理各种导出需求"""
def __init__(self, db_manager=None): def __init__(self, db_manager = None) -> None:
self.db = db_manager self.db = db_manager
def export_knowledge_graph_svg( def export_knowledge_graph_svg(
@@ -121,17 +121,17 @@ class ExportManager:
# 生成 SVG # 生成 SVG
svg_parts = [ svg_parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" ' f'<svg xmlns = "http://www.w3.org/2000/svg" width = "{width}" height = "{height}" '
f'viewBox="0 0 {width} {height}">', f'viewBox = "0 0 {width} {height}">',
"<defs>", "<defs>",
' <marker id="arrowhead" markerWidth="10" markerHeight="7" ' ' <marker id = "arrowhead" markerWidth = "10" markerHeight = "7" '
'refX="9" refY="3.5" orient="auto">', 'refX = "9" refY = "3.5" orient = "auto">',
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>', ' <polygon points = "0 0, 10 3.5, 0 7" fill = "#7f8c8d"/>',
" </marker>", " </marker>",
"</defs>", "</defs>",
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>', f'<rect width = "{width}" height = "{height}" fill = "#f8f9fa"/>',
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" ' f'<text x = "{center_x}" y = "30" text-anchor = "middle" font-size = "20" '
f'font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>', f'font-weight = "bold" fill = "#2c3e50">知识图谱 - {project_id}</text>',
] ]
# 绘制关系连线 # 绘制关系连线
@@ -150,20 +150,20 @@ class ExportManager:
y2 = y2 - dy * offset / dist y2 = y2 - dy * offset / dist
svg_parts.append( svg_parts.append(
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" ' f'<line x1 = "{x1}" y1 = "{y1}" x2 = "{x2}" y2 = "{y2}" '
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>' f'stroke = "#7f8c8d" stroke-width = "2" marker-end = "url(#arrowhead)" opacity = "0.6"/>'
) )
# 关系标签 # 关系标签
mid_x = (x1 + x2) / 2 mid_x = (x1 + x2) / 2
mid_y = (y1 + y2) / 2 mid_y = (y1 + y2) / 2
svg_parts.append( svg_parts.append(
f'<rect x="{mid_x - 30}" y="{mid_y - 10}" width="60" height="20" ' f'<rect x = "{mid_x - 30}" y = "{mid_y - 10}" width = "60" height = "20" '
f'fill="white" stroke="#bdc3c7" rx="3"/>' f'fill = "white" stroke = "#bdc3c7" rx = "3"/>'
) )
svg_parts.append( svg_parts.append(
f'<text x="{mid_x}" y="{mid_y + 5}" text-anchor="middle" ' f'<text x = "{mid_x}" y = "{mid_y + 5}" text-anchor = "middle" '
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>' f'font-size = "10" fill = "#2c3e50">{rel.relation_type}</text>'
) )
# 绘制实体节点 # 绘制实体节点
@@ -174,19 +174,19 @@ class ExportManager:
# 节点圆圈 # 节点圆圈
svg_parts.append( svg_parts.append(
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>' f'<circle cx = "{x}" cy = "{y}" r = "35" fill = "{color}" stroke = "white" stroke-width = "3"/>'
) )
# 实体名称 # 实体名称
svg_parts.append( svg_parts.append(
f'<text x="{x}" y="{y + 5}" text-anchor="middle" font-size="12" ' f'<text x = "{x}" y = "{y + 5}" text-anchor = "middle" font-size = "12" '
f'font-weight="bold" fill="white">{entity.name[:8]}</text>' f'font-weight = "bold" fill = "white">{entity.name[:8]}</text>'
) )
# 实体类型 # 实体类型
svg_parts.append( svg_parts.append(
f'<text x="{x}" y="{y + 55}" text-anchor="middle" font-size="10" ' f'<text x = "{x}" y = "{y + 55}" text-anchor = "middle" font-size = "10" '
f'fill="#7f8c8d">{entity.type}</text>' f'fill = "#7f8c8d">{entity.type}</text>'
) )
# 图例 # 图例
@@ -196,24 +196,24 @@ class ExportManager:
rect_y = legend_y - 20 rect_y = legend_y - 20
rect_height = len(type_colors) * 25 + 10 rect_height = len(type_colors) * 25 + 10
svg_parts.append( svg_parts.append(
f'<rect x="{rect_x}" y="{rect_y}" width="140" height="{rect_height}" ' f'<rect x = "{rect_x}" y = "{rect_y}" width = "140" height = "{rect_height}" '
f'fill="white" stroke="#bdc3c7" rx="5"/>' f'fill = "white" stroke = "#bdc3c7" rx = "5"/>'
) )
svg_parts.append( svg_parts.append(
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" ' f'<text x = "{legend_x}" y = "{legend_y}" font-size = "12" font-weight = "bold" '
f'fill="#2c3e50">实体类型</text>' f'fill = "#2c3e50">实体类型</text>'
) )
for i, (etype, color) in enumerate(type_colors.items()): for i, (etype, color) in enumerate(type_colors.items()):
if etype != "default": if etype != "default":
y_pos = legend_y + 25 + i * 20 y_pos = legend_y + 25 + i * 20
svg_parts.append( svg_parts.append(
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>' f'<circle cx = "{legend_x + 10}" cy = "{y_pos}" r = "8" fill = "{color}"/>'
) )
text_y = y_pos + 4 text_y = y_pos + 4
svg_parts.append( svg_parts.append(
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" ' f'<text x = "{legend_x + 25}" y = "{text_y}" font-size = "10" '
f'fill="#2c3e50">{etype}</text>' f'fill = "#2c3e50">{etype}</text>'
) )
svg_parts.append("</svg>") svg_parts.append("</svg>")
@@ -232,7 +232,7 @@ class ExportManager:
import cairosvg import cairosvg
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8"))
return png_bytes return png_bytes
except ImportError: except ImportError:
# 如果没有 cairosvg返回 SVG 的 base64 # 如果没有 cairosvg返回 SVG 的 base64
@@ -269,8 +269,8 @@ class ExportManager:
# 写入 Excel # 写入 Excel
output = io.BytesIO() output = io.BytesIO()
with pd.ExcelWriter(output, engine="openpyxl") as writer: with pd.ExcelWriter(output, engine = "openpyxl") as writer:
df.to_excel(writer, sheet_name="实体列表", index=False) df.to_excel(writer, sheet_name = "实体列表", index = False)
# 调整列宽 # 调整列宽
worksheet = writer.sheets["实体列表"] worksheet = writer.sheets["实体列表"]
@@ -417,24 +417,24 @@ class ExportManager:
output = io.BytesIO() output = io.BytesIO()
doc = SimpleDocTemplate( doc = SimpleDocTemplate(
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18 output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18
) )
# 样式 # 样式
styles = getSampleStyleSheet() styles = getSampleStyleSheet()
title_style = ParagraphStyle( title_style = ParagraphStyle(
"CustomTitle", "CustomTitle",
parent=styles["Heading1"], parent = styles["Heading1"],
fontSize=24, fontSize = 24,
spaceAfter=30, spaceAfter = 30,
textColor=colors.HexColor("#2c3e50"), textColor = colors.HexColor("#2c3e50"),
) )
heading_style = ParagraphStyle( heading_style = ParagraphStyle(
"CustomHeading", "CustomHeading",
parent=styles["Heading2"], parent = styles["Heading2"],
fontSize=16, fontSize = 16,
spaceAfter=12, spaceAfter = 12,
textColor=colors.HexColor("#34495e"), textColor = colors.HexColor("#34495e"),
) )
story = [] story = []
@@ -467,7 +467,7 @@ class ExportManager:
for etype, count in sorted(type_counts.items()): for etype, count in sorted(type_counts.items()):
stats_data.append([f"{etype} 实体", str(count)]) stats_data.append([f"{etype} 实体", str(count)])
stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch]) stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch])
stats_table.setStyle( stats_table.setStyle(
TableStyle( TableStyle(
[ [
@@ -497,7 +497,7 @@ class ExportManager:
story.append(Paragraph("实体列表", heading_style)) story.append(Paragraph("实体列表", heading_style))
entity_data = [["名称", "类型", "提及次数", "定义"]] entity_data = [["名称", "类型", "提及次数", "定义"]]
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[ for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[
:50 :50
]: # 限制前50个 ]: # 限制前50个
entity_data.append( entity_data.append(
@@ -510,7 +510,7 @@ class ExportManager:
) )
entity_table = Table( entity_table = Table(
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch] entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
) )
entity_table.setStyle( entity_table.setStyle(
TableStyle( TableStyle(
@@ -539,7 +539,7 @@ class ExportManager:
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"]) relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
relation_table = Table( relation_table = Table(
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch] relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
) )
relation_table.setStyle( relation_table.setStyle(
TableStyle( TableStyle(
@@ -613,14 +613,14 @@ class ExportManager:
], ],
} }
return json.dumps(data, ensure_ascii=False, indent=2) return json.dumps(data, ensure_ascii = False, indent = 2)
# 全局导出管理器实例 # 全局导出管理器实例
_export_manager = None _export_manager = None
def get_export_manager(db_manager=None) -> None: def get_export_manager(db_manager = None) -> None:
"""获取导出管理器实例""" """获取导出管理器实例"""
global _export_manager global _export_manager
if _export_manager is None: if _export_manager is None:

View File

@@ -362,7 +362,7 @@ class TeamIncentive:
class GrowthManager: class GrowthManager:
"""运营与增长管理主类""" """运营与增长管理主类"""
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path self.db_path = db_path
self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "") self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "")
self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "") self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "")
@@ -394,19 +394,19 @@ class GrowthManager:
now = datetime.now() now = datetime.now()
event = AnalyticsEvent( event = AnalyticsEvent(
id=event_id, id = event_id,
tenant_id=tenant_id, tenant_id = tenant_id,
user_id=user_id, user_id = user_id,
event_type=event_type, event_type = event_type,
event_name=event_name, event_name = event_name,
properties=properties or {}, properties = properties or {},
timestamp=now, timestamp = now,
session_id=session_id, session_id = session_id,
device_info=device_info or {}, device_info = device_info or {},
referrer=referrer, referrer = referrer,
utm_source=utm_params.get("source") if utm_params else None, utm_source = utm_params.get("source") if utm_params else None,
utm_medium=utm_params.get("medium") if utm_params else None, utm_medium = utm_params.get("medium") if utm_params else None,
utm_campaign=utm_params.get("campaign") if utm_params else None, utm_campaign = utm_params.get("campaign") if utm_params else None,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -443,7 +443,7 @@ class GrowthManager:
return event return event
async def _send_to_analytics_platforms(self, event: AnalyticsEvent): async def _send_to_analytics_platforms(self, event: AnalyticsEvent) -> None:
"""发送事件到第三方分析平台""" """发送事件到第三方分析平台"""
tasks = [] tasks = []
@@ -453,9 +453,9 @@ class GrowthManager:
tasks.append(self._send_to_amplitude(event)) tasks.append(self._send_to_amplitude(event))
if tasks: if tasks:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions = True)
async def _send_to_mixpanel(self, event: AnalyticsEvent): async def _send_to_mixpanel(self, event: AnalyticsEvent) -> None:
"""发送事件到 Mixpanel""" """发送事件到 Mixpanel"""
try: try:
headers = { headers = {
@@ -475,12 +475,12 @@ class GrowthManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post( await client.post(
"https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0 "https://api.mixpanel.com/track", headers = headers, json = [payload], timeout = 10.0
) )
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
print(f"Failed to send to Mixpanel: {e}") print(f"Failed to send to Mixpanel: {e}")
async def _send_to_amplitude(self, event: AnalyticsEvent): async def _send_to_amplitude(self, event: AnalyticsEvent) -> None:
"""发送事件到 Amplitude""" """发送事件到 Amplitude"""
try: try:
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@@ -501,16 +501,16 @@ class GrowthManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post( await client.post(
"https://api.amplitude.com/2/httpapi", "https://api.amplitude.com/2/httpapi",
headers=headers, headers = headers,
json=payload, json = payload,
timeout=10.0, timeout = 10.0,
) )
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
print(f"Failed to send to Amplitude: {e}") print(f"Failed to send to Amplitude: {e}")
async def _update_user_profile( async def _update_user_profile(
self, tenant_id: str, user_id: str, event_type: EventType, event_name: str self, tenant_id: str, user_id: str, event_type: EventType, event_name: str
): ) -> None:
"""更新用户画像""" """更新用户画像"""
with self._get_db() as conn: with self._get_db() as conn:
# 检查用户画像是否存在 # 检查用户画像是否存在
@@ -642,13 +642,13 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
funnel = Funnel( funnel = Funnel(
id=funnel_id, id = funnel_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
steps=steps, steps = steps,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -677,7 +677,7 @@ class GrowthManager:
) -> FunnelAnalysis | None: ) -> FunnelAnalysis | None:
"""分析漏斗转化率""" """分析漏斗转化率"""
with self._get_db() as conn: with self._get_db() as conn:
funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id,)).fetchone() funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id, )).fetchone()
if not funnel_row: if not funnel_row:
return None return None
@@ -685,7 +685,7 @@ class GrowthManager:
steps = json.loads(funnel_row["steps"]) steps = json.loads(funnel_row["steps"])
if not period_start: if not period_start:
period_start = datetime.now() - timedelta(days=30) period_start = datetime.now() - timedelta(days = 30)
if not period_end: if not period_end:
period_end = datetime.now() period_end = datetime.now()
@@ -740,13 +740,13 @@ class GrowthManager:
] ]
return FunnelAnalysis( return FunnelAnalysis(
funnel_id=funnel_id, funnel_id = funnel_id,
period_start=period_start, period_start = period_start,
period_end=period_end, period_end = period_end,
total_users=step_conversions[0]["user_count"] if step_conversions else 0, total_users = step_conversions[0]["user_count"] if step_conversions else 0,
step_conversions=step_conversions, step_conversions = step_conversions,
overall_conversion=round(overall_conversion, 4), overall_conversion = round(overall_conversion, 4),
drop_off_points=drop_off_points, drop_off_points = drop_off_points,
) )
def calculate_retention( def calculate_retention(
@@ -781,14 +781,14 @@ class GrowthManager:
retention_rates = {} retention_rates = {}
for period in periods: for period in periods:
period_date = cohort_date + timedelta(days=period) period_date = cohort_date + timedelta(days = period)
active_query = """ active_query = """
SELECT COUNT(DISTINCT user_id) as active_count SELECT COUNT(DISTINCT user_id) as active_count
FROM analytics_events FROM analytics_events
WHERE tenant_id = ? AND date(timestamp) = date(?) WHERE tenant_id = ? AND date(timestamp) = date(?)
AND user_id IN ({}) AND user_id IN ({})
""".format(",".join(["?" for _ in cohort_users])) """.format(", ".join(["?" for _ in cohort_users]))
params = [tenant_id, period_date.isoformat()] + list(cohort_users) params = [tenant_id, period_date.isoformat()] + list(cohort_users)
row = conn.execute(active_query, params).fetchone() row = conn.execute(active_query, params).fetchone()
@@ -830,25 +830,25 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
experiment = Experiment( experiment = Experiment(
id=experiment_id, id = experiment_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
hypothesis=hypothesis, hypothesis = hypothesis,
status=ExperimentStatus.DRAFT, status = ExperimentStatus.DRAFT,
variants=variants, variants = variants,
traffic_allocation=traffic_allocation, traffic_allocation = traffic_allocation,
traffic_split=traffic_split, traffic_split = traffic_split,
target_audience=target_audience, target_audience = target_audience,
primary_metric=primary_metric, primary_metric = primary_metric,
secondary_metrics=secondary_metrics, secondary_metrics = secondary_metrics,
start_date=None, start_date = None,
end_date=None, end_date = None,
min_sample_size=min_sample_size, min_sample_size = min_sample_size,
confidence_level=confidence_level, confidence_level = confidence_level,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
created_by=created_by or "system", created_by = created_by or "system",
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -891,7 +891,7 @@ class GrowthManager:
"""获取实验详情""" """获取实验详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT * FROM experiments WHERE id = ?", (experiment_id,) "SELECT * FROM experiments WHERE id = ?", (experiment_id, )
).fetchone() ).fetchone()
if row: if row:
@@ -973,7 +973,7 @@ class GrowthManager:
total = sum(weights) total = sum(weights)
normalized_weights = [w / total for w in weights] normalized_weights = [w / total for w in weights]
return random.choices(variant_ids, weights=normalized_weights, k=1)[0] return random.choices(variant_ids, weights = normalized_weights, k = 1)[0]
def _stratified_allocation( def _stratified_allocation(
self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict
@@ -1027,7 +1027,7 @@ class GrowthManager:
user_id: str, user_id: str,
metric_name: str, metric_name: str,
metric_value: float, metric_value: float,
): ) -> None:
"""记录实验指标""" """记录实验指标"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -1196,21 +1196,21 @@ class GrowthManager:
variables = re.findall(r"\{\{(\w+)\}\}", html_content) variables = re.findall(r"\{\{(\w+)\}\}", html_content)
template = EmailTemplate( template = EmailTemplate(
id=template_id, id = template_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
template_type=template_type, template_type = template_type,
subject=subject, subject = subject,
html_content=html_content, html_content = html_content,
text_content=text_content or re.sub(r"<[^>]+>", "", html_content), text_content = text_content or re.sub(r"<[^>]+>", "", html_content),
variables=variables, variables = variables,
preview_text=None, preview_text = None,
from_name=from_name or "InsightFlow", from_name = from_name or "InsightFlow",
from_email=from_email or "noreply@insightflow.io", from_email = from_email or "noreply@insightflow.io",
reply_to=reply_to, reply_to = reply_to,
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1246,7 +1246,7 @@ class GrowthManager:
"""获取邮件模板""" """获取邮件模板"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT * FROM email_templates WHERE id = ?", (template_id,) "SELECT * FROM email_templates WHERE id = ?", (template_id, )
).fetchone() ).fetchone()
if row: if row:
@@ -1308,22 +1308,22 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
campaign = EmailCampaign( campaign = EmailCampaign(
id=campaign_id, id = campaign_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
template_id=template_id, template_id = template_id,
status="draft", status = "draft",
recipient_count=len(recipient_list), recipient_count = len(recipient_list),
sent_count=0, sent_count = 0,
delivered_count=0, delivered_count = 0,
opened_count=0, opened_count = 0,
clicked_count=0, clicked_count = 0,
bounced_count=0, bounced_count = 0,
failed_count=0, failed_count = 0,
scheduled_at=scheduled_at.isoformat() if scheduled_at else None, scheduled_at = scheduled_at.isoformat() if scheduled_at else None,
started_at=None, started_at = None,
completed_at=None, completed_at = None,
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1452,7 +1452,7 @@ class GrowthManager:
"""发送整个营销活动""" """发送整个营销活动"""
with self._get_db() as conn: with self._get_db() as conn:
campaign_row = conn.execute( campaign_row = conn.execute(
"SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,) "SELECT * FROM email_campaigns WHERE id = ?", (campaign_id, )
).fetchone() ).fetchone()
if not campaign_row: if not campaign_row:
@@ -1530,17 +1530,17 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
workflow = AutomationWorkflow( workflow = AutomationWorkflow(
id=workflow_id, id = workflow_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
trigger_type=trigger_type, trigger_type = trigger_type,
trigger_conditions=trigger_conditions, trigger_conditions = trigger_conditions,
actions=actions, actions = actions,
is_active=True, is_active = True,
execution_count=0, execution_count = 0,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1569,11 +1569,11 @@ class GrowthManager:
return workflow return workflow
async def trigger_workflow(self, workflow_id: str, event_data: dict): async def trigger_workflow(self, workflow_id: str, event_data: dict) -> None:
"""触发自动化工作流""" """触发自动化工作流"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id,) "SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -1592,7 +1592,7 @@ class GrowthManager:
# 更新执行计数 # 更新执行计数
conn.execute( conn.execute(
"UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?", "UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?",
(workflow_id,), (workflow_id, ),
) )
conn.commit() conn.commit()
@@ -1606,7 +1606,7 @@ class GrowthManager:
return False return False
return True return True
async def _execute_action(self, action: dict, event_data: dict): async def _execute_action(self, action: dict, event_data: dict) -> None:
"""执行工作流动作""" """执行工作流动作"""
action_type = action.get("type") action_type = action.get("type")
@@ -1640,20 +1640,20 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
program = ReferralProgram( program = ReferralProgram(
id=program_id, id = program_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
referrer_reward_type=referrer_reward_type, referrer_reward_type = referrer_reward_type,
referrer_reward_value=referrer_reward_value, referrer_reward_value = referrer_reward_value,
referee_reward_type=referee_reward_type, referee_reward_type = referee_reward_type,
referee_reward_value=referee_reward_value, referee_reward_value = referee_reward_value,
max_referrals_per_user=max_referrals_per_user, max_referrals_per_user = max_referrals_per_user,
referral_code_length=referral_code_length, referral_code_length = referral_code_length,
expiry_days=expiry_days, expiry_days = expiry_days,
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1708,24 +1708,24 @@ class GrowthManager:
referral_id = f"ref_{uuid.uuid4().hex[:16]}" referral_id = f"ref_{uuid.uuid4().hex[:16]}"
now = datetime.now() now = datetime.now()
expires_at = now + timedelta(days=program.expiry_days) expires_at = now + timedelta(days = program.expiry_days)
referral = Referral( referral = Referral(
id=referral_id, id = referral_id,
program_id=program_id, program_id = program_id,
tenant_id=program.tenant_id, tenant_id = program.tenant_id,
referrer_id=referrer_id, referrer_id = referrer_id,
referee_id=None, referee_id = None,
referral_code=referral_code, referral_code = referral_code,
status=ReferralStatus.PENDING, status = ReferralStatus.PENDING,
referrer_rewarded=False, referrer_rewarded = False,
referee_rewarded=False, referee_rewarded = False,
referrer_reward_value=program.referrer_reward_value, referrer_reward_value = program.referrer_reward_value,
referee_reward_value=program.referee_reward_value, referee_reward_value = program.referee_reward_value,
converted_at=None, converted_at = None,
rewarded_at=None, rewarded_at = None,
expires_at=expires_at, expires_at = expires_at,
created_at=now, created_at = now,
) )
conn.execute( conn.execute(
@@ -1762,11 +1762,11 @@ class GrowthManager:
"""生成唯一推荐码""" """生成唯一推荐码"""
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符 chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符
while True: while True:
code = "".join(random.choices(chars, k=length)) code = "".join(random.choices(chars, k = length))
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT 1 FROM referrals WHERE referral_code = ?", (code,) "SELECT 1 FROM referrals WHERE referral_code = ?", (code, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -1776,7 +1776,7 @@ class GrowthManager:
"""获取推荐计划""" """获取推荐计划"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT * FROM referral_programs WHERE id = ?", (program_id,) "SELECT * FROM referral_programs WHERE id = ?", (program_id, )
).fetchone() ).fetchone()
if row: if row:
@@ -1811,7 +1811,7 @@ class GrowthManager:
def reward_referral(self, referral_id: str) -> bool: def reward_referral(self, referral_id: str) -> bool:
"""发放推荐奖励""" """发放推荐奖励"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id,)).fetchone() row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id, )).fetchone()
if not row or row["status"] != ReferralStatus.CONVERTED.value: if not row or row["status"] != ReferralStatus.CONVERTED.value:
return False return False
@@ -1883,18 +1883,18 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
incentive = TeamIncentive( incentive = TeamIncentive(
id=incentive_id, id = incentive_id,
tenant_id=tenant_id, tenant_id = tenant_id,
name=name, name = name,
description=description, description = description,
target_tier=target_tier, target_tier = target_tier,
min_team_size=min_team_size, min_team_size = min_team_size,
incentive_type=incentive_type, incentive_type = incentive_type,
incentive_value=incentive_value, incentive_value = incentive_value,
valid_from=valid_from.isoformat(), valid_from = valid_from.isoformat(),
valid_until=valid_until.isoformat(), valid_until = valid_until.isoformat(),
is_active=True, is_active = True,
created_at=now, created_at = now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1947,7 +1947,7 @@ class GrowthManager:
def get_realtime_dashboard(self, tenant_id: str) -> dict: def get_realtime_dashboard(self, tenant_id: str) -> dict:
"""获取实时分析仪表板数据""" """获取实时分析仪表板数据"""
now = datetime.now() now = datetime.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) today_start = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0)
with self._get_db() as conn: with self._get_db() as conn:
# 今日统计 # 今日统计
@@ -1972,7 +1972,7 @@ class GrowthManager:
ORDER BY timestamp DESC ORDER BY timestamp DESC
LIMIT 20 LIMIT 20
""", """,
(tenant_id,), (tenant_id, ),
).fetchall() ).fetchall()
# 热门功能 # 热门功能
@@ -1991,8 +1991,8 @@ class GrowthManager:
# 活跃用户趋势最近24小时每小时 # 活跃用户趋势最近24小时每小时
hourly_trend = [] hourly_trend = []
for i in range(24): for i in range(24):
hour_start = now - timedelta(hours=i + 1) hour_start = now - timedelta(hours = i + 1)
hour_end = now - timedelta(hours=i) hour_end = now - timedelta(hours = i)
row = conn.execute( row = conn.execute(
""" """
@@ -2035,116 +2035,116 @@ class GrowthManager:
def _row_to_user_profile(self, row) -> UserProfile: def _row_to_user_profile(self, row) -> UserProfile:
"""将数据库行转换为 UserProfile""" """将数据库行转换为 UserProfile"""
return UserProfile( return UserProfile(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
user_id=row["user_id"], user_id = row["user_id"],
first_seen=datetime.fromisoformat(row["first_seen"]), first_seen = datetime.fromisoformat(row["first_seen"]),
last_seen=datetime.fromisoformat(row["last_seen"]), last_seen = datetime.fromisoformat(row["last_seen"]),
total_sessions=row["total_sessions"], total_sessions = row["total_sessions"],
total_events=row["total_events"], total_events = row["total_events"],
feature_usage=json.loads(row["feature_usage"]), feature_usage = json.loads(row["feature_usage"]),
subscription_history=json.loads(row["subscription_history"]), subscription_history = json.loads(row["subscription_history"]),
ltv=row["ltv"], ltv = row["ltv"],
churn_risk_score=row["churn_risk_score"], churn_risk_score = row["churn_risk_score"],
engagement_score=row["engagement_score"], engagement_score = row["engagement_score"],
created_at=datetime.fromisoformat(row["created_at"]), created_at = datetime.fromisoformat(row["created_at"]),
updated_at=datetime.fromisoformat(row["updated_at"]), updated_at = datetime.fromisoformat(row["updated_at"]),
) )
def _row_to_experiment(self, row) -> Experiment: def _row_to_experiment(self, row) -> Experiment:
"""将数据库行转换为 Experiment""" """将数据库行转换为 Experiment"""
return Experiment( return Experiment(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
hypothesis=row["hypothesis"], hypothesis = row["hypothesis"],
status=ExperimentStatus(row["status"]), status = ExperimentStatus(row["status"]),
variants=json.loads(row["variants"]), variants = json.loads(row["variants"]),
traffic_allocation=TrafficAllocationType(row["traffic_allocation"]), traffic_allocation = TrafficAllocationType(row["traffic_allocation"]),
traffic_split=json.loads(row["traffic_split"]), traffic_split = json.loads(row["traffic_split"]),
target_audience=json.loads(row["target_audience"]), target_audience = json.loads(row["target_audience"]),
primary_metric=row["primary_metric"], primary_metric = row["primary_metric"],
secondary_metrics=json.loads(row["secondary_metrics"]), secondary_metrics = json.loads(row["secondary_metrics"]),
start_date=datetime.fromisoformat(row["start_date"]) if row["start_date"] else None, start_date = datetime.fromisoformat(row["start_date"]) if row["start_date"] else None,
end_date=datetime.fromisoformat(row["end_date"]) if row["end_date"] else None, end_date = datetime.fromisoformat(row["end_date"]) if row["end_date"] else None,
min_sample_size=row["min_sample_size"], min_sample_size = row["min_sample_size"],
confidence_level=row["confidence_level"], confidence_level = row["confidence_level"],
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
created_by=row["created_by"], created_by = row["created_by"],
) )
def _row_to_email_template(self, row) -> EmailTemplate: def _row_to_email_template(self, row) -> EmailTemplate:
"""将数据库行转换为 EmailTemplate""" """将数据库行转换为 EmailTemplate"""
return EmailTemplate( return EmailTemplate(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
template_type=EmailTemplateType(row["template_type"]), template_type = EmailTemplateType(row["template_type"]),
subject=row["subject"], subject = row["subject"],
html_content=row["html_content"], html_content = row["html_content"],
text_content=row["text_content"], text_content = row["text_content"],
variables=json.loads(row["variables"]), variables = json.loads(row["variables"]),
preview_text=row["preview_text"], preview_text = row["preview_text"],
from_name=row["from_name"], from_name = row["from_name"],
from_email=row["from_email"], from_email = row["from_email"],
reply_to=row["reply_to"], reply_to = row["reply_to"],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
) )
def _row_to_automation_workflow(self, row) -> AutomationWorkflow: def _row_to_automation_workflow(self, row) -> AutomationWorkflow:
"""将数据库行转换为 AutomationWorkflow""" """将数据库行转换为 AutomationWorkflow"""
return AutomationWorkflow( return AutomationWorkflow(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
trigger_type=WorkflowTriggerType(row["trigger_type"]), trigger_type = WorkflowTriggerType(row["trigger_type"]),
trigger_conditions=json.loads(row["trigger_conditions"]), trigger_conditions = json.loads(row["trigger_conditions"]),
actions=json.loads(row["actions"]), actions = json.loads(row["actions"]),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
execution_count=row["execution_count"], execution_count = row["execution_count"],
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
) )
def _row_to_referral_program(self, row) -> ReferralProgram: def _row_to_referral_program(self, row) -> ReferralProgram:
"""将数据库行转换为 ReferralProgram""" """将数据库行转换为 ReferralProgram"""
return ReferralProgram( return ReferralProgram(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
referrer_reward_type=row["referrer_reward_type"], referrer_reward_type = row["referrer_reward_type"],
referrer_reward_value=row["referrer_reward_value"], referrer_reward_value = row["referrer_reward_value"],
referee_reward_type=row["referee_reward_type"], referee_reward_type = row["referee_reward_type"],
referee_reward_value=row["referee_reward_value"], referee_reward_value = row["referee_reward_value"],
max_referrals_per_user=row["max_referrals_per_user"], max_referrals_per_user = row["max_referrals_per_user"],
referral_code_length=row["referral_code_length"], referral_code_length = row["referral_code_length"],
expiry_days=row["expiry_days"], expiry_days = row["expiry_days"],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
) )
def _row_to_team_incentive(self, row) -> TeamIncentive: def _row_to_team_incentive(self, row) -> TeamIncentive:
"""将数据库行转换为 TeamIncentive""" """将数据库行转换为 TeamIncentive"""
return TeamIncentive( return TeamIncentive(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
name=row["name"], name = row["name"],
description=row["description"], description = row["description"],
target_tier=row["target_tier"], target_tier = row["target_tier"],
min_team_size=row["min_team_size"], min_team_size = row["min_team_size"],
incentive_type=row["incentive_type"], incentive_type = row["incentive_type"],
incentive_value=row["incentive_value"], incentive_value = row["incentive_value"],
valid_from=datetime.fromisoformat(row["valid_from"]), valid_from = datetime.fromisoformat(row["valid_from"]),
valid_until=datetime.fromisoformat(row["valid_until"]), valid_until = datetime.fromisoformat(row["valid_until"]),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
) )

View File

@@ -104,7 +104,7 @@ class ImageProcessor:
temp_dir: 临时文件目录 temp_dir: 临时文件目录
""" """
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok=True) os.makedirs(self.temp_dir, exist_ok = True)
def preprocess_image(self, image, image_type: str = None) -> None: def preprocess_image(self, image, image_type: str = None) -> None:
""" """
@@ -169,7 +169,7 @@ class ImageProcessor:
gray = image.convert("L") gray = image.convert("L")
# 轻微降噪 # 轻微降噪
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1)) blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1))
# 增强对比度 # 增强对比度
enhancer = ImageEnhance.Contrast(blurred) enhancer = ImageEnhance.Contrast(blurred)
@@ -255,10 +255,10 @@ class ImageProcessor:
processed_image = self.preprocess_image(image) processed_image = self.preprocess_image(image)
# 执行OCR # 执行OCR
text = pytesseract.image_to_string(processed_image, lang=lang) text = pytesseract.image_to_string(processed_image, lang = lang)
# 获取置信度 # 获取置信度
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT) data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0] confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0 avg_confidence = sum(confidences) / len(confidences) if confidences else 0
@@ -288,12 +288,12 @@ class ImageProcessor:
for match in re.finditer(project_pattern, text): for match in re.finditer(project_pattern, text):
name = match.group(1) or match.group(2) name = match.group(1) or match.group(2)
if name and len(name) > 2: if name and len(name) > 2:
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7)) entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7))
# 人名(中文) # 人名(中文)
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)" name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)"
for match in re.finditer(name_pattern, text): for match in re.finditer(name_pattern, text):
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8)) entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8))
# 技术术语 # 技术术语
tech_keywords = [ tech_keywords = [
@@ -314,7 +314,7 @@ class ImageProcessor:
] ]
for keyword in tech_keywords: for keyword in tech_keywords:
if keyword in text: if keyword in text:
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9)) entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9))
# 去重 # 去重
seen = set() seen = set()
@@ -381,16 +381,16 @@ class ImageProcessor:
if not PIL_AVAILABLE: if not PIL_AVAILABLE:
return ImageProcessingResult( return ImageProcessingResult(
image_id=image_id, image_id = image_id,
image_type="other", image_type = "other",
ocr_text="", ocr_text = "",
description="PIL not available", description = "PIL not available",
entities=[], entities = [],
relations=[], relations = [],
width=0, width = 0,
height=0, height = 0,
success=False, success = False,
error_message="PIL library not available", error_message = "PIL library not available",
) )
try: try:
@@ -421,29 +421,29 @@ class ImageProcessor:
image.save(save_path) image.save(save_path)
return ImageProcessingResult( return ImageProcessingResult(
image_id=image_id, image_id = image_id,
image_type=image_type, image_type = image_type,
ocr_text=ocr_text, ocr_text = ocr_text,
description=description, description = description,
entities=entities, entities = entities,
relations=relations, relations = relations,
width=width, width = width,
height=height, height = height,
success=True, success = True,
) )
except Exception as e: except Exception as e:
return ImageProcessingResult( return ImageProcessingResult(
image_id=image_id, image_id = image_id,
image_type="other", image_type = "other",
ocr_text="", ocr_text = "",
description="", description = "",
entities=[], entities = [],
relations=[], relations = [],
width=0, width = 0,
height=0, height = 0,
success=False, success = False,
error_message=str(e), error_message = str(e),
) )
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]: def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
@@ -477,10 +477,10 @@ class ImageProcessor:
for j in range(i + 1, len(sentence_entities)): for j in range(i + 1, len(sentence_entities)):
relations.append( relations.append(
ImageRelation( ImageRelation(
source=sentence_entities[i].name, source = sentence_entities[i].name,
target=sentence_entities[j].name, target = sentence_entities[j].name,
relation_type="related", relation_type = "related",
confidence=0.5, confidence = 0.5,
) )
) )
@@ -513,10 +513,10 @@ class ImageProcessor:
failed_count += 1 failed_count += 1
return BatchProcessingResult( return BatchProcessingResult(
results=results, results = results,
total_count=len(results), total_count = len(results),
success_count=success_count, success_count = success_count,
failed_count=failed_count, failed_count = failed_count,
) )
def image_to_base64(self, image_data: bytes) -> str: def image_to_base64(self, image_data: bytes) -> str:
@@ -550,7 +550,7 @@ class ImageProcessor:
image.thumbnail(size, Image.Resampling.LANCZOS) image.thumbnail(size, Image.Resampling.LANCZOS)
buffer = io.BytesIO() buffer = io.BytesIO()
image.save(buffer, format="JPEG") image.save(buffer, format = "JPEG")
return buffer.getvalue() return buffer.getvalue()
except Exception as e: except Exception as e:
print(f"Thumbnail generation error: {e}") print(f"Thumbnail generation error: {e}")

View File

@@ -51,7 +51,7 @@ class InferencePath:
class KnowledgeReasoner: class KnowledgeReasoner:
"""知识推理引擎""" """知识推理引擎"""
def __init__(self, api_key: str = None, base_url: str = None): def __init__(self, api_key: str = None, base_url: str = None) -> None:
self.api_key = api_key or KIMI_API_KEY self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL self.base_url = base_url or KIMI_BASE_URL
self.headers = { self.headers = {
@@ -73,9 +73,9 @@ class KnowledgeReasoner:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
headers=self.headers, headers = self.headers,
json=payload, json = payload,
timeout=120.0, timeout = 120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -127,7 +127,7 @@ class KnowledgeReasoner:
- factual: 事实类问题(是什么、有哪些) - factual: 事实类问题(是什么、有哪些)
- opinion: 观点类问题(怎么看、态度、评价)""" - opinion: 观点类问题(怎么看、态度、评价)"""
content = await self._call_llm(prompt, temperature=0.1) content = await self._call_llm(prompt, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:
@@ -144,8 +144,8 @@ class KnowledgeReasoner:
"""因果推理 - 分析原因和影响""" """因果推理 - 分析原因和影响"""
# 构建因果分析提示 # 构建因果分析提示
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2) entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2) relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)
prompt = f"""基于以下知识图谱进行因果推理分析: prompt = f"""基于以下知识图谱进行因果推理分析:
@@ -159,7 +159,7 @@ class KnowledgeReasoner:
{relations_str[:2000]} {relations_str[:2000]}
## 项目上下文 ## 项目上下文
{json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]} {json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]}
请进行因果分析,返回 JSON 格式: 请进行因果分析,返回 JSON 格式:
{{ {{
@@ -172,7 +172,7 @@ class KnowledgeReasoner:
"knowledge_gaps": ["缺失信息1"] "knowledge_gaps": ["缺失信息1"]
}}""" }}"""
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -180,23 +180,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer=data.get("answer", ""), answer = data.get("answer", ""),
reasoning_type=ReasoningType.CAUSAL, reasoning_type = ReasoningType.CAUSAL,
confidence=data.get("confidence", 0.7), confidence = data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])], evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities=[], related_entities = [],
gaps=data.get("knowledge_gaps", []), gaps = data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer=content, answer = content,
reasoning_type=ReasoningType.CAUSAL, reasoning_type = ReasoningType.CAUSAL,
confidence=0.5, confidence = 0.5,
evidence=[], evidence = [],
related_entities=[], related_entities = [],
gaps=["无法完成因果推理"], gaps = ["无法完成因果推理"],
) )
async def _comparative_reasoning( async def _comparative_reasoning(
@@ -210,10 +210,10 @@ class KnowledgeReasoner:
{query} {query}
## 实体 ## 实体
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]} {json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]}
## 关系 ## 关系
{json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]} {json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]}
请进行对比分析,返回 JSON 格式: 请进行对比分析,返回 JSON 格式:
{{ {{
@@ -226,7 +226,7 @@ class KnowledgeReasoner:
"knowledge_gaps": [] "knowledge_gaps": []
}}""" }}"""
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -234,23 +234,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer=data.get("answer", ""), answer = data.get("answer", ""),
reasoning_type=ReasoningType.COMPARATIVE, reasoning_type = ReasoningType.COMPARATIVE,
confidence=data.get("confidence", 0.7), confidence = data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])], evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities=[], related_entities = [],
gaps=data.get("knowledge_gaps", []), gaps = data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer=content, answer = content,
reasoning_type=ReasoningType.COMPARATIVE, reasoning_type = ReasoningType.COMPARATIVE,
confidence=0.5, confidence = 0.5,
evidence=[], evidence = [],
related_entities=[], related_entities = [],
gaps=[], gaps = [],
) )
async def _temporal_reasoning( async def _temporal_reasoning(
@@ -264,10 +264,10 @@ class KnowledgeReasoner:
{query} {query}
## 项目时间线 ## 项目时间线
{json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]} {json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]}
## 实体提及历史 ## 实体提及历史
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]} {json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]}
请进行时序分析,返回 JSON 格式: 请进行时序分析,返回 JSON 格式:
{{ {{
@@ -280,7 +280,7 @@ class KnowledgeReasoner:
"knowledge_gaps": [] "knowledge_gaps": []
}}""" }}"""
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -288,23 +288,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer=data.get("answer", ""), answer = data.get("answer", ""),
reasoning_type=ReasoningType.TEMPORAL, reasoning_type = ReasoningType.TEMPORAL,
confidence=data.get("confidence", 0.7), confidence = data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])], evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities=[], related_entities = [],
gaps=data.get("knowledge_gaps", []), gaps = data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer=content, answer = content,
reasoning_type=ReasoningType.TEMPORAL, reasoning_type = ReasoningType.TEMPORAL,
confidence=0.5, confidence = 0.5,
evidence=[], evidence = [],
related_entities=[], related_entities = [],
gaps=[], gaps = [],
) )
async def _associative_reasoning( async def _associative_reasoning(
@@ -318,10 +318,10 @@ class KnowledgeReasoner:
{query} {query}
## 实体 ## 实体
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)} {json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)}
## 关系 ## 关系
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)} {json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)}
请进行关联推理,发现隐含联系,返回 JSON 格式: 请进行关联推理,发现隐含联系,返回 JSON 格式:
{{ {{
@@ -334,7 +334,7 @@ class KnowledgeReasoner:
"knowledge_gaps": [] "knowledge_gaps": []
}}""" }}"""
content = await self._call_llm(prompt, temperature=0.4) content = await self._call_llm(prompt, temperature = 0.4)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -342,23 +342,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer=data.get("answer", ""), answer = data.get("answer", ""),
reasoning_type=ReasoningType.ASSOCIATIVE, reasoning_type = ReasoningType.ASSOCIATIVE,
confidence=data.get("confidence", 0.7), confidence = data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])], evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities=[], related_entities = [],
gaps=data.get("knowledge_gaps", []), gaps = data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer=content, answer = content,
reasoning_type=ReasoningType.ASSOCIATIVE, reasoning_type = ReasoningType.ASSOCIATIVE,
confidence=0.5, confidence = 0.5,
evidence=[], evidence = [],
related_entities=[], related_entities = [],
gaps=[], gaps = [],
) )
def find_inference_paths( def find_inference_paths(
@@ -400,10 +400,10 @@ class KnowledgeReasoner:
# 找到一条路径 # 找到一条路径
paths.append( paths.append(
InferencePath( InferencePath(
start_entity=start_entity, start_entity = start_entity,
end_entity=end_entity, end_entity = end_entity,
path=path, path = path,
strength=self._calculate_path_strength(path), strength = self._calculate_path_strength(path),
) )
) )
continue continue
@@ -424,7 +424,7 @@ class KnowledgeReasoner:
queue.append((next_entity, new_path)) queue.append((next_entity, new_path))
# 按强度排序 # 按强度排序
paths.sort(key=lambda p: p.strength, reverse=True) paths.sort(key = lambda p: p.strength, reverse = True)
return paths return paths
def _calculate_path_strength(self, path: list[dict]) -> float: def _calculate_path_strength(self, path: list[dict]) -> float:
@@ -467,7 +467,7 @@ class KnowledgeReasoner:
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")} prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}
## 项目信息 ## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]} {json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]}
## 知识图谱 ## 知识图谱
实体数: {len(graph_data.get("entities", []))} 实体数: {len(graph_data.get("entities", []))}
@@ -483,7 +483,7 @@ class KnowledgeReasoner:
"confidence": 0.85 "confidence": 0.85
}}""" }}"""
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)

View File

@@ -41,7 +41,7 @@ class RelationExtractionResult:
class LLMClient: class LLMClient:
"""Kimi API 客户端""" """Kimi API 客户端"""
def __init__(self, api_key: str = None, base_url: str = None): def __init__(self, api_key: str = None, base_url: str = None) -> None:
self.api_key = api_key or KIMI_API_KEY self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL self.base_url = base_url or KIMI_BASE_URL
self.headers = { self.headers = {
@@ -66,9 +66,9 @@ class LLMClient:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
headers=self.headers, headers = self.headers,
json=payload, json = payload,
timeout=120.0, timeout = 120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -92,9 +92,9 @@ class LLMClient:
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
headers=self.headers, headers = self.headers,
json=payload, json = payload,
timeout=120.0, timeout = 120.0,
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@@ -139,8 +139,8 @@ class LLMClient:
] ]
}}""" }}"""
messages = [ChatMessage(role="user", content=prompt)] messages = [ChatMessage(role = "user", content = prompt)]
content = await self.chat(messages, temperature=0.1) content = await self.chat(messages, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match: if not json_match:
@@ -150,19 +150,19 @@ class LLMClient:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
entities = [ entities = [
EntityExtractionResult( EntityExtractionResult(
name=e["name"], name = e["name"],
type=e.get("type", "OTHER"), type = e.get("type", "OTHER"),
definition=e.get("definition", ""), definition = e.get("definition", ""),
confidence=e.get("confidence", 0.8), confidence = e.get("confidence", 0.8),
) )
for e in data.get("entities", []) for e in data.get("entities", [])
] ]
relations = [ relations = [
RelationExtractionResult( RelationExtractionResult(
source=r["source"], source = r["source"],
target=r["target"], target = r["target"],
type=r.get("type", "related"), type = r.get("type", "related"),
confidence=r.get("confidence", 0.8), confidence = r.get("confidence", 0.8),
) )
for r in data.get("relations", []) for r in data.get("relations", [])
] ]
@@ -176,7 +176,7 @@ class LLMClient:
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
## 项目信息 ## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)} {json.dumps(project_context, ensure_ascii = False, indent = 2)}
## 相关上下文 ## 相关上下文
{context[:4000]} {context[:4000]}
@@ -188,19 +188,19 @@ class LLMClient:
messages = [ messages = [
ChatMessage( ChatMessage(
role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。" role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
), ),
ChatMessage(role="user", content=prompt), ChatMessage(role = "user", content = prompt),
] ]
return await self.chat(messages, temperature=0.3) return await self.chat(messages, temperature = 0.3)
async def agent_command(self, command: str, project_context: dict) -> dict: async def agent_command(self, command: str, project_context: dict) -> dict:
"""Agent 指令解析 - 将自然语言指令转换为结构化操作""" """Agent 指令解析 - 将自然语言指令转换为结构化操作"""
prompt = f"""解析以下用户指令,转换为结构化操作: prompt = f"""解析以下用户指令,转换为结构化操作:
## 项目信息 ## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)} {json.dumps(project_context, ensure_ascii = False, indent = 2)}
## 用户指令 ## 用户指令
{command} {command}
@@ -221,8 +221,8 @@ class LLMClient:
- create_relation: 创建关系params 包含 source(源实体), target(目标实体), relation_type(关系类型) - create_relation: 创建关系params 包含 source(源实体), target(目标实体), relation_type(关系类型)
""" """
messages = [ChatMessage(role="user", content=prompt)] messages = [ChatMessage(role = "user", content = prompt)]
content = await self.chat(messages, temperature=0.1) content = await self.chat(messages, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match: if not json_match:
@@ -255,8 +255,8 @@ class LLMClient:
用中文回答,结构清晰。""" 用中文回答,结构清晰。"""
messages = [ChatMessage(role="user", content=prompt)] messages = [ChatMessage(role = "user", content = prompt)]
return await self.chat(messages, temperature=0.3) return await self.chat(messages, temperature = 0.3)
# Singleton instance # Singleton instance

View File

@@ -260,8 +260,8 @@ class LocalizationManager:
"date_format": "MM/dd/yyyy", "date_format": "MM/dd/yyyy",
"time_format": "h:mm a", "time_format": "h:mm a",
"datetime_format": "MM/dd/yyyy h:mm a", "datetime_format": "MM/dd/yyyy h:mm a",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "$#,##0.00", "currency_format": "$#, ##0.00",
"first_day_of_week": 0, "first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -272,8 +272,8 @@ class LocalizationManager:
"date_format": "yyyy-MM-dd", "date_format": "yyyy-MM-dd",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "yyyy-MM-dd HH:mm", "datetime_format": "yyyy-MM-dd HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "¥#,##0.00", "currency_format": "¥#, ##0.00",
"first_day_of_week": 1, "first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -284,8 +284,8 @@ class LocalizationManager:
"date_format": "yyyy/MM/dd", "date_format": "yyyy/MM/dd",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "yyyy/MM/dd HH:mm", "datetime_format": "yyyy/MM/dd HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "NT$#,##0.00", "currency_format": "NT$#, ##0.00",
"first_day_of_week": 0, "first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -296,8 +296,8 @@ class LocalizationManager:
"date_format": "yyyy/MM/dd", "date_format": "yyyy/MM/dd",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "yyyy/MM/dd HH:mm", "datetime_format": "yyyy/MM/dd HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "¥#,##0", "currency_format": "¥#, ##0",
"first_day_of_week": 0, "first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -308,8 +308,8 @@ class LocalizationManager:
"date_format": "yyyy. MM. dd", "date_format": "yyyy. MM. dd",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "yyyy. MM. dd HH:mm", "datetime_format": "yyyy. MM. dd HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "₩#,##0", "currency_format": "₩#, ##0",
"first_day_of_week": 0, "first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -320,8 +320,8 @@ class LocalizationManager:
"date_format": "dd.MM.yyyy", "date_format": "dd.MM.yyyy",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "dd.MM.yyyy HH:mm", "datetime_format": "dd.MM.yyyy HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "#,##0.00 €", "currency_format": "#, ##0.00 €",
"first_day_of_week": 1, "first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -332,8 +332,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy", "date_format": "dd/MM/yyyy",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "dd/MM/yyyy HH:mm", "datetime_format": "dd/MM/yyyy HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "#,##0.00 €", "currency_format": "#, ##0.00 €",
"first_day_of_week": 1, "first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -344,8 +344,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy", "date_format": "dd/MM/yyyy",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "dd/MM/yyyy HH:mm", "datetime_format": "dd/MM/yyyy HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "#,##0.00 €", "currency_format": "#, ##0.00 €",
"first_day_of_week": 1, "first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -356,8 +356,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy", "date_format": "dd/MM/yyyy",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "dd/MM/yyyy HH:mm", "datetime_format": "dd/MM/yyyy HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "R$#,##0.00", "currency_format": "R$#, ##0.00",
"first_day_of_week": 0, "first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -368,8 +368,8 @@ class LocalizationManager:
"date_format": "dd.MM.yyyy", "date_format": "dd.MM.yyyy",
"time_format": "HH:mm", "time_format": "HH:mm",
"datetime_format": "dd.MM.yyyy HH:mm", "datetime_format": "dd.MM.yyyy HH:mm",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "#,##0.00 ₽", "currency_format": "#, ##0.00 ₽",
"first_day_of_week": 1, "first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value, "calendar_type": CalendarType.GREGORIAN.value,
}, },
@@ -380,8 +380,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy", "date_format": "dd/MM/yyyy",
"time_format": "hh:mm a", "time_format": "hh:mm a",
"datetime_format": "dd/MM/yyyy hh:mm a", "datetime_format": "dd/MM/yyyy hh:mm a",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "#,##0.00 ر.س", "currency_format": "#, ##0.00 ر.س",
"first_day_of_week": 6, "first_day_of_week": 6,
"calendar_type": CalendarType.ISLAMIC.value, "calendar_type": CalendarType.ISLAMIC.value,
}, },
@@ -392,8 +392,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy", "date_format": "dd/MM/yyyy",
"time_format": "hh:mm a", "time_format": "hh:mm a",
"datetime_format": "dd/MM/yyyy hh:mm a", "datetime_format": "dd/MM/yyyy hh:mm a",
"number_format": "#,##0.##", "number_format": "#, ##0.##",
"currency_format": "₹#,##0.00", "currency_format": "₹#, ##0.00",
"first_day_of_week": 0, "first_day_of_week": 0,
"calendar_type": CalendarType.INDIAN.value, "calendar_type": CalendarType.INDIAN.value,
}, },
@@ -719,7 +719,7 @@ class LocalizationManager:
}, },
} }
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self._is_memory_db = db_path == ":memory:" self._is_memory_db = db_path == ":memory:"
self._conn = None self._conn = None
@@ -736,11 +736,11 @@ class LocalizationManager:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def _close_if_file_db(self, conn): def _close_if_file_db(self, conn) -> None:
if not self._is_memory_db: if not self._is_memory_db:
conn.close() conn.close()
def _init_db(self): def _init_db(self) -> None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -813,7 +813,7 @@ class LocalizationManager:
CREATE TABLE IF NOT EXISTS currency_configs ( CREATE TABLE IF NOT EXISTS currency_configs (
code TEXT PRIMARY KEY, name TEXT NOT NULL, name_local TEXT DEFAULT '{}', symbol TEXT NOT NULL, code TEXT PRIMARY KEY, name TEXT NOT NULL, name_local TEXT DEFAULT '{}', symbol TEXT NOT NULL,
decimal_places INTEGER DEFAULT 2, decimal_separator TEXT DEFAULT '.', decimal_places INTEGER DEFAULT 2, decimal_separator TEXT DEFAULT '.',
thousands_separator TEXT DEFAULT ',', is_active INTEGER DEFAULT 1 thousands_separator TEXT DEFAULT ', ', is_active INTEGER DEFAULT 1
) )
""") """)
cursor.execute(""" cursor.execute("""
@@ -863,7 +863,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def _init_default_data(self): def _init_default_data(self) -> None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1054,7 +1054,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT * FROM translations WHERE 1=1" query = "SELECT * FROM translations WHERE 1 = 1"
params = [] params = []
if language: if language:
query += " AND language = ?" query += " AND language = ?"
@@ -1074,7 +1074,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code,)) cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_language_config(row) return self._row_to_language_config(row)
@@ -1100,7 +1100,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id,)) cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_data_center(row) return self._row_to_data_center(row)
@@ -1112,7 +1112,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code,)) cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_data_center(row) return self._row_to_data_center(row)
@@ -1126,7 +1126,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT * FROM data_centers WHERE 1=1" query = "SELECT * FROM data_centers WHERE 1 = 1"
params = [] params = []
if status: if status:
query += " AND status = ?" query += " AND status = ?"
@@ -1146,7 +1146,7 @@ class LocalizationManager:
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,) "SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id, )
) )
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1166,7 +1166,7 @@ class LocalizationManager:
SELECT * FROM data_centers WHERE supported_regions LIKE ? AND status = 'active' SELECT * FROM data_centers WHERE supported_regions LIKE ? AND status = 'active'
ORDER BY priority LIMIT 1 ORDER BY priority LIMIT 1
""", """,
(f'%"{region_code}"%',), (f'%"{region_code}"%', ),
) )
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
@@ -1182,7 +1182,7 @@ class LocalizationManager:
""" """
SELECT * FROM data_centers WHERE id != ? AND status = 'active' ORDER BY priority LIMIT 1 SELECT * FROM data_centers WHERE id != ? AND status = 'active' ORDER BY priority LIMIT 1
""", """,
(primary_dc_id,), (primary_dc_id, ),
) )
secondary_row = cursor.fetchone() secondary_row = cursor.fetchone()
secondary_dc_id = secondary_row["id"] if secondary_row else None secondary_dc_id = secondary_row["id"] if secondary_row else None
@@ -1222,7 +1222,7 @@ class LocalizationManager:
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,) "SELECT * FROM localized_payment_methods WHERE provider = ?", (provider, )
) )
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1237,7 +1237,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT * FROM localized_payment_methods WHERE 1=1" query = "SELECT * FROM localized_payment_methods WHERE 1 = 1"
params = [] params = []
if active_only: if active_only:
query += " AND is_active = 1" query += " AND is_active = 1"
@@ -1257,7 +1257,7 @@ class LocalizationManager:
def get_localized_payment_methods( def get_localized_payment_methods(
self, country_code: str, language: str = "en" self, country_code: str, language: str = "en"
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
methods = self.list_payment_methods(country_code=country_code) methods = self.list_payment_methods(country_code = country_code)
result = [] result = []
for method in methods: for method in methods:
name_local = method.name_local.get(language, method.name) name_local = method.name_local.get(language, method.name)
@@ -1278,7 +1278,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code,)) cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_country_config(row) return self._row_to_country_config(row)
@@ -1292,7 +1292,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT * FROM country_configs WHERE 1=1" query = "SELECT * FROM country_configs WHERE 1 = 1"
params = [] params = []
if active_only: if active_only:
query += " AND is_active = 1" query += " AND is_active = 1"
@@ -1332,11 +1332,11 @@ class LocalizationManager:
try: try:
locale = Locale.parse(language.replace("_", "-")) locale = Locale.parse(language.replace("_", "-"))
if format_type == "date": if format_type == "date":
return dates.format_date(dt, locale=locale) return dates.format_date(dt, locale = locale)
elif format_type == "time": elif format_type == "time":
return dates.format_time(dt, locale=locale) return dates.format_time(dt, locale = locale)
else: else:
return dates.format_datetime(dt, locale=locale) return dates.format_datetime(dt, locale = locale)
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
return dt.strftime(fmt) return dt.strftime(fmt)
@@ -1352,13 +1352,13 @@ class LocalizationManager:
try: try:
locale = Locale.parse(language.replace("_", "-")) locale = Locale.parse(language.replace("_", "-"))
return numbers.format_decimal( return numbers.format_decimal(
number, locale=locale, decimal_quantization=(decimal_places is not None) number, locale = locale, decimal_quantization = (decimal_places is not None)
) )
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
if decimal_places is not None: if decimal_places is not None:
return f"{number:,.{decimal_places}f}" return f"{number:, .{decimal_places}f}"
return f"{number:,}" return f"{number:, }"
except Exception as e: except Exception as e:
logger.error(f"Error formatting number: {e}") logger.error(f"Error formatting number: {e}")
return str(number) return str(number)
@@ -1368,10 +1368,10 @@ class LocalizationManager:
if BABEL_AVAILABLE: if BABEL_AVAILABLE:
try: try:
locale = Locale.parse(language.replace("_", "-")) locale = Locale.parse(language.replace("_", "-"))
return numbers.format_currency(amount, currency, locale=locale) return numbers.format_currency(amount, currency, locale = locale)
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
return f"{currency} {amount:,.2f}" return f"{currency} {amount:, .2f}"
except Exception as e: except Exception as e:
logger.error(f"Error formatting currency: {e}") logger.error(f"Error formatting currency: {e}")
return f"{currency} {amount:.2f}" return f"{currency} {amount:.2f}"
@@ -1408,7 +1408,7 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id,)) cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_localization_settings(row) return self._row_to_localization_settings(row)
@@ -1453,7 +1453,7 @@ class LocalizationManager:
default_timezone, default_timezone,
lang_config.date_format if lang_config else "%Y-%m-%d", lang_config.date_format if lang_config else "%Y-%m-%d",
lang_config.time_format if lang_config else "%H:%M", lang_config.time_format if lang_config else "%H:%M",
lang_config.number_format if lang_config else "#,##0.##", lang_config.number_format if lang_config else "#, ##0.##",
lang_config.calendar_type if lang_config else CalendarType.GREGORIAN.value, lang_config.calendar_type if lang_config else CalendarType.GREGORIAN.value,
lang_config.first_day_of_week if lang_config else 1, lang_config.first_day_of_week if lang_config else 1,
region_code, region_code,
@@ -1517,7 +1517,7 @@ class LocalizationManager:
) -> dict[str, str]: ) -> dict[str, str]:
preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"}
if accept_language: if accept_language:
langs = accept_language.split(",") langs = accept_language.split(", ")
for lang in langs: for lang in langs:
lang_code = lang.split(";")[0].strip().replace("-", "_") lang_code = lang.split(";")[0].strip().replace("-", "_")
lang_config = self.get_language_config(lang_code) lang_config = self.get_language_config(lang_code)
@@ -1536,25 +1536,25 @@ class LocalizationManager:
def _row_to_translation(self, row: sqlite3.Row) -> Translation: def _row_to_translation(self, row: sqlite3.Row) -> Translation:
return Translation( return Translation(
id=row["id"], id = row["id"],
key=row["key"], key = row["key"],
language=row["language"], language = row["language"],
value=row["value"], value = row["value"],
namespace=row["namespace"], namespace = row["namespace"],
context=row["context"], context = row["context"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
is_reviewed=bool(row["is_reviewed"]), is_reviewed = bool(row["is_reviewed"]),
reviewed_by=row["reviewed_by"], reviewed_by = row["reviewed_by"],
reviewed_at=( reviewed_at = (
datetime.fromisoformat(row["reviewed_at"]) datetime.fromisoformat(row["reviewed_at"])
if row["reviewed_at"] and isinstance(row["reviewed_at"], str) if row["reviewed_at"] and isinstance(row["reviewed_at"], str)
else row["reviewed_at"] else row["reviewed_at"]
@@ -1563,39 +1563,39 @@ class LocalizationManager:
def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig: def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig:
return LanguageConfig( return LanguageConfig(
code=row["code"], code = row["code"],
name=row["name"], name = row["name"],
name_local=row["name_local"], name_local = row["name_local"],
is_rtl=bool(row["is_rtl"]), is_rtl = bool(row["is_rtl"]),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
is_default=bool(row["is_default"]), is_default = bool(row["is_default"]),
fallback_language=row["fallback_language"], fallback_language = row["fallback_language"],
date_format=row["date_format"], date_format = row["date_format"],
time_format=row["time_format"], time_format = row["time_format"],
datetime_format=row["datetime_format"], datetime_format = row["datetime_format"],
number_format=row["number_format"], number_format = row["number_format"],
currency_format=row["currency_format"], currency_format = row["currency_format"],
first_day_of_week=row["first_day_of_week"], first_day_of_week = row["first_day_of_week"],
calendar_type=row["calendar_type"], calendar_type = row["calendar_type"],
) )
def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter: def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter:
return DataCenter( return DataCenter(
id=row["id"], id = row["id"],
region_code=row["region_code"], region_code = row["region_code"],
name=row["name"], name = row["name"],
location=row["location"], location = row["location"],
endpoint=row["endpoint"], endpoint = row["endpoint"],
status=row["status"], status = row["status"],
priority=row["priority"], priority = row["priority"],
supported_regions=json.loads(row["supported_regions"] or "[]"), supported_regions = json.loads(row["supported_regions"] or "[]"),
capabilities=json.loads(row["capabilities"] or "{}"), capabilities = json.loads(row["capabilities"] or "{}"),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1604,18 +1604,18 @@ class LocalizationManager:
def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping: def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping:
return TenantDataCenterMapping( return TenantDataCenterMapping(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
primary_dc_id=row["primary_dc_id"], primary_dc_id = row["primary_dc_id"],
secondary_dc_id=row["secondary_dc_id"], secondary_dc_id = row["secondary_dc_id"],
region_code=row["region_code"], region_code = row["region_code"],
data_residency=row["data_residency"], data_residency = row["data_residency"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1624,24 +1624,24 @@ class LocalizationManager:
def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod: def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod:
return LocalizedPaymentMethod( return LocalizedPaymentMethod(
id=row["id"], id = row["id"],
provider=row["provider"], provider = row["provider"],
name=row["name"], name = row["name"],
name_local=json.loads(row["name_local"] or "{}"), name_local = json.loads(row["name_local"] or "{}"),
supported_countries=json.loads(row["supported_countries"] or "[]"), supported_countries = json.loads(row["supported_countries"] or "[]"),
supported_currencies=json.loads(row["supported_currencies"] or "[]"), supported_currencies = json.loads(row["supported_currencies"] or "[]"),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
config=json.loads(row["config"] or "{}"), config = json.loads(row["config"] or "{}"),
icon_url=row["icon_url"], icon_url = row["icon_url"],
display_order=row["display_order"], display_order = row["display_order"],
min_amount=row["min_amount"], min_amount = row["min_amount"],
max_amount=row["max_amount"], max_amount = row["max_amount"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1650,48 +1650,48 @@ class LocalizationManager:
def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig: def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig:
return CountryConfig( return CountryConfig(
code=row["code"], code = row["code"],
code3=row["code3"], code3 = row["code3"],
name=row["name"], name = row["name"],
name_local=json.loads(row["name_local"] or "{}"), name_local = json.loads(row["name_local"] or "{}"),
region=row["region"], region = row["region"],
default_language=row["default_language"], default_language = row["default_language"],
supported_languages=json.loads(row["supported_languages"] or "[]"), supported_languages = json.loads(row["supported_languages"] or "[]"),
default_currency=row["default_currency"], default_currency = row["default_currency"],
supported_currencies=json.loads(row["supported_currencies"] or "[]"), supported_currencies = json.loads(row["supported_currencies"] or "[]"),
timezone=row["timezone"], timezone = row["timezone"],
calendar_type=row["calendar_type"], calendar_type = row["calendar_type"],
date_format=row["date_format"], date_format = row["date_format"],
time_format=row["time_format"], time_format = row["time_format"],
number_format=row["number_format"], number_format = row["number_format"],
address_format=row["address_format"], address_format = row["address_format"],
phone_format=row["phone_format"], phone_format = row["phone_format"],
vat_rate=row["vat_rate"], vat_rate = row["vat_rate"],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
) )
def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings: def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings:
return LocalizationSettings( return LocalizationSettings(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
default_language=row["default_language"], default_language = row["default_language"],
supported_languages=json.loads(row["supported_languages"] or '["en"]'), supported_languages = json.loads(row["supported_languages"] or '["en"]'),
default_currency=row["default_currency"], default_currency = row["default_currency"],
supported_currencies=json.loads(row["supported_currencies"] or '["USD"]'), supported_currencies = json.loads(row["supported_currencies"] or '["USD"]'),
default_timezone=row["default_timezone"], default_timezone = row["default_timezone"],
default_date_format=row["default_date_format"], default_date_format = row["default_date_format"],
default_time_format=row["default_time_format"], default_time_format = row["default_time_format"],
default_number_format=row["default_number_format"], default_number_format = row["default_number_format"],
calendar_type=row["calendar_type"], calendar_type = row["calendar_type"],
first_day_of_week=row["first_day_of_week"], first_day_of_week = row["first_day_of_week"],
region_code=row["region_code"], region_code = row["region_code"],
data_residency=row["data_residency"], data_residency = row["data_residency"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]

View File

@@ -32,7 +32,7 @@ class MultimodalEntity:
confidence: float confidence: float
modality_features: dict = None # 模态特定特征 modality_features: dict = None # 模态特定特征
def __post_init__(self): def __post_init__(self) -> None:
if self.modality_features is None: if self.modality_features is None:
self.modality_features = {} self.modality_features = {}
@@ -200,11 +200,11 @@ class MultimodalEntityLinker:
if best_match: if best_match:
return AlignmentResult( return AlignmentResult(
entity_id=query_entity.get("id"), entity_id = query_entity.get("id"),
matched_entity_id=best_match.get("id"), matched_entity_id = best_match.get("id"),
similarity=best_similarity, similarity = best_similarity,
match_type=best_match_type, match_type = best_match_type,
confidence=best_similarity, confidence = best_similarity,
) )
return None return None
@@ -255,15 +255,15 @@ class MultimodalEntityLinker:
if result and result.matched_entity_id: if result and result.matched_entity_id:
link = EntityLink( link = EntityLink(
id=str(uuid.uuid4())[:UUID_LENGTH], id = str(uuid.uuid4())[:UUID_LENGTH],
project_id=project_id, project_id = project_id,
source_entity_id=ent1.get("id"), source_entity_id = ent1.get("id"),
target_entity_id=result.matched_entity_id, target_entity_id = result.matched_entity_id,
link_type="same_as" if result.similarity > 0.95 else "related_to", link_type = "same_as" if result.similarity > 0.95 else "related_to",
source_modality=mod1, source_modality = mod1,
target_modality=mod2, target_modality = mod2,
confidence=result.confidence, confidence = result.confidence,
evidence=f"Cross-modal alignment: {result.match_type}", evidence = f"Cross-modal alignment: {result.match_type}",
) )
links.append(link) links.append(link)
@@ -319,7 +319,7 @@ class MultimodalEntityLinker:
# 选择最佳定义(最长的那个) # 选择最佳定义(最长的那个)
best_definition = ( best_definition = (
max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else "" max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else ""
) )
# 选择最佳名称(最常见的那个) # 选择最佳名称(最常见的那个)
@@ -330,9 +330,9 @@ class MultimodalEntityLinker:
# 构建融合结果 # 构建融合结果
return FusionResult( return FusionResult(
canonical_entity_id=entity_id, canonical_entity_id = entity_id,
merged_entity_ids=merged_ids, merged_entity_ids = merged_ids,
fused_properties={ fused_properties = {
"name": best_name, "name": best_name,
"definition": best_definition, "definition": best_definition,
"aliases": list(fused_properties["aliases"]), "aliases": list(fused_properties["aliases"]),
@@ -340,8 +340,8 @@ class MultimodalEntityLinker:
"modalities": list(fused_properties["modalities"]), "modalities": list(fused_properties["modalities"]),
"contexts": fused_properties["contexts"][:10], # 最多10个上下文 "contexts": fused_properties["contexts"][:10], # 最多10个上下文
}, },
source_modalities=list(fused_properties["modalities"]), source_modalities = list(fused_properties["modalities"]),
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5), confidence = min(1.0, len(linked_entities) * 0.2 + 0.5),
) )
def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]: def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
@@ -441,7 +441,7 @@ class MultimodalEntityLinker:
) )
# 按相似度排序 # 按相似度排序
suggestions.sort(key=lambda x: x["similarity"], reverse=True) suggestions.sort(key = lambda x: x["similarity"], reverse = True)
return suggestions return suggestions
@@ -469,14 +469,14 @@ class MultimodalEntityLinker:
多模态实体记录 多模态实体记录
""" """
return MultimodalEntity( return MultimodalEntity(
id=str(uuid.uuid4())[:UUID_LENGTH], id = str(uuid.uuid4())[:UUID_LENGTH],
entity_id=entity_id, entity_id = entity_id,
project_id=project_id, project_id = project_id,
name="", # 将在后续填充 name = "", # 将在后续填充
source_type=source_type, source_type = source_type,
source_id=source_id, source_id = source_id,
mention_context=mention_context, mention_context = mention_context,
confidence=confidence, confidence = confidence,
) )
def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict: def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:

View File

@@ -52,7 +52,7 @@ class VideoFrame:
ocr_confidence: float = 0.0 ocr_confidence: float = 0.0
entities_detected: list[dict] = None entities_detected: list[dict] = None
def __post_init__(self): def __post_init__(self) -> None:
if self.entities_detected is None: if self.entities_detected is None:
self.entities_detected = [] self.entities_detected = []
@@ -76,7 +76,7 @@ class VideoInfo:
error_message: str = "" error_message: str = ""
metadata: dict = None metadata: dict = None
def __post_init__(self): def __post_init__(self) -> None:
if self.metadata is None: if self.metadata is None:
self.metadata = {} self.metadata = {}
@@ -112,9 +112,9 @@ class MultimodalProcessor:
self.audio_dir = os.path.join(self.temp_dir, "audio") self.audio_dir = os.path.join(self.temp_dir, "audio")
# 创建目录 # 创建目录
os.makedirs(self.video_dir, exist_ok=True) os.makedirs(self.video_dir, exist_ok = True)
os.makedirs(self.frames_dir, exist_ok=True) os.makedirs(self.frames_dir, exist_ok = True)
os.makedirs(self.audio_dir, exist_ok=True) os.makedirs(self.audio_dir, exist_ok = True)
def extract_video_info(self, video_path: str) -> dict: def extract_video_info(self, video_path: str) -> dict:
""" """
@@ -152,14 +152,14 @@ class MultimodalProcessor:
"-v", "-v",
"error", "error",
"-show_entries", "-show_entries",
"format=duration,bit_rate", "format = duration, bit_rate",
"-show_entries", "-show_entries",
"stream=width,height,r_frame_rate", "stream = width, height, r_frame_rate",
"-of", "-of",
"json", "json",
video_path, video_path,
] ]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output = True, text = True)
if result.returncode == 0: if result.returncode == 0:
data = json.loads(result.stdout) data = json.loads(result.stdout)
return { return {
@@ -196,9 +196,9 @@ class MultimodalProcessor:
if FFMPEG_AVAILABLE: if FFMPEG_AVAILABLE:
( (
ffmpeg.input(video_path) ffmpeg.input(video_path)
.output(output_path, ac=1, ar=16000, vn=None) .output(output_path, ac = 1, ar = 16000, vn = None)
.overwrite_output() .overwrite_output()
.run(quiet=True) .run(quiet = True)
) )
else: else:
# 使用命令行 ffmpeg # 使用命令行 ffmpeg
@@ -216,7 +216,7 @@ class MultimodalProcessor:
"-y", "-y",
output_path, output_path,
] ]
subprocess.run(cmd, check=True, capture_output=True) subprocess.run(cmd, check = True, capture_output = True)
return output_path return output_path
except Exception as e: except Exception as e:
@@ -240,7 +240,7 @@ class MultimodalProcessor:
# 创建帧存储目录 # 创建帧存储目录
video_frames_dir = os.path.join(self.frames_dir, video_id) video_frames_dir = os.path.join(self.frames_dir, video_id)
os.makedirs(video_frames_dir, exist_ok=True) os.makedirs(video_frames_dir, exist_ok = True)
try: try:
if CV2_AVAILABLE: if CV2_AVAILABLE:
@@ -278,13 +278,13 @@ class MultimodalProcessor:
"-i", "-i",
video_path, video_path,
"-vf", "-vf",
f"fps=1/{interval}", f"fps = 1/{interval}",
"-frame_pts", "-frame_pts",
"1", "1",
"-y", "-y",
output_pattern, output_pattern,
] ]
subprocess.run(cmd, check=True, capture_output=True) subprocess.run(cmd, check = True, capture_output = True)
# 获取生成的帧文件列表 # 获取生成的帧文件列表
frame_paths = sorted( frame_paths = sorted(
@@ -320,10 +320,10 @@ class MultimodalProcessor:
image = image.convert("L") image = image.convert("L")
# 使用 pytesseract 进行 OCR # 使用 pytesseract 进行 OCR
text = pytesseract.image_to_string(image, lang="chi_sim+eng") text = pytesseract.image_to_string(image, lang = "chi_sim+eng")
# 获取置信度数据 # 获取置信度数据
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0] confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0 avg_confidence = sum(confidences) / len(confidences) if confidences else 0
@@ -382,13 +382,13 @@ class MultimodalProcessor:
ocr_text, confidence = self.perform_ocr(frame_path) ocr_text, confidence = self.perform_ocr(frame_path)
frame = VideoFrame( frame = VideoFrame(
id=str(uuid.uuid4())[:UUID_LENGTH], id = str(uuid.uuid4())[:UUID_LENGTH],
video_id=video_id, video_id = video_id,
frame_number=frame_number, frame_number = frame_number,
timestamp=timestamp, timestamp = timestamp,
frame_path=frame_path, frame_path = frame_path,
ocr_text=ocr_text, ocr_text = ocr_text,
ocr_confidence=confidence, ocr_confidence = confidence,
) )
frames.append(frame) frames.append(frame)
@@ -407,23 +407,23 @@ class MultimodalProcessor:
full_ocr_text = "\n\n".join(all_ocr_text) full_ocr_text = "\n\n".join(all_ocr_text)
return VideoProcessingResult( return VideoProcessingResult(
video_id=video_id, video_id = video_id,
audio_path=audio_path, audio_path = audio_path,
frames=frames, frames = frames,
ocr_results=ocr_results, ocr_results = ocr_results,
full_text=full_ocr_text, full_text = full_ocr_text,
success=True, success = True,
) )
except Exception as e: except Exception as e:
return VideoProcessingResult( return VideoProcessingResult(
video_id=video_id, video_id = video_id,
audio_path="", audio_path = "",
frames=[], frames = [],
ocr_results=[], ocr_results = [],
full_text="", full_text = "",
success=False, success = False,
error_message=str(e), error_message = str(e),
) )
def cleanup(self, video_id: str = None) -> None: def cleanup(self, video_id: str = None) -> None:
@@ -450,7 +450,7 @@ class MultimodalProcessor:
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
if os.path.exists(dir_path): if os.path.exists(dir_path):
shutil.rmtree(dir_path) shutil.rmtree(dir_path)
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok = True)
# Singleton instance # Singleton instance

View File

@@ -39,7 +39,7 @@ class GraphEntity:
aliases: list[str] = None aliases: list[str] = None
properties: dict = None properties: dict = None
def __post_init__(self): def __post_init__(self) -> None:
if self.aliases is None: if self.aliases is None:
self.aliases = [] self.aliases = []
if self.properties is None: if self.properties is None:
@@ -57,7 +57,7 @@ class GraphRelation:
evidence: str = "" evidence: str = ""
properties: dict = None properties: dict = None
def __post_init__(self): def __post_init__(self) -> None:
if self.properties is None: if self.properties is None:
self.properties = {} self.properties = {}
@@ -95,7 +95,7 @@ class CentralityResult:
class Neo4jManager: class Neo4jManager:
"""Neo4j 图数据库管理器""" """Neo4j 图数据库管理器"""
def __init__(self, uri: str = None, user: str = None, password: str = None): def __init__(self, uri: str = None, user: str = None, password: str = None) -> None:
self.uri = uri or NEO4J_URI self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD self.password = password or NEO4J_PASSWORD
@@ -113,7 +113,7 @@ class Neo4jManager:
return return
try: try:
self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password))
# 验证连接 # 验证连接
self._driver.verify_connectivity() self._driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}") logger.info(f"Connected to Neo4j at {self.uri}")
@@ -193,9 +193,9 @@ class Neo4jManager:
p.description = $description, p.description = $description,
p.updated_at = datetime() p.updated_at = datetime()
""", """,
project_id=project_id, project_id = project_id,
name=project_name, name = project_name,
description=project_description, description = project_description,
) )
def sync_entity(self, entity: GraphEntity) -> None: def sync_entity(self, entity: GraphEntity) -> None:
@@ -218,13 +218,13 @@ class Neo4jManager:
MATCH (p:Project {id: $project_id}) MATCH (p:Project {id: $project_id})
MERGE (e)-[:BELONGS_TO]->(p) MERGE (e)-[:BELONGS_TO]->(p)
""", """,
id=entity.id, id = entity.id,
project_id=entity.project_id, project_id = entity.project_id,
name=entity.name, name = entity.name,
type=entity.type, type = entity.type,
definition=entity.definition, definition = entity.definition,
aliases=json.dumps(entity.aliases), aliases = json.dumps(entity.aliases),
properties=json.dumps(entity.properties), properties = json.dumps(entity.properties),
) )
def sync_entities_batch(self, entities: list[GraphEntity]) -> None: def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
@@ -261,7 +261,7 @@ class Neo4jManager:
MATCH (p:Project {id: entity.project_id}) MATCH (p:Project {id: entity.project_id})
MERGE (e)-[:BELONGS_TO]->(p) MERGE (e)-[:BELONGS_TO]->(p)
""", """,
entities=entities_data, entities = entities_data,
) )
def sync_relation(self, relation: GraphRelation) -> None: def sync_relation(self, relation: GraphRelation) -> None:
@@ -280,12 +280,12 @@ class Neo4jManager:
r.properties = $properties, r.properties = $properties,
r.updated_at = datetime() r.updated_at = datetime()
""", """,
id=relation.id, id = relation.id,
source_id=relation.source_id, source_id = relation.source_id,
target_id=relation.target_id, target_id = relation.target_id,
relation_type=relation.relation_type, relation_type = relation.relation_type,
evidence=relation.evidence, evidence = relation.evidence,
properties=json.dumps(relation.properties), properties = json.dumps(relation.properties),
) )
def sync_relations_batch(self, relations: list[GraphRelation]) -> None: def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
@@ -317,7 +317,7 @@ class Neo4jManager:
r.properties = rel.properties, r.properties = rel.properties,
r.updated_at = datetime() r.updated_at = datetime()
""", """,
relations=relations_data, relations = relations_data,
) )
def delete_entity(self, entity_id: str) -> None: def delete_entity(self, entity_id: str) -> None:
@@ -331,7 +331,7 @@ class Neo4jManager:
MATCH (e:Entity {id: $id}) MATCH (e:Entity {id: $id})
DETACH DELETE e DETACH DELETE e
""", """,
id=entity_id, id = entity_id,
) )
def delete_project(self, project_id: str) -> None: def delete_project(self, project_id: str) -> None:
@@ -346,7 +346,7 @@ class Neo4jManager:
OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p) OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p)
DETACH DELETE e, p DETACH DELETE e, p
""", """,
id=project_id, id = project_id,
) )
# ==================== 复杂图查询 ==================== # ==================== 复杂图查询 ====================
@@ -376,9 +376,9 @@ class Neo4jManager:
) )
RETURN path RETURN path
""", """,
source_id=source_id, source_id = source_id,
target_id=target_id, target_id = target_id,
max_depth=max_depth, max_depth = max_depth,
) )
record = result.single() record = result.single()
@@ -404,7 +404,7 @@ class Neo4jManager:
] ]
return PathResult( return PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships) nodes = nodes, relationships = relationships, length = len(path.relationships)
) )
def find_all_paths( def find_all_paths(
@@ -433,10 +433,10 @@ class Neo4jManager:
RETURN path RETURN path
LIMIT $limit LIMIT $limit
""", """,
source_id=source_id, source_id = source_id,
target_id=target_id, target_id = target_id,
max_depth=max_depth, max_depth = max_depth,
limit=limit, limit = limit,
) )
paths = [] paths = []
@@ -460,7 +460,7 @@ class Neo4jManager:
paths.append( paths.append(
PathResult( PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships) nodes = nodes, relationships = relationships, length = len(path.relationships)
) )
) )
@@ -491,9 +491,9 @@ class Neo4jManager:
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit LIMIT $limit
""", """,
entity_id=entity_id, entity_id = entity_id,
relation_type=relation_type, relation_type = relation_type,
limit=limit, limit = limit,
) )
else: else:
result = session.run( result = session.run(
@@ -502,8 +502,8 @@ class Neo4jManager:
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit LIMIT $limit
""", """,
entity_id=entity_id, entity_id = entity_id,
limit=limit, limit = limit,
) )
neighbors = [] neighbors = []
@@ -541,8 +541,8 @@ class Neo4jManager:
MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2}) MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2})
RETURN DISTINCT common RETURN DISTINCT common
""", """,
id1=entity_id1, id1 = entity_id1,
id2=entity_id2, id2 = entity_id2,
) )
return [ return [
@@ -581,7 +581,7 @@ class Neo4jManager:
{} {}
) YIELD value RETURN value ) YIELD value RETURN value
""", """,
project_id=project_id, project_id = project_id,
) )
# 创建临时图 # 创建临时图
@@ -601,7 +601,7 @@ class Neo4jManager:
} }
) )
""", """,
project_id=project_id, project_id = project_id,
) )
# 运行 PageRank # 运行 PageRank
@@ -615,8 +615,8 @@ class Neo4jManager:
ORDER BY score DESC ORDER BY score DESC
LIMIT $top_n LIMIT $top_n
""", """,
project_id=project_id, project_id = project_id,
top_n=top_n, top_n = top_n,
) )
rankings = [] rankings = []
@@ -624,10 +624,10 @@ class Neo4jManager:
for record in result: for record in result:
rankings.append( rankings.append(
CentralityResult( CentralityResult(
entity_id=record["entity_id"], entity_id = record["entity_id"],
entity_name=record["entity_name"], entity_name = record["entity_name"],
score=record["score"], score = record["score"],
rank=rank, rank = rank,
) )
) )
rank += 1 rank += 1
@@ -637,7 +637,7 @@ class Neo4jManager:
""" """
CALL gds.graph.drop('project-graph-$project_id') CALL gds.graph.drop('project-graph-$project_id')
""", """,
project_id=project_id, project_id = project_id,
) )
return rankings return rankings
@@ -667,8 +667,8 @@ class Neo4jManager:
LIMIT $top_n LIMIT $top_n
RETURN e.id as entity_id, e.name as entity_name, degree as score RETURN e.id as entity_id, e.name as entity_name, degree as score
""", """,
project_id=project_id, project_id = project_id,
top_n=top_n, top_n = top_n,
) )
rankings = [] rankings = []
@@ -676,10 +676,10 @@ class Neo4jManager:
for record in result: for record in result:
rankings.append( rankings.append(
CentralityResult( CentralityResult(
entity_id=record["entity_id"], entity_id = record["entity_id"],
entity_name=record["entity_name"], entity_name = record["entity_name"],
score=float(record["score"]), score = float(record["score"]),
rank=rank, rank = rank,
) )
) )
rank += 1 rank += 1
@@ -710,7 +710,7 @@ class Neo4jManager:
connections, size(connections) as connection_count connections, size(connections) as connection_count
ORDER BY connection_count DESC ORDER BY connection_count DESC
""", """,
project_id=project_id, project_id = project_id,
) )
# 手动分组(基于连通性) # 手动分组(基于连通性)
@@ -752,12 +752,12 @@ class Neo4jManager:
results.append( results.append(
CommunityResult( CommunityResult(
community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0) community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0)
) )
) )
# 按大小排序 # 按大小排序
results.sort(key=lambda x: x.size, reverse=True) results.sort(key = lambda x: x.size, reverse = True)
return results return results
def find_central_entities( def find_central_entities(
@@ -787,7 +787,7 @@ class Neo4jManager:
ORDER BY degree DESC ORDER BY degree DESC
LIMIT 20 LIMIT 20
""", """,
project_id=project_id, project_id = project_id,
) )
else: else:
# 默认使用度中心性 # 默认使用度中心性
@@ -800,7 +800,7 @@ class Neo4jManager:
ORDER BY degree DESC ORDER BY degree DESC
LIMIT 20 LIMIT 20
""", """,
project_id=project_id, project_id = project_id,
) )
rankings = [] rankings = []
@@ -808,10 +808,10 @@ class Neo4jManager:
for record in result: for record in result:
rankings.append( rankings.append(
CentralityResult( CentralityResult(
entity_id=record["entity_id"], entity_id = record["entity_id"],
entity_name=record["entity_name"], entity_name = record["entity_name"],
score=float(record["score"]), score = float(record["score"]),
rank=rank, rank = rank,
) )
) )
rank += 1 rank += 1
@@ -840,7 +840,7 @@ class Neo4jManager:
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN count(e) as count RETURN count(e) as count
""", """,
project_id=project_id, project_id = project_id,
).single()["count"] ).single()["count"]
# 关系数量 # 关系数量
@@ -850,7 +850,7 @@ class Neo4jManager:
MATCH (e)-[r:RELATES_TO]-() MATCH (e)-[r:RELATES_TO]-()
RETURN count(r) as count RETURN count(r) as count
""", """,
project_id=project_id, project_id = project_id,
).single()["count"] ).single()["count"]
# 实体类型分布 # 实体类型分布
@@ -860,7 +860,7 @@ class Neo4jManager:
RETURN e.type as type, count(e) as count RETURN e.type as type, count(e) as count
ORDER BY count DESC ORDER BY count DESC
""", """,
project_id=project_id, project_id = project_id,
) )
types = {record["type"]: record["count"] for record in type_distribution} types = {record["type"]: record["count"] for record in type_distribution}
@@ -873,7 +873,7 @@ class Neo4jManager:
WITH e, count(other) as degree WITH e, count(other) as degree
RETURN avg(degree) as avg_degree RETURN avg(degree) as avg_degree
""", """,
project_id=project_id, project_id = project_id,
).single()["avg_degree"] ).single()["avg_degree"]
# 关系类型分布 # 关系类型分布
@@ -885,7 +885,7 @@ class Neo4jManager:
ORDER BY count DESC ORDER BY count DESC
LIMIT 10 LIMIT 10
""", """,
project_id=project_id, project_id = project_id,
) )
relation_types = {record["type"]: record["count"] for record in rel_types} relation_types = {record["type"]: record["count"] for record in rel_types}
@@ -927,8 +927,8 @@ class Neo4jManager:
}) YIELD node }) YIELD node
RETURN DISTINCT node RETURN DISTINCT node
""", """,
entity_ids=entity_ids, entity_ids = entity_ids,
depth=depth, depth = depth,
) )
nodes = [] nodes = []
@@ -953,7 +953,7 @@ class Neo4jManager:
RETURN source.id as source_id, target.id as target_id, RETURN source.id as source_id, target.id as target_id,
r.relation_type as type, r.evidence as evidence r.relation_type as type, r.evidence as evidence
""", """,
node_ids=list(node_ids), node_ids = list(node_ids),
) )
relationships = [ relationships = [
@@ -1015,13 +1015,13 @@ def sync_project_to_neo4j(
# 同步实体 # 同步实体
graph_entities = [ graph_entities = [
GraphEntity( GraphEntity(
id=e["id"], id = e["id"],
project_id=project_id, project_id = project_id,
name=e["name"], name = e["name"],
type=e.get("type", "unknown"), type = e.get("type", "unknown"),
definition=e.get("definition", ""), definition = e.get("definition", ""),
aliases=e.get("aliases", []), aliases = e.get("aliases", []),
properties=e.get("properties", {}), properties = e.get("properties", {}),
) )
for e in entities for e in entities
] ]
@@ -1030,12 +1030,12 @@ def sync_project_to_neo4j(
# 同步关系 # 同步关系
graph_relations = [ graph_relations = [
GraphRelation( GraphRelation(
id=r["id"], id = r["id"],
source_id=r["source_entity_id"], source_id = r["source_entity_id"],
target_id=r["target_entity_id"], target_id = r["target_entity_id"],
relation_type=r["relation_type"], relation_type = r["relation_type"],
evidence=r.get("evidence", ""), evidence = r.get("evidence", ""),
properties=r.get("properties", {}), properties = r.get("properties", {}),
) )
for r in relations for r in relations
] ]
@@ -1048,7 +1048,7 @@ def sync_project_to_neo4j(
if __name__ == "__main__": if __name__ == "__main__":
# 测试代码 # 测试代码
logging.basicConfig(level=logging.INFO) logging.basicConfig(level = logging.INFO)
manager = Neo4jManager() manager = Neo4jManager()
@@ -1065,11 +1065,11 @@ if __name__ == "__main__":
# 测试实体 # 测试实体
test_entity = GraphEntity( test_entity = GraphEntity(
id="test-entity-1", id = "test-entity-1",
project_id="test-project", project_id = "test-project",
name="Test Entity", name = "Test Entity",
type="Person", type = "Person",
definition="A test entity", definition = "A test entity",
) )
manager.sync_entity(test_entity) manager.sync_entity(test_entity)
print("✅ Entity synced") print("✅ Entity synced")

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ import oss2
class OSSUploader: class OSSUploader:
def __init__(self): def __init__(self) -> None:
self.access_key = os.getenv("ALI_ACCESS_KEY") self.access_key = os.getenv("ALI_ACCESS_KEY")
self.secret_key = os.getenv("ALI_SECRET_KEY") self.secret_key = os.getenv("ALI_SECRET_KEY")
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio") self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")

View File

@@ -82,7 +82,7 @@ class PerformanceMetric:
endpoint: str | None endpoint: str | None
duration_ms: float duration_ms: float
timestamp: str timestamp: str
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@@ -164,7 +164,7 @@ class CacheManager:
max_memory_size: int = 100 * 1024 * 1024, # 100MB max_memory_size: int = 100 * 1024 * 1024, # 100MB
default_ttl: int = 3600, # 1小时 default_ttl: int = 3600, # 1小时
db_path: str = "insightflow.db", db_path: str = "insightflow.db",
): ) -> None:
self.db_path = db_path self.db_path = db_path
self.default_ttl = default_ttl self.default_ttl = default_ttl
self.max_memory_size = max_memory_size self.max_memory_size = max_memory_size
@@ -176,7 +176,7 @@ class CacheManager:
if REDIS_AVAILABLE and redis_url: if REDIS_AVAILABLE and redis_url:
try: try:
self.redis_client = redis.from_url(redis_url, decode_responses=True) self.redis_client = redis.from_url(redis_url, decode_responses = True)
self.redis_client.ping() self.redis_client.ping()
self.use_redis = True self.use_redis = True
print(f"Redis 缓存已连接: {redis_url}") print(f"Redis 缓存已连接: {redis_url}")
@@ -233,7 +233,7 @@ class CacheManager:
def _get_entry_size(self, value: Any) -> int: def _get_entry_size(self, value: Any) -> int:
"""估算缓存条目大小""" """估算缓存条目大小"""
try: try:
return len(json.dumps(value, ensure_ascii=False).encode("utf-8")) return len(json.dumps(value, ensure_ascii = False).encode("utf-8"))
except (TypeError, ValueError): except (TypeError, ValueError):
return 1024 # 默认估算 return 1024 # 默认估算
@@ -245,7 +245,7 @@ class CacheManager:
and self.memory_cache and self.memory_cache
): ):
# 移除最久未访问的 # 移除最久未访问的
oldest_key, oldest_entry = self.memory_cache.popitem(last=False) oldest_key, oldest_entry = self.memory_cache.popitem(last = False)
self.current_memory_size -= oldest_entry.size_bytes self.current_memory_size -= oldest_entry.size_bytes
self.stats.evictions += 1 self.stats.evictions += 1
@@ -314,7 +314,7 @@ class CacheManager:
if self.use_redis: if self.use_redis:
try: try:
serialized = json.dumps(value, ensure_ascii=False) serialized = json.dumps(value, ensure_ascii = False)
self.redis_client.setex(key, ttl, serialized) self.redis_client.setex(key, ttl, serialized)
return True return True
except Exception as e: except Exception as e:
@@ -331,12 +331,12 @@ class CacheManager:
now = time.time() now = time.time()
entry = CacheEntry( entry = CacheEntry(
key=key, key = key,
value=value, value = value,
created_at=now, created_at = now,
expires_at=now + ttl if ttl > 0 else None, expires_at = now + ttl if ttl > 0 else None,
size_bytes=size, size_bytes = size,
last_accessed=now, last_accessed = now,
) )
# 如果已存在,更新大小 # 如果已存在,更新大小
@@ -412,7 +412,7 @@ class CacheManager:
try: try:
pipe = self.redis_client.pipeline() pipe = self.redis_client.pipeline()
for key, value in mapping.items(): for key, value in mapping.items():
serialized = json.dumps(value, ensure_ascii=False) serialized = json.dumps(value, ensure_ascii = False)
pipe.setex(key, ttl, serialized) pipe.setex(key, ttl, serialized)
pipe.execute() pipe.execute()
return True return True
@@ -500,12 +500,12 @@ class CacheManager:
WHERE e.project_id = ? WHERE e.project_id = ?
ORDER BY mention_count DESC ORDER BY mention_count DESC
LIMIT 100""", LIMIT 100""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for entity in entities: for entity in entities:
key = f"entity:{entity['id']}" key = f"entity:{entity['id']}"
self.set(key, dict(entity), ttl=7200) # 2小时 self.set(key, dict(entity), ttl = 7200) # 2小时
stats["entities"] += 1 stats["entities"] += 1
# 预热关系数据 # 预热关系数据
@@ -517,12 +517,12 @@ class CacheManager:
JOIN entities e2 ON r.target_entity_id = e2.id JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ? WHERE r.project_id = ?
LIMIT 200""", LIMIT 200""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for relation in relations: for relation in relations:
key = f"relation:{relation['id']}" key = f"relation:{relation['id']}"
self.set(key, dict(relation), ttl=3600) self.set(key, dict(relation), ttl = 3600)
stats["relations"] += 1 stats["relations"] += 1
# 预热最近的转录 # 预热最近的转录
@@ -531,7 +531,7 @@ class CacheManager:
WHERE project_id = ? WHERE project_id = ?
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT 10""", LIMIT 10""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for transcript in transcripts: for transcript in transcripts:
@@ -543,16 +543,16 @@ class CacheManager:
"type": transcript.get("type", "audio"), "type": transcript.get("type", "audio"),
"created_at": transcript["created_at"], "created_at": transcript["created_at"],
} }
self.set(key, meta, ttl=1800) # 30分钟 self.set(key, meta, ttl = 1800) # 30分钟
stats["transcripts"] += 1 stats["transcripts"] += 1
# 预热项目知识库摘要 # 预热项目知识库摘要
entity_count = conn.execute( entity_count = conn.execute(
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,) "SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id, )
).fetchone()[0] ).fetchone()[0]
relation_count = conn.execute( relation_count = conn.execute(
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,) "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id, )
).fetchone()[0] ).fetchone()[0]
summary = { summary = {
@@ -561,7 +561,7 @@ class CacheManager:
"relation_count": relation_count, "relation_count": relation_count,
"cached_at": datetime.now().isoformat(), "cached_at": datetime.now().isoformat(),
} }
self.set(f"project_summary:{project_id}", summary, ttl=3600) self.set(f"project_summary:{project_id}", summary, ttl = 3600)
conn.close() conn.close()
@@ -583,7 +583,7 @@ class CacheManager:
try: try:
# 使用 Redis 的 scan 查找相关 key # 使用 Redis 的 scan 查找相关 key
pattern = f"*:{project_id}:*" pattern = f"*:{project_id}:*"
for key in self.redis_client.scan_iter(match=pattern): for key in self.redis_client.scan_iter(match = pattern):
self.redis_client.delete(key) self.redis_client.delete(key)
count += 1 count += 1
except Exception as e: except Exception as e:
@@ -619,13 +619,13 @@ class DatabaseSharding:
base_db_path: str = "insightflow.db", base_db_path: str = "insightflow.db",
shard_db_dir: str = "./shards", shard_db_dir: str = "./shards",
shards_count: int = 4, shards_count: int = 4,
): ) -> None:
self.base_db_path = base_db_path self.base_db_path = base_db_path
self.shard_db_dir = shard_db_dir self.shard_db_dir = shard_db_dir
self.shards_count = shards_count self.shards_count = shards_count
# 确保分片目录存在 # 确保分片目录存在
os.makedirs(shard_db_dir, exist_ok=True) os.makedirs(shard_db_dir, exist_ok = True)
# 分片映射 # 分片映射
self.shard_map: dict[str, ShardInfo] = {} self.shard_map: dict[str, ShardInfo] = {}
@@ -650,10 +650,10 @@ class DatabaseSharding:
db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db") db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db")
self.shard_map[shard_id] = ShardInfo( self.shard_map[shard_id] = ShardInfo(
shard_id=shard_id, shard_id = shard_id,
shard_key_range=(start_char, end_char), shard_key_range = (start_char, end_char),
db_path=db_path, db_path = db_path,
created_at=datetime.now().isoformat(), created_at = datetime.now().isoformat(),
) )
# 确保分片数据库存在 # 确保分片数据库存在
@@ -757,11 +757,11 @@ class DatabaseSharding:
source_conn.row_factory = sqlite3.Row source_conn.row_factory = sqlite3.Row
entities = source_conn.execute( entities = source_conn.execute(
"SELECT * FROM entities WHERE project_id = ?", (project_id,) "SELECT * FROM entities WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
relations = source_conn.execute( relations = source_conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id,) "SELECT * FROM entity_relations WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
source_conn.close() source_conn.close()
@@ -794,8 +794,8 @@ class DatabaseSharding:
# 从源分片删除数据 # 从源分片删除数据
source_conn = sqlite3.connect(source_info.db_path) source_conn = sqlite3.connect(source_info.db_path)
source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id,)) source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id, ))
source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id,)) source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id, ))
source_conn.commit() source_conn.commit()
source_conn.close() source_conn.close()
@@ -917,7 +917,7 @@ class TaskQueue:
- 任务状态追踪和重试机制 - 任务状态追踪和重试机制
""" """
def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db"): def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self.redis_url = redis_url self.redis_url = redis_url
self.celery_app = None self.celery_app = None
@@ -934,7 +934,7 @@ class TaskQueue:
# 初始化 Celery # 初始化 Celery
if CELERY_AVAILABLE and redis_url: if CELERY_AVAILABLE and redis_url:
try: try:
self.celery_app = Celery("insightflow", broker=redis_url, backend=redis_url) self.celery_app = Celery("insightflow", broker = redis_url, backend = redis_url)
self.use_celery = True self.use_celery = True
print("Celery 任务队列已初始化") print("Celery 任务队列已初始化")
except Exception as e: except Exception as e:
@@ -989,12 +989,12 @@ class TaskQueue:
task_id = str(uuid.uuid4())[:16] task_id = str(uuid.uuid4())[:16]
task = TaskInfo( task = TaskInfo(
id=task_id, id = task_id,
task_type=task_type, task_type = task_type,
status="pending", status = "pending",
payload=payload, payload = payload,
created_at=datetime.now().isoformat(), created_at = datetime.now().isoformat(),
max_retries=max_retries, max_retries = max_retries,
) )
if self.use_celery: if self.use_celery:
@@ -1003,10 +1003,10 @@ class TaskQueue:
# 这里简化处理,实际应该定义具体的 Celery 任务 # 这里简化处理,实际应该定义具体的 Celery 任务
result = self.celery_app.send_task( result = self.celery_app.send_task(
f"insightflow.tasks.{task_type}", f"insightflow.tasks.{task_type}",
args=[payload], args = [payload],
task_id=task_id, task_id = task_id,
retry=True, retry = True,
retry_policy={ retry_policy = {
"max_retries": max_retries, "max_retries": max_retries,
"interval_start": 10, "interval_start": 10,
"interval_step": 10, "interval_step": 10,
@@ -1024,7 +1024,7 @@ class TaskQueue:
with self.task_lock: with self.task_lock:
self.tasks[task_id] = task self.tasks[task_id] = task
# 异步执行 # 异步执行
threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start() threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start()
# 保存到数据库 # 保存到数据库
self._save_task(task) self._save_task(task)
@@ -1061,7 +1061,7 @@ class TaskQueue:
task.status = "retrying" task.status = "retrying"
# 延迟重试 # 延迟重试
threading.Timer( threading.Timer(
10 * task.retry_count, self._execute_task, args=(task_id,) 10 * task.retry_count, self._execute_task, args = (task_id, )
).start() ).start()
else: else:
task.status = "failed" task.status = "failed"
@@ -1089,8 +1089,8 @@ class TaskQueue:
task.id, task.id,
task.task_type, task.task_type,
task.status, task.status,
json.dumps(task.payload, ensure_ascii=False), json.dumps(task.payload, ensure_ascii = False),
json.dumps(task.result, ensure_ascii=False) if task.result else None, json.dumps(task.result, ensure_ascii = False) if task.result else None,
task.error_message, task.error_message,
task.retry_count, task.retry_count,
task.max_retries, task.max_retries,
@@ -1120,7 +1120,7 @@ class TaskQueue:
""", """,
( (
task.status, task.status,
json.dumps(task.result, ensure_ascii=False) if task.result else None, json.dumps(task.result, ensure_ascii = False) if task.result else None,
task.error_message, task.error_message,
task.retry_count, task.retry_count,
task.started_at, task.started_at,
@@ -1136,7 +1136,7 @@ class TaskQueue:
"""获取任务状态""" """获取任务状态"""
if self.use_celery: if self.use_celery:
try: try:
result = AsyncResult(task_id, app=self.celery_app) result = AsyncResult(task_id, app = self.celery_app)
status_map = { status_map = {
"PENDING": "pending", "PENDING": "pending",
@@ -1147,13 +1147,13 @@ class TaskQueue:
} }
return TaskInfo( return TaskInfo(
id=task_id, id = task_id,
task_type="celery_task", task_type = "celery_task",
status=status_map.get(result.status, "unknown"), status = status_map.get(result.status, "unknown"),
payload={}, payload = {},
created_at="", created_at = "",
result=result.result if result.successful() else None, result = result.result if result.successful() else None,
error_message=str(result.result) if result.failed() else None, error_message = str(result.result) if result.failed() else None,
) )
except Exception as e: except Exception as e:
print(f"获取 Celery 任务状态失败: {e}") print(f"获取 Celery 任务状态失败: {e}")
@@ -1180,7 +1180,7 @@ class TaskQueue:
where_clauses.append("task_type = ?") where_clauses.append("task_type = ?")
params.append(task_type) params.append(task_type)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1" where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1"
rows = conn.execute( rows = conn.execute(
f""" f"""
@@ -1198,17 +1198,17 @@ class TaskQueue:
for row in rows: for row in rows:
tasks.append( tasks.append(
TaskInfo( TaskInfo(
id=row["id"], id = row["id"],
task_type=row["task_type"], task_type = row["task_type"],
status=row["status"], status = row["status"],
payload=json.loads(row["payload"]) if row["payload"] else {}, payload = json.loads(row["payload"]) if row["payload"] else {},
created_at=row["created_at"], created_at = row["created_at"],
started_at=row["started_at"], started_at = row["started_at"],
completed_at=row["completed_at"], completed_at = row["completed_at"],
result=json.loads(row["result"]) if row["result"] else None, result = json.loads(row["result"]) if row["result"] else None,
error_message=row["error_message"], error_message = row["error_message"],
retry_count=row["retry_count"], retry_count = row["retry_count"],
max_retries=row["max_retries"], max_retries = row["max_retries"],
) )
) )
@@ -1218,7 +1218,7 @@ class TaskQueue:
"""取消任务""" """取消任务"""
if self.use_celery: if self.use_celery:
try: try:
self.celery_app.control.revoke(task_id, terminate=True) self.celery_app.control.revoke(task_id, terminate = True)
return True return True
except Exception as e: except Exception as e:
print(f"取消 Celery 任务失败: {e}") print(f"取消 Celery 任务失败: {e}")
@@ -1248,7 +1248,7 @@ class TaskQueue:
if not self.use_celery: if not self.use_celery:
with self.task_lock: with self.task_lock:
self.tasks[task_id] = task self.tasks[task_id] = task
threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start() threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start()
self._update_task_status(task) self._update_task_status(task)
return True return True
@@ -1307,7 +1307,7 @@ class PerformanceMonitor:
db_path: str = "insightflow.db", db_path: str = "insightflow.db",
slow_query_threshold: int = 1000, slow_query_threshold: int = 1000,
alert_threshold: int = 5000, # 毫秒 alert_threshold: int = 5000, # 毫秒
): # 毫秒 ) -> None: # 毫秒
self.db_path = db_path self.db_path = db_path
self.slow_query_threshold = slow_query_threshold self.slow_query_threshold = slow_query_threshold
self.alert_threshold = alert_threshold self.alert_threshold = alert_threshold
@@ -1326,7 +1326,7 @@ class PerformanceMonitor:
duration_ms: float, duration_ms: float,
endpoint: str | None = None, endpoint: str | None = None,
metadata: dict | None = None, metadata: dict | None = None,
): ) -> None:
""" """
记录性能指标 记录性能指标
@@ -1337,12 +1337,12 @@ class PerformanceMonitor:
metadata: 额外元数据 metadata: 额外元数据
""" """
metric = PerformanceMetric( metric = PerformanceMetric(
id=str(uuid.uuid4())[:16], id = str(uuid.uuid4())[:16],
metric_type=metric_type, metric_type = metric_type,
endpoint=endpoint, endpoint = endpoint,
duration_ms=duration_ms, duration_ms = duration_ms,
timestamp=datetime.now().isoformat(), timestamp = datetime.now().isoformat(),
metadata=metadata or {}, metadata = metadata or {},
) )
# 添加到缓冲区 # 添加到缓冲区
@@ -1379,7 +1379,7 @@ class PerformanceMonitor:
metric.endpoint, metric.endpoint,
metric.duration_ms, metric.duration_ms,
metric.timestamp, metric.timestamp,
json.dumps(metric.metadata, ensure_ascii=False), json.dumps(metric.metadata, ensure_ascii = False),
), ),
) )
@@ -1439,7 +1439,7 @@ class PerformanceMonitor:
FROM performance_metrics FROM performance_metrics
WHERE timestamp > datetime('now', ?) WHERE timestamp > datetime('now', ?)
""", """,
(f"-{hours} hours",), (f"-{hours} hours", ),
).fetchone() ).fetchone()
# 按类型统计 # 按类型统计
@@ -1454,7 +1454,7 @@ class PerformanceMonitor:
WHERE timestamp > datetime('now', ?) WHERE timestamp > datetime('now', ?)
GROUP BY metric_type GROUP BY metric_type
""", """,
(f"-{hours} hours",), (f"-{hours} hours", ),
).fetchall() ).fetchall()
# 按端点统计API # 按端点统计API
@@ -1472,7 +1472,7 @@ class PerformanceMonitor:
ORDER BY avg_duration DESC ORDER BY avg_duration DESC
LIMIT 20 LIMIT 20
""", """,
(f"-{hours} hours",), (f"-{hours} hours", ),
).fetchall() ).fetchall()
# 慢查询统计 # 慢查询统计
@@ -1597,7 +1597,7 @@ class PerformanceMonitor:
DELETE FROM performance_metrics DELETE FROM performance_metrics
WHERE timestamp < datetime('now', ?) WHERE timestamp < datetime('now', ?)
""", """,
(f"-{days} days",), (f"-{days} days", ),
) )
deleted = cursor.rowcount deleted = cursor.rowcount
@@ -1668,7 +1668,7 @@ def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | Non
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs) -> None:
start_time = time.time() start_time = time.time()
try: try:
@@ -1699,17 +1699,17 @@ class PerformanceManager:
db_path: str = "insightflow.db", db_path: str = "insightflow.db",
redis_url: str | None = None, redis_url: str | None = None,
enable_sharding: bool = False, enable_sharding: bool = False,
): ) -> None:
self.db_path = db_path self.db_path = db_path
# 初始化各模块 # 初始化各模块
self.cache = CacheManager(redis_url=redis_url, db_path=db_path) self.cache = CacheManager(redis_url = redis_url, db_path = db_path)
self.sharding = DatabaseSharding(base_db_path=db_path) if enable_sharding else None self.sharding = DatabaseSharding(base_db_path = db_path) if enable_sharding else None
self.task_queue = TaskQueue(redis_url=redis_url, db_path=db_path) self.task_queue = TaskQueue(redis_url = redis_url, db_path = db_path)
self.monitor = PerformanceMonitor(db_path=db_path) self.monitor = PerformanceMonitor(db_path = db_path)
def get_health_status(self) -> dict: def get_health_status(self) -> dict:
"""获取系统健康状态""" """获取系统健康状态"""
@@ -1760,6 +1760,6 @@ def get_performance_manager(
global _performance_manager global _performance_manager
if _performance_manager is None: if _performance_manager is None:
_performance_manager = PerformanceManager( _performance_manager = PerformanceManager(
db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding db_path = db_path, redis_url = redis_url, enable_sharding = enable_sharding
) )
return _performance_manager return _performance_manager

View File

@@ -63,7 +63,7 @@ class Plugin:
plugin_type: str plugin_type: str
project_id: str project_id: str
status: str = "active" status: str = "active"
config: dict = field(default_factory=dict) config: dict = field(default_factory = dict)
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_used_at: str | None = None last_used_at: str | None = None
@@ -111,8 +111,8 @@ class WebhookEndpoint:
endpoint_url: str endpoint_url: str
project_id: str | None = None project_id: str | None = None
auth_type: str = "none" # none, api_key, oauth, custom auth_type: str = "none" # none, api_key, oauth, custom
auth_config: dict = field(default_factory=dict) auth_config: dict = field(default_factory = dict)
trigger_events: list[str] = field(default_factory=list) trigger_events: list[str] = field(default_factory = list)
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@@ -151,7 +151,7 @@ class ChromeExtensionToken:
user_id: str | None = None user_id: str | None = None
project_id: str | None = None project_id: str | None = None
name: str = "" name: str = ""
permissions: list[str] = field(default_factory=lambda: ["read", "write"]) permissions: list[str] = field(default_factory = lambda: ["read", "write"])
expires_at: str | None = None expires_at: str | None = None
created_at: str = "" created_at: str = ""
last_used_at: str | None = None last_used_at: str | None = None
@@ -162,7 +162,7 @@ class ChromeExtensionToken:
class PluginManager: class PluginManager:
"""插件管理主类""" """插件管理主类"""
def __init__(self, db_manager=None): def __init__(self, db_manager = None) -> None:
self.db = db_manager self.db = db_manager
self._handlers = {} self._handlers = {}
self._register_default_handlers() self._register_default_handlers()
@@ -213,7 +213,7 @@ class PluginManager:
def get_plugin(self, plugin_id: str) -> Plugin | None: def get_plugin(self, plugin_id: str) -> Plugin | None:
"""获取插件""" """获取插件"""
conn = self.db.get_conn() conn = self.db.get_conn()
row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id,)).fetchone() row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id, )).fetchone()
conn.close() conn.close()
if row: if row:
@@ -239,7 +239,7 @@ class PluginManager:
conditions.append("status = ?") conditions.append("status = ?")
params.append(status) params.append(status)
where_clause = " AND ".join(conditions) if conditions else "1=1" where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute( rows = conn.execute(
f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params
@@ -284,10 +284,10 @@ class PluginManager:
conn = self.db.get_conn() conn = self.db.get_conn()
# 删除关联的配置 # 删除关联的配置
conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id,)) conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id, ))
# 删除插件 # 删除插件
cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id,)) cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -296,16 +296,16 @@ class PluginManager:
def _row_to_plugin(self, row: sqlite3.Row) -> Plugin: def _row_to_plugin(self, row: sqlite3.Row) -> Plugin:
"""将数据库行转换为 Plugin 对象""" """将数据库行转换为 Plugin 对象"""
return Plugin( return Plugin(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
plugin_type=row["plugin_type"], plugin_type = row["plugin_type"],
project_id=row["project_id"], project_id = row["project_id"],
status=row["status"], status = row["status"],
config=json.loads(row["config"]) if row["config"] else {}, config = json.loads(row["config"]) if row["config"] else {},
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
last_used_at=row["last_used_at"], last_used_at = row["last_used_at"],
use_count=row["use_count"], use_count = row["use_count"],
) )
# ==================== Plugin Config ==================== # ==================== Plugin Config ====================
@@ -343,13 +343,13 @@ class PluginManager:
conn.close() conn.close()
return PluginConfig( return PluginConfig(
id=config_id, id = config_id,
plugin_id=plugin_id, plugin_id = plugin_id,
config_key=key, config_key = key,
config_value=value, config_value = value,
is_encrypted=is_encrypted, is_encrypted = is_encrypted,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
def get_plugin_config(self, plugin_id: str, key: str) -> str | None: def get_plugin_config(self, plugin_id: str, key: str) -> str | None:
@@ -367,7 +367,7 @@ class PluginManager:
"""获取插件所有配置""" """获取插件所有配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id,) "SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id, )
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -402,7 +402,7 @@ class PluginManager:
class ChromeExtensionHandler: class ChromeExtensionHandler:
"""Chrome 扩展处理器""" """Chrome 扩展处理器"""
def __init__(self, plugin_manager: PluginManager): def __init__(self, plugin_manager: PluginManager) -> None:
self.pm = plugin_manager self.pm = plugin_manager
def create_token( def create_token(
@@ -417,7 +417,7 @@ class ChromeExtensionHandler:
token_id = str(uuid.uuid4())[:UUID_LENGTH] token_id = str(uuid.uuid4())[:UUID_LENGTH]
# 生成随机令牌 # 生成随机令牌
raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip('=')}" raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip(' = ')}"
# 哈希存储 # 哈希存储
token_hash = hashlib.sha256(raw_token.encode()).hexdigest() token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
@@ -427,7 +427,7 @@ class ChromeExtensionHandler:
if expires_days: if expires_days:
from datetime import timedelta from datetime import timedelta
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat() expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat()
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
conn.execute( conn.execute(
@@ -452,14 +452,14 @@ class ChromeExtensionHandler:
conn.close() conn.close()
return ChromeExtensionToken( return ChromeExtensionToken(
id=token_id, id = token_id,
token=raw_token, # 仅返回一次 token = raw_token, # 仅返回一次
user_id=user_id, user_id = user_id,
project_id=project_id, project_id = project_id,
name=name, name = name,
permissions=permissions or ["read"], permissions = permissions or ["read"],
expires_at=expires_at, expires_at = expires_at,
created_at=now, created_at = now,
) )
def validate_token(self, token: str) -> ChromeExtensionToken | None: def validate_token(self, token: str) -> ChromeExtensionToken | None:
@@ -470,7 +470,7 @@ class ChromeExtensionHandler:
row = conn.execute( row = conn.execute(
"""SELECT * FROM chrome_extension_tokens """SELECT * FROM chrome_extension_tokens
WHERE token_hash = ? AND is_revoked = 0""", WHERE token_hash = ? AND is_revoked = 0""",
(token_hash,), (token_hash, ),
).fetchone() ).fetchone()
conn.close() conn.close()
@@ -494,23 +494,23 @@ class ChromeExtensionHandler:
conn.close() conn.close()
return ChromeExtensionToken( return ChromeExtensionToken(
id=row["id"], id = row["id"],
token="", # 不返回实际令牌 token = "", # 不返回实际令牌
user_id=row["user_id"], user_id = row["user_id"],
project_id=row["project_id"], project_id = row["project_id"],
name=row["name"], name = row["name"],
permissions=json.loads(row["permissions"]), permissions = json.loads(row["permissions"]),
expires_at=row["expires_at"], expires_at = row["expires_at"],
created_at=row["created_at"], created_at = row["created_at"],
last_used_at=now, last_used_at = now,
use_count=row["use_count"] + 1, use_count = row["use_count"] + 1,
) )
def revoke_token(self, token_id: str) -> bool: def revoke_token(self, token_id: str) -> bool:
"""撤销令牌""" """撤销令牌"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
cursor = conn.execute( cursor = conn.execute(
"UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,) "UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id, )
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -545,17 +545,17 @@ class ChromeExtensionHandler:
for row in rows: for row in rows:
tokens.append( tokens.append(
ChromeExtensionToken( ChromeExtensionToken(
id=row["id"], id = row["id"],
token="", # 不返回实际令牌 token = "", # 不返回实际令牌
user_id=row["user_id"], user_id = row["user_id"],
project_id=row["project_id"], project_id = row["project_id"],
name=row["name"], name = row["name"],
permissions=json.loads(row["permissions"]), permissions = json.loads(row["permissions"]),
expires_at=row["expires_at"], expires_at = row["expires_at"],
created_at=row["created_at"], created_at = row["created_at"],
last_used_at=row["last_used_at"], last_used_at = row["last_used_at"],
use_count=row["use_count"], use_count = row["use_count"],
is_revoked=bool(row["is_revoked"]), is_revoked = bool(row["is_revoked"]),
) )
) )
@@ -606,7 +606,7 @@ class ChromeExtensionHandler:
class BotHandler: class BotHandler:
"""飞书/钉钉机器人处理器""" """飞书/钉钉机器人处理器"""
def __init__(self, plugin_manager: PluginManager, bot_type: str): def __init__(self, plugin_manager: PluginManager, bot_type: str) -> None:
self.pm = plugin_manager self.pm = plugin_manager
self.bot_type = bot_type self.bot_type = bot_type
@@ -646,16 +646,16 @@ class BotHandler:
conn.close() conn.close()
return BotSession( return BotSession(
id=bot_id, id = bot_id,
bot_type=self.bot_type, bot_type = self.bot_type,
session_id=session_id, session_id = session_id,
session_name=session_name, session_name = session_name,
project_id=project_id, project_id = project_id,
webhook_url=webhook_url, webhook_url = webhook_url,
secret=secret, secret = secret,
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
def get_session(self, session_id: str) -> BotSession | None: def get_session(self, session_id: str) -> BotSession | None:
@@ -686,7 +686,7 @@ class BotHandler:
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM bot_sessions """SELECT * FROM bot_sessions
WHERE bot_type = ? ORDER BY created_at DESC""", WHERE bot_type = ? ORDER BY created_at DESC""",
(self.bot_type,), (self.bot_type, ),
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -739,18 +739,18 @@ class BotHandler:
def _row_to_session(self, row: sqlite3.Row) -> BotSession: def _row_to_session(self, row: sqlite3.Row) -> BotSession:
"""将数据库行转换为 BotSession 对象""" """将数据库行转换为 BotSession 对象"""
return BotSession( return BotSession(
id=row["id"], id = row["id"],
bot_type=row["bot_type"], bot_type = row["bot_type"],
session_id=row["session_id"], session_id = row["session_id"],
session_name=row["session_name"], session_name = row["session_name"],
project_id=row["project_id"], project_id = row["project_id"],
webhook_url=row["webhook_url"], webhook_url = row["webhook_url"],
secret=row["secret"], secret = row["secret"],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
last_message_at=row["last_message_at"], last_message_at = row["last_message_at"],
message_count=row["message_count"], message_count = row["message_count"],
) )
async def handle_message(self, session: BotSession, message: dict) -> dict: async def handle_message(self, session: BotSession, message: dict) -> dict:
@@ -880,7 +880,7 @@ class BotHandler:
hmac_code = hmac.new( hmac_code = hmac.new(
session.secret.encode("utf-8"), session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"), string_to_sign.encode("utf-8"),
digestmod=hashlib.sha256, digestmod = hashlib.sha256,
).digest() ).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
else: else:
@@ -895,7 +895,7 @@ class BotHandler:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
session.webhook_url, json=payload, headers={"Content-Type": "application/json"} session.webhook_url, json = payload, headers = {"Content-Type": "application/json"}
) )
return response.status_code == 200 return response.status_code == 200
@@ -911,7 +911,7 @@ class BotHandler:
hmac_code = hmac.new( hmac_code = hmac.new(
session.secret.encode("utf-8"), session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"), string_to_sign.encode("utf-8"),
digestmod=hashlib.sha256, digestmod = hashlib.sha256,
).digest() ).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
sign = urllib.parse.quote(sign) sign = urllib.parse.quote(sign)
@@ -922,11 +922,11 @@ class BotHandler:
url = session.webhook_url url = session.webhook_url
if sign: if sign:
url = f"{url}&timestamp={timestamp}&sign={sign}" url = f"{url}&timestamp = {timestamp}&sign = {sign}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
url, json=payload, headers={"Content-Type": "application/json"} url, json = payload, headers = {"Content-Type": "application/json"}
) )
return response.status_code == 200 return response.status_code == 200
@@ -934,7 +934,7 @@ class BotHandler:
class WebhookIntegration: class WebhookIntegration:
"""Zapier/Make Webhook 集成""" """Zapier/Make Webhook 集成"""
def __init__(self, plugin_manager: PluginManager, endpoint_type: str): def __init__(self, plugin_manager: PluginManager, endpoint_type: str) -> None:
self.pm = plugin_manager self.pm = plugin_manager
self.endpoint_type = endpoint_type self.endpoint_type = endpoint_type
@@ -976,17 +976,17 @@ class WebhookIntegration:
conn.close() conn.close()
return WebhookEndpoint( return WebhookEndpoint(
id=endpoint_id, id = endpoint_id,
name=name, name = name,
endpoint_type=self.endpoint_type, endpoint_type = self.endpoint_type,
endpoint_url=endpoint_url, endpoint_url = endpoint_url,
project_id=project_id, project_id = project_id,
auth_type=auth_type, auth_type = auth_type,
auth_config=auth_config or {}, auth_config = auth_config or {},
trigger_events=trigger_events or [], trigger_events = trigger_events or [],
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None: def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None:
@@ -1016,7 +1016,7 @@ class WebhookIntegration:
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM webhook_endpoints """SELECT * FROM webhook_endpoints
WHERE endpoint_type = ? ORDER BY created_at DESC""", WHERE endpoint_type = ? ORDER BY created_at DESC""",
(self.endpoint_type,), (self.endpoint_type, ),
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -1065,7 +1065,7 @@ class WebhookIntegration:
def delete_endpoint(self, endpoint_id: str) -> bool: def delete_endpoint(self, endpoint_id: str) -> bool:
"""删除端点""" """删除端点"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id,)) cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -1074,19 +1074,19 @@ class WebhookIntegration:
def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint: def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint:
"""将数据库行转换为 WebhookEndpoint 对象""" """将数据库行转换为 WebhookEndpoint 对象"""
return WebhookEndpoint( return WebhookEndpoint(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
endpoint_type=row["endpoint_type"], endpoint_type = row["endpoint_type"],
endpoint_url=row["endpoint_url"], endpoint_url = row["endpoint_url"],
project_id=row["project_id"], project_id = row["project_id"],
auth_type=row["auth_type"], auth_type = row["auth_type"],
auth_config=json.loads(row["auth_config"]) if row["auth_config"] else {}, auth_config = json.loads(row["auth_config"]) if row["auth_config"] else {},
trigger_events=json.loads(row["trigger_events"]) if row["trigger_events"] else [], trigger_events = json.loads(row["trigger_events"]) if row["trigger_events"] else [],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
last_triggered_at=row["last_triggered_at"], last_triggered_at = row["last_triggered_at"],
trigger_count=row["trigger_count"], trigger_count = row["trigger_count"],
) )
async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool: async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool:
@@ -1113,7 +1113,7 @@ class WebhookIntegration:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0 endpoint.endpoint_url, json = payload, headers = headers, timeout = 30.0
) )
success = response.status_code in [200, 201, 202] success = response.status_code in [200, 201, 202]
@@ -1157,7 +1157,7 @@ class WebhookIntegration:
class WebDAVSyncManager: class WebDAVSyncManager:
"""WebDAV 同步管理""" """WebDAV 同步管理"""
def __init__(self, plugin_manager: PluginManager): def __init__(self, plugin_manager: PluginManager) -> None:
self.pm = plugin_manager self.pm = plugin_manager
def create_sync( def create_sync(
@@ -1202,25 +1202,25 @@ class WebDAVSyncManager:
conn.close() conn.close()
return WebDAVSync( return WebDAVSync(
id=sync_id, id = sync_id,
name=name, name = name,
project_id=project_id, project_id = project_id,
server_url=server_url, server_url = server_url,
username=username, username = username,
password=password, password = password,
remote_path=remote_path, remote_path = remote_path,
sync_mode=sync_mode, sync_mode = sync_mode,
sync_interval=sync_interval, sync_interval = sync_interval,
last_sync_status="pending", last_sync_status = "pending",
is_active=True, is_active = True,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
def get_sync(self, sync_id: str) -> WebDAVSync | None: def get_sync(self, sync_id: str) -> WebDAVSync | None:
"""获取同步配置""" """获取同步配置"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id,)).fetchone() row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id, )).fetchone()
conn.close() conn.close()
if row: if row:
@@ -1234,7 +1234,7 @@ class WebDAVSyncManager:
if project_id: if project_id:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC", "SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
else: else:
rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall() rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall()
@@ -1283,7 +1283,7 @@ class WebDAVSyncManager:
def delete_sync(self, sync_id: str) -> bool: def delete_sync(self, sync_id: str) -> bool:
"""删除同步配置""" """删除同步配置"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id,)) cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id, ))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -1292,22 +1292,22 @@ class WebDAVSyncManager:
def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync: def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync:
"""将数据库行转换为 WebDAVSync 对象""" """将数据库行转换为 WebDAVSync 对象"""
return WebDAVSync( return WebDAVSync(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
project_id=row["project_id"], project_id = row["project_id"],
server_url=row["server_url"], server_url = row["server_url"],
username=row["username"], username = row["username"],
password=row["password"], password = row["password"],
remote_path=row["remote_path"], remote_path = row["remote_path"],
sync_mode=row["sync_mode"], sync_mode = row["sync_mode"],
sync_interval=row["sync_interval"], sync_interval = row["sync_interval"],
last_sync_at=row["last_sync_at"], last_sync_at = row["last_sync_at"],
last_sync_status=row["last_sync_status"], last_sync_status = row["last_sync_status"],
last_sync_error=row["last_sync_error"] or "", last_sync_error = row["last_sync_error"] or "",
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
sync_count=row["sync_count"], sync_count = row["sync_count"],
) )
async def test_connection(self, sync: WebDAVSync) -> dict: async def test_connection(self, sync: WebDAVSync) -> dict:
@@ -1316,7 +1316,7 @@ class WebDAVSyncManager:
return {"success": False, "error": "WebDAV library not available"} return {"success": False, "error": "WebDAV library not available"}
try: try:
client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password)) client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password))
# 尝试列出根目录 # 尝试列出根目录
client.list("/") client.list("/")
@@ -1335,7 +1335,7 @@ class WebDAVSyncManager:
return {"success": False, "error": "Sync is not active"} return {"success": False, "error": "Sync is not active"}
try: try:
client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password)) client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password))
# 确保远程目录存在 # 确保远程目录存在
remote_project_path = f"{sync.remote_path}/{sync.project_id}" remote_project_path = f"{sync.remote_path}/{sync.project_id}"
@@ -1367,13 +1367,13 @@ class WebDAVSyncManager:
} }
# 上传 JSON 文件 # 上传 JSON 文件
json_content = json.dumps(export_data, ensure_ascii=False, indent=2) json_content = json.dumps(export_data, ensure_ascii = False, indent = 2)
json_path = f"{remote_project_path}/project_export.json" json_path = f"{remote_project_path}/project_export.json"
# 使用临时文件上传 # 使用临时文件上传
import tempfile import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: with tempfile.NamedTemporaryFile(mode = "w", suffix = ".json", delete = False) as f:
f.write(json_content) f.write(json_content)
temp_path = f.name temp_path = f.name
@@ -1419,7 +1419,7 @@ class WebDAVSyncManager:
_plugin_manager = None _plugin_manager = None
def get_plugin_manager(db_manager=None) -> None: def get_plugin_manager(db_manager = None) -> None:
"""获取 PluginManager 单例""" """获取 PluginManager 单例"""
global _plugin_manager global _plugin_manager
if _plugin_manager is None: if _plugin_manager is None:

View File

@@ -35,7 +35,7 @@ class RateLimitInfo:
class SlidingWindowCounter: class SlidingWindowCounter:
"""滑动窗口计数器""" """滑动窗口计数器"""
def __init__(self, window_size: int = 60): def __init__(self, window_size: int = 60) -> None:
self.window_size = window_size self.window_size = window_size
self.requests: dict[int, int] = defaultdict(int) # 秒级计数 self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@@ -110,17 +110,17 @@ class RateLimiter:
# 检查是否超过限制 # 检查是否超过限制
if current_count >= stored_config.requests_per_minute: if current_count >= stored_config.requests_per_minute:
return RateLimitInfo( return RateLimitInfo(
allowed=False, allowed = False,
remaining=0, remaining = 0,
reset_time=reset_time, reset_time = reset_time,
retry_after=stored_config.window_size, retry_after = stored_config.window_size,
) )
# 允许请求,增加计数 # 允许请求,增加计数
await counter.add_request() await counter.add_request()
return RateLimitInfo( return RateLimitInfo(
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0 allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0
) )
async def get_limit_info(self, key: str) -> RateLimitInfo: async def get_limit_info(self, key: str) -> RateLimitInfo:
@@ -128,10 +128,10 @@ class RateLimiter:
if key not in self.counters: if key not in self.counters:
config = RateLimitConfig() config = RateLimitConfig()
return RateLimitInfo( return RateLimitInfo(
allowed=True, allowed = True,
remaining=config.requests_per_minute, remaining = config.requests_per_minute,
reset_time=int(time.time()) + config.window_size, reset_time = int(time.time()) + config.window_size,
retry_after=0, retry_after = 0,
) )
counter = self.counters[key] counter = self.counters[key]
@@ -142,10 +142,10 @@ class RateLimiter:
reset_time = int(time.time()) + config.window_size reset_time = int(time.time()) + config.window_size
return RateLimitInfo( return RateLimitInfo(
allowed=current_count < config.requests_per_minute, allowed = current_count < config.requests_per_minute,
remaining=remaining, remaining = remaining,
reset_time=reset_time, reset_time = reset_time,
retry_after=max(0, config.window_size) retry_after = max(0, config.window_size)
if current_count >= config.requests_per_minute if current_count >= config.requests_per_minute
else 0, else 0,
) )
@@ -184,12 +184,12 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
key_func: 生成限流键的函数,默认为 None使用函数名 key_func: 生成限流键的函数,默认为 None使用函数名
""" """
def decorator(func): def decorator(func) -> None:
limiter = get_rate_limiter() limiter = get_rate_limiter()
config = RateLimitConfig(requests_per_minute=requests_per_minute) config = RateLimitConfig(requests_per_minute = requests_per_minute)
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs) -> None:
key = key_func(*args, **kwargs) if key_func else func.__name__ key = key_func(*args, **kwargs) if key_func else func.__name__
info = await limiter.is_allowed(key, config) info = await limiter.is_allowed(key, config)
@@ -201,7 +201,7 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
return await func(*args, **kwargs) return await func(*args, **kwargs)
@wraps(func) @wraps(func)
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs) -> None:
key = key_func(*args, **kwargs) if key_func else func.__name__ key = key_func(*args, **kwargs) if key_func else func.__name__
# 同步版本使用 asyncio.run # 同步版本使用 asyncio.run
info = asyncio.run(limiter.is_allowed(key, config)) info = asyncio.run(limiter.is_allowed(key, config))

View File

@@ -49,8 +49,8 @@ class SearchResult:
content_type: str # transcript, entity, relation content_type: str # transcript, entity, relation
project_id: str project_id: str
score: float score: float
highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置 highlights: list[tuple[int, int]] = field(default_factory = list) # 高亮位置
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@@ -74,7 +74,7 @@ class SemanticSearchResult:
project_id: str project_id: str
similarity: float similarity: float
embedding: list[float] | None = None embedding: list[float] | None = None
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
result = { result = {
@@ -132,7 +132,7 @@ class KnowledgeGap:
severity: str # high, medium, low severity: str # high, medium, low
suggestions: list[str] suggestions: list[str]
related_entities: list[str] related_entities: list[str]
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@@ -189,7 +189,7 @@ class FullTextSearch:
- 支持布尔搜索AND/OR/NOT - 支持布尔搜索AND/OR/NOT
""" """
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self._init_search_tables() self._init_search_tables()
@@ -318,8 +318,8 @@ class FullTextSearch:
content_id, content_id,
content_type, content_type,
project_id, project_id,
json.dumps(tokens, ensure_ascii=False), json.dumps(tokens, ensure_ascii = False),
json.dumps(token_positions, ensure_ascii=False), json.dumps(token_positions, ensure_ascii = False),
now, now,
now, now,
), ),
@@ -340,7 +340,7 @@ class FullTextSearch:
content_type, content_type,
project_id, project_id,
freq, freq,
json.dumps(positions, ensure_ascii=False), json.dumps(positions, ensure_ascii = False),
), ),
) )
@@ -383,7 +383,7 @@ class FullTextSearch:
scored_results = self._score_results(results, parsed_query) scored_results = self._score_results(results, parsed_query)
# 排序和分页 # 排序和分页
scored_results.sort(key=lambda x: x.score, reverse=True) scored_results.sort(key = lambda x: x.score, reverse = True)
return scored_results[offset : offset + limit] return scored_results[offset : offset + limit]
@@ -412,10 +412,10 @@ class FullTextSearch:
not_pattern = r"(?:NOT\s+|\-)(\w+)" not_pattern = r"(?:NOT\s+|\-)(\w+)"
not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE)
not_terms.extend(not_matches) not_terms.extend(not_matches)
query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE) query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags = re.IGNORECASE)
# 处理 OR # 处理 OR
or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE) or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags = re.IGNORECASE)
if len(or_parts) > 1: if len(or_parts) > 1:
or_terms = [p.strip() for p in or_parts[1:] if p.strip()] or_terms = [p.strip() for p in or_parts[1:] if p.strip()]
query_without_phrases = or_parts[0] query_without_phrases = or_parts[0]
@@ -443,11 +443,11 @@ class FullTextSearch:
params.append(project_id) params.append(project_id)
if content_types: if content_types:
placeholders = ",".join(["?" for _ in content_types]) placeholders = ", ".join(["?" for _ in content_types])
base_where.append(f"content_type IN ({placeholders})") base_where.append(f"content_type IN ({placeholders})")
params.extend(content_types) params.extend(content_types)
base_where_str = " AND ".join(base_where) if base_where else "1=1" base_where_str = " AND ".join(base_where) if base_where else "1 = 1"
# 获取候选结果 # 获取候选结果
candidates = set() candidates = set()
@@ -551,13 +551,13 @@ class FullTextSearch:
try: try:
if content_type == "transcript": if content_type == "transcript":
row = conn.execute( row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,) "SELECT full_text FROM transcripts WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
return row["full_text"] if row else None return row["full_text"] if row else None
elif content_type == "entity": elif content_type == "entity":
row = conn.execute( row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,) "SELECT name, definition FROM entities WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
if row: if row:
return f"{row['name']} {row['definition'] or ''}" return f"{row['name']} {row['definition'] or ''}"
@@ -571,7 +571,7 @@ class FullTextSearch:
JOIN entities e1 ON r.source_entity_id = e1.id JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""", WHERE r.id = ?""",
(content_id,), (content_id, ),
).fetchone() ).fetchone()
if row: if row:
return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}" return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}"
@@ -589,15 +589,15 @@ class FullTextSearch:
try: try:
if content_type == "transcript": if content_type == "transcript":
row = conn.execute( row = conn.execute(
"SELECT project_id FROM transcripts WHERE id = ?", (content_id,) "SELECT project_id FROM transcripts WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
elif content_type == "entity": elif content_type == "entity":
row = conn.execute( row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (content_id,) "SELECT project_id FROM entities WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
elif content_type == "relation": elif content_type == "relation":
row = conn.execute( row = conn.execute(
"SELECT project_id FROM entity_relations WHERE id = ?", (content_id,) "SELECT project_id FROM entity_relations WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
else: else:
return None return None
@@ -654,13 +654,13 @@ class FullTextSearch:
scored.append( scored.append(
SearchResult( SearchResult(
id=result["id"], id = result["id"],
content=result["content"], content = result["content"],
content_type=result["content_type"], content_type = result["content_type"],
project_id=result["project_id"], project_id = result["project_id"],
score=round(score, 4), score = round(score, 4),
highlights=highlights[:10], # 限制高亮数量 highlights = highlights[:10], # 限制高亮数量
metadata={}, metadata = {},
) )
) )
@@ -699,7 +699,7 @@ class FullTextSearch:
snippet = snippet + "..." snippet = snippet + "..."
# 添加高亮标记 # 添加高亮标记
for term in sorted(all_terms, key=len, reverse=True): # 长的先替换 for term in sorted(all_terms, key = len, reverse = True): # 长的先替换
pattern = re.compile(re.escape(term), re.IGNORECASE) pattern = re.compile(re.escape(term), re.IGNORECASE)
snippet = pattern.sub(f"**{term}**", snippet) snippet = pattern.sub(f"**{term}**", snippet)
@@ -738,7 +738,7 @@ class FullTextSearch:
# 索引转录文本 # 索引转录文本
transcripts = conn.execute( transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for t in transcripts: for t in transcripts:
@@ -751,7 +751,7 @@ class FullTextSearch:
# 索引实体 # 索引实体
entities = conn.execute( entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for e in entities: for e in entities:
@@ -769,7 +769,7 @@ class FullTextSearch:
JOIN entities e1 ON r.source_entity_id = e1.id JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ?""", WHERE r.project_id = ?""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for r in relations: for r in relations:
@@ -805,7 +805,7 @@ class SemanticSearch:
self, self,
db_path: str = "insightflow.db", db_path: str = "insightflow.db",
model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
): ) -> None:
self.db_path = db_path self.db_path = db_path
self.model_name = model_name self.model_name = model_name
self.model = None self.model = None
@@ -873,7 +873,7 @@ class SemanticSearch:
if len(text) > max_chars: if len(text) > max_chars:
text = text[:max_chars] text = text[:max_chars]
embedding = self.model.encode(text, convert_to_list=True) embedding = self.model.encode(text, convert_to_list = True)
return embedding return embedding
except Exception as e: except Exception as e:
print(f"生成 embedding 失败: {e}") print(f"生成 embedding 失败: {e}")
@@ -971,11 +971,11 @@ class SemanticSearch:
params.append(project_id) params.append(project_id)
if content_types: if content_types:
placeholders = ",".join(["?" for _ in content_types]) placeholders = ", ".join(["?" for _ in content_types])
where_clauses.append(f"content_type IN ({placeholders})") where_clauses.append(f"content_type IN ({placeholders})")
params.extend(content_types) params.extend(content_types)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1" where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1"
rows = conn.execute( rows = conn.execute(
f""" f"""
@@ -1005,13 +1005,13 @@ class SemanticSearch:
results.append( results.append(
SemanticSearchResult( SemanticSearchResult(
id=row["content_id"], id = row["content_id"],
content=content or "", content = content or "",
content_type=row["content_type"], content_type = row["content_type"],
project_id=row["project_id"], project_id = row["project_id"],
similarity=float(similarity), similarity = float(similarity),
embedding=None, # 不返回 embedding 以节省带宽 embedding = None, # 不返回 embedding 以节省带宽
metadata={}, metadata = {},
) )
) )
except Exception as e: except Exception as e:
@@ -1019,7 +1019,7 @@ class SemanticSearch:
continue continue
# 排序并返回 top_k # 排序并返回 top_k
results.sort(key=lambda x: x.similarity, reverse=True) results.sort(key = lambda x: x.similarity, reverse = True)
return results[:top_k] return results[:top_k]
def _get_content_text(self, content_id: str, content_type: str) -> str | None: def _get_content_text(self, content_id: str, content_type: str) -> str | None:
@@ -1029,13 +1029,13 @@ class SemanticSearch:
try: try:
if content_type == "transcript": if content_type == "transcript":
row = conn.execute( row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,) "SELECT full_text FROM transcripts WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
result = row["full_text"] if row else None result = row["full_text"] if row else None
elif content_type == "entity": elif content_type == "entity":
row = conn.execute( row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,) "SELECT name, definition FROM entities WHERE id = ?", (content_id, )
).fetchone() ).fetchone()
result = f"{row['name']}: {row['definition']}" if row else None result = f"{row['name']}: {row['definition']}" if row else None
@@ -1047,7 +1047,7 @@ class SemanticSearch:
JOIN entities e1 ON r.source_entity_id = e1.id JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""", WHERE r.id = ?""",
(content_id,), (content_id, ),
).fetchone() ).fetchone()
result = ( result = (
f"{row['source_name']} {row['relation_type']} {row['target_name']}" f"{row['source_name']} {row['relation_type']} {row['target_name']}"
@@ -1121,18 +1121,18 @@ class SemanticSearch:
results.append( results.append(
SemanticSearchResult( SemanticSearchResult(
id=row["content_id"], id = row["content_id"],
content=content or "", content = content or "",
content_type=row["content_type"], content_type = row["content_type"],
project_id=row["project_id"], project_id = row["project_id"],
similarity=float(similarity), similarity = float(similarity),
metadata={}, metadata = {},
) )
) )
except (KeyError, ValueError): except (KeyError, ValueError):
continue continue
results.sort(key=lambda x: x.similarity, reverse=True) results.sort(key = lambda x: x.similarity, reverse = True)
return results[:top_k] return results[:top_k]
def delete_embedding(self, content_id: str, content_type: str) -> bool: def delete_embedding(self, content_id: str, content_type: str) -> bool:
@@ -1165,7 +1165,7 @@ class EntityPathDiscovery:
- 路径可视化数据生成 - 路径可视化数据生成
""" """
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
def _get_conn(self) -> sqlite3.Connection: def _get_conn(self) -> sqlite3.Connection:
@@ -1192,7 +1192,7 @@ class EntityPathDiscovery:
# 获取项目ID # 获取项目ID
row = conn.execute( row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,) "SELECT project_id FROM entities WHERE id = ?", (source_entity_id, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -1267,7 +1267,7 @@ class EntityPathDiscovery:
# 获取项目ID # 获取项目ID
row = conn.execute( row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,) "SELECT project_id FROM entities WHERE id = ?", (source_entity_id, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -1278,7 +1278,7 @@ class EntityPathDiscovery:
paths = [] paths = []
def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int): def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int) -> None:
if depth > max_depth: if depth > max_depth:
return return
@@ -1325,7 +1325,7 @@ class EntityPathDiscovery:
nodes = [] nodes = []
for entity_id in entity_ids: for entity_id in entity_ids:
row = conn.execute( row = conn.execute(
"SELECT id, name, type FROM entities WHERE id = ?", (entity_id,) "SELECT id, name, type FROM entities WHERE id = ?", (entity_id, )
).fetchone() ).fetchone()
if row: if row:
nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]}) nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
@@ -1368,16 +1368,16 @@ class EntityPathDiscovery:
confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0 confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0
return EntityPath( return EntityPath(
path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", path_id = f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}",
source_entity_id=entity_ids[0], source_entity_id = entity_ids[0],
source_entity_name=nodes[0]["name"] if nodes else "", source_entity_name = nodes[0]["name"] if nodes else "",
target_entity_id=entity_ids[-1], target_entity_id = entity_ids[-1],
target_entity_name=nodes[-1]["name"] if nodes else "", target_entity_name = nodes[-1]["name"] if nodes else "",
path_length=len(entity_ids) - 1, path_length = len(entity_ids) - 1,
nodes=nodes, nodes = nodes,
edges=edges, edges = edges,
confidence=round(confidence, 4), confidence = round(confidence, 4),
path_description=path_desc, path_description = path_desc,
) )
def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]:
@@ -1395,7 +1395,7 @@ class EntityPathDiscovery:
# 获取项目ID # 获取项目ID
row = conn.execute( row = conn.execute(
"SELECT project_id, name FROM entities WHERE id = ?", (entity_id,) "SELECT project_id, name FROM entities WHERE id = ?", (entity_id, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -1442,7 +1442,7 @@ class EntityPathDiscovery:
# 获取邻居信息 # 获取邻居信息
neighbor_info = conn.execute( neighbor_info = conn.execute(
"SELECT name, type FROM entities WHERE id = ?", (neighbor_id,) "SELECT name, type FROM entities WHERE id = ?", (neighbor_id, )
).fetchone() ).fetchone()
if neighbor_info: if neighbor_info:
@@ -1463,7 +1463,7 @@ class EntityPathDiscovery:
conn.close() conn.close()
# 按跳数排序 # 按跳数排序
relations.sort(key=lambda x: x["hops"]) relations.sort(key = lambda x: x["hops"])
return relations return relations
def _get_path_to_entity( def _get_path_to_entity(
@@ -1562,7 +1562,7 @@ class EntityPathDiscovery:
# 获取所有实体 # 获取所有实体
entities = conn.execute( entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,) "SELECT id, name FROM entities WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
# 计算每个实体作为桥梁的次数 # 计算每个实体作为桥梁的次数
@@ -1594,10 +1594,10 @@ class EntityPathDiscovery:
f""" f"""
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM entity_relations FROM entity_relations
WHERE ((source_entity_id IN ({",".join(["?" for _ in neighbor_ids])}) WHERE ((source_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})
AND target_entity_id IN ({",".join(["?" for _ in neighbor_ids])})) AND target_entity_id IN ({", ".join(["?" for _ in neighbor_ids])}))
OR (target_entity_id IN ({",".join(["?" for _ in neighbor_ids])}) OR (target_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})
AND source_entity_id IN ({",".join(["?" for _ in neighbor_ids])}))) AND source_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})))
AND project_id = ? AND project_id = ?
""", """,
list(neighbor_ids) * 4 + [project_id], list(neighbor_ids) * 4 + [project_id],
@@ -1620,7 +1620,7 @@ class EntityPathDiscovery:
conn.close() conn.close()
# 按桥接分数排序 # 按桥接分数排序
bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True) bridge_scores.sort(key = lambda x: x["bridge_score"], reverse = True)
return bridge_scores[:20] # 返回前20 return bridge_scores[:20] # 返回前20
@@ -1638,7 +1638,7 @@ class KnowledgeGapDetection:
- 生成知识补全建议 - 生成知识补全建议
""" """
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
def _get_conn(self) -> sqlite3.Connection: def _get_conn(self) -> sqlite3.Connection:
@@ -1676,7 +1676,7 @@ class KnowledgeGapDetection:
# 按严重程度排序 # 按严重程度排序
severity_order = {"high": 0, "medium": 1, "low": 2} severity_order = {"high": 0, "medium": 1, "low": 2}
gaps.sort(key=lambda x: severity_order.get(x.severity, 3)) gaps.sort(key = lambda x: severity_order.get(x.severity, 3))
return gaps return gaps
@@ -1688,7 +1688,7 @@ class KnowledgeGapDetection:
# 获取项目的属性模板 # 获取项目的属性模板
templates = conn.execute( templates = conn.execute(
"SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
if not templates: if not templates:
@@ -1703,7 +1703,7 @@ class KnowledgeGapDetection:
# 检查每个实体的属性完整性 # 检查每个实体的属性完整性
entities = conn.execute( entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,) "SELECT id, name FROM entities WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
for entity in entities: for entity in entities:
@@ -1711,7 +1711,7 @@ class KnowledgeGapDetection:
# 获取实体已有的属性 # 获取实体已有的属性
existing_attrs = conn.execute( existing_attrs = conn.execute(
"SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,) "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id, )
).fetchall() ).fetchall()
existing_template_ids = {a["template_id"] for a in existing_attrs} existing_template_ids = {a["template_id"] for a in existing_attrs}
@@ -1723,7 +1723,7 @@ class KnowledgeGapDetection:
missing_names = [] missing_names = []
for template_id in missing_templates: for template_id in missing_templates:
template = conn.execute( template = conn.execute(
"SELECT name FROM attribute_templates WHERE id = ?", (template_id,) "SELECT name FROM attribute_templates WHERE id = ?", (template_id, )
).fetchone() ).fetchone()
if template: if template:
missing_names.append(template["name"]) missing_names.append(template["name"])
@@ -1731,18 +1731,18 @@ class KnowledgeGapDetection:
if missing_names: if missing_names:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id=f"gap_attr_{entity_id}", gap_id = f"gap_attr_{entity_id}",
gap_type="missing_attribute", gap_type = "missing_attribute",
entity_id=entity_id, entity_id = entity_id,
entity_name=entity["name"], entity_name = entity["name"],
description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", description = f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
severity="medium", severity = "medium",
suggestions=[ suggestions = [
f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}", f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}",
"检查属性模板定义是否合理", "检查属性模板定义是否合理",
], ],
related_entities=[], related_entities = [],
metadata={"missing_attributes": missing_names}, metadata = {"missing_attributes": missing_names},
) )
) )
@@ -1756,7 +1756,7 @@ class KnowledgeGapDetection:
# 获取所有实体及其关系数量 # 获取所有实体及其关系数量
entities = conn.execute( entities = conn.execute(
"SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,) "SELECT id, name, type FROM entities WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
for entity in entities: for entity in entities:
@@ -1793,19 +1793,19 @@ class KnowledgeGapDetection:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id=f"gap_sparse_{entity_id}", gap_id = f"gap_sparse_{entity_id}",
gap_type="sparse_relation", gap_type = "sparse_relation",
entity_id=entity_id, entity_id = entity_id,
entity_name=entity["name"], entity_name = entity["name"],
description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", description = f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
severity="medium" if relation_count == 0 else "low", severity = "medium" if relation_count == 0 else "low",
suggestions=[ suggestions = [
f"检查转录文本中提及 '{entity['name']}' 的其他实体", f"检查转录文本中提及 '{entity['name']}' 的其他实体",
f"手动添加 '{entity['name']}' 与其他实体的关系", f"手动添加 '{entity['name']}' 与其他实体的关系",
"使用实体对齐功能合并相似实体", "使用实体对齐功能合并相似实体",
], ],
related_entities=[r["id"] for r in potential_related], related_entities = [r["id"] for r in potential_related],
metadata={ metadata = {
"relation_count": relation_count, "relation_count": relation_count,
"potential_related": [r["name"] for r in potential_related], "potential_related": [r["name"] for r in potential_related],
}, },
@@ -1831,25 +1831,25 @@ class KnowledgeGapDetection:
AND r1.id IS NULL AND r1.id IS NULL
AND r2.id IS NULL AND r2.id IS NULL
""", """,
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for entity in isolated: for entity in isolated:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id=f"gap_iso_{entity['id']}", gap_id = f"gap_iso_{entity['id']}",
gap_type="isolated_entity", gap_type = "isolated_entity",
entity_id=entity["id"], entity_id = entity["id"],
entity_name=entity["name"], entity_name = entity["name"],
description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", description = f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
severity="high", severity = "high",
suggestions=[ suggestions = [
f"检查 '{entity['name']}' 是否应该与其他实体建立关系", f"检查 '{entity['name']}' 是否应该与其他实体建立关系",
f"考虑删除不相关的实体 '{entity['name']}'", f"考虑删除不相关的实体 '{entity['name']}'",
"运行关系发现算法自动识别潜在关系", "运行关系发现算法自动识别潜在关系",
], ],
related_entities=[], related_entities = [],
metadata={"entity_type": entity["type"]}, metadata = {"entity_type": entity["type"]},
) )
) )
@@ -1869,21 +1869,21 @@ class KnowledgeGapDetection:
WHERE project_id = ? WHERE project_id = ?
AND (definition IS NULL OR definition = '') AND (definition IS NULL OR definition = '')
""", """,
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for entity in incomplete: for entity in incomplete:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id=f"gap_inc_{entity['id']}", gap_id = f"gap_inc_{entity['id']}",
gap_type="incomplete_entity", gap_type = "incomplete_entity",
entity_id=entity["id"], entity_id = entity["id"],
entity_name=entity["name"], entity_name = entity["name"],
description=f"实体 '{entity['name']}' 缺少定义", description = f"实体 '{entity['name']}' 缺少定义",
severity="low", severity = "low",
suggestions=[f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"], suggestions = [f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"],
related_entities=[], related_entities = [],
metadata={"entity_type": entity["type"]}, metadata = {"entity_type": entity["type"]},
) )
) )
@@ -1897,7 +1897,7 @@ class KnowledgeGapDetection:
# 分析转录文本中频繁提及但未提取为实体的词 # 分析转录文本中频繁提及但未提取为实体的词
transcripts = conn.execute( transcripts = conn.execute(
"SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,) "SELECT full_text FROM transcripts WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
# 合并所有文本 # 合并所有文本
@@ -1905,7 +1905,7 @@ class KnowledgeGapDetection:
# 获取现有实体名称 # 获取现有实体名称
existing_entities = conn.execute( existing_entities = conn.execute(
"SELECT name FROM entities WHERE project_id = ?", (project_id,) "SELECT name FROM entities WHERE project_id = ?", (project_id, )
).fetchall() ).fetchall()
existing_names = {e["name"].lower() for e in existing_entities} existing_names = {e["name"].lower() for e in existing_entities}
@@ -1925,18 +1925,18 @@ class KnowledgeGapDetection:
if count >= 3: # 出现3次以上 if count >= 3: # 出现3次以上
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id=f"gap_missing_{hash(entity) % 10000}", gap_id = f"gap_missing_{hash(entity) % 10000}",
gap_type="missing_key_entity", gap_type = "missing_key_entity",
entity_id=None, entity_id = None,
entity_name=None, entity_name = None,
description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", description = f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
severity="low", severity = "low",
suggestions=[ suggestions = [
f"考虑将 '{entity}' 添加为实体", f"考虑将 '{entity}' 添加为实体",
"检查实体提取算法是否需要优化", "检查实体提取算法是否需要优化",
], ],
related_entities=[], related_entities = [],
metadata={"mention_count": count}, metadata = {"mention_count": count},
) )
) )
@@ -2040,7 +2040,7 @@ class SearchManager:
整合全文搜索、语义搜索、实体路径发现和知识缺口识别功能 整合全文搜索、语义搜索、实体路径发现和知识缺口识别功能
""" """
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self.fulltext_search = FullTextSearch(db_path) self.fulltext_search = FullTextSearch(db_path)
self.semantic_search = SemanticSearch(db_path) self.semantic_search = SemanticSearch(db_path)
@@ -2060,12 +2060,12 @@ class SearchManager:
Dict: 混合搜索结果 Dict: 混合搜索结果
""" """
# 全文搜索 # 全文搜索
fulltext_results = self.fulltext_search.search(query, project_id, limit=limit) fulltext_results = self.fulltext_search.search(query, project_id, limit = limit)
# 语义搜索 # 语义搜索
semantic_results = [] semantic_results = []
if self.semantic_search.is_available(): if self.semantic_search.is_available():
semantic_results = self.semantic_search.search(query, project_id, top_k=limit) semantic_results = self.semantic_search.search(query, project_id, top_k = limit)
# 合并结果(去重并加权) # 合并结果(去重并加权)
combined = {} combined = {}
@@ -2104,7 +2104,7 @@ class SearchManager:
# 排序 # 排序
results = list(combined.values()) results = list(combined.values())
results.sort(key=lambda x: x["combined_score"], reverse=True) results.sort(key = lambda x: x["combined_score"], reverse = True)
return { return {
"query": query, "query": query,
@@ -2138,7 +2138,7 @@ class SearchManager:
# 索引转录文本 # 索引转录文本
transcripts = conn.execute( transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for t in transcripts: for t in transcripts:
@@ -2152,7 +2152,7 @@ class SearchManager:
# 索引实体 # 索引实体
entities = conn.execute( entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
for e in entities: for e in entities:
@@ -2191,7 +2191,7 @@ class SearchManager:
"""SELECT content_type, COUNT(*) as count """SELECT content_type, COUNT(*) as count
FROM search_indexes WHERE project_id = ? FROM search_indexes WHERE project_id = ?
GROUP BY content_type""", GROUP BY content_type""",
(project_id,), (project_id, ),
).fetchall() ).fetchall()
type_stats = {r["content_type"]: r["count"] for r in rows} type_stats = {r["content_type"]: r["count"] for r in rows}
@@ -2226,7 +2226,7 @@ def fulltext_search(
) -> list[SearchResult]: ) -> list[SearchResult]:
"""全文搜索便捷函数""" """全文搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit) return manager.fulltext_search.search(query, project_id, limit = limit)
def semantic_search( def semantic_search(
@@ -2234,7 +2234,7 @@ def semantic_search(
) -> list[SemanticSearchResult]: ) -> list[SemanticSearchResult]:
"""语义搜索便捷函数""" """语义搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k) return manager.semantic_search.search(query, project_id, top_k = top_k)
def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None:

View File

@@ -86,7 +86,7 @@ class AuditLog:
after_value: str | None = None after_value: str | None = None
success: bool = True success: bool = True
error_message: str | None = None error_message: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -103,8 +103,8 @@ class EncryptionConfig:
key_derivation: str = "pbkdf2" # pbkdf2, argon2 key_derivation: str = "pbkdf2" # pbkdf2, argon2
master_key_hash: str | None = None # 主密钥哈希(用于验证) master_key_hash: str | None = None # 主密钥哈希(用于验证)
salt: str | None = None salt: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory = lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -123,8 +123,8 @@ class MaskingRule:
is_active: bool = True is_active: bool = True
priority: int = 0 priority: int = 0
description: str | None = None description: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory = lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -145,8 +145,8 @@ class DataAccessPolicy:
max_access_count: int | None = None # 最大访问次数 max_access_count: int | None = None # 最大访问次数
require_approval: bool = False require_approval: bool = False
is_active: bool = True is_active: bool = True
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory = lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -164,7 +164,7 @@ class AccessRequest:
approved_by: str | None = None approved_by: str | None = None
approved_at: str | None = None approved_at: str | None = None
expires_at: str | None = None expires_at: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -176,7 +176,7 @@ class SecurityManager:
# 预定义脱敏规则 # 预定义脱敏规则
DEFAULT_MASKING_RULES = { DEFAULT_MASKING_RULES = {
MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"}, MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"}, MaskingRuleType.EMAIL: {"pattern": r"(\w{1, 3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
MaskingRuleType.ID_CARD: { MaskingRuleType.ID_CARD: {
"pattern": r"(\d{6})\d{8}(\d{4})", "pattern": r"(\d{6})\d{8}(\d{4})",
"replacement": r"\1********\2", "replacement": r"\1********\2",
@@ -190,12 +190,12 @@ class SecurityManager:
"replacement": r"\1**", "replacement": r"\1**",
}, },
MaskingRuleType.ADDRESS: { MaskingRuleType.ADDRESS: {
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)", "pattern": r"([\u4e00-\u9fa5]{2, })([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
"replacement": r"\1\2***", "replacement": r"\1\2***",
}, },
} }
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
# 预编译正则缓存 # 预编译正则缓存
self._compiled_patterns: dict[str, re.Pattern] = {} self._compiled_patterns: dict[str, re.Pattern] = {}
@@ -345,18 +345,18 @@ class SecurityManager:
) -> AuditLog: ) -> AuditLog:
"""记录审计日志""" """记录审计日志"""
log = AuditLog( log = AuditLog(
id=self._generate_id(), id = self._generate_id(),
action_type=action_type.value, action_type = action_type.value,
user_id=user_id, user_id = user_id,
user_ip=user_ip, user_ip = user_ip,
user_agent=user_agent, user_agent = user_agent,
resource_type=resource_type, resource_type = resource_type,
resource_id=resource_id, resource_id = resource_id,
action_details=json.dumps(action_details) if action_details else None, action_details = json.dumps(action_details) if action_details else None,
before_value=before_value, before_value = before_value,
after_value=after_value, after_value = after_value,
success=success, success = success,
error_message=error_message, error_message = error_message,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -405,7 +405,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT * FROM audit_logs WHERE 1=1" query = "SELECT * FROM audit_logs WHERE 1 = 1"
params = [] params = []
if user_id: if user_id:
@@ -444,19 +444,19 @@ class SecurityManager:
for row in rows: for row in rows:
log = AuditLog( log = AuditLog(
id=row[0], id = row[0],
action_type=row[1], action_type = row[1],
user_id=row[2], user_id = row[2],
user_ip=row[3], user_ip = row[3],
user_agent=row[4], user_agent = row[4],
resource_type=row[5], resource_type = row[5],
resource_id=row[6], resource_id = row[6],
action_details=row[7], action_details = row[7],
before_value=row[8], before_value = row[8],
after_value=row[9], after_value = row[9],
success=bool(row[10]), success = bool(row[10]),
error_message=row[11], error_message = row[11],
created_at=row[12], created_at = row[12],
) )
logs.append(log) logs.append(log)
@@ -470,7 +470,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1=1" query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1 = 1"
params = [] params = []
if start_time: if start_time:
@@ -513,10 +513,10 @@ class SecurityManager:
raise RuntimeError("cryptography library not available") raise RuntimeError("cryptography library not available")
kdf = PBKDF2HMAC( kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(), algorithm = hashes.SHA256(),
length=32, length = 32,
salt=salt, salt = salt,
iterations=100000, iterations = 100000,
) )
return base64.urlsafe_b64encode(kdf.derive(password.encode())) return base64.urlsafe_b64encode(kdf.derive(password.encode()))
@@ -533,20 +533,20 @@ class SecurityManager:
key_hash = hashlib.sha256(key).hexdigest() key_hash = hashlib.sha256(key).hexdigest()
config = EncryptionConfig( config = EncryptionConfig(
id=self._generate_id(), id = self._generate_id(),
project_id=project_id, project_id = project_id,
is_enabled=True, is_enabled = True,
encryption_type="aes-256-gcm", encryption_type = "aes-256-gcm",
key_derivation="pbkdf2", key_derivation = "pbkdf2",
master_key_hash=key_hash, master_key_hash = key_hash,
salt=salt, salt = salt,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
# 检查是否已存在配置 # 检查是否已存在配置
cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,)) cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id, ))
existing = cursor.fetchone() existing = cursor.fetchone()
if existing: if existing:
@@ -593,10 +593,10 @@ class SecurityManager:
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(
action_type=AuditActionType.ENCRYPTION_ENABLE, action_type = AuditActionType.ENCRYPTION_ENABLE,
resource_type="project", resource_type = "project",
resource_id=project_id, resource_id = project_id,
action_details={"encryption_type": config.encryption_type}, action_details = {"encryption_type": config.encryption_type},
) )
return config return config
@@ -624,9 +624,9 @@ class SecurityManager:
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(
action_type=AuditActionType.ENCRYPTION_DISABLE, action_type = AuditActionType.ENCRYPTION_DISABLE,
resource_type="project", resource_type = "project",
resource_id=project_id, resource_id = project_id,
) )
return True return True
@@ -641,7 +641,7 @@ class SecurityManager:
cursor.execute( cursor.execute(
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", "SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
(project_id,), (project_id, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -660,7 +660,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,)) cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id, ))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -668,15 +668,15 @@ class SecurityManager:
return None return None
return EncryptionConfig( return EncryptionConfig(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
is_enabled=bool(row[2]), is_enabled = bool(row[2]),
encryption_type=row[3], encryption_type = row[3],
key_derivation=row[4], key_derivation = row[4],
master_key_hash=row[5], master_key_hash = row[5],
salt=row[6], salt = row[6],
created_at=row[7], created_at = row[7],
updated_at=row[8], updated_at = row[8],
) )
def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
@@ -724,14 +724,14 @@ class SecurityManager:
replacement = replacement or default["replacement"] replacement = replacement or default["replacement"]
rule = MaskingRule( rule = MaskingRule(
id=self._generate_id(), id = self._generate_id(),
project_id=project_id, project_id = project_id,
name=name, name = name,
rule_type=rule_type.value, rule_type = rule_type.value,
pattern=pattern or "", pattern = pattern or "",
replacement=replacement or "****", replacement = replacement or "****",
description=description, description = description,
priority=priority, priority = priority,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -764,10 +764,10 @@ class SecurityManager:
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(
action_type=AuditActionType.DATA_MASKING, action_type = AuditActionType.DATA_MASKING,
resource_type="project", resource_type = "project",
resource_id=project_id, resource_id = project_id,
action_details={"action": "create_rule", "rule_name": name}, action_details = {"action": "create_rule", "rule_name": name},
) )
return rule return rule
@@ -793,17 +793,17 @@ class SecurityManager:
for row in rows: for row in rows:
rules.append( rules.append(
MaskingRule( MaskingRule(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
name=row[2], name = row[2],
rule_type=row[3], rule_type = row[3],
pattern=row[4], pattern = row[4],
replacement=row[5], replacement = row[5],
is_active=bool(row[6]), is_active = bool(row[6]),
priority=row[7], priority = row[7],
description=row[8], description = row[8],
created_at=row[9], created_at = row[9],
updated_at=row[10], updated_at = row[10],
) )
) )
@@ -847,7 +847,7 @@ class SecurityManager:
# 获取更新后的规则 # 获取更新后的规则
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id,)) cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id, ))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -855,17 +855,17 @@ class SecurityManager:
return None return None
return MaskingRule( return MaskingRule(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
name=row[2], name = row[2],
rule_type=row[3], rule_type = row[3],
pattern=row[4], pattern = row[4],
replacement=row[5], replacement = row[5],
is_active=bool(row[6]), is_active = bool(row[6]),
priority=row[7], priority = row[7],
description=row[8], description = row[8],
created_at=row[9], created_at = row[9],
updated_at=row[10], updated_at = row[10],
) )
def delete_masking_rule(self, rule_id: str) -> bool: def delete_masking_rule(self, rule_id: str) -> bool:
@@ -873,7 +873,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id,)) cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id, ))
success = cursor.rowcount > 0 success = cursor.rowcount > 0
conn.commit() conn.commit()
@@ -936,16 +936,16 @@ class SecurityManager:
) -> DataAccessPolicy: ) -> DataAccessPolicy:
"""创建数据访问策略""" """创建数据访问策略"""
policy = DataAccessPolicy( policy = DataAccessPolicy(
id=self._generate_id(), id = self._generate_id(),
project_id=project_id, project_id = project_id,
name=name, name = name,
description=description, description = description,
allowed_users=json.dumps(allowed_users) if allowed_users else None, allowed_users = json.dumps(allowed_users) if allowed_users else None,
allowed_roles=json.dumps(allowed_roles) if allowed_roles else None, allowed_roles = json.dumps(allowed_roles) if allowed_roles else None,
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None, allowed_ips = json.dumps(allowed_ips) if allowed_ips else None,
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None, time_restrictions = json.dumps(time_restrictions) if time_restrictions else None,
max_access_count=max_access_count, max_access_count = max_access_count,
require_approval=require_approval, require_approval = require_approval,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1002,19 +1002,19 @@ class SecurityManager:
for row in rows: for row in rows:
policies.append( policies.append(
DataAccessPolicy( DataAccessPolicy(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
name=row[2], name = row[2],
description=row[3], description = row[3],
allowed_users=row[4], allowed_users = row[4],
allowed_roles=row[5], allowed_roles = row[5],
allowed_ips=row[6], allowed_ips = row[6],
time_restrictions=row[7], time_restrictions = row[7],
max_access_count=row[8], max_access_count = row[8],
require_approval=bool(row[9]), require_approval = bool(row[9]),
is_active=bool(row[10]), is_active = bool(row[10]),
created_at=row[11], created_at = row[11],
updated_at=row[12], updated_at = row[12],
) )
) )
@@ -1028,7 +1028,7 @@ class SecurityManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,) "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id, )
) )
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -1037,19 +1037,19 @@ class SecurityManager:
return False, "Policy not found or inactive" return False, "Policy not found or inactive"
policy = DataAccessPolicy( policy = DataAccessPolicy(
id=row[0], id = row[0],
project_id=row[1], project_id = row[1],
name=row[2], name = row[2],
description=row[3], description = row[3],
allowed_users=row[4], allowed_users = row[4],
allowed_roles=row[5], allowed_roles = row[5],
allowed_ips=row[6], allowed_ips = row[6],
time_restrictions=row[7], time_restrictions = row[7],
max_access_count=row[8], max_access_count = row[8],
require_approval=bool(row[9]), require_approval = bool(row[9]),
is_active=bool(row[10]), is_active = bool(row[10]),
created_at=row[11], created_at = row[11],
updated_at=row[12], updated_at = row[12],
) )
# 检查用户白名单 # 检查用户白名单
@@ -1113,7 +1113,7 @@ class SecurityManager:
try: try:
if "/" in pattern: if "/" in pattern:
# CIDR 表示法 # CIDR 表示法
network = ipaddress.ip_network(pattern, strict=False) network = ipaddress.ip_network(pattern, strict = False)
return ipaddress.ip_address(ip) in network return ipaddress.ip_address(ip) in network
else: else:
# 精确匹配 # 精确匹配
@@ -1130,11 +1130,11 @@ class SecurityManager:
) -> AccessRequest: ) -> AccessRequest:
"""创建访问请求""" """创建访问请求"""
request = AccessRequest( request = AccessRequest(
id=self._generate_id(), id = self._generate_id(),
policy_id=policy_id, policy_id = policy_id,
user_id=user_id, user_id = user_id,
request_reason=request_reason, request_reason = request_reason,
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(), expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat(),
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1169,7 +1169,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat() expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat()
approved_at = datetime.now().isoformat() approved_at = datetime.now().isoformat()
cursor.execute( cursor.execute(
@@ -1184,7 +1184,7 @@ class SecurityManager:
conn.commit() conn.commit()
# 获取更新后的请求 # 获取更新后的请求
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,)) cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, ))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -1192,15 +1192,15 @@ class SecurityManager:
return None return None
return AccessRequest( return AccessRequest(
id=row[0], id = row[0],
policy_id=row[1], policy_id = row[1],
user_id=row[2], user_id = row[2],
request_reason=row[3], request_reason = row[3],
status=row[4], status = row[4],
approved_by=row[5], approved_by = row[5],
approved_at=row[6], approved_at = row[6],
expires_at=row[7], expires_at = row[7],
created_at=row[8], created_at = row[8],
) )
def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None: def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
@@ -1219,7 +1219,7 @@ class SecurityManager:
conn.commit() conn.commit()
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,)) cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, ))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -1227,15 +1227,15 @@ class SecurityManager:
return None return None
return AccessRequest( return AccessRequest(
id=row[0], id = row[0],
policy_id=row[1], policy_id = row[1],
user_id=row[2], user_id = row[2],
request_reason=row[3], request_reason = row[3],
status=row[4], status = row[4],
approved_by=row[5], approved_by = row[5],
approved_at=row[6], approved_at = row[6],
expires_at=row[7], expires_at = row[7],
created_at=row[8], created_at = row[8],
) )

View File

@@ -313,7 +313,7 @@ class SubscriptionManager:
"export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页PDF导出 "export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页PDF导出
} }
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self._init_db() self._init_db()
self._init_default_plans() self._init_default_plans()
@@ -572,7 +572,7 @@ class SubscriptionManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id,)) cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -588,7 +588,7 @@ class SubscriptionManager:
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,) "SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier, )
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -635,19 +635,19 @@ class SubscriptionManager:
plan_id = str(uuid.uuid4()) plan_id = str(uuid.uuid4())
plan = SubscriptionPlan( plan = SubscriptionPlan(
id=plan_id, id = plan_id,
name=name, name = name,
tier=tier, tier = tier,
description=description, description = description,
price_monthly=price_monthly, price_monthly = price_monthly,
price_yearly=price_yearly, price_yearly = price_yearly,
currency=currency, currency = currency,
features=features or [], features = features or [],
limits=limits or {}, limits = limits or {},
is_active=True, is_active = True,
created_at=datetime.now(), created_at = datetime.now(),
updated_at=datetime.now(), updated_at = datetime.now(),
metadata={}, metadata = {},
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -760,7 +760,7 @@ class SubscriptionManager:
SELECT * FROM subscriptions SELECT * FROM subscriptions
WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending') WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending')
""", """,
(tenant_id,), (tenant_id, ),
) )
existing = cursor.fetchone() existing = cursor.fetchone()
@@ -777,36 +777,36 @@ class SubscriptionManager:
# 计算周期 # 计算周期
if billing_cycle == "yearly": if billing_cycle == "yearly":
period_end = now + timedelta(days=365) period_end = now + timedelta(days = 365)
else: else:
period_end = now + timedelta(days=30) period_end = now + timedelta(days = 30)
# 试用处理 # 试用处理
trial_start = None trial_start = None
trial_end = None trial_end = None
if trial_days > 0: if trial_days > 0:
trial_start = now trial_start = now
trial_end = now + timedelta(days=trial_days) trial_end = now + timedelta(days = trial_days)
status = SubscriptionStatus.TRIAL.value status = SubscriptionStatus.TRIAL.value
else: else:
status = SubscriptionStatus.PENDING.value status = SubscriptionStatus.PENDING.value
subscription = Subscription( subscription = Subscription(
id=subscription_id, id = subscription_id,
tenant_id=tenant_id, tenant_id = tenant_id,
plan_id=plan_id, plan_id = plan_id,
status=status, status = status,
current_period_start=now, current_period_start = now,
current_period_end=period_end, current_period_end = period_end,
cancel_at_period_end=False, cancel_at_period_end = False,
canceled_at=None, canceled_at = None,
trial_start=trial_start, trial_start = trial_start,
trial_end=trial_end, trial_end = trial_end,
payment_provider=payment_provider, payment_provider = payment_provider,
provider_subscription_id=None, provider_subscription_id = None,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
metadata={"billing_cycle": billing_cycle}, metadata = {"billing_cycle": billing_cycle},
) )
cursor.execute( cursor.execute(
@@ -878,7 +878,7 @@ class SubscriptionManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id,)) cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -899,7 +899,7 @@ class SubscriptionManager:
WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending') WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending')
ORDER BY created_at DESC LIMIT 1 ORDER BY created_at DESC LIMIT 1
""", """,
(tenant_id,), (tenant_id, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -1087,15 +1087,15 @@ class SubscriptionManager:
record_id = str(uuid.uuid4()) record_id = str(uuid.uuid4())
record = UsageRecord( record = UsageRecord(
id=record_id, id = record_id,
tenant_id=tenant_id, tenant_id = tenant_id,
resource_type=resource_type, resource_type = resource_type,
quantity=quantity, quantity = quantity,
unit=unit, unit = unit,
recorded_at=datetime.now(), recorded_at = datetime.now(),
cost=cost, cost = cost,
description=description, description = description,
metadata=metadata or {}, metadata = metadata or {},
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1214,22 +1214,22 @@ class SubscriptionManager:
now = datetime.now() now = datetime.now()
payment = Payment( payment = Payment(
id=payment_id, id = payment_id,
tenant_id=tenant_id, tenant_id = tenant_id,
subscription_id=subscription_id, subscription_id = subscription_id,
invoice_id=invoice_id, invoice_id = invoice_id,
amount=amount, amount = amount,
currency=currency, currency = currency,
provider=provider, provider = provider,
provider_payment_id=None, provider_payment_id = None,
status=PaymentStatus.PENDING.value, status = PaymentStatus.PENDING.value,
payment_method=payment_method, payment_method = payment_method,
payment_details=payment_details or {}, payment_details = payment_details or {},
paid_at=None, paid_at = None,
failed_at=None, failed_at = None,
failure_reason=None, failure_reason = None,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1389,7 +1389,7 @@ class SubscriptionManager:
def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None: def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None:
"""内部方法:获取支付记录""" """内部方法:获取支付记录"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id,)) cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1414,27 +1414,27 @@ class SubscriptionManager:
invoice_id = str(uuid.uuid4()) invoice_id = str(uuid.uuid4())
invoice_number = self._generate_invoice_number() invoice_number = self._generate_invoice_number()
now = datetime.now() now = datetime.now()
due_date = now + timedelta(days=7) # 7天付款期限 due_date = now + timedelta(days = 7) # 7天付款期限
invoice = Invoice( invoice = Invoice(
id=invoice_id, id = invoice_id,
tenant_id=tenant_id, tenant_id = tenant_id,
subscription_id=subscription_id, subscription_id = subscription_id,
invoice_number=invoice_number, invoice_number = invoice_number,
status=InvoiceStatus.DRAFT.value, status = InvoiceStatus.DRAFT.value,
amount_due=amount, amount_due = amount,
amount_paid=0, amount_paid = 0,
currency=currency, currency = currency,
period_start=period_start, period_start = period_start,
period_end=period_end, period_end = period_end,
description=description, description = description,
line_items=line_items or [{"description": description, "amount": amount}], line_items = line_items or [{"description": description, "amount": amount}],
due_date=due_date, due_date = due_date,
paid_at=None, paid_at = None,
voided_at=None, voided_at = None,
void_reason=None, void_reason = None,
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1475,7 +1475,7 @@ class SubscriptionManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id,)) cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1490,7 +1490,7 @@ class SubscriptionManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number,)) cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1568,7 +1568,7 @@ class SubscriptionManager:
SELECT COUNT(*) as count FROM invoices SELECT COUNT(*) as count FROM invoices
WHERE invoice_number LIKE ? WHERE invoice_number LIKE ?
""", """,
(f"{prefix}%",), (f"{prefix}%", ),
) )
row = cursor.fetchone() row = cursor.fetchone()
count = row["count"] + 1 count = row["count"] + 1
@@ -1604,23 +1604,23 @@ class SubscriptionManager:
now = datetime.now() now = datetime.now()
refund = Refund( refund = Refund(
id=refund_id, id = refund_id,
tenant_id=tenant_id, tenant_id = tenant_id,
payment_id=payment_id, payment_id = payment_id,
invoice_id=payment.invoice_id, invoice_id = payment.invoice_id,
amount=amount, amount = amount,
currency=payment.currency, currency = payment.currency,
reason=reason, reason = reason,
status=RefundStatus.PENDING.value, status = RefundStatus.PENDING.value,
requested_by=requested_by, requested_by = requested_by,
requested_at=now, requested_at = now,
approved_by=None, approved_by = None,
approved_at=None, approved_at = None,
completed_at=None, completed_at = None,
provider_refund_id=None, provider_refund_id = None,
metadata={}, metadata = {},
created_at=now, created_at = now,
updated_at=now, updated_at = now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1803,7 +1803,7 @@ class SubscriptionManager:
def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None: def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None:
"""内部方法:获取退款记录""" """内部方法:获取退款记录"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id,)) cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -1822,7 +1822,7 @@ class SubscriptionManager:
description: str, description: str,
reference_id: str, reference_id: str,
balance_after: float, balance_after: float,
): ) -> None:
"""内部方法:添加账单历史""" """内部方法:添加账单历史"""
history_id = str(uuid.uuid4()) history_id = str(uuid.uuid4())
@@ -1962,126 +1962,126 @@ class SubscriptionManager:
def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan: def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan:
"""数据库行转换为 SubscriptionPlan 对象""" """数据库行转换为 SubscriptionPlan 对象"""
return SubscriptionPlan( return SubscriptionPlan(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
tier=row["tier"], tier = row["tier"],
description=row["description"] or "", description = row["description"] or "",
price_monthly=row["price_monthly"], price_monthly = row["price_monthly"],
price_yearly=row["price_yearly"], price_yearly = row["price_yearly"],
currency=row["currency"], currency = row["currency"],
features=json.loads(row["features"] or "[]"), features = json.loads(row["features"] or "[]"),
limits=json.loads(row["limits"] or "{}"), limits = json.loads(row["limits"] or "{}"),
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
metadata=json.loads(row["metadata"] or "{}"), metadata = json.loads(row["metadata"] or "{}"),
) )
def _row_to_subscription(self, row: sqlite3.Row) -> Subscription: def _row_to_subscription(self, row: sqlite3.Row) -> Subscription:
"""数据库行转换为 Subscription 对象""" """数据库行转换为 Subscription 对象"""
return Subscription( return Subscription(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
plan_id=row["plan_id"], plan_id = row["plan_id"],
status=row["status"], status = row["status"],
current_period_start=( current_period_start = (
datetime.fromisoformat(row["current_period_start"]) datetime.fromisoformat(row["current_period_start"])
if row["current_period_start"] and isinstance(row["current_period_start"], str) if row["current_period_start"] and isinstance(row["current_period_start"], str)
else row["current_period_start"] else row["current_period_start"]
), ),
current_period_end=( current_period_end = (
datetime.fromisoformat(row["current_period_end"]) datetime.fromisoformat(row["current_period_end"])
if row["current_period_end"] and isinstance(row["current_period_end"], str) if row["current_period_end"] and isinstance(row["current_period_end"], str)
else row["current_period_end"] else row["current_period_end"]
), ),
cancel_at_period_end=bool(row["cancel_at_period_end"]), cancel_at_period_end = bool(row["cancel_at_period_end"]),
canceled_at=( canceled_at = (
datetime.fromisoformat(row["canceled_at"]) datetime.fromisoformat(row["canceled_at"])
if row["canceled_at"] and isinstance(row["canceled_at"], str) if row["canceled_at"] and isinstance(row["canceled_at"], str)
else row["canceled_at"] else row["canceled_at"]
), ),
trial_start=( trial_start = (
datetime.fromisoformat(row["trial_start"]) datetime.fromisoformat(row["trial_start"])
if row["trial_start"] and isinstance(row["trial_start"], str) if row["trial_start"] and isinstance(row["trial_start"], str)
else row["trial_start"] else row["trial_start"]
), ),
trial_end=( trial_end = (
datetime.fromisoformat(row["trial_end"]) datetime.fromisoformat(row["trial_end"])
if row["trial_end"] and isinstance(row["trial_end"], str) if row["trial_end"] and isinstance(row["trial_end"], str)
else row["trial_end"] else row["trial_end"]
), ),
payment_provider=row["payment_provider"], payment_provider = row["payment_provider"],
provider_subscription_id=row["provider_subscription_id"], provider_subscription_id = row["provider_subscription_id"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
metadata=json.loads(row["metadata"] or "{}"), metadata = json.loads(row["metadata"] or "{}"),
) )
def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord: def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord:
"""数据库行转换为 UsageRecord 对象""" """数据库行转换为 UsageRecord 对象"""
return UsageRecord( return UsageRecord(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
resource_type=row["resource_type"], resource_type = row["resource_type"],
quantity=row["quantity"], quantity = row["quantity"],
unit=row["unit"], unit = row["unit"],
recorded_at=( recorded_at = (
datetime.fromisoformat(row["recorded_at"]) datetime.fromisoformat(row["recorded_at"])
if isinstance(row["recorded_at"], str) if isinstance(row["recorded_at"], str)
else row["recorded_at"] else row["recorded_at"]
), ),
cost=row["cost"], cost = row["cost"],
description=row["description"], description = row["description"],
metadata=json.loads(row["metadata"] or "{}"), metadata = json.loads(row["metadata"] or "{}"),
) )
def _row_to_payment(self, row: sqlite3.Row) -> Payment: def _row_to_payment(self, row: sqlite3.Row) -> Payment:
"""数据库行转换为 Payment 对象""" """数据库行转换为 Payment 对象"""
return Payment( return Payment(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
subscription_id=row["subscription_id"], subscription_id = row["subscription_id"],
invoice_id=row["invoice_id"], invoice_id = row["invoice_id"],
amount=row["amount"], amount = row["amount"],
currency=row["currency"], currency = row["currency"],
provider=row["provider"], provider = row["provider"],
provider_payment_id=row["provider_payment_id"], provider_payment_id = row["provider_payment_id"],
status=row["status"], status = row["status"],
payment_method=row["payment_method"], payment_method = row["payment_method"],
payment_details=json.loads(row["payment_details"] or "{}"), payment_details = json.loads(row["payment_details"] or "{}"),
paid_at=( paid_at = (
datetime.fromisoformat(row["paid_at"]) datetime.fromisoformat(row["paid_at"])
if row["paid_at"] and isinstance(row["paid_at"], str) if row["paid_at"] and isinstance(row["paid_at"], str)
else row["paid_at"] else row["paid_at"]
), ),
failed_at=( failed_at = (
datetime.fromisoformat(row["failed_at"]) datetime.fromisoformat(row["failed_at"])
if row["failed_at"] and isinstance(row["failed_at"], str) if row["failed_at"] and isinstance(row["failed_at"], str)
else row["failed_at"] else row["failed_at"]
), ),
failure_reason=row["failure_reason"], failure_reason = row["failure_reason"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2091,48 +2091,48 @@ class SubscriptionManager:
def _row_to_invoice(self, row: sqlite3.Row) -> Invoice: def _row_to_invoice(self, row: sqlite3.Row) -> Invoice:
"""数据库行转换为 Invoice 对象""" """数据库行转换为 Invoice 对象"""
return Invoice( return Invoice(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
subscription_id=row["subscription_id"], subscription_id = row["subscription_id"],
invoice_number=row["invoice_number"], invoice_number = row["invoice_number"],
status=row["status"], status = row["status"],
amount_due=row["amount_due"], amount_due = row["amount_due"],
amount_paid=row["amount_paid"], amount_paid = row["amount_paid"],
currency=row["currency"], currency = row["currency"],
period_start=( period_start = (
datetime.fromisoformat(row["period_start"]) datetime.fromisoformat(row["period_start"])
if row["period_start"] and isinstance(row["period_start"], str) if row["period_start"] and isinstance(row["period_start"], str)
else row["period_start"] else row["period_start"]
), ),
period_end=( period_end = (
datetime.fromisoformat(row["period_end"]) datetime.fromisoformat(row["period_end"])
if row["period_end"] and isinstance(row["period_end"], str) if row["period_end"] and isinstance(row["period_end"], str)
else row["period_end"] else row["period_end"]
), ),
description=row["description"], description = row["description"],
line_items=json.loads(row["line_items"] or "[]"), line_items = json.loads(row["line_items"] or "[]"),
due_date=( due_date = (
datetime.fromisoformat(row["due_date"]) datetime.fromisoformat(row["due_date"])
if row["due_date"] and isinstance(row["due_date"], str) if row["due_date"] and isinstance(row["due_date"], str)
else row["due_date"] else row["due_date"]
), ),
paid_at=( paid_at = (
datetime.fromisoformat(row["paid_at"]) datetime.fromisoformat(row["paid_at"])
if row["paid_at"] and isinstance(row["paid_at"], str) if row["paid_at"] and isinstance(row["paid_at"], str)
else row["paid_at"] else row["paid_at"]
), ),
voided_at=( voided_at = (
datetime.fromisoformat(row["voided_at"]) datetime.fromisoformat(row["voided_at"])
if row["voided_at"] and isinstance(row["voided_at"], str) if row["voided_at"] and isinstance(row["voided_at"], str)
else row["voided_at"] else row["voided_at"]
), ),
void_reason=row["void_reason"], void_reason = row["void_reason"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2142,39 +2142,39 @@ class SubscriptionManager:
def _row_to_refund(self, row: sqlite3.Row) -> Refund: def _row_to_refund(self, row: sqlite3.Row) -> Refund:
"""数据库行转换为 Refund 对象""" """数据库行转换为 Refund 对象"""
return Refund( return Refund(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
payment_id=row["payment_id"], payment_id = row["payment_id"],
invoice_id=row["invoice_id"], invoice_id = row["invoice_id"],
amount=row["amount"], amount = row["amount"],
currency=row["currency"], currency = row["currency"],
reason=row["reason"], reason = row["reason"],
status=row["status"], status = row["status"],
requested_by=row["requested_by"], requested_by = row["requested_by"],
requested_at=( requested_at = (
datetime.fromisoformat(row["requested_at"]) datetime.fromisoformat(row["requested_at"])
if isinstance(row["requested_at"], str) if isinstance(row["requested_at"], str)
else row["requested_at"] else row["requested_at"]
), ),
approved_by=row["approved_by"], approved_by = row["approved_by"],
approved_at=( approved_at = (
datetime.fromisoformat(row["approved_at"]) datetime.fromisoformat(row["approved_at"])
if row["approved_at"] and isinstance(row["approved_at"], str) if row["approved_at"] and isinstance(row["approved_at"], str)
else row["approved_at"] else row["approved_at"]
), ),
completed_at=( completed_at = (
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str) if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"] else row["completed_at"]
), ),
provider_refund_id=row["provider_refund_id"], provider_refund_id = row["provider_refund_id"],
metadata=json.loads(row["metadata"] or "{}"), metadata = json.loads(row["metadata"] or "{}"),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2184,20 +2184,20 @@ class SubscriptionManager:
def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory: def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory:
"""数据库行转换为 BillingHistory 对象""" """数据库行转换为 BillingHistory 对象"""
return BillingHistory( return BillingHistory(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
type=row["type"], type = row["type"],
amount=row["amount"], amount = row["amount"],
currency=row["currency"], currency = row["currency"],
description=row["description"], description = row["description"],
reference_id=row["reference_id"], reference_id = row["reference_id"],
balance_after=row["balance_after"], balance_after = row["balance_after"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
metadata=json.loads(row["metadata"] or "{}"), metadata = json.loads(row["metadata"] or "{}"),
) )

View File

@@ -257,7 +257,7 @@ class TenantManager:
"export:basic": "基础导出", "export:basic": "基础导出",
} }
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path self.db_path = db_path
self._init_db() self._init_db()
@@ -437,19 +437,19 @@ class TenantManager:
) )
tenant = Tenant( tenant = Tenant(
id=tenant_id, id = tenant_id,
name=name, name = name,
slug=slug, slug = slug,
description=description, description = description,
tier=tier, tier = tier,
status=TenantStatus.PENDING.value, status = TenantStatus.PENDING.value,
owner_id=owner_id, owner_id = owner_id,
created_at=datetime.now(), created_at = datetime.now(),
updated_at=datetime.now(), updated_at = datetime.now(),
expires_at=None, expires_at = None,
settings=settings or {}, settings = settings or {},
resource_limits=resource_limits, resource_limits = resource_limits,
metadata={}, metadata = {},
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -495,7 +495,7 @@ class TenantManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id,)) cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -510,7 +510,7 @@ class TenantManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug,)) cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -531,7 +531,7 @@ class TenantManager:
JOIN tenant_domains d ON t.id = d.tenant_id JOIN tenant_domains d ON t.id = d.tenant_id
WHERE d.domain = ? AND d.status = 'verified' WHERE d.domain = ? AND d.status = 'verified'
""", """,
(domain,), (domain, ),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -605,7 +605,7 @@ class TenantManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id,)) cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id, ))
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
finally: finally:
@@ -619,7 +619,7 @@ class TenantManager:
try: try:
cursor = conn.cursor() cursor = conn.cursor()
query = "SELECT * FROM tenants WHERE 1=1" query = "SELECT * FROM tenants WHERE 1 = 1"
params = [] params = []
if status: if status:
@@ -661,18 +661,18 @@ class TenantManager:
domain_id = str(uuid.uuid4()) domain_id = str(uuid.uuid4())
tenant_domain = TenantDomain( tenant_domain = TenantDomain(
id=domain_id, id = domain_id,
tenant_id=tenant_id, tenant_id = tenant_id,
domain=domain.lower(), domain = domain.lower(),
status=DomainStatus.PENDING.value, status = DomainStatus.PENDING.value,
verification_token=verification_token, verification_token = verification_token,
verification_method=verification_method, verification_method = verification_method,
verified_at=None, verified_at = None,
created_at=datetime.now(), created_at = datetime.now(),
updated_at=datetime.now(), updated_at = datetime.now(),
is_primary=is_primary, is_primary = is_primary,
ssl_enabled=False, ssl_enabled = False,
ssl_expires_at=None, ssl_expires_at = None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -684,7 +684,7 @@ class TenantManager:
UPDATE tenant_domains SET is_primary = 0 UPDATE tenant_domains SET is_primary = 0
WHERE tenant_id = ? WHERE tenant_id = ?
""", """,
(tenant_id,), (tenant_id, ),
) )
cursor.execute( cursor.execute(
@@ -782,7 +782,7 @@ class TenantManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id,)) cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
@@ -797,7 +797,7 @@ class TenantManager:
"dns_record": { "dns_record": {
"type": "TXT", "type": "TXT",
"name": "_insightflow", "name": "_insightflow",
"value": f"insightflow-verify={token}", "value": f"insightflow-verify = {token}",
"ttl": 3600, "ttl": 3600,
}, },
"file_verification": { "file_verification": {
@@ -805,7 +805,7 @@ class TenantManager:
"content": token, "content": token,
}, },
"instructions": [ "instructions": [
f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}", f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify = {token}",
f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt内容为 {token}", f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt内容为 {token}",
], ],
} }
@@ -841,7 +841,7 @@ class TenantManager:
WHERE tenant_id = ? WHERE tenant_id = ?
ORDER BY is_primary DESC, created_at DESC ORDER BY is_primary DESC, created_at DESC
""", """,
(tenant_id,), (tenant_id, ),
) )
rows = cursor.fetchall() rows = cursor.fetchall()
@@ -857,7 +857,7 @@ class TenantManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id,)) cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id, ))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -885,7 +885,7 @@ class TenantManager:
cursor = conn.cursor() cursor = conn.cursor()
# 检查是否已存在 # 检查是否已存在
cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id,)) cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id, ))
existing = cursor.fetchone() existing = cursor.fetchone()
if existing: if existing:
@@ -1022,17 +1022,17 @@ class TenantManager:
final_permissions = permissions or default_permissions final_permissions = permissions or default_permissions
member = TenantMember( member = TenantMember(
id=member_id, id = member_id,
tenant_id=tenant_id, tenant_id = tenant_id,
user_id="pending", # 临时值,待用户接受邀请后更新 user_id = "pending", # 临时值,待用户接受邀请后更新
email=email, email = email,
role=role, role = role,
permissions=final_permissions, permissions = final_permissions,
invited_by=invited_by, invited_by = invited_by,
invited_at=datetime.now(), invited_at = datetime.now(),
joined_at=None, joined_at = None,
last_active_at=None, last_active_at = None,
status="pending", status = "pending",
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1197,7 +1197,7 @@ class TenantManager:
WHERE m.user_id = ? AND m.status = 'active' WHERE m.user_id = ? AND m.status = 'active'
ORDER BY t.created_at DESC ORDER BY t.created_at DESC
""", """,
(user_id,), (user_id, ),
) )
rows = cursor.fetchall() rows = cursor.fetchall()
@@ -1227,7 +1227,7 @@ class TenantManager:
projects_count: int = 0, projects_count: int = 0,
entities_count: int = 0, entities_count: int = 0,
members_count: int = 0, members_count: int = 0,
): ) -> None:
"""记录资源使用""" """记录资源使用"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1388,7 +1388,7 @@ class TenantManager:
counter = 1 counter = 1
while True: while True:
cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug,)) cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug, ))
if not cursor.fetchone(): if not cursor.fetchone():
break break
slug = f"{base_slug}-{counter}" slug = f"{base_slug}-{counter}"
@@ -1406,7 +1406,7 @@ class TenantManager:
def _validate_domain(self, domain: str) -> bool: def _validate_domain(self, domain: str) -> bool:
"""验证域名格式""" """验证域名格式"""
pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$" pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$"
return bool(re.match(pattern, domain)) return bool(re.match(pattern, domain))
def _check_domain_verification(self, domain: str, token: str, method: str) -> bool: def _check_domain_verification(self, domain: str, token: str, method: str) -> bool:
@@ -1431,7 +1431,7 @@ class TenantManager:
# TODO: 实现 HTTP 文件验证 # TODO: 实现 HTTP 文件验证
# import requests # import requests
# try: # try:
# response = requests.get(f"http://{domain}/.well-known/insightflow-verify.txt", timeout=10) # response = requests.get(f"http://{domain}/.well-known/insightflow-verify.txt", timeout = 10)
# if response.status_code == 200 and token in response.text: # if response.status_code == 200 and token in response.text:
# return True # return True
# except (ImportError, Exception): # except (ImportError, Exception):
@@ -1467,7 +1467,7 @@ class TenantManager:
email: str, email: str,
role: TenantRole, role: TenantRole,
invited_by: str | None, invited_by: str | None,
): ) -> None:
"""内部方法:添加成员""" """内部方法:添加成员"""
cursor = conn.cursor() cursor = conn.cursor()
member_id = str(uuid.uuid4()) member_id = str(uuid.uuid4())
@@ -1497,60 +1497,60 @@ class TenantManager:
def _row_to_tenant(self, row: sqlite3.Row) -> Tenant: def _row_to_tenant(self, row: sqlite3.Row) -> Tenant:
"""数据库行转换为 Tenant 对象""" """数据库行转换为 Tenant 对象"""
return Tenant( return Tenant(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
slug=row["slug"], slug = row["slug"],
description=row["description"], description = row["description"],
tier=row["tier"], tier = row["tier"],
status=row["status"], status = row["status"],
owner_id=row["owner_id"], owner_id = row["owner_id"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
expires_at=( expires_at = (
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
if row["expires_at"] and isinstance(row["expires_at"], str) if row["expires_at"] and isinstance(row["expires_at"], str)
else row["expires_at"] else row["expires_at"]
), ),
settings=json.loads(row["settings"] or "{}"), settings = json.loads(row["settings"] or "{}"),
resource_limits=json.loads(row["resource_limits"] or "{}"), resource_limits = json.loads(row["resource_limits"] or "{}"),
metadata=json.loads(row["metadata"] or "{}"), metadata = json.loads(row["metadata"] or "{}"),
) )
def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain: def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain:
"""数据库行转换为 TenantDomain 对象""" """数据库行转换为 TenantDomain 对象"""
return TenantDomain( return TenantDomain(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
domain=row["domain"], domain = row["domain"],
status=row["status"], status = row["status"],
verification_token=row["verification_token"], verification_token = row["verification_token"],
verification_method=row["verification_method"], verification_method = row["verification_method"],
verified_at=( verified_at = (
datetime.fromisoformat(row["verified_at"]) datetime.fromisoformat(row["verified_at"])
if row["verified_at"] and isinstance(row["verified_at"], str) if row["verified_at"] and isinstance(row["verified_at"], str)
else row["verified_at"] else row["verified_at"]
), ),
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
is_primary=bool(row["is_primary"]), is_primary = bool(row["is_primary"]),
ssl_enabled=bool(row["ssl_enabled"]), ssl_enabled = bool(row["ssl_enabled"]),
ssl_expires_at=( ssl_expires_at = (
datetime.fromisoformat(row["ssl_expires_at"]) datetime.fromisoformat(row["ssl_expires_at"])
if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str) if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str)
else row["ssl_expires_at"] else row["ssl_expires_at"]
@@ -1560,22 +1560,22 @@ class TenantManager:
def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding: def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding:
"""数据库行转换为 TenantBranding 对象""" """数据库行转换为 TenantBranding 对象"""
return TenantBranding( return TenantBranding(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
logo_url=row["logo_url"], logo_url = row["logo_url"],
favicon_url=row["favicon_url"], favicon_url = row["favicon_url"],
primary_color=row["primary_color"], primary_color = row["primary_color"],
secondary_color=row["secondary_color"], secondary_color = row["secondary_color"],
custom_css=row["custom_css"], custom_css = row["custom_css"],
custom_js=row["custom_js"], custom_js = row["custom_js"],
login_page_bg=row["login_page_bg"], login_page_bg = row["login_page_bg"],
email_template=row["email_template"], email_template = row["email_template"],
created_at=( created_at = (
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at=( updated_at = (
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1585,29 +1585,29 @@ class TenantManager:
def _row_to_member(self, row: sqlite3.Row) -> TenantMember: def _row_to_member(self, row: sqlite3.Row) -> TenantMember:
"""数据库行转换为 TenantMember 对象""" """数据库行转换为 TenantMember 对象"""
return TenantMember( return TenantMember(
id=row["id"], id = row["id"],
tenant_id=row["tenant_id"], tenant_id = row["tenant_id"],
user_id=row["user_id"], user_id = row["user_id"],
email=row["email"], email = row["email"],
role=row["role"], role = row["role"],
permissions=json.loads(row["permissions"] or "[]"), permissions = json.loads(row["permissions"] or "[]"),
invited_by=row["invited_by"], invited_by = row["invited_by"],
invited_at=( invited_at = (
datetime.fromisoformat(row["invited_at"]) datetime.fromisoformat(row["invited_at"])
if isinstance(row["invited_at"], str) if isinstance(row["invited_at"], str)
else row["invited_at"] else row["invited_at"]
), ),
joined_at=( joined_at = (
datetime.fromisoformat(row["joined_at"]) datetime.fromisoformat(row["joined_at"])
if row["joined_at"] and isinstance(row["joined_at"], str) if row["joined_at"] and isinstance(row["joined_at"], str)
else row["joined_at"] else row["joined_at"]
), ),
last_active_at=( last_active_at = (
datetime.fromisoformat(row["last_active_at"]) datetime.fromisoformat(row["last_active_at"])
if row["last_active_at"] and isinstance(row["last_active_at"], str) if row["last_active_at"] and isinstance(row["last_active_at"], str)
else row["last_active_at"] else row["last_active_at"]
), ),
status=row["status"], status = row["status"],
) )

View File

@@ -10,9 +10,9 @@ import sys
# 添加 backend 目录到路径 # 添加 backend 目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
print("=" * 60) print(" = " * 60)
print("InsightFlow 多模态模块测试") print("InsightFlow 多模态模块测试")
print("=" * 60) print(" = " * 60)
# 测试导入 # 测试导入
print("\n1. 测试模块导入...") print("\n1. 测试模块导入...")
@@ -147,6 +147,6 @@ try:
except Exception as e: except Exception as e:
print(f" ✗ 数据库多模态方法测试失败: {e}") print(f" ✗ 数据库多模态方法测试失败: {e}")
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试完成") print("测试完成")
print("=" * 60) print(" = " * 60)

View File

@@ -21,27 +21,27 @@ from search_manager import (
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_fulltext_search(): def test_fulltext_search() -> None:
"""测试全文搜索""" """测试全文搜索"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试全文搜索 (FullTextSearch)") print("测试全文搜索 (FullTextSearch)")
print("=" * 60) print(" = " * 60)
search = FullTextSearch() search = FullTextSearch()
# 测试索引创建 # 测试索引创建
print("\n1. 测试索引创建...") print("\n1. 测试索引创建...")
success = search.index_content( success = search.index_content(
content_id="test_entity_1", content_id = "test_entity_1",
content_type="entity", content_type = "entity",
project_id="test_project", project_id = "test_project",
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。", text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
) )
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
# 测试搜索 # 测试搜索
print("\n2. 测试关键词搜索...") print("\n2. 测试关键词搜索...")
results = search.search("测试", project_id="test_project") results = search.search("测试", project_id = "test_project")
print(f" 搜索结果数量: {len(results)}") print(f" 搜索结果数量: {len(results)}")
if results: if results:
print(f" 第一个结果: {results[0].content[:50]}...") print(f" 第一个结果: {results[0].content[:50]}...")
@@ -49,10 +49,10 @@ def test_fulltext_search():
# 测试布尔搜索 # 测试布尔搜索
print("\n3. 测试布尔搜索...") print("\n3. 测试布尔搜索...")
results = search.search("测试 AND 全文", project_id="test_project") results = search.search("测试 AND 全文", project_id = "test_project")
print(f" AND 搜索结果: {len(results)}") print(f" AND 搜索结果: {len(results)}")
results = search.search("测试 OR 关键词", project_id="test_project") results = search.search("测试 OR 关键词", project_id = "test_project")
print(f" OR 搜索结果: {len(results)}") print(f" OR 搜索结果: {len(results)}")
# 测试高亮 # 测试高亮
@@ -64,11 +64,11 @@ def test_fulltext_search():
return True return True
def test_semantic_search(): def test_semantic_search() -> None:
"""测试语义搜索""" """测试语义搜索"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试语义搜索 (SemanticSearch)") print("测试语义搜索 (SemanticSearch)")
print("=" * 60) print(" = " * 60)
semantic = SemanticSearch() semantic = SemanticSearch()
@@ -89,10 +89,10 @@ def test_semantic_search():
# 测试索引 # 测试索引
print("\n3. 测试语义索引...") print("\n3. 测试语义索引...")
success = semantic.index_embedding( success = semantic.index_embedding(
content_id="test_content_1", content_id = "test_content_1",
content_type="transcript", content_type = "transcript",
project_id="test_project", project_id = "test_project",
text="这是用于语义搜索测试的文本内容。", text = "这是用于语义搜索测试的文本内容。",
) )
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
@@ -100,11 +100,11 @@ def test_semantic_search():
return True return True
def test_entity_path_discovery(): def test_entity_path_discovery() -> None:
"""测试实体路径发现""" """测试实体路径发现"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试实体路径发现 (EntityPathDiscovery)") print("测试实体路径发现 (EntityPathDiscovery)")
print("=" * 60) print(" = " * 60)
discovery = EntityPathDiscovery() discovery = EntityPathDiscovery()
@@ -119,11 +119,11 @@ def test_entity_path_discovery():
return True return True
def test_knowledge_gap_detection(): def test_knowledge_gap_detection() -> None:
"""测试知识缺口识别""" """测试知识缺口识别"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试知识缺口识别 (KnowledgeGapDetection)") print("测试知识缺口识别 (KnowledgeGapDetection)")
print("=" * 60) print(" = " * 60)
detection = KnowledgeGapDetection() detection = KnowledgeGapDetection()
@@ -138,11 +138,11 @@ def test_knowledge_gap_detection():
return True return True
def test_cache_manager(): def test_cache_manager() -> None:
"""测试缓存管理器""" """测试缓存管理器"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试缓存管理器 (CacheManager)") print("测试缓存管理器 (CacheManager)")
print("=" * 60) print(" = " * 60)
cache = CacheManager() cache = CacheManager()
@@ -150,7 +150,7 @@ def test_cache_manager():
print("\n2. 测试缓存操作...") print("\n2. 测试缓存操作...")
# 设置缓存 # 设置缓存
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60) cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60)
print(" ✓ 设置缓存 test_key_1") print(" ✓ 设置缓存 test_key_1")
# 获取缓存 # 获取缓存
@@ -159,7 +159,7 @@ def test_cache_manager():
# 批量操作 # 批量操作
cache.set_many( cache.set_many(
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60 {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60
) )
print(" ✓ 批量设置缓存") print(" ✓ 批量设置缓存")
@@ -186,11 +186,11 @@ def test_cache_manager():
return True return True
def test_task_queue(): def test_task_queue() -> None:
"""测试任务队列""" """测试任务队列"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试任务队列 (TaskQueue)") print("测试任务队列 (TaskQueue)")
print("=" * 60) print(" = " * 60)
queue = TaskQueue() queue = TaskQueue()
@@ -200,7 +200,7 @@ def test_task_queue():
print("\n2. 测试任务提交...") print("\n2. 测试任务提交...")
# 定义测试任务处理器 # 定义测试任务处理器
def test_task_handler(payload): def test_task_handler(payload) -> None:
print(f" 执行任务: {payload}") print(f" 执行任务: {payload}")
return {"status": "success", "processed": True} return {"status": "success", "processed": True}
@@ -208,7 +208,7 @@ def test_task_queue():
# 提交任务 # 提交任务
task_id = queue.submit( task_id = queue.submit(
task_type="test_task", payload={"test": "data", "timestamp": time.time()} task_type = "test_task", payload = {"test": "data", "timestamp": time.time()}
) )
print(" ✓ 提交任务: {task_id}") print(" ✓ 提交任务: {task_id}")
@@ -227,11 +227,11 @@ def test_task_queue():
return True return True
def test_performance_monitor(): def test_performance_monitor() -> None:
"""测试性能监控""" """测试性能监控"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试性能监控 (PerformanceMonitor)") print("测试性能监控 (PerformanceMonitor)")
print("=" * 60) print(" = " * 60)
monitor = PerformanceMonitor() monitor = PerformanceMonitor()
@@ -240,25 +240,25 @@ def test_performance_monitor():
# 记录一些测试指标 # 记录一些测试指标
for i in range(5): for i in range(5):
monitor.record_metric( monitor.record_metric(
metric_type="api_response", metric_type = "api_response",
duration_ms=50 + i * 10, duration_ms = 50 + i * 10,
endpoint="/api/v1/test", endpoint = "/api/v1/test",
metadata={"test": True}, metadata = {"test": True},
) )
for i in range(3): for i in range(3):
monitor.record_metric( monitor.record_metric(
metric_type="db_query", metric_type = "db_query",
duration_ms=20 + i * 5, duration_ms = 20 + i * 5,
endpoint="SELECT test", endpoint = "SELECT test",
metadata={"test": True}, metadata = {"test": True},
) )
print(" ✓ 记录了 8 个测试指标") print(" ✓ 记录了 8 个测试指标")
# 获取统计 # 获取统计
print("\n2. 获取性能统计...") print("\n2. 获取性能统计...")
stats = monitor.get_stats(hours=1) stats = monitor.get_stats(hours = 1)
print(f" 总请求数: {stats['overall']['total_requests']}") print(f" 总请求数: {stats['overall']['total_requests']}")
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms") print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms") print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
@@ -274,11 +274,11 @@ def test_performance_monitor():
return True return True
def test_search_manager(): def test_search_manager() -> None:
"""测试搜索管理器""" """测试搜索管理器"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试搜索管理器 (SearchManager)") print("测试搜索管理器 (SearchManager)")
print("=" * 60) print(" = " * 60)
manager = get_search_manager() manager = get_search_manager()
@@ -295,11 +295,11 @@ def test_search_manager():
return True return True
def test_performance_manager(): def test_performance_manager() -> None:
"""测试性能管理器""" """测试性能管理器"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试性能管理器 (PerformanceManager)") print("测试性能管理器 (PerformanceManager)")
print("=" * 60) print(" = " * 60)
manager = get_performance_manager() manager = get_performance_manager()
@@ -320,12 +320,12 @@ def test_performance_manager():
return True return True
def run_all_tests(): def run_all_tests() -> None:
"""运行所有测试""" """运行所有测试"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("InsightFlow Phase 7 Task 6 & 8 测试") print("InsightFlow Phase 7 Task 6 & 8 测试")
print("高级搜索与发现 + 性能优化与扩展") print("高级搜索与发现 + 性能优化与扩展")
print("=" * 60) print(" = " * 60)
results = [] results = []
@@ -386,9 +386,9 @@ def run_all_tests():
results.append(("性能管理器", False)) results.append(("性能管理器", False))
# 打印测试汇总 # 打印测试汇总
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试汇总") print("测试汇总")
print("=" * 60) print(" = " * 60)
passed = sum(1 for _, result in results if result) passed = sum(1 for _, result in results if result)
total = len(results) total = len(results)

View File

@@ -18,18 +18,18 @@ from tenant_manager import get_tenant_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_tenant_management(): def test_tenant_management() -> None:
"""测试租户管理功能""" """测试租户管理功能"""
print("=" * 60) print(" = " * 60)
print("测试 1: 租户管理") print("测试 1: 租户管理")
print("=" * 60) print(" = " * 60)
manager = get_tenant_manager() manager = get_tenant_manager()
# 1. 创建租户 # 1. 创建租户
print("\n1.1 创建租户...") print("\n1.1 创建租户...")
tenant = manager.create_tenant( tenant = manager.create_tenant(
name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant" name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant"
) )
print(f"✅ 租户创建成功: {tenant.id}") print(f"✅ 租户创建成功: {tenant.id}")
print(f" - 名称: {tenant.name}") print(f" - 名称: {tenant.name}")
@@ -53,30 +53,30 @@ def test_tenant_management():
# 4. 更新租户 # 4. 更新租户
print("\n1.4 更新租户信息...") print("\n1.4 更新租户信息...")
updated = manager.update_tenant( updated = manager.update_tenant(
tenant_id=tenant.id, name="Test Company Updated", tier="enterprise" tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise"
) )
assert updated is not None, "更新租户失败" assert updated is not None, "更新租户失败"
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}") print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
# 5. 列出租户 # 5. 列出租户
print("\n1.5 列出租户...") print("\n1.5 列出租户...")
tenants = manager.list_tenants(limit=10) tenants = manager.list_tenants(limit = 10)
print(f"✅ 找到 {len(tenants)} 个租户") print(f"✅ 找到 {len(tenants)} 个租户")
return tenant.id return tenant.id
def test_domain_management(tenant_id: str): def test_domain_management(tenant_id: str) -> None:
"""测试域名管理功能""" """测试域名管理功能"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试 2: 域名管理") print("测试 2: 域名管理")
print("=" * 60) print(" = " * 60)
manager = get_tenant_manager() manager = get_tenant_manager()
# 1. 添加域名 # 1. 添加域名
print("\n2.1 添加自定义域名...") print("\n2.1 添加自定义域名...")
domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True) domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True)
print(f"✅ 域名添加成功: {domain.domain}") print(f"✅ 域名添加成功: {domain.domain}")
print(f" - ID: {domain.id}") print(f" - ID: {domain.id}")
print(f" - 状态: {domain.status}") print(f" - 状态: {domain.status}")
@@ -112,25 +112,25 @@ def test_domain_management(tenant_id: str):
return domain.id return domain.id
def test_branding_management(tenant_id: str): def test_branding_management(tenant_id: str) -> None:
"""测试品牌白标功能""" """测试品牌白标功能"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试 3: 品牌白标") print("测试 3: 品牌白标")
print("=" * 60) print(" = " * 60)
manager = get_tenant_manager() manager = get_tenant_manager()
# 1. 更新品牌配置 # 1. 更新品牌配置
print("\n3.1 更新品牌配置...") print("\n3.1 更新品牌配置...")
branding = manager.update_branding( branding = manager.update_branding(
tenant_id=tenant_id, tenant_id = tenant_id,
logo_url="https://example.com/logo.png", logo_url = "https://example.com/logo.png",
favicon_url="https://example.com/favicon.ico", favicon_url = "https://example.com/favicon.ico",
primary_color="#1890ff", primary_color = "#1890ff",
secondary_color="#52c41a", secondary_color = "#52c41a",
custom_css=".header { background: #1890ff; }", custom_css = ".header { background: #1890ff; }",
custom_js="console.log('Custom JS loaded');", custom_js = "console.log('Custom JS loaded');",
login_page_bg="https://example.com/bg.jpg", login_page_bg = "https://example.com/bg.jpg",
) )
print("✅ 品牌配置更新成功") print("✅ 品牌配置更新成功")
print(f" - Logo: {branding.logo_url}") print(f" - Logo: {branding.logo_url}")
@@ -152,18 +152,18 @@ def test_branding_management(tenant_id: str):
return branding.id return branding.id
def test_member_management(tenant_id: str): def test_member_management(tenant_id: str) -> None:
"""测试成员管理功能""" """测试成员管理功能"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试 4: 成员管理") print("测试 4: 成员管理")
print("=" * 60) print(" = " * 60)
manager = get_tenant_manager() manager = get_tenant_manager()
# 1. 邀请成员 # 1. 邀请成员
print("\n4.1 邀请成员...") print("\n4.1 邀请成员...")
member1 = manager.invite_member( member1 = manager.invite_member(
tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001" tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001"
) )
print(f"✅ 成员邀请成功: {member1.email}") print(f"✅ 成员邀请成功: {member1.email}")
print(f" - ID: {member1.id}") print(f" - ID: {member1.id}")
@@ -171,7 +171,7 @@ def test_member_management(tenant_id: str):
print(f" - 权限: {member1.permissions}") print(f" - 权限: {member1.permissions}")
member2 = manager.invite_member( member2 = manager.invite_member(
tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001" tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001"
) )
print(f"✅ 成员邀请成功: {member2.email}") print(f"✅ 成员邀请成功: {member2.email}")
@@ -207,24 +207,24 @@ def test_member_management(tenant_id: str):
return member1.id, member2.id return member1.id, member2.id
def test_usage_tracking(tenant_id: str): def test_usage_tracking(tenant_id: str) -> None:
"""测试资源使用统计功能""" """测试资源使用统计功能"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("测试 5: 资源使用统计") print("测试 5: 资源使用统计")
print("=" * 60) print(" = " * 60)
manager = get_tenant_manager() manager = get_tenant_manager()
# 1. 记录使用 # 1. 记录使用
print("\n5.1 记录资源使用...") print("\n5.1 记录资源使用...")
manager.record_usage( manager.record_usage(
tenant_id=tenant_id, tenant_id = tenant_id,
storage_bytes=1024 * 1024 * 50, # 50MB storage_bytes = 1024 * 1024 * 50, # 50MB
transcription_seconds=600, # 10分钟 transcription_seconds = 600, # 10分钟
api_calls=100, api_calls = 100,
projects_count=5, projects_count = 5,
entities_count=50, entities_count = 50,
members_count=3, members_count = 3,
) )
print("✅ 资源使用记录成功") print("✅ 资源使用记录成功")
@@ -249,11 +249,11 @@ def test_usage_tracking(tenant_id: str):
return stats return stats
def cleanup(tenant_id: str, domain_id: str, member_ids: list): def cleanup(tenant_id: str, domain_id: str, member_ids: list) -> None:
"""清理测试数据""" """清理测试数据"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("清理测试数据") print("清理测试数据")
print("=" * 60) print(" = " * 60)
manager = get_tenant_manager() manager = get_tenant_manager()
@@ -273,11 +273,11 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list):
print(f"✅ 租户已删除: {tenant_id}") print(f"✅ 租户已删除: {tenant_id}")
def main(): def main() -> None:
"""主测试函数""" """主测试函数"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试") print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试")
print("=" * 60) print(" = " * 60)
tenant_id = None tenant_id = None
domain_id = None domain_id = None
@@ -292,9 +292,9 @@ def main():
member_ids = [m1, m2] member_ids = [m1, m2]
test_usage_tracking(tenant_id) test_usage_tracking(tenant_id)
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("✅ 所有测试通过!") print("✅ 所有测试通过!")
print("=" * 60) print(" = " * 60)
except Exception as e: except Exception as e:
print(f"\n❌ 测试失败: {e}") print(f"\n❌ 测试失败: {e}")

View File

@@ -12,17 +12,17 @@ from subscription_manager import PaymentProvider, SubscriptionManager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_subscription_manager(): def test_subscription_manager() -> None:
"""测试订阅管理器""" """测试订阅管理器"""
print("=" * 60) print(" = " * 60)
print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试") print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试")
print("=" * 60) print(" = " * 60)
# 使用临时文件数据库进行测试 # 使用临时文件数据库进行测试
db_path = tempfile.mktemp(suffix=".db") db_path = tempfile.mktemp(suffix = ".db")
try: try:
manager = SubscriptionManager(db_path=db_path) manager = SubscriptionManager(db_path = db_path)
print("\n1. 测试订阅计划管理") print("\n1. 测试订阅计划管理")
print("-" * 40) print("-" * 40)
@@ -53,10 +53,10 @@ def test_subscription_manager():
# 创建订阅 # 创建订阅
subscription = manager.create_subscription( subscription = manager.create_subscription(
tenant_id=tenant_id, tenant_id = tenant_id,
plan_id=pro_plan.id, plan_id = pro_plan.id,
payment_provider=PaymentProvider.STRIPE.value, payment_provider = PaymentProvider.STRIPE.value,
trial_days=14, trial_days = 14,
) )
print(f"✓ 创建订阅: {subscription.id}") print(f"✓ 创建订阅: {subscription.id}")
@@ -75,21 +75,21 @@ def test_subscription_manager():
# 记录转录用量 # 记录转录用量
usage1 = manager.record_usage( usage1 = manager.record_usage(
tenant_id=tenant_id, tenant_id = tenant_id,
resource_type="transcription", resource_type = "transcription",
quantity=120, quantity = 120,
unit="minute", unit = "minute",
description="会议转录", description = "会议转录",
) )
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}") print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
# 记录存储用量 # 记录存储用量
usage2 = manager.record_usage( usage2 = manager.record_usage(
tenant_id=tenant_id, tenant_id = tenant_id,
resource_type="storage", resource_type = "storage",
quantity=2.5, quantity = 2.5,
unit="gb", unit = "gb",
description="文件存储", description = "文件存储",
) )
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}") print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
@@ -105,11 +105,11 @@ def test_subscription_manager():
# 创建支付 # 创建支付
payment = manager.create_payment( payment = manager.create_payment(
tenant_id=tenant_id, tenant_id = tenant_id,
amount=99.0, amount = 99.0,
currency="CNY", currency = "CNY",
provider=PaymentProvider.ALIPAY.value, provider = PaymentProvider.ALIPAY.value,
payment_method="qrcode", payment_method = "qrcode",
) )
print(f"✓ 创建支付: {payment.id}") print(f"✓ 创建支付: {payment.id}")
print(f" - 金额: ¥{payment.amount}") print(f" - 金额: ¥{payment.amount}")
@@ -142,11 +142,11 @@ def test_subscription_manager():
# 申请退款 # 申请退款
refund = manager.request_refund( refund = manager.request_refund(
tenant_id=tenant_id, tenant_id = tenant_id,
payment_id=payment.id, payment_id = payment.id,
amount=50.0, amount = 50.0,
reason="服务不满意", reason = "服务不满意",
requested_by="user_001", requested_by = "user_001",
) )
print(f"✓ 申请退款: {refund.id}") print(f"✓ 申请退款: {refund.id}")
print(f" - 金额: ¥{refund.amount}") print(f" - 金额: ¥{refund.amount}")
@@ -178,19 +178,19 @@ def test_subscription_manager():
# Stripe Checkout # Stripe Checkout
stripe_session = manager.create_stripe_checkout_session( stripe_session = manager.create_stripe_checkout_session(
tenant_id=tenant_id, tenant_id = tenant_id,
plan_id=enterprise_plan.id, plan_id = enterprise_plan.id,
success_url="https://example.com/success", success_url = "https://example.com/success",
cancel_url="https://example.com/cancel", cancel_url = "https://example.com/cancel",
) )
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}") print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
# 支付宝订单 # 支付宝订单
alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id) alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id)
print(f"✓ 支付宝订单: {alipay_order['order_id']}") print(f"✓ 支付宝订单: {alipay_order['order_id']}")
# 微信支付订单 # 微信支付订单
wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id) wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id)
print(f"✓ 微信支付订单: {wechat_order['order_id']}") print(f"✓ 微信支付订单: {wechat_order['order_id']}")
# Webhook 处理 # Webhook 处理
@@ -205,18 +205,18 @@ def test_subscription_manager():
# 更改计划 # 更改计划
changed = manager.change_plan( changed = manager.change_plan(
subscription_id=subscription.id, new_plan_id=enterprise_plan.id subscription_id = subscription.id, new_plan_id = enterprise_plan.id
) )
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)") print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
# 取消订阅 # 取消订阅
cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True) cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True)
print(f"✓ 取消订阅: {cancelled.status}") print(f"✓ 取消订阅: {cancelled.status}")
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}") print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("所有测试通过! ✓") print("所有测试通过! ✓")
print("=" * 60) print(" = " * 60)
finally: finally:
# 清理临时数据库 # 清理临时数据库

View File

@@ -14,7 +14,7 @@ from ai_manager import ModelType, PredictionType, get_ai_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_custom_model(): def test_custom_model() -> None:
"""测试自定义模型功能""" """测试自定义模型功能"""
print("\n=== 测试自定义模型 ===") print("\n=== 测试自定义模型 ===")
@@ -23,16 +23,16 @@ def test_custom_model():
# 1. 创建自定义模型 # 1. 创建自定义模型
print("1. 创建自定义模型...") print("1. 创建自定义模型...")
model = manager.create_custom_model( model = manager.create_custom_model(
tenant_id="tenant_001", tenant_id = "tenant_001",
name="领域实体识别模型", name = "领域实体识别模型",
description="用于识别医疗领域实体的自定义模型", description = "用于识别医疗领域实体的自定义模型",
model_type=ModelType.CUSTOM_NER, model_type = ModelType.CUSTOM_NER,
training_data={ training_data = {
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"], "entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
"domain": "medical", "domain": "medical",
}, },
hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32}, hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
created_by="user_001", created_by = "user_001",
) )
print(f" 创建成功: {model.id}, 状态: {model.status.value}") print(f" 创建成功: {model.id}, 状态: {model.status.value}")
@@ -67,10 +67,10 @@ def test_custom_model():
for sample_data in samples: for sample_data in samples:
sample = manager.add_training_sample( sample = manager.add_training_sample(
model_id=model.id, model_id = model.id,
text=sample_data["text"], text = sample_data["text"],
entities=sample_data["entities"], entities = sample_data["entities"],
metadata={"source": "manual"}, metadata = {"source": "manual"},
) )
print(f" 添加样本: {sample.id}") print(f" 添加样本: {sample.id}")
@@ -81,7 +81,7 @@ def test_custom_model():
# 4. 列出自定义模型 # 4. 列出自定义模型
print("4. 列出自定义模型...") print("4. 列出自定义模型...")
models = manager.list_custom_models(tenant_id="tenant_001") models = manager.list_custom_models(tenant_id = "tenant_001")
print(f" 找到 {len(models)} 个模型") print(f" 找到 {len(models)} 个模型")
for m in models: for m in models:
print(f" - {m.name} ({m.model_type.value}): {m.status.value}") print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
@@ -89,7 +89,7 @@ def test_custom_model():
return model.id return model.id
async def test_train_and_predict(model_id: str): async def test_train_and_predict(model_id: str) -> None:
"""测试训练和预测""" """测试训练和预测"""
print("\n=== 测试模型训练和预测 ===") print("\n=== 测试模型训练和预测 ===")
@@ -116,7 +116,7 @@ async def test_train_and_predict(model_id: str):
print(f" 预测失败: {e}") print(f" 预测失败: {e}")
def test_prediction_models(): def test_prediction_models() -> None:
"""测试预测模型""" """测试预测模型"""
print("\n=== 测试预测模型 ===") print("\n=== 测试预测模型 ===")
@@ -125,32 +125,32 @@ def test_prediction_models():
# 1. 创建趋势预测模型 # 1. 创建趋势预测模型
print("1. 创建趋势预测模型...") print("1. 创建趋势预测模型...")
trend_model = manager.create_prediction_model( trend_model = manager.create_prediction_model(
tenant_id="tenant_001", tenant_id = "tenant_001",
project_id="project_001", project_id = "project_001",
name="实体数量趋势预测", name = "实体数量趋势预测",
prediction_type=PredictionType.TREND, prediction_type = PredictionType.TREND,
target_entity_type="PERSON", target_entity_type = "PERSON",
features=["entity_count", "time_period", "document_count"], features = ["entity_count", "time_period", "document_count"],
model_config={"algorithm": "linear_regression", "window_size": 7}, model_config = {"algorithm": "linear_regression", "window_size": 7},
) )
print(f" 创建成功: {trend_model.id}") print(f" 创建成功: {trend_model.id}")
# 2. 创建异常检测模型 # 2. 创建异常检测模型
print("2. 创建异常检测模型...") print("2. 创建异常检测模型...")
anomaly_model = manager.create_prediction_model( anomaly_model = manager.create_prediction_model(
tenant_id="tenant_001", tenant_id = "tenant_001",
project_id="project_001", project_id = "project_001",
name="实体增长异常检测", name = "实体增长异常检测",
prediction_type=PredictionType.ANOMALY, prediction_type = PredictionType.ANOMALY,
target_entity_type=None, target_entity_type = None,
features=["daily_growth", "weekly_growth"], features = ["daily_growth", "weekly_growth"],
model_config={"threshold": 2.5, "sensitivity": "medium"}, model_config = {"threshold": 2.5, "sensitivity": "medium"},
) )
print(f" 创建成功: {anomaly_model.id}") print(f" 创建成功: {anomaly_model.id}")
# 3. 列出预测模型 # 3. 列出预测模型
print("3. 列出预测模型...") print("3. 列出预测模型...")
models = manager.list_prediction_models(tenant_id="tenant_001") models = manager.list_prediction_models(tenant_id = "tenant_001")
print(f" 找到 {len(models)} 个预测模型") print(f" 找到 {len(models)} 个预测模型")
for m in models: for m in models:
print(f" - {m.name} ({m.prediction_type.value})") print(f" - {m.name} ({m.prediction_type.value})")
@@ -158,7 +158,7 @@ def test_prediction_models():
return trend_model.id, anomaly_model.id return trend_model.id, anomaly_model.id
async def test_predictions(trend_model_id: str, anomaly_model_id: str): async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None:
"""测试预测功能""" """测试预测功能"""
print("\n=== 测试预测功能 ===") print("\n=== 测试预测功能 ===")
@@ -193,7 +193,7 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
print(f" 检测结果: {anomaly_result.prediction_data}") print(f" 检测结果: {anomaly_result.prediction_data}")
def test_kg_rag(): def test_kg_rag() -> None:
"""测试知识图谱 RAG""" """测试知识图谱 RAG"""
print("\n=== 测试知识图谱 RAG ===") print("\n=== 测试知识图谱 RAG ===")
@@ -202,28 +202,28 @@ def test_kg_rag():
# 创建 RAG 配置 # 创建 RAG 配置
print("1. 创建知识图谱 RAG 配置...") print("1. 创建知识图谱 RAG 配置...")
rag = manager.create_kg_rag( rag = manager.create_kg_rag(
tenant_id="tenant_001", tenant_id = "tenant_001",
project_id="project_001", project_id = "project_001",
name="项目知识问答", name = "项目知识问答",
description="基于项目知识图谱的智能问答", description = "基于项目知识图谱的智能问答",
kg_config={ kg_config = {
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"], "entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
"relation_types": ["works_with", "belongs_to", "depends_on"], "relation_types": ["works_with", "belongs_to", "depends_on"],
}, },
retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True}, retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True}, generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
) )
print(f" 创建成功: {rag.id}") print(f" 创建成功: {rag.id}")
# 列出 RAG 配置 # 列出 RAG 配置
print("2. 列出 RAG 配置...") print("2. 列出 RAG 配置...")
rags = manager.list_kg_rags(tenant_id="tenant_001") rags = manager.list_kg_rags(tenant_id = "tenant_001")
print(f" 找到 {len(rags)} 个配置") print(f" 找到 {len(rags)} 个配置")
return rag.id return rag.id
async def test_kg_rag_query(rag_id: str): async def test_kg_rag_query(rag_id: str) -> None:
"""测试 RAG 查询""" """测试 RAG 查询"""
print("\n=== 测试知识图谱 RAG 查询 ===") print("\n=== 测试知识图谱 RAG 查询 ===")
@@ -279,10 +279,10 @@ async def test_kg_rag_query(rag_id: str):
try: try:
result = await manager.query_kg_rag( result = await manager.query_kg_rag(
rag_id=rag_id, rag_id = rag_id,
query=query_text, query = query_text,
project_entities=project_entities, project_entities = project_entities,
project_relations=project_relations, project_relations = project_relations,
) )
print(f" 查询: {result.query}") print(f" 查询: {result.query}")
@@ -294,7 +294,7 @@ async def test_kg_rag_query(rag_id: str):
print(f" 查询失败: {e}") print(f" 查询失败: {e}")
async def test_smart_summary(): async def test_smart_summary() -> None:
"""测试智能摘要""" """测试智能摘要"""
print("\n=== 测试智能摘要 ===") print("\n=== 测试智能摘要 ===")
@@ -326,12 +326,12 @@ async def test_smart_summary():
print(f"1. 生成 {summary_type} 类型摘要...") print(f"1. 生成 {summary_type} 类型摘要...")
try: try:
summary = await manager.generate_smart_summary( summary = await manager.generate_smart_summary(
tenant_id="tenant_001", tenant_id = "tenant_001",
project_id="project_001", project_id = "project_001",
source_type="transcript", source_type = "transcript",
source_id="transcript_001", source_id = "transcript_001",
summary_type=summary_type, summary_type = summary_type,
content_data=content_data, content_data = content_data,
) )
print(f" 摘要类型: {summary.summary_type}") print(f" 摘要类型: {summary.summary_type}")
@@ -342,11 +342,11 @@ async def test_smart_summary():
print(f" 生成失败: {e}") print(f" 生成失败: {e}")
async def main(): async def main() -> None:
"""主测试函数""" """主测试函数"""
print("=" * 60) print(" = " * 60)
print("InsightFlow Phase 8 Task 4 - AI 能力增强测试") print("InsightFlow Phase 8 Task 4 - AI 能力增强测试")
print("=" * 60) print(" = " * 60)
try: try:
# 测试自定义模型 # 测试自定义模型
@@ -370,9 +370,9 @@ async def main():
# 测试智能摘要 # 测试智能摘要
await test_smart_summary() await test_smart_summary()
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("所有测试完成!") print("所有测试完成!")
print("=" * 60) print(" = " * 60)
except Exception as e: except Exception as e:
print(f"\n测试失败: {e}") print(f"\n测试失败: {e}")

View File

@@ -36,13 +36,13 @@ if backend_dir not in sys.path:
class TestGrowthManager: class TestGrowthManager:
"""测试 Growth Manager 功能""" """测试 Growth Manager 功能"""
def __init__(self): def __init__(self) -> None:
self.manager = GrowthManager() self.manager = GrowthManager()
self.test_tenant_id = "test_tenant_001" self.test_tenant_id = "test_tenant_001"
self.test_user_id = "test_user_001" self.test_user_id = "test_user_001"
self.test_results = [] self.test_results = []
def log(self, message: str, success: bool = True): def log(self, message: str, success: bool = True) -> None:
"""记录测试结果""" """记录测试结果"""
status = "" if success else "" status = "" if success else ""
print(f"{status} {message}") print(f"{status} {message}")
@@ -50,21 +50,21 @@ class TestGrowthManager:
# ==================== 测试用户行为分析 ==================== # ==================== 测试用户行为分析 ====================
async def test_track_event(self): async def test_track_event(self) -> None:
"""测试事件追踪""" """测试事件追踪"""
print("\n📊 测试事件追踪...") print("\n📊 测试事件追踪...")
try: try:
event = await self.manager.track_event( event = await self.manager.track_event(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
user_id=self.test_user_id, user_id = self.test_user_id,
event_type=EventType.PAGE_VIEW, event_type = EventType.PAGE_VIEW,
event_name="dashboard_view", event_name = "dashboard_view",
properties={"page": "/dashboard", "duration": 120}, properties = {"page": "/dashboard", "duration": 120},
session_id="session_001", session_id = "session_001",
device_info={"browser": "Chrome", "os": "MacOS"}, device_info = {"browser": "Chrome", "os": "MacOS"},
referrer="https://google.com", referrer = "https://google.com",
utm_params={"source": "google", "medium": "organic", "campaign": "summer"}, utm_params = {"source": "google", "medium": "organic", "campaign": "summer"},
) )
assert event.id is not None assert event.id is not None
@@ -74,10 +74,10 @@ class TestGrowthManager:
self.log(f"事件追踪成功: {event.id}") self.log(f"事件追踪成功: {event.id}")
return True return True
except Exception as e: except Exception as e:
self.log(f"事件追踪失败: {e}", success=False) self.log(f"事件追踪失败: {e}", success = False)
return False return False
async def test_track_multiple_events(self): async def test_track_multiple_events(self) -> None:
"""测试追踪多个事件""" """测试追踪多个事件"""
print("\n📊 测试追踪多个事件...") print("\n📊 测试追踪多个事件...")
@@ -91,20 +91,20 @@ class TestGrowthManager:
for event_type, event_name, props in events: for event_type, event_name, props in events:
await self.manager.track_event( await self.manager.track_event(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
user_id=self.test_user_id, user_id = self.test_user_id,
event_type=event_type, event_type = event_type,
event_name=event_name, event_name = event_name,
properties=props, properties = props,
) )
self.log(f"成功追踪 {len(events)} 个事件") self.log(f"成功追踪 {len(events)} 个事件")
return True return True
except Exception as e: except Exception as e:
self.log(f"批量事件追踪失败: {e}", success=False) self.log(f"批量事件追踪失败: {e}", success = False)
return False return False
def test_get_user_profile(self): def test_get_user_profile(self) -> None:
"""测试获取用户画像""" """测试获取用户画像"""
print("\n👤 测试用户画像...") print("\n👤 测试用户画像...")
@@ -120,18 +120,18 @@ class TestGrowthManager:
return True return True
except Exception as e: except Exception as e:
self.log(f"获取用户画像失败: {e}", success=False) self.log(f"获取用户画像失败: {e}", success = False)
return False return False
def test_get_analytics_summary(self): def test_get_analytics_summary(self) -> None:
"""测试获取分析汇总""" """测试获取分析汇总"""
print("\n📈 测试分析汇总...") print("\n📈 测试分析汇总...")
try: try:
summary = self.manager.get_user_analytics_summary( summary = self.manager.get_user_analytics_summary(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
start_date=datetime.now() - timedelta(days=7), start_date = datetime.now() - timedelta(days = 7),
end_date=datetime.now(), end_date = datetime.now(),
) )
assert "unique_users" in summary assert "unique_users" in summary
@@ -141,25 +141,25 @@ class TestGrowthManager:
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件") self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
return True return True
except Exception as e: except Exception as e:
self.log(f"获取分析汇总失败: {e}", success=False) self.log(f"获取分析汇总失败: {e}", success = False)
return False return False
def test_create_funnel(self): def test_create_funnel(self) -> None:
"""测试创建转化漏斗""" """测试创建转化漏斗"""
print("\n🎯 测试创建转化漏斗...") print("\n🎯 测试创建转化漏斗...")
try: try:
funnel = self.manager.create_funnel( funnel = self.manager.create_funnel(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="用户注册转化漏斗", name = "用户注册转化漏斗",
description="从访问到完成注册的转化流程", description = "从访问到完成注册的转化流程",
steps=[ steps = [
{"name": "访问首页", "event_name": "page_view_home"}, {"name": "访问首页", "event_name": "page_view_home"},
{"name": "点击注册", "event_name": "signup_click"}, {"name": "点击注册", "event_name": "signup_click"},
{"name": "填写信息", "event_name": "signup_form_fill"}, {"name": "填写信息", "event_name": "signup_form_fill"},
{"name": "完成注册", "event_name": "signup_complete"}, {"name": "完成注册", "event_name": "signup_complete"},
], ],
created_by="test", created_by = "test",
) )
assert funnel.id is not None assert funnel.id is not None
@@ -168,10 +168,10 @@ class TestGrowthManager:
self.log(f"漏斗创建成功: {funnel.id}") self.log(f"漏斗创建成功: {funnel.id}")
return funnel.id return funnel.id
except Exception as e: except Exception as e:
self.log(f"创建漏斗失败: {e}", success=False) self.log(f"创建漏斗失败: {e}", success = False)
return None return None
def test_analyze_funnel(self, funnel_id: str): def test_analyze_funnel(self, funnel_id: str) -> None:
"""测试分析漏斗""" """测试分析漏斗"""
print("\n📉 测试漏斗分析...") print("\n📉 测试漏斗分析...")
@@ -181,9 +181,9 @@ class TestGrowthManager:
try: try:
analysis = self.manager.analyze_funnel( analysis = self.manager.analyze_funnel(
funnel_id=funnel_id, funnel_id = funnel_id,
period_start=datetime.now() - timedelta(days=30), period_start = datetime.now() - timedelta(days = 30),
period_end=datetime.now(), period_end = datetime.now(),
) )
if analysis: if analysis:
@@ -194,18 +194,18 @@ class TestGrowthManager:
self.log("漏斗分析返回空结果") self.log("漏斗分析返回空结果")
return False return False
except Exception as e: except Exception as e:
self.log(f"漏斗分析失败: {e}", success=False) self.log(f"漏斗分析失败: {e}", success = False)
return False return False
def test_calculate_retention(self): def test_calculate_retention(self) -> None:
"""测试留存率计算""" """测试留存率计算"""
print("\n🔄 测试留存率计算...") print("\n🔄 测试留存率计算...")
try: try:
retention = self.manager.calculate_retention( retention = self.manager.calculate_retention(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
cohort_date=datetime.now() - timedelta(days=7), cohort_date = datetime.now() - timedelta(days = 7),
periods=[1, 3, 7], periods = [1, 3, 7],
) )
assert "cohort_date" in retention assert "cohort_date" in retention
@@ -214,34 +214,34 @@ class TestGrowthManager:
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户") self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
return True return True
except Exception as e: except Exception as e:
self.log(f"留存率计算失败: {e}", success=False) self.log(f"留存率计算失败: {e}", success = False)
return False return False
# ==================== 测试 A/B 测试框架 ==================== # ==================== 测试 A/B 测试框架 ====================
def test_create_experiment(self): def test_create_experiment(self) -> None:
"""测试创建实验""" """测试创建实验"""
print("\n🧪 测试创建 A/B 测试实验...") print("\n🧪 测试创建 A/B 测试实验...")
try: try:
experiment = self.manager.create_experiment( experiment = self.manager.create_experiment(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="首页按钮颜色测试", name = "首页按钮颜色测试",
description="测试不同按钮颜色对转化率的影响", description = "测试不同按钮颜色对转化率的影响",
hypothesis="蓝色按钮比红色按钮有更高的点击率", hypothesis = "蓝色按钮比红色按钮有更高的点击率",
variants=[ variants = [
{"id": "control", "name": "红色按钮", "is_control": True}, {"id": "control", "name": "红色按钮", "is_control": True},
{"id": "variant_a", "name": "蓝色按钮", "is_control": False}, {"id": "variant_a", "name": "蓝色按钮", "is_control": False},
{"id": "variant_b", "name": "绿色按钮", "is_control": False}, {"id": "variant_b", "name": "绿色按钮", "is_control": False},
], ],
traffic_allocation=TrafficAllocationType.RANDOM, traffic_allocation = TrafficAllocationType.RANDOM,
traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
target_audience={"conditions": []}, target_audience = {"conditions": []},
primary_metric="button_click_rate", primary_metric = "button_click_rate",
secondary_metrics=["conversion_rate", "bounce_rate"], secondary_metrics = ["conversion_rate", "bounce_rate"],
min_sample_size=100, min_sample_size = 100,
confidence_level=0.95, confidence_level = 0.95,
created_by="test", created_by = "test",
) )
assert experiment.id is not None assert experiment.id is not None
@@ -250,10 +250,10 @@ class TestGrowthManager:
self.log(f"实验创建成功: {experiment.id}") self.log(f"实验创建成功: {experiment.id}")
return experiment.id return experiment.id
except Exception as e: except Exception as e:
self.log(f"创建实验失败: {e}", success=False) self.log(f"创建实验失败: {e}", success = False)
return None return None
def test_list_experiments(self): def test_list_experiments(self) -> None:
"""测试列出实验""" """测试列出实验"""
print("\n📋 测试列出实验...") print("\n📋 测试列出实验...")
@@ -263,10 +263,10 @@ class TestGrowthManager:
self.log(f"列出 {len(experiments)} 个实验") self.log(f"列出 {len(experiments)} 个实验")
return True return True
except Exception as e: except Exception as e:
self.log(f"列出实验失败: {e}", success=False) self.log(f"列出实验失败: {e}", success = False)
return False return False
def test_assign_variant(self, experiment_id: str): def test_assign_variant(self, experiment_id: str) -> None:
"""测试分配变体""" """测试分配变体"""
print("\n🎲 测试分配实验变体...") print("\n🎲 测试分配实验变体...")
@@ -284,9 +284,9 @@ class TestGrowthManager:
for user_id in test_users: for user_id in test_users:
variant_id = self.manager.assign_variant( variant_id = self.manager.assign_variant(
experiment_id=experiment_id, experiment_id = experiment_id,
user_id=user_id, user_id = user_id,
user_attributes={"user_id": user_id, "segment": "new"}, user_attributes = {"user_id": user_id, "segment": "new"},
) )
if variant_id: if variant_id:
@@ -295,10 +295,10 @@ class TestGrowthManager:
self.log(f"变体分配完成: {len(assignments)} 个用户") self.log(f"变体分配完成: {len(assignments)} 个用户")
return True return True
except Exception as e: except Exception as e:
self.log(f"变体分配失败: {e}", success=False) self.log(f"变体分配失败: {e}", success = False)
return False return False
def test_record_experiment_metric(self, experiment_id: str): def test_record_experiment_metric(self, experiment_id: str) -> None:
"""测试记录实验指标""" """测试记录实验指标"""
print("\n📊 测试记录实验指标...") print("\n📊 测试记录实验指标...")
@@ -318,20 +318,20 @@ class TestGrowthManager:
for user_id, variant_id, value in test_data: for user_id, variant_id, value in test_data:
self.manager.record_experiment_metric( self.manager.record_experiment_metric(
experiment_id=experiment_id, experiment_id = experiment_id,
variant_id=variant_id, variant_id = variant_id,
user_id=user_id, user_id = user_id,
metric_name="button_click_rate", metric_name = "button_click_rate",
metric_value=value, metric_value = value,
) )
self.log(f"成功记录 {len(test_data)} 条指标") self.log(f"成功记录 {len(test_data)} 条指标")
return True return True
except Exception as e: except Exception as e:
self.log(f"记录指标失败: {e}", success=False) self.log(f"记录指标失败: {e}", success = False)
return False return False
def test_analyze_experiment(self, experiment_id: str): def test_analyze_experiment(self, experiment_id: str) -> None:
"""测试分析实验结果""" """测试分析实验结果"""
print("\n📈 测试分析实验结果...") print("\n📈 测试分析实验结果...")
@@ -346,25 +346,25 @@ class TestGrowthManager:
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体") self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
return True return True
else: else:
self.log(f"实验分析返回错误: {result['error']}", success=False) self.log(f"实验分析返回错误: {result['error']}", success = False)
return False return False
except Exception as e: except Exception as e:
self.log(f"实验分析失败: {e}", success=False) self.log(f"实验分析失败: {e}", success = False)
return False return False
# ==================== 测试邮件营销 ==================== # ==================== 测试邮件营销 ====================
def test_create_email_template(self): def test_create_email_template(self) -> None:
"""测试创建邮件模板""" """测试创建邮件模板"""
print("\n📧 测试创建邮件模板...") print("\n📧 测试创建邮件模板...")
try: try:
template = self.manager.create_email_template( template = self.manager.create_email_template(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="欢迎邮件", name = "欢迎邮件",
template_type=EmailTemplateType.WELCOME, template_type = EmailTemplateType.WELCOME,
subject="欢迎加入 InsightFlow", subject = "欢迎加入 InsightFlow",
html_content=""" html_content = """
<h1>欢迎,{{user_name}}</h1> <h1>欢迎,{{user_name}}</h1>
<p>感谢您注册 InsightFlow。我们很高兴您能加入我们</p> <p>感谢您注册 InsightFlow。我们很高兴您能加入我们</p>
<p>您的账户已创建,可以开始使用以下功能:</p> <p>您的账户已创建,可以开始使用以下功能:</p>
@@ -373,10 +373,10 @@ class TestGrowthManager:
<li>智能实体提取</li> <li>智能实体提取</li>
<li>团队协作</li> <li>团队协作</li>
</ul> </ul>
<p><a href="{{dashboard_url}}">立即开始使用</a></p> <p><a href = "{{dashboard_url}}">立即开始使用</a></p>
""", """,
from_name="InsightFlow 团队", from_name = "InsightFlow 团队",
from_email="welcome@insightflow.io", from_email = "welcome@insightflow.io",
) )
assert template.id is not None assert template.id is not None
@@ -385,10 +385,10 @@ class TestGrowthManager:
self.log(f"邮件模板创建成功: {template.id}") self.log(f"邮件模板创建成功: {template.id}")
return template.id return template.id
except Exception as e: except Exception as e:
self.log(f"创建邮件模板失败: {e}", success=False) self.log(f"创建邮件模板失败: {e}", success = False)
return None return None
def test_list_email_templates(self): def test_list_email_templates(self) -> None:
"""测试列出邮件模板""" """测试列出邮件模板"""
print("\n📧 测试列出邮件模板...") print("\n📧 测试列出邮件模板...")
@@ -398,10 +398,10 @@ class TestGrowthManager:
self.log(f"列出 {len(templates)} 个邮件模板") self.log(f"列出 {len(templates)} 个邮件模板")
return True return True
except Exception as e: except Exception as e:
self.log(f"列出邮件模板失败: {e}", success=False) self.log(f"列出邮件模板失败: {e}", success = False)
return False return False
def test_render_template(self, template_id: str): def test_render_template(self, template_id: str) -> None:
"""测试渲染邮件模板""" """测试渲染邮件模板"""
print("\n🎨 测试渲染邮件模板...") print("\n🎨 测试渲染邮件模板...")
@@ -411,8 +411,8 @@ class TestGrowthManager:
try: try:
rendered = self.manager.render_template( rendered = self.manager.render_template(
template_id=template_id, template_id = template_id,
variables={ variables = {
"user_name": "张三", "user_name": "张三",
"dashboard_url": "https://app.insightflow.io/dashboard", "dashboard_url": "https://app.insightflow.io/dashboard",
}, },
@@ -424,13 +424,13 @@ class TestGrowthManager:
self.log(f"模板渲染成功: {rendered['subject']}") self.log(f"模板渲染成功: {rendered['subject']}")
return True return True
else: else:
self.log("模板渲染返回空结果", success=False) self.log("模板渲染返回空结果", success = False)
return False return False
except Exception as e: except Exception as e:
self.log(f"模板渲染失败: {e}", success=False) self.log(f"模板渲染失败: {e}", success = False)
return False return False
def test_create_email_campaign(self, template_id: str): def test_create_email_campaign(self, template_id: str) -> None:
"""测试创建邮件营销活动""" """测试创建邮件营销活动"""
print("\n📮 测试创建邮件营销活动...") print("\n📮 测试创建邮件营销活动...")
@@ -440,10 +440,10 @@ class TestGrowthManager:
try: try:
campaign = self.manager.create_email_campaign( campaign = self.manager.create_email_campaign(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="新用户欢迎活动", name = "新用户欢迎活动",
template_id=template_id, template_id = template_id,
recipient_list=[ recipient_list = [
{"user_id": "user_001", "email": "user1@example.com"}, {"user_id": "user_001", "email": "user1@example.com"},
{"user_id": "user_002", "email": "user2@example.com"}, {"user_id": "user_002", "email": "user2@example.com"},
{"user_id": "user_003", "email": "user3@example.com"}, {"user_id": "user_003", "email": "user3@example.com"},
@@ -456,21 +456,21 @@ class TestGrowthManager:
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人") self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
return campaign.id return campaign.id
except Exception as e: except Exception as e:
self.log(f"创建营销活动失败: {e}", success=False) self.log(f"创建营销活动失败: {e}", success = False)
return None return None
def test_create_automation_workflow(self): def test_create_automation_workflow(self) -> None:
"""测试创建自动化工作流""" """测试创建自动化工作流"""
print("\n🤖 测试创建自动化工作流...") print("\n🤖 测试创建自动化工作流...")
try: try:
workflow = self.manager.create_automation_workflow( workflow = self.manager.create_automation_workflow(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="新用户欢迎序列", name = "新用户欢迎序列",
description="用户注册后自动发送欢迎邮件序列", description = "用户注册后自动发送欢迎邮件序列",
trigger_type=WorkflowTriggerType.USER_SIGNUP, trigger_type = WorkflowTriggerType.USER_SIGNUP,
trigger_conditions={"event": "user_signup"}, trigger_conditions = {"event": "user_signup"},
actions=[ actions = [
{"type": "send_email", "template_type": "welcome", "delay_hours": 0}, {"type": "send_email", "template_type": "welcome", "delay_hours": 0},
{"type": "send_email", "template_type": "onboarding", "delay_hours": 24}, {"type": "send_email", "template_type": "onboarding", "delay_hours": 24},
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72}, {"type": "send_email", "template_type": "feature_tips", "delay_hours": 72},
@@ -483,27 +483,27 @@ class TestGrowthManager:
self.log(f"自动化工作流创建成功: {workflow.id}") self.log(f"自动化工作流创建成功: {workflow.id}")
return True return True
except Exception as e: except Exception as e:
self.log(f"创建工作流失败: {e}", success=False) self.log(f"创建工作流失败: {e}", success = False)
return False return False
# ==================== 测试推荐系统 ==================== # ==================== 测试推荐系统 ====================
def test_create_referral_program(self): def test_create_referral_program(self) -> None:
"""测试创建推荐计划""" """测试创建推荐计划"""
print("\n🎁 测试创建推荐计划...") print("\n🎁 测试创建推荐计划...")
try: try:
program = self.manager.create_referral_program( program = self.manager.create_referral_program(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="邀请好友奖励计划", name = "邀请好友奖励计划",
description="邀请好友注册,双方获得积分奖励", description = "邀请好友注册,双方获得积分奖励",
referrer_reward_type="credit", referrer_reward_type = "credit",
referrer_reward_value=100.0, referrer_reward_value = 100.0,
referee_reward_type="credit", referee_reward_type = "credit",
referee_reward_value=50.0, referee_reward_value = 50.0,
max_referrals_per_user=10, max_referrals_per_user = 10,
referral_code_length=8, referral_code_length = 8,
expiry_days=30, expiry_days = 30,
) )
assert program.id is not None assert program.id is not None
@@ -512,10 +512,10 @@ class TestGrowthManager:
self.log(f"推荐计划创建成功: {program.id}") self.log(f"推荐计划创建成功: {program.id}")
return program.id return program.id
except Exception as e: except Exception as e:
self.log(f"创建推荐计划失败: {e}", success=False) self.log(f"创建推荐计划失败: {e}", success = False)
return None return None
def test_generate_referral_code(self, program_id: str): def test_generate_referral_code(self, program_id: str) -> None:
"""测试生成推荐码""" """测试生成推荐码"""
print("\n🔑 测试生成推荐码...") print("\n🔑 测试生成推荐码...")
@@ -525,7 +525,7 @@ class TestGrowthManager:
try: try:
referral = self.manager.generate_referral_code( referral = self.manager.generate_referral_code(
program_id=program_id, referrer_id="referrer_user_001" program_id = program_id, referrer_id = "referrer_user_001"
) )
if referral: if referral:
@@ -535,13 +535,13 @@ class TestGrowthManager:
self.log(f"推荐码生成成功: {referral.referral_code}") self.log(f"推荐码生成成功: {referral.referral_code}")
return referral.referral_code return referral.referral_code
else: else:
self.log("生成推荐码返回空结果", success=False) self.log("生成推荐码返回空结果", success = False)
return None return None
except Exception as e: except Exception as e:
self.log(f"生成推荐码失败: {e}", success=False) self.log(f"生成推荐码失败: {e}", success = False)
return None return None
def test_apply_referral_code(self, referral_code: str): def test_apply_referral_code(self, referral_code: str) -> None:
"""测试应用推荐码""" """测试应用推荐码"""
print("\n✅ 测试应用推荐码...") print("\n✅ 测试应用推荐码...")
@@ -551,20 +551,20 @@ class TestGrowthManager:
try: try:
success = self.manager.apply_referral_code( success = self.manager.apply_referral_code(
referral_code=referral_code, referee_id="new_user_001" referral_code = referral_code, referee_id = "new_user_001"
) )
if success: if success:
self.log(f"推荐码应用成功: {referral_code}") self.log(f"推荐码应用成功: {referral_code}")
return True return True
else: else:
self.log("推荐码应用失败", success=False) self.log("推荐码应用失败", success = False)
return False return False
except Exception as e: except Exception as e:
self.log(f"应用推荐码失败: {e}", success=False) self.log(f"应用推荐码失败: {e}", success = False)
return False return False
def test_get_referral_stats(self, program_id: str): def test_get_referral_stats(self, program_id: str) -> None:
"""测试获取推荐统计""" """测试获取推荐统计"""
print("\n📊 测试获取推荐统计...") print("\n📊 测试获取推荐统计...")
@@ -583,24 +583,24 @@ class TestGrowthManager:
) )
return True return True
except Exception as e: except Exception as e:
self.log(f"获取推荐统计失败: {e}", success=False) self.log(f"获取推荐统计失败: {e}", success = False)
return False return False
def test_create_team_incentive(self): def test_create_team_incentive(self) -> None:
"""测试创建团队激励""" """测试创建团队激励"""
print("\n🏆 测试创建团队升级激励...") print("\n🏆 测试创建团队升级激励...")
try: try:
incentive = self.manager.create_team_incentive( incentive = self.manager.create_team_incentive(
tenant_id=self.test_tenant_id, tenant_id = self.test_tenant_id,
name="团队升级奖励", name = "团队升级奖励",
description="团队规模达到5人升级到 Pro 计划可获得折扣", description = "团队规模达到5人升级到 Pro 计划可获得折扣",
target_tier="pro", target_tier = "pro",
min_team_size=5, min_team_size = 5,
incentive_type="discount", incentive_type = "discount",
incentive_value=20.0, # 20% 折扣 incentive_value = 20.0, # 20% 折扣
valid_from=datetime.now(), valid_from = datetime.now(),
valid_until=datetime.now() + timedelta(days=90), valid_until = datetime.now() + timedelta(days = 90),
) )
assert incentive.id is not None assert incentive.id is not None
@@ -609,27 +609,27 @@ class TestGrowthManager:
self.log(f"团队激励创建成功: {incentive.id}") self.log(f"团队激励创建成功: {incentive.id}")
return True return True
except Exception as e: except Exception as e:
self.log(f"创建团队激励失败: {e}", success=False) self.log(f"创建团队激励失败: {e}", success = False)
return False return False
def test_check_team_incentive_eligibility(self): def test_check_team_incentive_eligibility(self) -> None:
"""测试检查团队激励资格""" """测试检查团队激励资格"""
print("\n🔍 测试检查团队激励资格...") print("\n🔍 测试检查团队激励资格...")
try: try:
incentives = self.manager.check_team_incentive_eligibility( incentives = self.manager.check_team_incentive_eligibility(
tenant_id=self.test_tenant_id, current_tier="free", team_size=5 tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5
) )
self.log(f"找到 {len(incentives)} 个符合条件的激励") self.log(f"找到 {len(incentives)} 个符合条件的激励")
return True return True
except Exception as e: except Exception as e:
self.log(f"检查激励资格失败: {e}", success=False) self.log(f"检查激励资格失败: {e}", success = False)
return False return False
# ==================== 测试实时仪表板 ==================== # ==================== 测试实时仪表板 ====================
def test_get_realtime_dashboard(self): def test_get_realtime_dashboard(self) -> None:
"""测试获取实时仪表板""" """测试获取实时仪表板"""
print("\n📺 测试实时分析仪表板...") print("\n📺 测试实时分析仪表板...")
@@ -646,21 +646,21 @@ class TestGrowthManager:
) )
return True return True
except Exception as e: except Exception as e:
self.log(f"获取实时仪表板失败: {e}", success=False) self.log(f"获取实时仪表板失败: {e}", success = False)
return False return False
# ==================== 运行所有测试 ==================== # ==================== 运行所有测试 ====================
async def run_all_tests(self): async def run_all_tests(self) -> None:
"""运行所有测试""" """运行所有测试"""
print("=" * 60) print(" = " * 60)
print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试") print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试")
print("=" * 60) print(" = " * 60)
# 用户行为分析测试 # 用户行为分析测试
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("📊 模块 1: 用户行为分析") print("📊 模块 1: 用户行为分析")
print("=" * 60) print(" = " * 60)
await self.test_track_event() await self.test_track_event()
await self.test_track_multiple_events() await self.test_track_multiple_events()
@@ -671,9 +671,9 @@ class TestGrowthManager:
self.test_calculate_retention() self.test_calculate_retention()
# A/B 测试框架测试 # A/B 测试框架测试
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("🧪 模块 2: A/B 测试框架") print("🧪 模块 2: A/B 测试框架")
print("=" * 60) print(" = " * 60)
experiment_id = self.test_create_experiment() experiment_id = self.test_create_experiment()
self.test_list_experiments() self.test_list_experiments()
@@ -682,9 +682,9 @@ class TestGrowthManager:
self.test_analyze_experiment(experiment_id) self.test_analyze_experiment(experiment_id)
# 邮件营销测试 # 邮件营销测试
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("📧 模块 3: 邮件营销自动化") print("📧 模块 3: 邮件营销自动化")
print("=" * 60) print(" = " * 60)
template_id = self.test_create_email_template() template_id = self.test_create_email_template()
self.test_list_email_templates() self.test_list_email_templates()
@@ -693,9 +693,9 @@ class TestGrowthManager:
self.test_create_automation_workflow() self.test_create_automation_workflow()
# 推荐系统测试 # 推荐系统测试
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("🎁 模块 4: 推荐系统") print("🎁 模块 4: 推荐系统")
print("=" * 60) print(" = " * 60)
program_id = self.test_create_referral_program() program_id = self.test_create_referral_program()
referral_code = self.test_generate_referral_code(program_id) referral_code = self.test_generate_referral_code(program_id)
@@ -705,16 +705,16 @@ class TestGrowthManager:
self.test_check_team_incentive_eligibility() self.test_check_team_incentive_eligibility()
# 实时仪表板测试 # 实时仪表板测试
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("📺 模块 5: 实时分析仪表板") print("📺 模块 5: 实时分析仪表板")
print("=" * 60) print(" = " * 60)
self.test_get_realtime_dashboard() self.test_get_realtime_dashboard()
# 测试总结 # 测试总结
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("📋 测试总结") print("📋 测试总结")
print("=" * 60) print(" = " * 60)
total_tests = len(self.test_results) total_tests = len(self.test_results)
passed_tests = sum(1 for _, success in self.test_results if success) passed_tests = sum(1 for _, success in self.test_results if success)
@@ -731,12 +731,12 @@ class TestGrowthManager:
if not success: if not success:
print(f" - {message}") print(f" - {message}")
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("✨ 测试完成!") print("✨ 测试完成!")
print("=" * 60) print(" = " * 60)
async def main(): async def main() -> None:
"""主函数""" """主函数"""
tester = TestGrowthManager() tester = TestGrowthManager()
await tester.run_all_tests() await tester.run_all_tests()

View File

@@ -33,7 +33,7 @@ if backend_dir not in sys.path:
class TestDeveloperEcosystem: class TestDeveloperEcosystem:
"""开发者生态系统测试类""" """开发者生态系统测试类"""
def __init__(self): def __init__(self) -> None:
self.manager = DeveloperEcosystemManager() self.manager = DeveloperEcosystemManager()
self.test_results = [] self.test_results = []
self.created_ids = { self.created_ids = {
@@ -45,7 +45,7 @@ class TestDeveloperEcosystem:
"portal_config": [], "portal_config": [],
} }
def log(self, message: str, success: bool = True): def log(self, message: str, success: bool = True) -> None:
"""记录测试结果""" """记录测试结果"""
status = "" if success else "" status = "" if success else ""
print(f"{status} {message}") print(f"{status} {message}")
@@ -53,11 +53,11 @@ class TestDeveloperEcosystem:
{"message": message, "success": success, "timestamp": datetime.now().isoformat()} {"message": message, "success": success, "timestamp": datetime.now().isoformat()}
) )
def run_all_tests(self): def run_all_tests(self) -> None:
"""运行所有测试""" """运行所有测试"""
print("=" * 60) print(" = " * 60)
print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests") print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests")
print("=" * 60) print(" = " * 60)
# SDK Tests # SDK Tests
print("\n📦 SDK Release & Management Tests") print("\n📦 SDK Release & Management Tests")
@@ -119,69 +119,69 @@ class TestDeveloperEcosystem:
# Print Summary # Print Summary
self.print_summary() self.print_summary()
def test_sdk_create(self): def test_sdk_create(self) -> None:
"""测试创建 SDK""" """测试创建 SDK"""
try: try:
sdk = self.manager.create_sdk_release( sdk = self.manager.create_sdk_release(
name="InsightFlow Python SDK", name = "InsightFlow Python SDK",
language=SDKLanguage.PYTHON, language = SDKLanguage.PYTHON,
version="1.0.0", version = "1.0.0",
description="Python SDK for InsightFlow API", description = "Python SDK for InsightFlow API",
changelog="Initial release", changelog = "Initial release",
download_url="https://pypi.org/insightflow/1.0.0", download_url = "https://pypi.org/insightflow/1.0.0",
documentation_url="https://docs.insightflow.io/python", documentation_url = "https://docs.insightflow.io/python",
repository_url="https://github.com/insightflow/python-sdk", repository_url = "https://github.com/insightflow/python-sdk",
package_name="insightflow", package_name = "insightflow",
min_platform_version="1.0.0", min_platform_version = "1.0.0",
dependencies=[{"name": "requests", "version": ">=2.0"}], dependencies = [{"name": "requests", "version": ">= 2.0"}],
file_size=1024000, file_size = 1024000,
checksum="abc123", checksum = "abc123",
created_by="test_user", created_by = "test_user",
) )
self.created_ids["sdk"].append(sdk.id) self.created_ids["sdk"].append(sdk.id)
self.log(f"Created SDK: {sdk.name} ({sdk.id})") self.log(f"Created SDK: {sdk.name} ({sdk.id})")
# Create JavaScript SDK # Create JavaScript SDK
sdk_js = self.manager.create_sdk_release( sdk_js = self.manager.create_sdk_release(
name="InsightFlow JavaScript SDK", name = "InsightFlow JavaScript SDK",
language=SDKLanguage.JAVASCRIPT, language = SDKLanguage.JAVASCRIPT,
version="1.0.0", version = "1.0.0",
description="JavaScript SDK for InsightFlow API", description = "JavaScript SDK for InsightFlow API",
changelog="Initial release", changelog = "Initial release",
download_url="https://npmjs.com/insightflow/1.0.0", download_url = "https://npmjs.com/insightflow/1.0.0",
documentation_url="https://docs.insightflow.io/js", documentation_url = "https://docs.insightflow.io/js",
repository_url="https://github.com/insightflow/js-sdk", repository_url = "https://github.com/insightflow/js-sdk",
package_name="@insightflow/sdk", package_name = "@insightflow/sdk",
min_platform_version="1.0.0", min_platform_version = "1.0.0",
dependencies=[{"name": "axios", "version": ">=0.21"}], dependencies = [{"name": "axios", "version": ">= 0.21"}],
file_size=512000, file_size = 512000,
checksum="def456", checksum = "def456",
created_by="test_user", created_by = "test_user",
) )
self.created_ids["sdk"].append(sdk_js.id) self.created_ids["sdk"].append(sdk_js.id)
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})") self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
except Exception as e: except Exception as e:
self.log(f"Failed to create SDK: {str(e)}", success=False) self.log(f"Failed to create SDK: {str(e)}", success = False)
def test_sdk_list(self): def test_sdk_list(self) -> None:
"""测试列出 SDK""" """测试列出 SDK"""
try: try:
sdks = self.manager.list_sdk_releases() sdks = self.manager.list_sdk_releases()
self.log(f"Listed {len(sdks)} SDKs") self.log(f"Listed {len(sdks)} SDKs")
# Test filter by language # Test filter by language
python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON) python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON)
self.log(f"Found {len(python_sdks)} Python SDKs") self.log(f"Found {len(python_sdks)} Python SDKs")
# Test search # Test search
search_results = self.manager.list_sdk_releases(search="Python") search_results = self.manager.list_sdk_releases(search = "Python")
self.log(f"Search found {len(search_results)} SDKs") self.log(f"Search found {len(search_results)} SDKs")
except Exception as e: except Exception as e:
self.log(f"Failed to list SDKs: {str(e)}", success=False) self.log(f"Failed to list SDKs: {str(e)}", success = False)
def test_sdk_get(self): def test_sdk_get(self) -> None:
"""测试获取 SDK 详情""" """测试获取 SDK 详情"""
try: try:
if self.created_ids["sdk"]: if self.created_ids["sdk"]:
@@ -189,23 +189,23 @@ class TestDeveloperEcosystem:
if sdk: if sdk:
self.log(f"Retrieved SDK: {sdk.name}") self.log(f"Retrieved SDK: {sdk.name}")
else: else:
self.log("SDK not found", success=False) self.log("SDK not found", success = False)
except Exception as e: except Exception as e:
self.log(f"Failed to get SDK: {str(e)}", success=False) self.log(f"Failed to get SDK: {str(e)}", success = False)
def test_sdk_update(self): def test_sdk_update(self) -> None:
"""测试更新 SDK""" """测试更新 SDK"""
try: try:
if self.created_ids["sdk"]: if self.created_ids["sdk"]:
sdk = self.manager.update_sdk_release( sdk = self.manager.update_sdk_release(
self.created_ids["sdk"][0], description="Updated description" self.created_ids["sdk"][0], description = "Updated description"
) )
if sdk: if sdk:
self.log(f"Updated SDK: {sdk.name}") self.log(f"Updated SDK: {sdk.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to update SDK: {str(e)}", success=False) self.log(f"Failed to update SDK: {str(e)}", success = False)
def test_sdk_publish(self): def test_sdk_publish(self) -> None:
"""测试发布 SDK""" """测试发布 SDK"""
try: try:
if self.created_ids["sdk"]: if self.created_ids["sdk"]:
@@ -213,86 +213,86 @@ class TestDeveloperEcosystem:
if sdk: if sdk:
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})") self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
except Exception as e: except Exception as e:
self.log(f"Failed to publish SDK: {str(e)}", success=False) self.log(f"Failed to publish SDK: {str(e)}", success = False)
def test_sdk_version_add(self): def test_sdk_version_add(self) -> None:
"""测试添加 SDK 版本""" """测试添加 SDK 版本"""
try: try:
if self.created_ids["sdk"]: if self.created_ids["sdk"]:
version = self.manager.add_sdk_version( version = self.manager.add_sdk_version(
sdk_id=self.created_ids["sdk"][0], sdk_id = self.created_ids["sdk"][0],
version="1.1.0", version = "1.1.0",
is_lts=True, is_lts = True,
release_notes="Bug fixes and improvements", release_notes = "Bug fixes and improvements",
download_url="https://pypi.org/insightflow/1.1.0", download_url = "https://pypi.org/insightflow/1.1.0",
checksum="xyz789", checksum = "xyz789",
file_size=1100000, file_size = 1100000,
) )
self.log(f"Added SDK version: {version.version}") self.log(f"Added SDK version: {version.version}")
except Exception as e: except Exception as e:
self.log(f"Failed to add SDK version: {str(e)}", success=False) self.log(f"Failed to add SDK version: {str(e)}", success = False)
def test_template_create(self): def test_template_create(self) -> None:
"""测试创建模板""" """测试创建模板"""
try: try:
template = self.manager.create_template( template = self.manager.create_template(
name="医疗行业实体识别模板", name = "医疗行业实体识别模板",
description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体", description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
category=TemplateCategory.MEDICAL, category = TemplateCategory.MEDICAL,
subcategory="entity_recognition", subcategory = "entity_recognition",
tags=["medical", "healthcare", "ner"], tags = ["medical", "healthcare", "ner"],
author_id="dev_001", author_id = "dev_001",
author_name="Medical AI Lab", author_name = "Medical AI Lab",
price=99.0, price = 99.0,
currency="CNY", currency = "CNY",
preview_image_url="https://cdn.insightflow.io/templates/medical.png", preview_image_url = "https://cdn.insightflow.io/templates/medical.png",
demo_url="https://demo.insightflow.io/medical", demo_url = "https://demo.insightflow.io/medical",
documentation_url="https://docs.insightflow.io/templates/medical", documentation_url = "https://docs.insightflow.io/templates/medical",
download_url="https://cdn.insightflow.io/templates/medical.zip", download_url = "https://cdn.insightflow.io/templates/medical.zip",
version="1.0.0", version = "1.0.0",
min_platform_version="2.0.0", min_platform_version = "2.0.0",
file_size=5242880, file_size = 5242880,
checksum="tpl123", checksum = "tpl123",
) )
self.created_ids["template"].append(template.id) self.created_ids["template"].append(template.id)
self.log(f"Created template: {template.name} ({template.id})") self.log(f"Created template: {template.name} ({template.id})")
# Create free template # Create free template
template_free = self.manager.create_template( template_free = self.manager.create_template(
name="通用实体识别模板", name = "通用实体识别模板",
description="适用于一般场景的实体识别模板", description = "适用于一般场景的实体识别模板",
category=TemplateCategory.GENERAL, category = TemplateCategory.GENERAL,
subcategory=None, subcategory = None,
tags=["general", "ner", "basic"], tags = ["general", "ner", "basic"],
author_id="dev_002", author_id = "dev_002",
author_name="InsightFlow Team", author_name = "InsightFlow Team",
price=0.0, price = 0.0,
currency="CNY", currency = "CNY",
) )
self.created_ids["template"].append(template_free.id) self.created_ids["template"].append(template_free.id)
self.log(f"Created free template: {template_free.name}") self.log(f"Created free template: {template_free.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create template: {str(e)}", success=False) self.log(f"Failed to create template: {str(e)}", success = False)
def test_template_list(self): def test_template_list(self) -> None:
"""测试列出模板""" """测试列出模板"""
try: try:
templates = self.manager.list_templates() templates = self.manager.list_templates()
self.log(f"Listed {len(templates)} templates") self.log(f"Listed {len(templates)} templates")
# Filter by category # Filter by category
medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL) medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL)
self.log(f"Found {len(medical_templates)} medical templates") self.log(f"Found {len(medical_templates)} medical templates")
# Filter by price # Filter by price
free_templates = self.manager.list_templates(max_price=0) free_templates = self.manager.list_templates(max_price = 0)
self.log(f"Found {len(free_templates)} free templates") self.log(f"Found {len(free_templates)} free templates")
except Exception as e: except Exception as e:
self.log(f"Failed to list templates: {str(e)}", success=False) self.log(f"Failed to list templates: {str(e)}", success = False)
def test_template_get(self): def test_template_get(self) -> None:
"""测试获取模板详情""" """测试获取模板详情"""
try: try:
if self.created_ids["template"]: if self.created_ids["template"]:
@@ -300,21 +300,21 @@ class TestDeveloperEcosystem:
if template: if template:
self.log(f"Retrieved template: {template.name}") self.log(f"Retrieved template: {template.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get template: {str(e)}", success=False) self.log(f"Failed to get template: {str(e)}", success = False)
def test_template_approve(self): def test_template_approve(self) -> None:
"""测试审核通过模板""" """测试审核通过模板"""
try: try:
if self.created_ids["template"]: if self.created_ids["template"]:
template = self.manager.approve_template( template = self.manager.approve_template(
self.created_ids["template"][0], reviewed_by="admin_001" self.created_ids["template"][0], reviewed_by = "admin_001"
) )
if template: if template:
self.log(f"Approved template: {template.name}") self.log(f"Approved template: {template.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to approve template: {str(e)}", success=False) self.log(f"Failed to approve template: {str(e)}", success = False)
def test_template_publish(self): def test_template_publish(self) -> None:
"""测试发布模板""" """测试发布模板"""
try: try:
if self.created_ids["template"]: if self.created_ids["template"]:
@@ -322,84 +322,84 @@ class TestDeveloperEcosystem:
if template: if template:
self.log(f"Published template: {template.name}") self.log(f"Published template: {template.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to publish template: {str(e)}", success=False) self.log(f"Failed to publish template: {str(e)}", success = False)
def test_template_review(self): def test_template_review(self) -> None:
"""测试添加模板评价""" """测试添加模板评价"""
try: try:
if self.created_ids["template"]: if self.created_ids["template"]:
review = self.manager.add_template_review( review = self.manager.add_template_review(
template_id=self.created_ids["template"][0], template_id = self.created_ids["template"][0],
user_id="user_001", user_id = "user_001",
user_name="Test User", user_name = "Test User",
rating=5, rating = 5,
comment="Great template! Very accurate for medical entities.", comment = "Great template! Very accurate for medical entities.",
is_verified_purchase=True, is_verified_purchase = True,
) )
self.log(f"Added template review: {review.rating} stars") self.log(f"Added template review: {review.rating} stars")
except Exception as e: except Exception as e:
self.log(f"Failed to add template review: {str(e)}", success=False) self.log(f"Failed to add template review: {str(e)}", success = False)
def test_plugin_create(self): def test_plugin_create(self) -> None:
"""测试创建插件""" """测试创建插件"""
try: try:
plugin = self.manager.create_plugin( plugin = self.manager.create_plugin(
name="飞书机器人集成插件", name = "飞书机器人集成插件",
description="将 InsightFlow 与飞书机器人集成,实现自动通知", description = "将 InsightFlow 与飞书机器人集成,实现自动通知",
category=PluginCategory.INTEGRATION, category = PluginCategory.INTEGRATION,
tags=["feishu", "bot", "integration", "notification"], tags = ["feishu", "bot", "integration", "notification"],
author_id="dev_003", author_id = "dev_003",
author_name="Integration Team", author_name = "Integration Team",
price=49.0, price = 49.0,
currency="CNY", currency = "CNY",
pricing_model="paid", pricing_model = "paid",
preview_image_url="https://cdn.insightflow.io/plugins/feishu.png", preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png",
demo_url="https://demo.insightflow.io/feishu", demo_url = "https://demo.insightflow.io/feishu",
documentation_url="https://docs.insightflow.io/plugins/feishu", documentation_url = "https://docs.insightflow.io/plugins/feishu",
repository_url="https://github.com/insightflow/feishu-plugin", repository_url = "https://github.com/insightflow/feishu-plugin",
download_url="https://cdn.insightflow.io/plugins/feishu.zip", download_url = "https://cdn.insightflow.io/plugins/feishu.zip",
webhook_url="https://api.insightflow.io/webhooks/feishu", webhook_url = "https://api.insightflow.io/webhooks/feishu",
permissions=["read:projects", "write:notifications"], permissions = ["read:projects", "write:notifications"],
version="1.0.0", version = "1.0.0",
min_platform_version="2.0.0", min_platform_version = "2.0.0",
file_size=1048576, file_size = 1048576,
checksum="plg123", checksum = "plg123",
) )
self.created_ids["plugin"].append(plugin.id) self.created_ids["plugin"].append(plugin.id)
self.log(f"Created plugin: {plugin.name} ({plugin.id})") self.log(f"Created plugin: {plugin.name} ({plugin.id})")
# Create free plugin # Create free plugin
plugin_free = self.manager.create_plugin( plugin_free = self.manager.create_plugin(
name="数据导出插件", name = "数据导出插件",
description="支持多种格式的数据导出", description = "支持多种格式的数据导出",
category=PluginCategory.ANALYSIS, category = PluginCategory.ANALYSIS,
tags=["export", "data", "csv", "json"], tags = ["export", "data", "csv", "json"],
author_id="dev_004", author_id = "dev_004",
author_name="Data Team", author_name = "Data Team",
price=0.0, price = 0.0,
currency="CNY", currency = "CNY",
pricing_model="free", pricing_model = "free",
) )
self.created_ids["plugin"].append(plugin_free.id) self.created_ids["plugin"].append(plugin_free.id)
self.log(f"Created free plugin: {plugin_free.name}") self.log(f"Created free plugin: {plugin_free.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create plugin: {str(e)}", success=False) self.log(f"Failed to create plugin: {str(e)}", success = False)
def test_plugin_list(self): def test_plugin_list(self) -> None:
"""测试列出插件""" """测试列出插件"""
try: try:
plugins = self.manager.list_plugins() plugins = self.manager.list_plugins()
self.log(f"Listed {len(plugins)} plugins") self.log(f"Listed {len(plugins)} plugins")
# Filter by category # Filter by category
integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION) integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION)
self.log(f"Found {len(integration_plugins)} integration plugins") self.log(f"Found {len(integration_plugins)} integration plugins")
except Exception as e: except Exception as e:
self.log(f"Failed to list plugins: {str(e)}", success=False) self.log(f"Failed to list plugins: {str(e)}", success = False)
def test_plugin_get(self): def test_plugin_get(self) -> None:
"""测试获取插件详情""" """测试获取插件详情"""
try: try:
if self.created_ids["plugin"]: if self.created_ids["plugin"]:
@@ -407,24 +407,24 @@ class TestDeveloperEcosystem:
if plugin: if plugin:
self.log(f"Retrieved plugin: {plugin.name}") self.log(f"Retrieved plugin: {plugin.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get plugin: {str(e)}", success=False) self.log(f"Failed to get plugin: {str(e)}", success = False)
def test_plugin_review(self): def test_plugin_review(self) -> None:
"""测试审核插件""" """测试审核插件"""
try: try:
if self.created_ids["plugin"]: if self.created_ids["plugin"]:
plugin = self.manager.review_plugin( plugin = self.manager.review_plugin(
self.created_ids["plugin"][0], self.created_ids["plugin"][0],
reviewed_by="admin_001", reviewed_by = "admin_001",
status=PluginStatus.APPROVED, status = PluginStatus.APPROVED,
notes="Code review passed", notes = "Code review passed",
) )
if plugin: if plugin:
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})") self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
except Exception as e: except Exception as e:
self.log(f"Failed to review plugin: {str(e)}", success=False) self.log(f"Failed to review plugin: {str(e)}", success = False)
def test_plugin_publish(self): def test_plugin_publish(self) -> None:
"""测试发布插件""" """测试发布插件"""
try: try:
if self.created_ids["plugin"]: if self.created_ids["plugin"]:
@@ -432,56 +432,56 @@ class TestDeveloperEcosystem:
if plugin: if plugin:
self.log(f"Published plugin: {plugin.name}") self.log(f"Published plugin: {plugin.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to publish plugin: {str(e)}", success=False) self.log(f"Failed to publish plugin: {str(e)}", success = False)
def test_plugin_review_add(self): def test_plugin_review_add(self) -> None:
"""测试添加插件评价""" """测试添加插件评价"""
try: try:
if self.created_ids["plugin"]: if self.created_ids["plugin"]:
review = self.manager.add_plugin_review( review = self.manager.add_plugin_review(
plugin_id=self.created_ids["plugin"][0], plugin_id = self.created_ids["plugin"][0],
user_id="user_002", user_id = "user_002",
user_name="Plugin User", user_name = "Plugin User",
rating=4, rating = 4,
comment="Works great with Feishu!", comment = "Works great with Feishu!",
is_verified_purchase=True, is_verified_purchase = True,
) )
self.log(f"Added plugin review: {review.rating} stars") self.log(f"Added plugin review: {review.rating} stars")
except Exception as e: except Exception as e:
self.log(f"Failed to add plugin review: {str(e)}", success=False) self.log(f"Failed to add plugin review: {str(e)}", success = False)
def test_developer_profile_create(self): def test_developer_profile_create(self) -> None:
"""测试创建开发者档案""" """测试创建开发者档案"""
try: try:
# Generate unique user IDs # Generate unique user IDs
unique_id = uuid.uuid4().hex[:8] unique_id = uuid.uuid4().hex[:8]
profile = self.manager.create_developer_profile( profile = self.manager.create_developer_profile(
user_id=f"user_dev_{unique_id}_001", user_id = f"user_dev_{unique_id}_001",
display_name="张三", display_name = "张三",
email=f"zhangsan_{unique_id}@example.com", email = f"zhangsan_{unique_id}@example.com",
bio="专注于医疗AI和自然语言处理", bio = "专注于医疗AI和自然语言处理",
website="https://zhangsan.dev", website = "https://zhangsan.dev",
github_url="https://github.com/zhangsan", github_url = "https://github.com/zhangsan",
avatar_url="https://cdn.example.com/avatars/zhangsan.png", avatar_url = "https://cdn.example.com/avatars/zhangsan.png",
) )
self.created_ids["developer"].append(profile.id) self.created_ids["developer"].append(profile.id)
self.log(f"Created developer profile: {profile.display_name} ({profile.id})") self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
# Create another developer # Create another developer
profile2 = self.manager.create_developer_profile( profile2 = self.manager.create_developer_profile(
user_id=f"user_dev_{unique_id}_002", user_id = f"user_dev_{unique_id}_002",
display_name="李四", display_name = "李四",
email=f"lisi_{unique_id}@example.com", email = f"lisi_{unique_id}@example.com",
bio="全栈开发者,热爱开源", bio = "全栈开发者,热爱开源",
) )
self.created_ids["developer"].append(profile2.id) self.created_ids["developer"].append(profile2.id)
self.log(f"Created developer profile: {profile2.display_name}") self.log(f"Created developer profile: {profile2.display_name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create developer profile: {str(e)}", success=False) self.log(f"Failed to create developer profile: {str(e)}", success = False)
def test_developer_profile_get(self): def test_developer_profile_get(self) -> None:
"""测试获取开发者档案""" """测试获取开发者档案"""
try: try:
if self.created_ids["developer"]: if self.created_ids["developer"]:
@@ -489,9 +489,9 @@ class TestDeveloperEcosystem:
if profile: if profile:
self.log(f"Retrieved developer profile: {profile.display_name}") self.log(f"Retrieved developer profile: {profile.display_name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get developer profile: {str(e)}", success=False) self.log(f"Failed to get developer profile: {str(e)}", success = False)
def test_developer_verify(self): def test_developer_verify(self) -> None:
"""测试验证开发者""" """测试验证开发者"""
try: try:
if self.created_ids["developer"]: if self.created_ids["developer"]:
@@ -501,9 +501,9 @@ class TestDeveloperEcosystem:
if profile: if profile:
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})") self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
except Exception as e: except Exception as e:
self.log(f"Failed to verify developer: {str(e)}", success=False) self.log(f"Failed to verify developer: {str(e)}", success = False)
def test_developer_stats_update(self): def test_developer_stats_update(self) -> None:
"""测试更新开发者统计""" """测试更新开发者统计"""
try: try:
if self.created_ids["developer"]: if self.created_ids["developer"]:
@@ -513,38 +513,38 @@ class TestDeveloperEcosystem:
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates" f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates"
) )
except Exception as e: except Exception as e:
self.log(f"Failed to update developer stats: {str(e)}", success=False) self.log(f"Failed to update developer stats: {str(e)}", success = False)
def test_code_example_create(self): def test_code_example_create(self) -> None:
"""测试创建代码示例""" """测试创建代码示例"""
try: try:
example = self.manager.create_code_example( example = self.manager.create_code_example(
title="使用 Python SDK 创建项目", title = "使用 Python SDK 创建项目",
description="演示如何使用 Python SDK 创建新项目", description = "演示如何使用 Python SDK 创建新项目",
language="python", language = "python",
category="quickstart", category = "quickstart",
code="""from insightflow import Client code = """from insightflow import Client
client = Client(api_key="your_api_key") client = Client(api_key = "your_api_key")
project = client.projects.create(name="My Project") project = client.projects.create(name = "My Project")
print(f"Created project: {project.id}") print(f"Created project: {project.id}")
""", """,
explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。", explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
tags=["python", "quickstart", "projects"], tags = ["python", "quickstart", "projects"],
author_id="dev_001", author_id = "dev_001",
author_name="InsightFlow Team", author_name = "InsightFlow Team",
api_endpoints=["/api/v1/projects"], api_endpoints = ["/api/v1/projects"],
) )
self.created_ids["code_example"].append(example.id) self.created_ids["code_example"].append(example.id)
self.log(f"Created code example: {example.title}") self.log(f"Created code example: {example.title}")
# Create JavaScript example # Create JavaScript example
example_js = self.manager.create_code_example( example_js = self.manager.create_code_example(
title="使用 JavaScript SDK 上传文件", title = "使用 JavaScript SDK 上传文件",
description="演示如何使用 JavaScript SDK 上传音频文件", description = "演示如何使用 JavaScript SDK 上传音频文件",
language="javascript", language = "javascript",
category="upload", category = "upload",
code="""const { Client } = require('insightflow'); code = """const { Client } = require('insightflow');
const client = new Client({ apiKey: 'your_api_key' }); const client = new Client({ apiKey: 'your_api_key' });
const result = await client.uploads.create({ const result = await client.uploads.create({
@@ -553,31 +553,31 @@ const result = await client.uploads.create({
}); });
console.log('Upload complete:', result.id); console.log('Upload complete:', result.id);
""", """,
explanation="使用 JavaScript SDK 上传文件到 InsightFlow", explanation = "使用 JavaScript SDK 上传文件到 InsightFlow",
tags=["javascript", "upload", "audio"], tags = ["javascript", "upload", "audio"],
author_id="dev_002", author_id = "dev_002",
author_name="JS Team", author_name = "JS Team",
) )
self.created_ids["code_example"].append(example_js.id) self.created_ids["code_example"].append(example_js.id)
self.log(f"Created code example: {example_js.title}") self.log(f"Created code example: {example_js.title}")
except Exception as e: except Exception as e:
self.log(f"Failed to create code example: {str(e)}", success=False) self.log(f"Failed to create code example: {str(e)}", success = False)
def test_code_example_list(self): def test_code_example_list(self) -> None:
"""测试列出代码示例""" """测试列出代码示例"""
try: try:
examples = self.manager.list_code_examples() examples = self.manager.list_code_examples()
self.log(f"Listed {len(examples)} code examples") self.log(f"Listed {len(examples)} code examples")
# Filter by language # Filter by language
python_examples = self.manager.list_code_examples(language="python") python_examples = self.manager.list_code_examples(language = "python")
self.log(f"Found {len(python_examples)} Python examples") self.log(f"Found {len(python_examples)} Python examples")
except Exception as e: except Exception as e:
self.log(f"Failed to list code examples: {str(e)}", success=False) self.log(f"Failed to list code examples: {str(e)}", success = False)
def test_code_example_get(self): def test_code_example_get(self) -> None:
"""测试获取代码示例详情""" """测试获取代码示例详情"""
try: try:
if self.created_ids["code_example"]: if self.created_ids["code_example"]:
@@ -587,30 +587,30 @@ console.log('Upload complete:', result.id);
f"Retrieved code example: {example.title} (views: {example.view_count})" f"Retrieved code example: {example.title} (views: {example.view_count})"
) )
except Exception as e: except Exception as e:
self.log(f"Failed to get code example: {str(e)}", success=False) self.log(f"Failed to get code example: {str(e)}", success = False)
def test_portal_config_create(self): def test_portal_config_create(self) -> None:
"""测试创建开发者门户配置""" """测试创建开发者门户配置"""
try: try:
config = self.manager.create_portal_config( config = self.manager.create_portal_config(
name="InsightFlow Developer Portal", name = "InsightFlow Developer Portal",
description="开发者门户 - SDK、API 文档和示例代码", description = "开发者门户 - SDK、API 文档和示例代码",
theme="default", theme = "default",
primary_color="#1890ff", primary_color = "#1890ff",
secondary_color="#52c41a", secondary_color = "#52c41a",
support_email="developers@insightflow.io", support_email = "developers@insightflow.io",
support_url="https://support.insightflow.io", support_url = "https://support.insightflow.io",
github_url="https://github.com/insightflow", github_url = "https://github.com/insightflow",
discord_url="https://discord.gg/insightflow", discord_url = "https://discord.gg/insightflow",
api_base_url="https://api.insightflow.io/v1", api_base_url = "https://api.insightflow.io/v1",
) )
self.created_ids["portal_config"].append(config.id) self.created_ids["portal_config"].append(config.id)
self.log(f"Created portal config: {config.name}") self.log(f"Created portal config: {config.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create portal config: {str(e)}", success=False) self.log(f"Failed to create portal config: {str(e)}", success = False)
def test_portal_config_get(self): def test_portal_config_get(self) -> None:
"""测试获取开发者门户配置""" """测试获取开发者门户配置"""
try: try:
if self.created_ids["portal_config"]: if self.created_ids["portal_config"]:
@@ -624,29 +624,29 @@ console.log('Upload complete:', result.id);
self.log(f"Active portal config: {active_config.name}") self.log(f"Active portal config: {active_config.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get portal config: {str(e)}", success=False) self.log(f"Failed to get portal config: {str(e)}", success = False)
def test_revenue_record(self): def test_revenue_record(self) -> None:
"""测试记录开发者收益""" """测试记录开发者收益"""
try: try:
if self.created_ids["developer"] and self.created_ids["plugin"]: if self.created_ids["developer"] and self.created_ids["plugin"]:
revenue = self.manager.record_revenue( revenue = self.manager.record_revenue(
developer_id=self.created_ids["developer"][0], developer_id = self.created_ids["developer"][0],
item_type="plugin", item_type = "plugin",
item_id=self.created_ids["plugin"][0], item_id = self.created_ids["plugin"][0],
item_name="飞书机器人集成插件", item_name = "飞书机器人集成插件",
sale_amount=49.0, sale_amount = 49.0,
currency="CNY", currency = "CNY",
buyer_id="user_buyer_001", buyer_id = "user_buyer_001",
transaction_id="txn_123456", transaction_id = "txn_123456",
) )
self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}") self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}")
self.log(f" - Platform fee: {revenue.platform_fee}") self.log(f" - Platform fee: {revenue.platform_fee}")
self.log(f" - Developer earnings: {revenue.developer_earnings}") self.log(f" - Developer earnings: {revenue.developer_earnings}")
except Exception as e: except Exception as e:
self.log(f"Failed to record revenue: {str(e)}", success=False) self.log(f"Failed to record revenue: {str(e)}", success = False)
def test_revenue_summary(self): def test_revenue_summary(self) -> None:
"""测试获取开发者收益汇总""" """测试获取开发者收益汇总"""
try: try:
if self.created_ids["developer"]: if self.created_ids["developer"]:
@@ -659,13 +659,13 @@ console.log('Upload complete:', result.id);
self.log(f" - Total earnings: {summary['total_earnings']}") self.log(f" - Total earnings: {summary['total_earnings']}")
self.log(f" - Transaction count: {summary['transaction_count']}") self.log(f" - Transaction count: {summary['transaction_count']}")
except Exception as e: except Exception as e:
self.log(f"Failed to get revenue summary: {str(e)}", success=False) self.log(f"Failed to get revenue summary: {str(e)}", success = False)
def print_summary(self): def print_summary(self) -> None:
"""打印测试摘要""" """打印测试摘要"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("Test Summary") print("Test Summary")
print("=" * 60) print(" = " * 60)
total = len(self.test_results) total = len(self.test_results)
passed = sum(1 for r in self.test_results if r["success"]) passed = sum(1 for r in self.test_results if r["success"])
@@ -686,10 +686,10 @@ console.log('Upload complete:', result.id);
if ids: if ids:
print(f" {resource_type}: {len(ids)}") print(f" {resource_type}: {len(ids)}")
print("=" * 60) print(" = " * 60)
def main(): def main() -> None:
"""主函数""" """主函数"""
test = TestDeveloperEcosystem() test = TestDeveloperEcosystem()
test.run_all_tests() test.run_all_tests()

View File

@@ -34,22 +34,22 @@ if backend_dir not in sys.path:
class TestOpsManager: class TestOpsManager:
"""测试运维与监控管理器""" """测试运维与监控管理器"""
def __init__(self): def __init__(self) -> None:
self.manager = get_ops_manager() self.manager = get_ops_manager()
self.tenant_id = "test_tenant_001" self.tenant_id = "test_tenant_001"
self.test_results = [] self.test_results = []
def log(self, message: str, success: bool = True): def log(self, message: str, success: bool = True) -> None:
"""记录测试结果""" """记录测试结果"""
status = "" if success else "" status = "" if success else ""
print(f"{status} {message}") print(f"{status} {message}")
self.test_results.append((message, success)) self.test_results.append((message, success))
def run_all_tests(self): def run_all_tests(self) -> None:
"""运行所有测试""" """运行所有测试"""
print("=" * 60) print(" = " * 60)
print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests") print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests")
print("=" * 60) print(" = " * 60)
# 1. 告警系统测试 # 1. 告警系统测试
self.test_alert_rules() self.test_alert_rules()
@@ -73,46 +73,46 @@ class TestOpsManager:
# 打印测试总结 # 打印测试总结
self.print_summary() self.print_summary()
def test_alert_rules(self): def test_alert_rules(self) -> None:
"""测试告警规则管理""" """测试告警规则管理"""
print("\n📋 Testing Alert Rules...") print("\n📋 Testing Alert Rules...")
try: try:
# 创建阈值告警规则 # 创建阈值告警规则
rule1 = self.manager.create_alert_rule( rule1 = self.manager.create_alert_rule(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="CPU 使用率告警", name = "CPU 使用率告警",
description="当 CPU 使用率超过 80% 时触发告警", description = "当 CPU 使用率超过 80% 时触发告警",
rule_type=AlertRuleType.THRESHOLD, rule_type = AlertRuleType.THRESHOLD,
severity=AlertSeverity.P1, severity = AlertSeverity.P1,
metric="cpu_usage_percent", metric = "cpu_usage_percent",
condition=">", condition = ">",
threshold=80.0, threshold = 80.0,
duration=300, duration = 300,
evaluation_interval=60, evaluation_interval = 60,
channels=[], channels = [],
labels={"service": "api", "team": "platform"}, labels = {"service": "api", "team": "platform"},
annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by="test_user", created_by = "test_user",
) )
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})") self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
# 创建异常检测告警规则 # 创建异常检测告警规则
rule2 = self.manager.create_alert_rule( rule2 = self.manager.create_alert_rule(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="内存异常检测", name = "内存异常检测",
description="检测内存使用异常", description = "检测内存使用异常",
rule_type=AlertRuleType.ANOMALY, rule_type = AlertRuleType.ANOMALY,
severity=AlertSeverity.P2, severity = AlertSeverity.P2,
metric="memory_usage_percent", metric = "memory_usage_percent",
condition=">", condition = ">",
threshold=0.0, threshold = 0.0,
duration=600, duration = 600,
evaluation_interval=300, evaluation_interval = 300,
channels=[], channels = [],
labels={"service": "database"}, labels = {"service": "database"},
annotations={}, annotations = {},
created_by="test_user", created_by = "test_user",
) )
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})") self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
@@ -129,7 +129,7 @@ class TestOpsManager:
# 更新告警规则 # 更新告警规则
updated_rule = self.manager.update_alert_rule( updated_rule = self.manager.update_alert_rule(
rule1.id, threshold=85.0, description="更新后的描述" rule1.id, threshold = 85.0, description = "更新后的描述"
) )
assert updated_rule.threshold == 85.0 assert updated_rule.threshold == 85.0
self.log(f"Updated alert rule threshold to {updated_rule.threshold}") self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
@@ -140,46 +140,46 @@ class TestOpsManager:
self.log("Deleted test alert rules") self.log("Deleted test alert rules")
except Exception as e: except Exception as e:
self.log(f"Alert rules test failed: {e}", success=False) self.log(f"Alert rules test failed: {e}", success = False)
def test_alert_channels(self): def test_alert_channels(self) -> None:
"""测试告警渠道管理""" """测试告警渠道管理"""
print("\n📢 Testing Alert Channels...") print("\n📢 Testing Alert Channels...")
try: try:
# 创建飞书告警渠道 # 创建飞书告警渠道
channel1 = self.manager.create_alert_channel( channel1 = self.manager.create_alert_channel(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="飞书告警", name = "飞书告警",
channel_type=AlertChannelType.FEISHU, channel_type = AlertChannelType.FEISHU,
config={ config = {
"webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test", "webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test",
"secret": "test_secret", "secret": "test_secret",
}, },
severity_filter=["p0", "p1"], severity_filter = ["p0", "p1"],
) )
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})") self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
# 创建钉钉告警渠道 # 创建钉钉告警渠道
channel2 = self.manager.create_alert_channel( channel2 = self.manager.create_alert_channel(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="钉钉告警", name = "钉钉告警",
channel_type=AlertChannelType.DINGTALK, channel_type = AlertChannelType.DINGTALK,
config={ config = {
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test", "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test",
"secret": "test_secret", "secret": "test_secret",
}, },
severity_filter=["p0", "p1", "p2"], severity_filter = ["p0", "p1", "p2"],
) )
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})") self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
# 创建 Slack 告警渠道 # 创建 Slack 告警渠道
channel3 = self.manager.create_alert_channel( channel3 = self.manager.create_alert_channel(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="Slack 告警", name = "Slack 告警",
channel_type=AlertChannelType.SLACK, channel_type = AlertChannelType.SLACK,
config={"webhook_url": "https://hooks.slack.com/services/test"}, config = {"webhook_url": "https://hooks.slack.com/services/test"},
severity_filter=["p0", "p1", "p2", "p3"], severity_filter = ["p0", "p1", "p2", "p3"],
) )
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})") self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
@@ -198,46 +198,46 @@ class TestOpsManager:
for channel in channels: for channel in channels:
if channel.tenant_id == self.tenant_id: if channel.tenant_id == self.tenant_id:
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,)) conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id, ))
conn.commit() conn.commit()
self.log("Deleted test alert channels") self.log("Deleted test alert channels")
except Exception as e: except Exception as e:
self.log(f"Alert channels test failed: {e}", success=False) self.log(f"Alert channels test failed: {e}", success = False)
def test_alerts(self): def test_alerts(self) -> None:
"""测试告警管理""" """测试告警管理"""
print("\n🚨 Testing Alerts...") print("\n🚨 Testing Alerts...")
try: try:
# 创建告警规则 # 创建告警规则
rule = self.manager.create_alert_rule( rule = self.manager.create_alert_rule(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="测试告警规则", name = "测试告警规则",
description="用于测试的告警规则", description = "用于测试的告警规则",
rule_type=AlertRuleType.THRESHOLD, rule_type = AlertRuleType.THRESHOLD,
severity=AlertSeverity.P1, severity = AlertSeverity.P1,
metric="test_metric", metric = "test_metric",
condition=">", condition = ">",
threshold=100.0, threshold = 100.0,
duration=60, duration = 60,
evaluation_interval=60, evaluation_interval = 60,
channels=[], channels = [],
labels={}, labels = {},
annotations={}, annotations = {},
created_by="test_user", created_by = "test_user",
) )
# 记录资源指标 # 记录资源指标
for i in range(10): for i in range(10):
self.manager.record_resource_metric( self.manager.record_resource_metric(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
resource_type=ResourceType.CPU, resource_type = ResourceType.CPU,
resource_id="server-001", resource_id = "server-001",
metric_name="test_metric", metric_name = "test_metric",
metric_value=110.0 + i, metric_value = 110.0 + i,
unit="percent", unit = "percent",
metadata={"region": "cn-north-1"}, metadata = {"region": "cn-north-1"},
) )
self.log("Recorded 10 resource metrics") self.log("Recorded 10 resource metrics")
@@ -248,24 +248,24 @@ class TestOpsManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
alert = Alert( alert = Alert(
id=alert_id, id = alert_id,
rule_id=rule.id, rule_id = rule.id,
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
severity=AlertSeverity.P1, severity = AlertSeverity.P1,
status=AlertStatus.FIRING, status = AlertStatus.FIRING,
title="测试告警", title = "测试告警",
description="这是一条测试告警", description = "这是一条测试告警",
metric="test_metric", metric = "test_metric",
value=120.0, value = 120.0,
threshold=100.0, threshold = 100.0,
labels={"test": "true"}, labels = {"test": "true"},
annotations={}, annotations = {},
started_at=now, started_at = now,
resolved_at=None, resolved_at = None,
acknowledged_by=None, acknowledged_by = None,
acknowledged_at=None, acknowledged_at = None,
notification_sent={}, notification_sent = {},
suppression_count=0, suppression_count = 0,
) )
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
@@ -320,23 +320,23 @@ class TestOpsManager:
# 清理 # 清理
self.manager.delete_alert_rule(rule.id) self.manager.delete_alert_rule(rule.id)
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id,)) conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id, ))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit() conn.commit()
self.log("Cleaned up test data") self.log("Cleaned up test data")
except Exception as e: except Exception as e:
self.log(f"Alerts test failed: {e}", success=False) self.log(f"Alerts test failed: {e}", success = False)
def test_capacity_planning(self): def test_capacity_planning(self) -> None:
"""测试容量规划""" """测试容量规划"""
print("\n📊 Testing Capacity Planning...") print("\n📊 Testing Capacity Planning...")
try: try:
# 记录历史指标数据 # 记录历史指标数据
base_time = datetime.now() - timedelta(days=30) base_time = datetime.now() - timedelta(days = 30)
for i in range(30): for i in range(30):
timestamp = (base_time + timedelta(days=i)).isoformat() timestamp = (base_time + timedelta(days = i)).isoformat()
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute( conn.execute(
""" """
@@ -360,13 +360,13 @@ class TestOpsManager:
self.log("Recorded 30 days of historical metrics") self.log("Recorded 30 days of historical metrics")
# 创建容量规划 # 创建容量规划
prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d")
plan = self.manager.create_capacity_plan( plan = self.manager.create_capacity_plan(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
resource_type=ResourceType.CPU, resource_type = ResourceType.CPU,
current_capacity=100.0, current_capacity = 100.0,
prediction_date=prediction_date, prediction_date = prediction_date,
confidence=0.85, confidence = 0.85,
) )
self.log(f"Created capacity plan: {plan.id}") self.log(f"Created capacity plan: {plan.id}")
@@ -381,32 +381,32 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit() conn.commit()
self.log("Cleaned up capacity planning test data") self.log("Cleaned up capacity planning test data")
except Exception as e: except Exception as e:
self.log(f"Capacity planning test failed: {e}", success=False) self.log(f"Capacity planning test failed: {e}", success = False)
def test_auto_scaling(self): def test_auto_scaling(self) -> None:
"""测试自动扩缩容""" """测试自动扩缩容"""
print("\n⚖️ Testing Auto Scaling...") print("\n⚖️ Testing Auto Scaling...")
try: try:
# 创建自动扩缩容策略 # 创建自动扩缩容策略
policy = self.manager.create_auto_scaling_policy( policy = self.manager.create_auto_scaling_policy(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="API 服务自动扩缩容", name = "API 服务自动扩缩容",
resource_type=ResourceType.CPU, resource_type = ResourceType.CPU,
min_instances=2, min_instances = 2,
max_instances=10, max_instances = 10,
target_utilization=0.7, target_utilization = 0.7,
scale_up_threshold=0.8, scale_up_threshold = 0.8,
scale_down_threshold=0.3, scale_down_threshold = 0.3,
scale_up_step=2, scale_up_step = 2,
scale_down_step=1, scale_down_step = 1,
cooldown_period=300, cooldown_period = 300,
) )
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})") self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
@@ -421,7 +421,7 @@ class TestOpsManager:
# 模拟扩缩容评估 # 模拟扩缩容评估
event = self.manager.evaluate_scaling_policy( event = self.manager.evaluate_scaling_policy(
policy_id=policy.id, current_instances=3, current_utilization=0.85 policy_id = policy.id, current_instances = 3, current_utilization = 0.85
) )
if event: if event:
@@ -437,46 +437,46 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute( conn.execute(
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,) "DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id, )
) )
conn.commit() conn.commit()
self.log("Cleaned up auto scaling test data") self.log("Cleaned up auto scaling test data")
except Exception as e: except Exception as e:
self.log(f"Auto scaling test failed: {e}", success=False) self.log(f"Auto scaling test failed: {e}", success = False)
def test_health_checks(self): def test_health_checks(self) -> None:
"""测试健康检查""" """测试健康检查"""
print("\n💓 Testing Health Checks...") print("\n💓 Testing Health Checks...")
try: try:
# 创建 HTTP 健康检查 # 创建 HTTP 健康检查
check1 = self.manager.create_health_check( check1 = self.manager.create_health_check(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="API 服务健康检查", name = "API 服务健康检查",
target_type="service", target_type = "service",
target_id="api-service", target_id = "api-service",
check_type="http", check_type = "http",
check_config={"url": "https://api.insightflow.io/health", "expected_status": 200}, check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200},
interval=60, interval = 60,
timeout=10, timeout = 10,
retry_count=3, retry_count = 3,
) )
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})") self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
# 创建 TCP 健康检查 # 创建 TCP 健康检查
check2 = self.manager.create_health_check( check2 = self.manager.create_health_check(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="数据库健康检查", name = "数据库健康检查",
target_type="database", target_type = "database",
target_id="postgres-001", target_id = "postgres-001",
check_type="tcp", check_type = "tcp",
check_config={"host": "db.insightflow.io", "port": 5432}, check_config = {"host": "db.insightflow.io", "port": 5432},
interval=30, interval = 30,
timeout=5, timeout = 5,
retry_count=2, retry_count = 2,
) )
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})") self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
@@ -486,7 +486,7 @@ class TestOpsManager:
self.log(f"Listed {len(checks)} health checks") self.log(f"Listed {len(checks)} health checks")
# 执行健康检查(异步) # 执行健康检查(异步)
async def run_health_check(): async def run_health_check() -> None:
result = await self.manager.execute_health_check(check1.id) result = await self.manager.execute_health_check(check1.id)
return result return result
@@ -495,28 +495,28 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit() conn.commit()
self.log("Cleaned up health check test data") self.log("Cleaned up health check test data")
except Exception as e: except Exception as e:
self.log(f"Health checks test failed: {e}", success=False) self.log(f"Health checks test failed: {e}", success = False)
def test_failover(self): def test_failover(self) -> None:
"""测试故障转移""" """测试故障转移"""
print("\n🔄 Testing Failover...") print("\n🔄 Testing Failover...")
try: try:
# 创建故障转移配置 # 创建故障转移配置
config = self.manager.create_failover_config( config = self.manager.create_failover_config(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="主备数据中心故障转移", name = "主备数据中心故障转移",
primary_region="cn-north-1", primary_region = "cn-north-1",
secondary_regions=["cn-south-1", "cn-east-1"], secondary_regions = ["cn-south-1", "cn-east-1"],
failover_trigger="health_check_failed", failover_trigger = "health_check_failed",
auto_failover=False, auto_failover = False,
failover_timeout=300, failover_timeout = 300,
health_check_id=None, health_check_id = None,
) )
self.log(f"Created failover config: {config.name} (ID: {config.id})") self.log(f"Created failover config: {config.name} (ID: {config.id})")
@@ -530,7 +530,7 @@ class TestOpsManager:
# 发起故障转移 # 发起故障转移
event = self.manager.initiate_failover( event = self.manager.initiate_failover(
config_id=config.id, reason="Primary region health check failed" config_id = config.id, reason = "Primary region health check failed"
) )
if event: if event:
@@ -550,31 +550,31 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit() conn.commit()
self.log("Cleaned up failover test data") self.log("Cleaned up failover test data")
except Exception as e: except Exception as e:
self.log(f"Failover test failed: {e}", success=False) self.log(f"Failover test failed: {e}", success = False)
def test_backup(self): def test_backup(self) -> None:
"""测试备份与恢复""" """测试备份与恢复"""
print("\n💾 Testing Backup & Recovery...") print("\n💾 Testing Backup & Recovery...")
try: try:
# 创建备份任务 # 创建备份任务
job = self.manager.create_backup_job( job = self.manager.create_backup_job(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
name="每日数据库备份", name = "每日数据库备份",
backup_type="full", backup_type = "full",
target_type="database", target_type = "database",
target_id="postgres-main", target_id = "postgres-main",
schedule="0 2 * * *", # 每天凌晨2点 schedule = "0 2 * * *", # 每天凌晨2点
retention_days=30, retention_days = 30,
encryption_enabled=True, encryption_enabled = True,
compression_enabled=True, compression_enabled = True,
storage_location="s3://insightflow-backups/", storage_location = "s3://insightflow-backups/",
) )
self.log(f"Created backup job: {job.name} (ID: {job.id})") self.log(f"Created backup job: {job.name} (ID: {job.id})")
@@ -604,15 +604,15 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit() conn.commit()
self.log("Cleaned up backup test data") self.log("Cleaned up backup test data")
except Exception as e: except Exception as e:
self.log(f"Backup test failed: {e}", success=False) self.log(f"Backup test failed: {e}", success = False)
def test_cost_optimization(self): def test_cost_optimization(self) -> None:
"""测试成本优化""" """测试成本优化"""
print("\n💰 Testing Cost Optimization...") print("\n💰 Testing Cost Optimization...")
@@ -622,15 +622,15 @@ class TestOpsManager:
for i in range(5): for i in range(5):
self.manager.record_resource_utilization( self.manager.record_resource_utilization(
tenant_id=self.tenant_id, tenant_id = self.tenant_id,
resource_type=ResourceType.CPU, resource_type = ResourceType.CPU,
resource_id=f"server-{i:03d}", resource_id = f"server-{i:03d}",
utilization_rate=0.05 + random.random() * 0.1, # 低利用率 utilization_rate = 0.05 + random.random() * 0.1, # 低利用率
peak_utilization=0.15, peak_utilization = 0.15,
avg_utilization=0.08, avg_utilization = 0.08,
idle_time_percent=0.85, idle_time_percent = 0.85,
report_date=report_date, report_date = report_date,
recommendations=["Consider downsizing this resource"], recommendations = ["Consider downsizing this resource"],
) )
self.log("Recorded 5 resource utilization records") self.log("Recorded 5 resource utilization records")
@@ -638,7 +638,7 @@ class TestOpsManager:
# 生成成本报告 # 生成成本报告
now = datetime.now() now = datetime.now()
report = self.manager.generate_cost_report( report = self.manager.generate_cost_report(
tenant_id=self.tenant_id, year=now.year, month=now.month tenant_id = self.tenant_id, year = now.year, month = now.month
) )
self.log(f"Generated cost report: {report.id}") self.log(f"Generated cost report: {report.id}")
@@ -687,24 +687,24 @@ class TestOpsManager:
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute( conn.execute(
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?", "DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
(self.tenant_id,), (self.tenant_id, ),
) )
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute( conn.execute(
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,) "DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id, )
) )
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit() conn.commit()
self.log("Cleaned up cost optimization test data") self.log("Cleaned up cost optimization test data")
except Exception as e: except Exception as e:
self.log(f"Cost optimization test failed: {e}", success=False) self.log(f"Cost optimization test failed: {e}", success = False)
def print_summary(self): def print_summary(self) -> None:
"""打印测试总结""" """打印测试总结"""
print("\n" + "=" * 60) print("\n" + " = " * 60)
print("Test Summary") print("Test Summary")
print("=" * 60) print(" = " * 60)
total = len(self.test_results) total = len(self.test_results)
passed = sum(1 for _, success in self.test_results if success) passed = sum(1 for _, success in self.test_results if success)
@@ -720,10 +720,10 @@ class TestOpsManager:
if not success: if not success:
print(f"{message}") print(f"{message}")
print("=" * 60) print(" = " * 60)
def main(): def main() -> None:
"""主函数""" """主函数"""
test = TestOpsManager() test = TestOpsManager()
test.run_all_tests() test.run_all_tests()

View File

@@ -10,7 +10,7 @@ from typing import Any
class TingwuClient: class TingwuClient:
def __init__(self): def __init__(self) -> None:
self.access_key = os.getenv("ALI_ACCESS_KEY", "") self.access_key = os.getenv("ALI_ACCESS_KEY", "")
self.secret_key = os.getenv("ALI_SECRET_KEY", "") self.secret_key = os.getenv("ALI_SECRET_KEY", "")
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com" self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
@@ -31,7 +31,7 @@ class TingwuClient:
"x-acs-action": "CreateTask", "x-acs-action": "CreateTask",
"x-acs-version": "2023-09-30", "x-acs-version": "2023-09-30",
"x-acs-date": timestamp, "x-acs-date": timestamp,
"Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing", "Authorization": f"ACS3-HMAC-SHA256 Credential = {self.access_key}/acs/tingwu/cn-beijing",
} }
def create_task(self, audio_url: str, language: str = "zh") -> str: def create_task(self, audio_url: str, language: str = "zh") -> str:
@@ -43,17 +43,17 @@ class TingwuClient:
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config( config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key access_key_id = self.access_key, access_key_secret = self.secret_key
) )
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)
request = tingwu_models.CreateTaskRequest( request = tingwu_models.CreateTaskRequest(
type="offline", type = "offline",
input=tingwu_models.Input(source="OSS", file_url=audio_url), input = tingwu_models.Input(source = "OSS", file_url = audio_url),
parameters=tingwu_models.Parameters( parameters = tingwu_models.Parameters(
transcription=tingwu_models.Transcription( transcription = tingwu_models.Transcription(
diarization_enabled=True, sentence_max_length=20 diarization_enabled = True, sentence_max_length = 20
) )
), ),
) )
@@ -78,12 +78,9 @@ class TingwuClient:
"""获取任务结果""" """获取任务结果"""
try: try:
# 导入移到文件顶部会导致循环导入,保持在这里 # 导入移到文件顶部会导致循环导入,保持在这里
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config( config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key access_key_id = self.access_key, access_key_secret = self.secret_key
) )
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)

View File

@@ -37,7 +37,7 @@ DEFAULT_RETRY_COUNT = 3 # 默认重试次数
DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒) DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒)
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -87,16 +87,16 @@ class WorkflowTask:
workflow_id: str workflow_id: str
name: str name: str
task_type: str # analyze, align, discover_relations, notify, custom task_type: str # analyze, align, discover_relations, notify, custom
config: dict = field(default_factory=dict) config: dict = field(default_factory = dict)
order: int = 0 order: int = 0
depends_on: list[str] = field(default_factory=list) depends_on: list[str] = field(default_factory = list)
timeout_seconds: int = 300 timeout_seconds: int = 300
retry_count: int = 3 retry_count: int = 3
retry_delay: int = 5 retry_delay: int = 5
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
def __post_init__(self): def __post_init__(self) -> None:
if not self.created_at: if not self.created_at:
self.created_at = datetime.now().isoformat() self.created_at = datetime.now().isoformat()
if not self.updated_at: if not self.updated_at:
@@ -112,7 +112,7 @@ class WebhookConfig:
webhook_type: str # feishu, dingtalk, slack, custom webhook_type: str # feishu, dingtalk, slack, custom
url: str url: str
secret: str = "" # 用于签名验证 secret: str = "" # 用于签名验证
headers: dict = field(default_factory=dict) headers: dict = field(default_factory = dict)
template: str = "" # 消息模板 template: str = "" # 消息模板
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
@@ -121,7 +121,7 @@ class WebhookConfig:
success_count: int = 0 success_count: int = 0
fail_count: int = 0 fail_count: int = 0
def __post_init__(self): def __post_init__(self) -> None:
if not self.created_at: if not self.created_at:
self.created_at = datetime.now().isoformat() self.created_at = datetime.now().isoformat()
if not self.updated_at: if not self.updated_at:
@@ -140,8 +140,8 @@ class Workflow:
status: str = "active" status: str = "active"
schedule: str | None = None # cron expression or interval schedule: str | None = None # cron expression or interval
schedule_type: str = "manual" # manual, cron, interval schedule_type: str = "manual" # manual, cron, interval
config: dict = field(default_factory=dict) config: dict = field(default_factory = dict)
webhook_ids: list[str] = field(default_factory=list) webhook_ids: list[str] = field(default_factory = list)
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@@ -151,7 +151,7 @@ class Workflow:
success_count: int = 0 success_count: int = 0
fail_count: int = 0 fail_count: int = 0
def __post_init__(self): def __post_init__(self) -> None:
if not self.created_at: if not self.created_at:
self.created_at = datetime.now().isoformat() self.created_at = datetime.now().isoformat()
if not self.updated_at: if not self.updated_at:
@@ -169,12 +169,12 @@ class WorkflowLog:
start_time: str | None = None start_time: str | None = None
end_time: str | None = None end_time: str | None = None
duration_ms: int = 0 duration_ms: int = 0
input_data: dict = field(default_factory=dict) input_data: dict = field(default_factory = dict)
output_data: dict = field(default_factory=dict) output_data: dict = field(default_factory = dict)
error_message: str = "" error_message: str = ""
created_at: str = "" created_at: str = ""
def __post_init__(self): def __post_init__(self) -> None:
if not self.created_at: if not self.created_at:
self.created_at = datetime.now().isoformat() self.created_at = datetime.now().isoformat()
@@ -182,8 +182,8 @@ class WorkflowLog:
class WebhookNotifier: class WebhookNotifier:
"""Webhook 通知器 - 支持飞书、钉钉、Slack""" """Webhook 通知器 - 支持飞书、钉钉、Slack"""
def __init__(self): def __init__(self) -> None:
self.http_client = httpx.AsyncClient(timeout=30.0) self.http_client = httpx.AsyncClient(timeout = 30.0)
async def send(self, config: WebhookConfig, message: dict) -> bool: async def send(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Webhook 通知""" """发送 Webhook 通知"""
@@ -210,7 +210,7 @@ class WebhookNotifier:
# 签名计算 # 签名计算
if config.secret: if config.secret:
string_to_sign = f"{timestamp}\n{config.secret}" string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod = hashlib.sha256).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
else: else:
sign = "" sign = ""
@@ -250,7 +250,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json=payload, headers=headers) response = await self.http_client.post(config.url, json = payload, headers = headers)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -265,10 +265,10 @@ class WebhookNotifier:
secret_enc = config.secret.encode("utf-8") secret_enc = config.secret.encode("utf-8")
string_to_sign = f"{timestamp}\n{config.secret}" string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new( hmac_code = hmac.new(
secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256 secret_enc, string_to_sign.encode("utf-8"), digestmod = hashlib.sha256
).digest() ).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
url = f"{config.url}&timestamp={timestamp}&sign={sign}" url = f"{config.url}&timestamp = {timestamp}&sign = {sign}"
else: else:
url = config.url url = config.url
@@ -295,7 +295,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(url, json=payload, headers=headers) response = await self.http_client.post(url, json = payload, headers = headers)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -316,7 +316,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json=payload, headers=headers) response = await self.http_client.post(config.url, json = payload, headers = headers)
response.raise_for_status() response.raise_for_status()
return response.text == "ok" return response.text == "ok"
@@ -325,12 +325,12 @@ class WebhookNotifier:
"""发送自定义 Webhook 通知""" """发送自定义 Webhook 通知"""
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json=message, headers=headers) response = await self.http_client.post(config.url, json = message, headers = headers)
response.raise_for_status() response.raise_for_status()
return True return True
async def close(self): async def close(self) -> None:
"""关闭 HTTP 客户端""" """关闭 HTTP 客户端"""
await self.http_client.aclose() await self.http_client.aclose()
@@ -343,7 +343,7 @@ class WorkflowManager:
DEFAULT_RETRY_COUNT: int = 3 DEFAULT_RETRY_COUNT: int = 3
DEFAULT_RETRY_DELAY: int = 5 DEFAULT_RETRY_DELAY: int = 5
def __init__(self, db_manager=None): def __init__(self, db_manager = None) -> None:
self.db = db_manager self.db = db_manager
self.scheduler = AsyncIOScheduler() self.scheduler = AsyncIOScheduler()
self.notifier = WebhookNotifier() self.notifier = WebhookNotifier()
@@ -381,13 +381,13 @@ class WorkflowManager:
def stop(self) -> None: def stop(self) -> None:
"""停止工作流管理器""" """停止工作流管理器"""
if self.scheduler.running: if self.scheduler.running:
self.scheduler.shutdown(wait=True) self.scheduler.shutdown(wait = True)
logger.info("Workflow scheduler stopped") logger.info("Workflow scheduler stopped")
async def _load_and_schedule_workflows(self): async def _load_and_schedule_workflows(self) -> None:
"""从数据库加载并调度所有活跃工作流""" """从数据库加载并调度所有活跃工作流"""
try: try:
workflows = self.list_workflows(status="active") workflows = self.list_workflows(status = "active")
for workflow in workflows: for workflow in workflows:
if workflow.schedule and workflow.is_active: if workflow.schedule and workflow.is_active:
self._schedule_workflow(workflow) self._schedule_workflow(workflow)
@@ -408,25 +408,25 @@ class WorkflowManager:
elif workflow.schedule_type == "interval": elif workflow.schedule_type == "interval":
# 间隔调度 # 间隔调度
interval_minutes = int(workflow.schedule) interval_minutes = int(workflow.schedule)
trigger = IntervalTrigger(minutes=interval_minutes) trigger = IntervalTrigger(minutes = interval_minutes)
else: else:
return return
self.scheduler.add_job( self.scheduler.add_job(
func=self._execute_workflow_job, func = self._execute_workflow_job,
trigger=trigger, trigger = trigger,
id=job_id, id = job_id,
args=[workflow.id], args = [workflow.id],
replace_existing=True, replace_existing = True,
max_instances=1, max_instances = 1,
coalesce=True, coalesce = True,
) )
logger.info( logger.info(
f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}" f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}"
) )
async def _execute_workflow_job(self, workflow_id: str): async def _execute_workflow_job(self, workflow_id: str) -> None:
"""调度器调用的工作流执行函数""" """调度器调用的工作流执行函数"""
try: try:
await self.execute_workflow(workflow_id) await self.execute_workflow(workflow_id)
@@ -488,7 +488,7 @@ class WorkflowManager:
"""获取工作流""" """获取工作流"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id,)).fetchone() row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id, )).fetchone()
if not row: if not row:
return None return None
@@ -516,7 +516,7 @@ class WorkflowManager:
conditions.append("workflow_type = ?") conditions.append("workflow_type = ?")
params.append(workflow_type) params.append(workflow_type)
where_clause = " AND ".join(conditions) if conditions else "1=1" where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute( rows = conn.execute(
f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params
@@ -585,10 +585,10 @@ class WorkflowManager:
self.scheduler.remove_job(job_id) self.scheduler.remove_job(job_id)
# 删除相关任务 # 删除相关任务
conn.execute("DELETE FROM workflow_tasks WHERE workflow_id = ?", (workflow_id,)) conn.execute("DELETE FROM workflow_tasks WHERE workflow_id = ?", (workflow_id, ))
# 删除工作流 # 删除工作流
conn.execute("DELETE FROM workflows WHERE id = ?", (workflow_id,)) conn.execute("DELETE FROM workflows WHERE id = ?", (workflow_id, ))
conn.commit() conn.commit()
return True return True
@@ -598,24 +598,24 @@ class WorkflowManager:
def _row_to_workflow(self, row) -> Workflow: def _row_to_workflow(self, row) -> Workflow:
"""将数据库行转换为 Workflow 对象""" """将数据库行转换为 Workflow 对象"""
return Workflow( return Workflow(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
description=row["description"] or "", description = row["description"] or "",
workflow_type=row["workflow_type"], workflow_type = row["workflow_type"],
project_id=row["project_id"], project_id = row["project_id"],
status=row["status"], status = row["status"],
schedule=row["schedule"], schedule = row["schedule"],
schedule_type=row["schedule_type"], schedule_type = row["schedule_type"],
config=json.loads(row["config"]) if row["config"] else {}, config = json.loads(row["config"]) if row["config"] else {},
webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], webhook_ids = json.loads(row["webhook_ids"]) if row["webhook_ids"] else [],
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
last_run_at=row["last_run_at"], last_run_at = row["last_run_at"],
next_run_at=row["next_run_at"], next_run_at = row["next_run_at"],
run_count=row["run_count"] or 0, run_count = row["run_count"] or 0,
success_count=row["success_count"] or 0, success_count = row["success_count"] or 0,
fail_count=row["fail_count"] or 0, fail_count = row["fail_count"] or 0,
) )
# ==================== Workflow Task CRUD ==================== # ==================== Workflow Task CRUD ====================
@@ -654,7 +654,7 @@ class WorkflowManager:
"""获取任务""" """获取任务"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id,)).fetchone() row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id, )).fetchone()
if not row: if not row:
return None return None
@@ -669,7 +669,7 @@ class WorkflowManager:
try: try:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order",
(workflow_id,), (workflow_id, ),
).fetchall() ).fetchall()
return [self._row_to_task(row) for row in rows] return [self._row_to_task(row) for row in rows]
@@ -720,7 +720,7 @@ class WorkflowManager:
"""删除任务""" """删除任务"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id,)) conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id, ))
conn.commit() conn.commit()
return True return True
finally: finally:
@@ -729,18 +729,18 @@ class WorkflowManager:
def _row_to_task(self, row) -> WorkflowTask: def _row_to_task(self, row) -> WorkflowTask:
"""将数据库行转换为 WorkflowTask 对象""" """将数据库行转换为 WorkflowTask 对象"""
return WorkflowTask( return WorkflowTask(
id=row["id"], id = row["id"],
workflow_id=row["workflow_id"], workflow_id = row["workflow_id"],
name=row["name"], name = row["name"],
task_type=row["task_type"], task_type = row["task_type"],
config=json.loads(row["config"]) if row["config"] else {}, config = json.loads(row["config"]) if row["config"] else {},
order=row["task_order"] or 0, order = row["task_order"] or 0,
depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [], depends_on = json.loads(row["depends_on"]) if row["depends_on"] else [],
timeout_seconds=row["timeout_seconds"] or 300, timeout_seconds = row["timeout_seconds"] or 300,
retry_count=row["retry_count"] or 3, retry_count = row["retry_count"] or 3,
retry_delay=row["retry_delay"] or 5, retry_delay = row["retry_delay"] or 5,
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
) )
# ==================== Webhook Config CRUD ==================== # ==================== Webhook Config CRUD ====================
@@ -781,7 +781,7 @@ class WorkflowManager:
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
row = conn.execute( row = conn.execute(
"SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,) "SELECT * FROM webhook_configs WHERE id = ?", (webhook_id, )
).fetchone() ).fetchone()
if not row: if not row:
@@ -844,7 +844,7 @@ class WorkflowManager:
"""删除 Webhook 配置""" """删除 Webhook 配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id,)) conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id, ))
conn.commit() conn.commit()
return True return True
finally: finally:
@@ -875,19 +875,19 @@ class WorkflowManager:
def _row_to_webhook(self, row) -> WebhookConfig: def _row_to_webhook(self, row) -> WebhookConfig:
"""将数据库行转换为 WebhookConfig 对象""" """将数据库行转换为 WebhookConfig 对象"""
return WebhookConfig( return WebhookConfig(
id=row["id"], id = row["id"],
name=row["name"], name = row["name"],
webhook_type=row["webhook_type"], webhook_type = row["webhook_type"],
url=row["url"], url = row["url"],
secret=row["secret"] or "", secret = row["secret"] or "",
headers=json.loads(row["headers"]) if row["headers"] else {}, headers = json.loads(row["headers"]) if row["headers"] else {},
template=row["template"] or "", template = row["template"] or "",
is_active=bool(row["is_active"]), is_active = bool(row["is_active"]),
created_at=row["created_at"], created_at = row["created_at"],
updated_at=row["updated_at"], updated_at = row["updated_at"],
last_used_at=row["last_used_at"], last_used_at = row["last_used_at"],
success_count=row["success_count"] or 0, success_count = row["success_count"] or 0,
fail_count=row["fail_count"] or 0, fail_count = row["fail_count"] or 0,
) )
# ==================== Workflow Log ==================== # ==================== Workflow Log ====================
@@ -952,7 +952,7 @@ class WorkflowManager:
"""获取日志""" """获取日志"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id,)).fetchone() row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id, )).fetchone()
if not row: if not row:
return None return None
@@ -985,7 +985,7 @@ class WorkflowManager:
conditions.append("status = ?") conditions.append("status = ?")
params.append(status) params.append(status)
where_clause = " AND ".join(conditions) if conditions else "1=1" where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute( rows = conn.execute(
f"""SELECT * FROM workflow_logs f"""SELECT * FROM workflow_logs
@@ -1003,7 +1003,7 @@ class WorkflowManager:
"""获取工作流统计""" """获取工作流统计"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
since = (datetime.now() - timedelta(days=days)).isoformat() since = (datetime.now() - timedelta(days = days)).isoformat()
# 总执行次数 # 总执行次数
total = conn.execute( total = conn.execute(
@@ -1060,17 +1060,17 @@ class WorkflowManager:
def _row_to_log(self, row) -> WorkflowLog: def _row_to_log(self, row) -> WorkflowLog:
"""将数据库行转换为 WorkflowLog 对象""" """将数据库行转换为 WorkflowLog 对象"""
return WorkflowLog( return WorkflowLog(
id=row["id"], id = row["id"],
workflow_id=row["workflow_id"], workflow_id = row["workflow_id"],
task_id=row["task_id"], task_id = row["task_id"],
status=row["status"], status = row["status"],
start_time=row["start_time"], start_time = row["start_time"],
end_time=row["end_time"], end_time = row["end_time"],
duration_ms=row["duration_ms"] or 0, duration_ms = row["duration_ms"] or 0,
input_data=json.loads(row["input_data"]) if row["input_data"] else {}, input_data = json.loads(row["input_data"]) if row["input_data"] else {},
output_data=json.loads(row["output_data"]) if row["output_data"] else {}, output_data = json.loads(row["output_data"]) if row["output_data"] else {},
error_message=row["error_message"] or "", error_message = row["error_message"] or "",
created_at=row["created_at"], created_at = row["created_at"],
) )
# ==================== Workflow Execution ==================== # ==================== Workflow Execution ====================
@@ -1086,15 +1086,15 @@ class WorkflowManager:
# 更新最后运行时间 # 更新最后运行时间
now = datetime.now().isoformat() now = datetime.now().isoformat()
self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1) self.update_workflow(workflow_id, last_run_at = now, run_count = workflow.run_count + 1)
# 创建工作流执行日志 # 创建工作流执行日志
log = WorkflowLog( log = WorkflowLog(
id=str(uuid.uuid4())[:UUID_LENGTH], id = str(uuid.uuid4())[:UUID_LENGTH],
workflow_id=workflow_id, workflow_id = workflow_id,
status=TaskStatus.RUNNING.value, status = TaskStatus.RUNNING.value,
start_time=now, start_time = now,
input_data=input_data or {}, input_data = input_data or {},
) )
self.create_log(log) self.create_log(log)
@@ -1113,21 +1113,21 @@ class WorkflowManager:
results = await self._execute_tasks_with_deps(tasks, input_data, log.id) results = await self._execute_tasks_with_deps(tasks, input_data, log.id)
# 发送通知 # 发送通知
await self._send_workflow_notification(workflow, results, success=True) await self._send_workflow_notification(workflow, results, success = True)
# 更新日志为成功 # 更新日志为成功
end_time = datetime.now() end_time = datetime.now()
duration = int((end_time - start_time).total_seconds() * 1000) duration = int((end_time - start_time).total_seconds() * 1000)
self.update_log( self.update_log(
log.id, log.id,
status=TaskStatus.SUCCESS.value, status = TaskStatus.SUCCESS.value,
end_time=end_time.isoformat(), end_time = end_time.isoformat(),
duration_ms=duration, duration_ms = duration,
output_data=results, output_data = results,
) )
# 更新成功计数 # 更新成功计数
self.update_workflow(workflow_id, success_count=workflow.success_count + 1) self.update_workflow(workflow_id, success_count = workflow.success_count + 1)
return { return {
"success": True, "success": True,
@@ -1145,17 +1145,17 @@ class WorkflowManager:
duration = int((end_time - start_time).total_seconds() * 1000) duration = int((end_time - start_time).total_seconds() * 1000)
self.update_log( self.update_log(
log.id, log.id,
status=TaskStatus.FAILED.value, status = TaskStatus.FAILED.value,
end_time=end_time.isoformat(), end_time = end_time.isoformat(),
duration_ms=duration, duration_ms = duration,
error_message=str(e), error_message = str(e),
) )
# 更新失败计数 # 更新失败计数
self.update_workflow(workflow_id, fail_count=workflow.fail_count + 1) self.update_workflow(workflow_id, fail_count = workflow.fail_count + 1)
# 发送失败通知 # 发送失败通知
await self._send_workflow_notification(workflow, {"error": str(e)}, success=False) await self._send_workflow_notification(workflow, {"error": str(e)}, success = False)
raise raise
@@ -1185,7 +1185,7 @@ class WorkflowManager:
task_input = {**input_data, **results} task_input = {**input_data, **results}
task_coros.append(self._execute_single_task(task, task_input, log_id)) task_coros.append(self._execute_single_task(task, task_input, log_id))
task_results = await asyncio.gather(*task_coros, return_exceptions=True) task_results = await asyncio.gather(*task_coros, return_exceptions = True)
for task, result in zip(ready_tasks, task_results): for task, result in zip(ready_tasks, task_results):
if isinstance(result, Exception): if isinstance(result, Exception):
@@ -1217,25 +1217,25 @@ class WorkflowManager:
# 创建任务日志 # 创建任务日志
task_log = WorkflowLog( task_log = WorkflowLog(
id=str(uuid.uuid4())[:UUID_LENGTH], id = str(uuid.uuid4())[:UUID_LENGTH],
workflow_id=task.workflow_id, workflow_id = task.workflow_id,
task_id=task.id, task_id = task.id,
status=TaskStatus.RUNNING.value, status = TaskStatus.RUNNING.value,
start_time=datetime.now().isoformat(), start_time = datetime.now().isoformat(),
input_data=input_data, input_data = input_data,
) )
self.create_log(task_log) self.create_log(task_log)
try: try:
# 设置超时 # 设置超时
result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds) result = await asyncio.wait_for(handler(task, input_data), timeout = task.timeout_seconds)
# 更新任务日志为成功 # 更新任务日志为成功
self.update_log( self.update_log(
task_log.id, task_log.id,
status=TaskStatus.SUCCESS.value, status = TaskStatus.SUCCESS.value,
end_time=datetime.now().isoformat(), end_time = datetime.now().isoformat(),
output_data={"result": result} if not isinstance(result, dict) else result, output_data = {"result": result} if not isinstance(result, dict) else result,
) )
return result return result
@@ -1243,18 +1243,18 @@ class WorkflowManager:
except TimeoutError: except TimeoutError:
self.update_log( self.update_log(
task_log.id, task_log.id,
status=TaskStatus.FAILED.value, status = TaskStatus.FAILED.value,
end_time=datetime.now().isoformat(), end_time = datetime.now().isoformat(),
error_message="Task timeout", error_message = "Task timeout",
) )
raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s") raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s")
except Exception as e: except Exception as e:
self.update_log( self.update_log(
task_log.id, task_log.id,
status=TaskStatus.FAILED.value, status = TaskStatus.FAILED.value,
end_time=datetime.now().isoformat(), end_time = datetime.now().isoformat(),
error_message=str(e), error_message = str(e),
) )
raise raise
@@ -1415,7 +1415,7 @@ class WorkflowManager:
async def _send_workflow_notification( async def _send_workflow_notification(
self, workflow: Workflow, results: dict, success: bool = True self, workflow: Workflow, results: dict, success: bool = True
): ) -> None:
"""发送工作流执行通知""" """发送工作流执行通知"""
if not workflow.webhook_ids: if not workflow.webhook_ids:
return return
@@ -1476,7 +1476,7 @@ class WorkflowManager:
**结果:** **结果:**
```json ```json
{json.dumps(results, ensure_ascii=False, indent=2)} {json.dumps(results, ensure_ascii = False, indent = 2)}
``` ```
""", """,
} }
@@ -1510,7 +1510,7 @@ class WorkflowManager:
_workflow_manager = None _workflow_manager = None
def get_workflow_manager(db_manager=None) -> WorkflowManager: def get_workflow_manager(db_manager = None) -> WorkflowManager:
"""获取 WorkflowManager 单例""" """获取 WorkflowManager 单例"""
global _workflow_manager global _workflow_manager
if _workflow_manager is None: if _workflow_manager is None:

View File

@@ -55,7 +55,7 @@ def check_bare_excepts(content: str, file_path: Path) -> list[dict]:
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
stripped = line.strip() stripped = line.strip()
# 检查 except Exception: 或 except : # 检查 except Exception: 或 except Exception:
if re.match(r'^except\s*:', stripped): if re.match(r'^except\s*:', stripped):
issues.append({ issues.append({
"line": i, "line": i,
@@ -140,7 +140,7 @@ def check_magic_numbers(content: str, file_path: Path) -> list[dict]:
lines = content.split('\n') lines = content.split('\n')
# 常见魔法数字模式(排除常见索引和简单值) # 常见魔法数字模式(排除常见索引和简单值)
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3,})(?![\w\d_])') magic_pattern = re.compile(r'(?<![\w\d_])(\d{3, })(?![\w\d_])')
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
if line.strip().startswith('#'): if line.strip().startswith('#'):
@@ -217,7 +217,7 @@ def fix_line_length(content: str) -> str:
for line in lines: for line in lines:
if len(line) > 100: if len(line) > 100:
# 尝试在逗号或运算符处折行 # 尝试在逗号或运算符处折行
if ',' in line[80:]: if ', ' in line[80:]:
# 简单处理:截断并添加续行 # 简单处理:截断并添加续行
indent = len(line) - len(line.lstrip()) indent = len(line) - len(line.lstrip())
new_lines.append(line) new_lines.append(line)
@@ -231,7 +231,7 @@ def fix_line_length(content: str) -> str:
def analyze_file(file_path: Path) -> dict: def analyze_file(file_path: Path) -> dict:
"""分析单个文件""" """分析单个文件"""
try: try:
content = file_path.read_text(encoding='utf-8') content = file_path.read_text(encoding = 'utf-8')
except Exception as e: except Exception as e:
return {"error": str(e)} return {"error": str(e)}
@@ -251,7 +251,7 @@ def analyze_file(file_path: Path) -> dict:
def fix_file(file_path: Path, issues: dict) -> bool: def fix_file(file_path: Path, issues: dict) -> bool:
"""自动修复文件问题""" """自动修复文件问题"""
try: try:
content = file_path.read_text(encoding='utf-8') content = file_path.read_text(encoding = 'utf-8')
original_content = content original_content = content
# 修复裸异常 # 修复裸异常
@@ -260,7 +260,7 @@ def fix_file(file_path: Path, issues: dict) -> bool:
# 如果有修改,写回文件 # 如果有修改,写回文件
if content != original_content: if content != original_content:
file_path.write_text(content, encoding='utf-8') file_path.write_text(content, encoding = 'utf-8')
return True return True
return False return False
except Exception as e: except Exception as e:
@@ -333,7 +333,7 @@ def generate_report(all_issues: dict) -> str:
return '\n'.join(lines) return '\n'.join(lines)
def git_commit_and_push(): def git_commit_and_push() -> None:
"""提交并推送代码""" """提交并推送代码"""
try: try:
os.chdir(PROJECT_PATH) os.chdir(PROJECT_PATH)
@@ -341,15 +341,15 @@ def git_commit_and_push():
# 检查是否有修改 # 检查是否有修改
result = subprocess.run( result = subprocess.run(
["git", "status", "--porcelain"], ["git", "status", "--porcelain"],
capture_output=True, capture_output = True,
text=True text = True
) )
if not result.stdout.strip(): if not result.stdout.strip():
return "没有需要提交的更改" return "没有需要提交的更改"
# 添加所有修改 # 添加所有修改
subprocess.run(["git", "add", "-A"], check=True) subprocess.run(["git", "add", "-A"], check = True)
# 提交 # 提交
subprocess.run( subprocess.run(
@@ -359,11 +359,11 @@ def git_commit_and_push():
- 修复异常处理 - 修复异常处理
- 修复PEP8格式问题 - 修复PEP8格式问题
- 添加类型注解"""], - 添加类型注解"""],
check=True check = True
) )
# 推送 # 推送
subprocess.run(["git", "push"], check=True) subprocess.run(["git", "push"], check = True)
return "✅ 提交并推送成功" return "✅ 提交并推送成功"
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@@ -371,7 +371,7 @@ def git_commit_and_push():
except Exception as e: except Exception as e:
return f"❌ 错误: {e}" return f"❌ 错误: {e}"
def main(): def main() -> None:
"""主函数""" """主函数"""
print("🔍 开始代码审查...") print("🔍 开始代码审查...")
@@ -392,7 +392,7 @@ def main():
# 生成报告 # 生成报告
report_content = generate_report(all_issues) report_content = generate_report(all_issues)
report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md" report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md"
report_path.write_text(report_content, encoding='utf-8') report_path.write_text(report_content, encoding = 'utf-8')
print("\n📄 报告已生成:", report_path) print("\n📄 报告已生成:", report_path)
@@ -402,7 +402,7 @@ def main():
print(git_result) print(git_result)
# 追加提交结果到报告 # 追加提交结果到报告
with open(report_path, 'a', encoding='utf-8') as f: with open(report_path, 'a', encoding = 'utf-8') as f:
f.write(f"\n\n## Git 提交结果\n\n{git_result}\n") f.write(f"\n\n## Git 提交结果\n\n{git_result}\n")
print("\n✅ 代码审查完成!") print("\n✅ 代码审查完成!")

View File

@@ -16,7 +16,7 @@ class CodeIssue:
issue_type: str, issue_type: str,
message: str, message: str,
severity: str = "info", severity: str = "info",
): ) -> None:
self.file_path = file_path self.file_path = file_path
self.line_no = line_no self.line_no = line_no
self.issue_type = issue_type self.issue_type = issue_type
@@ -24,12 +24,12 @@ class CodeIssue:
self.severity = severity # info, warning, error self.severity = severity # info, warning, error
self.fixed = False self.fixed = False
def __repr__(self): def __repr__(self) -> None:
return f"{self.severity.upper()}: {self.file_path}:{self.line_no} - {self.issue_type}: {self.message}" return f"{self.severity.upper()}: {self.file_path}:{self.line_no} - {self.issue_type}: {self.message}"
class CodeReviewer: class CodeReviewer:
def __init__(self, base_path: str): def __init__(self, base_path: str) -> None:
self.base_path = Path(base_path) self.base_path = Path(base_path)
self.issues: list[CodeIssue] = [] self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = [] self.fixed_issues: list[CodeIssue] = []
@@ -45,7 +45,7 @@ class CodeReviewer:
def scan_file(self, file_path: Path) -> None: def scan_file(self, file_path: Path) -> None:
"""扫描单个文件""" """扫描单个文件"""
try: try:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding = "utf-8") as f:
content = f.read() content = f.read()
lines = content.split("\n") lines = content.split("\n")
except Exception as e: except Exception as e:
@@ -110,7 +110,7 @@ class CodeReviewer:
match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip()) match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip())
if match: if match:
module = match.group(1) or "" module = match.group(1) or ""
names = match.group(2).split(",") names = match.group(2).split(", ")
for name in names: for name in names:
name = name.strip().split()[0] # 处理 'as' 别名 name = name.strip().split()[0] # 处理 'as' 别名
key = f"{module}.{name}" if module else name key = f"{module}.{name}" if module else name
@@ -223,10 +223,10 @@ class CodeReviewer:
"""检查魔法数字""" """检查魔法数字"""
# 常见的魔法数字模式 # 常见的魔法数字模式
magic_patterns = [ magic_patterns = [
(r"=\s*(\d{3,})\s*[^:]", "可能的魔法数字"), (r" = \s*(\d{3, })\s*[^:]", "可能的魔法数字"),
(r"timeout\s*=\s*(\d+)", "timeout 魔法数字"), (r"timeout\s* = \s*(\d+)", "timeout 魔法数字"),
(r"limit\s*=\s*(\d+)", "limit 魔法数字"), (r"limit\s* = \s*(\d+)", "limit 魔法数字"),
(r"port\s*=\s*(\d+)", "port 魔法数字"), (r"port\s* = \s*(\d+)", "port 魔法数字"),
] ]
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
@@ -238,7 +238,7 @@ class CodeReviewer:
for pattern, msg in magic_patterns: for pattern, msg in magic_patterns:
if re.search(pattern, code_part, re.IGNORECASE): if re.search(pattern, code_part, re.IGNORECASE):
# 排除常见的合理数字 # 排除常见的合理数字
match = re.search(r"(\d{3,})", code_part) match = re.search(r"(\d{3, })", code_part)
if match: if match:
num = int(match.group(1)) num = int(match.group(1))
if num in [ if num in [
@@ -303,7 +303,7 @@ class CodeReviewer:
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 检查硬编码密钥 # 检查硬编码密钥
if re.search( if re.search(
r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']', r'(password|secret|key|token)\s* = \s*["\'][^"\']+["\']',
line, line,
re.IGNORECASE, re.IGNORECASE,
): ):
@@ -340,7 +340,7 @@ class CodeReviewer:
continue continue
try: try:
with open(full_path, "r", encoding="utf-8") as f: with open(full_path, "r", encoding = "utf-8") as f:
content = f.read() content = f.read()
lines = content.split("\n") lines = content.split("\n")
except Exception as e: except Exception as e:
@@ -363,9 +363,9 @@ class CodeReviewer:
idx = issue.line_no - 1 idx = issue.line_no - 1
if 0 <= idx < len(lines): if 0 <= idx < len(lines):
line = lines[idx] line = lines[idx]
# 将 except: 改为 except Exception: # 将 except Exception: 改为 except Exception:
if re.search(r"except\s*:\s*$", line.strip()): if re.search(r"except\s*:\s*$", line.strip()):
lines[idx] = line.replace("except:", "except Exception:") lines[idx] = line.replace("except Exception:", "except Exception:")
issue.fixed = True issue.fixed = True
elif re.search(r"except\s+Exception\s*:\s*$", line.strip()): elif re.search(r"except\s+Exception\s*:\s*$", line.strip()):
# 已经是 Exception但可能需要更具体 # 已经是 Exception但可能需要更具体
@@ -373,7 +373,7 @@ class CodeReviewer:
# 如果文件有修改,写回 # 如果文件有修改,写回
if lines != original_lines: if lines != original_lines:
with open(full_path, "w", encoding="utf-8") as f: with open(full_path, "w", encoding = "utf-8") as f:
f.write("\n".join(lines)) f.write("\n".join(lines))
print(f"Fixed issues in {file_path}") print(f"Fixed issues in {file_path}")
@@ -421,7 +421,7 @@ class CodeReviewer:
return "\n".join(report) return "\n".join(report)
def main(): def main() -> None:
base_path = "/root/.openclaw/workspace/projects/insightflow/backend" base_path = "/root/.openclaw/workspace/projects/insightflow/backend"
reviewer = CodeReviewer(base_path) reviewer = CodeReviewer(base_path)
@@ -439,7 +439,7 @@ def main():
# 生成报告 # 生成报告
report = reviewer.generate_report() report = reviewer.generate_report()
report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md" report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding="utf-8") as f: with open(report_path, "w", encoding = "utf-8") as f:
f.write(report) f.write(report)
print(f"\n报告已保存到: {report_path}") print(f"\n报告已保存到: {report_path}")