fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
This commit is contained in:
AutoFix Bot
2026-03-02 12:14:39 +08:00
parent e23f1fec08
commit 98527c4de4
39 changed files with 8109 additions and 8147 deletions

View File

@@ -1,192 +1,153 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
InsightFlow 代码自动修复脚本 - 增强版 InsightFlow 代码自动修复脚本
自动修复代码中的常见问题
""" """
import json
import os import os
import re import re
import subprocess import subprocess
from pathlib import Path from pathlib import Path
PROJECT_DIR = Path("/root/.openclaw/workspace/projects/insightflow")
BACKEND_DIR = PROJECT_DIR / "backend"
def run_ruff_check(directory: str) -> list[dict]: def run_flake8():
"""运行 ruff 检查并返回问题列表""" """运行 flake8 检查"""
try:
result = subprocess.run( result = subprocess.run(
["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory], ["flake8", "--max-line-length=120", "--ignore=E501,W503", "."],
capture_output = True, cwd=BACKEND_DIR,
text = True, capture_output=True,
check = False, text=True
) )
if result.stdout: return result.stdout
return json.loads(result.stdout)
return []
except Exception as e:
print(f"Ruff check failed: {e}")
return []
def fix_missing_imports():
"""修复缺失的导入"""
fixes = []
def fix_bare_except(content: str) -> str: # 检查 workflow_manager.py 中的 urllib
"""修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:""" workflow_file = BACKEND_DIR / "workflow_manager.py"
pattern = r'except\s*:\s*\n' if workflow_file.exists():
replacement = 'except Exception:\n' content = workflow_file.read_text()
return re.sub(pattern, replacement, content) if "import urllib" not in content and "urllib" in content:
# 在文件开头添加导入
def fix_undefined_names(content: str, filepath: str) -> str:
"""修复未定义的名称"""
lines = content.split('\n') lines = content.split('\n')
modified = False
import_map = {
'ExportEntity': 'from export_manager import ExportEntity',
'ExportRelation': 'from export_manager import ExportRelation',
'ExportTranscript': 'from export_manager import ExportTranscript',
'WorkflowManager': 'from workflow_manager import WorkflowManager',
'PluginManager': 'from plugin_manager import PluginManager',
'OpsManager': 'from ops_manager import OpsManager',
'urllib': 'import urllib.parse',
}
undefined_names = set()
for name, import_stmt in import_map.items():
if name in content and import_stmt not in content:
undefined_names.add((name, import_stmt))
if undefined_names:
import_idx = 0 import_idx = 0
for i, line in enumerate(lines): for i, line in enumerate(lines):
if line.startswith('import ') or line.startswith('from '): if line.startswith('import ') or line.startswith('from '):
import_idx = i + 1 import_idx = i + 1
lines.insert(import_idx, 'import urllib.parse')
workflow_file.write_text('\n'.join(lines))
fixes.append("workflow_manager.py: 添加 urllib.parse 导入")
for name, import_stmt in sorted(undefined_names): # 检查 plugin_manager.py 中的 urllib
lines.insert(import_idx, import_stmt) plugin_file = BACKEND_DIR / "plugin_manager.py"
import_idx += 1 if plugin_file.exists():
modified = True content = plugin_file.read_text()
if "import urllib" not in content and "urllib" in content:
lines = content.split('\n')
import_idx = 0
for i, line in enumerate(lines):
if line.startswith('import ') or line.startswith('from '):
import_idx = i + 1
lines.insert(import_idx, 'import urllib.parse')
plugin_file.write_text('\n'.join(lines))
fixes.append("plugin_manager.py: 添加 urllib.parse 导入")
if modified: # 检查 main.py 中的 PlainTextResponse
return '\n'.join(lines) main_file = BACKEND_DIR / "main.py"
return content if main_file.exists():
content = main_file.read_text()
if "PlainTextResponse" in content and "from fastapi.responses import" in content:
def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]: # 检查是否已导入
"""修复单个文件的问题""" if "PlainTextResponse" not in content.split('from fastapi.responses import')[1].split('\n')[0]:
with open(filepath, 'r', encoding = 'utf-8') as f: # 添加导入
original_content = f.read() content = content.replace(
"from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse",
content = original_content "from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse"
fixed_issues = []
manual_fix_needed = []
for issue in issues:
code = issue.get('code', '')
message = issue.get('message', '')
line_num = issue['location']['row']
if code == 'F821':
content = fix_undefined_names(content, filepath)
if content != original_content:
fixed_issues.append(f"F821 - {message} (line {line_num})")
else:
manual_fix_needed.append(f"F821 - {message} (line {line_num})")
elif code == 'E501':
manual_fix_needed.append(f"E501 (line {line_num})")
content = fix_bare_except(content)
if content != original_content:
with open(filepath, 'w', encoding = 'utf-8') as f:
f.write(content)
return True, fixed_issues, manual_fix_needed
return False, fixed_issues, manual_fix_needed
def main() -> None:
base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
backend_dir = base_dir / "backend"
print(" = " * 60)
print("InsightFlow 代码自动修复")
print(" = " * 60)
print("\n1. 扫描代码问题...")
issues = run_ruff_check(str(backend_dir))
issues_by_file = {}
for issue in issues:
filepath = issue.get('filename', '')
if filepath not in issues_by_file:
issues_by_file[filepath] = []
issues_by_file[filepath].append(issue)
print(f" 发现 {len(issues)} 个问题,分布在 {len(issues_by_file)} 个文件中")
issue_types = {}
for issue in issues:
code = issue.get('code', 'UNKNOWN')
issue_types[code] = issue_types.get(code, 0) + 1
print("\n2. 问题类型统计:")
for code, count in sorted(issue_types.items(), key = lambda x: -x[1]):
print(f" - {code}: {count}")
print("\n3. 尝试自动修复...")
fixed_files = []
all_fixed_issues = []
all_manual_fixes = []
for filepath, file_issues in issues_by_file.items():
if not os.path.exists(filepath):
continue
modified, fixed, manual = fix_file(filepath, file_issues)
if modified:
fixed_files.append(filepath)
all_fixed_issues.extend(fixed)
all_manual_fixes.extend([(filepath, m) for m in manual])
print(f" 直接修改了 {len(fixed_files)} 个文件")
print(f" 自动修复了 {len(all_fixed_issues)} 个问题")
print("\n4. 运行 ruff 自动格式化...")
try:
subprocess.run(
["ruff", "format", str(backend_dir)],
capture_output = True,
check = False,
) )
print(" 格式化完成") # 实际上已经导入了,可能是误报
except Exception as e:
print(f" 格式化失败: {e}")
print("\n5. 再次检查...") return fixes
remaining_issues = run_ruff_check(str(backend_dir))
print(f" 剩余 {len(remaining_issues)} 个问题需要手动处理")
report = { def fix_unused_imports():
'total_issues': len(issues), """修复未使用的导入"""
'fixed_files': len(fixed_files), fixes = []
'fixed_issues': len(all_fixed_issues),
'remaining_issues': len(remaining_issues),
'issue_types': issue_types,
'manual_fix_needed': all_manual_fixes[:30],
}
return report # code_reviewer.py 中的未使用导入
code_reviewer = PROJECT_DIR / "code_reviewer.py"
if code_reviewer.exists():
content = code_reviewer.read_text()
original = content
# 移除未使用的导入
content = re.sub(r'^import os\n', '', content, flags=re.MULTILINE)
content = re.sub(r'^import subprocess\n', '', content, flags=re.MULTILINE)
content = re.sub(r'^from typing import Any\n', '', content, flags=re.MULTILINE)
if content != original:
code_reviewer.write_text(content)
fixes.append("code_reviewer.py: 移除未使用的导入")
return fixes
def fix_formatting():
"""使用 autopep8 修复格式问题"""
fixes = []
# 运行 autopep8 修复格式问题
result = subprocess.run(
["autopep8", "--in-place", "--aggressive", "--max-line-length=120", "."],
cwd=BACKEND_DIR,
capture_output=True,
text=True
)
if result.returncode == 0:
fixes.append("使用 autopep8 修复了格式问题")
return fixes
def main():
print("=" * 60)
print("InsightFlow 代码自动修复")
print("=" * 60)
all_fixes = []
# 1. 修复缺失的导入
print("\n[1/3] 修复缺失的导入...")
fixes = fix_missing_imports()
all_fixes.extend(fixes)
for f in fixes:
print(f"{f}")
# 2. 修复未使用的导入
print("\n[2/3] 修复未使用的导入...")
fixes = fix_unused_imports()
all_fixes.extend(fixes)
for f in fixes:
print(f"{f}")
# 3. 修复格式问题
print("\n[3/3] 修复 PEP8 格式问题...")
fixes = fix_formatting()
all_fixes.extend(fixes)
for f in fixes:
print(f"{f}")
print("\n" + "=" * 60)
print(f"修复完成!共修复 {len(all_fixes)} 个问题")
print("=" * 60)
# 再次运行 flake8 检查
print("\n重新运行 flake8 检查...")
remaining = run_flake8()
if remaining:
lines = remaining.strip().split('\n')
print(f" 仍有 {len(lines)} 个问题需要手动处理")
else:
print(" ✓ 所有问题已修复!")
return all_fixes
if __name__ == "__main__": if __name__ == "__main__":
report = main() main()
print("\n" + " = " * 60)
print("修复报告")
print(" = " * 60)
print(f"总问题数: {report['total_issues']}")
print(f"修复文件数: {report['fixed_files']}")
print(f"自动修复问题数: {report['fixed_issues']}")
print(f"剩余问题数: {report['remaining_issues']}")
print(f"\n需要手动处理的问题 (前30个):")
for filepath, issue in report['manual_fix_needed']:
print(f" - {filepath}: {issue}")

View File

@@ -234,20 +234,20 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
model = CustomModel( model = CustomModel(
id = model_id, id=model_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
model_type = model_type, model_type=model_type,
status = ModelStatus.PENDING, status=ModelStatus.PENDING,
training_data = training_data, training_data=training_data,
hyperparameters = hyperparameters, hyperparameters=hyperparameters,
metrics = {}, metrics={},
model_path = None, model_path=None,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
trained_at = None, trained_at=None,
created_by = created_by, created_by=created_by,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -318,12 +318,12 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
sample = TrainingSample( sample = TrainingSample(
id = sample_id, id=sample_id,
model_id = model_id, model_id=model_id,
text = text, text=text,
entities = entities, entities=entities,
metadata = metadata or {}, metadata=metadata or {},
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -392,7 +392,7 @@ class AIManager:
# 保存模型(模拟) # 保存模型(模拟)
model_path = f"models/{model_id}.bin" model_path = f"models/{model_id}.bin"
os.makedirs("models", exist_ok = True) os.makedirs("models", exist_ok=True)
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -450,9 +450,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 60.0, timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -494,17 +494,17 @@ class AIManager:
result = await self._call_kimi_multimodal(input_urls, prompt) result = await self._call_kimi_multimodal(input_urls, prompt)
analysis = MultimodalAnalysis( analysis = MultimodalAnalysis(
id = analysis_id, id=analysis_id,
tenant_id = tenant_id, tenant_id=tenant_id,
project_id = project_id, project_id=project_id,
provider = provider, provider=provider,
input_type = input_type, input_type=input_type,
input_urls = input_urls, input_urls=input_urls,
prompt = prompt, prompt=prompt,
result = result, result=result,
tokens_used = result.get("tokens_used", 0), tokens_used=result.get("tokens_used", 0),
cost = result.get("cost", 0.0), cost=result.get("cost", 0.0),
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -553,9 +553,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 120.0, timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -588,9 +588,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://api.anthropic.com/v1/messages", "https://api.anthropic.com/v1/messages",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 120.0, timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -623,9 +623,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 60.0, timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -670,17 +670,17 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
rag = KnowledgeGraphRAG( rag = KnowledgeGraphRAG(
id = rag_id, id=rag_id,
tenant_id = tenant_id, tenant_id=tenant_id,
project_id = project_id, project_id=project_id,
name = name, name=name,
description = description, description=description,
kg_config = kg_config, kg_config=kg_config,
retrieval_config = retrieval_config, retrieval_config=retrieval_config,
generation_config = generation_config, generation_config=generation_config,
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -766,7 +766,7 @@ class AIManager:
if score > 0: if score > 0:
relevant_entities.append({**entity, "relevance_score": score}) relevant_entities.append({**entity, "relevance_score": score})
relevant_entities.sort(key = lambda x: x["relevance_score"], reverse = True) relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True)
relevant_entities = relevant_entities[:top_k] relevant_entities = relevant_entities[:top_k]
# 检索相关关系 # 检索相关关系
@@ -818,9 +818,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 60.0, timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -840,20 +840,20 @@ class AIManager:
] ]
rag_query = RAGQuery( rag_query = RAGQuery(
id = query_id, id=query_id,
rag_id = rag_id, rag_id=rag_id,
query = query, query=query,
context = context, context=context,
answer = answer, answer=answer,
sources = sources, sources=sources,
confidence = ( confidence=(
sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities) sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities)
if relevant_entities if relevant_entities
else 0 else 0
), ),
tokens_used = tokens_used, tokens_used=tokens_used,
latency_ms = latency_ms, latency_ms=latency_ms,
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -974,9 +974,9 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", f"{self.kimi_base_url}/v1/chat/completions",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 60.0, timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -1014,18 +1014,18 @@ class AIManager:
entity_names = [e.get("name", "") for e in entities_mentioned[:10]] entity_names = [e.get("name", "") for e in entities_mentioned[:10]]
summary = SmartSummary( summary = SmartSummary(
id = summary_id, id=summary_id,
tenant_id = tenant_id, tenant_id=tenant_id,
project_id = project_id, project_id=project_id,
source_type = source_type, source_type=source_type,
source_id = source_id, source_id=source_id,
summary_type = summary_type, summary_type=summary_type,
content = content, content=content,
key_points = key_points[:8], key_points=key_points[:8],
entities_mentioned = entity_names, entities_mentioned=entity_names,
confidence = 0.85, confidence=0.85,
tokens_used = tokens_used, tokens_used=tokens_used,
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1072,20 +1072,20 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
model = PredictionModel( model = PredictionModel(
id = model_id, id=model_id,
tenant_id = tenant_id, tenant_id=tenant_id,
project_id = project_id, project_id=project_id,
name = name, name=name,
prediction_type = prediction_type, prediction_type=prediction_type,
target_entity_type = target_entity_type, target_entity_type=target_entity_type,
features = features, features=features,
model_config = model_config, model_config=model_config,
accuracy = None, accuracy=None,
last_trained_at = None, last_trained_at=None,
prediction_count = 0, prediction_count=0,
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1201,16 +1201,16 @@ class AIManager:
explanation = prediction_data.get("explanation", "基于历史数据模式预测") explanation = prediction_data.get("explanation", "基于历史数据模式预测")
result = PredictionResult( result = PredictionResult(
id = prediction_id, id=prediction_id,
model_id = model_id, model_id=model_id,
prediction_type = model.prediction_type, prediction_type=model.prediction_type,
target_id = input_data.get("target_id"), target_id=input_data.get("target_id"),
prediction_data = prediction_data, prediction_data=prediction_data,
confidence = confidence, confidence=confidence,
explanation = explanation, explanation=explanation,
actual_value = None, actual_value=None,
is_correct = None, is_correct=None,
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1368,7 +1368,7 @@ class AIManager:
predicted_relations = [ predicted_relations = [
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
for rel_type, count in sorted( for rel_type, count in sorted(
relation_counts.items(), key = lambda x: x[1], reverse = True relation_counts.items(), key=lambda x: x[1], reverse=True
)[:5] )[:5]
] ]
@@ -1410,97 +1410,97 @@ class AIManager:
def _row_to_custom_model(self, row) -> CustomModel: def _row_to_custom_model(self, row) -> CustomModel:
"""将数据库行转换为 CustomModel""" """将数据库行转换为 CustomModel"""
return CustomModel( return CustomModel(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
model_type = ModelType(row["model_type"]), model_type=ModelType(row["model_type"]),
status = ModelStatus(row["status"]), status=ModelStatus(row["status"]),
training_data = json.loads(row["training_data"]), training_data=json.loads(row["training_data"]),
hyperparameters = json.loads(row["hyperparameters"]), hyperparameters=json.loads(row["hyperparameters"]),
metrics = json.loads(row["metrics"]), metrics=json.loads(row["metrics"]),
model_path = row["model_path"], model_path=row["model_path"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
trained_at = row["trained_at"], trained_at=row["trained_at"],
created_by = row["created_by"], created_by=row["created_by"],
) )
def _row_to_training_sample(self, row) -> TrainingSample: def _row_to_training_sample(self, row) -> TrainingSample:
"""将数据库行转换为 TrainingSample""" """将数据库行转换为 TrainingSample"""
return TrainingSample( return TrainingSample(
id = row["id"], id=row["id"],
model_id = row["model_id"], model_id=row["model_id"],
text = row["text"], text=row["text"],
entities = json.loads(row["entities"]), entities=json.loads(row["entities"]),
metadata = json.loads(row["metadata"]), metadata=json.loads(row["metadata"]),
created_at = row["created_at"], created_at=row["created_at"],
) )
def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis: def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis:
"""将数据库行转换为 MultimodalAnalysis""" """将数据库行转换为 MultimodalAnalysis"""
return MultimodalAnalysis( return MultimodalAnalysis(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
project_id = row["project_id"], project_id=row["project_id"],
provider = MultimodalProvider(row["provider"]), provider=MultimodalProvider(row["provider"]),
input_type = row["input_type"], input_type=row["input_type"],
input_urls = json.loads(row["input_urls"]), input_urls=json.loads(row["input_urls"]),
prompt = row["prompt"], prompt=row["prompt"],
result = json.loads(row["result"]), result=json.loads(row["result"]),
tokens_used = row["tokens_used"], tokens_used=row["tokens_used"],
cost = row["cost"], cost=row["cost"],
created_at = row["created_at"], created_at=row["created_at"],
) )
def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG: def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG:
"""将数据库行转换为 KnowledgeGraphRAG""" """将数据库行转换为 KnowledgeGraphRAG"""
return KnowledgeGraphRAG( return KnowledgeGraphRAG(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
project_id = row["project_id"], project_id=row["project_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
kg_config = json.loads(row["kg_config"]), kg_config=json.loads(row["kg_config"]),
retrieval_config = json.loads(row["retrieval_config"]), retrieval_config=json.loads(row["retrieval_config"]),
generation_config = json.loads(row["generation_config"]), generation_config=json.loads(row["generation_config"]),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_prediction_model(self, row) -> PredictionModel: def _row_to_prediction_model(self, row) -> PredictionModel:
"""将数据库行转换为 PredictionModel""" """将数据库行转换为 PredictionModel"""
return PredictionModel( return PredictionModel(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
project_id = row["project_id"], project_id=row["project_id"],
name = row["name"], name=row["name"],
prediction_type = PredictionType(row["prediction_type"]), prediction_type=PredictionType(row["prediction_type"]),
target_entity_type = row["target_entity_type"], target_entity_type=row["target_entity_type"],
features = json.loads(row["features"]), features=json.loads(row["features"]),
model_config = json.loads(row["model_config"]), model_config=json.loads(row["model_config"]),
accuracy = row["accuracy"], accuracy=row["accuracy"],
last_trained_at = row["last_trained_at"], last_trained_at=row["last_trained_at"],
prediction_count = row["prediction_count"], prediction_count=row["prediction_count"],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_prediction_result(self, row) -> PredictionResult: def _row_to_prediction_result(self, row) -> PredictionResult:
"""将数据库行转换为 PredictionResult""" """将数据库行转换为 PredictionResult"""
return PredictionResult( return PredictionResult(
id = row["id"], id=row["id"],
model_id = row["model_id"], model_id=row["model_id"],
prediction_type = PredictionType(row["prediction_type"]), prediction_type=PredictionType(row["prediction_type"]),
target_id = row["target_id"], target_id=row["target_id"],
prediction_data = json.loads(row["prediction_data"]), prediction_data=json.loads(row["prediction_data"]),
confidence = row["confidence"], confidence=row["confidence"],
explanation = row["explanation"], explanation=row["explanation"],
actual_value = row["actual_value"], actual_value=row["actual_value"],
is_correct = row["is_correct"], is_correct=row["is_correct"],
created_at = row["created_at"], created_at=row["created_at"],
) )

View File

@@ -152,23 +152,23 @@ class ApiKeyManager:
expires_at = None expires_at = None
if expires_days: if expires_days:
expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat() expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
api_key = ApiKey( api_key = ApiKey(
id = key_id, id=key_id,
key_hash = key_hash, key_hash=key_hash,
key_preview = key_preview, key_preview=key_preview,
name = name, name=name,
owner_id = owner_id, owner_id=owner_id,
permissions = permissions, permissions=permissions,
rate_limit = rate_limit, rate_limit=rate_limit,
status = ApiKeyStatus.ACTIVE.value, status=ApiKeyStatus.ACTIVE.value,
created_at = datetime.now().isoformat(), created_at=datetime.now().isoformat(),
expires_at = expires_at, expires_at=expires_at,
last_used_at = None, last_used_at=None,
revoked_at = None, revoked_at=None,
revoked_reason = None, revoked_reason=None,
total_calls = 0, total_calls=0,
) )
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -510,20 +510,20 @@ class ApiKeyManager:
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey: def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
"""将数据库行转换为 ApiKey 对象""" """将数据库行转换为 ApiKey 对象"""
return ApiKey( return ApiKey(
id = row["id"], id=row["id"],
key_hash = row["key_hash"], key_hash=row["key_hash"],
key_preview = row["key_preview"], key_preview=row["key_preview"],
name = row["name"], name=row["name"],
owner_id = row["owner_id"], owner_id=row["owner_id"],
permissions = json.loads(row["permissions"]), permissions=json.loads(row["permissions"]),
rate_limit = row["rate_limit"], rate_limit=row["rate_limit"],
status = row["status"], status=row["status"],
created_at = row["created_at"], created_at=row["created_at"],
expires_at = row["expires_at"], expires_at=row["expires_at"],
last_used_at = row["last_used_at"], last_used_at=row["last_used_at"],
revoked_at = row["revoked_at"], revoked_at=row["revoked_at"],
revoked_reason = row["revoked_reason"], revoked_reason=row["revoked_reason"],
total_calls = row["total_calls"], total_calls=row["total_calls"],
) )

View File

@@ -136,7 +136,7 @@ class TeamSpace:
class CollaborationManager: class CollaborationManager:
"""协作管理主类""" """协作管理主类"""
def __init__(self, db_manager = None) -> None: def __init__(self, db_manager=None) -> None:
self.db = db_manager self.db = db_manager
self._shares_cache: dict[str, ProjectShare] = {} self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: dict[str, list[Comment]] = {} self._comments_cache: dict[str, list[Comment]] = {}
@@ -161,26 +161,26 @@ class CollaborationManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
expires_at = None expires_at = None
if expires_in_days: if expires_in_days:
expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat() expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat()
password_hash = None password_hash = None
if password: if password:
password_hash = hashlib.sha256(password.encode()).hexdigest() password_hash = hashlib.sha256(password.encode()).hexdigest()
share = ProjectShare( share = ProjectShare(
id = share_id, id=share_id,
project_id = project_id, project_id=project_id,
token = token, token=token,
permission = permission, permission=permission,
created_by = created_by, created_by=created_by,
created_at = now, created_at=now,
expires_at = expires_at, expires_at=expires_at,
max_uses = max_uses, max_uses=max_uses,
use_count = 0, use_count=0,
password_hash = password_hash, password_hash=password_hash,
is_active = True, is_active=True,
allow_download = allow_download, allow_download=allow_download,
allow_export = allow_export, allow_export=allow_export,
) )
# 保存到数据库 # 保存到数据库
@@ -271,19 +271,19 @@ class CollaborationManager:
return None return None
return ProjectShare( return ProjectShare(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
token = row[2], token=row[2],
permission = row[3], permission=row[3],
created_by = row[4], created_by=row[4],
created_at = row[5], created_at=row[5],
expires_at = row[6], expires_at=row[6],
max_uses = row[7], max_uses=row[7],
use_count = row[8], use_count=row[8],
password_hash = row[9], password_hash=row[9],
is_active = bool(row[10]), is_active=bool(row[10]),
allow_download = bool(row[11]), allow_download=bool(row[11]),
allow_export = bool(row[12]), allow_export=bool(row[12]),
) )
def increment_share_usage(self, token: str) -> None: def increment_share_usage(self, token: str) -> None:
@@ -339,19 +339,19 @@ class CollaborationManager:
for row in cursor.fetchall(): for row in cursor.fetchall():
shares.append( shares.append(
ProjectShare( ProjectShare(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
token = row[2], token=row[2],
permission = row[3], permission=row[3],
created_by = row[4], created_by=row[4],
created_at = row[5], created_at=row[5],
expires_at = row[6], expires_at=row[6],
max_uses = row[7], max_uses=row[7],
use_count = row[8], use_count=row[8],
password_hash = row[9], password_hash=row[9],
is_active = bool(row[10]), is_active=bool(row[10]),
allow_download = bool(row[11]), allow_download=bool(row[11]),
allow_export = bool(row[12]), allow_export=bool(row[12]),
) )
) )
return shares return shares
@@ -375,21 +375,21 @@ class CollaborationManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
comment = Comment( comment = Comment(
id = comment_id, id=comment_id,
project_id = project_id, project_id=project_id,
target_type = target_type, target_type=target_type,
target_id = target_id, target_id=target_id,
parent_id = parent_id, parent_id=parent_id,
author = author, author=author,
author_name = author_name, author_name=author_name,
content = content, content=content,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
resolved = False, resolved=False,
resolved_by = None, resolved_by=None,
resolved_at = None, resolved_at=None,
mentions = mentions or [], mentions=mentions or [],
attachments = attachments or [], attachments=attachments or [],
) )
if self.db: if self.db:
@@ -469,21 +469,21 @@ class CollaborationManager:
def _row_to_comment(self, row) -> Comment: def _row_to_comment(self, row) -> Comment:
"""将数据库行转换为Comment对象""" """将数据库行转换为Comment对象"""
return Comment( return Comment(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
target_type = row[2], target_type=row[2],
target_id = row[3], target_id=row[3],
parent_id = row[4], parent_id=row[4],
author = row[5], author=row[5],
author_name = row[6], author_name=row[6],
content = row[7], content=row[7],
created_at = row[8], created_at=row[8],
updated_at = row[9], updated_at=row[9],
resolved = bool(row[10]), resolved=bool(row[10]),
resolved_by = row[11], resolved_by=row[11],
resolved_at = row[12], resolved_at=row[12],
mentions = json.loads(row[13]) if row[13] else [], mentions=json.loads(row[13]) if row[13] else [],
attachments = json.loads(row[14]) if row[14] else [], attachments=json.loads(row[14]) if row[14] else [],
) )
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None: def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
@@ -597,22 +597,22 @@ class CollaborationManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
record = ChangeRecord( record = ChangeRecord(
id = record_id, id=record_id,
project_id = project_id, project_id=project_id,
change_type = change_type, change_type=change_type,
entity_type = entity_type, entity_type=entity_type,
entity_id = entity_id, entity_id=entity_id,
entity_name = entity_name, entity_name=entity_name,
changed_by = changed_by, changed_by=changed_by,
changed_by_name = changed_by_name, changed_by_name=changed_by_name,
changed_at = now, changed_at=now,
old_value = old_value, old_value=old_value,
new_value = new_value, new_value=new_value,
description = description, description=description,
session_id = session_id, session_id=session_id,
reverted = False, reverted=False,
reverted_at = None, reverted_at=None,
reverted_by = None, reverted_by=None,
) )
if self.db: if self.db:
@@ -705,22 +705,22 @@ class CollaborationManager:
def _row_to_change_record(self, row) -> ChangeRecord: def _row_to_change_record(self, row) -> ChangeRecord:
"""将数据库行转换为ChangeRecord对象""" """将数据库行转换为ChangeRecord对象"""
return ChangeRecord( return ChangeRecord(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
change_type = row[2], change_type=row[2],
entity_type = row[3], entity_type=row[3],
entity_id = row[4], entity_id=row[4],
entity_name = row[5], entity_name=row[5],
changed_by = row[6], changed_by=row[6],
changed_by_name = row[7], changed_by_name=row[7],
changed_at = row[8], changed_at=row[8],
old_value = json.loads(row[9]) if row[9] else None, old_value=json.loads(row[9]) if row[9] else None,
new_value = json.loads(row[10]) if row[10] else None, new_value=json.loads(row[10]) if row[10] else None,
description = row[11], description=row[11],
session_id = row[12], session_id=row[12],
reverted = bool(row[13]), reverted=bool(row[13]),
reverted_at = row[14], reverted_at=row[14],
reverted_by = row[15], reverted_by=row[15],
) )
def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]: def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
@@ -838,16 +838,16 @@ class CollaborationManager:
permissions = self._get_default_permissions(role) permissions = self._get_default_permissions(role)
member = TeamMember( member = TeamMember(
id = member_id, id=member_id,
project_id = project_id, project_id=project_id,
user_id = user_id, user_id=user_id,
user_name = user_name, user_name=user_name,
user_email = user_email, user_email=user_email,
role = role, role=role,
joined_at = now, joined_at=now,
invited_by = invited_by, invited_by=invited_by,
last_active_at = None, last_active_at=None,
permissions = permissions, permissions=permissions,
) )
if self.db: if self.db:
@@ -913,16 +913,16 @@ class CollaborationManager:
def _row_to_team_member(self, row) -> TeamMember: def _row_to_team_member(self, row) -> TeamMember:
"""将数据库行转换为TeamMember对象""" """将数据库行转换为TeamMember对象"""
return TeamMember( return TeamMember(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
user_id = row[2], user_id=row[2],
user_name = row[3], user_name=row[3],
user_email = row[4], user_email=row[4],
role = row[5], role=row[5],
joined_at = row[6], joined_at=row[6],
invited_by = row[7], invited_by=row[7],
last_active_at = row[8], last_active_at=row[8],
permissions = json.loads(row[9]) if row[9] else [], permissions=json.loads(row[9]) if row[9] else [],
) )
def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool: def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool:
@@ -996,7 +996,7 @@ class CollaborationManager:
_collaboration_manager = None _collaboration_manager = None
def get_collaboration_manager(db_manager = None) -> None: def get_collaboration_manager(db_manager=None) -> None:
"""获取协作管理器单例""" """获取协作管理器单例"""
global _collaboration_manager global _collaboration_manager
if _collaboration_manager is None: if _collaboration_manager is None:

View File

@@ -118,7 +118,7 @@ class EntityMention:
class DatabaseManager: class DatabaseManager:
def __init__(self, db_path: str = DB_PATH) -> None: def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok = True) os.makedirs(os.path.dirname(db_path), exist_ok=True)
self.init_db() self.init_db()
def get_conn(self) -> None: def get_conn(self) -> None:
@@ -149,7 +149,7 @@ class DatabaseManager:
conn.commit() conn.commit()
conn.close() conn.close()
return Project( return Project(
id = project_id, name = name, description = description, created_at = now, updated_at = now id=project_id, name=name, description=description, created_at=now, updated_at=now
) )
def get_project(self, project_id: str) -> Project | None: def get_project(self, project_id: str) -> Project | None:
@@ -708,7 +708,7 @@ class DatabaseManager:
) )
conn.close() conn.close()
timeline_events.sort(key = lambda x: x["event_date"]) timeline_events.sort(key=lambda x: x["event_date"])
return timeline_events return timeline_events
def get_entity_timeline_summary(self, project_id: str) -> dict: def get_entity_timeline_summary(self, project_id: str) -> dict:

View File

@@ -382,26 +382,26 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
sdk = SDKRelease( sdk = SDKRelease(
id = sdk_id, id=sdk_id,
name = name, name=name,
language = language, language=language,
version = version, version=version,
description = description, description=description,
changelog = changelog, changelog=changelog,
download_url = download_url, download_url=download_url,
documentation_url = documentation_url, documentation_url=documentation_url,
repository_url = repository_url, repository_url=repository_url,
package_name = package_name, package_name=package_name,
status = SDKStatus.DRAFT, status=SDKStatus.DRAFT,
min_platform_version = min_platform_version, min_platform_version=min_platform_version,
dependencies = dependencies, dependencies=dependencies,
file_size = file_size, file_size=file_size,
checksum = checksum, checksum=checksum,
download_count = 0, download_count=0,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
published_at = None, published_at=None,
created_by = created_by, created_by=created_by,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -585,17 +585,17 @@ class DeveloperEcosystemManager:
conn.commit() conn.commit()
return SDKVersion( return SDKVersion(
id = version_id, id=version_id,
sdk_id = sdk_id, sdk_id=sdk_id,
version = version, version=version,
is_latest = True, is_latest=True,
is_lts = is_lts, is_lts=is_lts,
release_notes = release_notes, release_notes=release_notes,
download_url = download_url, download_url=download_url,
checksum = checksum, checksum=checksum,
file_size = file_size, file_size=file_size,
download_count = 0, download_count=0,
created_at = now, created_at=now,
) )
# ==================== 模板市场 ==================== # ==================== 模板市场 ====================
@@ -625,32 +625,32 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
template = TemplateMarketItem( template = TemplateMarketItem(
id = template_id, id=template_id,
name = name, name=name,
description = description, description=description,
category = category, category=category,
subcategory = subcategory, subcategory=subcategory,
tags = tags, tags=tags,
author_id = author_id, author_id=author_id,
author_name = author_name, author_name=author_name,
status = TemplateStatus.PENDING, status=TemplateStatus.PENDING,
price = price, price=price,
currency = currency, currency=currency,
preview_image_url = preview_image_url, preview_image_url=preview_image_url,
demo_url = demo_url, demo_url=demo_url,
documentation_url = documentation_url, documentation_url=documentation_url,
download_url = download_url, download_url=download_url,
install_count = 0, install_count=0,
rating = 0.0, rating=0.0,
rating_count = 0, rating_count=0,
review_count = 0, review_count=0,
version = version, version=version,
min_platform_version = min_platform_version, min_platform_version=min_platform_version,
file_size = file_size, file_size=file_size,
checksum = checksum, checksum=checksum,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
published_at = None, published_at=None,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -832,16 +832,16 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
review = TemplateReview( review = TemplateReview(
id = review_id, id=review_id,
template_id = template_id, template_id=template_id,
user_id = user_id, user_id=user_id,
user_name = user_name, user_name=user_name,
rating = rating, rating=rating,
comment = comment, comment=comment,
is_verified_purchase = is_verified_purchase, is_verified_purchase=is_verified_purchase,
helpful_count = 0, helpful_count=0,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -940,39 +940,39 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
plugin = PluginMarketItem( plugin = PluginMarketItem(
id = plugin_id, id=plugin_id,
name = name, name=name,
description = description, description=description,
category = category, category=category,
tags = tags, tags=tags,
author_id = author_id, author_id=author_id,
author_name = author_name, author_name=author_name,
status = PluginStatus.PENDING, status=PluginStatus.PENDING,
price = price, price=price,
currency = currency, currency=currency,
pricing_model = pricing_model, pricing_model=pricing_model,
preview_image_url = preview_image_url, preview_image_url=preview_image_url,
demo_url = demo_url, demo_url=demo_url,
documentation_url = documentation_url, documentation_url=documentation_url,
repository_url = repository_url, repository_url=repository_url,
download_url = download_url, download_url=download_url,
webhook_url = webhook_url, webhook_url=webhook_url,
permissions = permissions or [], permissions=permissions or [],
install_count = 0, install_count=0,
active_install_count = 0, active_install_count=0,
rating = 0.0, rating=0.0,
rating_count = 0, rating_count=0,
review_count = 0, review_count=0,
version = version, version=version,
min_platform_version = min_platform_version, min_platform_version=min_platform_version,
file_size = file_size, file_size=file_size,
checksum = checksum, checksum=checksum,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
published_at = None, published_at=None,
reviewed_by = None, reviewed_by=None,
reviewed_at = None, reviewed_at=None,
review_notes = None, review_notes=None,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1148,16 +1148,16 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
review = PluginReview( review = PluginReview(
id = review_id, id=review_id,
plugin_id = plugin_id, plugin_id=plugin_id,
user_id = user_id, user_id=user_id,
user_name = user_name, user_name=user_name,
rating = rating, rating=rating,
comment = comment, comment=comment,
is_verified_purchase = is_verified_purchase, is_verified_purchase=is_verified_purchase,
helpful_count = 0, helpful_count=0,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1246,18 +1246,18 @@ class DeveloperEcosystemManager:
developer_earnings = sale_amount - platform_fee developer_earnings = sale_amount - platform_fee
revenue = DeveloperRevenue( revenue = DeveloperRevenue(
id = revenue_id, id=revenue_id,
developer_id = developer_id, developer_id=developer_id,
item_type = item_type, item_type=item_type,
item_id = item_id, item_id=item_id,
item_name = item_name, item_name=item_name,
sale_amount = sale_amount, sale_amount=sale_amount,
platform_fee = platform_fee, platform_fee=platform_fee,
developer_earnings = developer_earnings, developer_earnings=developer_earnings,
currency = currency, currency=currency,
buyer_id = buyer_id, buyer_id=buyer_id,
transaction_id = transaction_id, transaction_id=transaction_id,
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1362,24 +1362,24 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
profile = DeveloperProfile( profile = DeveloperProfile(
id = profile_id, id=profile_id,
user_id = user_id, user_id=user_id,
display_name = display_name, display_name=display_name,
email = email, email=email,
bio = bio, bio=bio,
website = website, website=website,
github_url = github_url, github_url=github_url,
avatar_url = avatar_url, avatar_url=avatar_url,
status = DeveloperStatus.UNVERIFIED, status=DeveloperStatus.UNVERIFIED,
verification_documents = {}, verification_documents={},
total_sales = 0.0, total_sales=0.0,
total_downloads = 0, total_downloads=0,
plugin_count = 0, plugin_count=0,
template_count = 0, template_count=0,
rating_average = 0.0, rating_average=0.0,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
verified_at = None, verified_at=None,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1526,23 +1526,23 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
example = CodeExample( example = CodeExample(
id = example_id, id=example_id,
title = title, title=title,
description = description, description=description,
language = language, language=language,
category = category, category=category,
code = code, code=code,
explanation = explanation, explanation=explanation,
tags = tags, tags=tags,
author_id = author_id, author_id=author_id,
author_name = author_name, author_name=author_name,
sdk_id = sdk_id, sdk_id=sdk_id,
api_endpoints = api_endpoints or [], api_endpoints=api_endpoints or [],
view_count = 0, view_count=0,
copy_count = 0, copy_count=0,
rating = 0.0, rating=0.0,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1659,14 +1659,14 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
doc = APIDocumentation( doc = APIDocumentation(
id = doc_id, id=doc_id,
version = version, version=version,
openapi_spec = openapi_spec, openapi_spec=openapi_spec,
markdown_content = markdown_content, markdown_content=markdown_content,
html_content = html_content, html_content=html_content,
changelog = changelog, changelog=changelog,
generated_at = now, generated_at=now,
generated_by = generated_by, generated_by=generated_by,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1736,24 +1736,24 @@ class DeveloperEcosystemManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
config = DeveloperPortalConfig( config = DeveloperPortalConfig(
id = config_id, id=config_id,
name = name, name=name,
description = description, description=description,
theme = theme, theme=theme,
custom_css = custom_css, custom_css=custom_css,
custom_js = custom_js, custom_js=custom_js,
logo_url = logo_url, logo_url=logo_url,
favicon_url = favicon_url, favicon_url=favicon_url,
primary_color = primary_color, primary_color=primary_color,
secondary_color = secondary_color, secondary_color=secondary_color,
support_email = support_email, support_email=support_email,
support_url = support_url, support_url=support_url,
github_url = github_url, github_url=github_url,
discord_url = discord_url, discord_url=discord_url,
api_base_url = api_base_url, api_base_url=api_base_url,
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1817,239 +1817,239 @@ class DeveloperEcosystemManager:
def _row_to_sdk_release(self, row) -> SDKRelease: def _row_to_sdk_release(self, row) -> SDKRelease:
"""将数据库行转换为 SDKRelease""" """将数据库行转换为 SDKRelease"""
return SDKRelease( return SDKRelease(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
language = SDKLanguage(row["language"]), language=SDKLanguage(row["language"]),
version = row["version"], version=row["version"],
description = row["description"], description=row["description"],
changelog = row["changelog"], changelog=row["changelog"],
download_url = row["download_url"], download_url=row["download_url"],
documentation_url = row["documentation_url"], documentation_url=row["documentation_url"],
repository_url = row["repository_url"], repository_url=row["repository_url"],
package_name = row["package_name"], package_name=row["package_name"],
status = SDKStatus(row["status"]), status=SDKStatus(row["status"]),
min_platform_version = row["min_platform_version"], min_platform_version=row["min_platform_version"],
dependencies = json.loads(row["dependencies"]), dependencies=json.loads(row["dependencies"]),
file_size = row["file_size"], file_size=row["file_size"],
checksum = row["checksum"], checksum=row["checksum"],
download_count = row["download_count"], download_count=row["download_count"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
published_at = row["published_at"], published_at=row["published_at"],
created_by = row["created_by"], created_by=row["created_by"],
) )
def _row_to_sdk_version(self, row) -> SDKVersion: def _row_to_sdk_version(self, row) -> SDKVersion:
"""将数据库行转换为 SDKVersion""" """将数据库行转换为 SDKVersion"""
return SDKVersion( return SDKVersion(
id = row["id"], id=row["id"],
sdk_id = row["sdk_id"], sdk_id=row["sdk_id"],
version = row["version"], version=row["version"],
is_latest = bool(row["is_latest"]), is_latest=bool(row["is_latest"]),
is_lts = bool(row["is_lts"]), is_lts=bool(row["is_lts"]),
release_notes = row["release_notes"], release_notes=row["release_notes"],
download_url = row["download_url"], download_url=row["download_url"],
checksum = row["checksum"], checksum=row["checksum"],
file_size = row["file_size"], file_size=row["file_size"],
download_count = row["download_count"], download_count=row["download_count"],
created_at = row["created_at"], created_at=row["created_at"],
) )
def _row_to_template(self, row) -> TemplateMarketItem: def _row_to_template(self, row) -> TemplateMarketItem:
"""将数据库行转换为 TemplateMarketItem""" """将数据库行转换为 TemplateMarketItem"""
return TemplateMarketItem( return TemplateMarketItem(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
category = TemplateCategory(row["category"]), category=TemplateCategory(row["category"]),
subcategory = row["subcategory"], subcategory=row["subcategory"],
tags = json.loads(row["tags"]), tags=json.loads(row["tags"]),
author_id = row["author_id"], author_id=row["author_id"],
author_name = row["author_name"], author_name=row["author_name"],
status = TemplateStatus(row["status"]), status=TemplateStatus(row["status"]),
price = row["price"], price=row["price"],
currency = row["currency"], currency=row["currency"],
preview_image_url = row["preview_image_url"], preview_image_url=row["preview_image_url"],
demo_url = row["demo_url"], demo_url=row["demo_url"],
documentation_url = row["documentation_url"], documentation_url=row["documentation_url"],
download_url = row["download_url"], download_url=row["download_url"],
install_count = row["install_count"], install_count=row["install_count"],
rating = row["rating"], rating=row["rating"],
rating_count = row["rating_count"], rating_count=row["rating_count"],
review_count = row["review_count"], review_count=row["review_count"],
version = row["version"], version=row["version"],
min_platform_version = row["min_platform_version"], min_platform_version=row["min_platform_version"],
file_size = row["file_size"], file_size=row["file_size"],
checksum = row["checksum"], checksum=row["checksum"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
published_at = row["published_at"], published_at=row["published_at"],
) )
def _row_to_template_review(self, row) -> TemplateReview: def _row_to_template_review(self, row) -> TemplateReview:
"""将数据库行转换为 TemplateReview""" """将数据库行转换为 TemplateReview"""
return TemplateReview( return TemplateReview(
id = row["id"], id=row["id"],
template_id = row["template_id"], template_id=row["template_id"],
user_id = row["user_id"], user_id=row["user_id"],
user_name = row["user_name"], user_name=row["user_name"],
rating = row["rating"], rating=row["rating"],
comment = row["comment"], comment=row["comment"],
is_verified_purchase = bool(row["is_verified_purchase"]), is_verified_purchase=bool(row["is_verified_purchase"]),
helpful_count = row["helpful_count"], helpful_count=row["helpful_count"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_plugin(self, row) -> PluginMarketItem: def _row_to_plugin(self, row) -> PluginMarketItem:
"""将数据库行转换为 PluginMarketItem""" """将数据库行转换为 PluginMarketItem"""
return PluginMarketItem( return PluginMarketItem(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
category = PluginCategory(row["category"]), category=PluginCategory(row["category"]),
tags = json.loads(row["tags"]), tags=json.loads(row["tags"]),
author_id = row["author_id"], author_id=row["author_id"],
author_name = row["author_name"], author_name=row["author_name"],
status = PluginStatus(row["status"]), status=PluginStatus(row["status"]),
price = row["price"], price=row["price"],
currency = row["currency"], currency=row["currency"],
pricing_model = row["pricing_model"], pricing_model=row["pricing_model"],
preview_image_url = row["preview_image_url"], preview_image_url=row["preview_image_url"],
demo_url = row["demo_url"], demo_url=row["demo_url"],
documentation_url = row["documentation_url"], documentation_url=row["documentation_url"],
repository_url = row["repository_url"], repository_url=row["repository_url"],
download_url = row["download_url"], download_url=row["download_url"],
webhook_url = row["webhook_url"], webhook_url=row["webhook_url"],
permissions = json.loads(row["permissions"]), permissions=json.loads(row["permissions"]),
install_count = row["install_count"], install_count=row["install_count"],
active_install_count = row["active_install_count"], active_install_count=row["active_install_count"],
rating = row["rating"], rating=row["rating"],
rating_count = row["rating_count"], rating_count=row["rating_count"],
review_count = row["review_count"], review_count=row["review_count"],
version = row["version"], version=row["version"],
min_platform_version = row["min_platform_version"], min_platform_version=row["min_platform_version"],
file_size = row["file_size"], file_size=row["file_size"],
checksum = row["checksum"], checksum=row["checksum"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
published_at = row["published_at"], published_at=row["published_at"],
reviewed_by = row["reviewed_by"], reviewed_by=row["reviewed_by"],
reviewed_at = row["reviewed_at"], reviewed_at=row["reviewed_at"],
review_notes = row["review_notes"], review_notes=row["review_notes"],
) )
def _row_to_plugin_review(self, row) -> PluginReview: def _row_to_plugin_review(self, row) -> PluginReview:
"""将数据库行转换为 PluginReview""" """将数据库行转换为 PluginReview"""
return PluginReview( return PluginReview(
id = row["id"], id=row["id"],
plugin_id = row["plugin_id"], plugin_id=row["plugin_id"],
user_id = row["user_id"], user_id=row["user_id"],
user_name = row["user_name"], user_name=row["user_name"],
rating = row["rating"], rating=row["rating"],
comment = row["comment"], comment=row["comment"],
is_verified_purchase = bool(row["is_verified_purchase"]), is_verified_purchase=bool(row["is_verified_purchase"]),
helpful_count = row["helpful_count"], helpful_count=row["helpful_count"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_developer_profile(self, row) -> DeveloperProfile: def _row_to_developer_profile(self, row) -> DeveloperProfile:
"""将数据库行转换为 DeveloperProfile""" """将数据库行转换为 DeveloperProfile"""
return DeveloperProfile( return DeveloperProfile(
id = row["id"], id=row["id"],
user_id = row["user_id"], user_id=row["user_id"],
display_name = row["display_name"], display_name=row["display_name"],
email = row["email"], email=row["email"],
bio = row["bio"], bio=row["bio"],
website = row["website"], website=row["website"],
github_url = row["github_url"], github_url=row["github_url"],
avatar_url = row["avatar_url"], avatar_url=row["avatar_url"],
status = DeveloperStatus(row["status"]), status=DeveloperStatus(row["status"]),
verification_documents = json.loads(row["verification_documents"]), verification_documents=json.loads(row["verification_documents"]),
total_sales = row["total_sales"], total_sales=row["total_sales"],
total_downloads = row["total_downloads"], total_downloads=row["total_downloads"],
plugin_count = row["plugin_count"], plugin_count=row["plugin_count"],
template_count = row["template_count"], template_count=row["template_count"],
rating_average = row["rating_average"], rating_average=row["rating_average"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
verified_at = row["verified_at"], verified_at=row["verified_at"],
) )
def _row_to_developer_revenue(self, row) -> DeveloperRevenue: def _row_to_developer_revenue(self, row) -> DeveloperRevenue:
"""将数据库行转换为 DeveloperRevenue""" """将数据库行转换为 DeveloperRevenue"""
return DeveloperRevenue( return DeveloperRevenue(
id = row["id"], id=row["id"],
developer_id = row["developer_id"], developer_id=row["developer_id"],
item_type = row["item_type"], item_type=row["item_type"],
item_id = row["item_id"], item_id=row["item_id"],
item_name = row["item_name"], item_name=row["item_name"],
sale_amount = row["sale_amount"], sale_amount=row["sale_amount"],
platform_fee = row["platform_fee"], platform_fee=row["platform_fee"],
developer_earnings = row["developer_earnings"], developer_earnings=row["developer_earnings"],
currency = row["currency"], currency=row["currency"],
buyer_id = row["buyer_id"], buyer_id=row["buyer_id"],
transaction_id = row["transaction_id"], transaction_id=row["transaction_id"],
created_at = row["created_at"], created_at=row["created_at"],
) )
def _row_to_code_example(self, row) -> CodeExample: def _row_to_code_example(self, row) -> CodeExample:
"""将数据库行转换为 CodeExample""" """将数据库行转换为 CodeExample"""
return CodeExample( return CodeExample(
id = row["id"], id=row["id"],
title = row["title"], title=row["title"],
description = row["description"], description=row["description"],
language = row["language"], language=row["language"],
category = row["category"], category=row["category"],
code = row["code"], code=row["code"],
explanation = row["explanation"], explanation=row["explanation"],
tags = json.loads(row["tags"]), tags=json.loads(row["tags"]),
author_id = row["author_id"], author_id=row["author_id"],
author_name = row["author_name"], author_name=row["author_name"],
sdk_id = row["sdk_id"], sdk_id=row["sdk_id"],
api_endpoints = json.loads(row["api_endpoints"]), api_endpoints=json.loads(row["api_endpoints"]),
view_count = row["view_count"], view_count=row["view_count"],
copy_count = row["copy_count"], copy_count=row["copy_count"],
rating = row["rating"], rating=row["rating"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_api_documentation(self, row) -> APIDocumentation: def _row_to_api_documentation(self, row) -> APIDocumentation:
"""将数据库行转换为 APIDocumentation""" """将数据库行转换为 APIDocumentation"""
return APIDocumentation( return APIDocumentation(
id = row["id"], id=row["id"],
version = row["version"], version=row["version"],
openapi_spec = row["openapi_spec"], openapi_spec=row["openapi_spec"],
markdown_content = row["markdown_content"], markdown_content=row["markdown_content"],
html_content = row["html_content"], html_content=row["html_content"],
changelog = row["changelog"], changelog=row["changelog"],
generated_at = row["generated_at"], generated_at=row["generated_at"],
generated_by = row["generated_by"], generated_by=row["generated_by"],
) )
def _row_to_portal_config(self, row) -> DeveloperPortalConfig: def _row_to_portal_config(self, row) -> DeveloperPortalConfig:
"""将数据库行转换为 DeveloperPortalConfig""" """将数据库行转换为 DeveloperPortalConfig"""
return DeveloperPortalConfig( return DeveloperPortalConfig(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
theme = row["theme"], theme=row["theme"],
custom_css = row["custom_css"], custom_css=row["custom_css"],
custom_js = row["custom_js"], custom_js=row["custom_js"],
logo_url = row["logo_url"], logo_url=row["logo_url"],
favicon_url = row["favicon_url"], favicon_url=row["favicon_url"],
primary_color = row["primary_color"], primary_color=row["primary_color"],
secondary_color = row["secondary_color"], secondary_color=row["secondary_color"],
support_email = row["support_email"], support_email=row["support_email"],
support_url = row["support_url"], support_url=row["support_url"],
github_url = row["github_url"], github_url=row["github_url"],
discord_url = row["discord_url"], discord_url=row["discord_url"],
api_base_url = row["api_base_url"], api_base_url=row["api_base_url"],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )

View File

@@ -123,7 +123,7 @@ class DocumentProcessor:
continue continue
# 如果都失败了,使用 latin-1 并忽略错误 # 如果都失败了,使用 latin-1 并忽略错误
return content.decode("latin-1", errors = "ignore") return content.decode("latin-1", errors="ignore")
def _clean_text(self, text: str) -> str: def _clean_text(self, text: str) -> str:
"""清理提取的文本""" """清理提取的文本"""
@@ -173,7 +173,7 @@ class SimpleTextExtractor:
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
return content.decode("latin-1", errors = "ignore") return content.decode("latin-1", errors="ignore")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -610,30 +610,30 @@ class EnterpriseManager:
attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)] attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)]
config = SSOConfig( config = SSOConfig(
id = config_id, id=config_id,
tenant_id = tenant_id, tenant_id=tenant_id,
provider = provider, provider=provider,
status = SSOStatus.PENDING.value, status=SSOStatus.PENDING.value,
entity_id = entity_id, entity_id=entity_id,
sso_url = sso_url, sso_url=sso_url,
slo_url = slo_url, slo_url=slo_url,
certificate = certificate, certificate=certificate,
metadata_url = metadata_url, metadata_url=metadata_url,
metadata_xml = metadata_xml, metadata_xml=metadata_xml,
client_id = client_id, client_id=client_id,
client_secret = client_secret, client_secret=client_secret,
authorization_url = authorization_url, authorization_url=authorization_url,
token_url = token_url, token_url=token_url,
userinfo_url = userinfo_url, userinfo_url=userinfo_url,
scopes = scopes or ["openid", "email", "profile"], scopes=scopes or ["openid", "email", "profile"],
attribute_mapping = attribute_mapping or {}, attribute_mapping=attribute_mapping or {},
auto_provision = auto_provision, auto_provision=auto_provision,
default_role = default_role, default_role=default_role,
domain_restriction = domain_restriction or [], domain_restriction=domain_restriction or [],
created_at = now, created_at=now,
updated_at = now, updated_at=now,
last_tested_at = None, last_tested_at=None,
last_error = None, last_error=None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -878,18 +878,18 @@ class EnterpriseManager:
try: try:
request_id = f"_{uuid.uuid4().hex}" request_id = f"_{uuid.uuid4().hex}"
now = datetime.now() now = datetime.now()
expires = now + timedelta(minutes = 10) expires = now + timedelta(minutes=10)
auth_request = SAMLAuthRequest( auth_request = SAMLAuthRequest(
id = str(uuid.uuid4()), id=str(uuid.uuid4()),
tenant_id = tenant_id, tenant_id=tenant_id,
sso_config_id = config_id, sso_config_id=config_id,
request_id = request_id, request_id=request_id,
relay_state = relay_state, relay_state=relay_state,
created_at = now, created_at=now,
expires_at = expires, expires_at=expires,
used = False, used=False,
used_at = None, used_at=None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -949,17 +949,17 @@ class EnterpriseManager:
attributes = self._parse_saml_response(saml_response) attributes = self._parse_saml_response(saml_response)
auth_response = SAMLAuthResponse( auth_response = SAMLAuthResponse(
id = str(uuid.uuid4()), id=str(uuid.uuid4()),
request_id = request_id, request_id=request_id,
tenant_id = "", # 从 request 获取 tenant_id="", # 从 request 获取
user_id = None, user_id=None,
email = attributes.get("email"), email=attributes.get("email"),
name = attributes.get("name"), name=attributes.get("name"),
attributes = attributes, attributes=attributes,
session_index = attributes.get("session_index"), session_index=attributes.get("session_index"),
processed = False, processed=False,
processed_at = None, processed_at=None,
created_at = datetime.now(), created_at=datetime.now(),
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1028,21 +1028,21 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
config = SCIMConfig( config = SCIMConfig(
id = config_id, id=config_id,
tenant_id = tenant_id, tenant_id=tenant_id,
provider = provider, provider=provider,
status = "disabled", status="disabled",
scim_base_url = scim_base_url, scim_base_url=scim_base_url,
scim_token = scim_token, scim_token=scim_token,
sync_interval_minutes = sync_interval_minutes, sync_interval_minutes=sync_interval_minutes,
last_sync_at = None, last_sync_at=None,
last_sync_status = None, last_sync_status=None,
last_sync_error = None, last_sync_error=None,
last_sync_users_count = 0, last_sync_users_count=0,
attribute_mapping = attribute_mapping or {}, attribute_mapping=attribute_mapping or {},
sync_rules = sync_rules or {}, sync_rules=sync_rules or {},
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1325,28 +1325,28 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
# 默认7天后过期 # 默认7天后过期
expires_at = now + timedelta(days = 7) expires_at = now + timedelta(days=7)
export = AuditLogExport( export = AuditLogExport(
id = export_id, id=export_id,
tenant_id = tenant_id, tenant_id=tenant_id,
export_format = export_format, export_format=export_format,
start_date = start_date, start_date=start_date,
end_date = end_date, end_date=end_date,
filters = filters or {}, filters=filters or {},
compliance_standard = compliance_standard, compliance_standard=compliance_standard,
status = "pending", status="pending",
file_path = None, file_path=None,
file_size = None, file_size=None,
record_count = None, record_count=None,
checksum = None, checksum=None,
downloaded_by = None, downloaded_by=None,
downloaded_at = None, downloaded_at=None,
expires_at = expires_at, expires_at=expires_at,
created_by = created_by, created_by=created_by,
created_at = now, created_at=now,
completed_at = None, completed_at=None,
error_message = None, error_message=None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1383,7 +1383,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def process_audit_export(self, export_id: str, db_manager = None) -> AuditLogExport | None: def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None:
"""处理审计日志导出任务""" """处理审计日志导出任务"""
export = self.get_audit_export(export_id) export = self.get_audit_export(export_id)
if not export: if not export:
@@ -1454,7 +1454,7 @@ class EnterpriseManager:
start_date: datetime, start_date: datetime,
end_date: datetime, end_date: datetime,
filters: dict[str, Any], filters: dict[str, Any],
db_manager = None, db_manager=None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""获取审计日志数据""" """获取审计日志数据"""
if db_manager is None: if db_manager is None:
@@ -1488,26 +1488,26 @@ class EnterpriseManager:
import os import os
export_dir = "/tmp/insightflow/exports" export_dir = "/tmp/insightflow/exports"
os.makedirs(export_dir, exist_ok = True) os.makedirs(export_dir, exist_ok=True)
file_path = f"{export_dir}/audit_export_{export_id}.{format}" file_path = f"{export_dir}/audit_export_{export_id}.{format}"
if format == "json": if format == "json":
content = json.dumps(logs, ensure_ascii = False, indent = 2) content = json.dumps(logs, ensure_ascii=False, indent=2)
with open(file_path, "w", encoding = "utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
elif format == "csv": elif format == "csv":
import csv import csv
if logs: if logs:
with open(file_path, "w", newline = "", encoding = "utf-8") as f: with open(file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames = logs[0].keys()) writer = csv.DictWriter(f, fieldnames=logs[0].keys())
writer.writeheader() writer.writeheader()
writer.writerows(logs) writer.writerows(logs)
else: else:
# 其他格式暂不支持 # 其他格式暂不支持
content = json.dumps(logs, ensure_ascii = False) content = json.dumps(logs, ensure_ascii=False)
with open(file_path, "w", encoding = "utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
@@ -1596,24 +1596,24 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
policy = DataRetentionPolicy( policy = DataRetentionPolicy(
id = policy_id, id=policy_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
resource_type = resource_type, resource_type=resource_type,
retention_days = retention_days, retention_days=retention_days,
action = action, action=action,
conditions = conditions or {}, conditions=conditions or {},
auto_execute = auto_execute, auto_execute=auto_execute,
execute_at = execute_at, execute_at=execute_at,
notify_before_days = notify_before_days, notify_before_days=notify_before_days,
archive_location = archive_location, archive_location=archive_location,
archive_encryption = archive_encryption, archive_encryption=archive_encryption,
is_active = True, is_active=True,
last_executed_at = None, last_executed_at=None,
last_execution_result = None, last_execution_result=None,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1776,18 +1776,18 @@ class EnterpriseManager:
now = datetime.now() now = datetime.now()
job = DataRetentionJob( job = DataRetentionJob(
id = job_id, id=job_id,
policy_id = policy_id, policy_id=policy_id,
tenant_id = policy.tenant_id, tenant_id=policy.tenant_id,
status = "running", status="running",
started_at = now, started_at=now,
completed_at = None, completed_at=None,
affected_records = 0, affected_records=0,
archived_records = 0, archived_records=0,
deleted_records = 0, deleted_records=0,
error_count = 0, error_count=0,
details = {}, details={},
created_at = now, created_at=now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1804,7 +1804,7 @@ class EnterpriseManager:
try: try:
# 计算截止日期 # 计算截止日期
cutoff_date = now - timedelta(days = policy.retention_days) cutoff_date = now - timedelta(days=policy.retention_days)
# 根据资源类型执行不同的处理 # 根据资源类型执行不同的处理
if policy.resource_type == "audit_log": if policy.resource_type == "audit_log":
@@ -1963,64 +1963,64 @@ class EnterpriseManager:
def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig: def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig:
"""数据库行转换为 SSOConfig 对象""" """数据库行转换为 SSOConfig 对象"""
return SSOConfig( return SSOConfig(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
provider = row["provider"], provider=row["provider"],
status = row["status"], status=row["status"],
entity_id = row["entity_id"], entity_id=row["entity_id"],
sso_url = row["sso_url"], sso_url=row["sso_url"],
slo_url = row["slo_url"], slo_url=row["slo_url"],
certificate = row["certificate"], certificate=row["certificate"],
metadata_url = row["metadata_url"], metadata_url=row["metadata_url"],
metadata_xml = row["metadata_xml"], metadata_xml=row["metadata_xml"],
client_id = row["client_id"], client_id=row["client_id"],
client_secret = row["client_secret"], client_secret=row["client_secret"],
authorization_url = row["authorization_url"], authorization_url=row["authorization_url"],
token_url = row["token_url"], token_url=row["token_url"],
userinfo_url = row["userinfo_url"], userinfo_url=row["userinfo_url"],
scopes = json.loads(row["scopes"] or '["openid", "email", "profile"]'), scopes=json.loads(row["scopes"] or '["openid", "email", "profile"]'),
attribute_mapping = json.loads(row["attribute_mapping"] or "{}"), attribute_mapping=json.loads(row["attribute_mapping"] or "{}"),
auto_provision = bool(row["auto_provision"]), auto_provision=bool(row["auto_provision"]),
default_role = row["default_role"], default_role=row["default_role"],
domain_restriction = json.loads(row["domain_restriction"] or "[]"), domain_restriction=json.loads(row["domain_restriction"] or "[]"),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
last_tested_at = ( last_tested_at=(
datetime.fromisoformat(row["last_tested_at"]) datetime.fromisoformat(row["last_tested_at"])
if row["last_tested_at"] and isinstance(row["last_tested_at"], str) if row["last_tested_at"] and isinstance(row["last_tested_at"], str)
else row["last_tested_at"] else row["last_tested_at"]
), ),
last_error = row["last_error"], last_error=row["last_error"],
) )
def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest: def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest:
"""数据库行转换为 SAMLAuthRequest 对象""" """数据库行转换为 SAMLAuthRequest 对象"""
return SAMLAuthRequest( return SAMLAuthRequest(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
sso_config_id = row["sso_config_id"], sso_config_id=row["sso_config_id"],
request_id = row["request_id"], request_id=row["request_id"],
relay_state = row["relay_state"], relay_state=row["relay_state"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
expires_at = ( expires_at=(
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str) if isinstance(row["expires_at"], str)
else row["expires_at"] else row["expires_at"]
), ),
used = bool(row["used"]), used=bool(row["used"]),
used_at = ( used_at=(
datetime.fromisoformat(row["used_at"]) datetime.fromisoformat(row["used_at"])
if row["used_at"] and isinstance(row["used_at"], str) if row["used_at"] and isinstance(row["used_at"], str)
else row["used_at"] else row["used_at"]
@@ -2030,29 +2030,29 @@ class EnterpriseManager:
def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig: def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig:
"""数据库行转换为 SCIMConfig 对象""" """数据库行转换为 SCIMConfig 对象"""
return SCIMConfig( return SCIMConfig(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
provider = row["provider"], provider=row["provider"],
status = row["status"], status=row["status"],
scim_base_url = row["scim_base_url"], scim_base_url=row["scim_base_url"],
scim_token = row["scim_token"], scim_token=row["scim_token"],
sync_interval_minutes = row["sync_interval_minutes"], sync_interval_minutes=row["sync_interval_minutes"],
last_sync_at = ( last_sync_at=(
datetime.fromisoformat(row["last_sync_at"]) datetime.fromisoformat(row["last_sync_at"])
if row["last_sync_at"] and isinstance(row["last_sync_at"], str) if row["last_sync_at"] and isinstance(row["last_sync_at"], str)
else row["last_sync_at"] else row["last_sync_at"]
), ),
last_sync_status = row["last_sync_status"], last_sync_status=row["last_sync_status"],
last_sync_error = row["last_sync_error"], last_sync_error=row["last_sync_error"],
last_sync_users_count = row["last_sync_users_count"], last_sync_users_count=row["last_sync_users_count"],
attribute_mapping = json.loads(row["attribute_mapping"] or "{}"), attribute_mapping=json.loads(row["attribute_mapping"] or "{}"),
sync_rules = json.loads(row["sync_rules"] or "{}"), sync_rules=json.loads(row["sync_rules"] or "{}"),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2062,28 +2062,28 @@ class EnterpriseManager:
def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser: def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser:
"""数据库行转换为 SCIMUser 对象""" """数据库行转换为 SCIMUser 对象"""
return SCIMUser( return SCIMUser(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
external_id = row["external_id"], external_id=row["external_id"],
user_name = row["user_name"], user_name=row["user_name"],
email = row["email"], email=row["email"],
display_name = row["display_name"], display_name=row["display_name"],
given_name = row["given_name"], given_name=row["given_name"],
family_name = row["family_name"], family_name=row["family_name"],
active = bool(row["active"]), active=bool(row["active"]),
groups = json.loads(row["groups"] or "[]"), groups=json.loads(row["groups"] or "[]"),
raw_data = json.loads(row["raw_data"] or "{}"), raw_data=json.loads(row["raw_data"] or "{}"),
synced_at = ( synced_at=(
datetime.fromisoformat(row["synced_at"]) datetime.fromisoformat(row["synced_at"])
if isinstance(row["synced_at"], str) if isinstance(row["synced_at"], str)
else row["synced_at"] else row["synced_at"]
), ),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2093,78 +2093,78 @@ class EnterpriseManager:
def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport: def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport:
"""数据库行转换为 AuditLogExport 对象""" """数据库行转换为 AuditLogExport 对象"""
return AuditLogExport( return AuditLogExport(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
export_format = row["export_format"], export_format=row["export_format"],
start_date = ( start_date=(
datetime.fromisoformat(row["start_date"]) datetime.fromisoformat(row["start_date"])
if isinstance(row["start_date"], str) if isinstance(row["start_date"], str)
else row["start_date"] else row["start_date"]
), ),
end_date = datetime.fromisoformat(row["end_date"]) end_date=datetime.fromisoformat(row["end_date"])
if isinstance(row["end_date"], str) if isinstance(row["end_date"], str)
else row["end_date"], else row["end_date"],
filters = json.loads(row["filters"] or "{}"), filters=json.loads(row["filters"] or "{}"),
compliance_standard = row["compliance_standard"], compliance_standard=row["compliance_standard"],
status = row["status"], status=row["status"],
file_path = row["file_path"], file_path=row["file_path"],
file_size = row["file_size"], file_size=row["file_size"],
record_count = row["record_count"], record_count=row["record_count"],
checksum = row["checksum"], checksum=row["checksum"],
downloaded_by = row["downloaded_by"], downloaded_by=row["downloaded_by"],
downloaded_at = ( downloaded_at=(
datetime.fromisoformat(row["downloaded_at"]) datetime.fromisoformat(row["downloaded_at"])
if row["downloaded_at"] and isinstance(row["downloaded_at"], str) if row["downloaded_at"] and isinstance(row["downloaded_at"], str)
else row["downloaded_at"] else row["downloaded_at"]
), ),
expires_at = ( expires_at=(
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str) if isinstance(row["expires_at"], str)
else row["expires_at"] else row["expires_at"]
), ),
created_by = row["created_by"], created_by=row["created_by"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
completed_at = ( completed_at=(
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str) if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"] else row["completed_at"]
), ),
error_message = row["error_message"], error_message=row["error_message"],
) )
def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy: def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy:
"""数据库行转换为 DataRetentionPolicy 对象""" """数据库行转换为 DataRetentionPolicy 对象"""
return DataRetentionPolicy( return DataRetentionPolicy(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
resource_type = row["resource_type"], resource_type=row["resource_type"],
retention_days = row["retention_days"], retention_days=row["retention_days"],
action = row["action"], action=row["action"],
conditions = json.loads(row["conditions"] or "{}"), conditions=json.loads(row["conditions"] or "{}"),
auto_execute = bool(row["auto_execute"]), auto_execute=bool(row["auto_execute"]),
execute_at = row["execute_at"], execute_at=row["execute_at"],
notify_before_days = row["notify_before_days"], notify_before_days=row["notify_before_days"],
archive_location = row["archive_location"], archive_location=row["archive_location"],
archive_encryption = bool(row["archive_encryption"]), archive_encryption=bool(row["archive_encryption"]),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
last_executed_at = ( last_executed_at=(
datetime.fromisoformat(row["last_executed_at"]) datetime.fromisoformat(row["last_executed_at"])
if row["last_executed_at"] and isinstance(row["last_executed_at"], str) if row["last_executed_at"] and isinstance(row["last_executed_at"], str)
else row["last_executed_at"] else row["last_executed_at"]
), ),
last_execution_result = row["last_execution_result"], last_execution_result=row["last_execution_result"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2174,26 +2174,26 @@ class EnterpriseManager:
def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob: def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob:
"""数据库行转换为 DataRetentionJob 对象""" """数据库行转换为 DataRetentionJob 对象"""
return DataRetentionJob( return DataRetentionJob(
id = row["id"], id=row["id"],
policy_id = row["policy_id"], policy_id=row["policy_id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
status = row["status"], status=row["status"],
started_at = ( started_at=(
datetime.fromisoformat(row["started_at"]) datetime.fromisoformat(row["started_at"])
if row["started_at"] and isinstance(row["started_at"], str) if row["started_at"] and isinstance(row["started_at"], str)
else row["started_at"] else row["started_at"]
), ),
completed_at = ( completed_at=(
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str) if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"] else row["completed_at"]
), ),
affected_records = row["affected_records"], affected_records=row["affected_records"],
archived_records = row["archived_records"], archived_records=row["archived_records"],
deleted_records = row["deleted_records"], deleted_records=row["deleted_records"],
error_count = row["error_count"], error_count=row["error_count"],
details = json.loads(row["details"] or "{}"), details=json.loads(row["details"] or "{}"),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]

View File

@@ -52,12 +52,12 @@ class EntityAligner:
try: try:
response = httpx.post( response = httpx.post(
f"{KIMI_BASE_URL}/v1/embeddings", f"{KIMI_BASE_URL}/v1/embeddings",
headers = { headers={
"Authorization": f"Bearer {KIMI_API_KEY}", "Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json = {"model": "k2p5", "input": text[:500]}, # 限制长度 json={"model": "k2p5", "input": text[:500]}, # 限制长度
timeout = 30.0, timeout=30.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -232,7 +232,7 @@ class EntityAligner:
for new_ent in new_entities: for new_ent in new_entities:
matched = self.find_similar_entity( matched = self.find_similar_entity(
project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
) )
result = { result = {
@@ -292,16 +292,16 @@ class EntityAligner:
try: try:
response = httpx.post( response = httpx.post(
f"{KIMI_BASE_URL}/v1/chat/completions", f"{KIMI_BASE_URL}/v1/chat/completions",
headers = { headers={
"Authorization": f"Bearer {KIMI_API_KEY}", "Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json = { json={
"model": "k2p5", "model": "k2p5",
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.3, "temperature": 0.3,
}, },
timeout = 30.0, timeout=30.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()

View File

@@ -71,7 +71,7 @@ class ExportTranscript:
class ExportManager: class ExportManager:
"""导出管理器 - 处理各种导出需求""" """导出管理器 - 处理各种导出需求"""
def __init__(self, db_manager = None) -> None: def __init__(self, db_manager=None) -> None:
self.db = db_manager self.db = db_manager
def export_knowledge_graph_svg( def export_knowledge_graph_svg(
@@ -232,7 +232,7 @@ class ExportManager:
import cairosvg import cairosvg
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8")) png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
return png_bytes return png_bytes
except ImportError: except ImportError:
# 如果没有 cairosvg返回 SVG 的 base64 # 如果没有 cairosvg返回 SVG 的 base64
@@ -269,8 +269,8 @@ class ExportManager:
# 写入 Excel # 写入 Excel
output = io.BytesIO() output = io.BytesIO()
with pd.ExcelWriter(output, engine = "openpyxl") as writer: with pd.ExcelWriter(output, engine="openpyxl") as writer:
df.to_excel(writer, sheet_name = "实体列表", index = False) df.to_excel(writer, sheet_name="实体列表", index=False)
# 调整列宽 # 调整列宽
worksheet = writer.sheets["实体列表"] worksheet = writer.sheets["实体列表"]
@@ -417,24 +417,24 @@ class ExportManager:
output = io.BytesIO() output = io.BytesIO()
doc = SimpleDocTemplate( doc = SimpleDocTemplate(
output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18 output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18
) )
# 样式 # 样式
styles = getSampleStyleSheet() styles = getSampleStyleSheet()
title_style = ParagraphStyle( title_style = ParagraphStyle(
"CustomTitle", "CustomTitle",
parent = styles["Heading1"], parent=styles["Heading1"],
fontSize = 24, fontSize=24,
spaceAfter = 30, spaceAfter=30,
textColor = colors.HexColor("#2c3e50"), textColor=colors.HexColor("#2c3e50"),
) )
heading_style = ParagraphStyle( heading_style = ParagraphStyle(
"CustomHeading", "CustomHeading",
parent = styles["Heading2"], parent=styles["Heading2"],
fontSize = 16, fontSize=16,
spaceAfter = 12, spaceAfter=12,
textColor = colors.HexColor("#34495e"), textColor=colors.HexColor("#34495e"),
) )
story = [] story = []
@@ -467,7 +467,7 @@ class ExportManager:
for etype, count in sorted(type_counts.items()): for etype, count in sorted(type_counts.items()):
stats_data.append([f"{etype} 实体", str(count)]) stats_data.append([f"{etype} 实体", str(count)])
stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch]) stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch])
stats_table.setStyle( stats_table.setStyle(
TableStyle( TableStyle(
[ [
@@ -497,7 +497,7 @@ class ExportManager:
story.append(Paragraph("实体列表", heading_style)) story.append(Paragraph("实体列表", heading_style))
entity_data = [["名称", "类型", "提及次数", "定义"]] entity_data = [["名称", "类型", "提及次数", "定义"]]
for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[ for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[
:50 :50
]: # 限制前50个 ]: # 限制前50个
entity_data.append( entity_data.append(
@@ -510,7 +510,7 @@ class ExportManager:
) )
entity_table = Table( entity_table = Table(
entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch] entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
) )
entity_table.setStyle( entity_table.setStyle(
TableStyle( TableStyle(
@@ -539,7 +539,7 @@ class ExportManager:
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"]) relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
relation_table = Table( relation_table = Table(
relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch] relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
) )
relation_table.setStyle( relation_table.setStyle(
TableStyle( TableStyle(
@@ -613,14 +613,14 @@ class ExportManager:
], ],
} }
return json.dumps(data, ensure_ascii = False, indent = 2) return json.dumps(data, ensure_ascii=False, indent=2)
# 全局导出管理器实例 # 全局导出管理器实例
_export_manager = None _export_manager = None
def get_export_manager(db_manager = None) -> None: def get_export_manager(db_manager=None) -> None:
"""获取导出管理器实例""" """获取导出管理器实例"""
global _export_manager global _export_manager
if _export_manager is None: if _export_manager is None:

View File

@@ -394,19 +394,19 @@ class GrowthManager:
now = datetime.now() now = datetime.now()
event = AnalyticsEvent( event = AnalyticsEvent(
id = event_id, id=event_id,
tenant_id = tenant_id, tenant_id=tenant_id,
user_id = user_id, user_id=user_id,
event_type = event_type, event_type=event_type,
event_name = event_name, event_name=event_name,
properties = properties or {}, properties=properties or {},
timestamp = now, timestamp=now,
session_id = session_id, session_id=session_id,
device_info = device_info or {}, device_info=device_info or {},
referrer = referrer, referrer=referrer,
utm_source = utm_params.get("source") if utm_params else None, utm_source=utm_params.get("source") if utm_params else None,
utm_medium = utm_params.get("medium") if utm_params else None, utm_medium=utm_params.get("medium") if utm_params else None,
utm_campaign = utm_params.get("campaign") if utm_params else None, utm_campaign=utm_params.get("campaign") if utm_params else None,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -453,7 +453,7 @@ class GrowthManager:
tasks.append(self._send_to_amplitude(event)) tasks.append(self._send_to_amplitude(event))
if tasks: if tasks:
await asyncio.gather(*tasks, return_exceptions = True) await asyncio.gather(*tasks, return_exceptions=True)
async def _send_to_mixpanel(self, event: AnalyticsEvent) -> None: async def _send_to_mixpanel(self, event: AnalyticsEvent) -> None:
"""发送事件到 Mixpanel""" """发送事件到 Mixpanel"""
@@ -475,7 +475,7 @@ class GrowthManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post( await client.post(
"https://api.mixpanel.com/track", headers = headers, json = [payload], timeout = 10.0 "https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0
) )
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
print(f"Failed to send to Mixpanel: {e}") print(f"Failed to send to Mixpanel: {e}")
@@ -501,9 +501,9 @@ class GrowthManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post( await client.post(
"https://api.amplitude.com/2/httpapi", "https://api.amplitude.com/2/httpapi",
headers = headers, headers=headers,
json = payload, json=payload,
timeout = 10.0, timeout=10.0,
) )
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
print(f"Failed to send to Amplitude: {e}") print(f"Failed to send to Amplitude: {e}")
@@ -642,13 +642,13 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
funnel = Funnel( funnel = Funnel(
id = funnel_id, id=funnel_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
steps = steps, steps=steps,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -685,7 +685,7 @@ class GrowthManager:
steps = json.loads(funnel_row["steps"]) steps = json.loads(funnel_row["steps"])
if not period_start: if not period_start:
period_start = datetime.now() - timedelta(days = 30) period_start = datetime.now() - timedelta(days=30)
if not period_end: if not period_end:
period_end = datetime.now() period_end = datetime.now()
@@ -740,13 +740,13 @@ class GrowthManager:
] ]
return FunnelAnalysis( return FunnelAnalysis(
funnel_id = funnel_id, funnel_id=funnel_id,
period_start = period_start, period_start=period_start,
period_end = period_end, period_end=period_end,
total_users = step_conversions[0]["user_count"] if step_conversions else 0, total_users=step_conversions[0]["user_count"] if step_conversions else 0,
step_conversions = step_conversions, step_conversions=step_conversions,
overall_conversion = round(overall_conversion, 4), overall_conversion=round(overall_conversion, 4),
drop_off_points = drop_off_points, drop_off_points=drop_off_points,
) )
def calculate_retention( def calculate_retention(
@@ -781,7 +781,7 @@ class GrowthManager:
retention_rates = {} retention_rates = {}
for period in periods: for period in periods:
period_date = cohort_date + timedelta(days = period) period_date = cohort_date + timedelta(days=period)
active_query = """ active_query = """
SELECT COUNT(DISTINCT user_id) as active_count SELECT COUNT(DISTINCT user_id) as active_count
@@ -830,25 +830,25 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
experiment = Experiment( experiment = Experiment(
id = experiment_id, id=experiment_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
hypothesis = hypothesis, hypothesis=hypothesis,
status = ExperimentStatus.DRAFT, status=ExperimentStatus.DRAFT,
variants = variants, variants=variants,
traffic_allocation = traffic_allocation, traffic_allocation=traffic_allocation,
traffic_split = traffic_split, traffic_split=traffic_split,
target_audience = target_audience, target_audience=target_audience,
primary_metric = primary_metric, primary_metric=primary_metric,
secondary_metrics = secondary_metrics, secondary_metrics=secondary_metrics,
start_date = None, start_date=None,
end_date = None, end_date=None,
min_sample_size = min_sample_size, min_sample_size=min_sample_size,
confidence_level = confidence_level, confidence_level=confidence_level,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
created_by = created_by or "system", created_by=created_by or "system",
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -973,7 +973,7 @@ class GrowthManager:
total = sum(weights) total = sum(weights)
normalized_weights = [w / total for w in weights] normalized_weights = [w / total for w in weights]
return random.choices(variant_ids, weights = normalized_weights, k = 1)[0] return random.choices(variant_ids, weights=normalized_weights, k=1)[0]
def _stratified_allocation( def _stratified_allocation(
self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict
@@ -1196,21 +1196,21 @@ class GrowthManager:
variables = re.findall(r"\{\{(\w+)\}\}", html_content) variables = re.findall(r"\{\{(\w+)\}\}", html_content)
template = EmailTemplate( template = EmailTemplate(
id = template_id, id=template_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
template_type = template_type, template_type=template_type,
subject = subject, subject=subject,
html_content = html_content, html_content=html_content,
text_content = text_content or re.sub(r"<[^>]+>", "", html_content), text_content=text_content or re.sub(r"<[^>]+>", "", html_content),
variables = variables, variables=variables,
preview_text = None, preview_text=None,
from_name = from_name or "InsightFlow", from_name=from_name or "InsightFlow",
from_email = from_email or "noreply@insightflow.io", from_email=from_email or "noreply@insightflow.io",
reply_to = reply_to, reply_to=reply_to,
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1308,22 +1308,22 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
campaign = EmailCampaign( campaign = EmailCampaign(
id = campaign_id, id=campaign_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
template_id = template_id, template_id=template_id,
status = "draft", status="draft",
recipient_count = len(recipient_list), recipient_count=len(recipient_list),
sent_count = 0, sent_count=0,
delivered_count = 0, delivered_count=0,
opened_count = 0, opened_count=0,
clicked_count = 0, clicked_count=0,
bounced_count = 0, bounced_count=0,
failed_count = 0, failed_count=0,
scheduled_at = scheduled_at.isoformat() if scheduled_at else None, scheduled_at=scheduled_at.isoformat() if scheduled_at else None,
started_at = None, started_at=None,
completed_at = None, completed_at=None,
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1530,17 +1530,17 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
workflow = AutomationWorkflow( workflow = AutomationWorkflow(
id = workflow_id, id=workflow_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
trigger_type = trigger_type, trigger_type=trigger_type,
trigger_conditions = trigger_conditions, trigger_conditions=trigger_conditions,
actions = actions, actions=actions,
is_active = True, is_active=True,
execution_count = 0, execution_count=0,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1640,20 +1640,20 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
program = ReferralProgram( program = ReferralProgram(
id = program_id, id=program_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
referrer_reward_type = referrer_reward_type, referrer_reward_type=referrer_reward_type,
referrer_reward_value = referrer_reward_value, referrer_reward_value=referrer_reward_value,
referee_reward_type = referee_reward_type, referee_reward_type=referee_reward_type,
referee_reward_value = referee_reward_value, referee_reward_value=referee_reward_value,
max_referrals_per_user = max_referrals_per_user, max_referrals_per_user=max_referrals_per_user,
referral_code_length = referral_code_length, referral_code_length=referral_code_length,
expiry_days = expiry_days, expiry_days=expiry_days,
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1708,24 +1708,24 @@ class GrowthManager:
referral_id = f"ref_{uuid.uuid4().hex[:16]}" referral_id = f"ref_{uuid.uuid4().hex[:16]}"
now = datetime.now() now = datetime.now()
expires_at = now + timedelta(days = program.expiry_days) expires_at = now + timedelta(days=program.expiry_days)
referral = Referral( referral = Referral(
id = referral_id, id=referral_id,
program_id = program_id, program_id=program_id,
tenant_id = program.tenant_id, tenant_id=program.tenant_id,
referrer_id = referrer_id, referrer_id=referrer_id,
referee_id = None, referee_id=None,
referral_code = referral_code, referral_code=referral_code,
status = ReferralStatus.PENDING, status=ReferralStatus.PENDING,
referrer_rewarded = False, referrer_rewarded=False,
referee_rewarded = False, referee_rewarded=False,
referrer_reward_value = program.referrer_reward_value, referrer_reward_value=program.referrer_reward_value,
referee_reward_value = program.referee_reward_value, referee_reward_value=program.referee_reward_value,
converted_at = None, converted_at=None,
rewarded_at = None, rewarded_at=None,
expires_at = expires_at, expires_at=expires_at,
created_at = now, created_at=now,
) )
conn.execute( conn.execute(
@@ -1762,7 +1762,7 @@ class GrowthManager:
"""生成唯一推荐码""" """生成唯一推荐码"""
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符 chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符
while True: while True:
code = "".join(random.choices(chars, k = length)) code = "".join(random.choices(chars, k=length))
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
@@ -1883,18 +1883,18 @@ class GrowthManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
incentive = TeamIncentive( incentive = TeamIncentive(
id = incentive_id, id=incentive_id,
tenant_id = tenant_id, tenant_id=tenant_id,
name = name, name=name,
description = description, description=description,
target_tier = target_tier, target_tier=target_tier,
min_team_size = min_team_size, min_team_size=min_team_size,
incentive_type = incentive_type, incentive_type=incentive_type,
incentive_value = incentive_value, incentive_value=incentive_value,
valid_from = valid_from.isoformat(), valid_from=valid_from.isoformat(),
valid_until = valid_until.isoformat(), valid_until=valid_until.isoformat(),
is_active = True, is_active=True,
created_at = now, created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -1947,7 +1947,7 @@ class GrowthManager:
def get_realtime_dashboard(self, tenant_id: str) -> dict: def get_realtime_dashboard(self, tenant_id: str) -> dict:
"""获取实时分析仪表板数据""" """获取实时分析仪表板数据"""
now = datetime.now() now = datetime.now()
today_start = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0) today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
with self._get_db() as conn: with self._get_db() as conn:
# 今日统计 # 今日统计
@@ -1991,8 +1991,8 @@ class GrowthManager:
# 活跃用户趋势最近24小时每小时 # 活跃用户趋势最近24小时每小时
hourly_trend = [] hourly_trend = []
for i in range(24): for i in range(24):
hour_start = now - timedelta(hours = i + 1) hour_start = now - timedelta(hours=i + 1)
hour_end = now - timedelta(hours = i) hour_end = now - timedelta(hours=i)
row = conn.execute( row = conn.execute(
""" """
@@ -2035,116 +2035,116 @@ class GrowthManager:
def _row_to_user_profile(self, row) -> UserProfile: def _row_to_user_profile(self, row) -> UserProfile:
"""将数据库行转换为 UserProfile""" """将数据库行转换为 UserProfile"""
return UserProfile( return UserProfile(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
user_id = row["user_id"], user_id=row["user_id"],
first_seen = datetime.fromisoformat(row["first_seen"]), first_seen=datetime.fromisoformat(row["first_seen"]),
last_seen = datetime.fromisoformat(row["last_seen"]), last_seen=datetime.fromisoformat(row["last_seen"]),
total_sessions = row["total_sessions"], total_sessions=row["total_sessions"],
total_events = row["total_events"], total_events=row["total_events"],
feature_usage = json.loads(row["feature_usage"]), feature_usage=json.loads(row["feature_usage"]),
subscription_history = json.loads(row["subscription_history"]), subscription_history=json.loads(row["subscription_history"]),
ltv = row["ltv"], ltv=row["ltv"],
churn_risk_score = row["churn_risk_score"], churn_risk_score=row["churn_risk_score"],
engagement_score = row["engagement_score"], engagement_score=row["engagement_score"],
created_at = datetime.fromisoformat(row["created_at"]), created_at=datetime.fromisoformat(row["created_at"]),
updated_at = datetime.fromisoformat(row["updated_at"]), updated_at=datetime.fromisoformat(row["updated_at"]),
) )
def _row_to_experiment(self, row) -> Experiment: def _row_to_experiment(self, row) -> Experiment:
"""将数据库行转换为 Experiment""" """将数据库行转换为 Experiment"""
return Experiment( return Experiment(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
hypothesis = row["hypothesis"], hypothesis=row["hypothesis"],
status = ExperimentStatus(row["status"]), status=ExperimentStatus(row["status"]),
variants = json.loads(row["variants"]), variants=json.loads(row["variants"]),
traffic_allocation = TrafficAllocationType(row["traffic_allocation"]), traffic_allocation=TrafficAllocationType(row["traffic_allocation"]),
traffic_split = json.loads(row["traffic_split"]), traffic_split=json.loads(row["traffic_split"]),
target_audience = json.loads(row["target_audience"]), target_audience=json.loads(row["target_audience"]),
primary_metric = row["primary_metric"], primary_metric=row["primary_metric"],
secondary_metrics = json.loads(row["secondary_metrics"]), secondary_metrics=json.loads(row["secondary_metrics"]),
start_date = datetime.fromisoformat(row["start_date"]) if row["start_date"] else None, start_date=datetime.fromisoformat(row["start_date"]) if row["start_date"] else None,
end_date = datetime.fromisoformat(row["end_date"]) if row["end_date"] else None, end_date=datetime.fromisoformat(row["end_date"]) if row["end_date"] else None,
min_sample_size = row["min_sample_size"], min_sample_size=row["min_sample_size"],
confidence_level = row["confidence_level"], confidence_level=row["confidence_level"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
created_by = row["created_by"], created_by=row["created_by"],
) )
def _row_to_email_template(self, row) -> EmailTemplate: def _row_to_email_template(self, row) -> EmailTemplate:
"""将数据库行转换为 EmailTemplate""" """将数据库行转换为 EmailTemplate"""
return EmailTemplate( return EmailTemplate(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
template_type = EmailTemplateType(row["template_type"]), template_type=EmailTemplateType(row["template_type"]),
subject = row["subject"], subject=row["subject"],
html_content = row["html_content"], html_content=row["html_content"],
text_content = row["text_content"], text_content=row["text_content"],
variables = json.loads(row["variables"]), variables=json.loads(row["variables"]),
preview_text = row["preview_text"], preview_text=row["preview_text"],
from_name = row["from_name"], from_name=row["from_name"],
from_email = row["from_email"], from_email=row["from_email"],
reply_to = row["reply_to"], reply_to=row["reply_to"],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_automation_workflow(self, row) -> AutomationWorkflow: def _row_to_automation_workflow(self, row) -> AutomationWorkflow:
"""将数据库行转换为 AutomationWorkflow""" """将数据库行转换为 AutomationWorkflow"""
return AutomationWorkflow( return AutomationWorkflow(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
trigger_type = WorkflowTriggerType(row["trigger_type"]), trigger_type=WorkflowTriggerType(row["trigger_type"]),
trigger_conditions = json.loads(row["trigger_conditions"]), trigger_conditions=json.loads(row["trigger_conditions"]),
actions = json.loads(row["actions"]), actions=json.loads(row["actions"]),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
execution_count = row["execution_count"], execution_count=row["execution_count"],
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_referral_program(self, row) -> ReferralProgram: def _row_to_referral_program(self, row) -> ReferralProgram:
"""将数据库行转换为 ReferralProgram""" """将数据库行转换为 ReferralProgram"""
return ReferralProgram( return ReferralProgram(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
referrer_reward_type = row["referrer_reward_type"], referrer_reward_type=row["referrer_reward_type"],
referrer_reward_value = row["referrer_reward_value"], referrer_reward_value=row["referrer_reward_value"],
referee_reward_type = row["referee_reward_type"], referee_reward_type=row["referee_reward_type"],
referee_reward_value = row["referee_reward_value"], referee_reward_value=row["referee_reward_value"],
max_referrals_per_user = row["max_referrals_per_user"], max_referrals_per_user=row["max_referrals_per_user"],
referral_code_length = row["referral_code_length"], referral_code_length=row["referral_code_length"],
expiry_days = row["expiry_days"], expiry_days=row["expiry_days"],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
def _row_to_team_incentive(self, row) -> TeamIncentive: def _row_to_team_incentive(self, row) -> TeamIncentive:
"""将数据库行转换为 TeamIncentive""" """将数据库行转换为 TeamIncentive"""
return TeamIncentive( return TeamIncentive(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
name = row["name"], name=row["name"],
description = row["description"], description=row["description"],
target_tier = row["target_tier"], target_tier=row["target_tier"],
min_team_size = row["min_team_size"], min_team_size=row["min_team_size"],
incentive_type = row["incentive_type"], incentive_type=row["incentive_type"],
incentive_value = row["incentive_value"], incentive_value=row["incentive_value"],
valid_from = datetime.fromisoformat(row["valid_from"]), valid_from=datetime.fromisoformat(row["valid_from"]),
valid_until = datetime.fromisoformat(row["valid_until"]), valid_until=datetime.fromisoformat(row["valid_until"]),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
) )

View File

@@ -104,7 +104,7 @@ class ImageProcessor:
temp_dir: 临时文件目录 temp_dir: 临时文件目录
""" """
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok = True) os.makedirs(self.temp_dir, exist_ok=True)
def preprocess_image(self, image, image_type: str = None) -> None: def preprocess_image(self, image, image_type: str = None) -> None:
""" """
@@ -169,7 +169,7 @@ class ImageProcessor:
gray = image.convert("L") gray = image.convert("L")
# 轻微降噪 # 轻微降噪
blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1)) blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
# 增强对比度 # 增强对比度
enhancer = ImageEnhance.Contrast(blurred) enhancer = ImageEnhance.Contrast(blurred)
@@ -255,10 +255,10 @@ class ImageProcessor:
processed_image = self.preprocess_image(image) processed_image = self.preprocess_image(image)
# 执行OCR # 执行OCR
text = pytesseract.image_to_string(processed_image, lang = lang) text = pytesseract.image_to_string(processed_image, lang=lang)
# 获取置信度 # 获取置信度
data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT) data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0] confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0 avg_confidence = sum(confidences) / len(confidences) if confidences else 0
@@ -288,12 +288,12 @@ class ImageProcessor:
for match in re.finditer(project_pattern, text): for match in re.finditer(project_pattern, text):
name = match.group(1) or match.group(2) name = match.group(1) or match.group(2)
if name and len(name) > 2: if name and len(name) > 2:
entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7)) entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
# 人名(中文) # 人名(中文)
name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)" name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)"
for match in re.finditer(name_pattern, text): for match in re.finditer(name_pattern, text):
entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8)) entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
# 技术术语 # 技术术语
tech_keywords = [ tech_keywords = [
@@ -314,7 +314,7 @@ class ImageProcessor:
] ]
for keyword in tech_keywords: for keyword in tech_keywords:
if keyword in text: if keyword in text:
entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9)) entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
# 去重 # 去重
seen = set() seen = set()
@@ -381,16 +381,16 @@ class ImageProcessor:
if not PIL_AVAILABLE: if not PIL_AVAILABLE:
return ImageProcessingResult( return ImageProcessingResult(
image_id = image_id, image_id=image_id,
image_type = "other", image_type="other",
ocr_text = "", ocr_text="",
description = "PIL not available", description="PIL not available",
entities = [], entities=[],
relations = [], relations=[],
width = 0, width=0,
height = 0, height=0,
success = False, success=False,
error_message = "PIL library not available", error_message="PIL library not available",
) )
try: try:
@@ -421,29 +421,29 @@ class ImageProcessor:
image.save(save_path) image.save(save_path)
return ImageProcessingResult( return ImageProcessingResult(
image_id = image_id, image_id=image_id,
image_type = image_type, image_type=image_type,
ocr_text = ocr_text, ocr_text=ocr_text,
description = description, description=description,
entities = entities, entities=entities,
relations = relations, relations=relations,
width = width, width=width,
height = height, height=height,
success = True, success=True,
) )
except Exception as e: except Exception as e:
return ImageProcessingResult( return ImageProcessingResult(
image_id = image_id, image_id=image_id,
image_type = "other", image_type="other",
ocr_text = "", ocr_text="",
description = "", description="",
entities = [], entities=[],
relations = [], relations=[],
width = 0, width=0,
height = 0, height=0,
success = False, success=False,
error_message = str(e), error_message=str(e),
) )
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]: def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
@@ -477,10 +477,10 @@ class ImageProcessor:
for j in range(i + 1, len(sentence_entities)): for j in range(i + 1, len(sentence_entities)):
relations.append( relations.append(
ImageRelation( ImageRelation(
source = sentence_entities[i].name, source=sentence_entities[i].name,
target = sentence_entities[j].name, target=sentence_entities[j].name,
relation_type = "related", relation_type="related",
confidence = 0.5, confidence=0.5,
) )
) )
@@ -513,10 +513,10 @@ class ImageProcessor:
failed_count += 1 failed_count += 1
return BatchProcessingResult( return BatchProcessingResult(
results = results, results=results,
total_count = len(results), total_count=len(results),
success_count = success_count, success_count=success_count,
failed_count = failed_count, failed_count=failed_count,
) )
def image_to_base64(self, image_data: bytes) -> str: def image_to_base64(self, image_data: bytes) -> str:
@@ -550,7 +550,7 @@ class ImageProcessor:
image.thumbnail(size, Image.Resampling.LANCZOS) image.thumbnail(size, Image.Resampling.LANCZOS)
buffer = io.BytesIO() buffer = io.BytesIO()
image.save(buffer, format = "JPEG") image.save(buffer, format="JPEG")
return buffer.getvalue() return buffer.getvalue()
except Exception as e: except Exception as e:
print(f"Thumbnail generation error: {e}") print(f"Thumbnail generation error: {e}")

View File

@@ -73,9 +73,9 @@ class KnowledgeReasoner:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
headers = self.headers, headers=self.headers,
json = payload, json=payload,
timeout = 120.0, timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -127,7 +127,7 @@ class KnowledgeReasoner:
- factual: 事实类问题(是什么、有哪些) - factual: 事实类问题(是什么、有哪些)
- opinion: 观点类问题(怎么看、态度、评价)""" - opinion: 观点类问题(怎么看、态度、评价)"""
content = await self._call_llm(prompt, temperature = 0.1) content = await self._call_llm(prompt, temperature=0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:
@@ -144,8 +144,8 @@ class KnowledgeReasoner:
"""因果推理 - 分析原因和影响""" """因果推理 - 分析原因和影响"""
# 构建因果分析提示 # 构建因果分析提示
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2) entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2) relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)
prompt = f"""基于以下知识图谱进行因果推理分析: prompt = f"""基于以下知识图谱进行因果推理分析:
@@ -159,7 +159,7 @@ class KnowledgeReasoner:
{relations_str[:2000]} {relations_str[:2000]}
## 项目上下文 ## 项目上下文
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]} {json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]}
请进行因果分析,返回 JSON 格式: 请进行因果分析,返回 JSON 格式:
{{ {{
@@ -172,7 +172,7 @@ class KnowledgeReasoner:
"knowledge_gaps": ["缺失信息1"] "knowledge_gaps": ["缺失信息1"]
}}""" }}"""
content = await self._call_llm(prompt, temperature = 0.3) content = await self._call_llm(prompt, temperature=0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -180,23 +180,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer = data.get("answer", ""), answer=data.get("answer", ""),
reasoning_type = ReasoningType.CAUSAL, reasoning_type=ReasoningType.CAUSAL,
confidence = data.get("confidence", 0.7), confidence=data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])], evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities = [], related_entities=[],
gaps = data.get("knowledge_gaps", []), gaps=data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer = content, answer=content,
reasoning_type = ReasoningType.CAUSAL, reasoning_type=ReasoningType.CAUSAL,
confidence = 0.5, confidence=0.5,
evidence = [], evidence=[],
related_entities = [], related_entities=[],
gaps = ["无法完成因果推理"], gaps=["无法完成因果推理"],
) )
async def _comparative_reasoning( async def _comparative_reasoning(
@@ -210,10 +210,10 @@ class KnowledgeReasoner:
{query} {query}
## 实体 ## 实体
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]} {json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]}
## 关系 ## 关系
{json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]} {json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]}
请进行对比分析,返回 JSON 格式: 请进行对比分析,返回 JSON 格式:
{{ {{
@@ -226,7 +226,7 @@ class KnowledgeReasoner:
"knowledge_gaps": [] "knowledge_gaps": []
}}""" }}"""
content = await self._call_llm(prompt, temperature = 0.3) content = await self._call_llm(prompt, temperature=0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -234,23 +234,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer = data.get("answer", ""), answer=data.get("answer", ""),
reasoning_type = ReasoningType.COMPARATIVE, reasoning_type=ReasoningType.COMPARATIVE,
confidence = data.get("confidence", 0.7), confidence=data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])], evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities = [], related_entities=[],
gaps = data.get("knowledge_gaps", []), gaps=data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer = content, answer=content,
reasoning_type = ReasoningType.COMPARATIVE, reasoning_type=ReasoningType.COMPARATIVE,
confidence = 0.5, confidence=0.5,
evidence = [], evidence=[],
related_entities = [], related_entities=[],
gaps = [], gaps=[],
) )
async def _temporal_reasoning( async def _temporal_reasoning(
@@ -264,10 +264,10 @@ class KnowledgeReasoner:
{query} {query}
## 项目时间线 ## 项目时间线
{json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]} {json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]}
## 实体提及历史 ## 实体提及历史
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]} {json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]}
请进行时序分析,返回 JSON 格式: 请进行时序分析,返回 JSON 格式:
{{ {{
@@ -280,7 +280,7 @@ class KnowledgeReasoner:
"knowledge_gaps": [] "knowledge_gaps": []
}}""" }}"""
content = await self._call_llm(prompt, temperature = 0.3) content = await self._call_llm(prompt, temperature=0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -288,23 +288,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer = data.get("answer", ""), answer=data.get("answer", ""),
reasoning_type = ReasoningType.TEMPORAL, reasoning_type=ReasoningType.TEMPORAL,
confidence = data.get("confidence", 0.7), confidence=data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])], evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities = [], related_entities=[],
gaps = data.get("knowledge_gaps", []), gaps=data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer = content, answer=content,
reasoning_type = ReasoningType.TEMPORAL, reasoning_type=ReasoningType.TEMPORAL,
confidence = 0.5, confidence=0.5,
evidence = [], evidence=[],
related_entities = [], related_entities=[],
gaps = [], gaps=[],
) )
async def _associative_reasoning( async def _associative_reasoning(
@@ -318,10 +318,10 @@ class KnowledgeReasoner:
{query} {query}
## 实体 ## 实体
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)} {json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)}
## 关系 ## 关系
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)} {json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)}
请进行关联推理,发现隐含联系,返回 JSON 格式: 请进行关联推理,发现隐含联系,返回 JSON 格式:
{{ {{
@@ -334,7 +334,7 @@ class KnowledgeReasoner:
"knowledge_gaps": [] "knowledge_gaps": []
}}""" }}"""
content = await self._call_llm(prompt, temperature = 0.4) content = await self._call_llm(prompt, temperature=0.4)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -342,23 +342,23 @@ class KnowledgeReasoner:
try: try:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
return ReasoningResult( return ReasoningResult(
answer = data.get("answer", ""), answer=data.get("answer", ""),
reasoning_type = ReasoningType.ASSOCIATIVE, reasoning_type=ReasoningType.ASSOCIATIVE,
confidence = data.get("confidence", 0.7), confidence=data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])], evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities = [], related_entities=[],
gaps = data.get("knowledge_gaps", []), gaps=data.get("knowledge_gaps", []),
) )
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
return ReasoningResult( return ReasoningResult(
answer = content, answer=content,
reasoning_type = ReasoningType.ASSOCIATIVE, reasoning_type=ReasoningType.ASSOCIATIVE,
confidence = 0.5, confidence=0.5,
evidence = [], evidence=[],
related_entities = [], related_entities=[],
gaps = [], gaps=[],
) )
def find_inference_paths( def find_inference_paths(
@@ -400,10 +400,10 @@ class KnowledgeReasoner:
# 找到一条路径 # 找到一条路径
paths.append( paths.append(
InferencePath( InferencePath(
start_entity = start_entity, start_entity=start_entity,
end_entity = end_entity, end_entity=end_entity,
path = path, path=path,
strength = self._calculate_path_strength(path), strength=self._calculate_path_strength(path),
) )
) )
continue continue
@@ -424,7 +424,7 @@ class KnowledgeReasoner:
queue.append((next_entity, new_path)) queue.append((next_entity, new_path))
# 按强度排序 # 按强度排序
paths.sort(key = lambda p: p.strength, reverse = True) paths.sort(key=lambda p: p.strength, reverse=True)
return paths return paths
def _calculate_path_strength(self, path: list[dict]) -> float: def _calculate_path_strength(self, path: list[dict]) -> float:
@@ -467,7 +467,7 @@ class KnowledgeReasoner:
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")} prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}
## 项目信息 ## 项目信息
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]} {json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]}
## 知识图谱 ## 知识图谱
实体数: {len(graph_data.get("entities", []))} 实体数: {len(graph_data.get("entities", []))}
@@ -483,7 +483,7 @@ class KnowledgeReasoner:
"confidence": 0.85 "confidence": 0.85
}}""" }}"""
content = await self._call_llm(prompt, temperature = 0.3) content = await self._call_llm(prompt, temperature=0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)

View File

@@ -66,9 +66,9 @@ class LLMClient:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
headers = self.headers, headers=self.headers,
json = payload, json=payload,
timeout = 120.0, timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -92,9 +92,9 @@ class LLMClient:
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
headers = self.headers, headers=self.headers,
json = payload, json=payload,
timeout = 120.0, timeout=120.0,
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@@ -139,8 +139,8 @@ class LLMClient:
] ]
}}""" }}"""
messages = [ChatMessage(role = "user", content = prompt)] messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature = 0.1) content = await self.chat(messages, temperature=0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match: if not json_match:
@@ -150,19 +150,19 @@ class LLMClient:
data = json.loads(json_match.group()) data = json.loads(json_match.group())
entities = [ entities = [
EntityExtractionResult( EntityExtractionResult(
name = e["name"], name=e["name"],
type = e.get("type", "OTHER"), type=e.get("type", "OTHER"),
definition = e.get("definition", ""), definition=e.get("definition", ""),
confidence = e.get("confidence", 0.8), confidence=e.get("confidence", 0.8),
) )
for e in data.get("entities", []) for e in data.get("entities", [])
] ]
relations = [ relations = [
RelationExtractionResult( RelationExtractionResult(
source = r["source"], source=r["source"],
target = r["target"], target=r["target"],
type = r.get("type", "related"), type=r.get("type", "related"),
confidence = r.get("confidence", 0.8), confidence=r.get("confidence", 0.8),
) )
for r in data.get("relations", []) for r in data.get("relations", [])
] ]
@@ -176,7 +176,7 @@ class LLMClient:
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
## 项目信息 ## 项目信息
{json.dumps(project_context, ensure_ascii = False, indent = 2)} {json.dumps(project_context, ensure_ascii=False, indent=2)}
## 相关上下文 ## 相关上下文
{context[:4000]} {context[:4000]}
@@ -188,19 +188,19 @@ class LLMClient:
messages = [ messages = [
ChatMessage( ChatMessage(
role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。" role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
), ),
ChatMessage(role = "user", content = prompt), ChatMessage(role="user", content=prompt),
] ]
return await self.chat(messages, temperature = 0.3) return await self.chat(messages, temperature=0.3)
async def agent_command(self, command: str, project_context: dict) -> dict: async def agent_command(self, command: str, project_context: dict) -> dict:
"""Agent 指令解析 - 将自然语言指令转换为结构化操作""" """Agent 指令解析 - 将自然语言指令转换为结构化操作"""
prompt = f"""解析以下用户指令,转换为结构化操作: prompt = f"""解析以下用户指令,转换为结构化操作:
## 项目信息 ## 项目信息
{json.dumps(project_context, ensure_ascii = False, indent = 2)} {json.dumps(project_context, ensure_ascii=False, indent=2)}
## 用户指令 ## 用户指令
{command} {command}
@@ -221,8 +221,8 @@ class LLMClient:
- create_relation: 创建关系params 包含 source(源实体), target(目标实体), relation_type(关系类型) - create_relation: 创建关系params 包含 source(源实体), target(目标实体), relation_type(关系类型)
""" """
messages = [ChatMessage(role = "user", content = prompt)] messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature = 0.1) content = await self.chat(messages, temperature=0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match: if not json_match:
@@ -255,8 +255,8 @@ class LLMClient:
用中文回答,结构清晰。""" 用中文回答,结构清晰。"""
messages = [ChatMessage(role = "user", content = prompt)] messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(messages, temperature = 0.3) return await self.chat(messages, temperature=0.3)
# Singleton instance # Singleton instance

View File

@@ -1257,7 +1257,7 @@ class LocalizationManager:
def get_localized_payment_methods( def get_localized_payment_methods(
self, country_code: str, language: str = "en" self, country_code: str, language: str = "en"
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
methods = self.list_payment_methods(country_code = country_code) methods = self.list_payment_methods(country_code=country_code)
result = [] result = []
for method in methods: for method in methods:
name_local = method.name_local.get(language, method.name) name_local = method.name_local.get(language, method.name)
@@ -1332,11 +1332,11 @@ class LocalizationManager:
try: try:
locale = Locale.parse(language.replace("_", "-")) locale = Locale.parse(language.replace("_", "-"))
if format_type == "date": if format_type == "date":
return dates.format_date(dt, locale = locale) return dates.format_date(dt, locale=locale)
elif format_type == "time": elif format_type == "time":
return dates.format_time(dt, locale = locale) return dates.format_time(dt, locale=locale)
else: else:
return dates.format_datetime(dt, locale = locale) return dates.format_datetime(dt, locale=locale)
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
return dt.strftime(fmt) return dt.strftime(fmt)
@@ -1352,7 +1352,7 @@ class LocalizationManager:
try: try:
locale = Locale.parse(language.replace("_", "-")) locale = Locale.parse(language.replace("_", "-"))
return numbers.format_decimal( return numbers.format_decimal(
number, locale = locale, decimal_quantization = (decimal_places is not None) number, locale=locale, decimal_quantization=(decimal_places is not None)
) )
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
@@ -1368,7 +1368,7 @@ class LocalizationManager:
if BABEL_AVAILABLE: if BABEL_AVAILABLE:
try: try:
locale = Locale.parse(language.replace("_", "-")) locale = Locale.parse(language.replace("_", "-"))
return numbers.format_currency(amount, currency, locale = locale) return numbers.format_currency(amount, currency, locale=locale)
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
return f"{currency} {amount:, .2f}" return f"{currency} {amount:, .2f}"
@@ -1536,25 +1536,25 @@ class LocalizationManager:
def _row_to_translation(self, row: sqlite3.Row) -> Translation: def _row_to_translation(self, row: sqlite3.Row) -> Translation:
return Translation( return Translation(
id = row["id"], id=row["id"],
key = row["key"], key=row["key"],
language = row["language"], language=row["language"],
value = row["value"], value=row["value"],
namespace = row["namespace"], namespace=row["namespace"],
context = row["context"], context=row["context"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
is_reviewed = bool(row["is_reviewed"]), is_reviewed=bool(row["is_reviewed"]),
reviewed_by = row["reviewed_by"], reviewed_by=row["reviewed_by"],
reviewed_at = ( reviewed_at=(
datetime.fromisoformat(row["reviewed_at"]) datetime.fromisoformat(row["reviewed_at"])
if row["reviewed_at"] and isinstance(row["reviewed_at"], str) if row["reviewed_at"] and isinstance(row["reviewed_at"], str)
else row["reviewed_at"] else row["reviewed_at"]
@@ -1563,39 +1563,39 @@ class LocalizationManager:
def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig: def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig:
return LanguageConfig( return LanguageConfig(
code = row["code"], code=row["code"],
name = row["name"], name=row["name"],
name_local = row["name_local"], name_local=row["name_local"],
is_rtl = bool(row["is_rtl"]), is_rtl=bool(row["is_rtl"]),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
is_default = bool(row["is_default"]), is_default=bool(row["is_default"]),
fallback_language = row["fallback_language"], fallback_language=row["fallback_language"],
date_format = row["date_format"], date_format=row["date_format"],
time_format = row["time_format"], time_format=row["time_format"],
datetime_format = row["datetime_format"], datetime_format=row["datetime_format"],
number_format = row["number_format"], number_format=row["number_format"],
currency_format = row["currency_format"], currency_format=row["currency_format"],
first_day_of_week = row["first_day_of_week"], first_day_of_week=row["first_day_of_week"],
calendar_type = row["calendar_type"], calendar_type=row["calendar_type"],
) )
def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter: def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter:
return DataCenter( return DataCenter(
id = row["id"], id=row["id"],
region_code = row["region_code"], region_code=row["region_code"],
name = row["name"], name=row["name"],
location = row["location"], location=row["location"],
endpoint = row["endpoint"], endpoint=row["endpoint"],
status = row["status"], status=row["status"],
priority = row["priority"], priority=row["priority"],
supported_regions = json.loads(row["supported_regions"] or "[]"), supported_regions=json.loads(row["supported_regions"] or "[]"),
capabilities = json.loads(row["capabilities"] or "{}"), capabilities=json.loads(row["capabilities"] or "{}"),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1604,18 +1604,18 @@ class LocalizationManager:
def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping: def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping:
return TenantDataCenterMapping( return TenantDataCenterMapping(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
primary_dc_id = row["primary_dc_id"], primary_dc_id=row["primary_dc_id"],
secondary_dc_id = row["secondary_dc_id"], secondary_dc_id=row["secondary_dc_id"],
region_code = row["region_code"], region_code=row["region_code"],
data_residency = row["data_residency"], data_residency=row["data_residency"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1624,24 +1624,24 @@ class LocalizationManager:
def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod: def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod:
return LocalizedPaymentMethod( return LocalizedPaymentMethod(
id = row["id"], id=row["id"],
provider = row["provider"], provider=row["provider"],
name = row["name"], name=row["name"],
name_local = json.loads(row["name_local"] or "{}"), name_local=json.loads(row["name_local"] or "{}"),
supported_countries = json.loads(row["supported_countries"] or "[]"), supported_countries=json.loads(row["supported_countries"] or "[]"),
supported_currencies = json.loads(row["supported_currencies"] or "[]"), supported_currencies=json.loads(row["supported_currencies"] or "[]"),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
config = json.loads(row["config"] or "{}"), config=json.loads(row["config"] or "{}"),
icon_url = row["icon_url"], icon_url=row["icon_url"],
display_order = row["display_order"], display_order=row["display_order"],
min_amount = row["min_amount"], min_amount=row["min_amount"],
max_amount = row["max_amount"], max_amount=row["max_amount"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1650,48 +1650,48 @@ class LocalizationManager:
def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig: def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig:
return CountryConfig( return CountryConfig(
code = row["code"], code=row["code"],
code3 = row["code3"], code3=row["code3"],
name = row["name"], name=row["name"],
name_local = json.loads(row["name_local"] or "{}"), name_local=json.loads(row["name_local"] or "{}"),
region = row["region"], region=row["region"],
default_language = row["default_language"], default_language=row["default_language"],
supported_languages = json.loads(row["supported_languages"] or "[]"), supported_languages=json.loads(row["supported_languages"] or "[]"),
default_currency = row["default_currency"], default_currency=row["default_currency"],
supported_currencies = json.loads(row["supported_currencies"] or "[]"), supported_currencies=json.loads(row["supported_currencies"] or "[]"),
timezone = row["timezone"], timezone=row["timezone"],
calendar_type = row["calendar_type"], calendar_type=row["calendar_type"],
date_format = row["date_format"], date_format=row["date_format"],
time_format = row["time_format"], time_format=row["time_format"],
number_format = row["number_format"], number_format=row["number_format"],
address_format = row["address_format"], address_format=row["address_format"],
phone_format = row["phone_format"], phone_format=row["phone_format"],
vat_rate = row["vat_rate"], vat_rate=row["vat_rate"],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
) )
def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings: def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings:
return LocalizationSettings( return LocalizationSettings(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
default_language = row["default_language"], default_language=row["default_language"],
supported_languages = json.loads(row["supported_languages"] or '["en"]'), supported_languages=json.loads(row["supported_languages"] or '["en"]'),
default_currency = row["default_currency"], default_currency=row["default_currency"],
supported_currencies = json.loads(row["supported_currencies"] or '["USD"]'), supported_currencies=json.loads(row["supported_currencies"] or '["USD"]'),
default_timezone = row["default_timezone"], default_timezone=row["default_timezone"],
default_date_format = row["default_date_format"], default_date_format=row["default_date_format"],
default_time_format = row["default_time_format"], default_time_format=row["default_time_format"],
default_number_format = row["default_number_format"], default_number_format=row["default_number_format"],
calendar_type = row["calendar_type"], calendar_type=row["calendar_type"],
first_day_of_week = row["first_day_of_week"], first_day_of_week=row["first_day_of_week"],
region_code = row["region_code"], region_code=row["region_code"],
data_residency = row["data_residency"], data_residency=row["data_residency"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]

View File

@@ -1166,7 +1166,7 @@ async def create_manual_entity(
start_pos=entity.start_pos, start_pos=entity.start_pos,
end_pos=entity.end_pos, end_pos=entity.end_pos,
text_snippet=text[ text_snippet=text[
max(0, entity.start_pos - 20) : min(len(text), entity.end_pos + 20) max(0, entity.start_pos - 20): min(len(text), entity.end_pos + 20)
], ],
confidence=1.0, confidence=1.0,
) )
@@ -1408,7 +1408,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends(
start_pos=pos, start_pos=pos,
end_pos=pos + len(name), end_pos=pos + len(name),
text_snippet=full_text[ text_snippet=full_text[
max(0, pos - 20) : min(len(full_text), pos + len(name) + 20) max(0, pos - 20): min(len(full_text), pos + len(name) + 20)
], ],
confidence=1.0, confidence=1.0,
) )
@@ -1534,7 +1534,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen
start_pos=pos, start_pos=pos,
end_pos=pos + len(name), end_pos=pos + len(name),
text_snippet=full_text[ text_snippet=full_text[
max(0, pos - 20) : min(len(full_text), pos + len(name) + 20) max(0, pos - 20): min(len(full_text), pos + len(name) + 20)
], ],
confidence=1.0, confidence=1.0,
) )
@@ -3804,10 +3804,10 @@ async def system_status():
# ==================== Phase 7: Workflow Automation Endpoints ==================== # ==================== Phase 7: Workflow Automation Endpoints ====================
# Workflow Manager singleton # Workflow Manager singleton
_workflow_manager: "WorkflowManager | None" = None _workflow_manager: Any = None
def get_workflow_manager_instance() -> "WorkflowManager | None": def get_workflow_manager_instance() -> Any:
global _workflow_manager global _workflow_manager
if _workflow_manager is None and WORKFLOW_AVAILABLE and DB_AVAILABLE: if _workflow_manager is None and WORKFLOW_AVAILABLE and DB_AVAILABLE:
from workflow_manager import WorkflowManager from workflow_manager import WorkflowManager

View File

@@ -200,11 +200,11 @@ class MultimodalEntityLinker:
if best_match: if best_match:
return AlignmentResult( return AlignmentResult(
entity_id = query_entity.get("id"), entity_id=query_entity.get("id"),
matched_entity_id = best_match.get("id"), matched_entity_id=best_match.get("id"),
similarity = best_similarity, similarity=best_similarity,
match_type = best_match_type, match_type=best_match_type,
confidence = best_similarity, confidence=best_similarity,
) )
return None return None
@@ -255,15 +255,15 @@ class MultimodalEntityLinker:
if result and result.matched_entity_id: if result and result.matched_entity_id:
link = EntityLink( link = EntityLink(
id = str(uuid.uuid4())[:UUID_LENGTH], id=str(uuid.uuid4())[:UUID_LENGTH],
project_id = project_id, project_id=project_id,
source_entity_id = ent1.get("id"), source_entity_id=ent1.get("id"),
target_entity_id = result.matched_entity_id, target_entity_id=result.matched_entity_id,
link_type = "same_as" if result.similarity > 0.95 else "related_to", link_type="same_as" if result.similarity > 0.95 else "related_to",
source_modality = mod1, source_modality=mod1,
target_modality = mod2, target_modality=mod2,
confidence = result.confidence, confidence=result.confidence,
evidence = f"Cross-modal alignment: {result.match_type}", evidence=f"Cross-modal alignment: {result.match_type}",
) )
links.append(link) links.append(link)
@@ -319,7 +319,7 @@ class MultimodalEntityLinker:
# 选择最佳定义(最长的那个) # 选择最佳定义(最长的那个)
best_definition = ( best_definition = (
max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else "" max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else ""
) )
# 选择最佳名称(最常见的那个) # 选择最佳名称(最常见的那个)
@@ -330,9 +330,9 @@ class MultimodalEntityLinker:
# 构建融合结果 # 构建融合结果
return FusionResult( return FusionResult(
canonical_entity_id = entity_id, canonical_entity_id=entity_id,
merged_entity_ids = merged_ids, merged_entity_ids=merged_ids,
fused_properties = { fused_properties={
"name": best_name, "name": best_name,
"definition": best_definition, "definition": best_definition,
"aliases": list(fused_properties["aliases"]), "aliases": list(fused_properties["aliases"]),
@@ -340,8 +340,8 @@ class MultimodalEntityLinker:
"modalities": list(fused_properties["modalities"]), "modalities": list(fused_properties["modalities"]),
"contexts": fused_properties["contexts"][:10], # 最多10个上下文 "contexts": fused_properties["contexts"][:10], # 最多10个上下文
}, },
source_modalities = list(fused_properties["modalities"]), source_modalities=list(fused_properties["modalities"]),
confidence = min(1.0, len(linked_entities) * 0.2 + 0.5), confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
) )
def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]: def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
@@ -441,7 +441,7 @@ class MultimodalEntityLinker:
) )
# 按相似度排序 # 按相似度排序
suggestions.sort(key = lambda x: x["similarity"], reverse = True) suggestions.sort(key=lambda x: x["similarity"], reverse=True)
return suggestions return suggestions
@@ -469,14 +469,14 @@ class MultimodalEntityLinker:
多模态实体记录 多模态实体记录
""" """
return MultimodalEntity( return MultimodalEntity(
id = str(uuid.uuid4())[:UUID_LENGTH], id=str(uuid.uuid4())[:UUID_LENGTH],
entity_id = entity_id, entity_id=entity_id,
project_id = project_id, project_id=project_id,
name = "", # 将在后续填充 name="", # 将在后续填充
source_type = source_type, source_type=source_type,
source_id = source_id, source_id=source_id,
mention_context = mention_context, mention_context=mention_context,
confidence = confidence, confidence=confidence,
) )
def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict: def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:

View File

@@ -112,9 +112,9 @@ class MultimodalProcessor:
self.audio_dir = os.path.join(self.temp_dir, "audio") self.audio_dir = os.path.join(self.temp_dir, "audio")
# 创建目录 # 创建目录
os.makedirs(self.video_dir, exist_ok = True) os.makedirs(self.video_dir, exist_ok=True)
os.makedirs(self.frames_dir, exist_ok = True) os.makedirs(self.frames_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok = True) os.makedirs(self.audio_dir, exist_ok=True)
def extract_video_info(self, video_path: str) -> dict: def extract_video_info(self, video_path: str) -> dict:
""" """
@@ -159,7 +159,7 @@ class MultimodalProcessor:
"json", "json",
video_path, video_path,
] ]
result = subprocess.run(cmd, capture_output = True, text = True) result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0: if result.returncode == 0:
data = json.loads(result.stdout) data = json.loads(result.stdout)
return { return {
@@ -196,9 +196,9 @@ class MultimodalProcessor:
if FFMPEG_AVAILABLE: if FFMPEG_AVAILABLE:
( (
ffmpeg.input(video_path) ffmpeg.input(video_path)
.output(output_path, ac = 1, ar = 16000, vn = None) .output(output_path, ac=1, ar=16000, vn=None)
.overwrite_output() .overwrite_output()
.run(quiet = True) .run(quiet=True)
) )
else: else:
# 使用命令行 ffmpeg # 使用命令行 ffmpeg
@@ -216,7 +216,7 @@ class MultimodalProcessor:
"-y", "-y",
output_path, output_path,
] ]
subprocess.run(cmd, check = True, capture_output = True) subprocess.run(cmd, check=True, capture_output=True)
return output_path return output_path
except Exception as e: except Exception as e:
@@ -240,7 +240,7 @@ class MultimodalProcessor:
# 创建帧存储目录 # 创建帧存储目录
video_frames_dir = os.path.join(self.frames_dir, video_id) video_frames_dir = os.path.join(self.frames_dir, video_id)
os.makedirs(video_frames_dir, exist_ok = True) os.makedirs(video_frames_dir, exist_ok=True)
try: try:
if CV2_AVAILABLE: if CV2_AVAILABLE:
@@ -284,7 +284,7 @@ class MultimodalProcessor:
"-y", "-y",
output_pattern, output_pattern,
] ]
subprocess.run(cmd, check = True, capture_output = True) subprocess.run(cmd, check=True, capture_output=True)
# 获取生成的帧文件列表 # 获取生成的帧文件列表
frame_paths = sorted( frame_paths = sorted(
@@ -320,10 +320,10 @@ class MultimodalProcessor:
image = image.convert("L") image = image.convert("L")
# 使用 pytesseract 进行 OCR # 使用 pytesseract 进行 OCR
text = pytesseract.image_to_string(image, lang = "chi_sim+eng") text = pytesseract.image_to_string(image, lang="chi_sim+eng")
# 获取置信度数据 # 获取置信度数据
data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT) data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0] confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0 avg_confidence = sum(confidences) / len(confidences) if confidences else 0
@@ -382,13 +382,13 @@ class MultimodalProcessor:
ocr_text, confidence = self.perform_ocr(frame_path) ocr_text, confidence = self.perform_ocr(frame_path)
frame = VideoFrame( frame = VideoFrame(
id = str(uuid.uuid4())[:UUID_LENGTH], id=str(uuid.uuid4())[:UUID_LENGTH],
video_id = video_id, video_id=video_id,
frame_number = frame_number, frame_number=frame_number,
timestamp = timestamp, timestamp=timestamp,
frame_path = frame_path, frame_path=frame_path,
ocr_text = ocr_text, ocr_text=ocr_text,
ocr_confidence = confidence, ocr_confidence=confidence,
) )
frames.append(frame) frames.append(frame)
@@ -407,23 +407,23 @@ class MultimodalProcessor:
full_ocr_text = "\n\n".join(all_ocr_text) full_ocr_text = "\n\n".join(all_ocr_text)
return VideoProcessingResult( return VideoProcessingResult(
video_id = video_id, video_id=video_id,
audio_path = audio_path, audio_path=audio_path,
frames = frames, frames=frames,
ocr_results = ocr_results, ocr_results=ocr_results,
full_text = full_ocr_text, full_text=full_ocr_text,
success = True, success=True,
) )
except Exception as e: except Exception as e:
return VideoProcessingResult( return VideoProcessingResult(
video_id = video_id, video_id=video_id,
audio_path = "", audio_path="",
frames = [], frames=[],
ocr_results = [], ocr_results=[],
full_text = "", full_text="",
success = False, success=False,
error_message = str(e), error_message=str(e),
) )
def cleanup(self, video_id: str = None) -> None: def cleanup(self, video_id: str = None) -> None:
@@ -450,7 +450,7 @@ class MultimodalProcessor:
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
if os.path.exists(dir_path): if os.path.exists(dir_path):
shutil.rmtree(dir_path) shutil.rmtree(dir_path)
os.makedirs(dir_path, exist_ok = True) os.makedirs(dir_path, exist_ok=True)
# Singleton instance # Singleton instance

View File

@@ -113,7 +113,7 @@ class Neo4jManager:
return return
try: try:
self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password)) self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
# 验证连接 # 验证连接
self._driver.verify_connectivity() self._driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}") logger.info(f"Connected to Neo4j at {self.uri}")
@@ -193,9 +193,9 @@ class Neo4jManager:
p.description = $description, p.description = $description,
p.updated_at = datetime() p.updated_at = datetime()
""", """,
project_id = project_id, project_id=project_id,
name = project_name, name=project_name,
description = project_description, description=project_description,
) )
def sync_entity(self, entity: GraphEntity) -> None: def sync_entity(self, entity: GraphEntity) -> None:
@@ -218,13 +218,13 @@ class Neo4jManager:
MATCH (p:Project {id: $project_id}) MATCH (p:Project {id: $project_id})
MERGE (e)-[:BELONGS_TO]->(p) MERGE (e)-[:BELONGS_TO]->(p)
""", """,
id = entity.id, id=entity.id,
project_id = entity.project_id, project_id=entity.project_id,
name = entity.name, name=entity.name,
type = entity.type, type=entity.type,
definition = entity.definition, definition=entity.definition,
aliases = json.dumps(entity.aliases), aliases=json.dumps(entity.aliases),
properties = json.dumps(entity.properties), properties=json.dumps(entity.properties),
) )
def sync_entities_batch(self, entities: list[GraphEntity]) -> None: def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
@@ -261,7 +261,7 @@ class Neo4jManager:
MATCH (p:Project {id: entity.project_id}) MATCH (p:Project {id: entity.project_id})
MERGE (e)-[:BELONGS_TO]->(p) MERGE (e)-[:BELONGS_TO]->(p)
""", """,
entities = entities_data, entities=entities_data,
) )
def sync_relation(self, relation: GraphRelation) -> None: def sync_relation(self, relation: GraphRelation) -> None:
@@ -280,12 +280,12 @@ class Neo4jManager:
r.properties = $properties, r.properties = $properties,
r.updated_at = datetime() r.updated_at = datetime()
""", """,
id = relation.id, id=relation.id,
source_id = relation.source_id, source_id=relation.source_id,
target_id = relation.target_id, target_id=relation.target_id,
relation_type = relation.relation_type, relation_type=relation.relation_type,
evidence = relation.evidence, evidence=relation.evidence,
properties = json.dumps(relation.properties), properties=json.dumps(relation.properties),
) )
def sync_relations_batch(self, relations: list[GraphRelation]) -> None: def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
@@ -317,7 +317,7 @@ class Neo4jManager:
r.properties = rel.properties, r.properties = rel.properties,
r.updated_at = datetime() r.updated_at = datetime()
""", """,
relations = relations_data, relations=relations_data,
) )
def delete_entity(self, entity_id: str) -> None: def delete_entity(self, entity_id: str) -> None:
@@ -331,7 +331,7 @@ class Neo4jManager:
MATCH (e:Entity {id: $id}) MATCH (e:Entity {id: $id})
DETACH DELETE e DETACH DELETE e
""", """,
id = entity_id, id=entity_id,
) )
def delete_project(self, project_id: str) -> None: def delete_project(self, project_id: str) -> None:
@@ -346,7 +346,7 @@ class Neo4jManager:
OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p) OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p)
DETACH DELETE e, p DETACH DELETE e, p
""", """,
id = project_id, id=project_id,
) )
# ==================== 复杂图查询 ==================== # ==================== 复杂图查询 ====================
@@ -376,9 +376,9 @@ class Neo4jManager:
) )
RETURN path RETURN path
""", """,
source_id = source_id, source_id=source_id,
target_id = target_id, target_id=target_id,
max_depth = max_depth, max_depth=max_depth,
) )
record = result.single() record = result.single()
@@ -404,7 +404,7 @@ class Neo4jManager:
] ]
return PathResult( return PathResult(
nodes = nodes, relationships = relationships, length = len(path.relationships) nodes=nodes, relationships=relationships, length=len(path.relationships)
) )
def find_all_paths( def find_all_paths(
@@ -433,10 +433,10 @@ class Neo4jManager:
RETURN path RETURN path
LIMIT $limit LIMIT $limit
""", """,
source_id = source_id, source_id=source_id,
target_id = target_id, target_id=target_id,
max_depth = max_depth, max_depth=max_depth,
limit = limit, limit=limit,
) )
paths = [] paths = []
@@ -460,7 +460,7 @@ class Neo4jManager:
paths.append( paths.append(
PathResult( PathResult(
nodes = nodes, relationships = relationships, length = len(path.relationships) nodes=nodes, relationships=relationships, length=len(path.relationships)
) )
) )
@@ -491,9 +491,9 @@ class Neo4jManager:
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit LIMIT $limit
""", """,
entity_id = entity_id, entity_id=entity_id,
relation_type = relation_type, relation_type=relation_type,
limit = limit, limit=limit,
) )
else: else:
result = session.run( result = session.run(
@@ -502,8 +502,8 @@ class Neo4jManager:
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit LIMIT $limit
""", """,
entity_id = entity_id, entity_id=entity_id,
limit = limit, limit=limit,
) )
neighbors = [] neighbors = []
@@ -541,8 +541,8 @@ class Neo4jManager:
MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2}) MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2})
RETURN DISTINCT common RETURN DISTINCT common
""", """,
id1 = entity_id1, id1=entity_id1,
id2 = entity_id2, id2=entity_id2,
) )
return [ return [
@@ -581,7 +581,7 @@ class Neo4jManager:
{} {}
) YIELD value RETURN value ) YIELD value RETURN value
""", """,
project_id = project_id, project_id=project_id,
) )
# 创建临时图 # 创建临时图
@@ -601,7 +601,7 @@ class Neo4jManager:
} }
) )
""", """,
project_id = project_id, project_id=project_id,
) )
# 运行 PageRank # 运行 PageRank
@@ -615,8 +615,8 @@ class Neo4jManager:
ORDER BY score DESC ORDER BY score DESC
LIMIT $top_n LIMIT $top_n
""", """,
project_id = project_id, project_id=project_id,
top_n = top_n, top_n=top_n,
) )
rankings = [] rankings = []
@@ -624,10 +624,10 @@ class Neo4jManager:
for record in result: for record in result:
rankings.append( rankings.append(
CentralityResult( CentralityResult(
entity_id = record["entity_id"], entity_id=record["entity_id"],
entity_name = record["entity_name"], entity_name=record["entity_name"],
score = record["score"], score=record["score"],
rank = rank, rank=rank,
) )
) )
rank += 1 rank += 1
@@ -637,7 +637,7 @@ class Neo4jManager:
""" """
CALL gds.graph.drop('project-graph-$project_id') CALL gds.graph.drop('project-graph-$project_id')
""", """,
project_id = project_id, project_id=project_id,
) )
return rankings return rankings
@@ -667,8 +667,8 @@ class Neo4jManager:
LIMIT $top_n LIMIT $top_n
RETURN e.id as entity_id, e.name as entity_name, degree as score RETURN e.id as entity_id, e.name as entity_name, degree as score
""", """,
project_id = project_id, project_id=project_id,
top_n = top_n, top_n=top_n,
) )
rankings = [] rankings = []
@@ -676,10 +676,10 @@ class Neo4jManager:
for record in result: for record in result:
rankings.append( rankings.append(
CentralityResult( CentralityResult(
entity_id = record["entity_id"], entity_id=record["entity_id"],
entity_name = record["entity_name"], entity_name=record["entity_name"],
score = float(record["score"]), score=float(record["score"]),
rank = rank, rank=rank,
) )
) )
rank += 1 rank += 1
@@ -710,7 +710,7 @@ class Neo4jManager:
connections, size(connections) as connection_count connections, size(connections) as connection_count
ORDER BY connection_count DESC ORDER BY connection_count DESC
""", """,
project_id = project_id, project_id=project_id,
) )
# 手动分组(基于连通性) # 手动分组(基于连通性)
@@ -752,12 +752,12 @@ class Neo4jManager:
results.append( results.append(
CommunityResult( CommunityResult(
community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0) community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0)
) )
) )
# 按大小排序 # 按大小排序
results.sort(key = lambda x: x.size, reverse = True) results.sort(key=lambda x: x.size, reverse=True)
return results return results
def find_central_entities( def find_central_entities(
@@ -787,7 +787,7 @@ class Neo4jManager:
ORDER BY degree DESC ORDER BY degree DESC
LIMIT 20 LIMIT 20
""", """,
project_id = project_id, project_id=project_id,
) )
else: else:
# 默认使用度中心性 # 默认使用度中心性
@@ -800,7 +800,7 @@ class Neo4jManager:
ORDER BY degree DESC ORDER BY degree DESC
LIMIT 20 LIMIT 20
""", """,
project_id = project_id, project_id=project_id,
) )
rankings = [] rankings = []
@@ -808,10 +808,10 @@ class Neo4jManager:
for record in result: for record in result:
rankings.append( rankings.append(
CentralityResult( CentralityResult(
entity_id = record["entity_id"], entity_id=record["entity_id"],
entity_name = record["entity_name"], entity_name=record["entity_name"],
score = float(record["score"]), score=float(record["score"]),
rank = rank, rank=rank,
) )
) )
rank += 1 rank += 1
@@ -840,7 +840,7 @@ class Neo4jManager:
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN count(e) as count RETURN count(e) as count
""", """,
project_id = project_id, project_id=project_id,
).single()["count"] ).single()["count"]
# 关系数量 # 关系数量
@@ -850,7 +850,7 @@ class Neo4jManager:
MATCH (e)-[r:RELATES_TO]-() MATCH (e)-[r:RELATES_TO]-()
RETURN count(r) as count RETURN count(r) as count
""", """,
project_id = project_id, project_id=project_id,
).single()["count"] ).single()["count"]
# 实体类型分布 # 实体类型分布
@@ -860,7 +860,7 @@ class Neo4jManager:
RETURN e.type as type, count(e) as count RETURN e.type as type, count(e) as count
ORDER BY count DESC ORDER BY count DESC
""", """,
project_id = project_id, project_id=project_id,
) )
types = {record["type"]: record["count"] for record in type_distribution} types = {record["type"]: record["count"] for record in type_distribution}
@@ -873,7 +873,7 @@ class Neo4jManager:
WITH e, count(other) as degree WITH e, count(other) as degree
RETURN avg(degree) as avg_degree RETURN avg(degree) as avg_degree
""", """,
project_id = project_id, project_id=project_id,
).single()["avg_degree"] ).single()["avg_degree"]
# 关系类型分布 # 关系类型分布
@@ -885,7 +885,7 @@ class Neo4jManager:
ORDER BY count DESC ORDER BY count DESC
LIMIT 10 LIMIT 10
""", """,
project_id = project_id, project_id=project_id,
) )
relation_types = {record["type"]: record["count"] for record in rel_types} relation_types = {record["type"]: record["count"] for record in rel_types}
@@ -927,8 +927,8 @@ class Neo4jManager:
}) YIELD node }) YIELD node
RETURN DISTINCT node RETURN DISTINCT node
""", """,
entity_ids = entity_ids, entity_ids=entity_ids,
depth = depth, depth=depth,
) )
nodes = [] nodes = []
@@ -953,7 +953,7 @@ class Neo4jManager:
RETURN source.id as source_id, target.id as target_id, RETURN source.id as source_id, target.id as target_id,
r.relation_type as type, r.evidence as evidence r.relation_type as type, r.evidence as evidence
""", """,
node_ids = list(node_ids), node_ids=list(node_ids),
) )
relationships = [ relationships = [
@@ -1015,13 +1015,13 @@ def sync_project_to_neo4j(
# 同步实体 # 同步实体
graph_entities = [ graph_entities = [
GraphEntity( GraphEntity(
id = e["id"], id=e["id"],
project_id = project_id, project_id=project_id,
name = e["name"], name=e["name"],
type = e.get("type", "unknown"), type=e.get("type", "unknown"),
definition = e.get("definition", ""), definition=e.get("definition", ""),
aliases = e.get("aliases", []), aliases=e.get("aliases", []),
properties = e.get("properties", {}), properties=e.get("properties", {}),
) )
for e in entities for e in entities
] ]
@@ -1030,12 +1030,12 @@ def sync_project_to_neo4j(
# 同步关系 # 同步关系
graph_relations = [ graph_relations = [
GraphRelation( GraphRelation(
id = r["id"], id=r["id"],
source_id = r["source_entity_id"], source_id=r["source_entity_id"],
target_id = r["target_entity_id"], target_id=r["target_entity_id"],
relation_type = r["relation_type"], relation_type=r["relation_type"],
evidence = r.get("evidence", ""), evidence=r.get("evidence", ""),
properties = r.get("properties", {}), properties=r.get("properties", {}),
) )
for r in relations for r in relations
] ]
@@ -1048,7 +1048,7 @@ def sync_project_to_neo4j(
if __name__ == "__main__": if __name__ == "__main__":
# 测试代码 # 测试代码
logging.basicConfig(level = logging.INFO) logging.basicConfig(level=logging.INFO)
manager = Neo4jManager() manager = Neo4jManager()
@@ -1065,11 +1065,11 @@ if __name__ == "__main__":
# 测试实体 # 测试实体
test_entity = GraphEntity( test_entity = GraphEntity(
id = "test-entity-1", id="test-entity-1",
project_id = "test-project", project_id="test-project",
name = "Test Entity", name="Test Entity",
type = "Person", type="Person",
definition = "A test entity", definition="A test entity",
) )
manager.sync_entity(test_entity) manager.sync_entity(test_entity)
print("✅ Entity synced") print("✅ Entity synced")

File diff suppressed because it is too large Load Diff

View File

@@ -82,7 +82,7 @@ class PerformanceMetric:
endpoint: str | None endpoint: str | None
duration_ms: float duration_ms: float
timestamp: str timestamp: str
metadata: dict = field(default_factory = dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@@ -176,7 +176,7 @@ class CacheManager:
if REDIS_AVAILABLE and redis_url: if REDIS_AVAILABLE and redis_url:
try: try:
self.redis_client = redis.from_url(redis_url, decode_responses = True) self.redis_client = redis.from_url(redis_url, decode_responses=True)
self.redis_client.ping() self.redis_client.ping()
self.use_redis = True self.use_redis = True
print(f"Redis 缓存已连接: {redis_url}") print(f"Redis 缓存已连接: {redis_url}")
@@ -233,7 +233,7 @@ class CacheManager:
def _get_entry_size(self, value: Any) -> int: def _get_entry_size(self, value: Any) -> int:
"""估算缓存条目大小""" """估算缓存条目大小"""
try: try:
return len(json.dumps(value, ensure_ascii = False).encode("utf-8")) return len(json.dumps(value, ensure_ascii=False).encode("utf-8"))
except (TypeError, ValueError): except (TypeError, ValueError):
return 1024 # 默认估算 return 1024 # 默认估算
@@ -245,7 +245,7 @@ class CacheManager:
and self.memory_cache and self.memory_cache
): ):
# 移除最久未访问的 # 移除最久未访问的
oldest_key, oldest_entry = self.memory_cache.popitem(last = False) oldest_key, oldest_entry = self.memory_cache.popitem(last=False)
self.current_memory_size -= oldest_entry.size_bytes self.current_memory_size -= oldest_entry.size_bytes
self.stats.evictions += 1 self.stats.evictions += 1
@@ -314,7 +314,7 @@ class CacheManager:
if self.use_redis: if self.use_redis:
try: try:
serialized = json.dumps(value, ensure_ascii = False) serialized = json.dumps(value, ensure_ascii=False)
self.redis_client.setex(key, ttl, serialized) self.redis_client.setex(key, ttl, serialized)
return True return True
except Exception as e: except Exception as e:
@@ -331,12 +331,12 @@ class CacheManager:
now = time.time() now = time.time()
entry = CacheEntry( entry = CacheEntry(
key = key, key=key,
value = value, value=value,
created_at = now, created_at=now,
expires_at = now + ttl if ttl > 0 else None, expires_at=now + ttl if ttl > 0 else None,
size_bytes = size, size_bytes=size,
last_accessed = now, last_accessed=now,
) )
# 如果已存在,更新大小 # 如果已存在,更新大小
@@ -412,7 +412,7 @@ class CacheManager:
try: try:
pipe = self.redis_client.pipeline() pipe = self.redis_client.pipeline()
for key, value in mapping.items(): for key, value in mapping.items():
serialized = json.dumps(value, ensure_ascii = False) serialized = json.dumps(value, ensure_ascii=False)
pipe.setex(key, ttl, serialized) pipe.setex(key, ttl, serialized)
pipe.execute() pipe.execute()
return True return True
@@ -505,7 +505,7 @@ class CacheManager:
for entity in entities: for entity in entities:
key = f"entity:{entity['id']}" key = f"entity:{entity['id']}"
self.set(key, dict(entity), ttl = 7200) # 2小时 self.set(key, dict(entity), ttl=7200) # 2小时
stats["entities"] += 1 stats["entities"] += 1
# 预热关系数据 # 预热关系数据
@@ -522,7 +522,7 @@ class CacheManager:
for relation in relations: for relation in relations:
key = f"relation:{relation['id']}" key = f"relation:{relation['id']}"
self.set(key, dict(relation), ttl = 3600) self.set(key, dict(relation), ttl=3600)
stats["relations"] += 1 stats["relations"] += 1
# 预热最近的转录 # 预热最近的转录
@@ -543,7 +543,7 @@ class CacheManager:
"type": transcript.get("type", "audio"), "type": transcript.get("type", "audio"),
"created_at": transcript["created_at"], "created_at": transcript["created_at"],
} }
self.set(key, meta, ttl = 1800) # 30分钟 self.set(key, meta, ttl=1800) # 30分钟
stats["transcripts"] += 1 stats["transcripts"] += 1
# 预热项目知识库摘要 # 预热项目知识库摘要
@@ -561,7 +561,7 @@ class CacheManager:
"relation_count": relation_count, "relation_count": relation_count,
"cached_at": datetime.now().isoformat(), "cached_at": datetime.now().isoformat(),
} }
self.set(f"project_summary:{project_id}", summary, ttl = 3600) self.set(f"project_summary:{project_id}", summary, ttl=3600)
conn.close() conn.close()
@@ -583,7 +583,7 @@ class CacheManager:
try: try:
# 使用 Redis 的 scan 查找相关 key # 使用 Redis 的 scan 查找相关 key
pattern = f"*:{project_id}:*" pattern = f"*:{project_id}:*"
for key in self.redis_client.scan_iter(match = pattern): for key in self.redis_client.scan_iter(match=pattern):
self.redis_client.delete(key) self.redis_client.delete(key)
count += 1 count += 1
except Exception as e: except Exception as e:
@@ -625,7 +625,7 @@ class DatabaseSharding:
self.shards_count = shards_count self.shards_count = shards_count
# 确保分片目录存在 # 确保分片目录存在
os.makedirs(shard_db_dir, exist_ok = True) os.makedirs(shard_db_dir, exist_ok=True)
# 分片映射 # 分片映射
self.shard_map: dict[str, ShardInfo] = {} self.shard_map: dict[str, ShardInfo] = {}
@@ -650,10 +650,10 @@ class DatabaseSharding:
db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db") db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db")
self.shard_map[shard_id] = ShardInfo( self.shard_map[shard_id] = ShardInfo(
shard_id = shard_id, shard_id=shard_id,
shard_key_range = (start_char, end_char), shard_key_range=(start_char, end_char),
db_path = db_path, db_path=db_path,
created_at = datetime.now().isoformat(), created_at=datetime.now().isoformat(),
) )
# 确保分片数据库存在 # 确保分片数据库存在
@@ -934,7 +934,7 @@ class TaskQueue:
# 初始化 Celery # 初始化 Celery
if CELERY_AVAILABLE and redis_url: if CELERY_AVAILABLE and redis_url:
try: try:
self.celery_app = Celery("insightflow", broker = redis_url, backend = redis_url) self.celery_app = Celery("insightflow", broker=redis_url, backend=redis_url)
self.use_celery = True self.use_celery = True
print("Celery 任务队列已初始化") print("Celery 任务队列已初始化")
except Exception as e: except Exception as e:
@@ -989,12 +989,12 @@ class TaskQueue:
task_id = str(uuid.uuid4())[:16] task_id = str(uuid.uuid4())[:16]
task = TaskInfo( task = TaskInfo(
id = task_id, id=task_id,
task_type = task_type, task_type=task_type,
status = "pending", status="pending",
payload = payload, payload=payload,
created_at = datetime.now().isoformat(), created_at=datetime.now().isoformat(),
max_retries = max_retries, max_retries=max_retries,
) )
if self.use_celery: if self.use_celery:
@@ -1003,10 +1003,10 @@ class TaskQueue:
# 这里简化处理,实际应该定义具体的 Celery 任务 # 这里简化处理,实际应该定义具体的 Celery 任务
result = self.celery_app.send_task( result = self.celery_app.send_task(
f"insightflow.tasks.{task_type}", f"insightflow.tasks.{task_type}",
args = [payload], args=[payload],
task_id = task_id, task_id=task_id,
retry = True, retry=True,
retry_policy = { retry_policy={
"max_retries": max_retries, "max_retries": max_retries,
"interval_start": 10, "interval_start": 10,
"interval_step": 10, "interval_step": 10,
@@ -1024,7 +1024,7 @@ class TaskQueue:
with self.task_lock: with self.task_lock:
self.tasks[task_id] = task self.tasks[task_id] = task
# 异步执行 # 异步执行
threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start() threading.Thread(target=self._execute_task, args=(task_id, ), daemon=True).start()
# 保存到数据库 # 保存到数据库
self._save_task(task) self._save_task(task)
@@ -1061,7 +1061,7 @@ class TaskQueue:
task.status = "retrying" task.status = "retrying"
# 延迟重试 # 延迟重试
threading.Timer( threading.Timer(
10 * task.retry_count, self._execute_task, args = (task_id, ) 10 * task.retry_count, self._execute_task, args=(task_id, )
).start() ).start()
else: else:
task.status = "failed" task.status = "failed"
@@ -1089,8 +1089,8 @@ class TaskQueue:
task.id, task.id,
task.task_type, task.task_type,
task.status, task.status,
json.dumps(task.payload, ensure_ascii = False), json.dumps(task.payload, ensure_ascii=False),
json.dumps(task.result, ensure_ascii = False) if task.result else None, json.dumps(task.result, ensure_ascii=False) if task.result else None,
task.error_message, task.error_message,
task.retry_count, task.retry_count,
task.max_retries, task.max_retries,
@@ -1120,7 +1120,7 @@ class TaskQueue:
""", """,
( (
task.status, task.status,
json.dumps(task.result, ensure_ascii = False) if task.result else None, json.dumps(task.result, ensure_ascii=False) if task.result else None,
task.error_message, task.error_message,
task.retry_count, task.retry_count,
task.started_at, task.started_at,
@@ -1136,7 +1136,7 @@ class TaskQueue:
"""获取任务状态""" """获取任务状态"""
if self.use_celery: if self.use_celery:
try: try:
result = AsyncResult(task_id, app = self.celery_app) result = AsyncResult(task_id, app=self.celery_app)
status_map = { status_map = {
"PENDING": "pending", "PENDING": "pending",
@@ -1147,13 +1147,13 @@ class TaskQueue:
} }
return TaskInfo( return TaskInfo(
id = task_id, id=task_id,
task_type = "celery_task", task_type="celery_task",
status = status_map.get(result.status, "unknown"), status=status_map.get(result.status, "unknown"),
payload = {}, payload={},
created_at = "", created_at="",
result = result.result if result.successful() else None, result=result.result if result.successful() else None,
error_message = str(result.result) if result.failed() else None, error_message=str(result.result) if result.failed() else None,
) )
except Exception as e: except Exception as e:
print(f"获取 Celery 任务状态失败: {e}") print(f"获取 Celery 任务状态失败: {e}")
@@ -1198,17 +1198,17 @@ class TaskQueue:
for row in rows: for row in rows:
tasks.append( tasks.append(
TaskInfo( TaskInfo(
id = row["id"], id=row["id"],
task_type = row["task_type"], task_type=row["task_type"],
status = row["status"], status=row["status"],
payload = json.loads(row["payload"]) if row["payload"] else {}, payload=json.loads(row["payload"]) if row["payload"] else {},
created_at = row["created_at"], created_at=row["created_at"],
started_at = row["started_at"], started_at=row["started_at"],
completed_at = row["completed_at"], completed_at=row["completed_at"],
result = json.loads(row["result"]) if row["result"] else None, result=json.loads(row["result"]) if row["result"] else None,
error_message = row["error_message"], error_message=row["error_message"],
retry_count = row["retry_count"], retry_count=row["retry_count"],
max_retries = row["max_retries"], max_retries=row["max_retries"],
) )
) )
@@ -1218,7 +1218,7 @@ class TaskQueue:
"""取消任务""" """取消任务"""
if self.use_celery: if self.use_celery:
try: try:
self.celery_app.control.revoke(task_id, terminate = True) self.celery_app.control.revoke(task_id, terminate=True)
return True return True
except Exception as e: except Exception as e:
print(f"取消 Celery 任务失败: {e}") print(f"取消 Celery 任务失败: {e}")
@@ -1248,7 +1248,7 @@ class TaskQueue:
if not self.use_celery: if not self.use_celery:
with self.task_lock: with self.task_lock:
self.tasks[task_id] = task self.tasks[task_id] = task
threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start() threading.Thread(target=self._execute_task, args=(task_id, ), daemon=True).start()
self._update_task_status(task) self._update_task_status(task)
return True return True
@@ -1337,12 +1337,12 @@ class PerformanceMonitor:
metadata: 额外元数据 metadata: 额外元数据
""" """
metric = PerformanceMetric( metric = PerformanceMetric(
id = str(uuid.uuid4())[:16], id=str(uuid.uuid4())[:16],
metric_type = metric_type, metric_type=metric_type,
endpoint = endpoint, endpoint=endpoint,
duration_ms = duration_ms, duration_ms=duration_ms,
timestamp = datetime.now().isoformat(), timestamp=datetime.now().isoformat(),
metadata = metadata or {}, metadata=metadata or {},
) )
# 添加到缓冲区 # 添加到缓冲区
@@ -1379,7 +1379,7 @@ class PerformanceMonitor:
metric.endpoint, metric.endpoint,
metric.duration_ms, metric.duration_ms,
metric.timestamp, metric.timestamp,
json.dumps(metric.metadata, ensure_ascii = False), json.dumps(metric.metadata, ensure_ascii=False),
), ),
) )
@@ -1703,13 +1703,13 @@ class PerformanceManager:
self.db_path = db_path self.db_path = db_path
# 初始化各模块 # 初始化各模块
self.cache = CacheManager(redis_url = redis_url, db_path = db_path) self.cache = CacheManager(redis_url=redis_url, db_path=db_path)
self.sharding = DatabaseSharding(base_db_path = db_path) if enable_sharding else None self.sharding = DatabaseSharding(base_db_path=db_path) if enable_sharding else None
self.task_queue = TaskQueue(redis_url = redis_url, db_path = db_path) self.task_queue = TaskQueue(redis_url=redis_url, db_path=db_path)
self.monitor = PerformanceMonitor(db_path = db_path) self.monitor = PerformanceMonitor(db_path=db_path)
def get_health_status(self) -> dict: def get_health_status(self) -> dict:
"""获取系统健康状态""" """获取系统健康状态"""
@@ -1760,6 +1760,6 @@ def get_performance_manager(
global _performance_manager global _performance_manager
if _performance_manager is None: if _performance_manager is None:
_performance_manager = PerformanceManager( _performance_manager = PerformanceManager(
db_path = db_path, redis_url = redis_url, enable_sharding = enable_sharding db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding
) )
return _performance_manager return _performance_manager

View File

@@ -18,7 +18,6 @@ from enum import Enum
from typing import Any from typing import Any
import httpx import httpx
from plugin_manager import PluginManager
import urllib.parse import urllib.parse
# Constants # Constants
@@ -63,7 +62,7 @@ class Plugin:
plugin_type: str plugin_type: str
project_id: str project_id: str
status: str = "active" status: str = "active"
config: dict = field(default_factory = dict) config: dict = field(default_factory=dict)
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_used_at: str | None = None last_used_at: str | None = None
@@ -111,8 +110,8 @@ class WebhookEndpoint:
endpoint_url: str endpoint_url: str
project_id: str | None = None project_id: str | None = None
auth_type: str = "none" # none, api_key, oauth, custom auth_type: str = "none" # none, api_key, oauth, custom
auth_config: dict = field(default_factory = dict) auth_config: dict = field(default_factory=dict)
trigger_events: list[str] = field(default_factory = list) trigger_events: list[str] = field(default_factory=list)
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@@ -151,7 +150,7 @@ class ChromeExtensionToken:
user_id: str | None = None user_id: str | None = None
project_id: str | None = None project_id: str | None = None
name: str = "" name: str = ""
permissions: list[str] = field(default_factory = lambda: ["read", "write"]) permissions: list[str] = field(default_factory=lambda: ["read", "write"])
expires_at: str | None = None expires_at: str | None = None
created_at: str = "" created_at: str = ""
last_used_at: str | None = None last_used_at: str | None = None
@@ -162,7 +161,7 @@ class ChromeExtensionToken:
class PluginManager: class PluginManager:
"""插件管理主类""" """插件管理主类"""
def __init__(self, db_manager = None) -> None: def __init__(self, db_manager=None) -> None:
self.db = db_manager self.db = db_manager
self._handlers = {} self._handlers = {}
self._register_default_handlers() self._register_default_handlers()
@@ -296,16 +295,16 @@ class PluginManager:
def _row_to_plugin(self, row: sqlite3.Row) -> Plugin: def _row_to_plugin(self, row: sqlite3.Row) -> Plugin:
"""将数据库行转换为 Plugin 对象""" """将数据库行转换为 Plugin 对象"""
return Plugin( return Plugin(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
plugin_type = row["plugin_type"], plugin_type=row["plugin_type"],
project_id = row["project_id"], project_id=row["project_id"],
status = row["status"], status=row["status"],
config = json.loads(row["config"]) if row["config"] else {}, config=json.loads(row["config"]) if row["config"] else {},
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
last_used_at = row["last_used_at"], last_used_at=row["last_used_at"],
use_count = row["use_count"], use_count=row["use_count"],
) )
# ==================== Plugin Config ==================== # ==================== Plugin Config ====================
@@ -343,13 +342,13 @@ class PluginManager:
conn.close() conn.close()
return PluginConfig( return PluginConfig(
id = config_id, id=config_id,
plugin_id = plugin_id, plugin_id=plugin_id,
config_key = key, config_key=key,
config_value = value, config_value=value,
is_encrypted = is_encrypted, is_encrypted=is_encrypted,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
def get_plugin_config(self, plugin_id: str, key: str) -> str | None: def get_plugin_config(self, plugin_id: str, key: str) -> str | None:
@@ -427,7 +426,7 @@ class ChromeExtensionHandler:
if expires_days: if expires_days:
from datetime import timedelta from datetime import timedelta
expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat() expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
conn.execute( conn.execute(
@@ -452,14 +451,14 @@ class ChromeExtensionHandler:
conn.close() conn.close()
return ChromeExtensionToken( return ChromeExtensionToken(
id = token_id, id=token_id,
token = raw_token, # 仅返回一次 token=raw_token, # 仅返回一次
user_id = user_id, user_id=user_id,
project_id = project_id, project_id=project_id,
name = name, name=name,
permissions = permissions or ["read"], permissions=permissions or ["read"],
expires_at = expires_at, expires_at=expires_at,
created_at = now, created_at=now,
) )
def validate_token(self, token: str) -> ChromeExtensionToken | None: def validate_token(self, token: str) -> ChromeExtensionToken | None:
@@ -494,16 +493,16 @@ class ChromeExtensionHandler:
conn.close() conn.close()
return ChromeExtensionToken( return ChromeExtensionToken(
id = row["id"], id=row["id"],
token = "", # 不返回实际令牌 token="", # 不返回实际令牌
user_id = row["user_id"], user_id=row["user_id"],
project_id = row["project_id"], project_id=row["project_id"],
name = row["name"], name=row["name"],
permissions = json.loads(row["permissions"]), permissions=json.loads(row["permissions"]),
expires_at = row["expires_at"], expires_at=row["expires_at"],
created_at = row["created_at"], created_at=row["created_at"],
last_used_at = now, last_used_at=now,
use_count = row["use_count"] + 1, use_count=row["use_count"] + 1,
) )
def revoke_token(self, token_id: str) -> bool: def revoke_token(self, token_id: str) -> bool:
@@ -545,17 +544,17 @@ class ChromeExtensionHandler:
for row in rows: for row in rows:
tokens.append( tokens.append(
ChromeExtensionToken( ChromeExtensionToken(
id = row["id"], id=row["id"],
token = "", # 不返回实际令牌 token="", # 不返回实际令牌
user_id = row["user_id"], user_id=row["user_id"],
project_id = row["project_id"], project_id=row["project_id"],
name = row["name"], name=row["name"],
permissions = json.loads(row["permissions"]), permissions=json.loads(row["permissions"]),
expires_at = row["expires_at"], expires_at=row["expires_at"],
created_at = row["created_at"], created_at=row["created_at"],
last_used_at = row["last_used_at"], last_used_at=row["last_used_at"],
use_count = row["use_count"], use_count=row["use_count"],
is_revoked = bool(row["is_revoked"]), is_revoked=bool(row["is_revoked"]),
) )
) )
@@ -646,16 +645,16 @@ class BotHandler:
conn.close() conn.close()
return BotSession( return BotSession(
id = bot_id, id=bot_id,
bot_type = self.bot_type, bot_type=self.bot_type,
session_id = session_id, session_id=session_id,
session_name = session_name, session_name=session_name,
project_id = project_id, project_id=project_id,
webhook_url = webhook_url, webhook_url=webhook_url,
secret = secret, secret=secret,
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
def get_session(self, session_id: str) -> BotSession | None: def get_session(self, session_id: str) -> BotSession | None:
@@ -739,18 +738,18 @@ class BotHandler:
def _row_to_session(self, row: sqlite3.Row) -> BotSession: def _row_to_session(self, row: sqlite3.Row) -> BotSession:
"""将数据库行转换为 BotSession 对象""" """将数据库行转换为 BotSession 对象"""
return BotSession( return BotSession(
id = row["id"], id=row["id"],
bot_type = row["bot_type"], bot_type=row["bot_type"],
session_id = row["session_id"], session_id=row["session_id"],
session_name = row["session_name"], session_name=row["session_name"],
project_id = row["project_id"], project_id=row["project_id"],
webhook_url = row["webhook_url"], webhook_url=row["webhook_url"],
secret = row["secret"], secret=row["secret"],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
last_message_at = row["last_message_at"], last_message_at=row["last_message_at"],
message_count = row["message_count"], message_count=row["message_count"],
) )
async def handle_message(self, session: BotSession, message: dict) -> dict: async def handle_message(self, session: BotSession, message: dict) -> dict:
@@ -880,7 +879,7 @@ class BotHandler:
hmac_code = hmac.new( hmac_code = hmac.new(
session.secret.encode("utf-8"), session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"), string_to_sign.encode("utf-8"),
digestmod = hashlib.sha256, digestmod=hashlib.sha256,
).digest() ).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
else: else:
@@ -895,7 +894,7 @@ class BotHandler:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
session.webhook_url, json = payload, headers = {"Content-Type": "application/json"} session.webhook_url, json=payload, headers={"Content-Type": "application/json"}
) )
return response.status_code == 200 return response.status_code == 200
@@ -911,7 +910,7 @@ class BotHandler:
hmac_code = hmac.new( hmac_code = hmac.new(
session.secret.encode("utf-8"), session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"), string_to_sign.encode("utf-8"),
digestmod = hashlib.sha256, digestmod=hashlib.sha256,
).digest() ).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
sign = urllib.parse.quote(sign) sign = urllib.parse.quote(sign)
@@ -926,7 +925,7 @@ class BotHandler:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
url, json = payload, headers = {"Content-Type": "application/json"} url, json=payload, headers={"Content-Type": "application/json"}
) )
return response.status_code == 200 return response.status_code == 200
@@ -976,17 +975,17 @@ class WebhookIntegration:
conn.close() conn.close()
return WebhookEndpoint( return WebhookEndpoint(
id = endpoint_id, id=endpoint_id,
name = name, name=name,
endpoint_type = self.endpoint_type, endpoint_type=self.endpoint_type,
endpoint_url = endpoint_url, endpoint_url=endpoint_url,
project_id = project_id, project_id=project_id,
auth_type = auth_type, auth_type=auth_type,
auth_config = auth_config or {}, auth_config=auth_config or {},
trigger_events = trigger_events or [], trigger_events=trigger_events or [],
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None: def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None:
@@ -1074,19 +1073,19 @@ class WebhookIntegration:
def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint: def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint:
"""将数据库行转换为 WebhookEndpoint 对象""" """将数据库行转换为 WebhookEndpoint 对象"""
return WebhookEndpoint( return WebhookEndpoint(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
endpoint_type = row["endpoint_type"], endpoint_type=row["endpoint_type"],
endpoint_url = row["endpoint_url"], endpoint_url=row["endpoint_url"],
project_id = row["project_id"], project_id=row["project_id"],
auth_type = row["auth_type"], auth_type=row["auth_type"],
auth_config = json.loads(row["auth_config"]) if row["auth_config"] else {}, auth_config=json.loads(row["auth_config"]) if row["auth_config"] else {},
trigger_events = json.loads(row["trigger_events"]) if row["trigger_events"] else [], trigger_events=json.loads(row["trigger_events"]) if row["trigger_events"] else [],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
last_triggered_at = row["last_triggered_at"], last_triggered_at=row["last_triggered_at"],
trigger_count = row["trigger_count"], trigger_count=row["trigger_count"],
) )
async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool: async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool:
@@ -1113,7 +1112,7 @@ class WebhookIntegration:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
endpoint.endpoint_url, json = payload, headers = headers, timeout = 30.0 endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0
) )
success = response.status_code in [200, 201, 202] success = response.status_code in [200, 201, 202]
@@ -1202,19 +1201,19 @@ class WebDAVSyncManager:
conn.close() conn.close()
return WebDAVSync( return WebDAVSync(
id = sync_id, id=sync_id,
name = name, name=name,
project_id = project_id, project_id=project_id,
server_url = server_url, server_url=server_url,
username = username, username=username,
password = password, password=password,
remote_path = remote_path, remote_path=remote_path,
sync_mode = sync_mode, sync_mode=sync_mode,
sync_interval = sync_interval, sync_interval=sync_interval,
last_sync_status = "pending", last_sync_status="pending",
is_active = True, is_active=True,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
def get_sync(self, sync_id: str) -> WebDAVSync | None: def get_sync(self, sync_id: str) -> WebDAVSync | None:
@@ -1292,22 +1291,22 @@ class WebDAVSyncManager:
def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync: def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync:
"""将数据库行转换为 WebDAVSync 对象""" """将数据库行转换为 WebDAVSync 对象"""
return WebDAVSync( return WebDAVSync(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
project_id = row["project_id"], project_id=row["project_id"],
server_url = row["server_url"], server_url=row["server_url"],
username = row["username"], username=row["username"],
password = row["password"], password=row["password"],
remote_path = row["remote_path"], remote_path=row["remote_path"],
sync_mode = row["sync_mode"], sync_mode=row["sync_mode"],
sync_interval = row["sync_interval"], sync_interval=row["sync_interval"],
last_sync_at = row["last_sync_at"], last_sync_at=row["last_sync_at"],
last_sync_status = row["last_sync_status"], last_sync_status=row["last_sync_status"],
last_sync_error = row["last_sync_error"] or "", last_sync_error=row["last_sync_error"] or "",
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
sync_count = row["sync_count"], sync_count=row["sync_count"],
) )
async def test_connection(self, sync: WebDAVSync) -> dict: async def test_connection(self, sync: WebDAVSync) -> dict:
@@ -1316,7 +1315,7 @@ class WebDAVSyncManager:
return {"success": False, "error": "WebDAV library not available"} return {"success": False, "error": "WebDAV library not available"}
try: try:
client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password)) client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password))
# 尝试列出根目录 # 尝试列出根目录
client.list("/") client.list("/")
@@ -1335,7 +1334,7 @@ class WebDAVSyncManager:
return {"success": False, "error": "Sync is not active"} return {"success": False, "error": "Sync is not active"}
try: try:
client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password)) client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password))
# 确保远程目录存在 # 确保远程目录存在
remote_project_path = f"{sync.remote_path}/{sync.project_id}" remote_project_path = f"{sync.remote_path}/{sync.project_id}"
@@ -1367,13 +1366,13 @@ class WebDAVSyncManager:
} }
# 上传 JSON 文件 # 上传 JSON 文件
json_content = json.dumps(export_data, ensure_ascii = False, indent = 2) json_content = json.dumps(export_data, ensure_ascii=False, indent=2)
json_path = f"{remote_project_path}/project_export.json" json_path = f"{remote_project_path}/project_export.json"
# 使用临时文件上传 # 使用临时文件上传
import tempfile import tempfile
with tempfile.NamedTemporaryFile(mode = "w", suffix = ".json", delete = False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
f.write(json_content) f.write(json_content)
temp_path = f.name temp_path = f.name
@@ -1419,7 +1418,7 @@ class WebDAVSyncManager:
_plugin_manager = None _plugin_manager = None
def get_plugin_manager(db_manager = None) -> None: def get_plugin_manager(db_manager=None) -> None:
"""获取 PluginManager 单例""" """获取 PluginManager 单例"""
global _plugin_manager global _plugin_manager
if _plugin_manager is None: if _plugin_manager is None:

View File

@@ -110,17 +110,17 @@ class RateLimiter:
# 检查是否超过限制 # 检查是否超过限制
if current_count >= stored_config.requests_per_minute: if current_count >= stored_config.requests_per_minute:
return RateLimitInfo( return RateLimitInfo(
allowed = False, allowed=False,
remaining = 0, remaining=0,
reset_time = reset_time, reset_time=reset_time,
retry_after = stored_config.window_size, retry_after=stored_config.window_size,
) )
# 允许请求,增加计数 # 允许请求,增加计数
await counter.add_request() await counter.add_request()
return RateLimitInfo( return RateLimitInfo(
allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0 allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
) )
async def get_limit_info(self, key: str) -> RateLimitInfo: async def get_limit_info(self, key: str) -> RateLimitInfo:
@@ -128,10 +128,10 @@ class RateLimiter:
if key not in self.counters: if key not in self.counters:
config = RateLimitConfig() config = RateLimitConfig()
return RateLimitInfo( return RateLimitInfo(
allowed = True, allowed=True,
remaining = config.requests_per_minute, remaining=config.requests_per_minute,
reset_time = int(time.time()) + config.window_size, reset_time=int(time.time()) + config.window_size,
retry_after = 0, retry_after=0,
) )
counter = self.counters[key] counter = self.counters[key]
@@ -142,10 +142,10 @@ class RateLimiter:
reset_time = int(time.time()) + config.window_size reset_time = int(time.time()) + config.window_size
return RateLimitInfo( return RateLimitInfo(
allowed = current_count < config.requests_per_minute, allowed=current_count < config.requests_per_minute,
remaining = remaining, remaining=remaining,
reset_time = reset_time, reset_time=reset_time,
retry_after = max(0, config.window_size) retry_after=max(0, config.window_size)
if current_count >= config.requests_per_minute if current_count >= config.requests_per_minute
else 0, else 0,
) )
@@ -186,7 +186,7 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None
def decorator(func) -> None: def decorator(func) -> None:
limiter = get_rate_limiter() limiter = get_rate_limiter()
config = RateLimitConfig(requests_per_minute = requests_per_minute) config = RateLimitConfig(requests_per_minute=requests_per_minute)
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs) -> None: async def async_wrapper(*args, **kwargs) -> None:

View File

@@ -49,8 +49,8 @@ class SearchResult:
content_type: str # transcript, entity, relation content_type: str # transcript, entity, relation
project_id: str project_id: str
score: float score: float
highlights: list[tuple[int, int]] = field(default_factory = list) # 高亮位置 highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置
metadata: dict = field(default_factory = dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@@ -74,7 +74,7 @@ class SemanticSearchResult:
project_id: str project_id: str
similarity: float similarity: float
embedding: list[float] | None = None embedding: list[float] | None = None
metadata: dict = field(default_factory = dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
result = { result = {
@@ -132,7 +132,7 @@ class KnowledgeGap:
severity: str # high, medium, low severity: str # high, medium, low
suggestions: list[str] suggestions: list[str]
related_entities: list[str] related_entities: list[str]
metadata: dict = field(default_factory = dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@@ -318,8 +318,8 @@ class FullTextSearch:
content_id, content_id,
content_type, content_type,
project_id, project_id,
json.dumps(tokens, ensure_ascii = False), json.dumps(tokens, ensure_ascii=False),
json.dumps(token_positions, ensure_ascii = False), json.dumps(token_positions, ensure_ascii=False),
now, now,
now, now,
), ),
@@ -340,7 +340,7 @@ class FullTextSearch:
content_type, content_type,
project_id, project_id,
freq, freq,
json.dumps(positions, ensure_ascii = False), json.dumps(positions, ensure_ascii=False),
), ),
) )
@@ -383,9 +383,9 @@ class FullTextSearch:
scored_results = self._score_results(results, parsed_query) scored_results = self._score_results(results, parsed_query)
# 排序和分页 # 排序和分页
scored_results.sort(key = lambda x: x.score, reverse = True) scored_results.sort(key=lambda x: x.score, reverse=True)
return scored_results[offset : offset + limit] return scored_results[offset: offset + limit]
def _parse_boolean_query(self, query: str) -> dict: def _parse_boolean_query(self, query: str) -> dict:
""" """
@@ -412,10 +412,10 @@ class FullTextSearch:
not_pattern = r"(?:NOT\s+|\-)(\w+)" not_pattern = r"(?:NOT\s+|\-)(\w+)"
not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE) not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE)
not_terms.extend(not_matches) not_terms.extend(not_matches)
query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags = re.IGNORECASE) query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE)
# 处理 OR # 处理 OR
or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags = re.IGNORECASE) or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE)
if len(or_parts) > 1: if len(or_parts) > 1:
or_terms = [p.strip() for p in or_parts[1:] if p.strip()] or_terms = [p.strip() for p in or_parts[1:] if p.strip()]
query_without_phrases = or_parts[0] query_without_phrases = or_parts[0]
@@ -654,13 +654,13 @@ class FullTextSearch:
scored.append( scored.append(
SearchResult( SearchResult(
id = result["id"], id=result["id"],
content = result["content"], content=result["content"],
content_type = result["content_type"], content_type=result["content_type"],
project_id = result["project_id"], project_id=result["project_id"],
score = round(score, 4), score=round(score, 4),
highlights = highlights[:10], # 限制高亮数量 highlights=highlights[:10], # 限制高亮数量
metadata = {}, metadata={},
) )
) )
@@ -699,7 +699,7 @@ class FullTextSearch:
snippet = snippet + "..." snippet = snippet + "..."
# 添加高亮标记 # 添加高亮标记
for term in sorted(all_terms, key = len, reverse = True): # 长的先替换 for term in sorted(all_terms, key=len, reverse=True): # 长的先替换
pattern = re.compile(re.escape(term), re.IGNORECASE) pattern = re.compile(re.escape(term), re.IGNORECASE)
snippet = pattern.sub(f"**{term}**", snippet) snippet = pattern.sub(f"**{term}**", snippet)
@@ -873,7 +873,7 @@ class SemanticSearch:
if len(text) > max_chars: if len(text) > max_chars:
text = text[:max_chars] text = text[:max_chars]
embedding = self.model.encode(text, convert_to_list = True) embedding = self.model.encode(text, convert_to_list=True)
return embedding return embedding
except Exception as e: except Exception as e:
print(f"生成 embedding 失败: {e}") print(f"生成 embedding 失败: {e}")
@@ -1005,13 +1005,13 @@ class SemanticSearch:
results.append( results.append(
SemanticSearchResult( SemanticSearchResult(
id = row["content_id"], id=row["content_id"],
content = content or "", content=content or "",
content_type = row["content_type"], content_type=row["content_type"],
project_id = row["project_id"], project_id=row["project_id"],
similarity = float(similarity), similarity=float(similarity),
embedding = None, # 不返回 embedding 以节省带宽 embedding=None, # 不返回 embedding 以节省带宽
metadata = {}, metadata={},
) )
) )
except Exception as e: except Exception as e:
@@ -1019,7 +1019,7 @@ class SemanticSearch:
continue continue
# 排序并返回 top_k # 排序并返回 top_k
results.sort(key = lambda x: x.similarity, reverse = True) results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k] return results[:top_k]
def _get_content_text(self, content_id: str, content_type: str) -> str | None: def _get_content_text(self, content_id: str, content_type: str) -> str | None:
@@ -1121,18 +1121,18 @@ class SemanticSearch:
results.append( results.append(
SemanticSearchResult( SemanticSearchResult(
id = row["content_id"], id=row["content_id"],
content = content or "", content=content or "",
content_type = row["content_type"], content_type=row["content_type"],
project_id = row["project_id"], project_id=row["project_id"],
similarity = float(similarity), similarity=float(similarity),
metadata = {}, metadata={},
) )
) )
except (KeyError, ValueError): except (KeyError, ValueError):
continue continue
results.sort(key = lambda x: x.similarity, reverse = True) results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k] return results[:top_k]
def delete_embedding(self, content_id: str, content_type: str) -> bool: def delete_embedding(self, content_id: str, content_type: str) -> bool:
@@ -1368,16 +1368,16 @@ class EntityPathDiscovery:
confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0 confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0
return EntityPath( return EntityPath(
path_id = f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}", path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}",
source_entity_id = entity_ids[0], source_entity_id=entity_ids[0],
source_entity_name = nodes[0]["name"] if nodes else "", source_entity_name=nodes[0]["name"] if nodes else "",
target_entity_id = entity_ids[-1], target_entity_id=entity_ids[-1],
target_entity_name = nodes[-1]["name"] if nodes else "", target_entity_name=nodes[-1]["name"] if nodes else "",
path_length = len(entity_ids) - 1, path_length=len(entity_ids) - 1,
nodes = nodes, nodes=nodes,
edges = edges, edges=edges,
confidence = round(confidence, 4), confidence=round(confidence, 4),
path_description = path_desc, path_description=path_desc,
) )
def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]: def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]:
@@ -1463,7 +1463,7 @@ class EntityPathDiscovery:
conn.close() conn.close()
# 按跳数排序 # 按跳数排序
relations.sort(key = lambda x: x["hops"]) relations.sort(key=lambda x: x["hops"])
return relations return relations
def _get_path_to_entity( def _get_path_to_entity(
@@ -1620,7 +1620,7 @@ class EntityPathDiscovery:
conn.close() conn.close()
# 按桥接分数排序 # 按桥接分数排序
bridge_scores.sort(key = lambda x: x["bridge_score"], reverse = True) bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True)
return bridge_scores[:20] # 返回前20 return bridge_scores[:20] # 返回前20
@@ -1676,7 +1676,7 @@ class KnowledgeGapDetection:
# 按严重程度排序 # 按严重程度排序
severity_order = {"high": 0, "medium": 1, "low": 2} severity_order = {"high": 0, "medium": 1, "low": 2}
gaps.sort(key = lambda x: severity_order.get(x.severity, 3)) gaps.sort(key=lambda x: severity_order.get(x.severity, 3))
return gaps return gaps
@@ -1731,18 +1731,18 @@ class KnowledgeGapDetection:
if missing_names: if missing_names:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id = f"gap_attr_{entity_id}", gap_id=f"gap_attr_{entity_id}",
gap_type = "missing_attribute", gap_type="missing_attribute",
entity_id = entity_id, entity_id=entity_id,
entity_name = entity["name"], entity_name=entity["name"],
description = f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}", description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
severity = "medium", severity="medium",
suggestions = [ suggestions=[
f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}", f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}",
"检查属性模板定义是否合理", "检查属性模板定义是否合理",
], ],
related_entities = [], related_entities=[],
metadata = {"missing_attributes": missing_names}, metadata={"missing_attributes": missing_names},
) )
) )
@@ -1793,19 +1793,19 @@ class KnowledgeGapDetection:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id = f"gap_sparse_{entity_id}", gap_id=f"gap_sparse_{entity_id}",
gap_type = "sparse_relation", gap_type="sparse_relation",
entity_id = entity_id, entity_id=entity_id,
entity_name = entity["name"], entity_name=entity["name"],
description = f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)", description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
severity = "medium" if relation_count == 0 else "low", severity="medium" if relation_count == 0 else "low",
suggestions = [ suggestions=[
f"检查转录文本中提及 '{entity['name']}' 的其他实体", f"检查转录文本中提及 '{entity['name']}' 的其他实体",
f"手动添加 '{entity['name']}' 与其他实体的关系", f"手动添加 '{entity['name']}' 与其他实体的关系",
"使用实体对齐功能合并相似实体", "使用实体对齐功能合并相似实体",
], ],
related_entities = [r["id"] for r in potential_related], related_entities=[r["id"] for r in potential_related],
metadata = { metadata={
"relation_count": relation_count, "relation_count": relation_count,
"potential_related": [r["name"] for r in potential_related], "potential_related": [r["name"] for r in potential_related],
}, },
@@ -1837,19 +1837,19 @@ class KnowledgeGapDetection:
for entity in isolated: for entity in isolated:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id = f"gap_iso_{entity['id']}", gap_id=f"gap_iso_{entity['id']}",
gap_type = "isolated_entity", gap_type="isolated_entity",
entity_id = entity["id"], entity_id=entity["id"],
entity_name = entity["name"], entity_name=entity["name"],
description = f"实体 '{entity['name']}' 是孤立实体(没有任何关系)", description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
severity = "high", severity="high",
suggestions = [ suggestions=[
f"检查 '{entity['name']}' 是否应该与其他实体建立关系", f"检查 '{entity['name']}' 是否应该与其他实体建立关系",
f"考虑删除不相关的实体 '{entity['name']}'", f"考虑删除不相关的实体 '{entity['name']}'",
"运行关系发现算法自动识别潜在关系", "运行关系发现算法自动识别潜在关系",
], ],
related_entities = [], related_entities=[],
metadata = {"entity_type": entity["type"]}, metadata={"entity_type": entity["type"]},
) )
) )
@@ -1875,15 +1875,15 @@ class KnowledgeGapDetection:
for entity in incomplete: for entity in incomplete:
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id = f"gap_inc_{entity['id']}", gap_id=f"gap_inc_{entity['id']}",
gap_type = "incomplete_entity", gap_type="incomplete_entity",
entity_id = entity["id"], entity_id=entity["id"],
entity_name = entity["name"], entity_name=entity["name"],
description = f"实体 '{entity['name']}' 缺少定义", description=f"实体 '{entity['name']}' 缺少定义",
severity = "low", severity="low",
suggestions = [f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"], suggestions=[f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"],
related_entities = [], related_entities=[],
metadata = {"entity_type": entity["type"]}, metadata={"entity_type": entity["type"]},
) )
) )
@@ -1925,18 +1925,18 @@ class KnowledgeGapDetection:
if count >= 3: # 出现3次以上 if count >= 3: # 出现3次以上
gaps.append( gaps.append(
KnowledgeGap( KnowledgeGap(
gap_id = f"gap_missing_{hash(entity) % 10000}", gap_id=f"gap_missing_{hash(entity) % 10000}",
gap_type = "missing_key_entity", gap_type="missing_key_entity",
entity_id = None, entity_id=None,
entity_name = None, entity_name=None,
description = f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
severity = "low", severity="low",
suggestions = [ suggestions=[
f"考虑将 '{entity}' 添加为实体", f"考虑将 '{entity}' 添加为实体",
"检查实体提取算法是否需要优化", "检查实体提取算法是否需要优化",
], ],
related_entities = [], related_entities=[],
metadata = {"mention_count": count}, metadata={"mention_count": count},
) )
) )
@@ -2060,12 +2060,12 @@ class SearchManager:
Dict: 混合搜索结果 Dict: 混合搜索结果
""" """
# 全文搜索 # 全文搜索
fulltext_results = self.fulltext_search.search(query, project_id, limit = limit) fulltext_results = self.fulltext_search.search(query, project_id, limit=limit)
# 语义搜索 # 语义搜索
semantic_results = [] semantic_results = []
if self.semantic_search.is_available(): if self.semantic_search.is_available():
semantic_results = self.semantic_search.search(query, project_id, top_k = limit) semantic_results = self.semantic_search.search(query, project_id, top_k=limit)
# 合并结果(去重并加权) # 合并结果(去重并加权)
combined = {} combined = {}
@@ -2104,7 +2104,7 @@ class SearchManager:
# 排序 # 排序
results = list(combined.values()) results = list(combined.values())
results.sort(key = lambda x: x["combined_score"], reverse = True) results.sort(key=lambda x: x["combined_score"], reverse=True)
return { return {
"query": query, "query": query,
@@ -2226,7 +2226,7 @@ def fulltext_search(
) -> list[SearchResult]: ) -> list[SearchResult]:
"""全文搜索便捷函数""" """全文搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit = limit) return manager.fulltext_search.search(query, project_id, limit=limit)
def semantic_search( def semantic_search(
@@ -2234,7 +2234,7 @@ def semantic_search(
) -> list[SemanticSearchResult]: ) -> list[SemanticSearchResult]:
"""语义搜索便捷函数""" """语义搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k = top_k) return manager.semantic_search.search(query, project_id, top_k=top_k)
def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None:

View File

@@ -86,7 +86,7 @@ class AuditLog:
after_value: str | None = None after_value: str | None = None
success: bool = True success: bool = True
error_message: str | None = None error_message: str | None = None
created_at: str = field(default_factory = lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -103,8 +103,8 @@ class EncryptionConfig:
key_derivation: str = "pbkdf2" # pbkdf2, argon2 key_derivation: str = "pbkdf2" # pbkdf2, argon2
master_key_hash: str | None = None # 主密钥哈希(用于验证) master_key_hash: str | None = None # 主密钥哈希(用于验证)
salt: str | None = None salt: str | None = None
created_at: str = field(default_factory = lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -123,8 +123,8 @@ class MaskingRule:
is_active: bool = True is_active: bool = True
priority: int = 0 priority: int = 0
description: str | None = None description: str | None = None
created_at: str = field(default_factory = lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -145,8 +145,8 @@ class DataAccessPolicy:
max_access_count: int | None = None # 最大访问次数 max_access_count: int | None = None # 最大访问次数
require_approval: bool = False require_approval: bool = False
is_active: bool = True is_active: bool = True
created_at: str = field(default_factory = lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory = lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -164,7 +164,7 @@ class AccessRequest:
approved_by: str | None = None approved_by: str | None = None
approved_at: str | None = None approved_at: str | None = None
expires_at: str | None = None expires_at: str | None = None
created_at: str = field(default_factory = lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -345,18 +345,18 @@ class SecurityManager:
) -> AuditLog: ) -> AuditLog:
"""记录审计日志""" """记录审计日志"""
log = AuditLog( log = AuditLog(
id = self._generate_id(), id=self._generate_id(),
action_type = action_type.value, action_type=action_type.value,
user_id = user_id, user_id=user_id,
user_ip = user_ip, user_ip=user_ip,
user_agent = user_agent, user_agent=user_agent,
resource_type = resource_type, resource_type=resource_type,
resource_id = resource_id, resource_id=resource_id,
action_details = json.dumps(action_details) if action_details else None, action_details=json.dumps(action_details) if action_details else None,
before_value = before_value, before_value=before_value,
after_value = after_value, after_value=after_value,
success = success, success=success,
error_message = error_message, error_message=error_message,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -444,19 +444,19 @@ class SecurityManager:
for row in rows: for row in rows:
log = AuditLog( log = AuditLog(
id = row[0], id=row[0],
action_type = row[1], action_type=row[1],
user_id = row[2], user_id=row[2],
user_ip = row[3], user_ip=row[3],
user_agent = row[4], user_agent=row[4],
resource_type = row[5], resource_type=row[5],
resource_id = row[6], resource_id=row[6],
action_details = row[7], action_details=row[7],
before_value = row[8], before_value=row[8],
after_value = row[9], after_value=row[9],
success = bool(row[10]), success=bool(row[10]),
error_message = row[11], error_message=row[11],
created_at = row[12], created_at=row[12],
) )
logs.append(log) logs.append(log)
@@ -513,10 +513,10 @@ class SecurityManager:
raise RuntimeError("cryptography library not available") raise RuntimeError("cryptography library not available")
kdf = PBKDF2HMAC( kdf = PBKDF2HMAC(
algorithm = hashes.SHA256(), algorithm=hashes.SHA256(),
length = 32, length=32,
salt = salt, salt=salt,
iterations = 100000, iterations=100000,
) )
return base64.urlsafe_b64encode(kdf.derive(password.encode())) return base64.urlsafe_b64encode(kdf.derive(password.encode()))
@@ -533,13 +533,13 @@ class SecurityManager:
key_hash = hashlib.sha256(key).hexdigest() key_hash = hashlib.sha256(key).hexdigest()
config = EncryptionConfig( config = EncryptionConfig(
id = self._generate_id(), id=self._generate_id(),
project_id = project_id, project_id=project_id,
is_enabled = True, is_enabled=True,
encryption_type = "aes-256-gcm", encryption_type="aes-256-gcm",
key_derivation = "pbkdf2", key_derivation="pbkdf2",
master_key_hash = key_hash, master_key_hash=key_hash,
salt = salt, salt=salt,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -593,10 +593,10 @@ class SecurityManager:
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(
action_type = AuditActionType.ENCRYPTION_ENABLE, action_type=AuditActionType.ENCRYPTION_ENABLE,
resource_type = "project", resource_type="project",
resource_id = project_id, resource_id=project_id,
action_details = {"encryption_type": config.encryption_type}, action_details={"encryption_type": config.encryption_type},
) )
return config return config
@@ -624,9 +624,9 @@ class SecurityManager:
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(
action_type = AuditActionType.ENCRYPTION_DISABLE, action_type=AuditActionType.ENCRYPTION_DISABLE,
resource_type = "project", resource_type="project",
resource_id = project_id, resource_id=project_id,
) )
return True return True
@@ -668,15 +668,15 @@ class SecurityManager:
return None return None
return EncryptionConfig( return EncryptionConfig(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
is_enabled = bool(row[2]), is_enabled=bool(row[2]),
encryption_type = row[3], encryption_type=row[3],
key_derivation = row[4], key_derivation=row[4],
master_key_hash = row[5], master_key_hash=row[5],
salt = row[6], salt=row[6],
created_at = row[7], created_at=row[7],
updated_at = row[8], updated_at=row[8],
) )
def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]: def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
@@ -724,14 +724,14 @@ class SecurityManager:
replacement = replacement or default["replacement"] replacement = replacement or default["replacement"]
rule = MaskingRule( rule = MaskingRule(
id = self._generate_id(), id=self._generate_id(),
project_id = project_id, project_id=project_id,
name = name, name=name,
rule_type = rule_type.value, rule_type=rule_type.value,
pattern = pattern or "", pattern=pattern or "",
replacement = replacement or "****", replacement=replacement or "****",
description = description, description=description,
priority = priority, priority=priority,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -764,10 +764,10 @@ class SecurityManager:
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(
action_type = AuditActionType.DATA_MASKING, action_type=AuditActionType.DATA_MASKING,
resource_type = "project", resource_type="project",
resource_id = project_id, resource_id=project_id,
action_details = {"action": "create_rule", "rule_name": name}, action_details={"action": "create_rule", "rule_name": name},
) )
return rule return rule
@@ -793,17 +793,17 @@ class SecurityManager:
for row in rows: for row in rows:
rules.append( rules.append(
MaskingRule( MaskingRule(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
name = row[2], name=row[2],
rule_type = row[3], rule_type=row[3],
pattern = row[4], pattern=row[4],
replacement = row[5], replacement=row[5],
is_active = bool(row[6]), is_active=bool(row[6]),
priority = row[7], priority=row[7],
description = row[8], description=row[8],
created_at = row[9], created_at=row[9],
updated_at = row[10], updated_at=row[10],
) )
) )
@@ -855,17 +855,17 @@ class SecurityManager:
return None return None
return MaskingRule( return MaskingRule(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
name = row[2], name=row[2],
rule_type = row[3], rule_type=row[3],
pattern = row[4], pattern=row[4],
replacement = row[5], replacement=row[5],
is_active = bool(row[6]), is_active=bool(row[6]),
priority = row[7], priority=row[7],
description = row[8], description=row[8],
created_at = row[9], created_at=row[9],
updated_at = row[10], updated_at=row[10],
) )
def delete_masking_rule(self, rule_id: str) -> bool: def delete_masking_rule(self, rule_id: str) -> bool:
@@ -936,16 +936,16 @@ class SecurityManager:
) -> DataAccessPolicy: ) -> DataAccessPolicy:
"""创建数据访问策略""" """创建数据访问策略"""
policy = DataAccessPolicy( policy = DataAccessPolicy(
id = self._generate_id(), id=self._generate_id(),
project_id = project_id, project_id=project_id,
name = name, name=name,
description = description, description=description,
allowed_users = json.dumps(allowed_users) if allowed_users else None, allowed_users=json.dumps(allowed_users) if allowed_users else None,
allowed_roles = json.dumps(allowed_roles) if allowed_roles else None, allowed_roles=json.dumps(allowed_roles) if allowed_roles else None,
allowed_ips = json.dumps(allowed_ips) if allowed_ips else None, allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
time_restrictions = json.dumps(time_restrictions) if time_restrictions else None, time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
max_access_count = max_access_count, max_access_count=max_access_count,
require_approval = require_approval, require_approval=require_approval,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1002,19 +1002,19 @@ class SecurityManager:
for row in rows: for row in rows:
policies.append( policies.append(
DataAccessPolicy( DataAccessPolicy(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
name = row[2], name=row[2],
description = row[3], description=row[3],
allowed_users = row[4], allowed_users=row[4],
allowed_roles = row[5], allowed_roles=row[5],
allowed_ips = row[6], allowed_ips=row[6],
time_restrictions = row[7], time_restrictions=row[7],
max_access_count = row[8], max_access_count=row[8],
require_approval = bool(row[9]), require_approval=bool(row[9]),
is_active = bool(row[10]), is_active=bool(row[10]),
created_at = row[11], created_at=row[11],
updated_at = row[12], updated_at=row[12],
) )
) )
@@ -1037,19 +1037,19 @@ class SecurityManager:
return False, "Policy not found or inactive" return False, "Policy not found or inactive"
policy = DataAccessPolicy( policy = DataAccessPolicy(
id = row[0], id=row[0],
project_id = row[1], project_id=row[1],
name = row[2], name=row[2],
description = row[3], description=row[3],
allowed_users = row[4], allowed_users=row[4],
allowed_roles = row[5], allowed_roles=row[5],
allowed_ips = row[6], allowed_ips=row[6],
time_restrictions = row[7], time_restrictions=row[7],
max_access_count = row[8], max_access_count=row[8],
require_approval = bool(row[9]), require_approval=bool(row[9]),
is_active = bool(row[10]), is_active=bool(row[10]),
created_at = row[11], created_at=row[11],
updated_at = row[12], updated_at=row[12],
) )
# 检查用户白名单 # 检查用户白名单
@@ -1113,7 +1113,7 @@ class SecurityManager:
try: try:
if "/" in pattern: if "/" in pattern:
# CIDR 表示法 # CIDR 表示法
network = ipaddress.ip_network(pattern, strict = False) network = ipaddress.ip_network(pattern, strict=False)
return ipaddress.ip_address(ip) in network return ipaddress.ip_address(ip) in network
else: else:
# 精确匹配 # 精确匹配
@@ -1130,11 +1130,11 @@ class SecurityManager:
) -> AccessRequest: ) -> AccessRequest:
"""创建访问请求""" """创建访问请求"""
request = AccessRequest( request = AccessRequest(
id = self._generate_id(), id=self._generate_id(),
policy_id = policy_id, policy_id=policy_id,
user_id = user_id, user_id=user_id,
request_reason = request_reason, request_reason=request_reason,
expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat(), expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(),
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1169,7 +1169,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat() expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
approved_at = datetime.now().isoformat() approved_at = datetime.now().isoformat()
cursor.execute( cursor.execute(
@@ -1192,15 +1192,15 @@ class SecurityManager:
return None return None
return AccessRequest( return AccessRequest(
id = row[0], id=row[0],
policy_id = row[1], policy_id=row[1],
user_id = row[2], user_id=row[2],
request_reason = row[3], request_reason=row[3],
status = row[4], status=row[4],
approved_by = row[5], approved_by=row[5],
approved_at = row[6], approved_at=row[6],
expires_at = row[7], expires_at=row[7],
created_at = row[8], created_at=row[8],
) )
def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None: def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
@@ -1227,15 +1227,15 @@ class SecurityManager:
return None return None
return AccessRequest( return AccessRequest(
id = row[0], id=row[0],
policy_id = row[1], policy_id=row[1],
user_id = row[2], user_id=row[2],
request_reason = row[3], request_reason=row[3],
status = row[4], status=row[4],
approved_by = row[5], approved_by=row[5],
approved_at = row[6], approved_at=row[6],
expires_at = row[7], expires_at=row[7],
created_at = row[8], created_at=row[8],
) )

View File

@@ -635,19 +635,19 @@ class SubscriptionManager:
plan_id = str(uuid.uuid4()) plan_id = str(uuid.uuid4())
plan = SubscriptionPlan( plan = SubscriptionPlan(
id = plan_id, id=plan_id,
name = name, name=name,
tier = tier, tier=tier,
description = description, description=description,
price_monthly = price_monthly, price_monthly=price_monthly,
price_yearly = price_yearly, price_yearly=price_yearly,
currency = currency, currency=currency,
features = features or [], features=features or [],
limits = limits or {}, limits=limits or {},
is_active = True, is_active=True,
created_at = datetime.now(), created_at=datetime.now(),
updated_at = datetime.now(), updated_at=datetime.now(),
metadata = {}, metadata={},
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -777,36 +777,36 @@ class SubscriptionManager:
# 计算周期 # 计算周期
if billing_cycle == "yearly": if billing_cycle == "yearly":
period_end = now + timedelta(days = 365) period_end = now + timedelta(days=365)
else: else:
period_end = now + timedelta(days = 30) period_end = now + timedelta(days=30)
# 试用处理 # 试用处理
trial_start = None trial_start = None
trial_end = None trial_end = None
if trial_days > 0: if trial_days > 0:
trial_start = now trial_start = now
trial_end = now + timedelta(days = trial_days) trial_end = now + timedelta(days=trial_days)
status = SubscriptionStatus.TRIAL.value status = SubscriptionStatus.TRIAL.value
else: else:
status = SubscriptionStatus.PENDING.value status = SubscriptionStatus.PENDING.value
subscription = Subscription( subscription = Subscription(
id = subscription_id, id=subscription_id,
tenant_id = tenant_id, tenant_id=tenant_id,
plan_id = plan_id, plan_id=plan_id,
status = status, status=status,
current_period_start = now, current_period_start=now,
current_period_end = period_end, current_period_end=period_end,
cancel_at_period_end = False, cancel_at_period_end=False,
canceled_at = None, canceled_at=None,
trial_start = trial_start, trial_start=trial_start,
trial_end = trial_end, trial_end=trial_end,
payment_provider = payment_provider, payment_provider=payment_provider,
provider_subscription_id = None, provider_subscription_id=None,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
metadata = {"billing_cycle": billing_cycle}, metadata={"billing_cycle": billing_cycle},
) )
cursor.execute( cursor.execute(
@@ -1087,15 +1087,15 @@ class SubscriptionManager:
record_id = str(uuid.uuid4()) record_id = str(uuid.uuid4())
record = UsageRecord( record = UsageRecord(
id = record_id, id=record_id,
tenant_id = tenant_id, tenant_id=tenant_id,
resource_type = resource_type, resource_type=resource_type,
quantity = quantity, quantity=quantity,
unit = unit, unit=unit,
recorded_at = datetime.now(), recorded_at=datetime.now(),
cost = cost, cost=cost,
description = description, description=description,
metadata = metadata or {}, metadata=metadata or {},
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1214,22 +1214,22 @@ class SubscriptionManager:
now = datetime.now() now = datetime.now()
payment = Payment( payment = Payment(
id = payment_id, id=payment_id,
tenant_id = tenant_id, tenant_id=tenant_id,
subscription_id = subscription_id, subscription_id=subscription_id,
invoice_id = invoice_id, invoice_id=invoice_id,
amount = amount, amount=amount,
currency = currency, currency=currency,
provider = provider, provider=provider,
provider_payment_id = None, provider_payment_id=None,
status = PaymentStatus.PENDING.value, status=PaymentStatus.PENDING.value,
payment_method = payment_method, payment_method=payment_method,
payment_details = payment_details or {}, payment_details=payment_details or {},
paid_at = None, paid_at=None,
failed_at = None, failed_at=None,
failure_reason = None, failure_reason=None,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1414,27 +1414,27 @@ class SubscriptionManager:
invoice_id = str(uuid.uuid4()) invoice_id = str(uuid.uuid4())
invoice_number = self._generate_invoice_number() invoice_number = self._generate_invoice_number()
now = datetime.now() now = datetime.now()
due_date = now + timedelta(days = 7) # 7天付款期限 due_date = now + timedelta(days=7) # 7天付款期限
invoice = Invoice( invoice = Invoice(
id = invoice_id, id=invoice_id,
tenant_id = tenant_id, tenant_id=tenant_id,
subscription_id = subscription_id, subscription_id=subscription_id,
invoice_number = invoice_number, invoice_number=invoice_number,
status = InvoiceStatus.DRAFT.value, status=InvoiceStatus.DRAFT.value,
amount_due = amount, amount_due=amount,
amount_paid = 0, amount_paid=0,
currency = currency, currency=currency,
period_start = period_start, period_start=period_start,
period_end = period_end, period_end=period_end,
description = description, description=description,
line_items = line_items or [{"description": description, "amount": amount}], line_items=line_items or [{"description": description, "amount": amount}],
due_date = due_date, due_date=due_date,
paid_at = None, paid_at=None,
voided_at = None, voided_at=None,
void_reason = None, void_reason=None,
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1604,23 +1604,23 @@ class SubscriptionManager:
now = datetime.now() now = datetime.now()
refund = Refund( refund = Refund(
id = refund_id, id=refund_id,
tenant_id = tenant_id, tenant_id=tenant_id,
payment_id = payment_id, payment_id=payment_id,
invoice_id = payment.invoice_id, invoice_id=payment.invoice_id,
amount = amount, amount=amount,
currency = payment.currency, currency=payment.currency,
reason = reason, reason=reason,
status = RefundStatus.PENDING.value, status=RefundStatus.PENDING.value,
requested_by = requested_by, requested_by=requested_by,
requested_at = now, requested_at=now,
approved_by = None, approved_by=None,
approved_at = None, approved_at=None,
completed_at = None, completed_at=None,
provider_refund_id = None, provider_refund_id=None,
metadata = {}, metadata={},
created_at = now, created_at=now,
updated_at = now, updated_at=now,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1962,126 +1962,126 @@ class SubscriptionManager:
def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan: def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan:
"""数据库行转换为 SubscriptionPlan 对象""" """数据库行转换为 SubscriptionPlan 对象"""
return SubscriptionPlan( return SubscriptionPlan(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
tier = row["tier"], tier=row["tier"],
description = row["description"] or "", description=row["description"] or "",
price_monthly = row["price_monthly"], price_monthly=row["price_monthly"],
price_yearly = row["price_yearly"], price_yearly=row["price_yearly"],
currency = row["currency"], currency=row["currency"],
features = json.loads(row["features"] or "[]"), features=json.loads(row["features"] or "[]"),
limits = json.loads(row["limits"] or "{}"), limits=json.loads(row["limits"] or "{}"),
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
metadata = json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
def _row_to_subscription(self, row: sqlite3.Row) -> Subscription: def _row_to_subscription(self, row: sqlite3.Row) -> Subscription:
"""数据库行转换为 Subscription 对象""" """数据库行转换为 Subscription 对象"""
return Subscription( return Subscription(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
plan_id = row["plan_id"], plan_id=row["plan_id"],
status = row["status"], status=row["status"],
current_period_start = ( current_period_start=(
datetime.fromisoformat(row["current_period_start"]) datetime.fromisoformat(row["current_period_start"])
if row["current_period_start"] and isinstance(row["current_period_start"], str) if row["current_period_start"] and isinstance(row["current_period_start"], str)
else row["current_period_start"] else row["current_period_start"]
), ),
current_period_end = ( current_period_end=(
datetime.fromisoformat(row["current_period_end"]) datetime.fromisoformat(row["current_period_end"])
if row["current_period_end"] and isinstance(row["current_period_end"], str) if row["current_period_end"] and isinstance(row["current_period_end"], str)
else row["current_period_end"] else row["current_period_end"]
), ),
cancel_at_period_end = bool(row["cancel_at_period_end"]), cancel_at_period_end=bool(row["cancel_at_period_end"]),
canceled_at = ( canceled_at=(
datetime.fromisoformat(row["canceled_at"]) datetime.fromisoformat(row["canceled_at"])
if row["canceled_at"] and isinstance(row["canceled_at"], str) if row["canceled_at"] and isinstance(row["canceled_at"], str)
else row["canceled_at"] else row["canceled_at"]
), ),
trial_start = ( trial_start=(
datetime.fromisoformat(row["trial_start"]) datetime.fromisoformat(row["trial_start"])
if row["trial_start"] and isinstance(row["trial_start"], str) if row["trial_start"] and isinstance(row["trial_start"], str)
else row["trial_start"] else row["trial_start"]
), ),
trial_end = ( trial_end=(
datetime.fromisoformat(row["trial_end"]) datetime.fromisoformat(row["trial_end"])
if row["trial_end"] and isinstance(row["trial_end"], str) if row["trial_end"] and isinstance(row["trial_end"], str)
else row["trial_end"] else row["trial_end"]
), ),
payment_provider = row["payment_provider"], payment_provider=row["payment_provider"],
provider_subscription_id = row["provider_subscription_id"], provider_subscription_id=row["provider_subscription_id"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
metadata = json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord: def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord:
"""数据库行转换为 UsageRecord 对象""" """数据库行转换为 UsageRecord 对象"""
return UsageRecord( return UsageRecord(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
resource_type = row["resource_type"], resource_type=row["resource_type"],
quantity = row["quantity"], quantity=row["quantity"],
unit = row["unit"], unit=row["unit"],
recorded_at = ( recorded_at=(
datetime.fromisoformat(row["recorded_at"]) datetime.fromisoformat(row["recorded_at"])
if isinstance(row["recorded_at"], str) if isinstance(row["recorded_at"], str)
else row["recorded_at"] else row["recorded_at"]
), ),
cost = row["cost"], cost=row["cost"],
description = row["description"], description=row["description"],
metadata = json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
def _row_to_payment(self, row: sqlite3.Row) -> Payment: def _row_to_payment(self, row: sqlite3.Row) -> Payment:
"""数据库行转换为 Payment 对象""" """数据库行转换为 Payment 对象"""
return Payment( return Payment(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
subscription_id = row["subscription_id"], subscription_id=row["subscription_id"],
invoice_id = row["invoice_id"], invoice_id=row["invoice_id"],
amount = row["amount"], amount=row["amount"],
currency = row["currency"], currency=row["currency"],
provider = row["provider"], provider=row["provider"],
provider_payment_id = row["provider_payment_id"], provider_payment_id=row["provider_payment_id"],
status = row["status"], status=row["status"],
payment_method = row["payment_method"], payment_method=row["payment_method"],
payment_details = json.loads(row["payment_details"] or "{}"), payment_details=json.loads(row["payment_details"] or "{}"),
paid_at = ( paid_at=(
datetime.fromisoformat(row["paid_at"]) datetime.fromisoformat(row["paid_at"])
if row["paid_at"] and isinstance(row["paid_at"], str) if row["paid_at"] and isinstance(row["paid_at"], str)
else row["paid_at"] else row["paid_at"]
), ),
failed_at = ( failed_at=(
datetime.fromisoformat(row["failed_at"]) datetime.fromisoformat(row["failed_at"])
if row["failed_at"] and isinstance(row["failed_at"], str) if row["failed_at"] and isinstance(row["failed_at"], str)
else row["failed_at"] else row["failed_at"]
), ),
failure_reason = row["failure_reason"], failure_reason=row["failure_reason"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2091,48 +2091,48 @@ class SubscriptionManager:
def _row_to_invoice(self, row: sqlite3.Row) -> Invoice: def _row_to_invoice(self, row: sqlite3.Row) -> Invoice:
"""数据库行转换为 Invoice 对象""" """数据库行转换为 Invoice 对象"""
return Invoice( return Invoice(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
subscription_id = row["subscription_id"], subscription_id=row["subscription_id"],
invoice_number = row["invoice_number"], invoice_number=row["invoice_number"],
status = row["status"], status=row["status"],
amount_due = row["amount_due"], amount_due=row["amount_due"],
amount_paid = row["amount_paid"], amount_paid=row["amount_paid"],
currency = row["currency"], currency=row["currency"],
period_start = ( period_start=(
datetime.fromisoformat(row["period_start"]) datetime.fromisoformat(row["period_start"])
if row["period_start"] and isinstance(row["period_start"], str) if row["period_start"] and isinstance(row["period_start"], str)
else row["period_start"] else row["period_start"]
), ),
period_end = ( period_end=(
datetime.fromisoformat(row["period_end"]) datetime.fromisoformat(row["period_end"])
if row["period_end"] and isinstance(row["period_end"], str) if row["period_end"] and isinstance(row["period_end"], str)
else row["period_end"] else row["period_end"]
), ),
description = row["description"], description=row["description"],
line_items = json.loads(row["line_items"] or "[]"), line_items=json.loads(row["line_items"] or "[]"),
due_date = ( due_date=(
datetime.fromisoformat(row["due_date"]) datetime.fromisoformat(row["due_date"])
if row["due_date"] and isinstance(row["due_date"], str) if row["due_date"] and isinstance(row["due_date"], str)
else row["due_date"] else row["due_date"]
), ),
paid_at = ( paid_at=(
datetime.fromisoformat(row["paid_at"]) datetime.fromisoformat(row["paid_at"])
if row["paid_at"] and isinstance(row["paid_at"], str) if row["paid_at"] and isinstance(row["paid_at"], str)
else row["paid_at"] else row["paid_at"]
), ),
voided_at = ( voided_at=(
datetime.fromisoformat(row["voided_at"]) datetime.fromisoformat(row["voided_at"])
if row["voided_at"] and isinstance(row["voided_at"], str) if row["voided_at"] and isinstance(row["voided_at"], str)
else row["voided_at"] else row["voided_at"]
), ),
void_reason = row["void_reason"], void_reason=row["void_reason"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2142,39 +2142,39 @@ class SubscriptionManager:
def _row_to_refund(self, row: sqlite3.Row) -> Refund: def _row_to_refund(self, row: sqlite3.Row) -> Refund:
"""数据库行转换为 Refund 对象""" """数据库行转换为 Refund 对象"""
return Refund( return Refund(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
payment_id = row["payment_id"], payment_id=row["payment_id"],
invoice_id = row["invoice_id"], invoice_id=row["invoice_id"],
amount = row["amount"], amount=row["amount"],
currency = row["currency"], currency=row["currency"],
reason = row["reason"], reason=row["reason"],
status = row["status"], status=row["status"],
requested_by = row["requested_by"], requested_by=row["requested_by"],
requested_at = ( requested_at=(
datetime.fromisoformat(row["requested_at"]) datetime.fromisoformat(row["requested_at"])
if isinstance(row["requested_at"], str) if isinstance(row["requested_at"], str)
else row["requested_at"] else row["requested_at"]
), ),
approved_by = row["approved_by"], approved_by=row["approved_by"],
approved_at = ( approved_at=(
datetime.fromisoformat(row["approved_at"]) datetime.fromisoformat(row["approved_at"])
if row["approved_at"] and isinstance(row["approved_at"], str) if row["approved_at"] and isinstance(row["approved_at"], str)
else row["approved_at"] else row["approved_at"]
), ),
completed_at = ( completed_at=(
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str) if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"] else row["completed_at"]
), ),
provider_refund_id = row["provider_refund_id"], provider_refund_id=row["provider_refund_id"],
metadata = json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -2184,20 +2184,20 @@ class SubscriptionManager:
def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory: def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory:
"""数据库行转换为 BillingHistory 对象""" """数据库行转换为 BillingHistory 对象"""
return BillingHistory( return BillingHistory(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
type = row["type"], type=row["type"],
amount = row["amount"], amount=row["amount"],
currency = row["currency"], currency=row["currency"],
description = row["description"], description=row["description"],
reference_id = row["reference_id"], reference_id=row["reference_id"],
balance_after = row["balance_after"], balance_after=row["balance_after"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
metadata = json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )

View File

@@ -437,19 +437,19 @@ class TenantManager:
) )
tenant = Tenant( tenant = Tenant(
id = tenant_id, id=tenant_id,
name = name, name=name,
slug = slug, slug=slug,
description = description, description=description,
tier = tier, tier=tier,
status = TenantStatus.PENDING.value, status=TenantStatus.PENDING.value,
owner_id = owner_id, owner_id=owner_id,
created_at = datetime.now(), created_at=datetime.now(),
updated_at = datetime.now(), updated_at=datetime.now(),
expires_at = None, expires_at=None,
settings = settings or {}, settings=settings or {},
resource_limits = resource_limits, resource_limits=resource_limits,
metadata = {}, metadata={},
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -661,18 +661,18 @@ class TenantManager:
domain_id = str(uuid.uuid4()) domain_id = str(uuid.uuid4())
tenant_domain = TenantDomain( tenant_domain = TenantDomain(
id = domain_id, id=domain_id,
tenant_id = tenant_id, tenant_id=tenant_id,
domain = domain.lower(), domain=domain.lower(),
status = DomainStatus.PENDING.value, status=DomainStatus.PENDING.value,
verification_token = verification_token, verification_token=verification_token,
verification_method = verification_method, verification_method=verification_method,
verified_at = None, verified_at=None,
created_at = datetime.now(), created_at=datetime.now(),
updated_at = datetime.now(), updated_at=datetime.now(),
is_primary = is_primary, is_primary=is_primary,
ssl_enabled = False, ssl_enabled=False,
ssl_expires_at = None, ssl_expires_at=None,
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1022,17 +1022,17 @@ class TenantManager:
final_permissions = permissions or default_permissions final_permissions = permissions or default_permissions
member = TenantMember( member = TenantMember(
id = member_id, id=member_id,
tenant_id = tenant_id, tenant_id=tenant_id,
user_id = "pending", # 临时值,待用户接受邀请后更新 user_id="pending", # 临时值,待用户接受邀请后更新
email = email, email=email,
role = role, role=role,
permissions = final_permissions, permissions=final_permissions,
invited_by = invited_by, invited_by=invited_by,
invited_at = datetime.now(), invited_at=datetime.now(),
joined_at = None, joined_at=None,
last_active_at = None, last_active_at=None,
status = "pending", status="pending",
) )
cursor = conn.cursor() cursor = conn.cursor()
@@ -1497,60 +1497,60 @@ class TenantManager:
def _row_to_tenant(self, row: sqlite3.Row) -> Tenant: def _row_to_tenant(self, row: sqlite3.Row) -> Tenant:
"""数据库行转换为 Tenant 对象""" """数据库行转换为 Tenant 对象"""
return Tenant( return Tenant(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
slug = row["slug"], slug=row["slug"],
description = row["description"], description=row["description"],
tier = row["tier"], tier=row["tier"],
status = row["status"], status=row["status"],
owner_id = row["owner_id"], owner_id=row["owner_id"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
expires_at = ( expires_at=(
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
if row["expires_at"] and isinstance(row["expires_at"], str) if row["expires_at"] and isinstance(row["expires_at"], str)
else row["expires_at"] else row["expires_at"]
), ),
settings = json.loads(row["settings"] or "{}"), settings=json.loads(row["settings"] or "{}"),
resource_limits = json.loads(row["resource_limits"] or "{}"), resource_limits=json.loads(row["resource_limits"] or "{}"),
metadata = json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain: def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain:
"""数据库行转换为 TenantDomain 对象""" """数据库行转换为 TenantDomain 对象"""
return TenantDomain( return TenantDomain(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
domain = row["domain"], domain=row["domain"],
status = row["status"], status=row["status"],
verification_token = row["verification_token"], verification_token=row["verification_token"],
verification_method = row["verification_method"], verification_method=row["verification_method"],
verified_at = ( verified_at=(
datetime.fromisoformat(row["verified_at"]) datetime.fromisoformat(row["verified_at"])
if row["verified_at"] and isinstance(row["verified_at"], str) if row["verified_at"] and isinstance(row["verified_at"], str)
else row["verified_at"] else row["verified_at"]
), ),
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
), ),
is_primary = bool(row["is_primary"]), is_primary=bool(row["is_primary"]),
ssl_enabled = bool(row["ssl_enabled"]), ssl_enabled=bool(row["ssl_enabled"]),
ssl_expires_at = ( ssl_expires_at=(
datetime.fromisoformat(row["ssl_expires_at"]) datetime.fromisoformat(row["ssl_expires_at"])
if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str) if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str)
else row["ssl_expires_at"] else row["ssl_expires_at"]
@@ -1560,22 +1560,22 @@ class TenantManager:
def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding: def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding:
"""数据库行转换为 TenantBranding 对象""" """数据库行转换为 TenantBranding 对象"""
return TenantBranding( return TenantBranding(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
logo_url = row["logo_url"], logo_url=row["logo_url"],
favicon_url = row["favicon_url"], favicon_url=row["favicon_url"],
primary_color = row["primary_color"], primary_color=row["primary_color"],
secondary_color = row["secondary_color"], secondary_color=row["secondary_color"],
custom_css = row["custom_css"], custom_css=row["custom_css"],
custom_js = row["custom_js"], custom_js=row["custom_js"],
login_page_bg = row["login_page_bg"], login_page_bg=row["login_page_bg"],
email_template = row["email_template"], email_template=row["email_template"],
created_at = ( created_at=(
datetime.fromisoformat(row["created_at"]) datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str) if isinstance(row["created_at"], str)
else row["created_at"] else row["created_at"]
), ),
updated_at = ( updated_at=(
datetime.fromisoformat(row["updated_at"]) datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str) if isinstance(row["updated_at"], str)
else row["updated_at"] else row["updated_at"]
@@ -1585,29 +1585,29 @@ class TenantManager:
def _row_to_member(self, row: sqlite3.Row) -> TenantMember: def _row_to_member(self, row: sqlite3.Row) -> TenantMember:
"""数据库行转换为 TenantMember 对象""" """数据库行转换为 TenantMember 对象"""
return TenantMember( return TenantMember(
id = row["id"], id=row["id"],
tenant_id = row["tenant_id"], tenant_id=row["tenant_id"],
user_id = row["user_id"], user_id=row["user_id"],
email = row["email"], email=row["email"],
role = row["role"], role=row["role"],
permissions = json.loads(row["permissions"] or "[]"), permissions=json.loads(row["permissions"] or "[]"),
invited_by = row["invited_by"], invited_by=row["invited_by"],
invited_at = ( invited_at=(
datetime.fromisoformat(row["invited_at"]) datetime.fromisoformat(row["invited_at"])
if isinstance(row["invited_at"], str) if isinstance(row["invited_at"], str)
else row["invited_at"] else row["invited_at"]
), ),
joined_at = ( joined_at=(
datetime.fromisoformat(row["joined_at"]) datetime.fromisoformat(row["joined_at"])
if row["joined_at"] and isinstance(row["joined_at"], str) if row["joined_at"] and isinstance(row["joined_at"], str)
else row["joined_at"] else row["joined_at"]
), ),
last_active_at = ( last_active_at=(
datetime.fromisoformat(row["last_active_at"]) datetime.fromisoformat(row["last_active_at"])
if row["last_active_at"] and isinstance(row["last_active_at"], str) if row["last_active_at"] and isinstance(row["last_active_at"], str)
else row["last_active_at"] else row["last_active_at"]
), ),
status = row["status"], status=row["status"],
) )

View File

@@ -32,16 +32,16 @@ def test_fulltext_search() -> None:
# 测试索引创建 # 测试索引创建
print("\n1. 测试索引创建...") print("\n1. 测试索引创建...")
success = search.index_content( success = search.index_content(
content_id = "test_entity_1", content_id="test_entity_1",
content_type = "entity", content_type="entity",
project_id = "test_project", project_id="test_project",
text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。", text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
) )
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
# 测试搜索 # 测试搜索
print("\n2. 测试关键词搜索...") print("\n2. 测试关键词搜索...")
results = search.search("测试", project_id = "test_project") results = search.search("测试", project_id="test_project")
print(f" 搜索结果数量: {len(results)}") print(f" 搜索结果数量: {len(results)}")
if results: if results:
print(f" 第一个结果: {results[0].content[:50]}...") print(f" 第一个结果: {results[0].content[:50]}...")
@@ -49,10 +49,10 @@ def test_fulltext_search() -> None:
# 测试布尔搜索 # 测试布尔搜索
print("\n3. 测试布尔搜索...") print("\n3. 测试布尔搜索...")
results = search.search("测试 AND 全文", project_id = "test_project") results = search.search("测试 AND 全文", project_id="test_project")
print(f" AND 搜索结果: {len(results)}") print(f" AND 搜索结果: {len(results)}")
results = search.search("测试 OR 关键词", project_id = "test_project") results = search.search("测试 OR 关键词", project_id="test_project")
print(f" OR 搜索结果: {len(results)}") print(f" OR 搜索结果: {len(results)}")
# 测试高亮 # 测试高亮
@@ -89,10 +89,10 @@ def test_semantic_search() -> None:
# 测试索引 # 测试索引
print("\n3. 测试语义索引...") print("\n3. 测试语义索引...")
success = semantic.index_embedding( success = semantic.index_embedding(
content_id = "test_content_1", content_id="test_content_1",
content_type = "transcript", content_type="transcript",
project_id = "test_project", project_id="test_project",
text = "这是用于语义搜索测试的文本内容。", text="这是用于语义搜索测试的文本内容。",
) )
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
@@ -150,7 +150,7 @@ def test_cache_manager() -> None:
print("\n2. 测试缓存操作...") print("\n2. 测试缓存操作...")
# 设置缓存 # 设置缓存
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60) cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60)
print(" ✓ 设置缓存 test_key_1") print(" ✓ 设置缓存 test_key_1")
# 获取缓存 # 获取缓存
@@ -159,7 +159,7 @@ def test_cache_manager() -> None:
# 批量操作 # 批量操作
cache.set_many( cache.set_many(
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60 {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60
) )
print(" ✓ 批量设置缓存") print(" ✓ 批量设置缓存")
@@ -208,7 +208,7 @@ def test_task_queue() -> None:
# 提交任务 # 提交任务
task_id = queue.submit( task_id = queue.submit(
task_type = "test_task", payload = {"test": "data", "timestamp": time.time()} task_type="test_task", payload={"test": "data", "timestamp": time.time()}
) )
print(" ✓ 提交任务: {task_id}") print(" ✓ 提交任务: {task_id}")
@@ -240,25 +240,25 @@ def test_performance_monitor() -> None:
# 记录一些测试指标 # 记录一些测试指标
for i in range(5): for i in range(5):
monitor.record_metric( monitor.record_metric(
metric_type = "api_response", metric_type="api_response",
duration_ms = 50 + i * 10, duration_ms=50 + i * 10,
endpoint = "/api/v1/test", endpoint="/api/v1/test",
metadata = {"test": True}, metadata={"test": True},
) )
for i in range(3): for i in range(3):
monitor.record_metric( monitor.record_metric(
metric_type = "db_query", metric_type="db_query",
duration_ms = 20 + i * 5, duration_ms=20 + i * 5,
endpoint = "SELECT test", endpoint="SELECT test",
metadata = {"test": True}, metadata={"test": True},
) )
print(" ✓ 记录了 8 个测试指标") print(" ✓ 记录了 8 个测试指标")
# 获取统计 # 获取统计
print("\n2. 获取性能统计...") print("\n2. 获取性能统计...")
stats = monitor.get_stats(hours = 1) stats = monitor.get_stats(hours=1)
print(f" 总请求数: {stats['overall']['total_requests']}") print(f" 总请求数: {stats['overall']['total_requests']}")
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms") print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms") print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")

View File

@@ -29,7 +29,7 @@ def test_tenant_management() -> None:
# 1. 创建租户 # 1. 创建租户
print("\n1.1 创建租户...") print("\n1.1 创建租户...")
tenant = manager.create_tenant( tenant = manager.create_tenant(
name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant" name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant"
) )
print(f"✅ 租户创建成功: {tenant.id}") print(f"✅ 租户创建成功: {tenant.id}")
print(f" - 名称: {tenant.name}") print(f" - 名称: {tenant.name}")
@@ -53,14 +53,14 @@ def test_tenant_management() -> None:
# 4. 更新租户 # 4. 更新租户
print("\n1.4 更新租户信息...") print("\n1.4 更新租户信息...")
updated = manager.update_tenant( updated = manager.update_tenant(
tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise" tenant_id=tenant.id, name="Test Company Updated", tier="enterprise"
) )
assert updated is not None, "更新租户失败" assert updated is not None, "更新租户失败"
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}") print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
# 5. 列出租户 # 5. 列出租户
print("\n1.5 列出租户...") print("\n1.5 列出租户...")
tenants = manager.list_tenants(limit = 10) tenants = manager.list_tenants(limit=10)
print(f"✅ 找到 {len(tenants)} 个租户") print(f"✅ 找到 {len(tenants)} 个租户")
return tenant.id return tenant.id
@@ -76,7 +76,7 @@ def test_domain_management(tenant_id: str) -> None:
# 1. 添加域名 # 1. 添加域名
print("\n2.1 添加自定义域名...") print("\n2.1 添加自定义域名...")
domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True) domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True)
print(f"✅ 域名添加成功: {domain.domain}") print(f"✅ 域名添加成功: {domain.domain}")
print(f" - ID: {domain.id}") print(f" - ID: {domain.id}")
print(f" - 状态: {domain.status}") print(f" - 状态: {domain.status}")
@@ -123,14 +123,14 @@ def test_branding_management(tenant_id: str) -> None:
# 1. 更新品牌配置 # 1. 更新品牌配置
print("\n3.1 更新品牌配置...") print("\n3.1 更新品牌配置...")
branding = manager.update_branding( branding = manager.update_branding(
tenant_id = tenant_id, tenant_id=tenant_id,
logo_url = "https://example.com/logo.png", logo_url="https://example.com/logo.png",
favicon_url = "https://example.com/favicon.ico", favicon_url="https://example.com/favicon.ico",
primary_color = "#1890ff", primary_color="#1890ff",
secondary_color = "#52c41a", secondary_color="#52c41a",
custom_css = ".header { background: #1890ff; }", custom_css=".header { background: #1890ff; }",
custom_js = "console.log('Custom JS loaded');", custom_js="console.log('Custom JS loaded');",
login_page_bg = "https://example.com/bg.jpg", login_page_bg="https://example.com/bg.jpg",
) )
print("✅ 品牌配置更新成功") print("✅ 品牌配置更新成功")
print(f" - Logo: {branding.logo_url}") print(f" - Logo: {branding.logo_url}")
@@ -163,7 +163,7 @@ def test_member_management(tenant_id: str) -> None:
# 1. 邀请成员 # 1. 邀请成员
print("\n4.1 邀请成员...") print("\n4.1 邀请成员...")
member1 = manager.invite_member( member1 = manager.invite_member(
tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001" tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001"
) )
print(f"✅ 成员邀请成功: {member1.email}") print(f"✅ 成员邀请成功: {member1.email}")
print(f" - ID: {member1.id}") print(f" - ID: {member1.id}")
@@ -171,7 +171,7 @@ def test_member_management(tenant_id: str) -> None:
print(f" - 权限: {member1.permissions}") print(f" - 权限: {member1.permissions}")
member2 = manager.invite_member( member2 = manager.invite_member(
tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001" tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001"
) )
print(f"✅ 成员邀请成功: {member2.email}") print(f"✅ 成员邀请成功: {member2.email}")
@@ -218,13 +218,13 @@ def test_usage_tracking(tenant_id: str) -> None:
# 1. 记录使用 # 1. 记录使用
print("\n5.1 记录资源使用...") print("\n5.1 记录资源使用...")
manager.record_usage( manager.record_usage(
tenant_id = tenant_id, tenant_id=tenant_id,
storage_bytes = 1024 * 1024 * 50, # 50MB storage_bytes=1024 * 1024 * 50, # 50MB
transcription_seconds = 600, # 10分钟 transcription_seconds=600, # 10分钟
api_calls = 100, api_calls=100,
projects_count = 5, projects_count=5,
entities_count = 50, entities_count=50,
members_count = 3, members_count=3,
) )
print("✅ 资源使用记录成功") print("✅ 资源使用记录成功")

View File

@@ -19,10 +19,10 @@ def test_subscription_manager() -> None:
print(" = " * 60) print(" = " * 60)
# 使用临时文件数据库进行测试 # 使用临时文件数据库进行测试
db_path = tempfile.mktemp(suffix = ".db") db_path = tempfile.mktemp(suffix=".db")
try: try:
manager = SubscriptionManager(db_path = db_path) manager = SubscriptionManager(db_path=db_path)
print("\n1. 测试订阅计划管理") print("\n1. 测试订阅计划管理")
print("-" * 40) print("-" * 40)
@@ -53,10 +53,10 @@ def test_subscription_manager() -> None:
# 创建订阅 # 创建订阅
subscription = manager.create_subscription( subscription = manager.create_subscription(
tenant_id = tenant_id, tenant_id=tenant_id,
plan_id = pro_plan.id, plan_id=pro_plan.id,
payment_provider = PaymentProvider.STRIPE.value, payment_provider=PaymentProvider.STRIPE.value,
trial_days = 14, trial_days=14,
) )
print(f"✓ 创建订阅: {subscription.id}") print(f"✓ 创建订阅: {subscription.id}")
@@ -75,21 +75,21 @@ def test_subscription_manager() -> None:
# 记录转录用量 # 记录转录用量
usage1 = manager.record_usage( usage1 = manager.record_usage(
tenant_id = tenant_id, tenant_id=tenant_id,
resource_type = "transcription", resource_type="transcription",
quantity = 120, quantity=120,
unit = "minute", unit="minute",
description = "会议转录", description="会议转录",
) )
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}") print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
# 记录存储用量 # 记录存储用量
usage2 = manager.record_usage( usage2 = manager.record_usage(
tenant_id = tenant_id, tenant_id=tenant_id,
resource_type = "storage", resource_type="storage",
quantity = 2.5, quantity=2.5,
unit = "gb", unit="gb",
description = "文件存储", description="文件存储",
) )
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}") print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
@@ -105,11 +105,11 @@ def test_subscription_manager() -> None:
# 创建支付 # 创建支付
payment = manager.create_payment( payment = manager.create_payment(
tenant_id = tenant_id, tenant_id=tenant_id,
amount = 99.0, amount=99.0,
currency = "CNY", currency="CNY",
provider = PaymentProvider.ALIPAY.value, provider=PaymentProvider.ALIPAY.value,
payment_method = "qrcode", payment_method="qrcode",
) )
print(f"✓ 创建支付: {payment.id}") print(f"✓ 创建支付: {payment.id}")
print(f" - 金额: ¥{payment.amount}") print(f" - 金额: ¥{payment.amount}")
@@ -142,11 +142,11 @@ def test_subscription_manager() -> None:
# 申请退款 # 申请退款
refund = manager.request_refund( refund = manager.request_refund(
tenant_id = tenant_id, tenant_id=tenant_id,
payment_id = payment.id, payment_id=payment.id,
amount = 50.0, amount=50.0,
reason = "服务不满意", reason="服务不满意",
requested_by = "user_001", requested_by="user_001",
) )
print(f"✓ 申请退款: {refund.id}") print(f"✓ 申请退款: {refund.id}")
print(f" - 金额: ¥{refund.amount}") print(f" - 金额: ¥{refund.amount}")
@@ -178,19 +178,19 @@ def test_subscription_manager() -> None:
# Stripe Checkout # Stripe Checkout
stripe_session = manager.create_stripe_checkout_session( stripe_session = manager.create_stripe_checkout_session(
tenant_id = tenant_id, tenant_id=tenant_id,
plan_id = enterprise_plan.id, plan_id=enterprise_plan.id,
success_url = "https://example.com/success", success_url="https://example.com/success",
cancel_url = "https://example.com/cancel", cancel_url="https://example.com/cancel",
) )
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}") print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
# 支付宝订单 # 支付宝订单
alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id) alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id)
print(f"✓ 支付宝订单: {alipay_order['order_id']}") print(f"✓ 支付宝订单: {alipay_order['order_id']}")
# 微信支付订单 # 微信支付订单
wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id) wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id)
print(f"✓ 微信支付订单: {wechat_order['order_id']}") print(f"✓ 微信支付订单: {wechat_order['order_id']}")
# Webhook 处理 # Webhook 处理
@@ -205,12 +205,12 @@ def test_subscription_manager() -> None:
# 更改计划 # 更改计划
changed = manager.change_plan( changed = manager.change_plan(
subscription_id = subscription.id, new_plan_id = enterprise_plan.id subscription_id=subscription.id, new_plan_id=enterprise_plan.id
) )
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)") print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
# 取消订阅 # 取消订阅
cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True) cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True)
print(f"✓ 取消订阅: {cancelled.status}") print(f"✓ 取消订阅: {cancelled.status}")
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}") print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")

View File

@@ -23,16 +23,16 @@ def test_custom_model() -> None:
# 1. 创建自定义模型 # 1. 创建自定义模型
print("1. 创建自定义模型...") print("1. 创建自定义模型...")
model = manager.create_custom_model( model = manager.create_custom_model(
tenant_id = "tenant_001", tenant_id="tenant_001",
name = "领域实体识别模型", name="领域实体识别模型",
description = "用于识别医疗领域实体的自定义模型", description="用于识别医疗领域实体的自定义模型",
model_type = ModelType.CUSTOM_NER, model_type=ModelType.CUSTOM_NER,
training_data = { training_data={
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"], "entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
"domain": "medical", "domain": "medical",
}, },
hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32}, hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
created_by = "user_001", created_by="user_001",
) )
print(f" 创建成功: {model.id}, 状态: {model.status.value}") print(f" 创建成功: {model.id}, 状态: {model.status.value}")
@@ -67,10 +67,10 @@ def test_custom_model() -> None:
for sample_data in samples: for sample_data in samples:
sample = manager.add_training_sample( sample = manager.add_training_sample(
model_id = model.id, model_id=model.id,
text = sample_data["text"], text=sample_data["text"],
entities = sample_data["entities"], entities=sample_data["entities"],
metadata = {"source": "manual"}, metadata={"source": "manual"},
) )
print(f" 添加样本: {sample.id}") print(f" 添加样本: {sample.id}")
@@ -81,7 +81,7 @@ def test_custom_model() -> None:
# 4. 列出自定义模型 # 4. 列出自定义模型
print("4. 列出自定义模型...") print("4. 列出自定义模型...")
models = manager.list_custom_models(tenant_id = "tenant_001") models = manager.list_custom_models(tenant_id="tenant_001")
print(f" 找到 {len(models)} 个模型") print(f" 找到 {len(models)} 个模型")
for m in models: for m in models:
print(f" - {m.name} ({m.model_type.value}): {m.status.value}") print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
@@ -125,32 +125,32 @@ def test_prediction_models() -> None:
# 1. 创建趋势预测模型 # 1. 创建趋势预测模型
print("1. 创建趋势预测模型...") print("1. 创建趋势预测模型...")
trend_model = manager.create_prediction_model( trend_model = manager.create_prediction_model(
tenant_id = "tenant_001", tenant_id="tenant_001",
project_id = "project_001", project_id="project_001",
name = "实体数量趋势预测", name="实体数量趋势预测",
prediction_type = PredictionType.TREND, prediction_type=PredictionType.TREND,
target_entity_type = "PERSON", target_entity_type="PERSON",
features = ["entity_count", "time_period", "document_count"], features=["entity_count", "time_period", "document_count"],
model_config = {"algorithm": "linear_regression", "window_size": 7}, model_config={"algorithm": "linear_regression", "window_size": 7},
) )
print(f" 创建成功: {trend_model.id}") print(f" 创建成功: {trend_model.id}")
# 2. 创建异常检测模型 # 2. 创建异常检测模型
print("2. 创建异常检测模型...") print("2. 创建异常检测模型...")
anomaly_model = manager.create_prediction_model( anomaly_model = manager.create_prediction_model(
tenant_id = "tenant_001", tenant_id="tenant_001",
project_id = "project_001", project_id="project_001",
name = "实体增长异常检测", name="实体增长异常检测",
prediction_type = PredictionType.ANOMALY, prediction_type=PredictionType.ANOMALY,
target_entity_type = None, target_entity_type=None,
features = ["daily_growth", "weekly_growth"], features=["daily_growth", "weekly_growth"],
model_config = {"threshold": 2.5, "sensitivity": "medium"}, model_config={"threshold": 2.5, "sensitivity": "medium"},
) )
print(f" 创建成功: {anomaly_model.id}") print(f" 创建成功: {anomaly_model.id}")
# 3. 列出预测模型 # 3. 列出预测模型
print("3. 列出预测模型...") print("3. 列出预测模型...")
models = manager.list_prediction_models(tenant_id = "tenant_001") models = manager.list_prediction_models(tenant_id="tenant_001")
print(f" 找到 {len(models)} 个预测模型") print(f" 找到 {len(models)} 个预测模型")
for m in models: for m in models:
print(f" - {m.name} ({m.prediction_type.value})") print(f" - {m.name} ({m.prediction_type.value})")
@@ -202,22 +202,22 @@ def test_kg_rag() -> None:
# 创建 RAG 配置 # 创建 RAG 配置
print("1. 创建知识图谱 RAG 配置...") print("1. 创建知识图谱 RAG 配置...")
rag = manager.create_kg_rag( rag = manager.create_kg_rag(
tenant_id = "tenant_001", tenant_id="tenant_001",
project_id = "project_001", project_id="project_001",
name = "项目知识问答", name="项目知识问答",
description = "基于项目知识图谱的智能问答", description="基于项目知识图谱的智能问答",
kg_config = { kg_config={
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"], "entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
"relation_types": ["works_with", "belongs_to", "depends_on"], "relation_types": ["works_with", "belongs_to", "depends_on"],
}, },
retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True}, retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True}, generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
) )
print(f" 创建成功: {rag.id}") print(f" 创建成功: {rag.id}")
# 列出 RAG 配置 # 列出 RAG 配置
print("2. 列出 RAG 配置...") print("2. 列出 RAG 配置...")
rags = manager.list_kg_rags(tenant_id = "tenant_001") rags = manager.list_kg_rags(tenant_id="tenant_001")
print(f" 找到 {len(rags)} 个配置") print(f" 找到 {len(rags)} 个配置")
return rag.id return rag.id
@@ -279,10 +279,10 @@ async def test_kg_rag_query(rag_id: str) -> None:
try: try:
result = await manager.query_kg_rag( result = await manager.query_kg_rag(
rag_id = rag_id, rag_id=rag_id,
query = query_text, query=query_text,
project_entities = project_entities, project_entities=project_entities,
project_relations = project_relations, project_relations=project_relations,
) )
print(f" 查询: {result.query}") print(f" 查询: {result.query}")
@@ -326,12 +326,12 @@ async def test_smart_summary() -> None:
print(f"1. 生成 {summary_type} 类型摘要...") print(f"1. 生成 {summary_type} 类型摘要...")
try: try:
summary = await manager.generate_smart_summary( summary = await manager.generate_smart_summary(
tenant_id = "tenant_001", tenant_id="tenant_001",
project_id = "project_001", project_id="project_001",
source_type = "transcript", source_type="transcript",
source_id = "transcript_001", source_id="transcript_001",
summary_type = summary_type, summary_type=summary_type,
content_data = content_data, content_data=content_data,
) )
print(f" 摘要类型: {summary.summary_type}") print(f" 摘要类型: {summary.summary_type}")

View File

@@ -56,15 +56,15 @@ class TestGrowthManager:
try: try:
event = await self.manager.track_event( event = await self.manager.track_event(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
user_id = self.test_user_id, user_id=self.test_user_id,
event_type = EventType.PAGE_VIEW, event_type=EventType.PAGE_VIEW,
event_name = "dashboard_view", event_name="dashboard_view",
properties = {"page": "/dashboard", "duration": 120}, properties={"page": "/dashboard", "duration": 120},
session_id = "session_001", session_id="session_001",
device_info = {"browser": "Chrome", "os": "MacOS"}, device_info={"browser": "Chrome", "os": "MacOS"},
referrer = "https://google.com", referrer="https://google.com",
utm_params = {"source": "google", "medium": "organic", "campaign": "summer"}, utm_params={"source": "google", "medium": "organic", "campaign": "summer"},
) )
assert event.id is not None assert event.id is not None
@@ -74,7 +74,7 @@ class TestGrowthManager:
self.log(f"事件追踪成功: {event.id}") self.log(f"事件追踪成功: {event.id}")
return True return True
except Exception as e: except Exception as e:
self.log(f"事件追踪失败: {e}", success = False) self.log(f"事件追踪失败: {e}", success=False)
return False return False
async def test_track_multiple_events(self) -> None: async def test_track_multiple_events(self) -> None:
@@ -91,17 +91,17 @@ class TestGrowthManager:
for event_type, event_name, props in events: for event_type, event_name, props in events:
await self.manager.track_event( await self.manager.track_event(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
user_id = self.test_user_id, user_id=self.test_user_id,
event_type = event_type, event_type=event_type,
event_name = event_name, event_name=event_name,
properties = props, properties=props,
) )
self.log(f"成功追踪 {len(events)} 个事件") self.log(f"成功追踪 {len(events)} 个事件")
return True return True
except Exception as e: except Exception as e:
self.log(f"批量事件追踪失败: {e}", success = False) self.log(f"批量事件追踪失败: {e}", success=False)
return False return False
def test_get_user_profile(self) -> None: def test_get_user_profile(self) -> None:
@@ -120,7 +120,7 @@ class TestGrowthManager:
return True return True
except Exception as e: except Exception as e:
self.log(f"获取用户画像失败: {e}", success = False) self.log(f"获取用户画像失败: {e}", success=False)
return False return False
def test_get_analytics_summary(self) -> None: def test_get_analytics_summary(self) -> None:
@@ -129,9 +129,9 @@ class TestGrowthManager:
try: try:
summary = self.manager.get_user_analytics_summary( summary = self.manager.get_user_analytics_summary(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
start_date = datetime.now() - timedelta(days = 7), start_date=datetime.now() - timedelta(days=7),
end_date = datetime.now(), end_date=datetime.now(),
) )
assert "unique_users" in summary assert "unique_users" in summary
@@ -141,7 +141,7 @@ class TestGrowthManager:
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件") self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
return True return True
except Exception as e: except Exception as e:
self.log(f"获取分析汇总失败: {e}", success = False) self.log(f"获取分析汇总失败: {e}", success=False)
return False return False
def test_create_funnel(self) -> None: def test_create_funnel(self) -> None:
@@ -150,16 +150,16 @@ class TestGrowthManager:
try: try:
funnel = self.manager.create_funnel( funnel = self.manager.create_funnel(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "用户注册转化漏斗", name="用户注册转化漏斗",
description = "从访问到完成注册的转化流程", description="从访问到完成注册的转化流程",
steps = [ steps=[
{"name": "访问首页", "event_name": "page_view_home"}, {"name": "访问首页", "event_name": "page_view_home"},
{"name": "点击注册", "event_name": "signup_click"}, {"name": "点击注册", "event_name": "signup_click"},
{"name": "填写信息", "event_name": "signup_form_fill"}, {"name": "填写信息", "event_name": "signup_form_fill"},
{"name": "完成注册", "event_name": "signup_complete"}, {"name": "完成注册", "event_name": "signup_complete"},
], ],
created_by = "test", created_by="test",
) )
assert funnel.id is not None assert funnel.id is not None
@@ -168,7 +168,7 @@ class TestGrowthManager:
self.log(f"漏斗创建成功: {funnel.id}") self.log(f"漏斗创建成功: {funnel.id}")
return funnel.id return funnel.id
except Exception as e: except Exception as e:
self.log(f"创建漏斗失败: {e}", success = False) self.log(f"创建漏斗失败: {e}", success=False)
return None return None
def test_analyze_funnel(self, funnel_id: str) -> None: def test_analyze_funnel(self, funnel_id: str) -> None:
@@ -181,9 +181,9 @@ class TestGrowthManager:
try: try:
analysis = self.manager.analyze_funnel( analysis = self.manager.analyze_funnel(
funnel_id = funnel_id, funnel_id=funnel_id,
period_start = datetime.now() - timedelta(days = 30), period_start=datetime.now() - timedelta(days=30),
period_end = datetime.now(), period_end=datetime.now(),
) )
if analysis: if analysis:
@@ -194,7 +194,7 @@ class TestGrowthManager:
self.log("漏斗分析返回空结果") self.log("漏斗分析返回空结果")
return False return False
except Exception as e: except Exception as e:
self.log(f"漏斗分析失败: {e}", success = False) self.log(f"漏斗分析失败: {e}", success=False)
return False return False
def test_calculate_retention(self) -> None: def test_calculate_retention(self) -> None:
@@ -203,9 +203,9 @@ class TestGrowthManager:
try: try:
retention = self.manager.calculate_retention( retention = self.manager.calculate_retention(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
cohort_date = datetime.now() - timedelta(days = 7), cohort_date=datetime.now() - timedelta(days=7),
periods = [1, 3, 7], periods=[1, 3, 7],
) )
assert "cohort_date" in retention assert "cohort_date" in retention
@@ -214,7 +214,7 @@ class TestGrowthManager:
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户") self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
return True return True
except Exception as e: except Exception as e:
self.log(f"留存率计算失败: {e}", success = False) self.log(f"留存率计算失败: {e}", success=False)
return False return False
# ==================== 测试 A/B 测试框架 ==================== # ==================== 测试 A/B 测试框架 ====================
@@ -225,23 +225,23 @@ class TestGrowthManager:
try: try:
experiment = self.manager.create_experiment( experiment = self.manager.create_experiment(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "首页按钮颜色测试", name="首页按钮颜色测试",
description = "测试不同按钮颜色对转化率的影响", description="测试不同按钮颜色对转化率的影响",
hypothesis = "蓝色按钮比红色按钮有更高的点击率", hypothesis="蓝色按钮比红色按钮有更高的点击率",
variants = [ variants=[
{"id": "control", "name": "红色按钮", "is_control": True}, {"id": "control", "name": "红色按钮", "is_control": True},
{"id": "variant_a", "name": "蓝色按钮", "is_control": False}, {"id": "variant_a", "name": "蓝色按钮", "is_control": False},
{"id": "variant_b", "name": "绿色按钮", "is_control": False}, {"id": "variant_b", "name": "绿色按钮", "is_control": False},
], ],
traffic_allocation = TrafficAllocationType.RANDOM, traffic_allocation=TrafficAllocationType.RANDOM,
traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
target_audience = {"conditions": []}, target_audience={"conditions": []},
primary_metric = "button_click_rate", primary_metric="button_click_rate",
secondary_metrics = ["conversion_rate", "bounce_rate"], secondary_metrics=["conversion_rate", "bounce_rate"],
min_sample_size = 100, min_sample_size=100,
confidence_level = 0.95, confidence_level=0.95,
created_by = "test", created_by="test",
) )
assert experiment.id is not None assert experiment.id is not None
@@ -250,7 +250,7 @@ class TestGrowthManager:
self.log(f"实验创建成功: {experiment.id}") self.log(f"实验创建成功: {experiment.id}")
return experiment.id return experiment.id
except Exception as e: except Exception as e:
self.log(f"创建实验失败: {e}", success = False) self.log(f"创建实验失败: {e}", success=False)
return None return None
def test_list_experiments(self) -> None: def test_list_experiments(self) -> None:
@@ -263,7 +263,7 @@ class TestGrowthManager:
self.log(f"列出 {len(experiments)} 个实验") self.log(f"列出 {len(experiments)} 个实验")
return True return True
except Exception as e: except Exception as e:
self.log(f"列出实验失败: {e}", success = False) self.log(f"列出实验失败: {e}", success=False)
return False return False
def test_assign_variant(self, experiment_id: str) -> None: def test_assign_variant(self, experiment_id: str) -> None:
@@ -284,9 +284,9 @@ class TestGrowthManager:
for user_id in test_users: for user_id in test_users:
variant_id = self.manager.assign_variant( variant_id = self.manager.assign_variant(
experiment_id = experiment_id, experiment_id=experiment_id,
user_id = user_id, user_id=user_id,
user_attributes = {"user_id": user_id, "segment": "new"}, user_attributes={"user_id": user_id, "segment": "new"},
) )
if variant_id: if variant_id:
@@ -295,7 +295,7 @@ class TestGrowthManager:
self.log(f"变体分配完成: {len(assignments)} 个用户") self.log(f"变体分配完成: {len(assignments)} 个用户")
return True return True
except Exception as e: except Exception as e:
self.log(f"变体分配失败: {e}", success = False) self.log(f"变体分配失败: {e}", success=False)
return False return False
def test_record_experiment_metric(self, experiment_id: str) -> None: def test_record_experiment_metric(self, experiment_id: str) -> None:
@@ -318,17 +318,17 @@ class TestGrowthManager:
for user_id, variant_id, value in test_data: for user_id, variant_id, value in test_data:
self.manager.record_experiment_metric( self.manager.record_experiment_metric(
experiment_id = experiment_id, experiment_id=experiment_id,
variant_id = variant_id, variant_id=variant_id,
user_id = user_id, user_id=user_id,
metric_name = "button_click_rate", metric_name="button_click_rate",
metric_value = value, metric_value=value,
) )
self.log(f"成功记录 {len(test_data)} 条指标") self.log(f"成功记录 {len(test_data)} 条指标")
return True return True
except Exception as e: except Exception as e:
self.log(f"记录指标失败: {e}", success = False) self.log(f"记录指标失败: {e}", success=False)
return False return False
def test_analyze_experiment(self, experiment_id: str) -> None: def test_analyze_experiment(self, experiment_id: str) -> None:
@@ -346,10 +346,10 @@ class TestGrowthManager:
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体") self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
return True return True
else: else:
self.log(f"实验分析返回错误: {result['error']}", success = False) self.log(f"实验分析返回错误: {result['error']}", success=False)
return False return False
except Exception as e: except Exception as e:
self.log(f"实验分析失败: {e}", success = False) self.log(f"实验分析失败: {e}", success=False)
return False return False
# ==================== 测试邮件营销 ==================== # ==================== 测试邮件营销 ====================
@@ -360,11 +360,11 @@ class TestGrowthManager:
try: try:
template = self.manager.create_email_template( template = self.manager.create_email_template(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "欢迎邮件", name="欢迎邮件",
template_type = EmailTemplateType.WELCOME, template_type=EmailTemplateType.WELCOME,
subject = "欢迎加入 InsightFlow", subject="欢迎加入 InsightFlow",
html_content = """ html_content="""
<h1>欢迎,{{user_name}}</h1> <h1>欢迎,{{user_name}}</h1>
<p>感谢您注册 InsightFlow。我们很高兴您能加入我们</p> <p>感谢您注册 InsightFlow。我们很高兴您能加入我们</p>
<p>您的账户已创建,可以开始使用以下功能:</p> <p>您的账户已创建,可以开始使用以下功能:</p>
@@ -375,8 +375,8 @@ class TestGrowthManager:
</ul> </ul>
<p><a href = "{{dashboard_url}}">立即开始使用</a></p> <p><a href = "{{dashboard_url}}">立即开始使用</a></p>
""", """,
from_name = "InsightFlow 团队", from_name="InsightFlow 团队",
from_email = "welcome@insightflow.io", from_email="welcome@insightflow.io",
) )
assert template.id is not None assert template.id is not None
@@ -385,7 +385,7 @@ class TestGrowthManager:
self.log(f"邮件模板创建成功: {template.id}") self.log(f"邮件模板创建成功: {template.id}")
return template.id return template.id
except Exception as e: except Exception as e:
self.log(f"创建邮件模板失败: {e}", success = False) self.log(f"创建邮件模板失败: {e}", success=False)
return None return None
def test_list_email_templates(self) -> None: def test_list_email_templates(self) -> None:
@@ -398,7 +398,7 @@ class TestGrowthManager:
self.log(f"列出 {len(templates)} 个邮件模板") self.log(f"列出 {len(templates)} 个邮件模板")
return True return True
except Exception as e: except Exception as e:
self.log(f"列出邮件模板失败: {e}", success = False) self.log(f"列出邮件模板失败: {e}", success=False)
return False return False
def test_render_template(self, template_id: str) -> None: def test_render_template(self, template_id: str) -> None:
@@ -411,8 +411,8 @@ class TestGrowthManager:
try: try:
rendered = self.manager.render_template( rendered = self.manager.render_template(
template_id = template_id, template_id=template_id,
variables = { variables={
"user_name": "张三", "user_name": "张三",
"dashboard_url": "https://app.insightflow.io/dashboard", "dashboard_url": "https://app.insightflow.io/dashboard",
}, },
@@ -424,10 +424,10 @@ class TestGrowthManager:
self.log(f"模板渲染成功: {rendered['subject']}") self.log(f"模板渲染成功: {rendered['subject']}")
return True return True
else: else:
self.log("模板渲染返回空结果", success = False) self.log("模板渲染返回空结果", success=False)
return False return False
except Exception as e: except Exception as e:
self.log(f"模板渲染失败: {e}", success = False) self.log(f"模板渲染失败: {e}", success=False)
return False return False
def test_create_email_campaign(self, template_id: str) -> None: def test_create_email_campaign(self, template_id: str) -> None:
@@ -440,10 +440,10 @@ class TestGrowthManager:
try: try:
campaign = self.manager.create_email_campaign( campaign = self.manager.create_email_campaign(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "新用户欢迎活动", name="新用户欢迎活动",
template_id = template_id, template_id=template_id,
recipient_list = [ recipient_list=[
{"user_id": "user_001", "email": "user1@example.com"}, {"user_id": "user_001", "email": "user1@example.com"},
{"user_id": "user_002", "email": "user2@example.com"}, {"user_id": "user_002", "email": "user2@example.com"},
{"user_id": "user_003", "email": "user3@example.com"}, {"user_id": "user_003", "email": "user3@example.com"},
@@ -456,7 +456,7 @@ class TestGrowthManager:
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人") self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
return campaign.id return campaign.id
except Exception as e: except Exception as e:
self.log(f"创建营销活动失败: {e}", success = False) self.log(f"创建营销活动失败: {e}", success=False)
return None return None
def test_create_automation_workflow(self) -> None: def test_create_automation_workflow(self) -> None:
@@ -465,12 +465,12 @@ class TestGrowthManager:
try: try:
workflow = self.manager.create_automation_workflow( workflow = self.manager.create_automation_workflow(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "新用户欢迎序列", name="新用户欢迎序列",
description = "用户注册后自动发送欢迎邮件序列", description="用户注册后自动发送欢迎邮件序列",
trigger_type = WorkflowTriggerType.USER_SIGNUP, trigger_type=WorkflowTriggerType.USER_SIGNUP,
trigger_conditions = {"event": "user_signup"}, trigger_conditions={"event": "user_signup"},
actions = [ actions=[
{"type": "send_email", "template_type": "welcome", "delay_hours": 0}, {"type": "send_email", "template_type": "welcome", "delay_hours": 0},
{"type": "send_email", "template_type": "onboarding", "delay_hours": 24}, {"type": "send_email", "template_type": "onboarding", "delay_hours": 24},
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72}, {"type": "send_email", "template_type": "feature_tips", "delay_hours": 72},
@@ -483,7 +483,7 @@ class TestGrowthManager:
self.log(f"自动化工作流创建成功: {workflow.id}") self.log(f"自动化工作流创建成功: {workflow.id}")
return True return True
except Exception as e: except Exception as e:
self.log(f"创建工作流失败: {e}", success = False) self.log(f"创建工作流失败: {e}", success=False)
return False return False
# ==================== 测试推荐系统 ==================== # ==================== 测试推荐系统 ====================
@@ -494,16 +494,16 @@ class TestGrowthManager:
try: try:
program = self.manager.create_referral_program( program = self.manager.create_referral_program(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "邀请好友奖励计划", name="邀请好友奖励计划",
description = "邀请好友注册,双方获得积分奖励", description="邀请好友注册,双方获得积分奖励",
referrer_reward_type = "credit", referrer_reward_type="credit",
referrer_reward_value = 100.0, referrer_reward_value=100.0,
referee_reward_type = "credit", referee_reward_type="credit",
referee_reward_value = 50.0, referee_reward_value=50.0,
max_referrals_per_user = 10, max_referrals_per_user=10,
referral_code_length = 8, referral_code_length=8,
expiry_days = 30, expiry_days=30,
) )
assert program.id is not None assert program.id is not None
@@ -512,7 +512,7 @@ class TestGrowthManager:
self.log(f"推荐计划创建成功: {program.id}") self.log(f"推荐计划创建成功: {program.id}")
return program.id return program.id
except Exception as e: except Exception as e:
self.log(f"创建推荐计划失败: {e}", success = False) self.log(f"创建推荐计划失败: {e}", success=False)
return None return None
def test_generate_referral_code(self, program_id: str) -> None: def test_generate_referral_code(self, program_id: str) -> None:
@@ -525,7 +525,7 @@ class TestGrowthManager:
try: try:
referral = self.manager.generate_referral_code( referral = self.manager.generate_referral_code(
program_id = program_id, referrer_id = "referrer_user_001" program_id=program_id, referrer_id="referrer_user_001"
) )
if referral: if referral:
@@ -535,10 +535,10 @@ class TestGrowthManager:
self.log(f"推荐码生成成功: {referral.referral_code}") self.log(f"推荐码生成成功: {referral.referral_code}")
return referral.referral_code return referral.referral_code
else: else:
self.log("生成推荐码返回空结果", success = False) self.log("生成推荐码返回空结果", success=False)
return None return None
except Exception as e: except Exception as e:
self.log(f"生成推荐码失败: {e}", success = False) self.log(f"生成推荐码失败: {e}", success=False)
return None return None
def test_apply_referral_code(self, referral_code: str) -> None: def test_apply_referral_code(self, referral_code: str) -> None:
@@ -551,17 +551,17 @@ class TestGrowthManager:
try: try:
success = self.manager.apply_referral_code( success = self.manager.apply_referral_code(
referral_code = referral_code, referee_id = "new_user_001" referral_code=referral_code, referee_id="new_user_001"
) )
if success: if success:
self.log(f"推荐码应用成功: {referral_code}") self.log(f"推荐码应用成功: {referral_code}")
return True return True
else: else:
self.log("推荐码应用失败", success = False) self.log("推荐码应用失败", success=False)
return False return False
except Exception as e: except Exception as e:
self.log(f"应用推荐码失败: {e}", success = False) self.log(f"应用推荐码失败: {e}", success=False)
return False return False
def test_get_referral_stats(self, program_id: str) -> None: def test_get_referral_stats(self, program_id: str) -> None:
@@ -583,7 +583,7 @@ class TestGrowthManager:
) )
return True return True
except Exception as e: except Exception as e:
self.log(f"获取推荐统计失败: {e}", success = False) self.log(f"获取推荐统计失败: {e}", success=False)
return False return False
def test_create_team_incentive(self) -> None: def test_create_team_incentive(self) -> None:
@@ -592,15 +592,15 @@ class TestGrowthManager:
try: try:
incentive = self.manager.create_team_incentive( incentive = self.manager.create_team_incentive(
tenant_id = self.test_tenant_id, tenant_id=self.test_tenant_id,
name = "团队升级奖励", name="团队升级奖励",
description = "团队规模达到5人升级到 Pro 计划可获得折扣", description="团队规模达到5人升级到 Pro 计划可获得折扣",
target_tier = "pro", target_tier="pro",
min_team_size = 5, min_team_size=5,
incentive_type = "discount", incentive_type="discount",
incentive_value = 20.0, # 20% 折扣 incentive_value=20.0, # 20% 折扣
valid_from = datetime.now(), valid_from=datetime.now(),
valid_until = datetime.now() + timedelta(days = 90), valid_until=datetime.now() + timedelta(days=90),
) )
assert incentive.id is not None assert incentive.id is not None
@@ -609,7 +609,7 @@ class TestGrowthManager:
self.log(f"团队激励创建成功: {incentive.id}") self.log(f"团队激励创建成功: {incentive.id}")
return True return True
except Exception as e: except Exception as e:
self.log(f"创建团队激励失败: {e}", success = False) self.log(f"创建团队激励失败: {e}", success=False)
return False return False
def test_check_team_incentive_eligibility(self) -> None: def test_check_team_incentive_eligibility(self) -> None:
@@ -618,13 +618,13 @@ class TestGrowthManager:
try: try:
incentives = self.manager.check_team_incentive_eligibility( incentives = self.manager.check_team_incentive_eligibility(
tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5 tenant_id=self.test_tenant_id, current_tier="free", team_size=5
) )
self.log(f"找到 {len(incentives)} 个符合条件的激励") self.log(f"找到 {len(incentives)} 个符合条件的激励")
return True return True
except Exception as e: except Exception as e:
self.log(f"检查激励资格失败: {e}", success = False) self.log(f"检查激励资格失败: {e}", success=False)
return False return False
# ==================== 测试实时仪表板 ==================== # ==================== 测试实时仪表板 ====================
@@ -646,7 +646,7 @@ class TestGrowthManager:
) )
return True return True
except Exception as e: except Exception as e:
self.log(f"获取实时仪表板失败: {e}", success = False) self.log(f"获取实时仪表板失败: {e}", success=False)
return False return False
# ==================== 运行所有测试 ==================== # ==================== 运行所有测试 ====================

View File

@@ -123,46 +123,46 @@ class TestDeveloperEcosystem:
"""测试创建 SDK""" """测试创建 SDK"""
try: try:
sdk = self.manager.create_sdk_release( sdk = self.manager.create_sdk_release(
name = "InsightFlow Python SDK", name="InsightFlow Python SDK",
language = SDKLanguage.PYTHON, language=SDKLanguage.PYTHON,
version = "1.0.0", version="1.0.0",
description = "Python SDK for InsightFlow API", description="Python SDK for InsightFlow API",
changelog = "Initial release", changelog="Initial release",
download_url = "https://pypi.org/insightflow/1.0.0", download_url="https://pypi.org/insightflow/1.0.0",
documentation_url = "https://docs.insightflow.io/python", documentation_url="https://docs.insightflow.io/python",
repository_url = "https://github.com/insightflow/python-sdk", repository_url="https://github.com/insightflow/python-sdk",
package_name = "insightflow", package_name="insightflow",
min_platform_version = "1.0.0", min_platform_version="1.0.0",
dependencies = [{"name": "requests", "version": ">= 2.0"}], dependencies=[{"name": "requests", "version": ">= 2.0"}],
file_size = 1024000, file_size=1024000,
checksum = "abc123", checksum="abc123",
created_by = "test_user", created_by="test_user",
) )
self.created_ids["sdk"].append(sdk.id) self.created_ids["sdk"].append(sdk.id)
self.log(f"Created SDK: {sdk.name} ({sdk.id})") self.log(f"Created SDK: {sdk.name} ({sdk.id})")
# Create JavaScript SDK # Create JavaScript SDK
sdk_js = self.manager.create_sdk_release( sdk_js = self.manager.create_sdk_release(
name = "InsightFlow JavaScript SDK", name="InsightFlow JavaScript SDK",
language = SDKLanguage.JAVASCRIPT, language=SDKLanguage.JAVASCRIPT,
version = "1.0.0", version="1.0.0",
description = "JavaScript SDK for InsightFlow API", description="JavaScript SDK for InsightFlow API",
changelog = "Initial release", changelog="Initial release",
download_url = "https://npmjs.com/insightflow/1.0.0", download_url="https://npmjs.com/insightflow/1.0.0",
documentation_url = "https://docs.insightflow.io/js", documentation_url="https://docs.insightflow.io/js",
repository_url = "https://github.com/insightflow/js-sdk", repository_url="https://github.com/insightflow/js-sdk",
package_name = "@insightflow/sdk", package_name="@insightflow/sdk",
min_platform_version = "1.0.0", min_platform_version="1.0.0",
dependencies = [{"name": "axios", "version": ">= 0.21"}], dependencies=[{"name": "axios", "version": ">= 0.21"}],
file_size = 512000, file_size=512000,
checksum = "def456", checksum="def456",
created_by = "test_user", created_by="test_user",
) )
self.created_ids["sdk"].append(sdk_js.id) self.created_ids["sdk"].append(sdk_js.id)
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})") self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
except Exception as e: except Exception as e:
self.log(f"Failed to create SDK: {str(e)}", success = False) self.log(f"Failed to create SDK: {str(e)}", success=False)
def test_sdk_list(self) -> None: def test_sdk_list(self) -> None:
"""测试列出 SDK""" """测试列出 SDK"""
@@ -171,15 +171,15 @@ class TestDeveloperEcosystem:
self.log(f"Listed {len(sdks)} SDKs") self.log(f"Listed {len(sdks)} SDKs")
# Test filter by language # Test filter by language
python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON) python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON)
self.log(f"Found {len(python_sdks)} Python SDKs") self.log(f"Found {len(python_sdks)} Python SDKs")
# Test search # Test search
search_results = self.manager.list_sdk_releases(search = "Python") search_results = self.manager.list_sdk_releases(search="Python")
self.log(f"Search found {len(search_results)} SDKs") self.log(f"Search found {len(search_results)} SDKs")
except Exception as e: except Exception as e:
self.log(f"Failed to list SDKs: {str(e)}", success = False) self.log(f"Failed to list SDKs: {str(e)}", success=False)
def test_sdk_get(self) -> None: def test_sdk_get(self) -> None:
"""测试获取 SDK 详情""" """测试获取 SDK 详情"""
@@ -189,21 +189,21 @@ class TestDeveloperEcosystem:
if sdk: if sdk:
self.log(f"Retrieved SDK: {sdk.name}") self.log(f"Retrieved SDK: {sdk.name}")
else: else:
self.log("SDK not found", success = False) self.log("SDK not found", success=False)
except Exception as e: except Exception as e:
self.log(f"Failed to get SDK: {str(e)}", success = False) self.log(f"Failed to get SDK: {str(e)}", success=False)
def test_sdk_update(self) -> None: def test_sdk_update(self) -> None:
"""测试更新 SDK""" """测试更新 SDK"""
try: try:
if self.created_ids["sdk"]: if self.created_ids["sdk"]:
sdk = self.manager.update_sdk_release( sdk = self.manager.update_sdk_release(
self.created_ids["sdk"][0], description = "Updated description" self.created_ids["sdk"][0], description="Updated description"
) )
if sdk: if sdk:
self.log(f"Updated SDK: {sdk.name}") self.log(f"Updated SDK: {sdk.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to update SDK: {str(e)}", success = False) self.log(f"Failed to update SDK: {str(e)}", success=False)
def test_sdk_publish(self) -> None: def test_sdk_publish(self) -> None:
"""测试发布 SDK""" """测试发布 SDK"""
@@ -213,67 +213,67 @@ class TestDeveloperEcosystem:
if sdk: if sdk:
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})") self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
except Exception as e: except Exception as e:
self.log(f"Failed to publish SDK: {str(e)}", success = False) self.log(f"Failed to publish SDK: {str(e)}", success=False)
def test_sdk_version_add(self) -> None: def test_sdk_version_add(self) -> None:
"""测试添加 SDK 版本""" """测试添加 SDK 版本"""
try: try:
if self.created_ids["sdk"]: if self.created_ids["sdk"]:
version = self.manager.add_sdk_version( version = self.manager.add_sdk_version(
sdk_id = self.created_ids["sdk"][0], sdk_id=self.created_ids["sdk"][0],
version = "1.1.0", version="1.1.0",
is_lts = True, is_lts=True,
release_notes = "Bug fixes and improvements", release_notes="Bug fixes and improvements",
download_url = "https://pypi.org/insightflow/1.1.0", download_url="https://pypi.org/insightflow/1.1.0",
checksum = "xyz789", checksum="xyz789",
file_size = 1100000, file_size=1100000,
) )
self.log(f"Added SDK version: {version.version}") self.log(f"Added SDK version: {version.version}")
except Exception as e: except Exception as e:
self.log(f"Failed to add SDK version: {str(e)}", success = False) self.log(f"Failed to add SDK version: {str(e)}", success=False)
def test_template_create(self) -> None: def test_template_create(self) -> None:
"""测试创建模板""" """测试创建模板"""
try: try:
template = self.manager.create_template( template = self.manager.create_template(
name = "医疗行业实体识别模板", name="医疗行业实体识别模板",
description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体", description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
category = TemplateCategory.MEDICAL, category=TemplateCategory.MEDICAL,
subcategory = "entity_recognition", subcategory="entity_recognition",
tags = ["medical", "healthcare", "ner"], tags=["medical", "healthcare", "ner"],
author_id = "dev_001", author_id="dev_001",
author_name = "Medical AI Lab", author_name="Medical AI Lab",
price = 99.0, price=99.0,
currency = "CNY", currency="CNY",
preview_image_url = "https://cdn.insightflow.io/templates/medical.png", preview_image_url="https://cdn.insightflow.io/templates/medical.png",
demo_url = "https://demo.insightflow.io/medical", demo_url="https://demo.insightflow.io/medical",
documentation_url = "https://docs.insightflow.io/templates/medical", documentation_url="https://docs.insightflow.io/templates/medical",
download_url = "https://cdn.insightflow.io/templates/medical.zip", download_url="https://cdn.insightflow.io/templates/medical.zip",
version = "1.0.0", version="1.0.0",
min_platform_version = "2.0.0", min_platform_version="2.0.0",
file_size = 5242880, file_size=5242880,
checksum = "tpl123", checksum="tpl123",
) )
self.created_ids["template"].append(template.id) self.created_ids["template"].append(template.id)
self.log(f"Created template: {template.name} ({template.id})") self.log(f"Created template: {template.name} ({template.id})")
# Create free template # Create free template
template_free = self.manager.create_template( template_free = self.manager.create_template(
name = "通用实体识别模板", name="通用实体识别模板",
description = "适用于一般场景的实体识别模板", description="适用于一般场景的实体识别模板",
category = TemplateCategory.GENERAL, category=TemplateCategory.GENERAL,
subcategory = None, subcategory=None,
tags = ["general", "ner", "basic"], tags=["general", "ner", "basic"],
author_id = "dev_002", author_id="dev_002",
author_name = "InsightFlow Team", author_name="InsightFlow Team",
price = 0.0, price=0.0,
currency = "CNY", currency="CNY",
) )
self.created_ids["template"].append(template_free.id) self.created_ids["template"].append(template_free.id)
self.log(f"Created free template: {template_free.name}") self.log(f"Created free template: {template_free.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create template: {str(e)}", success = False) self.log(f"Failed to create template: {str(e)}", success=False)
def test_template_list(self) -> None: def test_template_list(self) -> None:
"""测试列出模板""" """测试列出模板"""
@@ -282,15 +282,15 @@ class TestDeveloperEcosystem:
self.log(f"Listed {len(templates)} templates") self.log(f"Listed {len(templates)} templates")
# Filter by category # Filter by category
medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL) medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL)
self.log(f"Found {len(medical_templates)} medical templates") self.log(f"Found {len(medical_templates)} medical templates")
# Filter by price # Filter by price
free_templates = self.manager.list_templates(max_price = 0) free_templates = self.manager.list_templates(max_price=0)
self.log(f"Found {len(free_templates)} free templates") self.log(f"Found {len(free_templates)} free templates")
except Exception as e: except Exception as e:
self.log(f"Failed to list templates: {str(e)}", success = False) self.log(f"Failed to list templates: {str(e)}", success=False)
def test_template_get(self) -> None: def test_template_get(self) -> None:
"""测试获取模板详情""" """测试获取模板详情"""
@@ -300,19 +300,19 @@ class TestDeveloperEcosystem:
if template: if template:
self.log(f"Retrieved template: {template.name}") self.log(f"Retrieved template: {template.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get template: {str(e)}", success = False) self.log(f"Failed to get template: {str(e)}", success=False)
def test_template_approve(self) -> None: def test_template_approve(self) -> None:
"""测试审核通过模板""" """测试审核通过模板"""
try: try:
if self.created_ids["template"]: if self.created_ids["template"]:
template = self.manager.approve_template( template = self.manager.approve_template(
self.created_ids["template"][0], reviewed_by = "admin_001" self.created_ids["template"][0], reviewed_by="admin_001"
) )
if template: if template:
self.log(f"Approved template: {template.name}") self.log(f"Approved template: {template.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to approve template: {str(e)}", success = False) self.log(f"Failed to approve template: {str(e)}", success=False)
def test_template_publish(self) -> None: def test_template_publish(self) -> None:
"""测试发布模板""" """测试发布模板"""
@@ -322,69 +322,69 @@ class TestDeveloperEcosystem:
if template: if template:
self.log(f"Published template: {template.name}") self.log(f"Published template: {template.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to publish template: {str(e)}", success = False) self.log(f"Failed to publish template: {str(e)}", success=False)
def test_template_review(self) -> None: def test_template_review(self) -> None:
"""测试添加模板评价""" """测试添加模板评价"""
try: try:
if self.created_ids["template"]: if self.created_ids["template"]:
review = self.manager.add_template_review( review = self.manager.add_template_review(
template_id = self.created_ids["template"][0], template_id=self.created_ids["template"][0],
user_id = "user_001", user_id="user_001",
user_name = "Test User", user_name="Test User",
rating = 5, rating=5,
comment = "Great template! Very accurate for medical entities.", comment="Great template! Very accurate for medical entities.",
is_verified_purchase = True, is_verified_purchase=True,
) )
self.log(f"Added template review: {review.rating} stars") self.log(f"Added template review: {review.rating} stars")
except Exception as e: except Exception as e:
self.log(f"Failed to add template review: {str(e)}", success = False) self.log(f"Failed to add template review: {str(e)}", success=False)
def test_plugin_create(self) -> None: def test_plugin_create(self) -> None:
"""测试创建插件""" """测试创建插件"""
try: try:
plugin = self.manager.create_plugin( plugin = self.manager.create_plugin(
name = "飞书机器人集成插件", name="飞书机器人集成插件",
description = "将 InsightFlow 与飞书机器人集成,实现自动通知", description="将 InsightFlow 与飞书机器人集成,实现自动通知",
category = PluginCategory.INTEGRATION, category=PluginCategory.INTEGRATION,
tags = ["feishu", "bot", "integration", "notification"], tags=["feishu", "bot", "integration", "notification"],
author_id = "dev_003", author_id="dev_003",
author_name = "Integration Team", author_name="Integration Team",
price = 49.0, price=49.0,
currency = "CNY", currency="CNY",
pricing_model = "paid", pricing_model="paid",
preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png", preview_image_url="https://cdn.insightflow.io/plugins/feishu.png",
demo_url = "https://demo.insightflow.io/feishu", demo_url="https://demo.insightflow.io/feishu",
documentation_url = "https://docs.insightflow.io/plugins/feishu", documentation_url="https://docs.insightflow.io/plugins/feishu",
repository_url = "https://github.com/insightflow/feishu-plugin", repository_url="https://github.com/insightflow/feishu-plugin",
download_url = "https://cdn.insightflow.io/plugins/feishu.zip", download_url="https://cdn.insightflow.io/plugins/feishu.zip",
webhook_url = "https://api.insightflow.io/webhooks/feishu", webhook_url="https://api.insightflow.io/webhooks/feishu",
permissions = ["read:projects", "write:notifications"], permissions=["read:projects", "write:notifications"],
version = "1.0.0", version="1.0.0",
min_platform_version = "2.0.0", min_platform_version="2.0.0",
file_size = 1048576, file_size=1048576,
checksum = "plg123", checksum="plg123",
) )
self.created_ids["plugin"].append(plugin.id) self.created_ids["plugin"].append(plugin.id)
self.log(f"Created plugin: {plugin.name} ({plugin.id})") self.log(f"Created plugin: {plugin.name} ({plugin.id})")
# Create free plugin # Create free plugin
plugin_free = self.manager.create_plugin( plugin_free = self.manager.create_plugin(
name = "数据导出插件", name="数据导出插件",
description = "支持多种格式的数据导出", description="支持多种格式的数据导出",
category = PluginCategory.ANALYSIS, category=PluginCategory.ANALYSIS,
tags = ["export", "data", "csv", "json"], tags=["export", "data", "csv", "json"],
author_id = "dev_004", author_id="dev_004",
author_name = "Data Team", author_name="Data Team",
price = 0.0, price=0.0,
currency = "CNY", currency="CNY",
pricing_model = "free", pricing_model="free",
) )
self.created_ids["plugin"].append(plugin_free.id) self.created_ids["plugin"].append(plugin_free.id)
self.log(f"Created free plugin: {plugin_free.name}") self.log(f"Created free plugin: {plugin_free.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create plugin: {str(e)}", success = False) self.log(f"Failed to create plugin: {str(e)}", success=False)
def test_plugin_list(self) -> None: def test_plugin_list(self) -> None:
"""测试列出插件""" """测试列出插件"""
@@ -393,11 +393,11 @@ class TestDeveloperEcosystem:
self.log(f"Listed {len(plugins)} plugins") self.log(f"Listed {len(plugins)} plugins")
# Filter by category # Filter by category
integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION) integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION)
self.log(f"Found {len(integration_plugins)} integration plugins") self.log(f"Found {len(integration_plugins)} integration plugins")
except Exception as e: except Exception as e:
self.log(f"Failed to list plugins: {str(e)}", success = False) self.log(f"Failed to list plugins: {str(e)}", success=False)
def test_plugin_get(self) -> None: def test_plugin_get(self) -> None:
"""测试获取插件详情""" """测试获取插件详情"""
@@ -407,7 +407,7 @@ class TestDeveloperEcosystem:
if plugin: if plugin:
self.log(f"Retrieved plugin: {plugin.name}") self.log(f"Retrieved plugin: {plugin.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get plugin: {str(e)}", success = False) self.log(f"Failed to get plugin: {str(e)}", success=False)
def test_plugin_review(self) -> None: def test_plugin_review(self) -> None:
"""测试审核插件""" """测试审核插件"""
@@ -415,14 +415,14 @@ class TestDeveloperEcosystem:
if self.created_ids["plugin"]: if self.created_ids["plugin"]:
plugin = self.manager.review_plugin( plugin = self.manager.review_plugin(
self.created_ids["plugin"][0], self.created_ids["plugin"][0],
reviewed_by = "admin_001", reviewed_by="admin_001",
status = PluginStatus.APPROVED, status=PluginStatus.APPROVED,
notes = "Code review passed", notes="Code review passed",
) )
if plugin: if plugin:
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})") self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
except Exception as e: except Exception as e:
self.log(f"Failed to review plugin: {str(e)}", success = False) self.log(f"Failed to review plugin: {str(e)}", success=False)
def test_plugin_publish(self) -> None: def test_plugin_publish(self) -> None:
"""测试发布插件""" """测试发布插件"""
@@ -432,23 +432,23 @@ class TestDeveloperEcosystem:
if plugin: if plugin:
self.log(f"Published plugin: {plugin.name}") self.log(f"Published plugin: {plugin.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to publish plugin: {str(e)}", success = False) self.log(f"Failed to publish plugin: {str(e)}", success=False)
def test_plugin_review_add(self) -> None: def test_plugin_review_add(self) -> None:
"""测试添加插件评价""" """测试添加插件评价"""
try: try:
if self.created_ids["plugin"]: if self.created_ids["plugin"]:
review = self.manager.add_plugin_review( review = self.manager.add_plugin_review(
plugin_id = self.created_ids["plugin"][0], plugin_id=self.created_ids["plugin"][0],
user_id = "user_002", user_id="user_002",
user_name = "Plugin User", user_name="Plugin User",
rating = 4, rating=4,
comment = "Works great with Feishu!", comment="Works great with Feishu!",
is_verified_purchase = True, is_verified_purchase=True,
) )
self.log(f"Added plugin review: {review.rating} stars") self.log(f"Added plugin review: {review.rating} stars")
except Exception as e: except Exception as e:
self.log(f"Failed to add plugin review: {str(e)}", success = False) self.log(f"Failed to add plugin review: {str(e)}", success=False)
def test_developer_profile_create(self) -> None: def test_developer_profile_create(self) -> None:
"""测试创建开发者档案""" """测试创建开发者档案"""
@@ -457,29 +457,29 @@ class TestDeveloperEcosystem:
unique_id = uuid.uuid4().hex[:8] unique_id = uuid.uuid4().hex[:8]
profile = self.manager.create_developer_profile( profile = self.manager.create_developer_profile(
user_id = f"user_dev_{unique_id}_001", user_id=f"user_dev_{unique_id}_001",
display_name = "张三", display_name="张三",
email = f"zhangsan_{unique_id}@example.com", email=f"zhangsan_{unique_id}@example.com",
bio = "专注于医疗AI和自然语言处理", bio="专注于医疗AI和自然语言处理",
website = "https://zhangsan.dev", website="https://zhangsan.dev",
github_url = "https://github.com/zhangsan", github_url="https://github.com/zhangsan",
avatar_url = "https://cdn.example.com/avatars/zhangsan.png", avatar_url="https://cdn.example.com/avatars/zhangsan.png",
) )
self.created_ids["developer"].append(profile.id) self.created_ids["developer"].append(profile.id)
self.log(f"Created developer profile: {profile.display_name} ({profile.id})") self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
# Create another developer # Create another developer
profile2 = self.manager.create_developer_profile( profile2 = self.manager.create_developer_profile(
user_id = f"user_dev_{unique_id}_002", user_id=f"user_dev_{unique_id}_002",
display_name = "李四", display_name="李四",
email = f"lisi_{unique_id}@example.com", email=f"lisi_{unique_id}@example.com",
bio = "全栈开发者,热爱开源", bio="全栈开发者,热爱开源",
) )
self.created_ids["developer"].append(profile2.id) self.created_ids["developer"].append(profile2.id)
self.log(f"Created developer profile: {profile2.display_name}") self.log(f"Created developer profile: {profile2.display_name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create developer profile: {str(e)}", success = False) self.log(f"Failed to create developer profile: {str(e)}", success=False)
def test_developer_profile_get(self) -> None: def test_developer_profile_get(self) -> None:
"""测试获取开发者档案""" """测试获取开发者档案"""
@@ -489,7 +489,7 @@ class TestDeveloperEcosystem:
if profile: if profile:
self.log(f"Retrieved developer profile: {profile.display_name}") self.log(f"Retrieved developer profile: {profile.display_name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get developer profile: {str(e)}", success = False) self.log(f"Failed to get developer profile: {str(e)}", success=False)
def test_developer_verify(self) -> None: def test_developer_verify(self) -> None:
"""测试验证开发者""" """测试验证开发者"""
@@ -501,7 +501,7 @@ class TestDeveloperEcosystem:
if profile: if profile:
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})") self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
except Exception as e: except Exception as e:
self.log(f"Failed to verify developer: {str(e)}", success = False) self.log(f"Failed to verify developer: {str(e)}", success=False)
def test_developer_stats_update(self) -> None: def test_developer_stats_update(self) -> None:
"""测试更新开发者统计""" """测试更新开发者统计"""
@@ -513,38 +513,38 @@ class TestDeveloperEcosystem:
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates" f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates"
) )
except Exception as e: except Exception as e:
self.log(f"Failed to update developer stats: {str(e)}", success = False) self.log(f"Failed to update developer stats: {str(e)}", success=False)
def test_code_example_create(self) -> None: def test_code_example_create(self) -> None:
"""测试创建代码示例""" """测试创建代码示例"""
try: try:
example = self.manager.create_code_example( example = self.manager.create_code_example(
title = "使用 Python SDK 创建项目", title="使用 Python SDK 创建项目",
description = "演示如何使用 Python SDK 创建新项目", description="演示如何使用 Python SDK 创建新项目",
language = "python", language="python",
category = "quickstart", category="quickstart",
code = """from insightflow import Client code="""from insightflow import Client
client = Client(api_key = "your_api_key") client = Client(api_key = "your_api_key")
project = client.projects.create(name = "My Project") project = client.projects.create(name = "My Project")
print(f"Created project: {project.id}") print(f"Created project: {project.id}")
""", """,
explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。", explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
tags = ["python", "quickstart", "projects"], tags=["python", "quickstart", "projects"],
author_id = "dev_001", author_id="dev_001",
author_name = "InsightFlow Team", author_name="InsightFlow Team",
api_endpoints = ["/api/v1/projects"], api_endpoints=["/api/v1/projects"],
) )
self.created_ids["code_example"].append(example.id) self.created_ids["code_example"].append(example.id)
self.log(f"Created code example: {example.title}") self.log(f"Created code example: {example.title}")
# Create JavaScript example # Create JavaScript example
example_js = self.manager.create_code_example( example_js = self.manager.create_code_example(
title = "使用 JavaScript SDK 上传文件", title="使用 JavaScript SDK 上传文件",
description = "演示如何使用 JavaScript SDK 上传音频文件", description="演示如何使用 JavaScript SDK 上传音频文件",
language = "javascript", language="javascript",
category = "upload", category="upload",
code = """const { Client } = require('insightflow'); code="""const { Client } = require('insightflow');
const client = new Client({ apiKey: 'your_api_key' }); const client = new Client({ apiKey: 'your_api_key' });
const result = await client.uploads.create({ const result = await client.uploads.create({
@@ -553,16 +553,16 @@ const result = await client.uploads.create({
}); });
console.log('Upload complete:', result.id); console.log('Upload complete:', result.id);
""", """,
explanation = "使用 JavaScript SDK 上传文件到 InsightFlow", explanation="使用 JavaScript SDK 上传文件到 InsightFlow",
tags = ["javascript", "upload", "audio"], tags=["javascript", "upload", "audio"],
author_id = "dev_002", author_id="dev_002",
author_name = "JS Team", author_name="JS Team",
) )
self.created_ids["code_example"].append(example_js.id) self.created_ids["code_example"].append(example_js.id)
self.log(f"Created code example: {example_js.title}") self.log(f"Created code example: {example_js.title}")
except Exception as e: except Exception as e:
self.log(f"Failed to create code example: {str(e)}", success = False) self.log(f"Failed to create code example: {str(e)}", success=False)
def test_code_example_list(self) -> None: def test_code_example_list(self) -> None:
"""测试列出代码示例""" """测试列出代码示例"""
@@ -571,11 +571,11 @@ console.log('Upload complete:', result.id);
self.log(f"Listed {len(examples)} code examples") self.log(f"Listed {len(examples)} code examples")
# Filter by language # Filter by language
python_examples = self.manager.list_code_examples(language = "python") python_examples = self.manager.list_code_examples(language="python")
self.log(f"Found {len(python_examples)} Python examples") self.log(f"Found {len(python_examples)} Python examples")
except Exception as e: except Exception as e:
self.log(f"Failed to list code examples: {str(e)}", success = False) self.log(f"Failed to list code examples: {str(e)}", success=False)
def test_code_example_get(self) -> None: def test_code_example_get(self) -> None:
"""测试获取代码示例详情""" """测试获取代码示例详情"""
@@ -587,28 +587,28 @@ console.log('Upload complete:', result.id);
f"Retrieved code example: {example.title} (views: {example.view_count})" f"Retrieved code example: {example.title} (views: {example.view_count})"
) )
except Exception as e: except Exception as e:
self.log(f"Failed to get code example: {str(e)}", success = False) self.log(f"Failed to get code example: {str(e)}", success=False)
def test_portal_config_create(self) -> None: def test_portal_config_create(self) -> None:
"""测试创建开发者门户配置""" """测试创建开发者门户配置"""
try: try:
config = self.manager.create_portal_config( config = self.manager.create_portal_config(
name = "InsightFlow Developer Portal", name="InsightFlow Developer Portal",
description = "开发者门户 - SDK、API 文档和示例代码", description="开发者门户 - SDK、API 文档和示例代码",
theme = "default", theme="default",
primary_color = "#1890ff", primary_color="#1890ff",
secondary_color = "#52c41a", secondary_color="#52c41a",
support_email = "developers@insightflow.io", support_email="developers@insightflow.io",
support_url = "https://support.insightflow.io", support_url="https://support.insightflow.io",
github_url = "https://github.com/insightflow", github_url="https://github.com/insightflow",
discord_url = "https://discord.gg/insightflow", discord_url="https://discord.gg/insightflow",
api_base_url = "https://api.insightflow.io/v1", api_base_url="https://api.insightflow.io/v1",
) )
self.created_ids["portal_config"].append(config.id) self.created_ids["portal_config"].append(config.id)
self.log(f"Created portal config: {config.name}") self.log(f"Created portal config: {config.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to create portal config: {str(e)}", success = False) self.log(f"Failed to create portal config: {str(e)}", success=False)
def test_portal_config_get(self) -> None: def test_portal_config_get(self) -> None:
"""测试获取开发者门户配置""" """测试获取开发者门户配置"""
@@ -624,27 +624,27 @@ console.log('Upload complete:', result.id);
self.log(f"Active portal config: {active_config.name}") self.log(f"Active portal config: {active_config.name}")
except Exception as e: except Exception as e:
self.log(f"Failed to get portal config: {str(e)}", success = False) self.log(f"Failed to get portal config: {str(e)}", success=False)
def test_revenue_record(self) -> None: def test_revenue_record(self) -> None:
"""测试记录开发者收益""" """测试记录开发者收益"""
try: try:
if self.created_ids["developer"] and self.created_ids["plugin"]: if self.created_ids["developer"] and self.created_ids["plugin"]:
revenue = self.manager.record_revenue( revenue = self.manager.record_revenue(
developer_id = self.created_ids["developer"][0], developer_id=self.created_ids["developer"][0],
item_type = "plugin", item_type="plugin",
item_id = self.created_ids["plugin"][0], item_id=self.created_ids["plugin"][0],
item_name = "飞书机器人集成插件", item_name="飞书机器人集成插件",
sale_amount = 49.0, sale_amount=49.0,
currency = "CNY", currency="CNY",
buyer_id = "user_buyer_001", buyer_id="user_buyer_001",
transaction_id = "txn_123456", transaction_id="txn_123456",
) )
self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}") self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}")
self.log(f" - Platform fee: {revenue.platform_fee}") self.log(f" - Platform fee: {revenue.platform_fee}")
self.log(f" - Developer earnings: {revenue.developer_earnings}") self.log(f" - Developer earnings: {revenue.developer_earnings}")
except Exception as e: except Exception as e:
self.log(f"Failed to record revenue: {str(e)}", success = False) self.log(f"Failed to record revenue: {str(e)}", success=False)
def test_revenue_summary(self) -> None: def test_revenue_summary(self) -> None:
"""测试获取开发者收益汇总""" """测试获取开发者收益汇总"""
@@ -659,7 +659,7 @@ console.log('Upload complete:', result.id);
self.log(f" - Total earnings: {summary['total_earnings']}") self.log(f" - Total earnings: {summary['total_earnings']}")
self.log(f" - Transaction count: {summary['transaction_count']}") self.log(f" - Transaction count: {summary['transaction_count']}")
except Exception as e: except Exception as e:
self.log(f"Failed to get revenue summary: {str(e)}", success = False) self.log(f"Failed to get revenue summary: {str(e)}", success=False)
def print_summary(self) -> None: def print_summary(self) -> None:
"""打印测试摘要""" """打印测试摘要"""

View File

@@ -80,39 +80,39 @@ class TestOpsManager:
try: try:
# 创建阈值告警规则 # 创建阈值告警规则
rule1 = self.manager.create_alert_rule( rule1 = self.manager.create_alert_rule(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "CPU 使用率告警", name="CPU 使用率告警",
description = "当 CPU 使用率超过 80% 时触发告警", description="当 CPU 使用率超过 80% 时触发告警",
rule_type = AlertRuleType.THRESHOLD, rule_type=AlertRuleType.THRESHOLD,
severity = AlertSeverity.P1, severity=AlertSeverity.P1,
metric = "cpu_usage_percent", metric="cpu_usage_percent",
condition = ">", condition=">",
threshold = 80.0, threshold=80.0,
duration = 300, duration=300,
evaluation_interval = 60, evaluation_interval=60,
channels = [], channels=[],
labels = {"service": "api", "team": "platform"}, labels={"service": "api", "team": "platform"},
annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by = "test_user", created_by="test_user",
) )
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})") self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
# 创建异常检测告警规则 # 创建异常检测告警规则
rule2 = self.manager.create_alert_rule( rule2 = self.manager.create_alert_rule(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "内存异常检测", name="内存异常检测",
description = "检测内存使用异常", description="检测内存使用异常",
rule_type = AlertRuleType.ANOMALY, rule_type=AlertRuleType.ANOMALY,
severity = AlertSeverity.P2, severity=AlertSeverity.P2,
metric = "memory_usage_percent", metric="memory_usage_percent",
condition = ">", condition=">",
threshold = 0.0, threshold=0.0,
duration = 600, duration=600,
evaluation_interval = 300, evaluation_interval=300,
channels = [], channels=[],
labels = {"service": "database"}, labels={"service": "database"},
annotations = {}, annotations={},
created_by = "test_user", created_by="test_user",
) )
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})") self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
@@ -129,7 +129,7 @@ class TestOpsManager:
# 更新告警规则 # 更新告警规则
updated_rule = self.manager.update_alert_rule( updated_rule = self.manager.update_alert_rule(
rule1.id, threshold = 85.0, description = "更新后的描述" rule1.id, threshold=85.0, description="更新后的描述"
) )
assert updated_rule.threshold == 85.0 assert updated_rule.threshold == 85.0
self.log(f"Updated alert rule threshold to {updated_rule.threshold}") self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
@@ -140,7 +140,7 @@ class TestOpsManager:
self.log("Deleted test alert rules") self.log("Deleted test alert rules")
except Exception as e: except Exception as e:
self.log(f"Alert rules test failed: {e}", success = False) self.log(f"Alert rules test failed: {e}", success=False)
def test_alert_channels(self) -> None: def test_alert_channels(self) -> None:
"""测试告警渠道管理""" """测试告警渠道管理"""
@@ -149,37 +149,37 @@ class TestOpsManager:
try: try:
# 创建飞书告警渠道 # 创建飞书告警渠道
channel1 = self.manager.create_alert_channel( channel1 = self.manager.create_alert_channel(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "飞书告警", name="飞书告警",
channel_type = AlertChannelType.FEISHU, channel_type=AlertChannelType.FEISHU,
config = { config={
"webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test", "webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test",
"secret": "test_secret", "secret": "test_secret",
}, },
severity_filter = ["p0", "p1"], severity_filter=["p0", "p1"],
) )
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})") self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
# 创建钉钉告警渠道 # 创建钉钉告警渠道
channel2 = self.manager.create_alert_channel( channel2 = self.manager.create_alert_channel(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "钉钉告警", name="钉钉告警",
channel_type = AlertChannelType.DINGTALK, channel_type=AlertChannelType.DINGTALK,
config = { config={
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test", "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test",
"secret": "test_secret", "secret": "test_secret",
}, },
severity_filter = ["p0", "p1", "p2"], severity_filter=["p0", "p1", "p2"],
) )
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})") self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
# 创建 Slack 告警渠道 # 创建 Slack 告警渠道
channel3 = self.manager.create_alert_channel( channel3 = self.manager.create_alert_channel(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "Slack 告警", name="Slack 告警",
channel_type = AlertChannelType.SLACK, channel_type=AlertChannelType.SLACK,
config = {"webhook_url": "https://hooks.slack.com/services/test"}, config={"webhook_url": "https://hooks.slack.com/services/test"},
severity_filter = ["p0", "p1", "p2", "p3"], severity_filter=["p0", "p1", "p2", "p3"],
) )
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})") self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
@@ -203,7 +203,7 @@ class TestOpsManager:
self.log("Deleted test alert channels") self.log("Deleted test alert channels")
except Exception as e: except Exception as e:
self.log(f"Alert channels test failed: {e}", success = False) self.log(f"Alert channels test failed: {e}", success=False)
def test_alerts(self) -> None: def test_alerts(self) -> None:
"""测试告警管理""" """测试告警管理"""
@@ -212,32 +212,32 @@ class TestOpsManager:
try: try:
# 创建告警规则 # 创建告警规则
rule = self.manager.create_alert_rule( rule = self.manager.create_alert_rule(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "测试告警规则", name="测试告警规则",
description = "用于测试的告警规则", description="用于测试的告警规则",
rule_type = AlertRuleType.THRESHOLD, rule_type=AlertRuleType.THRESHOLD,
severity = AlertSeverity.P1, severity=AlertSeverity.P1,
metric = "test_metric", metric="test_metric",
condition = ">", condition=">",
threshold = 100.0, threshold=100.0,
duration = 60, duration=60,
evaluation_interval = 60, evaluation_interval=60,
channels = [], channels=[],
labels = {}, labels={},
annotations = {}, annotations={},
created_by = "test_user", created_by="test_user",
) )
# 记录资源指标 # 记录资源指标
for i in range(10): for i in range(10):
self.manager.record_resource_metric( self.manager.record_resource_metric(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
resource_type = ResourceType.CPU, resource_type=ResourceType.CPU,
resource_id = "server-001", resource_id="server-001",
metric_name = "test_metric", metric_name="test_metric",
metric_value = 110.0 + i, metric_value=110.0 + i,
unit = "percent", unit="percent",
metadata = {"region": "cn-north-1"}, metadata={"region": "cn-north-1"},
) )
self.log("Recorded 10 resource metrics") self.log("Recorded 10 resource metrics")
@@ -248,24 +248,24 @@ class TestOpsManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
alert = Alert( alert = Alert(
id = alert_id, id=alert_id,
rule_id = rule.id, rule_id=rule.id,
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
severity = AlertSeverity.P1, severity=AlertSeverity.P1,
status = AlertStatus.FIRING, status=AlertStatus.FIRING,
title = "测试告警", title="测试告警",
description = "这是一条测试告警", description="这是一条测试告警",
metric = "test_metric", metric="test_metric",
value = 120.0, value=120.0,
threshold = 100.0, threshold=100.0,
labels = {"test": "true"}, labels={"test": "true"},
annotations = {}, annotations={},
started_at = now, started_at=now,
resolved_at = None, resolved_at=None,
acknowledged_by = None, acknowledged_by=None,
acknowledged_at = None, acknowledged_at=None,
notification_sent = {}, notification_sent={},
suppression_count = 0, suppression_count=0,
) )
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
@@ -326,7 +326,7 @@ class TestOpsManager:
self.log("Cleaned up test data") self.log("Cleaned up test data")
except Exception as e: except Exception as e:
self.log(f"Alerts test failed: {e}", success = False) self.log(f"Alerts test failed: {e}", success=False)
def test_capacity_planning(self) -> None: def test_capacity_planning(self) -> None:
"""测试容量规划""" """测试容量规划"""
@@ -334,9 +334,9 @@ class TestOpsManager:
try: try:
# 记录历史指标数据 # 记录历史指标数据
base_time = datetime.now() - timedelta(days = 30) base_time = datetime.now() - timedelta(days=30)
for i in range(30): for i in range(30):
timestamp = (base_time + timedelta(days = i)).isoformat() timestamp = (base_time + timedelta(days=i)).isoformat()
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute( conn.execute(
""" """
@@ -360,13 +360,13 @@ class TestOpsManager:
self.log("Recorded 30 days of historical metrics") self.log("Recorded 30 days of historical metrics")
# 创建容量规划 # 创建容量规划
prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d") prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
plan = self.manager.create_capacity_plan( plan = self.manager.create_capacity_plan(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
resource_type = ResourceType.CPU, resource_type=ResourceType.CPU,
current_capacity = 100.0, current_capacity=100.0,
prediction_date = prediction_date, prediction_date=prediction_date,
confidence = 0.85, confidence=0.85,
) )
self.log(f"Created capacity plan: {plan.id}") self.log(f"Created capacity plan: {plan.id}")
@@ -387,7 +387,7 @@ class TestOpsManager:
self.log("Cleaned up capacity planning test data") self.log("Cleaned up capacity planning test data")
except Exception as e: except Exception as e:
self.log(f"Capacity planning test failed: {e}", success = False) self.log(f"Capacity planning test failed: {e}", success=False)
def test_auto_scaling(self) -> None: def test_auto_scaling(self) -> None:
"""测试自动扩缩容""" """测试自动扩缩容"""
@@ -396,17 +396,17 @@ class TestOpsManager:
try: try:
# 创建自动扩缩容策略 # 创建自动扩缩容策略
policy = self.manager.create_auto_scaling_policy( policy = self.manager.create_auto_scaling_policy(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "API 服务自动扩缩容", name="API 服务自动扩缩容",
resource_type = ResourceType.CPU, resource_type=ResourceType.CPU,
min_instances = 2, min_instances=2,
max_instances = 10, max_instances=10,
target_utilization = 0.7, target_utilization=0.7,
scale_up_threshold = 0.8, scale_up_threshold=0.8,
scale_down_threshold = 0.3, scale_down_threshold=0.3,
scale_up_step = 2, scale_up_step=2,
scale_down_step = 1, scale_down_step=1,
cooldown_period = 300, cooldown_period=300,
) )
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})") self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
@@ -421,7 +421,7 @@ class TestOpsManager:
# 模拟扩缩容评估 # 模拟扩缩容评估
event = self.manager.evaluate_scaling_policy( event = self.manager.evaluate_scaling_policy(
policy_id = policy.id, current_instances = 3, current_utilization = 0.85 policy_id=policy.id, current_instances=3, current_utilization=0.85
) )
if event: if event:
@@ -445,7 +445,7 @@ class TestOpsManager:
self.log("Cleaned up auto scaling test data") self.log("Cleaned up auto scaling test data")
except Exception as e: except Exception as e:
self.log(f"Auto scaling test failed: {e}", success = False) self.log(f"Auto scaling test failed: {e}", success=False)
def test_health_checks(self) -> None: def test_health_checks(self) -> None:
"""测试健康检查""" """测试健康检查"""
@@ -454,29 +454,29 @@ class TestOpsManager:
try: try:
# 创建 HTTP 健康检查 # 创建 HTTP 健康检查
check1 = self.manager.create_health_check( check1 = self.manager.create_health_check(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "API 服务健康检查", name="API 服务健康检查",
target_type = "service", target_type="service",
target_id = "api-service", target_id="api-service",
check_type = "http", check_type="http",
check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200}, check_config={"url": "https://api.insightflow.io/health", "expected_status": 200},
interval = 60, interval=60,
timeout = 10, timeout=10,
retry_count = 3, retry_count=3,
) )
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})") self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
# 创建 TCP 健康检查 # 创建 TCP 健康检查
check2 = self.manager.create_health_check( check2 = self.manager.create_health_check(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "数据库健康检查", name="数据库健康检查",
target_type = "database", target_type="database",
target_id = "postgres-001", target_id="postgres-001",
check_type = "tcp", check_type="tcp",
check_config = {"host": "db.insightflow.io", "port": 5432}, check_config={"host": "db.insightflow.io", "port": 5432},
interval = 30, interval=30,
timeout = 5, timeout=5,
retry_count = 2, retry_count=2,
) )
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})") self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
@@ -500,7 +500,7 @@ class TestOpsManager:
self.log("Cleaned up health check test data") self.log("Cleaned up health check test data")
except Exception as e: except Exception as e:
self.log(f"Health checks test failed: {e}", success = False) self.log(f"Health checks test failed: {e}", success=False)
def test_failover(self) -> None: def test_failover(self) -> None:
"""测试故障转移""" """测试故障转移"""
@@ -509,14 +509,14 @@ class TestOpsManager:
try: try:
# 创建故障转移配置 # 创建故障转移配置
config = self.manager.create_failover_config( config = self.manager.create_failover_config(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "主备数据中心故障转移", name="主备数据中心故障转移",
primary_region = "cn-north-1", primary_region="cn-north-1",
secondary_regions = ["cn-south-1", "cn-east-1"], secondary_regions=["cn-south-1", "cn-east-1"],
failover_trigger = "health_check_failed", failover_trigger="health_check_failed",
auto_failover = False, auto_failover=False,
failover_timeout = 300, failover_timeout=300,
health_check_id = None, health_check_id=None,
) )
self.log(f"Created failover config: {config.name} (ID: {config.id})") self.log(f"Created failover config: {config.name} (ID: {config.id})")
@@ -530,7 +530,7 @@ class TestOpsManager:
# 发起故障转移 # 发起故障转移
event = self.manager.initiate_failover( event = self.manager.initiate_failover(
config_id = config.id, reason = "Primary region health check failed" config_id=config.id, reason="Primary region health check failed"
) )
if event: if event:
@@ -556,7 +556,7 @@ class TestOpsManager:
self.log("Cleaned up failover test data") self.log("Cleaned up failover test data")
except Exception as e: except Exception as e:
self.log(f"Failover test failed: {e}", success = False) self.log(f"Failover test failed: {e}", success=False)
def test_backup(self) -> None: def test_backup(self) -> None:
"""测试备份与恢复""" """测试备份与恢复"""
@@ -565,16 +565,16 @@ class TestOpsManager:
try: try:
# 创建备份任务 # 创建备份任务
job = self.manager.create_backup_job( job = self.manager.create_backup_job(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
name = "每日数据库备份", name="每日数据库备份",
backup_type = "full", backup_type="full",
target_type = "database", target_type="database",
target_id = "postgres-main", target_id="postgres-main",
schedule = "0 2 * * *", # 每天凌晨2点 schedule="0 2 * * *", # 每天凌晨2点
retention_days = 30, retention_days=30,
encryption_enabled = True, encryption_enabled=True,
compression_enabled = True, compression_enabled=True,
storage_location = "s3://insightflow-backups/", storage_location="s3://insightflow-backups/",
) )
self.log(f"Created backup job: {job.name} (ID: {job.id})") self.log(f"Created backup job: {job.name} (ID: {job.id})")
@@ -610,7 +610,7 @@ class TestOpsManager:
self.log("Cleaned up backup test data") self.log("Cleaned up backup test data")
except Exception as e: except Exception as e:
self.log(f"Backup test failed: {e}", success = False) self.log(f"Backup test failed: {e}", success=False)
def test_cost_optimization(self) -> None: def test_cost_optimization(self) -> None:
"""测试成本优化""" """测试成本优化"""
@@ -622,15 +622,15 @@ class TestOpsManager:
for i in range(5): for i in range(5):
self.manager.record_resource_utilization( self.manager.record_resource_utilization(
tenant_id = self.tenant_id, tenant_id=self.tenant_id,
resource_type = ResourceType.CPU, resource_type=ResourceType.CPU,
resource_id = f"server-{i:03d}", resource_id=f"server-{i:03d}",
utilization_rate = 0.05 + random.random() * 0.1, # 低利用率 utilization_rate=0.05 + random.random() * 0.1, # 低利用率
peak_utilization = 0.15, peak_utilization=0.15,
avg_utilization = 0.08, avg_utilization=0.08,
idle_time_percent = 0.85, idle_time_percent=0.85,
report_date = report_date, report_date=report_date,
recommendations = ["Consider downsizing this resource"], recommendations=["Consider downsizing this resource"],
) )
self.log("Recorded 5 resource utilization records") self.log("Recorded 5 resource utilization records")
@@ -638,7 +638,7 @@ class TestOpsManager:
# 生成成本报告 # 生成成本报告
now = datetime.now() now = datetime.now()
report = self.manager.generate_cost_report( report = self.manager.generate_cost_report(
tenant_id = self.tenant_id, year = now.year, month = now.month tenant_id=self.tenant_id, year=now.year, month=now.month
) )
self.log(f"Generated cost report: {report.id}") self.log(f"Generated cost report: {report.id}")
@@ -698,7 +698,7 @@ class TestOpsManager:
self.log("Cleaned up cost optimization test data") self.log("Cleaned up cost optimization test data")
except Exception as e: except Exception as e:
self.log(f"Cost optimization test failed: {e}", success = False) self.log(f"Cost optimization test failed: {e}", success=False)
def print_summary(self) -> None: def print_summary(self) -> None:
"""打印测试总结""" """打印测试总结"""

View File

@@ -43,17 +43,17 @@ class TingwuClient:
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config( config = open_api_models.Config(
access_key_id = self.access_key, access_key_secret = self.secret_key access_key_id=self.access_key, access_key_secret=self.secret_key
) )
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)
request = tingwu_models.CreateTaskRequest( request = tingwu_models.CreateTaskRequest(
type = "offline", type="offline",
input = tingwu_models.Input(source = "OSS", file_url = audio_url), input=tingwu_models.Input(source="OSS", file_url=audio_url),
parameters = tingwu_models.Parameters( parameters=tingwu_models.Parameters(
transcription = tingwu_models.Transcription( transcription=tingwu_models.Transcription(
diarization_enabled = True, sentence_max_length = 20 diarization_enabled=True, sentence_max_length=20
) )
), ),
) )
@@ -78,9 +78,12 @@ class TingwuClient:
"""获取任务结果""" """获取任务结果"""
try: try:
# 导入移到文件顶部会导致循环导入,保持在这里 # 导入移到文件顶部会导致循环导入,保持在这里
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
from alibabacloud_openapi_util import models as open_api_models
config = open_api_models.Config( config = open_api_models.Config(
access_key_id = self.access_key, access_key_secret = self.secret_key access_key_id=self.access_key, access_key_secret=self.secret_key
) )
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)

View File

@@ -27,7 +27,6 @@ from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger from apscheduler.triggers.interval import IntervalTrigger
from workflow_manager import WorkflowManager
import urllib.parse import urllib.parse
# Constants # Constants
@@ -37,7 +36,7 @@ DEFAULT_RETRY_COUNT = 3 # 默认重试次数
DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒) DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒)
# Configure logging # Configure logging
logging.basicConfig(level = logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -87,9 +86,9 @@ class WorkflowTask:
workflow_id: str workflow_id: str
name: str name: str
task_type: str # analyze, align, discover_relations, notify, custom task_type: str # analyze, align, discover_relations, notify, custom
config: dict = field(default_factory = dict) config: dict = field(default_factory=dict)
order: int = 0 order: int = 0
depends_on: list[str] = field(default_factory = list) depends_on: list[str] = field(default_factory=list)
timeout_seconds: int = 300 timeout_seconds: int = 300
retry_count: int = 3 retry_count: int = 3
retry_delay: int = 5 retry_delay: int = 5
@@ -112,7 +111,7 @@ class WebhookConfig:
webhook_type: str # feishu, dingtalk, slack, custom webhook_type: str # feishu, dingtalk, slack, custom
url: str url: str
secret: str = "" # 用于签名验证 secret: str = "" # 用于签名验证
headers: dict = field(default_factory = dict) headers: dict = field(default_factory=dict)
template: str = "" # 消息模板 template: str = "" # 消息模板
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
@@ -140,8 +139,8 @@ class Workflow:
status: str = "active" status: str = "active"
schedule: str | None = None # cron expression or interval schedule: str | None = None # cron expression or interval
schedule_type: str = "manual" # manual, cron, interval schedule_type: str = "manual" # manual, cron, interval
config: dict = field(default_factory = dict) config: dict = field(default_factory=dict)
webhook_ids: list[str] = field(default_factory = list) webhook_ids: list[str] = field(default_factory=list)
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@@ -169,8 +168,8 @@ class WorkflowLog:
start_time: str | None = None start_time: str | None = None
end_time: str | None = None end_time: str | None = None
duration_ms: int = 0 duration_ms: int = 0
input_data: dict = field(default_factory = dict) input_data: dict = field(default_factory=dict)
output_data: dict = field(default_factory = dict) output_data: dict = field(default_factory=dict)
error_message: str = "" error_message: str = ""
created_at: str = "" created_at: str = ""
@@ -183,7 +182,7 @@ class WebhookNotifier:
"""Webhook 通知器 - 支持飞书、钉钉、Slack""" """Webhook 通知器 - 支持飞书、钉钉、Slack"""
def __init__(self) -> None: def __init__(self) -> None:
self.http_client = httpx.AsyncClient(timeout = 30.0) self.http_client = httpx.AsyncClient(timeout=30.0)
async def send(self, config: WebhookConfig, message: dict) -> bool: async def send(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Webhook 通知""" """发送 Webhook 通知"""
@@ -210,7 +209,7 @@ class WebhookNotifier:
# 签名计算 # 签名计算
if config.secret: if config.secret:
string_to_sign = f"{timestamp}\n{config.secret}" string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod = hashlib.sha256).digest() hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
else: else:
sign = "" sign = ""
@@ -250,7 +249,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json = payload, headers = headers) response = await self.http_client.post(config.url, json=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -265,7 +264,7 @@ class WebhookNotifier:
secret_enc = config.secret.encode("utf-8") secret_enc = config.secret.encode("utf-8")
string_to_sign = f"{timestamp}\n{config.secret}" string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new( hmac_code = hmac.new(
secret_enc, string_to_sign.encode("utf-8"), digestmod = hashlib.sha256 secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256
).digest() ).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
url = f"{config.url}&timestamp = {timestamp}&sign = {sign}" url = f"{config.url}&timestamp = {timestamp}&sign = {sign}"
@@ -295,7 +294,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(url, json = payload, headers = headers) response = await self.http_client.post(url, json=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -316,7 +315,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json = payload, headers = headers) response = await self.http_client.post(config.url, json=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
return response.text == "ok" return response.text == "ok"
@@ -325,7 +324,7 @@ class WebhookNotifier:
"""发送自定义 Webhook 通知""" """发送自定义 Webhook 通知"""
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json = message, headers = headers) response = await self.http_client.post(config.url, json=message, headers=headers)
response.raise_for_status() response.raise_for_status()
return True return True
@@ -343,7 +342,7 @@ class WorkflowManager:
DEFAULT_RETRY_COUNT: int = 3 DEFAULT_RETRY_COUNT: int = 3
DEFAULT_RETRY_DELAY: int = 5 DEFAULT_RETRY_DELAY: int = 5
def __init__(self, db_manager = None) -> None: def __init__(self, db_manager=None) -> None:
self.db = db_manager self.db = db_manager
self.scheduler = AsyncIOScheduler() self.scheduler = AsyncIOScheduler()
self.notifier = WebhookNotifier() self.notifier = WebhookNotifier()
@@ -381,13 +380,13 @@ class WorkflowManager:
def stop(self) -> None: def stop(self) -> None:
"""停止工作流管理器""" """停止工作流管理器"""
if self.scheduler.running: if self.scheduler.running:
self.scheduler.shutdown(wait = True) self.scheduler.shutdown(wait=True)
logger.info("Workflow scheduler stopped") logger.info("Workflow scheduler stopped")
async def _load_and_schedule_workflows(self) -> None: async def _load_and_schedule_workflows(self) -> None:
"""从数据库加载并调度所有活跃工作流""" """从数据库加载并调度所有活跃工作流"""
try: try:
workflows = self.list_workflows(status = "active") workflows = self.list_workflows(status="active")
for workflow in workflows: for workflow in workflows:
if workflow.schedule and workflow.is_active: if workflow.schedule and workflow.is_active:
self._schedule_workflow(workflow) self._schedule_workflow(workflow)
@@ -408,18 +407,18 @@ class WorkflowManager:
elif workflow.schedule_type == "interval": elif workflow.schedule_type == "interval":
# 间隔调度 # 间隔调度
interval_minutes = int(workflow.schedule) interval_minutes = int(workflow.schedule)
trigger = IntervalTrigger(minutes = interval_minutes) trigger = IntervalTrigger(minutes=interval_minutes)
else: else:
return return
self.scheduler.add_job( self.scheduler.add_job(
func = self._execute_workflow_job, func=self._execute_workflow_job,
trigger = trigger, trigger=trigger,
id = job_id, id=job_id,
args = [workflow.id], args=[workflow.id],
replace_existing = True, replace_existing=True,
max_instances = 1, max_instances=1,
coalesce = True, coalesce=True,
) )
logger.info( logger.info(
@@ -598,24 +597,24 @@ class WorkflowManager:
def _row_to_workflow(self, row) -> Workflow: def _row_to_workflow(self, row) -> Workflow:
"""将数据库行转换为 Workflow 对象""" """将数据库行转换为 Workflow 对象"""
return Workflow( return Workflow(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
description = row["description"] or "", description=row["description"] or "",
workflow_type = row["workflow_type"], workflow_type=row["workflow_type"],
project_id = row["project_id"], project_id=row["project_id"],
status = row["status"], status=row["status"],
schedule = row["schedule"], schedule=row["schedule"],
schedule_type = row["schedule_type"], schedule_type=row["schedule_type"],
config = json.loads(row["config"]) if row["config"] else {}, config=json.loads(row["config"]) if row["config"] else {},
webhook_ids = json.loads(row["webhook_ids"]) if row["webhook_ids"] else [], webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [],
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
last_run_at = row["last_run_at"], last_run_at=row["last_run_at"],
next_run_at = row["next_run_at"], next_run_at=row["next_run_at"],
run_count = row["run_count"] or 0, run_count=row["run_count"] or 0,
success_count = row["success_count"] or 0, success_count=row["success_count"] or 0,
fail_count = row["fail_count"] or 0, fail_count=row["fail_count"] or 0,
) )
# ==================== Workflow Task CRUD ==================== # ==================== Workflow Task CRUD ====================
@@ -729,18 +728,18 @@ class WorkflowManager:
def _row_to_task(self, row) -> WorkflowTask: def _row_to_task(self, row) -> WorkflowTask:
"""将数据库行转换为 WorkflowTask 对象""" """将数据库行转换为 WorkflowTask 对象"""
return WorkflowTask( return WorkflowTask(
id = row["id"], id=row["id"],
workflow_id = row["workflow_id"], workflow_id=row["workflow_id"],
name = row["name"], name=row["name"],
task_type = row["task_type"], task_type=row["task_type"],
config = json.loads(row["config"]) if row["config"] else {}, config=json.loads(row["config"]) if row["config"] else {},
order = row["task_order"] or 0, order=row["task_order"] or 0,
depends_on = json.loads(row["depends_on"]) if row["depends_on"] else [], depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [],
timeout_seconds = row["timeout_seconds"] or 300, timeout_seconds=row["timeout_seconds"] or 300,
retry_count = row["retry_count"] or 3, retry_count=row["retry_count"] or 3,
retry_delay = row["retry_delay"] or 5, retry_delay=row["retry_delay"] or 5,
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
) )
# ==================== Webhook Config CRUD ==================== # ==================== Webhook Config CRUD ====================
@@ -875,19 +874,19 @@ class WorkflowManager:
def _row_to_webhook(self, row) -> WebhookConfig: def _row_to_webhook(self, row) -> WebhookConfig:
"""将数据库行转换为 WebhookConfig 对象""" """将数据库行转换为 WebhookConfig 对象"""
return WebhookConfig( return WebhookConfig(
id = row["id"], id=row["id"],
name = row["name"], name=row["name"],
webhook_type = row["webhook_type"], webhook_type=row["webhook_type"],
url = row["url"], url=row["url"],
secret = row["secret"] or "", secret=row["secret"] or "",
headers = json.loads(row["headers"]) if row["headers"] else {}, headers=json.loads(row["headers"]) if row["headers"] else {},
template = row["template"] or "", template=row["template"] or "",
is_active = bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at = row["created_at"], created_at=row["created_at"],
updated_at = row["updated_at"], updated_at=row["updated_at"],
last_used_at = row["last_used_at"], last_used_at=row["last_used_at"],
success_count = row["success_count"] or 0, success_count=row["success_count"] or 0,
fail_count = row["fail_count"] or 0, fail_count=row["fail_count"] or 0,
) )
# ==================== Workflow Log ==================== # ==================== Workflow Log ====================
@@ -1003,7 +1002,7 @@ class WorkflowManager:
"""获取工作流统计""" """获取工作流统计"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
since = (datetime.now() - timedelta(days = days)).isoformat() since = (datetime.now() - timedelta(days=days)).isoformat()
# 总执行次数 # 总执行次数
total = conn.execute( total = conn.execute(
@@ -1060,17 +1059,17 @@ class WorkflowManager:
def _row_to_log(self, row) -> WorkflowLog: def _row_to_log(self, row) -> WorkflowLog:
"""将数据库行转换为 WorkflowLog 对象""" """将数据库行转换为 WorkflowLog 对象"""
return WorkflowLog( return WorkflowLog(
id = row["id"], id=row["id"],
workflow_id = row["workflow_id"], workflow_id=row["workflow_id"],
task_id = row["task_id"], task_id=row["task_id"],
status = row["status"], status=row["status"],
start_time = row["start_time"], start_time=row["start_time"],
end_time = row["end_time"], end_time=row["end_time"],
duration_ms = row["duration_ms"] or 0, duration_ms=row["duration_ms"] or 0,
input_data = json.loads(row["input_data"]) if row["input_data"] else {}, input_data=json.loads(row["input_data"]) if row["input_data"] else {},
output_data = json.loads(row["output_data"]) if row["output_data"] else {}, output_data=json.loads(row["output_data"]) if row["output_data"] else {},
error_message = row["error_message"] or "", error_message=row["error_message"] or "",
created_at = row["created_at"], created_at=row["created_at"],
) )
# ==================== Workflow Execution ==================== # ==================== Workflow Execution ====================
@@ -1086,15 +1085,15 @@ class WorkflowManager:
# 更新最后运行时间 # 更新最后运行时间
now = datetime.now().isoformat() now = datetime.now().isoformat()
self.update_workflow(workflow_id, last_run_at = now, run_count = workflow.run_count + 1) self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1)
# 创建工作流执行日志 # 创建工作流执行日志
log = WorkflowLog( log = WorkflowLog(
id = str(uuid.uuid4())[:UUID_LENGTH], id=str(uuid.uuid4())[:UUID_LENGTH],
workflow_id = workflow_id, workflow_id=workflow_id,
status = TaskStatus.RUNNING.value, status=TaskStatus.RUNNING.value,
start_time = now, start_time=now,
input_data = input_data or {}, input_data=input_data or {},
) )
self.create_log(log) self.create_log(log)
@@ -1113,21 +1112,21 @@ class WorkflowManager:
results = await self._execute_tasks_with_deps(tasks, input_data, log.id) results = await self._execute_tasks_with_deps(tasks, input_data, log.id)
# 发送通知 # 发送通知
await self._send_workflow_notification(workflow, results, success = True) await self._send_workflow_notification(workflow, results, success=True)
# 更新日志为成功 # 更新日志为成功
end_time = datetime.now() end_time = datetime.now()
duration = int((end_time - start_time).total_seconds() * 1000) duration = int((end_time - start_time).total_seconds() * 1000)
self.update_log( self.update_log(
log.id, log.id,
status = TaskStatus.SUCCESS.value, status=TaskStatus.SUCCESS.value,
end_time = end_time.isoformat(), end_time=end_time.isoformat(),
duration_ms = duration, duration_ms=duration,
output_data = results, output_data=results,
) )
# 更新成功计数 # 更新成功计数
self.update_workflow(workflow_id, success_count = workflow.success_count + 1) self.update_workflow(workflow_id, success_count=workflow.success_count + 1)
return { return {
"success": True, "success": True,
@@ -1145,17 +1144,17 @@ class WorkflowManager:
duration = int((end_time - start_time).total_seconds() * 1000) duration = int((end_time - start_time).total_seconds() * 1000)
self.update_log( self.update_log(
log.id, log.id,
status = TaskStatus.FAILED.value, status=TaskStatus.FAILED.value,
end_time = end_time.isoformat(), end_time=end_time.isoformat(),
duration_ms = duration, duration_ms=duration,
error_message = str(e), error_message=str(e),
) )
# 更新失败计数 # 更新失败计数
self.update_workflow(workflow_id, fail_count = workflow.fail_count + 1) self.update_workflow(workflow_id, fail_count=workflow.fail_count + 1)
# 发送失败通知 # 发送失败通知
await self._send_workflow_notification(workflow, {"error": str(e)}, success = False) await self._send_workflow_notification(workflow, {"error": str(e)}, success=False)
raise raise
@@ -1185,7 +1184,7 @@ class WorkflowManager:
task_input = {**input_data, **results} task_input = {**input_data, **results}
task_coros.append(self._execute_single_task(task, task_input, log_id)) task_coros.append(self._execute_single_task(task, task_input, log_id))
task_results = await asyncio.gather(*task_coros, return_exceptions = True) task_results = await asyncio.gather(*task_coros, return_exceptions=True)
for task, result in zip(ready_tasks, task_results): for task, result in zip(ready_tasks, task_results):
if isinstance(result, Exception): if isinstance(result, Exception):
@@ -1217,25 +1216,25 @@ class WorkflowManager:
# 创建任务日志 # 创建任务日志
task_log = WorkflowLog( task_log = WorkflowLog(
id = str(uuid.uuid4())[:UUID_LENGTH], id=str(uuid.uuid4())[:UUID_LENGTH],
workflow_id = task.workflow_id, workflow_id=task.workflow_id,
task_id = task.id, task_id=task.id,
status = TaskStatus.RUNNING.value, status=TaskStatus.RUNNING.value,
start_time = datetime.now().isoformat(), start_time=datetime.now().isoformat(),
input_data = input_data, input_data=input_data,
) )
self.create_log(task_log) self.create_log(task_log)
try: try:
# 设置超时 # 设置超时
result = await asyncio.wait_for(handler(task, input_data), timeout = task.timeout_seconds) result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds)
# 更新任务日志为成功 # 更新任务日志为成功
self.update_log( self.update_log(
task_log.id, task_log.id,
status = TaskStatus.SUCCESS.value, status=TaskStatus.SUCCESS.value,
end_time = datetime.now().isoformat(), end_time=datetime.now().isoformat(),
output_data = {"result": result} if not isinstance(result, dict) else result, output_data={"result": result} if not isinstance(result, dict) else result,
) )
return result return result
@@ -1243,18 +1242,18 @@ class WorkflowManager:
except TimeoutError: except TimeoutError:
self.update_log( self.update_log(
task_log.id, task_log.id,
status = TaskStatus.FAILED.value, status=TaskStatus.FAILED.value,
end_time = datetime.now().isoformat(), end_time=datetime.now().isoformat(),
error_message = "Task timeout", error_message="Task timeout",
) )
raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s") raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s")
except Exception as e: except Exception as e:
self.update_log( self.update_log(
task_log.id, task_log.id,
status = TaskStatus.FAILED.value, status=TaskStatus.FAILED.value,
end_time = datetime.now().isoformat(), end_time=datetime.now().isoformat(),
error_message = str(e), error_message=str(e),
) )
raise raise
@@ -1476,7 +1475,7 @@ class WorkflowManager:
**结果:** **结果:**
```json ```json
{json.dumps(results, ensure_ascii = False, indent = 2)} {json.dumps(results, ensure_ascii=False, indent=2)}
``` ```
""", """,
} }
@@ -1510,7 +1509,7 @@ class WorkflowManager:
_workflow_manager = None _workflow_manager = None
def get_workflow_manager(db_manager = None) -> WorkflowManager: def get_workflow_manager(db_manager=None) -> WorkflowManager:
"""获取 WorkflowManager 单例""" """获取 WorkflowManager 单例"""
global _workflow_manager global _workflow_manager
if _workflow_manager is None: if _workflow_manager is None: