fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 修复语法错误(运算符空格问题)
- 修复类型注解格式
This commit is contained in:
AutoFix Bot
2026-03-02 06:09:49 +08:00
parent b83265e5fd
commit e23f1fec08
84 changed files with 9492 additions and 9491 deletions

View File

@@ -225,3 +225,7 @@
- 第 541 行: line_too_long
- 第 579 行: line_too_long
- ... 还有 2 个类似问题
## Git 提交结果
✅ 提交并推送成功

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -21,7 +21,7 @@ class CodeIssue:
message: str,
severity: str = "warning",
original_line: str = "",
):
) -> None:
self.file_path = file_path
self.line_no = line_no
self.issue_type = issue_type
@@ -30,14 +30,14 @@ class CodeIssue:
self.original_line = original_line
self.fixed = False
def __repr__(self):
def __repr__(self) -> None:
return f"{self.file_path}:{self.line_no} [{self.severity}] {self.issue_type}: {self.message}"
class CodeFixer:
"""代码自动修复器"""
def __init__(self, project_path: str):
def __init__(self, project_path: str) -> None:
self.project_path = Path(project_path)
self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = []
@@ -55,7 +55,7 @@ class CodeFixer:
def _scan_file(self, file_path: Path) -> None:
"""扫描单个文件"""
try:
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception as e:
@@ -85,7 +85,7 @@ class CodeFixer:
) -> None:
"""检查裸异常捕获"""
for i, line in enumerate(lines, 1):
# 匹配 except: 但不匹配 except Exception: 或 except SpecificError:
# 匹配 except Exception: 但不匹配 except Exception: 或 except SpecificError:
if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line):
# 跳过注释说明的情况
if "# noqa" in line or "# intentional" in line.lower():
@@ -221,10 +221,10 @@ class CodeFixer:
return
patterns = [
(r'password\s*=\s*["\'][^"\']{8,}["\']', "硬编码密码"),
(r'secret_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"),
(r'api_key\s*=\s*["\'][^"\']{8,}["\']', "硬编码 API Key"),
(r'token\s*=\s*["\'][^"\']{8,}["\']', "硬编码 Token"),
(r'password\s* = \s*["\'][^"\']{8, }["\']', "硬编码密码"),
(r'secret_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码密钥"),
(r'api_key\s* = \s*["\'][^"\']{8, }["\']', "硬编码 API Key"),
(r'token\s* = \s*["\'][^"\']{8, }["\']', "硬编码 Token"),
]
for i, line in enumerate(lines, 1):
@@ -241,7 +241,7 @@ class CodeFixer:
if any(x in line.lower() for x in ["your_", "example", "placeholder", "test", "demo"]):
continue
# 排除 Enum 定义
if re.search(r'^\s*[A-Z_]+\s*=', line.strip()):
if re.search(r'^\s*[A-Z_]+\s* = ', line.strip()):
continue
self.manual_issues.append(
CodeIssue(
@@ -275,7 +275,7 @@ class CodeFixer:
continue
try:
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception:
@@ -301,9 +301,9 @@ class CodeFixer:
line_idx = issue.line_no - 1
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
line = lines[line_idx]
# 将 except: 改为 except Exception:
# 将 except Exception: 改为 except Exception:
if re.search(r"except\s*:\s*$", line.strip()):
lines[line_idx] = line.replace("except:", "except Exception:")
lines[line_idx] = line.replace("except Exception:", "except Exception:")
fixed_lines.add(line_idx)
issue.fixed = True
self.fixed_issues.append(issue)
@@ -311,7 +311,7 @@ class CodeFixer:
# 如果文件有修改,写回
if lines != original_lines:
try:
with open(file_path, "w", encoding="utf-8") as f:
with open(file_path, "w", encoding = "utf-8") as f:
f.write("\n".join(lines))
print(f"Fixed issues in {file_path}")
except Exception as e:
@@ -426,16 +426,16 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
# 检查是否有变更
result = subprocess.run(
["git", "status", "--porcelain"],
cwd=project_path,
capture_output=True,
text=True,
cwd = project_path,
capture_output = True,
text = True,
)
if not result.stdout.strip():
return True, "没有需要提交的变更"
# 添加所有变更
subprocess.run(["git", "add", "-A"], cwd=project_path, check=True)
subprocess.run(["git", "add", "-A"], cwd = project_path, check = True)
# 提交
commit_msg = """fix: auto-fix code issues (cron)
@@ -446,11 +446,11 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
- 添加类型注解"""
subprocess.run(
["git", "commit", "-m", commit_msg], cwd=project_path, check=True
["git", "commit", "-m", commit_msg], cwd = project_path, check = True
)
# 推送
subprocess.run(["git", "push"], cwd=project_path, check=True)
subprocess.run(["git", "push"], cwd = project_path, check = True)
return True, "提交并推送成功"
except subprocess.CalledProcessError as e:
@@ -459,7 +459,7 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
return False, f"Git 操作异常: {e}"
def main():
def main() -> None:
project_path = "/root/.openclaw/workspace/projects/insightflow"
print("🔍 开始扫描代码...")
@@ -479,7 +479,7 @@ def main():
# 保存报告
report_path = Path(project_path) / "AUTO_CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding="utf-8") as f:
with open(report_path, "w", encoding = "utf-8") as f:
f.write(report)
print(f"📝 报告已保存到: {report_path}")
@@ -493,12 +493,12 @@ def main():
report += f"\n\n## Git 提交结果\n\n{'' if success else ''} {msg}\n"
# 重新保存完整报告
with open(report_path, "w", encoding="utf-8") as f:
with open(report_path, "w", encoding = "utf-8") as f:
f.write(report)
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print(report)
print("=" * 60)
print(" = " * 60)
return report

View File

@@ -15,10 +15,10 @@ def run_ruff_check(directory: str) -> list[dict]:
"""运行 ruff 检查并返回问题列表"""
try:
result = subprocess.run(
["ruff", "check", "--select=E,W,F,I", "--output-format=json", directory],
capture_output=True,
text=True,
check=False,
["ruff", "check", "--select = E, W, F, I", "--output-format = json", directory],
capture_output = True,
text = True,
check = False,
)
if result.stdout:
return json.loads(result.stdout)
@@ -29,7 +29,7 @@ def run_ruff_check(directory: str) -> list[dict]:
def fix_bare_except(content: str) -> str:
"""修复裸异常捕获 - 将 bare except: 改为 except Exception:"""
"""修复裸异常捕获 - 将 bare except Exception: 改为 except Exception:"""
pattern = r'except\s*:\s*\n'
replacement = 'except Exception:\n'
return re.sub(pattern, replacement, content)
@@ -73,7 +73,7 @@ def fix_undefined_names(content: str, filepath: str) -> str:
def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[str]]:
"""修复单个文件的问题"""
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, 'r', encoding = 'utf-8') as f:
original_content = f.read()
content = original_content
@@ -97,20 +97,20 @@ def fix_file(filepath: str, issues: list[dict]) -> tuple[bool, list[str], list[s
content = fix_bare_except(content)
if content != original_content:
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, 'w', encoding = 'utf-8') as f:
f.write(content)
return True, fixed_issues, manual_fix_needed
return False, fixed_issues, manual_fix_needed
def main():
def main() -> None:
base_dir = Path("/root/.openclaw/workspace/projects/insightflow")
backend_dir = base_dir / "backend"
print("=" * 60)
print(" = " * 60)
print("InsightFlow 代码自动修复")
print("=" * 60)
print(" = " * 60)
print("\n1. 扫描代码问题...")
issues = run_ruff_check(str(backend_dir))
@@ -130,7 +130,7 @@ def main():
issue_types[code] = issue_types.get(code, 0) + 1
print("\n2. 问题类型统计:")
for code, count in sorted(issue_types.items(), key=lambda x: -x[1]):
for code, count in sorted(issue_types.items(), key = lambda x: -x[1]):
print(f" - {code}: {count}")
print("\n3. 尝试自动修复...")
@@ -155,8 +155,8 @@ def main():
try:
subprocess.run(
["ruff", "format", str(backend_dir)],
capture_output=True,
check=False,
capture_output = True,
check = False,
)
print(" 格式化完成")
except Exception as e:
@@ -180,9 +180,9 @@ def main():
if __name__ == "__main__":
report = main()
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("修复报告")
print("=" * 60)
print(" = " * 60)
print(f"总问题数: {report['total_issues']}")
print(f"修复文件数: {report['fixed_files']}")
print(f"自动修复问题数: {report['fixed_issues']}")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -234,20 +234,20 @@ class AIManager:
now = datetime.now().isoformat()
model = CustomModel(
id=model_id,
tenant_id=tenant_id,
name=name,
description=description,
model_type=model_type,
status=ModelStatus.PENDING,
training_data=training_data,
hyperparameters=hyperparameters,
metrics={},
model_path=None,
created_at=now,
updated_at=now,
trained_at=None,
created_by=created_by,
id = model_id,
tenant_id = tenant_id,
name = name,
description = description,
model_type = model_type,
status = ModelStatus.PENDING,
training_data = training_data,
hyperparameters = hyperparameters,
metrics = {},
model_path = None,
created_at = now,
updated_at = now,
trained_at = None,
created_by = created_by,
)
with self._get_db() as conn:
@@ -283,7 +283,7 @@ class AIManager:
def get_custom_model(self, model_id: str) -> CustomModel | None:
"""获取自定义模型"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone()
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id, )).fetchone()
if not row:
return None
@@ -318,12 +318,12 @@ class AIManager:
now = datetime.now().isoformat()
sample = TrainingSample(
id=sample_id,
model_id=model_id,
text=text,
entities=entities,
metadata=metadata or {},
created_at=now,
id = sample_id,
model_id = model_id,
text = text,
entities = entities,
metadata = metadata or {},
created_at = now,
)
with self._get_db() as conn:
@@ -350,7 +350,7 @@ class AIManager:
"""获取训练样本"""
with self._get_db() as conn:
rows = conn.execute(
"SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id,)
"SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id, )
).fetchall()
return [self._row_to_training_sample(row) for row in rows]
@@ -392,7 +392,7 @@ class AIManager:
# 保存模型(模拟)
model_path = f"models/{model_id}.bin"
os.makedirs("models", exist_ok=True)
os.makedirs("models", exist_ok = True)
now = datetime.now().isoformat()
@@ -450,9 +450,9 @@ class AIManager:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
headers = headers,
json = payload,
timeout = 60.0,
)
response.raise_for_status()
result = response.json()
@@ -494,17 +494,17 @@ class AIManager:
result = await self._call_kimi_multimodal(input_urls, prompt)
analysis = MultimodalAnalysis(
id=analysis_id,
tenant_id=tenant_id,
project_id=project_id,
provider=provider,
input_type=input_type,
input_urls=input_urls,
prompt=prompt,
result=result,
tokens_used=result.get("tokens_used", 0),
cost=result.get("cost", 0.0),
created_at=now,
id = analysis_id,
tenant_id = tenant_id,
project_id = project_id,
provider = provider,
input_type = input_type,
input_urls = input_urls,
prompt = prompt,
result = result,
tokens_used = result.get("tokens_used", 0),
cost = result.get("cost", 0.0),
created_at = now,
)
with self._get_db() as conn:
@@ -553,9 +553,9 @@ class AIManager:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=120.0,
headers = headers,
json = payload,
timeout = 120.0,
)
response.raise_for_status()
result = response.json()
@@ -588,9 +588,9 @@ class AIManager:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload,
timeout=120.0,
headers = headers,
json = payload,
timeout = 120.0,
)
response.raise_for_status()
result = response.json()
@@ -623,9 +623,9 @@ class AIManager:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
headers = headers,
json = payload,
timeout = 60.0,
)
response.raise_for_status()
result = response.json()
@@ -670,17 +670,17 @@ class AIManager:
now = datetime.now().isoformat()
rag = KnowledgeGraphRAG(
id=rag_id,
tenant_id=tenant_id,
project_id=project_id,
name=name,
description=description,
kg_config=kg_config,
retrieval_config=retrieval_config,
generation_config=generation_config,
is_active=True,
created_at=now,
updated_at=now,
id = rag_id,
tenant_id = tenant_id,
project_id = project_id,
name = name,
description = description,
kg_config = kg_config,
retrieval_config = retrieval_config,
generation_config = generation_config,
is_active = True,
created_at = now,
updated_at = now,
)
with self._get_db() as conn:
@@ -712,7 +712,7 @@ class AIManager:
def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None:
"""获取知识图谱 RAG 配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone()
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id, )).fetchone()
if not row:
return None
@@ -766,7 +766,7 @@ class AIManager:
if score > 0:
relevant_entities.append({**entity, "relevance_score": score})
relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True)
relevant_entities.sort(key = lambda x: x["relevance_score"], reverse = True)
relevant_entities = relevant_entities[:top_k]
# 检索相关关系
@@ -818,9 +818,9 @@ class AIManager:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
headers = headers,
json = payload,
timeout = 60.0,
)
response.raise_for_status()
result = response.json()
@@ -840,20 +840,20 @@ class AIManager:
]
rag_query = RAGQuery(
id=query_id,
rag_id=rag_id,
query=query,
context=context,
answer=answer,
sources=sources,
confidence=(
id = query_id,
rag_id = rag_id,
query = query,
context = context,
answer = answer,
sources = sources,
confidence = (
sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities)
if relevant_entities
else 0
),
tokens_used=tokens_used,
latency_ms=latency_ms,
created_at=now,
tokens_used = tokens_used,
latency_ms = latency_ms,
created_at = now,
)
with self._get_db() as conn:
@@ -974,9 +974,9 @@ class AIManager:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
headers = headers,
json = payload,
timeout = 60.0,
)
response.raise_for_status()
result = response.json()
@@ -1014,18 +1014,18 @@ class AIManager:
entity_names = [e.get("name", "") for e in entities_mentioned[:10]]
summary = SmartSummary(
id=summary_id,
tenant_id=tenant_id,
project_id=project_id,
source_type=source_type,
source_id=source_id,
summary_type=summary_type,
content=content,
key_points=key_points[:8],
entities_mentioned=entity_names,
confidence=0.85,
tokens_used=tokens_used,
created_at=now,
id = summary_id,
tenant_id = tenant_id,
project_id = project_id,
source_type = source_type,
source_id = source_id,
summary_type = summary_type,
content = content,
key_points = key_points[:8],
entities_mentioned = entity_names,
confidence = 0.85,
tokens_used = tokens_used,
created_at = now,
)
with self._get_db() as conn:
@@ -1072,20 +1072,20 @@ class AIManager:
now = datetime.now().isoformat()
model = PredictionModel(
id=model_id,
tenant_id=tenant_id,
project_id=project_id,
name=name,
prediction_type=prediction_type,
target_entity_type=target_entity_type,
features=features,
model_config=model_config,
accuracy=None,
last_trained_at=None,
prediction_count=0,
is_active=True,
created_at=now,
updated_at=now,
id = model_id,
tenant_id = tenant_id,
project_id = project_id,
name = name,
prediction_type = prediction_type,
target_entity_type = target_entity_type,
features = features,
model_config = model_config,
accuracy = None,
last_trained_at = None,
prediction_count = 0,
is_active = True,
created_at = now,
updated_at = now,
)
with self._get_db() as conn:
@@ -1122,7 +1122,7 @@ class AIManager:
"""获取预测模型"""
with self._get_db() as conn:
row = conn.execute(
"SELECT * FROM prediction_models WHERE id = ?", (model_id,)
"SELECT * FROM prediction_models WHERE id = ?", (model_id, )
).fetchone()
if not row:
@@ -1201,16 +1201,16 @@ class AIManager:
explanation = prediction_data.get("explanation", "基于历史数据模式预测")
result = PredictionResult(
id=prediction_id,
model_id=model_id,
prediction_type=model.prediction_type,
target_id=input_data.get("target_id"),
prediction_data=prediction_data,
confidence=confidence,
explanation=explanation,
actual_value=None,
is_correct=None,
created_at=now,
id = prediction_id,
model_id = model_id,
prediction_type = model.prediction_type,
target_id = input_data.get("target_id"),
prediction_data = prediction_data,
confidence = confidence,
explanation = explanation,
actual_value = None,
is_correct = None,
created_at = now,
)
with self._get_db() as conn:
@@ -1238,7 +1238,7 @@ class AIManager:
# 更新预测计数
conn.execute(
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?",
(model_id,),
(model_id, ),
)
conn.commit()
@@ -1368,7 +1368,7 @@ class AIManager:
predicted_relations = [
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
for rel_type, count in sorted(
relation_counts.items(), key=lambda x: x[1], reverse=True
relation_counts.items(), key = lambda x: x[1], reverse = True
)[:5]
]
@@ -1410,97 +1410,97 @@ class AIManager:
def _row_to_custom_model(self, row) -> CustomModel:
"""将数据库行转换为 CustomModel"""
return CustomModel(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
model_type=ModelType(row["model_type"]),
status=ModelStatus(row["status"]),
training_data=json.loads(row["training_data"]),
hyperparameters=json.loads(row["hyperparameters"]),
metrics=json.loads(row["metrics"]),
model_path=row["model_path"],
created_at=row["created_at"],
updated_at=row["updated_at"],
trained_at=row["trained_at"],
created_by=row["created_by"],
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
description = row["description"],
model_type = ModelType(row["model_type"]),
status = ModelStatus(row["status"]),
training_data = json.loads(row["training_data"]),
hyperparameters = json.loads(row["hyperparameters"]),
metrics = json.loads(row["metrics"]),
model_path = row["model_path"],
created_at = row["created_at"],
updated_at = row["updated_at"],
trained_at = row["trained_at"],
created_by = row["created_by"],
)
def _row_to_training_sample(self, row) -> TrainingSample:
"""将数据库行转换为 TrainingSample"""
return TrainingSample(
id=row["id"],
model_id=row["model_id"],
text=row["text"],
entities=json.loads(row["entities"]),
metadata=json.loads(row["metadata"]),
created_at=row["created_at"],
id = row["id"],
model_id = row["model_id"],
text = row["text"],
entities = json.loads(row["entities"]),
metadata = json.loads(row["metadata"]),
created_at = row["created_at"],
)
def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis:
"""将数据库行转换为 MultimodalAnalysis"""
return MultimodalAnalysis(
id=row["id"],
tenant_id=row["tenant_id"],
project_id=row["project_id"],
provider=MultimodalProvider(row["provider"]),
input_type=row["input_type"],
input_urls=json.loads(row["input_urls"]),
prompt=row["prompt"],
result=json.loads(row["result"]),
tokens_used=row["tokens_used"],
cost=row["cost"],
created_at=row["created_at"],
id = row["id"],
tenant_id = row["tenant_id"],
project_id = row["project_id"],
provider = MultimodalProvider(row["provider"]),
input_type = row["input_type"],
input_urls = json.loads(row["input_urls"]),
prompt = row["prompt"],
result = json.loads(row["result"]),
tokens_used = row["tokens_used"],
cost = row["cost"],
created_at = row["created_at"],
)
def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG:
"""将数据库行转换为 KnowledgeGraphRAG"""
return KnowledgeGraphRAG(
id=row["id"],
tenant_id=row["tenant_id"],
project_id=row["project_id"],
name=row["name"],
description=row["description"],
kg_config=json.loads(row["kg_config"]),
retrieval_config=json.loads(row["retrieval_config"]),
generation_config=json.loads(row["generation_config"]),
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
id = row["id"],
tenant_id = row["tenant_id"],
project_id = row["project_id"],
name = row["name"],
description = row["description"],
kg_config = json.loads(row["kg_config"]),
retrieval_config = json.loads(row["retrieval_config"]),
generation_config = json.loads(row["generation_config"]),
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
)
def _row_to_prediction_model(self, row) -> PredictionModel:
"""将数据库行转换为 PredictionModel"""
return PredictionModel(
id=row["id"],
tenant_id=row["tenant_id"],
project_id=row["project_id"],
name=row["name"],
prediction_type=PredictionType(row["prediction_type"]),
target_entity_type=row["target_entity_type"],
features=json.loads(row["features"]),
model_config=json.loads(row["model_config"]),
accuracy=row["accuracy"],
last_trained_at=row["last_trained_at"],
prediction_count=row["prediction_count"],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
id = row["id"],
tenant_id = row["tenant_id"],
project_id = row["project_id"],
name = row["name"],
prediction_type = PredictionType(row["prediction_type"]),
target_entity_type = row["target_entity_type"],
features = json.loads(row["features"]),
model_config = json.loads(row["model_config"]),
accuracy = row["accuracy"],
last_trained_at = row["last_trained_at"],
prediction_count = row["prediction_count"],
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
)
def _row_to_prediction_result(self, row) -> PredictionResult:
"""将数据库行转换为 PredictionResult"""
return PredictionResult(
id=row["id"],
model_id=row["model_id"],
prediction_type=PredictionType(row["prediction_type"]),
target_id=row["target_id"],
prediction_data=json.loads(row["prediction_data"]),
confidence=row["confidence"],
explanation=row["explanation"],
actual_value=row["actual_value"],
is_correct=row["is_correct"],
created_at=row["created_at"],
id = row["id"],
model_id = row["model_id"],
prediction_type = PredictionType(row["prediction_type"]),
target_id = row["target_id"],
prediction_data = json.loads(row["prediction_data"]),
confidence = row["confidence"],
explanation = row["explanation"],
actual_value = row["actual_value"],
is_correct = row["is_correct"],
created_at = row["created_at"],
)

View File

@@ -152,23 +152,23 @@ class ApiKeyManager:
expires_at = None
if expires_days:
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat()
api_key = ApiKey(
id=key_id,
key_hash=key_hash,
key_preview=key_preview,
name=name,
owner_id=owner_id,
permissions=permissions,
rate_limit=rate_limit,
status=ApiKeyStatus.ACTIVE.value,
created_at=datetime.now().isoformat(),
expires_at=expires_at,
last_used_at=None,
revoked_at=None,
revoked_reason=None,
total_calls=0,
id = key_id,
key_hash = key_hash,
key_preview = key_preview,
name = name,
owner_id = owner_id,
permissions = permissions,
rate_limit = rate_limit,
status = ApiKeyStatus.ACTIVE.value,
created_at = datetime.now().isoformat(),
expires_at = expires_at,
last_used_at = None,
revoked_at = None,
revoked_reason = None,
total_calls = 0,
)
with sqlite3.connect(self.db_path) as conn:
@@ -207,7 +207,7 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone()
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash, )).fetchone()
if not row:
return None
@@ -238,7 +238,7 @@ class ApiKeyManager:
# 验证所有权(如果提供了 owner_id
if owner_id:
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
).fetchone()
if not row or row[0] != owner_id:
return False
@@ -270,7 +270,7 @@ class ApiKeyManager:
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
).fetchone()
else:
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id, )).fetchone()
if row:
return self._row_to_api_key(row)
@@ -287,7 +287,7 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_keys WHERE 1=1"
query = "SELECT * FROM api_keys WHERE 1 = 1"
params = []
if owner_id:
@@ -337,7 +337,7 @@ class ApiKeyManager:
# 验证所有权
if owner_id:
row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id, )
).fetchone()
if not row or row[0] != owner_id:
return False
@@ -370,7 +370,7 @@ class ApiKeyManager:
ip_address: str = "",
user_agent: str = "",
error_message: str = "",
):
) -> None:
"""记录 API 调用日志"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
@@ -405,7 +405,7 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
query = "SELECT * FROM api_call_logs WHERE 1=1"
query = "SELECT * FROM api_call_logs WHERE 1 = 1"
params = []
if api_key_id:
@@ -510,20 +510,20 @@ class ApiKeyManager:
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
"""将数据库行转换为 ApiKey 对象"""
return ApiKey(
id=row["id"],
key_hash=row["key_hash"],
key_preview=row["key_preview"],
name=row["name"],
owner_id=row["owner_id"],
permissions=json.loads(row["permissions"]),
rate_limit=row["rate_limit"],
status=row["status"],
created_at=row["created_at"],
expires_at=row["expires_at"],
last_used_at=row["last_used_at"],
revoked_at=row["revoked_at"],
revoked_reason=row["revoked_reason"],
total_calls=row["total_calls"],
id = row["id"],
key_hash = row["key_hash"],
key_preview = row["key_preview"],
name = row["name"],
owner_id = row["owner_id"],
permissions = json.loads(row["permissions"]),
rate_limit = row["rate_limit"],
status = row["status"],
created_at = row["created_at"],
expires_at = row["expires_at"],
last_used_at = row["last_used_at"],
revoked_at = row["revoked_at"],
revoked_reason = row["revoked_reason"],
total_calls = row["total_calls"],
)

View File

@@ -136,7 +136,7 @@ class TeamSpace:
class CollaborationManager:
"""协作管理主类"""
def __init__(self, db_manager=None):
def __init__(self, db_manager = None) -> None:
self.db = db_manager
self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: dict[str, list[Comment]] = {}
@@ -161,26 +161,26 @@ class CollaborationManager:
now = datetime.now().isoformat()
expires_at = None
if expires_in_days:
expires_at = (datetime.now() + timedelta(days=expires_in_days)).isoformat()
expires_at = (datetime.now() + timedelta(days = expires_in_days)).isoformat()
password_hash = None
if password:
password_hash = hashlib.sha256(password.encode()).hexdigest()
share = ProjectShare(
id=share_id,
project_id=project_id,
token=token,
permission=permission,
created_by=created_by,
created_at=now,
expires_at=expires_at,
max_uses=max_uses,
use_count=0,
password_hash=password_hash,
is_active=True,
allow_download=allow_download,
allow_export=allow_export,
id = share_id,
project_id = project_id,
token = token,
permission = permission,
created_by = created_by,
created_at = now,
expires_at = expires_at,
max_uses = max_uses,
use_count = 0,
password_hash = password_hash,
is_active = True,
allow_download = allow_download,
allow_export = allow_export,
)
# 保存到数据库
@@ -263,7 +263,7 @@ class CollaborationManager:
"""
SELECT * FROM project_shares WHERE token = ?
""",
(token,),
(token, ),
)
row = cursor.fetchone()
@@ -271,19 +271,19 @@ class CollaborationManager:
return None
return ProjectShare(
id=row[0],
project_id=row[1],
token=row[2],
permission=row[3],
created_by=row[4],
created_at=row[5],
expires_at=row[6],
max_uses=row[7],
use_count=row[8],
password_hash=row[9],
is_active=bool(row[10]),
allow_download=bool(row[11]),
allow_export=bool(row[12]),
id = row[0],
project_id = row[1],
token = row[2],
permission = row[3],
created_by = row[4],
created_at = row[5],
expires_at = row[6],
max_uses = row[7],
use_count = row[8],
password_hash = row[9],
is_active = bool(row[10]),
allow_download = bool(row[11]),
allow_export = bool(row[12]),
)
def increment_share_usage(self, token: str) -> None:
@@ -300,7 +300,7 @@ class CollaborationManager:
SET use_count = use_count + 1
WHERE token = ?
""",
(token,),
(token, ),
)
self.db.conn.commit()
@@ -314,7 +314,7 @@ class CollaborationManager:
SET is_active = 0
WHERE id = ?
""",
(share_id,),
(share_id, ),
)
self.db.conn.commit()
return cursor.rowcount > 0
@@ -332,26 +332,26 @@ class CollaborationManager:
WHERE project_id = ?
ORDER BY created_at DESC
""",
(project_id,),
(project_id, ),
)
shares = []
for row in cursor.fetchall():
shares.append(
ProjectShare(
id=row[0],
project_id=row[1],
token=row[2],
permission=row[3],
created_by=row[4],
created_at=row[5],
expires_at=row[6],
max_uses=row[7],
use_count=row[8],
password_hash=row[9],
is_active=bool(row[10]),
allow_download=bool(row[11]),
allow_export=bool(row[12]),
id = row[0],
project_id = row[1],
token = row[2],
permission = row[3],
created_by = row[4],
created_at = row[5],
expires_at = row[6],
max_uses = row[7],
use_count = row[8],
password_hash = row[9],
is_active = bool(row[10]),
allow_download = bool(row[11]),
allow_export = bool(row[12]),
)
)
return shares
@@ -375,21 +375,21 @@ class CollaborationManager:
now = datetime.now().isoformat()
comment = Comment(
id=comment_id,
project_id=project_id,
target_type=target_type,
target_id=target_id,
parent_id=parent_id,
author=author,
author_name=author_name,
content=content,
created_at=now,
updated_at=now,
resolved=False,
resolved_by=None,
resolved_at=None,
mentions=mentions or [],
attachments=attachments or [],
id = comment_id,
project_id = project_id,
target_type = target_type,
target_id = target_id,
parent_id = parent_id,
author = author,
author_name = author_name,
content = content,
created_at = now,
updated_at = now,
resolved = False,
resolved_by = None,
resolved_at = None,
mentions = mentions or [],
attachments = attachments or [],
)
if self.db:
@@ -469,21 +469,21 @@ class CollaborationManager:
def _row_to_comment(self, row) -> Comment:
"""将数据库行转换为Comment对象"""
return Comment(
id=row[0],
project_id=row[1],
target_type=row[2],
target_id=row[3],
parent_id=row[4],
author=row[5],
author_name=row[6],
content=row[7],
created_at=row[8],
updated_at=row[9],
resolved=bool(row[10]),
resolved_by=row[11],
resolved_at=row[12],
mentions=json.loads(row[13]) if row[13] else [],
attachments=json.loads(row[14]) if row[14] else [],
id = row[0],
project_id = row[1],
target_type = row[2],
target_id = row[3],
parent_id = row[4],
author = row[5],
author_name = row[6],
content = row[7],
created_at = row[8],
updated_at = row[9],
resolved = bool(row[10]),
resolved_by = row[11],
resolved_at = row[12],
mentions = json.loads(row[13]) if row[13] else [],
attachments = json.loads(row[14]) if row[14] else [],
)
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
@@ -510,7 +510,7 @@ class CollaborationManager:
def _get_comment_by_id(self, comment_id: str) -> Comment | None:
"""根据ID获取评论"""
cursor = self.db.conn.cursor()
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,))
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id, ))
row = cursor.fetchone()
if row:
return self._row_to_comment(row)
@@ -597,22 +597,22 @@ class CollaborationManager:
now = datetime.now().isoformat()
record = ChangeRecord(
id=record_id,
project_id=project_id,
change_type=change_type,
entity_type=entity_type,
entity_id=entity_id,
entity_name=entity_name,
changed_by=changed_by,
changed_by_name=changed_by_name,
changed_at=now,
old_value=old_value,
new_value=new_value,
description=description,
session_id=session_id,
reverted=False,
reverted_at=None,
reverted_by=None,
id = record_id,
project_id = project_id,
change_type = change_type,
entity_type = entity_type,
entity_id = entity_id,
entity_name = entity_name,
changed_by = changed_by,
changed_by_name = changed_by_name,
changed_at = now,
old_value = old_value,
new_value = new_value,
description = description,
session_id = session_id,
reverted = False,
reverted_at = None,
reverted_by = None,
)
if self.db:
@@ -705,22 +705,22 @@ class CollaborationManager:
def _row_to_change_record(self, row) -> ChangeRecord:
"""将数据库行转换为ChangeRecord对象"""
return ChangeRecord(
id=row[0],
project_id=row[1],
change_type=row[2],
entity_type=row[3],
entity_id=row[4],
entity_name=row[5],
changed_by=row[6],
changed_by_name=row[7],
changed_at=row[8],
old_value=json.loads(row[9]) if row[9] else None,
new_value=json.loads(row[10]) if row[10] else None,
description=row[11],
session_id=row[12],
reverted=bool(row[13]),
reverted_at=row[14],
reverted_by=row[15],
id = row[0],
project_id = row[1],
change_type = row[2],
entity_type = row[3],
entity_id = row[4],
entity_name = row[5],
changed_by = row[6],
changed_by_name = row[7],
changed_at = row[8],
old_value = json.loads(row[9]) if row[9] else None,
new_value = json.loads(row[10]) if row[10] else None,
description = row[11],
session_id = row[12],
reverted = bool(row[13]),
reverted_at = row[14],
reverted_by = row[15],
)
def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
@@ -773,7 +773,7 @@ class CollaborationManager:
"""
SELECT COUNT(*) FROM change_history WHERE project_id = ?
""",
(project_id,),
(project_id, ),
)
total_changes = cursor.fetchone()[0]
@@ -783,7 +783,7 @@ class CollaborationManager:
SELECT change_type, COUNT(*) FROM change_history
WHERE project_id = ? GROUP BY change_type
""",
(project_id,),
(project_id, ),
)
type_counts = {row[0]: row[1] for row in cursor.fetchall()}
@@ -793,7 +793,7 @@ class CollaborationManager:
SELECT entity_type, COUNT(*) FROM change_history
WHERE project_id = ? GROUP BY entity_type
""",
(project_id,),
(project_id, ),
)
entity_type_counts = {row[0]: row[1] for row in cursor.fetchall()}
@@ -806,7 +806,7 @@ class CollaborationManager:
ORDER BY count DESC
LIMIT 5
""",
(project_id,),
(project_id, ),
)
top_contributors = [{"name": row[0], "changes": row[1]} for row in cursor.fetchall()]
@@ -838,16 +838,16 @@ class CollaborationManager:
permissions = self._get_default_permissions(role)
member = TeamMember(
id=member_id,
project_id=project_id,
user_id=user_id,
user_name=user_name,
user_email=user_email,
role=role,
joined_at=now,
invited_by=invited_by,
last_active_at=None,
permissions=permissions,
id = member_id,
project_id = project_id,
user_id = user_id,
user_name = user_name,
user_email = user_email,
role = role,
joined_at = now,
invited_by = invited_by,
last_active_at = None,
permissions = permissions,
)
if self.db:
@@ -902,7 +902,7 @@ class CollaborationManager:
SELECT * FROM team_members WHERE project_id = ?
ORDER BY joined_at ASC
""",
(project_id,),
(project_id, ),
)
members = []
@@ -913,16 +913,16 @@ class CollaborationManager:
def _row_to_team_member(self, row) -> TeamMember:
"""将数据库行转换为TeamMember对象"""
return TeamMember(
id=row[0],
project_id=row[1],
user_id=row[2],
user_name=row[3],
user_email=row[4],
role=row[5],
joined_at=row[6],
invited_by=row[7],
last_active_at=row[8],
permissions=json.loads(row[9]) if row[9] else [],
id = row[0],
project_id = row[1],
user_id = row[2],
user_name = row[3],
user_email = row[4],
role = row[5],
joined_at = row[6],
invited_by = row[7],
last_active_at = row[8],
permissions = json.loads(row[9]) if row[9] else [],
)
def update_member_role(self, member_id: str, new_role: str, updated_by: str) -> bool:
@@ -949,7 +949,7 @@ class CollaborationManager:
return False
cursor = self.db.conn.cursor()
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id,))
cursor.execute("DELETE FROM team_members WHERE id = ?", (member_id, ))
self.db.conn.commit()
return cursor.rowcount > 0
@@ -996,7 +996,7 @@ class CollaborationManager:
_collaboration_manager = None
def get_collaboration_manager(db_manager=None) -> None:
def get_collaboration_manager(db_manager = None) -> None:
"""获取协作管理器单例"""
global _collaboration_manager
if _collaboration_manager is None:

View File

@@ -41,7 +41,7 @@ class Entity:
created_at: str = ""
updated_at: str = ""
def __post_init__(self):
def __post_init__(self) -> None:
if self.aliases is None:
self.aliases = []
if self.attributes is None:
@@ -64,7 +64,7 @@ class AttributeTemplate:
created_at: str = ""
updated_at: str = ""
def __post_init__(self):
def __post_init__(self) -> None:
if self.options is None:
self.options = []
@@ -85,7 +85,7 @@ class EntityAttribute:
created_at: str = ""
updated_at: str = ""
def __post_init__(self):
def __post_init__(self) -> None:
if self.options is None:
self.options = []
@@ -116,12 +116,12 @@ class EntityMention:
class DatabaseManager:
def __init__(self, db_path: str = DB_PATH):
def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
os.makedirs(os.path.dirname(db_path), exist_ok = True)
self.init_db()
def get_conn(self):
def get_conn(self) -> None:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
@@ -149,12 +149,12 @@ class DatabaseManager:
conn.commit()
conn.close()
return Project(
id=project_id, name=name, description=description, created_at=now, updated_at=now
id = project_id, name = name, description = description, created_at = now, updated_at = now
)
def get_project(self, project_id: str) -> Project | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone()
conn.close()
if row:
return Project(**dict(row))
@@ -226,8 +226,8 @@ class DatabaseManager:
"""合并两个实体"""
conn = self.get_conn()
target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id,)).fetchone()
source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id,)).fetchone()
target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id, )).fetchone()
source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id, )).fetchone()
if not target or not source:
conn.close()
@@ -252,7 +252,7 @@ class DatabaseManager:
"UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?",
(target_id, source_id),
)
conn.execute("DELETE FROM entities WHERE id = ?", (source_id,))
conn.execute("DELETE FROM entities WHERE id = ?", (source_id, ))
conn.commit()
conn.close()
@@ -260,7 +260,7 @@ class DatabaseManager:
def get_entity(self, entity_id: str) -> Entity | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone()
conn.close()
if row:
data = dict(row)
@@ -271,7 +271,7 @@ class DatabaseManager:
def list_project_entities(self, project_id: str) -> list[Entity]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,)
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id, )
).fetchall()
conn.close()
@@ -316,13 +316,13 @@ class DatabaseManager:
def delete_entity(self, entity_id: str) -> None:
"""删除实体及其关联数据"""
conn = self.get_conn()
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id, ))
conn.execute(
"DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?",
(entity_id, entity_id),
)
conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id,))
conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,))
conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id, ))
conn.execute("DELETE FROM entities WHERE id = ?", (entity_id, ))
conn.commit()
conn.close()
@@ -352,7 +352,7 @@ class DatabaseManager:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos",
(entity_id,),
(entity_id, ),
).fetchall()
conn.close()
return [EntityMention(**dict(r)) for r in rows]
@@ -366,7 +366,7 @@ class DatabaseManager:
filename: str,
full_text: str,
transcript_type: str = "audio",
):
) -> None:
conn = self.get_conn()
now = datetime.now().isoformat()
conn.execute(
@@ -380,14 +380,14 @@ class DatabaseManager:
def get_transcript(self, transcript_id: str) -> dict | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone()
conn.close()
return dict(row) if row else None
def list_project_transcripts(self, project_id: str) -> list[dict]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id, )
).fetchall()
conn.close()
return [dict(r) for r in rows]
@@ -400,7 +400,7 @@ class DatabaseManager:
(full_text, now, transcript_id),
)
conn.commit()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id, )).fetchone()
conn.close()
return dict(row) if row else None
@@ -414,7 +414,7 @@ class DatabaseManager:
relation_type: str = "related",
evidence: str = "",
transcript_id: str = "",
):
) -> None:
conn = self.get_conn()
relation_id = str(uuid.uuid4())[:UUID_LENGTH]
now = datetime.now().isoformat()
@@ -453,7 +453,7 @@ class DatabaseManager:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC",
(project_id,),
(project_id, ),
).fetchall()
conn.close()
return [dict(r) for r in rows]
@@ -475,13 +475,13 @@ class DatabaseManager:
conn.execute(query, values)
conn.commit()
row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id,)).fetchone()
row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id, )).fetchone()
conn.close()
return dict(row) if row else None
def delete_relation(self, relation_id: str):
def delete_relation(self, relation_id: str) -> None:
conn = self.get_conn()
conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id,))
conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id, ))
conn.commit()
conn.close()
@@ -495,7 +495,7 @@ class DatabaseManager:
if existing:
conn.execute(
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],)
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"], )
)
conn.commit()
conn.close()
@@ -515,14 +515,14 @@ class DatabaseManager:
def list_glossary(self, project_id: str) -> list[dict]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,)
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id, )
).fetchall()
conn.close()
return [dict(r) for r in rows]
def delete_glossary_term(self, term_id: str):
def delete_glossary_term(self, term_id: str) -> None:
conn = self.get_conn()
conn.execute("DELETE FROM glossary WHERE id = ?", (term_id,))
conn.execute("DELETE FROM glossary WHERE id = ?", (term_id, ))
conn.commit()
conn.close()
@@ -539,14 +539,14 @@ class DatabaseManager:
JOIN entities t ON r.target_entity_id = t.id
LEFT JOIN transcripts tr ON r.transcript_id = tr.id
WHERE r.id = ?""",
(relation_id,),
(relation_id, ),
).fetchone()
conn.close()
return dict(row) if row else None
def get_entity_with_mentions(self, entity_id: str) -> dict | None:
conn = self.get_conn()
entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id, )).fetchone()
if not entity_row:
conn.close()
return None
@@ -559,7 +559,7 @@ class DatabaseManager:
FROM entity_mentions m
JOIN transcripts t ON m.transcript_id = t.id
WHERE m.entity_id = ? ORDER BY t.created_at, m.start_pos""",
(entity_id,),
(entity_id, ),
).fetchall()
entity["mentions"] = [dict(m) for m in mentions]
entity["mention_count"] = len(mentions)
@@ -598,24 +598,24 @@ class DatabaseManager:
def get_project_summary(self, project_id: str) -> dict:
conn = self.get_conn()
project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id, )).fetchone()
entity_count = conn.execute(
"SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id, )
).fetchone()["count"]
transcript_count = conn.execute(
"SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id, )
).fetchone()["count"]
relation_count = conn.execute(
"SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id, )
).fetchone()["count"]
recent_transcripts = conn.execute(
"""SELECT filename, full_text, created_at FROM transcripts
WHERE project_id = ? ORDER BY created_at DESC LIMIT 5""",
(project_id,),
(project_id, ),
).fetchall()
top_entities = conn.execute(
@@ -624,7 +624,7 @@ class DatabaseManager:
LEFT JOIN entity_mentions m ON e.id = m.entity_id
WHERE e.project_id = ?
GROUP BY e.id ORDER BY mention_count DESC LIMIT 10""",
(project_id,),
(project_id, ),
).fetchall()
conn.close()
@@ -645,7 +645,7 @@ class DatabaseManager:
) -> str:
conn = self.get_conn()
row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,)
"SELECT full_text FROM transcripts WHERE id = ?", (transcript_id, )
).fetchone()
conn.close()
if not row:
@@ -708,7 +708,7 @@ class DatabaseManager:
)
conn.close()
timeline_events.sort(key=lambda x: x["event_date"])
timeline_events.sort(key = lambda x: x["event_date"])
return timeline_events
def get_entity_timeline_summary(self, project_id: str) -> dict:
@@ -719,7 +719,7 @@ class DatabaseManager:
FROM entity_mentions m
JOIN transcripts t ON m.transcript_id = t.id
WHERE t.project_id = ? GROUP BY DATE(t.created_at) ORDER BY date""",
(project_id,),
(project_id, ),
).fetchall()
entity_stats = conn.execute(
@@ -731,7 +731,7 @@ class DatabaseManager:
LEFT JOIN transcripts t ON m.transcript_id = t.id
WHERE e.project_id = ?
GROUP BY e.id ORDER BY mention_count DESC LIMIT 20""",
(project_id,),
(project_id, ),
).fetchall()
conn.close()
@@ -772,7 +772,7 @@ class DatabaseManager:
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
conn = self.get_conn()
row = conn.execute(
"SELECT * FROM attribute_templates WHERE id = ?", (template_id,)
"SELECT * FROM attribute_templates WHERE id = ?", (template_id, )
).fetchone()
conn.close()
if row:
@@ -786,7 +786,7 @@ class DatabaseManager:
rows = conn.execute(
"""SELECT * FROM attribute_templates WHERE project_id = ?
ORDER BY sort_order, created_at""",
(project_id,),
(project_id, ),
).fetchall()
conn.close()
@@ -830,9 +830,9 @@ class DatabaseManager:
conn.close()
return self.get_attribute_template(template_id)
def delete_attribute_template(self, template_id: str):
def delete_attribute_template(self, template_id: str) -> None:
conn = self.get_conn()
conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id,))
conn.execute("DELETE FROM attribute_templates WHERE id = ?", (template_id, ))
conn.commit()
conn.close()
@@ -905,7 +905,7 @@ class DatabaseManager:
FROM entity_attributes ea
LEFT JOIN attribute_templates at ON ea.template_id = at.id
WHERE ea.entity_id = ? ORDER BY ea.created_at""",
(entity_id,),
(entity_id, ),
).fetchall()
conn.close()
return [EntityAttribute(**dict(r)) for r in rows]
@@ -927,7 +927,7 @@ class DatabaseManager:
def delete_entity_attribute(
self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = ""
):
) -> None:
conn = self.get_conn()
old_row = conn.execute(
"""SELECT value FROM entity_attributes
@@ -973,7 +973,7 @@ class DatabaseManager:
conditions.append("ah.template_id = ?")
params.append(template_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute(
f"""SELECT ah.*
@@ -997,7 +997,7 @@ class DatabaseManager:
return []
conn = self.get_conn()
placeholders = ",".join(["?" for _ in entity_ids])
placeholders = ", ".join(["?" for _ in entity_ids])
rows = conn.execute(
f"""SELECT ea.*, at.name as template_name
FROM entity_attributes ea
@@ -1075,7 +1075,7 @@ class DatabaseManager:
def get_video(self, video_id: str) -> dict | None:
"""获取视频信息"""
conn = self.get_conn()
row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone()
row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id, )).fetchone()
conn.close()
if row:
@@ -1094,7 +1094,7 @@ class DatabaseManager:
"""获取项目的所有视频"""
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
"SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id, )
).fetchall()
conn.close()
@@ -1149,7 +1149,7 @@ class DatabaseManager:
"""获取视频的所有帧"""
conn = self.get_conn()
rows = conn.execute(
"""SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id,)
"""SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id, )
).fetchall()
conn.close()
@@ -1201,7 +1201,7 @@ class DatabaseManager:
def get_image(self, image_id: str) -> dict | None:
"""获取图片信息"""
conn = self.get_conn()
row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone()
row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id, )).fetchone()
conn.close()
if row:
@@ -1219,7 +1219,7 @@ class DatabaseManager:
"""获取项目的所有图片"""
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
"SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id, )
).fetchall()
conn.close()
@@ -1279,7 +1279,7 @@ class DatabaseManager:
FROM multimodal_mentions m
JOIN entities e ON m.entity_id = e.id
WHERE m.entity_id = ? ORDER BY m.created_at DESC""",
(entity_id,),
(entity_id, ),
).fetchall()
conn.close()
return [dict(r) for r in rows]
@@ -1303,7 +1303,7 @@ class DatabaseManager:
FROM multimodal_mentions m
JOIN entities e ON m.entity_id = e.id
WHERE m.project_id = ? ORDER BY m.created_at DESC""",
(project_id,),
(project_id, ),
).fetchall()
conn.close()
@@ -1377,13 +1377,13 @@ class DatabaseManager:
# 视频数量
row = conn.execute(
"SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id, )
).fetchone()
stats["video_count"] = row["count"]
# 图片数量
row = conn.execute(
"SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id, )
).fetchone()
stats["image_count"] = row["count"]
@@ -1391,7 +1391,7 @@ class DatabaseManager:
row = conn.execute(
"""SELECT COUNT(DISTINCT entity_id) as count
FROM multimodal_mentions WHERE project_id = ?""",
(project_id,),
(project_id, ),
).fetchone()
stats["multimodal_entity_count"] = row["count"]
@@ -1399,7 +1399,7 @@ class DatabaseManager:
row = conn.execute(
"""SELECT COUNT(*) as count FROM multimodal_entity_links
WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""",
(project_id,),
(project_id, ),
).fetchone()
stats["cross_modal_links"] = row["count"]

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ import os
class DocumentProcessor:
"""文档处理器 - 提取 PDF/DOCX 文本"""
def __init__(self):
def __init__(self) -> None:
self.supported_formats = {
".pdf": self._extract_pdf,
".docx": self._extract_docx,
@@ -123,7 +123,7 @@ class DocumentProcessor:
continue
# 如果都失败了,使用 latin-1 并忽略错误
return content.decode("latin-1", errors="ignore")
return content.decode("latin-1", errors = "ignore")
def _clean_text(self, text: str) -> str:
"""清理提取的文本"""
@@ -173,7 +173,7 @@ class SimpleTextExtractor:
except UnicodeDecodeError:
continue
return content.decode("latin-1", errors="ignore")
return content.decode("latin-1", errors = "ignore")
if __name__ == "__main__":

View File

@@ -329,7 +329,7 @@ class EnterpriseManager:
],
}
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self._init_db()
@@ -610,30 +610,30 @@ class EnterpriseManager:
attribute_mapping = self.DEFAULT_ATTRIBUTE_MAPPING[SSOProvider(provider)]
config = SSOConfig(
id=config_id,
tenant_id=tenant_id,
provider=provider,
status=SSOStatus.PENDING.value,
entity_id=entity_id,
sso_url=sso_url,
slo_url=slo_url,
certificate=certificate,
metadata_url=metadata_url,
metadata_xml=metadata_xml,
client_id=client_id,
client_secret=client_secret,
authorization_url=authorization_url,
token_url=token_url,
userinfo_url=userinfo_url,
scopes=scopes or ["openid", "email", "profile"],
attribute_mapping=attribute_mapping or {},
auto_provision=auto_provision,
default_role=default_role,
domain_restriction=domain_restriction or [],
created_at=now,
updated_at=now,
last_tested_at=None,
last_error=None,
id = config_id,
tenant_id = tenant_id,
provider = provider,
status = SSOStatus.PENDING.value,
entity_id = entity_id,
sso_url = sso_url,
slo_url = slo_url,
certificate = certificate,
metadata_url = metadata_url,
metadata_xml = metadata_xml,
client_id = client_id,
client_secret = client_secret,
authorization_url = authorization_url,
token_url = token_url,
userinfo_url = userinfo_url,
scopes = scopes or ["openid", "email", "profile"],
attribute_mapping = attribute_mapping or {},
auto_provision = auto_provision,
default_role = default_role,
domain_restriction = domain_restriction or [],
created_at = now,
updated_at = now,
last_tested_at = None,
last_error = None,
)
cursor = conn.cursor()
@@ -688,7 +688,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id,))
cursor.execute("SELECT * FROM sso_configs WHERE id = ?", (config_id, ))
row = cursor.fetchone()
if row:
@@ -722,7 +722,7 @@ class EnterpriseManager:
WHERE tenant_id = ? AND status = 'active'
ORDER BY created_at DESC LIMIT 1
""",
(tenant_id,),
(tenant_id, ),
)
row = cursor.fetchone()
@@ -802,7 +802,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id,))
cursor.execute("DELETE FROM sso_configs WHERE id = ?", (config_id, ))
conn.commit()
return cursor.rowcount > 0
finally:
@@ -818,7 +818,7 @@ class EnterpriseManager:
SELECT * FROM sso_configs WHERE tenant_id = ?
ORDER BY created_at DESC
""",
(tenant_id,),
(tenant_id, ),
)
rows = cursor.fetchall()
@@ -841,30 +841,30 @@ class EnterpriseManager:
# 生成 X.509 证书(简化实现,实际应该生成真实的密钥对)
cert = config.certificate or self._generate_self_signed_cert()
metadata = f"""<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
entityID="{sp_entity_id}">
<md:SPSSODescriptor AuthnRequestsSigned="true"
WantAssertionsSigned="true"
protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
metadata = f"""<?xml version = "1.0" encoding = "UTF-8"?>
<md:EntityDescriptor xmlns:md = "urn:oasis:names:tc:SAML:2.0:metadata"
entityID = "{sp_entity_id}">
<md:SPSSODescriptor AuthnRequestsSigned = "true"
WantAssertionsSigned = "true"
protocolSupportEnumeration = "urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use = "signing">
<ds:KeyInfo xmlns:ds = "http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>{cert}</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
Location="{slo_url}"/>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location="{acs_url}"
index="0"
isDefault="true"/>
<md:SingleLogoutService Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
Location = "{slo_url}"/>
<md:AssertionConsumerService Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location = "{acs_url}"
index = "0"
isDefault = "true"/>
</md:SPSSODescriptor>
<md:Organization>
<md:OrganizationName xml:lang="en">InsightFlow</md:OrganizationName>
<md:OrganizationDisplayName xml:lang="en">InsightFlow</md:OrganizationDisplayName>
<md:OrganizationURL xml:lang="en">{base_url}</md:OrganizationURL>
<md:OrganizationName xml:lang = "en">InsightFlow</md:OrganizationName>
<md:OrganizationDisplayName xml:lang = "en">InsightFlow</md:OrganizationDisplayName>
<md:OrganizationURL xml:lang = "en">{base_url}</md:OrganizationURL>
</md:Organization>
</md:EntityDescriptor>"""
@@ -878,18 +878,18 @@ class EnterpriseManager:
try:
request_id = f"_{uuid.uuid4().hex}"
now = datetime.now()
expires = now + timedelta(minutes=10)
expires = now + timedelta(minutes = 10)
auth_request = SAMLAuthRequest(
id=str(uuid.uuid4()),
tenant_id=tenant_id,
sso_config_id=config_id,
request_id=request_id,
relay_state=relay_state,
created_at=now,
expires_at=expires,
used=False,
used_at=None,
id = str(uuid.uuid4()),
tenant_id = tenant_id,
sso_config_id = config_id,
request_id = request_id,
relay_state = relay_state,
created_at = now,
expires_at = expires,
used = False,
used_at = None,
)
cursor = conn.cursor()
@@ -926,7 +926,7 @@ class EnterpriseManager:
"""
SELECT * FROM saml_auth_requests WHERE request_id = ?
""",
(request_id,),
(request_id, ),
)
row = cursor.fetchone()
@@ -949,17 +949,17 @@ class EnterpriseManager:
attributes = self._parse_saml_response(saml_response)
auth_response = SAMLAuthResponse(
id=str(uuid.uuid4()),
request_id=request_id,
tenant_id="", # 从 request 获取
user_id=None,
email=attributes.get("email"),
name=attributes.get("name"),
attributes=attributes,
session_index=attributes.get("session_index"),
processed=False,
processed_at=None,
created_at=datetime.now(),
id = str(uuid.uuid4()),
request_id = request_id,
tenant_id = "", # 从 request 获取
user_id = None,
email = attributes.get("email"),
name = attributes.get("name"),
attributes = attributes,
session_index = attributes.get("session_index"),
processed = False,
processed_at = None,
created_at = datetime.now(),
)
cursor = conn.cursor()
@@ -1028,21 +1028,21 @@ class EnterpriseManager:
now = datetime.now()
config = SCIMConfig(
id=config_id,
tenant_id=tenant_id,
provider=provider,
status="disabled",
scim_base_url=scim_base_url,
scim_token=scim_token,
sync_interval_minutes=sync_interval_minutes,
last_sync_at=None,
last_sync_status=None,
last_sync_error=None,
last_sync_users_count=0,
attribute_mapping=attribute_mapping or {},
sync_rules=sync_rules or {},
created_at=now,
updated_at=now,
id = config_id,
tenant_id = tenant_id,
provider = provider,
status = "disabled",
scim_base_url = scim_base_url,
scim_token = scim_token,
sync_interval_minutes = sync_interval_minutes,
last_sync_at = None,
last_sync_status = None,
last_sync_error = None,
last_sync_users_count = 0,
attribute_mapping = attribute_mapping or {},
sync_rules = sync_rules or {},
created_at = now,
updated_at = now,
)
cursor = conn.cursor()
@@ -1084,7 +1084,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id,))
cursor.execute("SELECT * FROM scim_configs WHERE id = ?", (config_id, ))
row = cursor.fetchone()
if row:
@@ -1104,7 +1104,7 @@ class EnterpriseManager:
SELECT * FROM scim_configs WHERE tenant_id = ?
ORDER BY created_at DESC LIMIT 1
""",
(tenant_id,),
(tenant_id, ),
)
row = cursor.fetchone()
@@ -1325,28 +1325,28 @@ class EnterpriseManager:
now = datetime.now()
# 默认7天后过期
expires_at = now + timedelta(days=7)
expires_at = now + timedelta(days = 7)
export = AuditLogExport(
id=export_id,
tenant_id=tenant_id,
export_format=export_format,
start_date=start_date,
end_date=end_date,
filters=filters or {},
compliance_standard=compliance_standard,
status="pending",
file_path=None,
file_size=None,
record_count=None,
checksum=None,
downloaded_by=None,
downloaded_at=None,
expires_at=expires_at,
created_by=created_by,
created_at=now,
completed_at=None,
error_message=None,
id = export_id,
tenant_id = tenant_id,
export_format = export_format,
start_date = start_date,
end_date = end_date,
filters = filters or {},
compliance_standard = compliance_standard,
status = "pending",
file_path = None,
file_size = None,
record_count = None,
checksum = None,
downloaded_by = None,
downloaded_at = None,
expires_at = expires_at,
created_by = created_by,
created_at = now,
completed_at = None,
error_message = None,
)
cursor = conn.cursor()
@@ -1383,7 +1383,7 @@ class EnterpriseManager:
finally:
conn.close()
def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None:
def process_audit_export(self, export_id: str, db_manager = None) -> AuditLogExport | None:
"""处理审计日志导出任务"""
export = self.get_audit_export(export_id)
if not export:
@@ -1398,7 +1398,7 @@ class EnterpriseManager:
UPDATE audit_log_exports SET status = 'processing'
WHERE id = ?
""",
(export_id,),
(export_id, ),
)
conn.commit()
@@ -1454,7 +1454,7 @@ class EnterpriseManager:
start_date: datetime,
end_date: datetime,
filters: dict[str, Any],
db_manager=None,
db_manager = None,
) -> list[dict[str, Any]]:
"""获取审计日志数据"""
if db_manager is None:
@@ -1488,26 +1488,26 @@ class EnterpriseManager:
import os
export_dir = "/tmp/insightflow/exports"
os.makedirs(export_dir, exist_ok=True)
os.makedirs(export_dir, exist_ok = True)
file_path = f"{export_dir}/audit_export_{export_id}.{format}"
if format == "json":
content = json.dumps(logs, ensure_ascii=False, indent=2)
with open(file_path, "w", encoding="utf-8") as f:
content = json.dumps(logs, ensure_ascii = False, indent = 2)
with open(file_path, "w", encoding = "utf-8") as f:
f.write(content)
elif format == "csv":
import csv
if logs:
with open(file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=logs[0].keys())
with open(file_path, "w", newline = "", encoding = "utf-8") as f:
writer = csv.DictWriter(f, fieldnames = logs[0].keys())
writer.writeheader()
writer.writerows(logs)
else:
# 其他格式暂不支持
content = json.dumps(logs, ensure_ascii=False)
with open(file_path, "w", encoding="utf-8") as f:
content = json.dumps(logs, ensure_ascii = False)
with open(file_path, "w", encoding = "utf-8") as f:
f.write(content)
file_size = os.path.getsize(file_path)
@@ -1523,7 +1523,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id,))
cursor.execute("SELECT * FROM audit_log_exports WHERE id = ?", (export_id, ))
row = cursor.fetchone()
if row:
@@ -1596,24 +1596,24 @@ class EnterpriseManager:
now = datetime.now()
policy = DataRetentionPolicy(
id=policy_id,
tenant_id=tenant_id,
name=name,
description=description,
resource_type=resource_type,
retention_days=retention_days,
action=action,
conditions=conditions or {},
auto_execute=auto_execute,
execute_at=execute_at,
notify_before_days=notify_before_days,
archive_location=archive_location,
archive_encryption=archive_encryption,
is_active=True,
last_executed_at=None,
last_execution_result=None,
created_at=now,
updated_at=now,
id = policy_id,
tenant_id = tenant_id,
name = name,
description = description,
resource_type = resource_type,
retention_days = retention_days,
action = action,
conditions = conditions or {},
auto_execute = auto_execute,
execute_at = execute_at,
notify_before_days = notify_before_days,
archive_location = archive_location,
archive_encryption = archive_encryption,
is_active = True,
last_executed_at = None,
last_execution_result = None,
created_at = now,
updated_at = now,
)
cursor = conn.cursor()
@@ -1661,7 +1661,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id,))
cursor.execute("SELECT * FROM data_retention_policies WHERE id = ?", (policy_id, ))
row = cursor.fetchone()
if row:
@@ -1758,7 +1758,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id,))
cursor.execute("DELETE FROM data_retention_policies WHERE id = ?", (policy_id, ))
conn.commit()
return cursor.rowcount > 0
finally:
@@ -1776,18 +1776,18 @@ class EnterpriseManager:
now = datetime.now()
job = DataRetentionJob(
id=job_id,
policy_id=policy_id,
tenant_id=policy.tenant_id,
status="running",
started_at=now,
completed_at=None,
affected_records=0,
archived_records=0,
deleted_records=0,
error_count=0,
details={},
created_at=now,
id = job_id,
policy_id = policy_id,
tenant_id = policy.tenant_id,
status = "running",
started_at = now,
completed_at = None,
affected_records = 0,
archived_records = 0,
deleted_records = 0,
error_count = 0,
details = {},
created_at = now,
)
cursor = conn.cursor()
@@ -1804,7 +1804,7 @@ class EnterpriseManager:
try:
# 计算截止日期
cutoff_date = now - timedelta(days=policy.retention_days)
cutoff_date = now - timedelta(days = policy.retention_days)
# 根据资源类型执行不同的处理
if policy.resource_type == "audit_log":
@@ -1887,7 +1887,7 @@ class EnterpriseManager:
SELECT COUNT(*) as count FROM audit_logs
WHERE created_at < ?
""",
(cutoff_date,),
(cutoff_date, ),
)
count = cursor.fetchone()["count"]
@@ -1896,7 +1896,7 @@ class EnterpriseManager:
"""
DELETE FROM audit_logs WHERE created_at < ?
""",
(cutoff_date,),
(cutoff_date, ),
)
deleted = cursor.rowcount
return {"affected": count, "archived": 0, "deleted": deleted, "errors": 0}
@@ -1927,7 +1927,7 @@ class EnterpriseManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id,))
cursor.execute("SELECT * FROM data_retention_jobs WHERE id = ?", (job_id, ))
row = cursor.fetchone()
if row:
@@ -1963,64 +1963,64 @@ class EnterpriseManager:
def _row_to_sso_config(self, row: sqlite3.Row) -> SSOConfig:
"""数据库行转换为 SSOConfig 对象"""
return SSOConfig(
id=row["id"],
tenant_id=row["tenant_id"],
provider=row["provider"],
status=row["status"],
entity_id=row["entity_id"],
sso_url=row["sso_url"],
slo_url=row["slo_url"],
certificate=row["certificate"],
metadata_url=row["metadata_url"],
metadata_xml=row["metadata_xml"],
client_id=row["client_id"],
client_secret=row["client_secret"],
authorization_url=row["authorization_url"],
token_url=row["token_url"],
userinfo_url=row["userinfo_url"],
scopes=json.loads(row["scopes"] or '["openid", "email", "profile"]'),
attribute_mapping=json.loads(row["attribute_mapping"] or "{}"),
auto_provision=bool(row["auto_provision"]),
default_role=row["default_role"],
domain_restriction=json.loads(row["domain_restriction"] or "[]"),
created_at=(
id = row["id"],
tenant_id = row["tenant_id"],
provider = row["provider"],
status = row["status"],
entity_id = row["entity_id"],
sso_url = row["sso_url"],
slo_url = row["slo_url"],
certificate = row["certificate"],
metadata_url = row["metadata_url"],
metadata_xml = row["metadata_xml"],
client_id = row["client_id"],
client_secret = row["client_secret"],
authorization_url = row["authorization_url"],
token_url = row["token_url"],
userinfo_url = row["userinfo_url"],
scopes = json.loads(row["scopes"] or '["openid", "email", "profile"]'),
attribute_mapping = json.loads(row["attribute_mapping"] or "{}"),
auto_provision = bool(row["auto_provision"]),
default_role = row["default_role"],
domain_restriction = json.loads(row["domain_restriction"] or "[]"),
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
),
last_tested_at=(
last_tested_at = (
datetime.fromisoformat(row["last_tested_at"])
if row["last_tested_at"] and isinstance(row["last_tested_at"], str)
else row["last_tested_at"]
),
last_error=row["last_error"],
last_error = row["last_error"],
)
def _row_to_saml_request(self, row: sqlite3.Row) -> SAMLAuthRequest:
"""数据库行转换为 SAMLAuthRequest 对象"""
return SAMLAuthRequest(
id=row["id"],
tenant_id=row["tenant_id"],
sso_config_id=row["sso_config_id"],
request_id=row["request_id"],
relay_state=row["relay_state"],
created_at=(
id = row["id"],
tenant_id = row["tenant_id"],
sso_config_id = row["sso_config_id"],
request_id = row["request_id"],
relay_state = row["relay_state"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
expires_at=(
expires_at = (
datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str)
else row["expires_at"]
),
used=bool(row["used"]),
used_at=(
used = bool(row["used"]),
used_at = (
datetime.fromisoformat(row["used_at"])
if row["used_at"] and isinstance(row["used_at"], str)
else row["used_at"]
@@ -2030,29 +2030,29 @@ class EnterpriseManager:
def _row_to_scim_config(self, row: sqlite3.Row) -> SCIMConfig:
"""数据库行转换为 SCIMConfig 对象"""
return SCIMConfig(
id=row["id"],
tenant_id=row["tenant_id"],
provider=row["provider"],
status=row["status"],
scim_base_url=row["scim_base_url"],
scim_token=row["scim_token"],
sync_interval_minutes=row["sync_interval_minutes"],
last_sync_at=(
id = row["id"],
tenant_id = row["tenant_id"],
provider = row["provider"],
status = row["status"],
scim_base_url = row["scim_base_url"],
scim_token = row["scim_token"],
sync_interval_minutes = row["sync_interval_minutes"],
last_sync_at = (
datetime.fromisoformat(row["last_sync_at"])
if row["last_sync_at"] and isinstance(row["last_sync_at"], str)
else row["last_sync_at"]
),
last_sync_status=row["last_sync_status"],
last_sync_error=row["last_sync_error"],
last_sync_users_count=row["last_sync_users_count"],
attribute_mapping=json.loads(row["attribute_mapping"] or "{}"),
sync_rules=json.loads(row["sync_rules"] or "{}"),
created_at=(
last_sync_status = row["last_sync_status"],
last_sync_error = row["last_sync_error"],
last_sync_users_count = row["last_sync_users_count"],
attribute_mapping = json.loads(row["attribute_mapping"] or "{}"),
sync_rules = json.loads(row["sync_rules"] or "{}"),
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -2062,28 +2062,28 @@ class EnterpriseManager:
def _row_to_scim_user(self, row: sqlite3.Row) -> SCIMUser:
"""数据库行转换为 SCIMUser 对象"""
return SCIMUser(
id=row["id"],
tenant_id=row["tenant_id"],
external_id=row["external_id"],
user_name=row["user_name"],
email=row["email"],
display_name=row["display_name"],
given_name=row["given_name"],
family_name=row["family_name"],
active=bool(row["active"]),
groups=json.loads(row["groups"] or "[]"),
raw_data=json.loads(row["raw_data"] or "{}"),
synced_at=(
id = row["id"],
tenant_id = row["tenant_id"],
external_id = row["external_id"],
user_name = row["user_name"],
email = row["email"],
display_name = row["display_name"],
given_name = row["given_name"],
family_name = row["family_name"],
active = bool(row["active"]),
groups = json.loads(row["groups"] or "[]"),
raw_data = json.loads(row["raw_data"] or "{}"),
synced_at = (
datetime.fromisoformat(row["synced_at"])
if isinstance(row["synced_at"], str)
else row["synced_at"]
),
created_at=(
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -2093,78 +2093,78 @@ class EnterpriseManager:
def _row_to_audit_export(self, row: sqlite3.Row) -> AuditLogExport:
"""数据库行转换为 AuditLogExport 对象"""
return AuditLogExport(
id=row["id"],
tenant_id=row["tenant_id"],
export_format=row["export_format"],
start_date=(
id = row["id"],
tenant_id = row["tenant_id"],
export_format = row["export_format"],
start_date = (
datetime.fromisoformat(row["start_date"])
if isinstance(row["start_date"], str)
else row["start_date"]
),
end_date=datetime.fromisoformat(row["end_date"])
end_date = datetime.fromisoformat(row["end_date"])
if isinstance(row["end_date"], str)
else row["end_date"],
filters=json.loads(row["filters"] or "{}"),
compliance_standard=row["compliance_standard"],
status=row["status"],
file_path=row["file_path"],
file_size=row["file_size"],
record_count=row["record_count"],
checksum=row["checksum"],
downloaded_by=row["downloaded_by"],
downloaded_at=(
filters = json.loads(row["filters"] or "{}"),
compliance_standard = row["compliance_standard"],
status = row["status"],
file_path = row["file_path"],
file_size = row["file_size"],
record_count = row["record_count"],
checksum = row["checksum"],
downloaded_by = row["downloaded_by"],
downloaded_at = (
datetime.fromisoformat(row["downloaded_at"])
if row["downloaded_at"] and isinstance(row["downloaded_at"], str)
else row["downloaded_at"]
),
expires_at=(
expires_at = (
datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str)
else row["expires_at"]
),
created_by=row["created_by"],
created_at=(
created_by = row["created_by"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
completed_at=(
completed_at = (
datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"]
),
error_message=row["error_message"],
error_message = row["error_message"],
)
def _row_to_retention_policy(self, row: sqlite3.Row) -> DataRetentionPolicy:
"""数据库行转换为 DataRetentionPolicy 对象"""
return DataRetentionPolicy(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
resource_type=row["resource_type"],
retention_days=row["retention_days"],
action=row["action"],
conditions=json.loads(row["conditions"] or "{}"),
auto_execute=bool(row["auto_execute"]),
execute_at=row["execute_at"],
notify_before_days=row["notify_before_days"],
archive_location=row["archive_location"],
archive_encryption=bool(row["archive_encryption"]),
is_active=bool(row["is_active"]),
last_executed_at=(
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
description = row["description"],
resource_type = row["resource_type"],
retention_days = row["retention_days"],
action = row["action"],
conditions = json.loads(row["conditions"] or "{}"),
auto_execute = bool(row["auto_execute"]),
execute_at = row["execute_at"],
notify_before_days = row["notify_before_days"],
archive_location = row["archive_location"],
archive_encryption = bool(row["archive_encryption"]),
is_active = bool(row["is_active"]),
last_executed_at = (
datetime.fromisoformat(row["last_executed_at"])
if row["last_executed_at"] and isinstance(row["last_executed_at"], str)
else row["last_executed_at"]
),
last_execution_result=row["last_execution_result"],
created_at=(
last_execution_result = row["last_execution_result"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -2174,26 +2174,26 @@ class EnterpriseManager:
def _row_to_retention_job(self, row: sqlite3.Row) -> DataRetentionJob:
"""数据库行转换为 DataRetentionJob 对象"""
return DataRetentionJob(
id=row["id"],
policy_id=row["policy_id"],
tenant_id=row["tenant_id"],
status=row["status"],
started_at=(
id = row["id"],
policy_id = row["policy_id"],
tenant_id = row["tenant_id"],
status = row["status"],
started_at = (
datetime.fromisoformat(row["started_at"])
if row["started_at"] and isinstance(row["started_at"], str)
else row["started_at"]
),
completed_at=(
completed_at = (
datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"]
),
affected_records=row["affected_records"],
archived_records=row["archived_records"],
deleted_records=row["deleted_records"],
error_count=row["error_count"],
details=json.loads(row["details"] or "{}"),
created_at=(
affected_records = row["affected_records"],
archived_records = row["archived_records"],
deleted_records = row["deleted_records"],
error_count = row["error_count"],
details = json.loads(row["details"] or "{}"),
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]

View File

@@ -27,7 +27,7 @@ class EntityEmbedding:
class EntityAligner:
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
def __init__(self, similarity_threshold: float = 0.85):
def __init__(self, similarity_threshold: float = 0.85) -> None:
self.similarity_threshold = similarity_threshold
self.embedding_cache: dict[str, list[float]] = {}
@@ -52,12 +52,12 @@ class EntityAligner:
try:
response = httpx.post(
f"{KIMI_BASE_URL}/v1/embeddings",
headers={
headers = {
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={"model": "k2p5", "input": text[:500]}, # 限制长度
timeout=30.0,
json = {"model": "k2p5", "input": text[:500]}, # 限制长度
timeout = 30.0,
)
response.raise_for_status()
result = response.json()
@@ -232,7 +232,7 @@ class EntityAligner:
for new_ent in new_entities:
matched = self.find_similar_entity(
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
project_id, new_ent["name"], new_ent.get("definition", ""), threshold = threshold
)
result = {
@@ -292,16 +292,16 @@ class EntityAligner:
try:
response = httpx.post(
f"{KIMI_BASE_URL}/v1/chat/completions",
headers={
headers = {
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={
json = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
timeout=30.0,
timeout = 30.0,
)
response.raise_for_status()
result = response.json()

View File

@@ -71,7 +71,7 @@ class ExportTranscript:
class ExportManager:
"""导出管理器 - 处理各种导出需求"""
def __init__(self, db_manager=None):
def __init__(self, db_manager = None) -> None:
self.db = db_manager
def export_knowledge_graph_svg(
@@ -121,17 +121,17 @@ class ExportManager:
# 生成 SVG
svg_parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" '
f'viewBox="0 0 {width} {height}">',
f'<svg xmlns = "http://www.w3.org/2000/svg" width = "{width}" height = "{height}" '
f'viewBox = "0 0 {width} {height}">',
"<defs>",
' <marker id="arrowhead" markerWidth="10" markerHeight="7" '
'refX="9" refY="3.5" orient="auto">',
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>',
' <marker id = "arrowhead" markerWidth = "10" markerHeight = "7" '
'refX = "9" refY = "3.5" orient = "auto">',
' <polygon points = "0 0, 10 3.5, 0 7" fill = "#7f8c8d"/>',
" </marker>",
"</defs>",
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>',
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" '
f'font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>',
f'<rect width = "{width}" height = "{height}" fill = "#f8f9fa"/>',
f'<text x = "{center_x}" y = "30" text-anchor = "middle" font-size = "20" '
f'font-weight = "bold" fill = "#2c3e50">知识图谱 - {project_id}</text>',
]
# 绘制关系连线
@@ -150,20 +150,20 @@ class ExportManager:
y2 = y2 - dy * offset / dist
svg_parts.append(
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" '
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>'
f'<line x1 = "{x1}" y1 = "{y1}" x2 = "{x2}" y2 = "{y2}" '
f'stroke = "#7f8c8d" stroke-width = "2" marker-end = "url(#arrowhead)" opacity = "0.6"/>'
)
# 关系标签
mid_x = (x1 + x2) / 2
mid_y = (y1 + y2) / 2
svg_parts.append(
f'<rect x="{mid_x - 30}" y="{mid_y - 10}" width="60" height="20" '
f'fill="white" stroke="#bdc3c7" rx="3"/>'
f'<rect x = "{mid_x - 30}" y = "{mid_y - 10}" width = "60" height = "20" '
f'fill = "white" stroke = "#bdc3c7" rx = "3"/>'
)
svg_parts.append(
f'<text x="{mid_x}" y="{mid_y + 5}" text-anchor="middle" '
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>'
f'<text x = "{mid_x}" y = "{mid_y + 5}" text-anchor = "middle" '
f'font-size = "10" fill = "#2c3e50">{rel.relation_type}</text>'
)
# 绘制实体节点
@@ -174,19 +174,19 @@ class ExportManager:
# 节点圆圈
svg_parts.append(
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>'
f'<circle cx = "{x}" cy = "{y}" r = "35" fill = "{color}" stroke = "white" stroke-width = "3"/>'
)
# 实体名称
svg_parts.append(
f'<text x="{x}" y="{y + 5}" text-anchor="middle" font-size="12" '
f'font-weight="bold" fill="white">{entity.name[:8]}</text>'
f'<text x = "{x}" y = "{y + 5}" text-anchor = "middle" font-size = "12" '
f'font-weight = "bold" fill = "white">{entity.name[:8]}</text>'
)
# 实体类型
svg_parts.append(
f'<text x="{x}" y="{y + 55}" text-anchor="middle" font-size="10" '
f'fill="#7f8c8d">{entity.type}</text>'
f'<text x = "{x}" y = "{y + 55}" text-anchor = "middle" font-size = "10" '
f'fill = "#7f8c8d">{entity.type}</text>'
)
# 图例
@@ -196,24 +196,24 @@ class ExportManager:
rect_y = legend_y - 20
rect_height = len(type_colors) * 25 + 10
svg_parts.append(
f'<rect x="{rect_x}" y="{rect_y}" width="140" height="{rect_height}" '
f'fill="white" stroke="#bdc3c7" rx="5"/>'
f'<rect x = "{rect_x}" y = "{rect_y}" width = "140" height = "{rect_height}" '
f'fill = "white" stroke = "#bdc3c7" rx = "5"/>'
)
svg_parts.append(
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
f'fill="#2c3e50">实体类型</text>'
f'<text x = "{legend_x}" y = "{legend_y}" font-size = "12" font-weight = "bold" '
f'fill = "#2c3e50">实体类型</text>'
)
for i, (etype, color) in enumerate(type_colors.items()):
if etype != "default":
y_pos = legend_y + 25 + i * 20
svg_parts.append(
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
f'<circle cx = "{legend_x + 10}" cy = "{y_pos}" r = "8" fill = "{color}"/>'
)
text_y = y_pos + 4
svg_parts.append(
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
f'fill="#2c3e50">{etype}</text>'
f'<text x = "{legend_x + 25}" y = "{text_y}" font-size = "10" '
f'fill = "#2c3e50">{etype}</text>'
)
svg_parts.append("</svg>")
@@ -232,7 +232,7 @@ class ExportManager:
import cairosvg
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
png_bytes = cairosvg.svg2png(bytestring = svg_content.encode("utf-8"))
return png_bytes
except ImportError:
# 如果没有 cairosvg返回 SVG 的 base64
@@ -269,8 +269,8 @@ class ExportManager:
# 写入 Excel
output = io.BytesIO()
with pd.ExcelWriter(output, engine="openpyxl") as writer:
df.to_excel(writer, sheet_name="实体列表", index=False)
with pd.ExcelWriter(output, engine = "openpyxl") as writer:
df.to_excel(writer, sheet_name = "实体列表", index = False)
# 调整列宽
worksheet = writer.sheets["实体列表"]
@@ -417,24 +417,24 @@ class ExportManager:
output = io.BytesIO()
doc = SimpleDocTemplate(
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18
output, pagesize = A4, rightMargin = 72, leftMargin = 72, topMargin = 72, bottomMargin = 18
)
# 样式
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
"CustomTitle",
parent=styles["Heading1"],
fontSize=24,
spaceAfter=30,
textColor=colors.HexColor("#2c3e50"),
parent = styles["Heading1"],
fontSize = 24,
spaceAfter = 30,
textColor = colors.HexColor("#2c3e50"),
)
heading_style = ParagraphStyle(
"CustomHeading",
parent=styles["Heading2"],
fontSize=16,
spaceAfter=12,
textColor=colors.HexColor("#34495e"),
parent = styles["Heading2"],
fontSize = 16,
spaceAfter = 12,
textColor = colors.HexColor("#34495e"),
)
story = []
@@ -467,7 +467,7 @@ class ExportManager:
for etype, count in sorted(type_counts.items()):
stats_data.append([f"{etype} 实体", str(count)])
stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch])
stats_table = Table(stats_data, colWidths = [3 * inch, 2 * inch])
stats_table.setStyle(
TableStyle(
[
@@ -497,7 +497,7 @@ class ExportManager:
story.append(Paragraph("实体列表", heading_style))
entity_data = [["名称", "类型", "提及次数", "定义"]]
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[
for e in sorted(entities, key = lambda x: x.mention_count, reverse = True)[
:50
]: # 限制前50个
entity_data.append(
@@ -510,7 +510,7 @@ class ExportManager:
)
entity_table = Table(
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
entity_data, colWidths = [1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
)
entity_table.setStyle(
TableStyle(
@@ -539,7 +539,7 @@ class ExportManager:
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
relation_table = Table(
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
relation_data, colWidths = [2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
)
relation_table.setStyle(
TableStyle(
@@ -613,14 +613,14 @@ class ExportManager:
],
}
return json.dumps(data, ensure_ascii=False, indent=2)
return json.dumps(data, ensure_ascii = False, indent = 2)
# 全局导出管理器实例
_export_manager = None
def get_export_manager(db_manager=None) -> None:
def get_export_manager(db_manager = None) -> None:
"""获取导出管理器实例"""
global _export_manager
if _export_manager is None:

View File

@@ -362,7 +362,7 @@ class TeamIncentive:
class GrowthManager:
"""运营与增长管理主类"""
def __init__(self, db_path: str = DB_PATH):
def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path
self.mixpanel_token = os.getenv("MIXPANEL_TOKEN", "")
self.amplitude_api_key = os.getenv("AMPLITUDE_API_KEY", "")
@@ -394,19 +394,19 @@ class GrowthManager:
now = datetime.now()
event = AnalyticsEvent(
id=event_id,
tenant_id=tenant_id,
user_id=user_id,
event_type=event_type,
event_name=event_name,
properties=properties or {},
timestamp=now,
session_id=session_id,
device_info=device_info or {},
referrer=referrer,
utm_source=utm_params.get("source") if utm_params else None,
utm_medium=utm_params.get("medium") if utm_params else None,
utm_campaign=utm_params.get("campaign") if utm_params else None,
id = event_id,
tenant_id = tenant_id,
user_id = user_id,
event_type = event_type,
event_name = event_name,
properties = properties or {},
timestamp = now,
session_id = session_id,
device_info = device_info or {},
referrer = referrer,
utm_source = utm_params.get("source") if utm_params else None,
utm_medium = utm_params.get("medium") if utm_params else None,
utm_campaign = utm_params.get("campaign") if utm_params else None,
)
with self._get_db() as conn:
@@ -443,7 +443,7 @@ class GrowthManager:
return event
async def _send_to_analytics_platforms(self, event: AnalyticsEvent):
async def _send_to_analytics_platforms(self, event: AnalyticsEvent) -> None:
"""发送事件到第三方分析平台"""
tasks = []
@@ -453,9 +453,9 @@ class GrowthManager:
tasks.append(self._send_to_amplitude(event))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.gather(*tasks, return_exceptions = True)
async def _send_to_mixpanel(self, event: AnalyticsEvent):
async def _send_to_mixpanel(self, event: AnalyticsEvent) -> None:
"""发送事件到 Mixpanel"""
try:
headers = {
@@ -475,12 +475,12 @@ class GrowthManager:
async with httpx.AsyncClient() as client:
await client.post(
"https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0
"https://api.mixpanel.com/track", headers = headers, json = [payload], timeout = 10.0
)
except (RuntimeError, ValueError, TypeError) as e:
print(f"Failed to send to Mixpanel: {e}")
async def _send_to_amplitude(self, event: AnalyticsEvent):
async def _send_to_amplitude(self, event: AnalyticsEvent) -> None:
"""发送事件到 Amplitude"""
try:
headers = {"Content-Type": "application/json"}
@@ -501,16 +501,16 @@ class GrowthManager:
async with httpx.AsyncClient() as client:
await client.post(
"https://api.amplitude.com/2/httpapi",
headers=headers,
json=payload,
timeout=10.0,
headers = headers,
json = payload,
timeout = 10.0,
)
except (RuntimeError, ValueError, TypeError) as e:
print(f"Failed to send to Amplitude: {e}")
async def _update_user_profile(
self, tenant_id: str, user_id: str, event_type: EventType, event_name: str
):
) -> None:
"""更新用户画像"""
with self._get_db() as conn:
# 检查用户画像是否存在
@@ -642,13 +642,13 @@ class GrowthManager:
now = datetime.now().isoformat()
funnel = Funnel(
id=funnel_id,
tenant_id=tenant_id,
name=name,
description=description,
steps=steps,
created_at=now,
updated_at=now,
id = funnel_id,
tenant_id = tenant_id,
name = name,
description = description,
steps = steps,
created_at = now,
updated_at = now,
)
with self._get_db() as conn:
@@ -677,7 +677,7 @@ class GrowthManager:
) -> FunnelAnalysis | None:
"""分析漏斗转化率"""
with self._get_db() as conn:
funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id,)).fetchone()
funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id, )).fetchone()
if not funnel_row:
return None
@@ -685,7 +685,7 @@ class GrowthManager:
steps = json.loads(funnel_row["steps"])
if not period_start:
period_start = datetime.now() - timedelta(days=30)
period_start = datetime.now() - timedelta(days = 30)
if not period_end:
period_end = datetime.now()
@@ -740,13 +740,13 @@ class GrowthManager:
]
return FunnelAnalysis(
funnel_id=funnel_id,
period_start=period_start,
period_end=period_end,
total_users=step_conversions[0]["user_count"] if step_conversions else 0,
step_conversions=step_conversions,
overall_conversion=round(overall_conversion, 4),
drop_off_points=drop_off_points,
funnel_id = funnel_id,
period_start = period_start,
period_end = period_end,
total_users = step_conversions[0]["user_count"] if step_conversions else 0,
step_conversions = step_conversions,
overall_conversion = round(overall_conversion, 4),
drop_off_points = drop_off_points,
)
def calculate_retention(
@@ -781,14 +781,14 @@ class GrowthManager:
retention_rates = {}
for period in periods:
period_date = cohort_date + timedelta(days=period)
period_date = cohort_date + timedelta(days = period)
active_query = """
SELECT COUNT(DISTINCT user_id) as active_count
FROM analytics_events
WHERE tenant_id = ? AND date(timestamp) = date(?)
AND user_id IN ({})
""".format(",".join(["?" for _ in cohort_users]))
""".format(", ".join(["?" for _ in cohort_users]))
params = [tenant_id, period_date.isoformat()] + list(cohort_users)
row = conn.execute(active_query, params).fetchone()
@@ -830,25 +830,25 @@ class GrowthManager:
now = datetime.now().isoformat()
experiment = Experiment(
id=experiment_id,
tenant_id=tenant_id,
name=name,
description=description,
hypothesis=hypothesis,
status=ExperimentStatus.DRAFT,
variants=variants,
traffic_allocation=traffic_allocation,
traffic_split=traffic_split,
target_audience=target_audience,
primary_metric=primary_metric,
secondary_metrics=secondary_metrics,
start_date=None,
end_date=None,
min_sample_size=min_sample_size,
confidence_level=confidence_level,
created_at=now,
updated_at=now,
created_by=created_by or "system",
id = experiment_id,
tenant_id = tenant_id,
name = name,
description = description,
hypothesis = hypothesis,
status = ExperimentStatus.DRAFT,
variants = variants,
traffic_allocation = traffic_allocation,
traffic_split = traffic_split,
target_audience = target_audience,
primary_metric = primary_metric,
secondary_metrics = secondary_metrics,
start_date = None,
end_date = None,
min_sample_size = min_sample_size,
confidence_level = confidence_level,
created_at = now,
updated_at = now,
created_by = created_by or "system",
)
with self._get_db() as conn:
@@ -891,7 +891,7 @@ class GrowthManager:
"""获取实验详情"""
with self._get_db() as conn:
row = conn.execute(
"SELECT * FROM experiments WHERE id = ?", (experiment_id,)
"SELECT * FROM experiments WHERE id = ?", (experiment_id, )
).fetchone()
if row:
@@ -973,7 +973,7 @@ class GrowthManager:
total = sum(weights)
normalized_weights = [w / total for w in weights]
return random.choices(variant_ids, weights=normalized_weights, k=1)[0]
return random.choices(variant_ids, weights = normalized_weights, k = 1)[0]
def _stratified_allocation(
self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict
@@ -1027,7 +1027,7 @@ class GrowthManager:
user_id: str,
metric_name: str,
metric_value: float,
):
) -> None:
"""记录实验指标"""
with self._get_db() as conn:
conn.execute(
@@ -1196,21 +1196,21 @@ class GrowthManager:
variables = re.findall(r"\{\{(\w+)\}\}", html_content)
template = EmailTemplate(
id=template_id,
tenant_id=tenant_id,
name=name,
template_type=template_type,
subject=subject,
html_content=html_content,
text_content=text_content or re.sub(r"<[^>]+>", "", html_content),
variables=variables,
preview_text=None,
from_name=from_name or "InsightFlow",
from_email=from_email or "noreply@insightflow.io",
reply_to=reply_to,
is_active=True,
created_at=now,
updated_at=now,
id = template_id,
tenant_id = tenant_id,
name = name,
template_type = template_type,
subject = subject,
html_content = html_content,
text_content = text_content or re.sub(r"<[^>]+>", "", html_content),
variables = variables,
preview_text = None,
from_name = from_name or "InsightFlow",
from_email = from_email or "noreply@insightflow.io",
reply_to = reply_to,
is_active = True,
created_at = now,
updated_at = now,
)
with self._get_db() as conn:
@@ -1246,7 +1246,7 @@ class GrowthManager:
"""获取邮件模板"""
with self._get_db() as conn:
row = conn.execute(
"SELECT * FROM email_templates WHERE id = ?", (template_id,)
"SELECT * FROM email_templates WHERE id = ?", (template_id, )
).fetchone()
if row:
@@ -1308,22 +1308,22 @@ class GrowthManager:
now = datetime.now().isoformat()
campaign = EmailCampaign(
id=campaign_id,
tenant_id=tenant_id,
name=name,
template_id=template_id,
status="draft",
recipient_count=len(recipient_list),
sent_count=0,
delivered_count=0,
opened_count=0,
clicked_count=0,
bounced_count=0,
failed_count=0,
scheduled_at=scheduled_at.isoformat() if scheduled_at else None,
started_at=None,
completed_at=None,
created_at=now,
id = campaign_id,
tenant_id = tenant_id,
name = name,
template_id = template_id,
status = "draft",
recipient_count = len(recipient_list),
sent_count = 0,
delivered_count = 0,
opened_count = 0,
clicked_count = 0,
bounced_count = 0,
failed_count = 0,
scheduled_at = scheduled_at.isoformat() if scheduled_at else None,
started_at = None,
completed_at = None,
created_at = now,
)
with self._get_db() as conn:
@@ -1452,7 +1452,7 @@ class GrowthManager:
"""发送整个营销活动"""
with self._get_db() as conn:
campaign_row = conn.execute(
"SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,)
"SELECT * FROM email_campaigns WHERE id = ?", (campaign_id, )
).fetchone()
if not campaign_row:
@@ -1530,17 +1530,17 @@ class GrowthManager:
now = datetime.now().isoformat()
workflow = AutomationWorkflow(
id=workflow_id,
tenant_id=tenant_id,
name=name,
description=description,
trigger_type=trigger_type,
trigger_conditions=trigger_conditions,
actions=actions,
is_active=True,
execution_count=0,
created_at=now,
updated_at=now,
id = workflow_id,
tenant_id = tenant_id,
name = name,
description = description,
trigger_type = trigger_type,
trigger_conditions = trigger_conditions,
actions = actions,
is_active = True,
execution_count = 0,
created_at = now,
updated_at = now,
)
with self._get_db() as conn:
@@ -1569,11 +1569,11 @@ class GrowthManager:
return workflow
async def trigger_workflow(self, workflow_id: str, event_data: dict):
async def trigger_workflow(self, workflow_id: str, event_data: dict) -> None:
"""触发自动化工作流"""
with self._get_db() as conn:
row = conn.execute(
"SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id,)
"SELECT * FROM automation_workflows WHERE id = ? AND is_active = 1", (workflow_id, )
).fetchone()
if not row:
@@ -1592,7 +1592,7 @@ class GrowthManager:
# 更新执行计数
conn.execute(
"UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?",
(workflow_id,),
(workflow_id, ),
)
conn.commit()
@@ -1606,7 +1606,7 @@ class GrowthManager:
return False
return True
async def _execute_action(self, action: dict, event_data: dict):
async def _execute_action(self, action: dict, event_data: dict) -> None:
"""执行工作流动作"""
action_type = action.get("type")
@@ -1640,20 +1640,20 @@ class GrowthManager:
now = datetime.now().isoformat()
program = ReferralProgram(
id=program_id,
tenant_id=tenant_id,
name=name,
description=description,
referrer_reward_type=referrer_reward_type,
referrer_reward_value=referrer_reward_value,
referee_reward_type=referee_reward_type,
referee_reward_value=referee_reward_value,
max_referrals_per_user=max_referrals_per_user,
referral_code_length=referral_code_length,
expiry_days=expiry_days,
is_active=True,
created_at=now,
updated_at=now,
id = program_id,
tenant_id = tenant_id,
name = name,
description = description,
referrer_reward_type = referrer_reward_type,
referrer_reward_value = referrer_reward_value,
referee_reward_type = referee_reward_type,
referee_reward_value = referee_reward_value,
max_referrals_per_user = max_referrals_per_user,
referral_code_length = referral_code_length,
expiry_days = expiry_days,
is_active = True,
created_at = now,
updated_at = now,
)
with self._get_db() as conn:
@@ -1708,24 +1708,24 @@ class GrowthManager:
referral_id = f"ref_{uuid.uuid4().hex[:16]}"
now = datetime.now()
expires_at = now + timedelta(days=program.expiry_days)
expires_at = now + timedelta(days = program.expiry_days)
referral = Referral(
id=referral_id,
program_id=program_id,
tenant_id=program.tenant_id,
referrer_id=referrer_id,
referee_id=None,
referral_code=referral_code,
status=ReferralStatus.PENDING,
referrer_rewarded=False,
referee_rewarded=False,
referrer_reward_value=program.referrer_reward_value,
referee_reward_value=program.referee_reward_value,
converted_at=None,
rewarded_at=None,
expires_at=expires_at,
created_at=now,
id = referral_id,
program_id = program_id,
tenant_id = program.tenant_id,
referrer_id = referrer_id,
referee_id = None,
referral_code = referral_code,
status = ReferralStatus.PENDING,
referrer_rewarded = False,
referee_rewarded = False,
referrer_reward_value = program.referrer_reward_value,
referee_reward_value = program.referee_reward_value,
converted_at = None,
rewarded_at = None,
expires_at = expires_at,
created_at = now,
)
conn.execute(
@@ -1762,11 +1762,11 @@ class GrowthManager:
"""生成唯一推荐码"""
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # 排除易混淆字符
while True:
code = "".join(random.choices(chars, k=length))
code = "".join(random.choices(chars, k = length))
with self._get_db() as conn:
row = conn.execute(
"SELECT 1 FROM referrals WHERE referral_code = ?", (code,)
"SELECT 1 FROM referrals WHERE referral_code = ?", (code, )
).fetchone()
if not row:
@@ -1776,7 +1776,7 @@ class GrowthManager:
"""获取推荐计划"""
with self._get_db() as conn:
row = conn.execute(
"SELECT * FROM referral_programs WHERE id = ?", (program_id,)
"SELECT * FROM referral_programs WHERE id = ?", (program_id, )
).fetchone()
if row:
@@ -1811,7 +1811,7 @@ class GrowthManager:
def reward_referral(self, referral_id: str) -> bool:
"""发放推荐奖励"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id,)).fetchone()
row = conn.execute("SELECT * FROM referrals WHERE id = ?", (referral_id, )).fetchone()
if not row or row["status"] != ReferralStatus.CONVERTED.value:
return False
@@ -1883,18 +1883,18 @@ class GrowthManager:
now = datetime.now().isoformat()
incentive = TeamIncentive(
id=incentive_id,
tenant_id=tenant_id,
name=name,
description=description,
target_tier=target_tier,
min_team_size=min_team_size,
incentive_type=incentive_type,
incentive_value=incentive_value,
valid_from=valid_from.isoformat(),
valid_until=valid_until.isoformat(),
is_active=True,
created_at=now,
id = incentive_id,
tenant_id = tenant_id,
name = name,
description = description,
target_tier = target_tier,
min_team_size = min_team_size,
incentive_type = incentive_type,
incentive_value = incentive_value,
valid_from = valid_from.isoformat(),
valid_until = valid_until.isoformat(),
is_active = True,
created_at = now,
)
with self._get_db() as conn:
@@ -1947,7 +1947,7 @@ class GrowthManager:
def get_realtime_dashboard(self, tenant_id: str) -> dict:
"""获取实时分析仪表板数据"""
now = datetime.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
today_start = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0)
with self._get_db() as conn:
# 今日统计
@@ -1972,7 +1972,7 @@ class GrowthManager:
ORDER BY timestamp DESC
LIMIT 20
""",
(tenant_id,),
(tenant_id, ),
).fetchall()
# 热门功能
@@ -1991,8 +1991,8 @@ class GrowthManager:
# 活跃用户趋势最近24小时每小时
hourly_trend = []
for i in range(24):
hour_start = now - timedelta(hours=i + 1)
hour_end = now - timedelta(hours=i)
hour_start = now - timedelta(hours = i + 1)
hour_end = now - timedelta(hours = i)
row = conn.execute(
"""
@@ -2035,116 +2035,116 @@ class GrowthManager:
def _row_to_user_profile(self, row) -> UserProfile:
"""将数据库行转换为 UserProfile"""
return UserProfile(
id=row["id"],
tenant_id=row["tenant_id"],
user_id=row["user_id"],
first_seen=datetime.fromisoformat(row["first_seen"]),
last_seen=datetime.fromisoformat(row["last_seen"]),
total_sessions=row["total_sessions"],
total_events=row["total_events"],
feature_usage=json.loads(row["feature_usage"]),
subscription_history=json.loads(row["subscription_history"]),
ltv=row["ltv"],
churn_risk_score=row["churn_risk_score"],
engagement_score=row["engagement_score"],
created_at=datetime.fromisoformat(row["created_at"]),
updated_at=datetime.fromisoformat(row["updated_at"]),
id = row["id"],
tenant_id = row["tenant_id"],
user_id = row["user_id"],
first_seen = datetime.fromisoformat(row["first_seen"]),
last_seen = datetime.fromisoformat(row["last_seen"]),
total_sessions = row["total_sessions"],
total_events = row["total_events"],
feature_usage = json.loads(row["feature_usage"]),
subscription_history = json.loads(row["subscription_history"]),
ltv = row["ltv"],
churn_risk_score = row["churn_risk_score"],
engagement_score = row["engagement_score"],
created_at = datetime.fromisoformat(row["created_at"]),
updated_at = datetime.fromisoformat(row["updated_at"]),
)
def _row_to_experiment(self, row) -> Experiment:
"""将数据库行转换为 Experiment"""
return Experiment(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
hypothesis=row["hypothesis"],
status=ExperimentStatus(row["status"]),
variants=json.loads(row["variants"]),
traffic_allocation=TrafficAllocationType(row["traffic_allocation"]),
traffic_split=json.loads(row["traffic_split"]),
target_audience=json.loads(row["target_audience"]),
primary_metric=row["primary_metric"],
secondary_metrics=json.loads(row["secondary_metrics"]),
start_date=datetime.fromisoformat(row["start_date"]) if row["start_date"] else None,
end_date=datetime.fromisoformat(row["end_date"]) if row["end_date"] else None,
min_sample_size=row["min_sample_size"],
confidence_level=row["confidence_level"],
created_at=row["created_at"],
updated_at=row["updated_at"],
created_by=row["created_by"],
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
description = row["description"],
hypothesis = row["hypothesis"],
status = ExperimentStatus(row["status"]),
variants = json.loads(row["variants"]),
traffic_allocation = TrafficAllocationType(row["traffic_allocation"]),
traffic_split = json.loads(row["traffic_split"]),
target_audience = json.loads(row["target_audience"]),
primary_metric = row["primary_metric"],
secondary_metrics = json.loads(row["secondary_metrics"]),
start_date = datetime.fromisoformat(row["start_date"]) if row["start_date"] else None,
end_date = datetime.fromisoformat(row["end_date"]) if row["end_date"] else None,
min_sample_size = row["min_sample_size"],
confidence_level = row["confidence_level"],
created_at = row["created_at"],
updated_at = row["updated_at"],
created_by = row["created_by"],
)
def _row_to_email_template(self, row) -> EmailTemplate:
"""将数据库行转换为 EmailTemplate"""
return EmailTemplate(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
template_type=EmailTemplateType(row["template_type"]),
subject=row["subject"],
html_content=row["html_content"],
text_content=row["text_content"],
variables=json.loads(row["variables"]),
preview_text=row["preview_text"],
from_name=row["from_name"],
from_email=row["from_email"],
reply_to=row["reply_to"],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
template_type = EmailTemplateType(row["template_type"]),
subject = row["subject"],
html_content = row["html_content"],
text_content = row["text_content"],
variables = json.loads(row["variables"]),
preview_text = row["preview_text"],
from_name = row["from_name"],
from_email = row["from_email"],
reply_to = row["reply_to"],
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
)
def _row_to_automation_workflow(self, row) -> AutomationWorkflow:
"""将数据库行转换为 AutomationWorkflow"""
return AutomationWorkflow(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
trigger_type=WorkflowTriggerType(row["trigger_type"]),
trigger_conditions=json.loads(row["trigger_conditions"]),
actions=json.loads(row["actions"]),
is_active=bool(row["is_active"]),
execution_count=row["execution_count"],
created_at=row["created_at"],
updated_at=row["updated_at"],
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
description = row["description"],
trigger_type = WorkflowTriggerType(row["trigger_type"]),
trigger_conditions = json.loads(row["trigger_conditions"]),
actions = json.loads(row["actions"]),
is_active = bool(row["is_active"]),
execution_count = row["execution_count"],
created_at = row["created_at"],
updated_at = row["updated_at"],
)
def _row_to_referral_program(self, row) -> ReferralProgram:
"""将数据库行转换为 ReferralProgram"""
return ReferralProgram(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
referrer_reward_type=row["referrer_reward_type"],
referrer_reward_value=row["referrer_reward_value"],
referee_reward_type=row["referee_reward_type"],
referee_reward_value=row["referee_reward_value"],
max_referrals_per_user=row["max_referrals_per_user"],
referral_code_length=row["referral_code_length"],
expiry_days=row["expiry_days"],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
description = row["description"],
referrer_reward_type = row["referrer_reward_type"],
referrer_reward_value = row["referrer_reward_value"],
referee_reward_type = row["referee_reward_type"],
referee_reward_value = row["referee_reward_value"],
max_referrals_per_user = row["max_referrals_per_user"],
referral_code_length = row["referral_code_length"],
expiry_days = row["expiry_days"],
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
)
def _row_to_team_incentive(self, row) -> TeamIncentive:
"""将数据库行转换为 TeamIncentive"""
return TeamIncentive(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
target_tier=row["target_tier"],
min_team_size=row["min_team_size"],
incentive_type=row["incentive_type"],
incentive_value=row["incentive_value"],
valid_from=datetime.fromisoformat(row["valid_from"]),
valid_until=datetime.fromisoformat(row["valid_until"]),
is_active=bool(row["is_active"]),
created_at=row["created_at"],
id = row["id"],
tenant_id = row["tenant_id"],
name = row["name"],
description = row["description"],
target_tier = row["target_tier"],
min_team_size = row["min_team_size"],
incentive_type = row["incentive_type"],
incentive_value = row["incentive_value"],
valid_from = datetime.fromisoformat(row["valid_from"]),
valid_until = datetime.fromisoformat(row["valid_until"]),
is_active = bool(row["is_active"]),
created_at = row["created_at"],
)

View File

@@ -104,7 +104,7 @@ class ImageProcessor:
temp_dir: 临时文件目录
"""
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok=True)
os.makedirs(self.temp_dir, exist_ok = True)
def preprocess_image(self, image, image_type: str = None) -> None:
"""
@@ -169,7 +169,7 @@ class ImageProcessor:
gray = image.convert("L")
# 轻微降噪
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
blurred = gray.filter(ImageFilter.GaussianBlur(radius = 1))
# 增强对比度
enhancer = ImageEnhance.Contrast(blurred)
@@ -255,10 +255,10 @@ class ImageProcessor:
processed_image = self.preprocess_image(image)
# 执行OCR
text = pytesseract.image_to_string(processed_image, lang=lang)
text = pytesseract.image_to_string(processed_image, lang = lang)
# 获取置信度
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
data = pytesseract.image_to_data(processed_image, output_type = pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
@@ -288,12 +288,12 @@ class ImageProcessor:
for match in re.finditer(project_pattern, text):
name = match.group(1) or match.group(2)
if name and len(name) > 2:
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
entities.append(ImageEntity(name = name.strip(), type = "PROJECT", confidence = 0.7))
# 人名(中文)
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)"
name_pattern = r"([\u4e00-\u9fa5]{2, 4})(?:先生|女士|总|经理|工程师|老师)"
for match in re.finditer(name_pattern, text):
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
entities.append(ImageEntity(name = match.group(1), type = "PERSON", confidence = 0.8))
# 技术术语
tech_keywords = [
@@ -314,7 +314,7 @@ class ImageProcessor:
]
for keyword in tech_keywords:
if keyword in text:
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
entities.append(ImageEntity(name = keyword, type = "TECH", confidence = 0.9))
# 去重
seen = set()
@@ -381,16 +381,16 @@ class ImageProcessor:
if not PIL_AVAILABLE:
return ImageProcessingResult(
image_id=image_id,
image_type="other",
ocr_text="",
description="PIL not available",
entities=[],
relations=[],
width=0,
height=0,
success=False,
error_message="PIL library not available",
image_id = image_id,
image_type = "other",
ocr_text = "",
description = "PIL not available",
entities = [],
relations = [],
width = 0,
height = 0,
success = False,
error_message = "PIL library not available",
)
try:
@@ -421,29 +421,29 @@ class ImageProcessor:
image.save(save_path)
return ImageProcessingResult(
image_id=image_id,
image_type=image_type,
ocr_text=ocr_text,
description=description,
entities=entities,
relations=relations,
width=width,
height=height,
success=True,
image_id = image_id,
image_type = image_type,
ocr_text = ocr_text,
description = description,
entities = entities,
relations = relations,
width = width,
height = height,
success = True,
)
except Exception as e:
return ImageProcessingResult(
image_id=image_id,
image_type="other",
ocr_text="",
description="",
entities=[],
relations=[],
width=0,
height=0,
success=False,
error_message=str(e),
image_id = image_id,
image_type = "other",
ocr_text = "",
description = "",
entities = [],
relations = [],
width = 0,
height = 0,
success = False,
error_message = str(e),
)
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
@@ -477,10 +477,10 @@ class ImageProcessor:
for j in range(i + 1, len(sentence_entities)):
relations.append(
ImageRelation(
source=sentence_entities[i].name,
target=sentence_entities[j].name,
relation_type="related",
confidence=0.5,
source = sentence_entities[i].name,
target = sentence_entities[j].name,
relation_type = "related",
confidence = 0.5,
)
)
@@ -513,10 +513,10 @@ class ImageProcessor:
failed_count += 1
return BatchProcessingResult(
results=results,
total_count=len(results),
success_count=success_count,
failed_count=failed_count,
results = results,
total_count = len(results),
success_count = success_count,
failed_count = failed_count,
)
def image_to_base64(self, image_data: bytes) -> str:
@@ -550,7 +550,7 @@ class ImageProcessor:
image.thumbnail(size, Image.Resampling.LANCZOS)
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image.save(buffer, format = "JPEG")
return buffer.getvalue()
except Exception as e:
print(f"Thumbnail generation error: {e}")

View File

@@ -51,7 +51,7 @@ class InferencePath:
class KnowledgeReasoner:
"""知识推理引擎"""
def __init__(self, api_key: str = None, base_url: str = None):
def __init__(self, api_key: str = None, base_url: str = None) -> None:
self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL
self.headers = {
@@ -73,9 +73,9 @@ class KnowledgeReasoner:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
headers = self.headers,
json = payload,
timeout = 120.0,
)
response.raise_for_status()
result = response.json()
@@ -127,7 +127,7 @@ class KnowledgeReasoner:
- factual: 事实类问题(是什么、有哪些)
- opinion: 观点类问题(怎么看、态度、评价)"""
content = await self._call_llm(prompt, temperature=0.1)
content = await self._call_llm(prompt, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
@@ -144,8 +144,8 @@ class KnowledgeReasoner:
"""因果推理 - 分析原因和影响"""
# 构建因果分析提示
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)
prompt = f"""基于以下知识图谱进行因果推理分析:
@@ -159,7 +159,7 @@ class KnowledgeReasoner:
{relations_str[:2000]}
## 项目上下文
{json.dumps(project_context, ensure_ascii=False, indent=2)[:1500]}
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:1500]}
请进行因果分析,返回 JSON 格式:
{{
@@ -172,7 +172,7 @@ class KnowledgeReasoner:
"knowledge_gaps": ["缺失信息1"]
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -180,23 +180,23 @@ class KnowledgeReasoner:
try:
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.CAUSAL,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.CAUSAL,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.CAUSAL,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=["无法完成因果推理"],
answer = content,
reasoning_type = ReasoningType.CAUSAL,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = ["无法完成因果推理"],
)
async def _comparative_reasoning(
@@ -210,10 +210,10 @@ class KnowledgeReasoner:
{query}
## 实体
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:2000]}
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:2000]}
## 关系
{json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)[:1500]}
{json.dumps(graph_data.get("relations", []), ensure_ascii = False, indent = 2)[:1500]}
请进行对比分析,返回 JSON 格式:
{{
@@ -226,7 +226,7 @@ class KnowledgeReasoner:
"knowledge_gaps": []
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -234,23 +234,23 @@ class KnowledgeReasoner:
try:
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.COMPARATIVE,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.COMPARATIVE,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.COMPARATIVE,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=[],
answer = content,
reasoning_type = ReasoningType.COMPARATIVE,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = [],
)
async def _temporal_reasoning(
@@ -264,10 +264,10 @@ class KnowledgeReasoner:
{query}
## 项目时间线
{json.dumps(project_context.get("timeline", []), ensure_ascii=False, indent=2)[:2000]}
{json.dumps(project_context.get("timeline", []), ensure_ascii = False, indent = 2)[:2000]}
## 实体提及历史
{json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)[:1500]}
{json.dumps(graph_data.get("entities", []), ensure_ascii = False, indent = 2)[:1500]}
请进行时序分析,返回 JSON 格式:
{{
@@ -280,7 +280,7 @@ class KnowledgeReasoner:
"knowledge_gaps": []
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -288,23 +288,23 @@ class KnowledgeReasoner:
try:
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.TEMPORAL,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.TEMPORAL,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.TEMPORAL,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=[],
answer = content,
reasoning_type = ReasoningType.TEMPORAL,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = [],
)
async def _associative_reasoning(
@@ -318,10 +318,10 @@ class KnowledgeReasoner:
{query}
## 实体
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii=False, indent=2)}
{json.dumps(graph_data.get("entities", [])[:20], ensure_ascii = False, indent = 2)}
## 关系
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii=False, indent=2)}
{json.dumps(graph_data.get("relations", [])[:30], ensure_ascii = False, indent = 2)}
请进行关联推理,发现隐含联系,返回 JSON 格式:
{{
@@ -334,7 +334,7 @@ class KnowledgeReasoner:
"knowledge_gaps": []
}}"""
content = await self._call_llm(prompt, temperature=0.4)
content = await self._call_llm(prompt, temperature = 0.4)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
@@ -342,23 +342,23 @@ class KnowledgeReasoner:
try:
data = json.loads(json_match.group())
return ReasoningResult(
answer=data.get("answer", ""),
reasoning_type=ReasoningType.ASSOCIATIVE,
confidence=data.get("confidence", 0.7),
evidence=[{"text": e} for e in data.get("evidence", [])],
related_entities=[],
gaps=data.get("knowledge_gaps", []),
answer = data.get("answer", ""),
reasoning_type = ReasoningType.ASSOCIATIVE,
confidence = data.get("confidence", 0.7),
evidence = [{"text": e} for e in data.get("evidence", [])],
related_entities = [],
gaps = data.get("knowledge_gaps", []),
)
except (json.JSONDecodeError, KeyError):
pass
return ReasoningResult(
answer=content,
reasoning_type=ReasoningType.ASSOCIATIVE,
confidence=0.5,
evidence=[],
related_entities=[],
gaps=[],
answer = content,
reasoning_type = ReasoningType.ASSOCIATIVE,
confidence = 0.5,
evidence = [],
related_entities = [],
gaps = [],
)
def find_inference_paths(
@@ -400,10 +400,10 @@ class KnowledgeReasoner:
# 找到一条路径
paths.append(
InferencePath(
start_entity=start_entity,
end_entity=end_entity,
path=path,
strength=self._calculate_path_strength(path),
start_entity = start_entity,
end_entity = end_entity,
path = path,
strength = self._calculate_path_strength(path),
)
)
continue
@@ -424,7 +424,7 @@ class KnowledgeReasoner:
queue.append((next_entity, new_path))
# 按强度排序
paths.sort(key=lambda p: p.strength, reverse=True)
paths.sort(key = lambda p: p.strength, reverse = True)
return paths
def _calculate_path_strength(self, path: list[dict]) -> float:
@@ -467,7 +467,7 @@ class KnowledgeReasoner:
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)[:3000]}
{json.dumps(project_context, ensure_ascii = False, indent = 2)[:3000]}
## 知识图谱
实体数: {len(graph_data.get("entities", []))}
@@ -483,7 +483,7 @@ class KnowledgeReasoner:
"confidence": 0.85
}}"""
content = await self._call_llm(prompt, temperature=0.3)
content = await self._call_llm(prompt, temperature = 0.3)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)

View File

@@ -41,7 +41,7 @@ class RelationExtractionResult:
class LLMClient:
"""Kimi API 客户端"""
def __init__(self, api_key: str = None, base_url: str = None):
def __init__(self, api_key: str = None, base_url: str = None) -> None:
self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL
self.headers = {
@@ -66,9 +66,9 @@ class LLMClient:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
headers = self.headers,
json = payload,
timeout = 120.0,
)
response.raise_for_status()
result = response.json()
@@ -92,9 +92,9 @@ class LLMClient:
async with client.stream(
"POST",
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
headers = self.headers,
json = payload,
timeout = 120.0,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
@@ -139,8 +139,8 @@ class LLMClient:
]
}}"""
messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature=0.1)
messages = [ChatMessage(role = "user", content = prompt)]
content = await self.chat(messages, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match:
@@ -150,19 +150,19 @@ class LLMClient:
data = json.loads(json_match.group())
entities = [
EntityExtractionResult(
name=e["name"],
type=e.get("type", "OTHER"),
definition=e.get("definition", ""),
confidence=e.get("confidence", 0.8),
name = e["name"],
type = e.get("type", "OTHER"),
definition = e.get("definition", ""),
confidence = e.get("confidence", 0.8),
)
for e in data.get("entities", [])
]
relations = [
RelationExtractionResult(
source=r["source"],
target=r["target"],
type=r.get("type", "related"),
confidence=r.get("confidence", 0.8),
source = r["source"],
target = r["target"],
type = r.get("type", "related"),
confidence = r.get("confidence", 0.8),
)
for r in data.get("relations", [])
]
@@ -176,7 +176,7 @@ class LLMClient:
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)}
{json.dumps(project_context, ensure_ascii = False, indent = 2)}
## 相关上下文
{context[:4000]}
@@ -188,19 +188,19 @@ class LLMClient:
messages = [
ChatMessage(
role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
role = "system", content = "你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
),
ChatMessage(role="user", content=prompt),
ChatMessage(role = "user", content = prompt),
]
return await self.chat(messages, temperature=0.3)
return await self.chat(messages, temperature = 0.3)
async def agent_command(self, command: str, project_context: dict) -> dict:
"""Agent 指令解析 - 将自然语言指令转换为结构化操作"""
prompt = f"""解析以下用户指令,转换为结构化操作:
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)}
{json.dumps(project_context, ensure_ascii = False, indent = 2)}
## 用户指令
{command}
@@ -221,8 +221,8 @@ class LLMClient:
- create_relation: 创建关系params 包含 source(源实体), target(目标实体), relation_type(关系类型)
"""
messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature=0.1)
messages = [ChatMessage(role = "user", content = prompt)]
content = await self.chat(messages, temperature = 0.1)
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match:
@@ -255,8 +255,8 @@ class LLMClient:
用中文回答,结构清晰。"""
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(messages, temperature=0.3)
messages = [ChatMessage(role = "user", content = prompt)]
return await self.chat(messages, temperature = 0.3)
# Singleton instance

View File

@@ -260,8 +260,8 @@ class LocalizationManager:
"date_format": "MM/dd/yyyy",
"time_format": "h:mm a",
"datetime_format": "MM/dd/yyyy h:mm a",
"number_format": "#,##0.##",
"currency_format": "$#,##0.00",
"number_format": "#, ##0.##",
"currency_format": "$#, ##0.00",
"first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -272,8 +272,8 @@ class LocalizationManager:
"date_format": "yyyy-MM-dd",
"time_format": "HH:mm",
"datetime_format": "yyyy-MM-dd HH:mm",
"number_format": "#,##0.##",
"currency_format": "¥#,##0.00",
"number_format": "#, ##0.##",
"currency_format": "¥#, ##0.00",
"first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -284,8 +284,8 @@ class LocalizationManager:
"date_format": "yyyy/MM/dd",
"time_format": "HH:mm",
"datetime_format": "yyyy/MM/dd HH:mm",
"number_format": "#,##0.##",
"currency_format": "NT$#,##0.00",
"number_format": "#, ##0.##",
"currency_format": "NT$#, ##0.00",
"first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -296,8 +296,8 @@ class LocalizationManager:
"date_format": "yyyy/MM/dd",
"time_format": "HH:mm",
"datetime_format": "yyyy/MM/dd HH:mm",
"number_format": "#,##0.##",
"currency_format": "¥#,##0",
"number_format": "#, ##0.##",
"currency_format": "¥#, ##0",
"first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -308,8 +308,8 @@ class LocalizationManager:
"date_format": "yyyy. MM. dd",
"time_format": "HH:mm",
"datetime_format": "yyyy. MM. dd HH:mm",
"number_format": "#,##0.##",
"currency_format": "₩#,##0",
"number_format": "#, ##0.##",
"currency_format": "₩#, ##0",
"first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -320,8 +320,8 @@ class LocalizationManager:
"date_format": "dd.MM.yyyy",
"time_format": "HH:mm",
"datetime_format": "dd.MM.yyyy HH:mm",
"number_format": "#,##0.##",
"currency_format": "#,##0.00 €",
"number_format": "#, ##0.##",
"currency_format": "#, ##0.00 €",
"first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -332,8 +332,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy",
"time_format": "HH:mm",
"datetime_format": "dd/MM/yyyy HH:mm",
"number_format": "#,##0.##",
"currency_format": "#,##0.00 €",
"number_format": "#, ##0.##",
"currency_format": "#, ##0.00 €",
"first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -344,8 +344,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy",
"time_format": "HH:mm",
"datetime_format": "dd/MM/yyyy HH:mm",
"number_format": "#,##0.##",
"currency_format": "#,##0.00 €",
"number_format": "#, ##0.##",
"currency_format": "#, ##0.00 €",
"first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -356,8 +356,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy",
"time_format": "HH:mm",
"datetime_format": "dd/MM/yyyy HH:mm",
"number_format": "#,##0.##",
"currency_format": "R$#,##0.00",
"number_format": "#, ##0.##",
"currency_format": "R$#, ##0.00",
"first_day_of_week": 0,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -368,8 +368,8 @@ class LocalizationManager:
"date_format": "dd.MM.yyyy",
"time_format": "HH:mm",
"datetime_format": "dd.MM.yyyy HH:mm",
"number_format": "#,##0.##",
"currency_format": "#,##0.00 ₽",
"number_format": "#, ##0.##",
"currency_format": "#, ##0.00 ₽",
"first_day_of_week": 1,
"calendar_type": CalendarType.GREGORIAN.value,
},
@@ -380,8 +380,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy",
"time_format": "hh:mm a",
"datetime_format": "dd/MM/yyyy hh:mm a",
"number_format": "#,##0.##",
"currency_format": "#,##0.00 ر.س",
"number_format": "#, ##0.##",
"currency_format": "#, ##0.00 ر.س",
"first_day_of_week": 6,
"calendar_type": CalendarType.ISLAMIC.value,
},
@@ -392,8 +392,8 @@ class LocalizationManager:
"date_format": "dd/MM/yyyy",
"time_format": "hh:mm a",
"datetime_format": "dd/MM/yyyy hh:mm a",
"number_format": "#,##0.##",
"currency_format": "₹#,##0.00",
"number_format": "#, ##0.##",
"currency_format": "₹#, ##0.00",
"first_day_of_week": 0,
"calendar_type": CalendarType.INDIAN.value,
},
@@ -719,7 +719,7 @@ class LocalizationManager:
},
}
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self._is_memory_db = db_path == ":memory:"
self._conn = None
@@ -736,11 +736,11 @@ class LocalizationManager:
conn.row_factory = sqlite3.Row
return conn
def _close_if_file_db(self, conn):
def _close_if_file_db(self, conn) -> None:
if not self._is_memory_db:
conn.close()
def _init_db(self):
def _init_db(self) -> None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -813,7 +813,7 @@ class LocalizationManager:
CREATE TABLE IF NOT EXISTS currency_configs (
code TEXT PRIMARY KEY, name TEXT NOT NULL, name_local TEXT DEFAULT '{}', symbol TEXT NOT NULL,
decimal_places INTEGER DEFAULT 2, decimal_separator TEXT DEFAULT '.',
thousands_separator TEXT DEFAULT ',', is_active INTEGER DEFAULT 1
thousands_separator TEXT DEFAULT ', ', is_active INTEGER DEFAULT 1
)
""")
cursor.execute("""
@@ -863,7 +863,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def _init_default_data(self):
def _init_default_data(self) -> None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1054,7 +1054,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
query = "SELECT * FROM translations WHERE 1=1"
query = "SELECT * FROM translations WHERE 1 = 1"
params = []
if language:
query += " AND language = ?"
@@ -1074,7 +1074,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code,))
cursor.execute("SELECT * FROM language_configs WHERE code = ?", (code, ))
row = cursor.fetchone()
if row:
return self._row_to_language_config(row)
@@ -1100,7 +1100,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id,))
cursor.execute("SELECT * FROM data_centers WHERE id = ?", (dc_id, ))
row = cursor.fetchone()
if row:
return self._row_to_data_center(row)
@@ -1112,7 +1112,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code,))
cursor.execute("SELECT * FROM data_centers WHERE region_code = ?", (region_code, ))
row = cursor.fetchone()
if row:
return self._row_to_data_center(row)
@@ -1126,7 +1126,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
query = "SELECT * FROM data_centers WHERE 1=1"
query = "SELECT * FROM data_centers WHERE 1 = 1"
params = []
if status:
query += " AND status = ?"
@@ -1146,7 +1146,7 @@ class LocalizationManager:
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,)
"SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id, )
)
row = cursor.fetchone()
if row:
@@ -1166,7 +1166,7 @@ class LocalizationManager:
SELECT * FROM data_centers WHERE supported_regions LIKE ? AND status = 'active'
ORDER BY priority LIMIT 1
""",
(f'%"{region_code}"%',),
(f'%"{region_code}"%', ),
)
row = cursor.fetchone()
if not row:
@@ -1182,7 +1182,7 @@ class LocalizationManager:
"""
SELECT * FROM data_centers WHERE id != ? AND status = 'active' ORDER BY priority LIMIT 1
""",
(primary_dc_id,),
(primary_dc_id, ),
)
secondary_row = cursor.fetchone()
secondary_dc_id = secondary_row["id"] if secondary_row else None
@@ -1222,7 +1222,7 @@ class LocalizationManager:
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,)
"SELECT * FROM localized_payment_methods WHERE provider = ?", (provider, )
)
row = cursor.fetchone()
if row:
@@ -1237,7 +1237,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
query = "SELECT * FROM localized_payment_methods WHERE 1=1"
query = "SELECT * FROM localized_payment_methods WHERE 1 = 1"
params = []
if active_only:
query += " AND is_active = 1"
@@ -1257,7 +1257,7 @@ class LocalizationManager:
def get_localized_payment_methods(
self, country_code: str, language: str = "en"
) -> list[dict[str, Any]]:
methods = self.list_payment_methods(country_code=country_code)
methods = self.list_payment_methods(country_code = country_code)
result = []
for method in methods:
name_local = method.name_local.get(language, method.name)
@@ -1278,7 +1278,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code,))
cursor.execute("SELECT * FROM country_configs WHERE code = ?", (code, ))
row = cursor.fetchone()
if row:
return self._row_to_country_config(row)
@@ -1292,7 +1292,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
query = "SELECT * FROM country_configs WHERE 1=1"
query = "SELECT * FROM country_configs WHERE 1 = 1"
params = []
if active_only:
query += " AND is_active = 1"
@@ -1332,11 +1332,11 @@ class LocalizationManager:
try:
locale = Locale.parse(language.replace("_", "-"))
if format_type == "date":
return dates.format_date(dt, locale=locale)
return dates.format_date(dt, locale = locale)
elif format_type == "time":
return dates.format_time(dt, locale=locale)
return dates.format_time(dt, locale = locale)
else:
return dates.format_datetime(dt, locale=locale)
return dates.format_datetime(dt, locale = locale)
except (ValueError, AttributeError):
pass
return dt.strftime(fmt)
@@ -1352,13 +1352,13 @@ class LocalizationManager:
try:
locale = Locale.parse(language.replace("_", "-"))
return numbers.format_decimal(
number, locale=locale, decimal_quantization=(decimal_places is not None)
number, locale = locale, decimal_quantization = (decimal_places is not None)
)
except (ValueError, AttributeError):
pass
if decimal_places is not None:
return f"{number:,.{decimal_places}f}"
return f"{number:,}"
return f"{number:, .{decimal_places}f}"
return f"{number:, }"
except Exception as e:
logger.error(f"Error formatting number: {e}")
return str(number)
@@ -1368,10 +1368,10 @@ class LocalizationManager:
if BABEL_AVAILABLE:
try:
locale = Locale.parse(language.replace("_", "-"))
return numbers.format_currency(amount, currency, locale=locale)
return numbers.format_currency(amount, currency, locale = locale)
except (ValueError, AttributeError):
pass
return f"{currency} {amount:,.2f}"
return f"{currency} {amount:, .2f}"
except Exception as e:
logger.error(f"Error formatting currency: {e}")
return f"{currency} {amount:.2f}"
@@ -1408,7 +1408,7 @@ class LocalizationManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id,))
cursor.execute("SELECT * FROM localization_settings WHERE tenant_id = ?", (tenant_id, ))
row = cursor.fetchone()
if row:
return self._row_to_localization_settings(row)
@@ -1453,7 +1453,7 @@ class LocalizationManager:
default_timezone,
lang_config.date_format if lang_config else "%Y-%m-%d",
lang_config.time_format if lang_config else "%H:%M",
lang_config.number_format if lang_config else "#,##0.##",
lang_config.number_format if lang_config else "#, ##0.##",
lang_config.calendar_type if lang_config else CalendarType.GREGORIAN.value,
lang_config.first_day_of_week if lang_config else 1,
region_code,
@@ -1517,7 +1517,7 @@ class LocalizationManager:
) -> dict[str, str]:
preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"}
if accept_language:
langs = accept_language.split(",")
langs = accept_language.split(", ")
for lang in langs:
lang_code = lang.split(";")[0].strip().replace("-", "_")
lang_config = self.get_language_config(lang_code)
@@ -1536,25 +1536,25 @@ class LocalizationManager:
def _row_to_translation(self, row: sqlite3.Row) -> Translation:
return Translation(
id=row["id"],
key=row["key"],
language=row["language"],
value=row["value"],
namespace=row["namespace"],
context=row["context"],
created_at=(
id = row["id"],
key = row["key"],
language = row["language"],
value = row["value"],
namespace = row["namespace"],
context = row["context"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
),
is_reviewed=bool(row["is_reviewed"]),
reviewed_by=row["reviewed_by"],
reviewed_at=(
is_reviewed = bool(row["is_reviewed"]),
reviewed_by = row["reviewed_by"],
reviewed_at = (
datetime.fromisoformat(row["reviewed_at"])
if row["reviewed_at"] and isinstance(row["reviewed_at"], str)
else row["reviewed_at"]
@@ -1563,39 +1563,39 @@ class LocalizationManager:
def _row_to_language_config(self, row: sqlite3.Row) -> LanguageConfig:
return LanguageConfig(
code=row["code"],
name=row["name"],
name_local=row["name_local"],
is_rtl=bool(row["is_rtl"]),
is_active=bool(row["is_active"]),
is_default=bool(row["is_default"]),
fallback_language=row["fallback_language"],
date_format=row["date_format"],
time_format=row["time_format"],
datetime_format=row["datetime_format"],
number_format=row["number_format"],
currency_format=row["currency_format"],
first_day_of_week=row["first_day_of_week"],
calendar_type=row["calendar_type"],
code = row["code"],
name = row["name"],
name_local = row["name_local"],
is_rtl = bool(row["is_rtl"]),
is_active = bool(row["is_active"]),
is_default = bool(row["is_default"]),
fallback_language = row["fallback_language"],
date_format = row["date_format"],
time_format = row["time_format"],
datetime_format = row["datetime_format"],
number_format = row["number_format"],
currency_format = row["currency_format"],
first_day_of_week = row["first_day_of_week"],
calendar_type = row["calendar_type"],
)
def _row_to_data_center(self, row: sqlite3.Row) -> DataCenter:
return DataCenter(
id=row["id"],
region_code=row["region_code"],
name=row["name"],
location=row["location"],
endpoint=row["endpoint"],
status=row["status"],
priority=row["priority"],
supported_regions=json.loads(row["supported_regions"] or "[]"),
capabilities=json.loads(row["capabilities"] or "{}"),
created_at=(
id = row["id"],
region_code = row["region_code"],
name = row["name"],
location = row["location"],
endpoint = row["endpoint"],
status = row["status"],
priority = row["priority"],
supported_regions = json.loads(row["supported_regions"] or "[]"),
capabilities = json.loads(row["capabilities"] or "{}"),
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -1604,18 +1604,18 @@ class LocalizationManager:
def _row_to_tenant_dc_mapping(self, row: sqlite3.Row) -> TenantDataCenterMapping:
return TenantDataCenterMapping(
id=row["id"],
tenant_id=row["tenant_id"],
primary_dc_id=row["primary_dc_id"],
secondary_dc_id=row["secondary_dc_id"],
region_code=row["region_code"],
data_residency=row["data_residency"],
created_at=(
id = row["id"],
tenant_id = row["tenant_id"],
primary_dc_id = row["primary_dc_id"],
secondary_dc_id = row["secondary_dc_id"],
region_code = row["region_code"],
data_residency = row["data_residency"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -1624,24 +1624,24 @@ class LocalizationManager:
def _row_to_payment_method(self, row: sqlite3.Row) -> LocalizedPaymentMethod:
return LocalizedPaymentMethod(
id=row["id"],
provider=row["provider"],
name=row["name"],
name_local=json.loads(row["name_local"] or "{}"),
supported_countries=json.loads(row["supported_countries"] or "[]"),
supported_currencies=json.loads(row["supported_currencies"] or "[]"),
is_active=bool(row["is_active"]),
config=json.loads(row["config"] or "{}"),
icon_url=row["icon_url"],
display_order=row["display_order"],
min_amount=row["min_amount"],
max_amount=row["max_amount"],
created_at=(
id = row["id"],
provider = row["provider"],
name = row["name"],
name_local = json.loads(row["name_local"] or "{}"),
supported_countries = json.loads(row["supported_countries"] or "[]"),
supported_currencies = json.loads(row["supported_currencies"] or "[]"),
is_active = bool(row["is_active"]),
config = json.loads(row["config"] or "{}"),
icon_url = row["icon_url"],
display_order = row["display_order"],
min_amount = row["min_amount"],
max_amount = row["max_amount"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -1650,48 +1650,48 @@ class LocalizationManager:
def _row_to_country_config(self, row: sqlite3.Row) -> CountryConfig:
return CountryConfig(
code=row["code"],
code3=row["code3"],
name=row["name"],
name_local=json.loads(row["name_local"] or "{}"),
region=row["region"],
default_language=row["default_language"],
supported_languages=json.loads(row["supported_languages"] or "[]"),
default_currency=row["default_currency"],
supported_currencies=json.loads(row["supported_currencies"] or "[]"),
timezone=row["timezone"],
calendar_type=row["calendar_type"],
date_format=row["date_format"],
time_format=row["time_format"],
number_format=row["number_format"],
address_format=row["address_format"],
phone_format=row["phone_format"],
vat_rate=row["vat_rate"],
is_active=bool(row["is_active"]),
code = row["code"],
code3 = row["code3"],
name = row["name"],
name_local = json.loads(row["name_local"] or "{}"),
region = row["region"],
default_language = row["default_language"],
supported_languages = json.loads(row["supported_languages"] or "[]"),
default_currency = row["default_currency"],
supported_currencies = json.loads(row["supported_currencies"] or "[]"),
timezone = row["timezone"],
calendar_type = row["calendar_type"],
date_format = row["date_format"],
time_format = row["time_format"],
number_format = row["number_format"],
address_format = row["address_format"],
phone_format = row["phone_format"],
vat_rate = row["vat_rate"],
is_active = bool(row["is_active"]),
)
def _row_to_localization_settings(self, row: sqlite3.Row) -> LocalizationSettings:
return LocalizationSettings(
id=row["id"],
tenant_id=row["tenant_id"],
default_language=row["default_language"],
supported_languages=json.loads(row["supported_languages"] or '["en"]'),
default_currency=row["default_currency"],
supported_currencies=json.loads(row["supported_currencies"] or '["USD"]'),
default_timezone=row["default_timezone"],
default_date_format=row["default_date_format"],
default_time_format=row["default_time_format"],
default_number_format=row["default_number_format"],
calendar_type=row["calendar_type"],
first_day_of_week=row["first_day_of_week"],
region_code=row["region_code"],
data_residency=row["data_residency"],
created_at=(
id = row["id"],
tenant_id = row["tenant_id"],
default_language = row["default_language"],
supported_languages = json.loads(row["supported_languages"] or '["en"]'),
default_currency = row["default_currency"],
supported_currencies = json.loads(row["supported_currencies"] or '["USD"]'),
default_timezone = row["default_timezone"],
default_date_format = row["default_date_format"],
default_time_format = row["default_time_format"],
default_number_format = row["default_number_format"],
calendar_type = row["calendar_type"],
first_day_of_week = row["first_day_of_week"],
region_code = row["region_code"],
data_residency = row["data_residency"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]

View File

@@ -32,7 +32,7 @@ class MultimodalEntity:
confidence: float
modality_features: dict = None # 模态特定特征
def __post_init__(self):
def __post_init__(self) -> None:
if self.modality_features is None:
self.modality_features = {}
@@ -200,11 +200,11 @@ class MultimodalEntityLinker:
if best_match:
return AlignmentResult(
entity_id=query_entity.get("id"),
matched_entity_id=best_match.get("id"),
similarity=best_similarity,
match_type=best_match_type,
confidence=best_similarity,
entity_id = query_entity.get("id"),
matched_entity_id = best_match.get("id"),
similarity = best_similarity,
match_type = best_match_type,
confidence = best_similarity,
)
return None
@@ -255,15 +255,15 @@ class MultimodalEntityLinker:
if result and result.matched_entity_id:
link = EntityLink(
id=str(uuid.uuid4())[:UUID_LENGTH],
project_id=project_id,
source_entity_id=ent1.get("id"),
target_entity_id=result.matched_entity_id,
link_type="same_as" if result.similarity > 0.95 else "related_to",
source_modality=mod1,
target_modality=mod2,
confidence=result.confidence,
evidence=f"Cross-modal alignment: {result.match_type}",
id = str(uuid.uuid4())[:UUID_LENGTH],
project_id = project_id,
source_entity_id = ent1.get("id"),
target_entity_id = result.matched_entity_id,
link_type = "same_as" if result.similarity > 0.95 else "related_to",
source_modality = mod1,
target_modality = mod2,
confidence = result.confidence,
evidence = f"Cross-modal alignment: {result.match_type}",
)
links.append(link)
@@ -319,7 +319,7 @@ class MultimodalEntityLinker:
# 选择最佳定义(最长的那个)
best_definition = (
max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else ""
max(fused_properties["definitions"], key = len) if fused_properties["definitions"] else ""
)
# 选择最佳名称(最常见的那个)
@@ -330,9 +330,9 @@ class MultimodalEntityLinker:
# 构建融合结果
return FusionResult(
canonical_entity_id=entity_id,
merged_entity_ids=merged_ids,
fused_properties={
canonical_entity_id = entity_id,
merged_entity_ids = merged_ids,
fused_properties = {
"name": best_name,
"definition": best_definition,
"aliases": list(fused_properties["aliases"]),
@@ -340,8 +340,8 @@ class MultimodalEntityLinker:
"modalities": list(fused_properties["modalities"]),
"contexts": fused_properties["contexts"][:10], # 最多10个上下文
},
source_modalities=list(fused_properties["modalities"]),
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
source_modalities = list(fused_properties["modalities"]),
confidence = min(1.0, len(linked_entities) * 0.2 + 0.5),
)
def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
@@ -441,7 +441,7 @@ class MultimodalEntityLinker:
)
# 按相似度排序
suggestions.sort(key=lambda x: x["similarity"], reverse=True)
suggestions.sort(key = lambda x: x["similarity"], reverse = True)
return suggestions
@@ -469,14 +469,14 @@ class MultimodalEntityLinker:
多模态实体记录
"""
return MultimodalEntity(
id=str(uuid.uuid4())[:UUID_LENGTH],
entity_id=entity_id,
project_id=project_id,
name="", # 将在后续填充
source_type=source_type,
source_id=source_id,
mention_context=mention_context,
confidence=confidence,
id = str(uuid.uuid4())[:UUID_LENGTH],
entity_id = entity_id,
project_id = project_id,
name = "", # 将在后续填充
source_type = source_type,
source_id = source_id,
mention_context = mention_context,
confidence = confidence,
)
def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:

View File

@@ -52,7 +52,7 @@ class VideoFrame:
ocr_confidence: float = 0.0
entities_detected: list[dict] = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.entities_detected is None:
self.entities_detected = []
@@ -76,7 +76,7 @@ class VideoInfo:
error_message: str = ""
metadata: dict = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.metadata is None:
self.metadata = {}
@@ -112,9 +112,9 @@ class MultimodalProcessor:
self.audio_dir = os.path.join(self.temp_dir, "audio")
# 创建目录
os.makedirs(self.video_dir, exist_ok=True)
os.makedirs(self.frames_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok=True)
os.makedirs(self.video_dir, exist_ok = True)
os.makedirs(self.frames_dir, exist_ok = True)
os.makedirs(self.audio_dir, exist_ok = True)
def extract_video_info(self, video_path: str) -> dict:
"""
@@ -152,14 +152,14 @@ class MultimodalProcessor:
"-v",
"error",
"-show_entries",
"format=duration,bit_rate",
"format = duration, bit_rate",
"-show_entries",
"stream=width,height,r_frame_rate",
"stream = width, height, r_frame_rate",
"-of",
"json",
video_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
result = subprocess.run(cmd, capture_output = True, text = True)
if result.returncode == 0:
data = json.loads(result.stdout)
return {
@@ -196,9 +196,9 @@ class MultimodalProcessor:
if FFMPEG_AVAILABLE:
(
ffmpeg.input(video_path)
.output(output_path, ac=1, ar=16000, vn=None)
.output(output_path, ac = 1, ar = 16000, vn = None)
.overwrite_output()
.run(quiet=True)
.run(quiet = True)
)
else:
# 使用命令行 ffmpeg
@@ -216,7 +216,7 @@ class MultimodalProcessor:
"-y",
output_path,
]
subprocess.run(cmd, check=True, capture_output=True)
subprocess.run(cmd, check = True, capture_output = True)
return output_path
except Exception as e:
@@ -240,7 +240,7 @@ class MultimodalProcessor:
# 创建帧存储目录
video_frames_dir = os.path.join(self.frames_dir, video_id)
os.makedirs(video_frames_dir, exist_ok=True)
os.makedirs(video_frames_dir, exist_ok = True)
try:
if CV2_AVAILABLE:
@@ -278,13 +278,13 @@ class MultimodalProcessor:
"-i",
video_path,
"-vf",
f"fps=1/{interval}",
f"fps = 1/{interval}",
"-frame_pts",
"1",
"-y",
output_pattern,
]
subprocess.run(cmd, check=True, capture_output=True)
subprocess.run(cmd, check = True, capture_output = True)
# 获取生成的帧文件列表
frame_paths = sorted(
@@ -320,10 +320,10 @@ class MultimodalProcessor:
image = image.convert("L")
# 使用 pytesseract 进行 OCR
text = pytesseract.image_to_string(image, lang="chi_sim+eng")
text = pytesseract.image_to_string(image, lang = "chi_sim+eng")
# 获取置信度数据
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
data = pytesseract.image_to_data(image, output_type = pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
@@ -382,13 +382,13 @@ class MultimodalProcessor:
ocr_text, confidence = self.perform_ocr(frame_path)
frame = VideoFrame(
id=str(uuid.uuid4())[:UUID_LENGTH],
video_id=video_id,
frame_number=frame_number,
timestamp=timestamp,
frame_path=frame_path,
ocr_text=ocr_text,
ocr_confidence=confidence,
id = str(uuid.uuid4())[:UUID_LENGTH],
video_id = video_id,
frame_number = frame_number,
timestamp = timestamp,
frame_path = frame_path,
ocr_text = ocr_text,
ocr_confidence = confidence,
)
frames.append(frame)
@@ -407,23 +407,23 @@ class MultimodalProcessor:
full_ocr_text = "\n\n".join(all_ocr_text)
return VideoProcessingResult(
video_id=video_id,
audio_path=audio_path,
frames=frames,
ocr_results=ocr_results,
full_text=full_ocr_text,
success=True,
video_id = video_id,
audio_path = audio_path,
frames = frames,
ocr_results = ocr_results,
full_text = full_ocr_text,
success = True,
)
except Exception as e:
return VideoProcessingResult(
video_id=video_id,
audio_path="",
frames=[],
ocr_results=[],
full_text="",
success=False,
error_message=str(e),
video_id = video_id,
audio_path = "",
frames = [],
ocr_results = [],
full_text = "",
success = False,
error_message = str(e),
)
def cleanup(self, video_id: str = None) -> None:
@@ -450,7 +450,7 @@ class MultimodalProcessor:
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.makedirs(dir_path, exist_ok=True)
os.makedirs(dir_path, exist_ok = True)
# Singleton instance

View File

@@ -39,7 +39,7 @@ class GraphEntity:
aliases: list[str] = None
properties: dict = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.aliases is None:
self.aliases = []
if self.properties is None:
@@ -57,7 +57,7 @@ class GraphRelation:
evidence: str = ""
properties: dict = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.properties is None:
self.properties = {}
@@ -95,7 +95,7 @@ class CentralityResult:
class Neo4jManager:
"""Neo4j 图数据库管理器"""
def __init__(self, uri: str = None, user: str = None, password: str = None):
def __init__(self, uri: str = None, user: str = None, password: str = None) -> None:
self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD
@@ -113,7 +113,7 @@ class Neo4jManager:
return
try:
self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
self._driver = GraphDatabase.driver(self.uri, auth = (self.user, self.password))
# 验证连接
self._driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}")
@@ -193,9 +193,9 @@ class Neo4jManager:
p.description = $description,
p.updated_at = datetime()
""",
project_id=project_id,
name=project_name,
description=project_description,
project_id = project_id,
name = project_name,
description = project_description,
)
def sync_entity(self, entity: GraphEntity) -> None:
@@ -218,13 +218,13 @@ class Neo4jManager:
MATCH (p:Project {id: $project_id})
MERGE (e)-[:BELONGS_TO]->(p)
""",
id=entity.id,
project_id=entity.project_id,
name=entity.name,
type=entity.type,
definition=entity.definition,
aliases=json.dumps(entity.aliases),
properties=json.dumps(entity.properties),
id = entity.id,
project_id = entity.project_id,
name = entity.name,
type = entity.type,
definition = entity.definition,
aliases = json.dumps(entity.aliases),
properties = json.dumps(entity.properties),
)
def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
@@ -261,7 +261,7 @@ class Neo4jManager:
MATCH (p:Project {id: entity.project_id})
MERGE (e)-[:BELONGS_TO]->(p)
""",
entities=entities_data,
entities = entities_data,
)
def sync_relation(self, relation: GraphRelation) -> None:
@@ -280,12 +280,12 @@ class Neo4jManager:
r.properties = $properties,
r.updated_at = datetime()
""",
id=relation.id,
source_id=relation.source_id,
target_id=relation.target_id,
relation_type=relation.relation_type,
evidence=relation.evidence,
properties=json.dumps(relation.properties),
id = relation.id,
source_id = relation.source_id,
target_id = relation.target_id,
relation_type = relation.relation_type,
evidence = relation.evidence,
properties = json.dumps(relation.properties),
)
def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
@@ -317,7 +317,7 @@ class Neo4jManager:
r.properties = rel.properties,
r.updated_at = datetime()
""",
relations=relations_data,
relations = relations_data,
)
def delete_entity(self, entity_id: str) -> None:
@@ -331,7 +331,7 @@ class Neo4jManager:
MATCH (e:Entity {id: $id})
DETACH DELETE e
""",
id=entity_id,
id = entity_id,
)
def delete_project(self, project_id: str) -> None:
@@ -346,7 +346,7 @@ class Neo4jManager:
OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p)
DETACH DELETE e, p
""",
id=project_id,
id = project_id,
)
# ==================== 复杂图查询 ====================
@@ -376,9 +376,9 @@ class Neo4jManager:
)
RETURN path
""",
source_id=source_id,
target_id=target_id,
max_depth=max_depth,
source_id = source_id,
target_id = target_id,
max_depth = max_depth,
)
record = result.single()
@@ -404,7 +404,7 @@ class Neo4jManager:
]
return PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships)
nodes = nodes, relationships = relationships, length = len(path.relationships)
)
def find_all_paths(
@@ -433,10 +433,10 @@ class Neo4jManager:
RETURN path
LIMIT $limit
""",
source_id=source_id,
target_id=target_id,
max_depth=max_depth,
limit=limit,
source_id = source_id,
target_id = target_id,
max_depth = max_depth,
limit = limit,
)
paths = []
@@ -460,7 +460,7 @@ class Neo4jManager:
paths.append(
PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships)
nodes = nodes, relationships = relationships, length = len(path.relationships)
)
)
@@ -491,9 +491,9 @@ class Neo4jManager:
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit
""",
entity_id=entity_id,
relation_type=relation_type,
limit=limit,
entity_id = entity_id,
relation_type = relation_type,
limit = limit,
)
else:
result = session.run(
@@ -502,8 +502,8 @@ class Neo4jManager:
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit
""",
entity_id=entity_id,
limit=limit,
entity_id = entity_id,
limit = limit,
)
neighbors = []
@@ -541,8 +541,8 @@ class Neo4jManager:
MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2})
RETURN DISTINCT common
""",
id1=entity_id1,
id2=entity_id2,
id1 = entity_id1,
id2 = entity_id2,
)
return [
@@ -581,7 +581,7 @@ class Neo4jManager:
{}
) YIELD value RETURN value
""",
project_id=project_id,
project_id = project_id,
)
# 创建临时图
@@ -601,7 +601,7 @@ class Neo4jManager:
}
)
""",
project_id=project_id,
project_id = project_id,
)
# 运行 PageRank
@@ -615,8 +615,8 @@ class Neo4jManager:
ORDER BY score DESC
LIMIT $top_n
""",
project_id=project_id,
top_n=top_n,
project_id = project_id,
top_n = top_n,
)
rankings = []
@@ -624,10 +624,10 @@ class Neo4jManager:
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=record["score"],
rank=rank,
entity_id = record["entity_id"],
entity_name = record["entity_name"],
score = record["score"],
rank = rank,
)
)
rank += 1
@@ -637,7 +637,7 @@ class Neo4jManager:
"""
CALL gds.graph.drop('project-graph-$project_id')
""",
project_id=project_id,
project_id = project_id,
)
return rankings
@@ -667,8 +667,8 @@ class Neo4jManager:
LIMIT $top_n
RETURN e.id as entity_id, e.name as entity_name, degree as score
""",
project_id=project_id,
top_n=top_n,
project_id = project_id,
top_n = top_n,
)
rankings = []
@@ -676,10 +676,10 @@ class Neo4jManager:
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=float(record["score"]),
rank=rank,
entity_id = record["entity_id"],
entity_name = record["entity_name"],
score = float(record["score"]),
rank = rank,
)
)
rank += 1
@@ -710,7 +710,7 @@ class Neo4jManager:
connections, size(connections) as connection_count
ORDER BY connection_count DESC
""",
project_id=project_id,
project_id = project_id,
)
# 手动分组(基于连通性)
@@ -752,12 +752,12 @@ class Neo4jManager:
results.append(
CommunityResult(
community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0)
community_id = comm_id, nodes = nodes, size = size, density = min(density, 1.0)
)
)
# 按大小排序
results.sort(key=lambda x: x.size, reverse=True)
results.sort(key = lambda x: x.size, reverse = True)
return results
def find_central_entities(
@@ -787,7 +787,7 @@ class Neo4jManager:
ORDER BY degree DESC
LIMIT 20
""",
project_id=project_id,
project_id = project_id,
)
else:
# 默认使用度中心性
@@ -800,7 +800,7 @@ class Neo4jManager:
ORDER BY degree DESC
LIMIT 20
""",
project_id=project_id,
project_id = project_id,
)
rankings = []
@@ -808,10 +808,10 @@ class Neo4jManager:
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=float(record["score"]),
rank=rank,
entity_id = record["entity_id"],
entity_name = record["entity_name"],
score = float(record["score"]),
rank = rank,
)
)
rank += 1
@@ -840,7 +840,7 @@ class Neo4jManager:
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN count(e) as count
""",
project_id=project_id,
project_id = project_id,
).single()["count"]
# 关系数量
@@ -850,7 +850,7 @@ class Neo4jManager:
MATCH (e)-[r:RELATES_TO]-()
RETURN count(r) as count
""",
project_id=project_id,
project_id = project_id,
).single()["count"]
# 实体类型分布
@@ -860,7 +860,7 @@ class Neo4jManager:
RETURN e.type as type, count(e) as count
ORDER BY count DESC
""",
project_id=project_id,
project_id = project_id,
)
types = {record["type"]: record["count"] for record in type_distribution}
@@ -873,7 +873,7 @@ class Neo4jManager:
WITH e, count(other) as degree
RETURN avg(degree) as avg_degree
""",
project_id=project_id,
project_id = project_id,
).single()["avg_degree"]
# 关系类型分布
@@ -885,7 +885,7 @@ class Neo4jManager:
ORDER BY count DESC
LIMIT 10
""",
project_id=project_id,
project_id = project_id,
)
relation_types = {record["type"]: record["count"] for record in rel_types}
@@ -927,8 +927,8 @@ class Neo4jManager:
}) YIELD node
RETURN DISTINCT node
""",
entity_ids=entity_ids,
depth=depth,
entity_ids = entity_ids,
depth = depth,
)
nodes = []
@@ -953,7 +953,7 @@ class Neo4jManager:
RETURN source.id as source_id, target.id as target_id,
r.relation_type as type, r.evidence as evidence
""",
node_ids=list(node_ids),
node_ids = list(node_ids),
)
relationships = [
@@ -1015,13 +1015,13 @@ def sync_project_to_neo4j(
# 同步实体
graph_entities = [
GraphEntity(
id=e["id"],
project_id=project_id,
name=e["name"],
type=e.get("type", "unknown"),
definition=e.get("definition", ""),
aliases=e.get("aliases", []),
properties=e.get("properties", {}),
id = e["id"],
project_id = project_id,
name = e["name"],
type = e.get("type", "unknown"),
definition = e.get("definition", ""),
aliases = e.get("aliases", []),
properties = e.get("properties", {}),
)
for e in entities
]
@@ -1030,12 +1030,12 @@ def sync_project_to_neo4j(
# 同步关系
graph_relations = [
GraphRelation(
id=r["id"],
source_id=r["source_entity_id"],
target_id=r["target_entity_id"],
relation_type=r["relation_type"],
evidence=r.get("evidence", ""),
properties=r.get("properties", {}),
id = r["id"],
source_id = r["source_entity_id"],
target_id = r["target_entity_id"],
relation_type = r["relation_type"],
evidence = r.get("evidence", ""),
properties = r.get("properties", {}),
)
for r in relations
]
@@ -1048,7 +1048,7 @@ def sync_project_to_neo4j(
if __name__ == "__main__":
# 测试代码
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level = logging.INFO)
manager = Neo4jManager()
@@ -1065,11 +1065,11 @@ if __name__ == "__main__":
# 测试实体
test_entity = GraphEntity(
id="test-entity-1",
project_id="test-project",
name="Test Entity",
type="Person",
definition="A test entity",
id = "test-entity-1",
project_id = "test-project",
name = "Test Entity",
type = "Person",
definition = "A test entity",
)
manager.sync_entity(test_entity)
print("✅ Entity synced")

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ import oss2
class OSSUploader:
def __init__(self):
def __init__(self) -> None:
self.access_key = os.getenv("ALI_ACCESS_KEY")
self.secret_key = os.getenv("ALI_SECRET_KEY")
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")

View File

@@ -82,7 +82,7 @@ class PerformanceMetric:
endpoint: str | None
duration_ms: float
timestamp: str
metadata: dict = field(default_factory=dict)
metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict:
return {
@@ -164,7 +164,7 @@ class CacheManager:
max_memory_size: int = 100 * 1024 * 1024, # 100MB
default_ttl: int = 3600, # 1小时
db_path: str = "insightflow.db",
):
) -> None:
self.db_path = db_path
self.default_ttl = default_ttl
self.max_memory_size = max_memory_size
@@ -176,7 +176,7 @@ class CacheManager:
if REDIS_AVAILABLE and redis_url:
try:
self.redis_client = redis.from_url(redis_url, decode_responses=True)
self.redis_client = redis.from_url(redis_url, decode_responses = True)
self.redis_client.ping()
self.use_redis = True
print(f"Redis 缓存已连接: {redis_url}")
@@ -233,7 +233,7 @@ class CacheManager:
def _get_entry_size(self, value: Any) -> int:
"""估算缓存条目大小"""
try:
return len(json.dumps(value, ensure_ascii=False).encode("utf-8"))
return len(json.dumps(value, ensure_ascii = False).encode("utf-8"))
except (TypeError, ValueError):
return 1024 # 默认估算
@@ -245,7 +245,7 @@ class CacheManager:
and self.memory_cache
):
# 移除最久未访问的
oldest_key, oldest_entry = self.memory_cache.popitem(last=False)
oldest_key, oldest_entry = self.memory_cache.popitem(last = False)
self.current_memory_size -= oldest_entry.size_bytes
self.stats.evictions += 1
@@ -314,7 +314,7 @@ class CacheManager:
if self.use_redis:
try:
serialized = json.dumps(value, ensure_ascii=False)
serialized = json.dumps(value, ensure_ascii = False)
self.redis_client.setex(key, ttl, serialized)
return True
except Exception as e:
@@ -331,12 +331,12 @@ class CacheManager:
now = time.time()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
expires_at=now + ttl if ttl > 0 else None,
size_bytes=size,
last_accessed=now,
key = key,
value = value,
created_at = now,
expires_at = now + ttl if ttl > 0 else None,
size_bytes = size,
last_accessed = now,
)
# 如果已存在,更新大小
@@ -412,7 +412,7 @@ class CacheManager:
try:
pipe = self.redis_client.pipeline()
for key, value in mapping.items():
serialized = json.dumps(value, ensure_ascii=False)
serialized = json.dumps(value, ensure_ascii = False)
pipe.setex(key, ttl, serialized)
pipe.execute()
return True
@@ -500,12 +500,12 @@ class CacheManager:
WHERE e.project_id = ?
ORDER BY mention_count DESC
LIMIT 100""",
(project_id,),
(project_id, ),
).fetchall()
for entity in entities:
key = f"entity:{entity['id']}"
self.set(key, dict(entity), ttl=7200) # 2小时
self.set(key, dict(entity), ttl = 7200) # 2小时
stats["entities"] += 1
# 预热关系数据
@@ -517,12 +517,12 @@ class CacheManager:
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ?
LIMIT 200""",
(project_id,),
(project_id, ),
).fetchall()
for relation in relations:
key = f"relation:{relation['id']}"
self.set(key, dict(relation), ttl=3600)
self.set(key, dict(relation), ttl = 3600)
stats["relations"] += 1
# 预热最近的转录
@@ -531,7 +531,7 @@ class CacheManager:
WHERE project_id = ?
ORDER BY created_at DESC
LIMIT 10""",
(project_id,),
(project_id, ),
).fetchall()
for transcript in transcripts:
@@ -543,16 +543,16 @@ class CacheManager:
"type": transcript.get("type", "audio"),
"created_at": transcript["created_at"],
}
self.set(key, meta, ttl=1800) # 30分钟
self.set(key, meta, ttl = 1800) # 30分钟
stats["transcripts"] += 1
# 预热项目知识库摘要
entity_count = conn.execute(
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id, )
).fetchone()[0]
relation_count = conn.execute(
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,)
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id, )
).fetchone()[0]
summary = {
@@ -561,7 +561,7 @@ class CacheManager:
"relation_count": relation_count,
"cached_at": datetime.now().isoformat(),
}
self.set(f"project_summary:{project_id}", summary, ttl=3600)
self.set(f"project_summary:{project_id}", summary, ttl = 3600)
conn.close()
@@ -583,7 +583,7 @@ class CacheManager:
try:
# 使用 Redis 的 scan 查找相关 key
pattern = f"*:{project_id}:*"
for key in self.redis_client.scan_iter(match=pattern):
for key in self.redis_client.scan_iter(match = pattern):
self.redis_client.delete(key)
count += 1
except Exception as e:
@@ -619,13 +619,13 @@ class DatabaseSharding:
base_db_path: str = "insightflow.db",
shard_db_dir: str = "./shards",
shards_count: int = 4,
):
) -> None:
self.base_db_path = base_db_path
self.shard_db_dir = shard_db_dir
self.shards_count = shards_count
# 确保分片目录存在
os.makedirs(shard_db_dir, exist_ok=True)
os.makedirs(shard_db_dir, exist_ok = True)
# 分片映射
self.shard_map: dict[str, ShardInfo] = {}
@@ -650,10 +650,10 @@ class DatabaseSharding:
db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db")
self.shard_map[shard_id] = ShardInfo(
shard_id=shard_id,
shard_key_range=(start_char, end_char),
db_path=db_path,
created_at=datetime.now().isoformat(),
shard_id = shard_id,
shard_key_range = (start_char, end_char),
db_path = db_path,
created_at = datetime.now().isoformat(),
)
# 确保分片数据库存在
@@ -757,11 +757,11 @@ class DatabaseSharding:
source_conn.row_factory = sqlite3.Row
entities = source_conn.execute(
"SELECT * FROM entities WHERE project_id = ?", (project_id,)
"SELECT * FROM entities WHERE project_id = ?", (project_id, )
).fetchall()
relations = source_conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id,)
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id, )
).fetchall()
source_conn.close()
@@ -794,8 +794,8 @@ class DatabaseSharding:
# 从源分片删除数据
source_conn = sqlite3.connect(source_info.db_path)
source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id,))
source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id,))
source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id, ))
source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id, ))
source_conn.commit()
source_conn.close()
@@ -917,7 +917,7 @@ class TaskQueue:
- 任务状态追踪和重试机制
"""
def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db"):
def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self.redis_url = redis_url
self.celery_app = None
@@ -934,7 +934,7 @@ class TaskQueue:
# 初始化 Celery
if CELERY_AVAILABLE and redis_url:
try:
self.celery_app = Celery("insightflow", broker=redis_url, backend=redis_url)
self.celery_app = Celery("insightflow", broker = redis_url, backend = redis_url)
self.use_celery = True
print("Celery 任务队列已初始化")
except Exception as e:
@@ -989,12 +989,12 @@ class TaskQueue:
task_id = str(uuid.uuid4())[:16]
task = TaskInfo(
id=task_id,
task_type=task_type,
status="pending",
payload=payload,
created_at=datetime.now().isoformat(),
max_retries=max_retries,
id = task_id,
task_type = task_type,
status = "pending",
payload = payload,
created_at = datetime.now().isoformat(),
max_retries = max_retries,
)
if self.use_celery:
@@ -1003,10 +1003,10 @@ class TaskQueue:
# 这里简化处理,实际应该定义具体的 Celery 任务
result = self.celery_app.send_task(
f"insightflow.tasks.{task_type}",
args=[payload],
task_id=task_id,
retry=True,
retry_policy={
args = [payload],
task_id = task_id,
retry = True,
retry_policy = {
"max_retries": max_retries,
"interval_start": 10,
"interval_step": 10,
@@ -1024,7 +1024,7 @@ class TaskQueue:
with self.task_lock:
self.tasks[task_id] = task
# 异步执行
threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start()
threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start()
# 保存到数据库
self._save_task(task)
@@ -1061,7 +1061,7 @@ class TaskQueue:
task.status = "retrying"
# 延迟重试
threading.Timer(
10 * task.retry_count, self._execute_task, args=(task_id,)
10 * task.retry_count, self._execute_task, args = (task_id, )
).start()
else:
task.status = "failed"
@@ -1089,8 +1089,8 @@ class TaskQueue:
task.id,
task.task_type,
task.status,
json.dumps(task.payload, ensure_ascii=False),
json.dumps(task.result, ensure_ascii=False) if task.result else None,
json.dumps(task.payload, ensure_ascii = False),
json.dumps(task.result, ensure_ascii = False) if task.result else None,
task.error_message,
task.retry_count,
task.max_retries,
@@ -1120,7 +1120,7 @@ class TaskQueue:
""",
(
task.status,
json.dumps(task.result, ensure_ascii=False) if task.result else None,
json.dumps(task.result, ensure_ascii = False) if task.result else None,
task.error_message,
task.retry_count,
task.started_at,
@@ -1136,7 +1136,7 @@ class TaskQueue:
"""获取任务状态"""
if self.use_celery:
try:
result = AsyncResult(task_id, app=self.celery_app)
result = AsyncResult(task_id, app = self.celery_app)
status_map = {
"PENDING": "pending",
@@ -1147,13 +1147,13 @@ class TaskQueue:
}
return TaskInfo(
id=task_id,
task_type="celery_task",
status=status_map.get(result.status, "unknown"),
payload={},
created_at="",
result=result.result if result.successful() else None,
error_message=str(result.result) if result.failed() else None,
id = task_id,
task_type = "celery_task",
status = status_map.get(result.status, "unknown"),
payload = {},
created_at = "",
result = result.result if result.successful() else None,
error_message = str(result.result) if result.failed() else None,
)
except Exception as e:
print(f"获取 Celery 任务状态失败: {e}")
@@ -1180,7 +1180,7 @@ class TaskQueue:
where_clauses.append("task_type = ?")
params.append(task_type)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1"
rows = conn.execute(
f"""
@@ -1198,17 +1198,17 @@ class TaskQueue:
for row in rows:
tasks.append(
TaskInfo(
id=row["id"],
task_type=row["task_type"],
status=row["status"],
payload=json.loads(row["payload"]) if row["payload"] else {},
created_at=row["created_at"],
started_at=row["started_at"],
completed_at=row["completed_at"],
result=json.loads(row["result"]) if row["result"] else None,
error_message=row["error_message"],
retry_count=row["retry_count"],
max_retries=row["max_retries"],
id = row["id"],
task_type = row["task_type"],
status = row["status"],
payload = json.loads(row["payload"]) if row["payload"] else {},
created_at = row["created_at"],
started_at = row["started_at"],
completed_at = row["completed_at"],
result = json.loads(row["result"]) if row["result"] else None,
error_message = row["error_message"],
retry_count = row["retry_count"],
max_retries = row["max_retries"],
)
)
@@ -1218,7 +1218,7 @@ class TaskQueue:
"""取消任务"""
if self.use_celery:
try:
self.celery_app.control.revoke(task_id, terminate=True)
self.celery_app.control.revoke(task_id, terminate = True)
return True
except Exception as e:
print(f"取消 Celery 任务失败: {e}")
@@ -1248,7 +1248,7 @@ class TaskQueue:
if not self.use_celery:
with self.task_lock:
self.tasks[task_id] = task
threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start()
threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start()
self._update_task_status(task)
return True
@@ -1307,7 +1307,7 @@ class PerformanceMonitor:
db_path: str = "insightflow.db",
slow_query_threshold: int = 1000,
alert_threshold: int = 5000, # 毫秒
): # 毫秒
) -> None: # 毫秒
self.db_path = db_path
self.slow_query_threshold = slow_query_threshold
self.alert_threshold = alert_threshold
@@ -1326,7 +1326,7 @@ class PerformanceMonitor:
duration_ms: float,
endpoint: str | None = None,
metadata: dict | None = None,
):
) -> None:
"""
记录性能指标
@@ -1337,12 +1337,12 @@ class PerformanceMonitor:
metadata: 额外元数据
"""
metric = PerformanceMetric(
id=str(uuid.uuid4())[:16],
metric_type=metric_type,
endpoint=endpoint,
duration_ms=duration_ms,
timestamp=datetime.now().isoformat(),
metadata=metadata or {},
id = str(uuid.uuid4())[:16],
metric_type = metric_type,
endpoint = endpoint,
duration_ms = duration_ms,
timestamp = datetime.now().isoformat(),
metadata = metadata or {},
)
# 添加到缓冲区
@@ -1379,7 +1379,7 @@ class PerformanceMonitor:
metric.endpoint,
metric.duration_ms,
metric.timestamp,
json.dumps(metric.metadata, ensure_ascii=False),
json.dumps(metric.metadata, ensure_ascii = False),
),
)
@@ -1439,7 +1439,7 @@ class PerformanceMonitor:
FROM performance_metrics
WHERE timestamp > datetime('now', ?)
""",
(f"-{hours} hours",),
(f"-{hours} hours", ),
).fetchone()
# 按类型统计
@@ -1454,7 +1454,7 @@ class PerformanceMonitor:
WHERE timestamp > datetime('now', ?)
GROUP BY metric_type
""",
(f"-{hours} hours",),
(f"-{hours} hours", ),
).fetchall()
# 按端点统计API
@@ -1472,7 +1472,7 @@ class PerformanceMonitor:
ORDER BY avg_duration DESC
LIMIT 20
""",
(f"-{hours} hours",),
(f"-{hours} hours", ),
).fetchall()
# 慢查询统计
@@ -1597,7 +1597,7 @@ class PerformanceMonitor:
DELETE FROM performance_metrics
WHERE timestamp < datetime('now', ?)
""",
(f"-{days} days",),
(f"-{days} days", ),
)
deleted = cursor.rowcount
@@ -1668,7 +1668,7 @@ def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | Non
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> None:
start_time = time.time()
try:
@@ -1699,17 +1699,17 @@ class PerformanceManager:
db_path: str = "insightflow.db",
redis_url: str | None = None,
enable_sharding: bool = False,
):
) -> None:
self.db_path = db_path
# 初始化各模块
self.cache = CacheManager(redis_url=redis_url, db_path=db_path)
self.cache = CacheManager(redis_url = redis_url, db_path = db_path)
self.sharding = DatabaseSharding(base_db_path=db_path) if enable_sharding else None
self.sharding = DatabaseSharding(base_db_path = db_path) if enable_sharding else None
self.task_queue = TaskQueue(redis_url=redis_url, db_path=db_path)
self.task_queue = TaskQueue(redis_url = redis_url, db_path = db_path)
self.monitor = PerformanceMonitor(db_path=db_path)
self.monitor = PerformanceMonitor(db_path = db_path)
def get_health_status(self) -> dict:
"""获取系统健康状态"""
@@ -1760,6 +1760,6 @@ def get_performance_manager(
global _performance_manager
if _performance_manager is None:
_performance_manager = PerformanceManager(
db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding
db_path = db_path, redis_url = redis_url, enable_sharding = enable_sharding
)
return _performance_manager

View File

@@ -63,7 +63,7 @@ class Plugin:
plugin_type: str
project_id: str
status: str = "active"
config: dict = field(default_factory=dict)
config: dict = field(default_factory = dict)
created_at: str = ""
updated_at: str = ""
last_used_at: str | None = None
@@ -111,8 +111,8 @@ class WebhookEndpoint:
endpoint_url: str
project_id: str | None = None
auth_type: str = "none" # none, api_key, oauth, custom
auth_config: dict = field(default_factory=dict)
trigger_events: list[str] = field(default_factory=list)
auth_config: dict = field(default_factory = dict)
trigger_events: list[str] = field(default_factory = list)
is_active: bool = True
created_at: str = ""
updated_at: str = ""
@@ -151,7 +151,7 @@ class ChromeExtensionToken:
user_id: str | None = None
project_id: str | None = None
name: str = ""
permissions: list[str] = field(default_factory=lambda: ["read", "write"])
permissions: list[str] = field(default_factory = lambda: ["read", "write"])
expires_at: str | None = None
created_at: str = ""
last_used_at: str | None = None
@@ -162,7 +162,7 @@ class ChromeExtensionToken:
class PluginManager:
"""插件管理主类"""
def __init__(self, db_manager=None):
def __init__(self, db_manager = None) -> None:
self.db = db_manager
self._handlers = {}
self._register_default_handlers()
@@ -213,7 +213,7 @@ class PluginManager:
def get_plugin(self, plugin_id: str) -> Plugin | None:
"""获取插件"""
conn = self.db.get_conn()
row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id,)).fetchone()
row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id, )).fetchone()
conn.close()
if row:
@@ -239,7 +239,7 @@ class PluginManager:
conditions.append("status = ?")
params.append(status)
where_clause = " AND ".join(conditions) if conditions else "1=1"
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute(
f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params
@@ -284,10 +284,10 @@ class PluginManager:
conn = self.db.get_conn()
# 删除关联的配置
conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id,))
conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ?", (plugin_id, ))
# 删除插件
cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id,))
cursor = conn.execute("DELETE FROM plugins WHERE id = ?", (plugin_id, ))
conn.commit()
conn.close()
@@ -296,16 +296,16 @@ class PluginManager:
def _row_to_plugin(self, row: sqlite3.Row) -> Plugin:
"""将数据库行转换为 Plugin 对象"""
return Plugin(
id=row["id"],
name=row["name"],
plugin_type=row["plugin_type"],
project_id=row["project_id"],
status=row["status"],
config=json.loads(row["config"]) if row["config"] else {},
created_at=row["created_at"],
updated_at=row["updated_at"],
last_used_at=row["last_used_at"],
use_count=row["use_count"],
id = row["id"],
name = row["name"],
plugin_type = row["plugin_type"],
project_id = row["project_id"],
status = row["status"],
config = json.loads(row["config"]) if row["config"] else {},
created_at = row["created_at"],
updated_at = row["updated_at"],
last_used_at = row["last_used_at"],
use_count = row["use_count"],
)
# ==================== Plugin Config ====================
@@ -343,13 +343,13 @@ class PluginManager:
conn.close()
return PluginConfig(
id=config_id,
plugin_id=plugin_id,
config_key=key,
config_value=value,
is_encrypted=is_encrypted,
created_at=now,
updated_at=now,
id = config_id,
plugin_id = plugin_id,
config_key = key,
config_value = value,
is_encrypted = is_encrypted,
created_at = now,
updated_at = now,
)
def get_plugin_config(self, plugin_id: str, key: str) -> str | None:
@@ -367,7 +367,7 @@ class PluginManager:
"""获取插件所有配置"""
conn = self.db.get_conn()
rows = conn.execute(
"SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id,)
"SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id, )
).fetchall()
conn.close()
@@ -402,7 +402,7 @@ class PluginManager:
class ChromeExtensionHandler:
"""Chrome 扩展处理器"""
def __init__(self, plugin_manager: PluginManager):
def __init__(self, plugin_manager: PluginManager) -> None:
self.pm = plugin_manager
def create_token(
@@ -417,7 +417,7 @@ class ChromeExtensionHandler:
token_id = str(uuid.uuid4())[:UUID_LENGTH]
# 生成随机令牌
raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip('=')}"
raw_token = f"if_ext_{base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip(' = ')}"
# 哈希存储
token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
@@ -427,7 +427,7 @@ class ChromeExtensionHandler:
if expires_days:
from datetime import timedelta
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
expires_at = (datetime.now() + timedelta(days = expires_days)).isoformat()
conn = self.pm.db.get_conn()
conn.execute(
@@ -452,14 +452,14 @@ class ChromeExtensionHandler:
conn.close()
return ChromeExtensionToken(
id=token_id,
token=raw_token, # 仅返回一次
user_id=user_id,
project_id=project_id,
name=name,
permissions=permissions or ["read"],
expires_at=expires_at,
created_at=now,
id = token_id,
token = raw_token, # 仅返回一次
user_id = user_id,
project_id = project_id,
name = name,
permissions = permissions or ["read"],
expires_at = expires_at,
created_at = now,
)
def validate_token(self, token: str) -> ChromeExtensionToken | None:
@@ -470,7 +470,7 @@ class ChromeExtensionHandler:
row = conn.execute(
"""SELECT * FROM chrome_extension_tokens
WHERE token_hash = ? AND is_revoked = 0""",
(token_hash,),
(token_hash, ),
).fetchone()
conn.close()
@@ -494,23 +494,23 @@ class ChromeExtensionHandler:
conn.close()
return ChromeExtensionToken(
id=row["id"],
token="", # 不返回实际令牌
user_id=row["user_id"],
project_id=row["project_id"],
name=row["name"],
permissions=json.loads(row["permissions"]),
expires_at=row["expires_at"],
created_at=row["created_at"],
last_used_at=now,
use_count=row["use_count"] + 1,
id = row["id"],
token = "", # 不返回实际令牌
user_id = row["user_id"],
project_id = row["project_id"],
name = row["name"],
permissions = json.loads(row["permissions"]),
expires_at = row["expires_at"],
created_at = row["created_at"],
last_used_at = now,
use_count = row["use_count"] + 1,
)
def revoke_token(self, token_id: str) -> bool:
"""撤销令牌"""
conn = self.pm.db.get_conn()
cursor = conn.execute(
"UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,)
"UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id, )
)
conn.commit()
conn.close()
@@ -545,17 +545,17 @@ class ChromeExtensionHandler:
for row in rows:
tokens.append(
ChromeExtensionToken(
id=row["id"],
token="", # 不返回实际令牌
user_id=row["user_id"],
project_id=row["project_id"],
name=row["name"],
permissions=json.loads(row["permissions"]),
expires_at=row["expires_at"],
created_at=row["created_at"],
last_used_at=row["last_used_at"],
use_count=row["use_count"],
is_revoked=bool(row["is_revoked"]),
id = row["id"],
token = "", # 不返回实际令牌
user_id = row["user_id"],
project_id = row["project_id"],
name = row["name"],
permissions = json.loads(row["permissions"]),
expires_at = row["expires_at"],
created_at = row["created_at"],
last_used_at = row["last_used_at"],
use_count = row["use_count"],
is_revoked = bool(row["is_revoked"]),
)
)
@@ -606,7 +606,7 @@ class ChromeExtensionHandler:
class BotHandler:
"""飞书/钉钉机器人处理器"""
def __init__(self, plugin_manager: PluginManager, bot_type: str):
def __init__(self, plugin_manager: PluginManager, bot_type: str) -> None:
self.pm = plugin_manager
self.bot_type = bot_type
@@ -646,16 +646,16 @@ class BotHandler:
conn.close()
return BotSession(
id=bot_id,
bot_type=self.bot_type,
session_id=session_id,
session_name=session_name,
project_id=project_id,
webhook_url=webhook_url,
secret=secret,
is_active=True,
created_at=now,
updated_at=now,
id = bot_id,
bot_type = self.bot_type,
session_id = session_id,
session_name = session_name,
project_id = project_id,
webhook_url = webhook_url,
secret = secret,
is_active = True,
created_at = now,
updated_at = now,
)
def get_session(self, session_id: str) -> BotSession | None:
@@ -686,7 +686,7 @@ class BotHandler:
rows = conn.execute(
"""SELECT * FROM bot_sessions
WHERE bot_type = ? ORDER BY created_at DESC""",
(self.bot_type,),
(self.bot_type, ),
).fetchall()
conn.close()
@@ -739,18 +739,18 @@ class BotHandler:
def _row_to_session(self, row: sqlite3.Row) -> BotSession:
"""将数据库行转换为 BotSession 对象"""
return BotSession(
id=row["id"],
bot_type=row["bot_type"],
session_id=row["session_id"],
session_name=row["session_name"],
project_id=row["project_id"],
webhook_url=row["webhook_url"],
secret=row["secret"],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
last_message_at=row["last_message_at"],
message_count=row["message_count"],
id = row["id"],
bot_type = row["bot_type"],
session_id = row["session_id"],
session_name = row["session_name"],
project_id = row["project_id"],
webhook_url = row["webhook_url"],
secret = row["secret"],
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
last_message_at = row["last_message_at"],
message_count = row["message_count"],
)
async def handle_message(self, session: BotSession, message: dict) -> dict:
@@ -880,7 +880,7 @@ class BotHandler:
hmac_code = hmac.new(
session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
digestmod=hashlib.sha256,
digestmod = hashlib.sha256,
).digest()
sign = base64.b64encode(hmac_code).decode("utf-8")
else:
@@ -895,7 +895,7 @@ class BotHandler:
async with httpx.AsyncClient() as client:
response = await client.post(
session.webhook_url, json=payload, headers={"Content-Type": "application/json"}
session.webhook_url, json = payload, headers = {"Content-Type": "application/json"}
)
return response.status_code == 200
@@ -911,7 +911,7 @@ class BotHandler:
hmac_code = hmac.new(
session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
digestmod=hashlib.sha256,
digestmod = hashlib.sha256,
).digest()
sign = base64.b64encode(hmac_code).decode("utf-8")
sign = urllib.parse.quote(sign)
@@ -922,11 +922,11 @@ class BotHandler:
url = session.webhook_url
if sign:
url = f"{url}&timestamp={timestamp}&sign={sign}"
url = f"{url}&timestamp = {timestamp}&sign = {sign}"
async with httpx.AsyncClient() as client:
response = await client.post(
url, json=payload, headers={"Content-Type": "application/json"}
url, json = payload, headers = {"Content-Type": "application/json"}
)
return response.status_code == 200
@@ -934,7 +934,7 @@ class BotHandler:
class WebhookIntegration:
"""Zapier/Make Webhook 集成"""
def __init__(self, plugin_manager: PluginManager, endpoint_type: str):
def __init__(self, plugin_manager: PluginManager, endpoint_type: str) -> None:
self.pm = plugin_manager
self.endpoint_type = endpoint_type
@@ -976,17 +976,17 @@ class WebhookIntegration:
conn.close()
return WebhookEndpoint(
id=endpoint_id,
name=name,
endpoint_type=self.endpoint_type,
endpoint_url=endpoint_url,
project_id=project_id,
auth_type=auth_type,
auth_config=auth_config or {},
trigger_events=trigger_events or [],
is_active=True,
created_at=now,
updated_at=now,
id = endpoint_id,
name = name,
endpoint_type = self.endpoint_type,
endpoint_url = endpoint_url,
project_id = project_id,
auth_type = auth_type,
auth_config = auth_config or {},
trigger_events = trigger_events or [],
is_active = True,
created_at = now,
updated_at = now,
)
def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None:
@@ -1016,7 +1016,7 @@ class WebhookIntegration:
rows = conn.execute(
"""SELECT * FROM webhook_endpoints
WHERE endpoint_type = ? ORDER BY created_at DESC""",
(self.endpoint_type,),
(self.endpoint_type, ),
).fetchall()
conn.close()
@@ -1065,7 +1065,7 @@ class WebhookIntegration:
def delete_endpoint(self, endpoint_id: str) -> bool:
"""删除端点"""
conn = self.pm.db.get_conn()
cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id,))
cursor = conn.execute("DELETE FROM webhook_endpoints WHERE id = ?", (endpoint_id, ))
conn.commit()
conn.close()
@@ -1074,19 +1074,19 @@ class WebhookIntegration:
def _row_to_endpoint(self, row: sqlite3.Row) -> WebhookEndpoint:
"""将数据库行转换为 WebhookEndpoint 对象"""
return WebhookEndpoint(
id=row["id"],
name=row["name"],
endpoint_type=row["endpoint_type"],
endpoint_url=row["endpoint_url"],
project_id=row["project_id"],
auth_type=row["auth_type"],
auth_config=json.loads(row["auth_config"]) if row["auth_config"] else {},
trigger_events=json.loads(row["trigger_events"]) if row["trigger_events"] else [],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
last_triggered_at=row["last_triggered_at"],
trigger_count=row["trigger_count"],
id = row["id"],
name = row["name"],
endpoint_type = row["endpoint_type"],
endpoint_url = row["endpoint_url"],
project_id = row["project_id"],
auth_type = row["auth_type"],
auth_config = json.loads(row["auth_config"]) if row["auth_config"] else {},
trigger_events = json.loads(row["trigger_events"]) if row["trigger_events"] else [],
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
last_triggered_at = row["last_triggered_at"],
trigger_count = row["trigger_count"],
)
async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool:
@@ -1113,7 +1113,7 @@ class WebhookIntegration:
async with httpx.AsyncClient() as client:
response = await client.post(
endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0
endpoint.endpoint_url, json = payload, headers = headers, timeout = 30.0
)
success = response.status_code in [200, 201, 202]
@@ -1157,7 +1157,7 @@ class WebhookIntegration:
class WebDAVSyncManager:
"""WebDAV 同步管理"""
def __init__(self, plugin_manager: PluginManager):
def __init__(self, plugin_manager: PluginManager) -> None:
self.pm = plugin_manager
def create_sync(
@@ -1202,25 +1202,25 @@ class WebDAVSyncManager:
conn.close()
return WebDAVSync(
id=sync_id,
name=name,
project_id=project_id,
server_url=server_url,
username=username,
password=password,
remote_path=remote_path,
sync_mode=sync_mode,
sync_interval=sync_interval,
last_sync_status="pending",
is_active=True,
created_at=now,
updated_at=now,
id = sync_id,
name = name,
project_id = project_id,
server_url = server_url,
username = username,
password = password,
remote_path = remote_path,
sync_mode = sync_mode,
sync_interval = sync_interval,
last_sync_status = "pending",
is_active = True,
created_at = now,
updated_at = now,
)
def get_sync(self, sync_id: str) -> WebDAVSync | None:
"""获取同步配置"""
conn = self.pm.db.get_conn()
row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id,)).fetchone()
row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id, )).fetchone()
conn.close()
if row:
@@ -1234,7 +1234,7 @@ class WebDAVSyncManager:
if project_id:
rows = conn.execute(
"SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC",
(project_id,),
(project_id, ),
).fetchall()
else:
rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall()
@@ -1283,7 +1283,7 @@ class WebDAVSyncManager:
def delete_sync(self, sync_id: str) -> bool:
"""删除同步配置"""
conn = self.pm.db.get_conn()
cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id,))
cursor = conn.execute("DELETE FROM webdav_syncs WHERE id = ?", (sync_id, ))
conn.commit()
conn.close()
@@ -1292,22 +1292,22 @@ class WebDAVSyncManager:
def _row_to_sync(self, row: sqlite3.Row) -> WebDAVSync:
"""将数据库行转换为 WebDAVSync 对象"""
return WebDAVSync(
id=row["id"],
name=row["name"],
project_id=row["project_id"],
server_url=row["server_url"],
username=row["username"],
password=row["password"],
remote_path=row["remote_path"],
sync_mode=row["sync_mode"],
sync_interval=row["sync_interval"],
last_sync_at=row["last_sync_at"],
last_sync_status=row["last_sync_status"],
last_sync_error=row["last_sync_error"] or "",
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
sync_count=row["sync_count"],
id = row["id"],
name = row["name"],
project_id = row["project_id"],
server_url = row["server_url"],
username = row["username"],
password = row["password"],
remote_path = row["remote_path"],
sync_mode = row["sync_mode"],
sync_interval = row["sync_interval"],
last_sync_at = row["last_sync_at"],
last_sync_status = row["last_sync_status"],
last_sync_error = row["last_sync_error"] or "",
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
sync_count = row["sync_count"],
)
async def test_connection(self, sync: WebDAVSync) -> dict:
@@ -1316,7 +1316,7 @@ class WebDAVSyncManager:
return {"success": False, "error": "WebDAV library not available"}
try:
client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password))
client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password))
# 尝试列出根目录
client.list("/")
@@ -1335,7 +1335,7 @@ class WebDAVSyncManager:
return {"success": False, "error": "Sync is not active"}
try:
client = webdav_client.Client(sync.server_url, auth=(sync.username, sync.password))
client = webdav_client.Client(sync.server_url, auth = (sync.username, sync.password))
# 确保远程目录存在
remote_project_path = f"{sync.remote_path}/{sync.project_id}"
@@ -1367,13 +1367,13 @@ class WebDAVSyncManager:
}
# 上传 JSON 文件
json_content = json.dumps(export_data, ensure_ascii=False, indent=2)
json_content = json.dumps(export_data, ensure_ascii = False, indent = 2)
json_path = f"{remote_project_path}/project_export.json"
# 使用临时文件上传
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
with tempfile.NamedTemporaryFile(mode = "w", suffix = ".json", delete = False) as f:
f.write(json_content)
temp_path = f.name
@@ -1419,7 +1419,7 @@ class WebDAVSyncManager:
_plugin_manager = None
def get_plugin_manager(db_manager=None) -> None:
def get_plugin_manager(db_manager = None) -> None:
"""获取 PluginManager 单例"""
global _plugin_manager
if _plugin_manager is None:

View File

@@ -35,7 +35,7 @@ class RateLimitInfo:
class SlidingWindowCounter:
"""滑动窗口计数器"""
def __init__(self, window_size: int = 60):
def __init__(self, window_size: int = 60) -> None:
self.window_size = window_size
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock()
@@ -110,17 +110,17 @@ class RateLimiter:
# 检查是否超过限制
if current_count >= stored_config.requests_per_minute:
return RateLimitInfo(
allowed=False,
remaining=0,
reset_time=reset_time,
retry_after=stored_config.window_size,
allowed = False,
remaining = 0,
reset_time = reset_time,
retry_after = stored_config.window_size,
)
# 允许请求,增加计数
await counter.add_request()
return RateLimitInfo(
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
allowed = True, remaining = remaining - 1, reset_time = reset_time, retry_after = 0
)
async def get_limit_info(self, key: str) -> RateLimitInfo:
@@ -128,10 +128,10 @@ class RateLimiter:
if key not in self.counters:
config = RateLimitConfig()
return RateLimitInfo(
allowed=True,
remaining=config.requests_per_minute,
reset_time=int(time.time()) + config.window_size,
retry_after=0,
allowed = True,
remaining = config.requests_per_minute,
reset_time = int(time.time()) + config.window_size,
retry_after = 0,
)
counter = self.counters[key]
@@ -142,10 +142,10 @@ class RateLimiter:
reset_time = int(time.time()) + config.window_size
return RateLimitInfo(
allowed=current_count < config.requests_per_minute,
remaining=remaining,
reset_time=reset_time,
retry_after=max(0, config.window_size)
allowed = current_count < config.requests_per_minute,
remaining = remaining,
reset_time = reset_time,
retry_after = max(0, config.window_size)
if current_count >= config.requests_per_minute
else 0,
)
@@ -184,12 +184,12 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
key_func: 生成限流键的函数,默认为 None使用函数名
"""
def decorator(func):
def decorator(func) -> None:
limiter = get_rate_limiter()
config = RateLimitConfig(requests_per_minute=requests_per_minute)
config = RateLimitConfig(requests_per_minute = requests_per_minute)
@wraps(func)
async def async_wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs) -> None:
key = key_func(*args, **kwargs) if key_func else func.__name__
info = await limiter.is_allowed(key, config)
@@ -201,7 +201,7 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
def sync_wrapper(*args, **kwargs) -> None:
key = key_func(*args, **kwargs) if key_func else func.__name__
# 同步版本使用 asyncio.run
info = asyncio.run(limiter.is_allowed(key, config))

View File

@@ -49,8 +49,8 @@ class SearchResult:
content_type: str # transcript, entity, relation
project_id: str
score: float
highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置
metadata: dict = field(default_factory=dict)
highlights: list[tuple[int, int]] = field(default_factory = list) # 高亮位置
metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict:
return {
@@ -74,7 +74,7 @@ class SemanticSearchResult:
project_id: str
similarity: float
embedding: list[float] | None = None
metadata: dict = field(default_factory=dict)
metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict:
result = {
@@ -132,7 +132,7 @@ class KnowledgeGap:
severity: str # high, medium, low
suggestions: list[str]
related_entities: list[str]
metadata: dict = field(default_factory=dict)
metadata: dict = field(default_factory = dict)
def to_dict(self) -> dict:
return {
@@ -189,7 +189,7 @@ class FullTextSearch:
- 支持布尔搜索AND/OR/NOT
"""
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self._init_search_tables()
@@ -318,8 +318,8 @@ class FullTextSearch:
content_id,
content_type,
project_id,
json.dumps(tokens, ensure_ascii=False),
json.dumps(token_positions, ensure_ascii=False),
json.dumps(tokens, ensure_ascii = False),
json.dumps(token_positions, ensure_ascii = False),
now,
now,
),
@@ -340,7 +340,7 @@ class FullTextSearch:
content_type,
project_id,
freq,
json.dumps(positions, ensure_ascii=False),
json.dumps(positions, ensure_ascii = False),
),
)
@@ -383,7 +383,7 @@ class FullTextSearch:
scored_results = self._score_results(results, parsed_query)
# 排序和分页
scored_results.sort(key=lambda x: x.score, reverse=True)
scored_results.sort(key = lambda x: x.score, reverse = True)
return scored_results[offset : offset + limit]
@@ -412,10 +412,10 @@ class FullTextSearch:
not_pattern = r"(?:NOT\s+|\-)(\w+)"
not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE)
not_terms.extend(not_matches)
query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE)
query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags = re.IGNORECASE)
# 处理 OR
or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE)
or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags = re.IGNORECASE)
if len(or_parts) > 1:
or_terms = [p.strip() for p in or_parts[1:] if p.strip()]
query_without_phrases = or_parts[0]
@@ -443,11 +443,11 @@ class FullTextSearch:
params.append(project_id)
if content_types:
placeholders = ",".join(["?" for _ in content_types])
placeholders = ", ".join(["?" for _ in content_types])
base_where.append(f"content_type IN ({placeholders})")
params.extend(content_types)
base_where_str = " AND ".join(base_where) if base_where else "1=1"
base_where_str = " AND ".join(base_where) if base_where else "1 = 1"
# 获取候选结果
candidates = set()
@@ -551,13 +551,13 @@ class FullTextSearch:
try:
if content_type == "transcript":
row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,)
"SELECT full_text FROM transcripts WHERE id = ?", (content_id, )
).fetchone()
return row["full_text"] if row else None
elif content_type == "entity":
row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,)
"SELECT name, definition FROM entities WHERE id = ?", (content_id, )
).fetchone()
if row:
return f"{row['name']} {row['definition'] or ''}"
@@ -571,7 +571,7 @@ class FullTextSearch:
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""",
(content_id,),
(content_id, ),
).fetchone()
if row:
return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}"
@@ -589,15 +589,15 @@ class FullTextSearch:
try:
if content_type == "transcript":
row = conn.execute(
"SELECT project_id FROM transcripts WHERE id = ?", (content_id,)
"SELECT project_id FROM transcripts WHERE id = ?", (content_id, )
).fetchone()
elif content_type == "entity":
row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (content_id,)
"SELECT project_id FROM entities WHERE id = ?", (content_id, )
).fetchone()
elif content_type == "relation":
row = conn.execute(
"SELECT project_id FROM entity_relations WHERE id = ?", (content_id,)
"SELECT project_id FROM entity_relations WHERE id = ?", (content_id, )
).fetchone()
else:
return None
@@ -654,13 +654,13 @@ class FullTextSearch:
scored.append(
SearchResult(
id=result["id"],
content=result["content"],
content_type=result["content_type"],
project_id=result["project_id"],
score=round(score, 4),
highlights=highlights[:10], # 限制高亮数量
metadata={},
id = result["id"],
content = result["content"],
content_type = result["content_type"],
project_id = result["project_id"],
score = round(score, 4),
highlights = highlights[:10], # 限制高亮数量
metadata = {},
)
)
@@ -699,7 +699,7 @@ class FullTextSearch:
snippet = snippet + "..."
# 添加高亮标记
for term in sorted(all_terms, key=len, reverse=True): # 长的先替换
for term in sorted(all_terms, key = len, reverse = True): # 长的先替换
pattern = re.compile(re.escape(term), re.IGNORECASE)
snippet = pattern.sub(f"**{term}**", snippet)
@@ -738,7 +738,7 @@ class FullTextSearch:
# 索引转录文本
transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,),
(project_id, ),
).fetchall()
for t in transcripts:
@@ -751,7 +751,7 @@ class FullTextSearch:
# 索引实体
entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,),
(project_id, ),
).fetchall()
for e in entities:
@@ -769,7 +769,7 @@ class FullTextSearch:
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ?""",
(project_id,),
(project_id, ),
).fetchall()
for r in relations:
@@ -805,7 +805,7 @@ class SemanticSearch:
self,
db_path: str = "insightflow.db",
model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
):
) -> None:
self.db_path = db_path
self.model_name = model_name
self.model = None
@@ -873,7 +873,7 @@ class SemanticSearch:
if len(text) > max_chars:
text = text[:max_chars]
embedding = self.model.encode(text, convert_to_list=True)
embedding = self.model.encode(text, convert_to_list = True)
return embedding
except Exception as e:
print(f"生成 embedding 失败: {e}")
@@ -971,11 +971,11 @@ class SemanticSearch:
params.append(project_id)
if content_types:
placeholders = ",".join(["?" for _ in content_types])
placeholders = ", ".join(["?" for _ in content_types])
where_clauses.append(f"content_type IN ({placeholders})")
params.extend(content_types)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1"
rows = conn.execute(
f"""
@@ -1005,13 +1005,13 @@ class SemanticSearch:
results.append(
SemanticSearchResult(
id=row["content_id"],
content=content or "",
content_type=row["content_type"],
project_id=row["project_id"],
similarity=float(similarity),
embedding=None, # 不返回 embedding 以节省带宽
metadata={},
id = row["content_id"],
content = content or "",
content_type = row["content_type"],
project_id = row["project_id"],
similarity = float(similarity),
embedding = None, # 不返回 embedding 以节省带宽
metadata = {},
)
)
except Exception as e:
@@ -1019,7 +1019,7 @@ class SemanticSearch:
continue
# 排序并返回 top_k
results.sort(key=lambda x: x.similarity, reverse=True)
results.sort(key = lambda x: x.similarity, reverse = True)
return results[:top_k]
def _get_content_text(self, content_id: str, content_type: str) -> str | None:
@@ -1029,13 +1029,13 @@ class SemanticSearch:
try:
if content_type == "transcript":
row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,)
"SELECT full_text FROM transcripts WHERE id = ?", (content_id, )
).fetchone()
result = row["full_text"] if row else None
elif content_type == "entity":
row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,)
"SELECT name, definition FROM entities WHERE id = ?", (content_id, )
).fetchone()
result = f"{row['name']}: {row['definition']}" if row else None
@@ -1047,7 +1047,7 @@ class SemanticSearch:
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""",
(content_id,),
(content_id, ),
).fetchone()
result = (
f"{row['source_name']} {row['relation_type']} {row['target_name']}"
@@ -1121,18 +1121,18 @@ class SemanticSearch:
results.append(
SemanticSearchResult(
id=row["content_id"],
content=content or "",
content_type=row["content_type"],
project_id=row["project_id"],
similarity=float(similarity),
metadata={},
id = row["content_id"],
content = content or "",
content_type = row["content_type"],
project_id = row["project_id"],
similarity = float(similarity),
metadata = {},
)
)
except (KeyError, ValueError):
continue
results.sort(key=lambda x: x.similarity, reverse=True)
results.sort(key = lambda x: x.similarity, reverse = True)
return results[:top_k]
def delete_embedding(self, content_id: str, content_type: str) -> bool:
@@ -1165,7 +1165,7 @@ class EntityPathDiscovery:
- 路径可视化数据生成
"""
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
def _get_conn(self) -> sqlite3.Connection:
@@ -1192,7 +1192,7 @@ class EntityPathDiscovery:
# 获取项目ID
row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id, )
).fetchone()
if not row:
@@ -1267,7 +1267,7 @@ class EntityPathDiscovery:
# 获取项目ID
row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id, )
).fetchone()
if not row:
@@ -1278,7 +1278,7 @@ class EntityPathDiscovery:
paths = []
def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int):
def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int) -> None:
if depth > max_depth:
return
@@ -1325,7 +1325,7 @@ class EntityPathDiscovery:
nodes = []
for entity_id in entity_ids:
row = conn.execute(
"SELECT id, name, type FROM entities WHERE id = ?", (entity_id,)
"SELECT id, name, type FROM entities WHERE id = ?", (entity_id, )
).fetchone()
if row:
nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
@@ -1368,16 +1368,16 @@ class EntityPathDiscovery:
confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0
return EntityPath(
path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}",
source_entity_id=entity_ids[0],
source_entity_name=nodes[0]["name"] if nodes else "",
target_entity_id=entity_ids[-1],
target_entity_name=nodes[-1]["name"] if nodes else "",
path_length=len(entity_ids) - 1,
nodes=nodes,
edges=edges,
confidence=round(confidence, 4),
path_description=path_desc,
path_id = f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}",
source_entity_id = entity_ids[0],
source_entity_name = nodes[0]["name"] if nodes else "",
target_entity_id = entity_ids[-1],
target_entity_name = nodes[-1]["name"] if nodes else "",
path_length = len(entity_ids) - 1,
nodes = nodes,
edges = edges,
confidence = round(confidence, 4),
path_description = path_desc,
)
def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]:
@@ -1395,7 +1395,7 @@ class EntityPathDiscovery:
# 获取项目ID
row = conn.execute(
"SELECT project_id, name FROM entities WHERE id = ?", (entity_id,)
"SELECT project_id, name FROM entities WHERE id = ?", (entity_id, )
).fetchone()
if not row:
@@ -1442,7 +1442,7 @@ class EntityPathDiscovery:
# 获取邻居信息
neighbor_info = conn.execute(
"SELECT name, type FROM entities WHERE id = ?", (neighbor_id,)
"SELECT name, type FROM entities WHERE id = ?", (neighbor_id, )
).fetchone()
if neighbor_info:
@@ -1463,7 +1463,7 @@ class EntityPathDiscovery:
conn.close()
# 按跳数排序
relations.sort(key=lambda x: x["hops"])
relations.sort(key = lambda x: x["hops"])
return relations
def _get_path_to_entity(
@@ -1562,7 +1562,7 @@ class EntityPathDiscovery:
# 获取所有实体
entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,)
"SELECT id, name FROM entities WHERE project_id = ?", (project_id, )
).fetchall()
# 计算每个实体作为桥梁的次数
@@ -1594,10 +1594,10 @@ class EntityPathDiscovery:
f"""
SELECT COUNT(*) as count
FROM entity_relations
WHERE ((source_entity_id IN ({",".join(["?" for _ in neighbor_ids])})
AND target_entity_id IN ({",".join(["?" for _ in neighbor_ids])}))
OR (target_entity_id IN ({",".join(["?" for _ in neighbor_ids])})
AND source_entity_id IN ({",".join(["?" for _ in neighbor_ids])})))
WHERE ((source_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})
AND target_entity_id IN ({", ".join(["?" for _ in neighbor_ids])}))
OR (target_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})
AND source_entity_id IN ({", ".join(["?" for _ in neighbor_ids])})))
AND project_id = ?
""",
list(neighbor_ids) * 4 + [project_id],
@@ -1620,7 +1620,7 @@ class EntityPathDiscovery:
conn.close()
# 按桥接分数排序
bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True)
bridge_scores.sort(key = lambda x: x["bridge_score"], reverse = True)
return bridge_scores[:20] # 返回前20
@@ -1638,7 +1638,7 @@ class KnowledgeGapDetection:
- 生成知识补全建议
"""
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
def _get_conn(self) -> sqlite3.Connection:
@@ -1676,7 +1676,7 @@ class KnowledgeGapDetection:
# 按严重程度排序
severity_order = {"high": 0, "medium": 1, "low": 2}
gaps.sort(key=lambda x: severity_order.get(x.severity, 3))
gaps.sort(key = lambda x: severity_order.get(x.severity, 3))
return gaps
@@ -1688,7 +1688,7 @@ class KnowledgeGapDetection:
# 获取项目的属性模板
templates = conn.execute(
"SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?",
(project_id,),
(project_id, ),
).fetchall()
if not templates:
@@ -1703,7 +1703,7 @@ class KnowledgeGapDetection:
# 检查每个实体的属性完整性
entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,)
"SELECT id, name FROM entities WHERE project_id = ?", (project_id, )
).fetchall()
for entity in entities:
@@ -1711,7 +1711,7 @@ class KnowledgeGapDetection:
# 获取实体已有的属性
existing_attrs = conn.execute(
"SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,)
"SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id, )
).fetchall()
existing_template_ids = {a["template_id"] for a in existing_attrs}
@@ -1723,7 +1723,7 @@ class KnowledgeGapDetection:
missing_names = []
for template_id in missing_templates:
template = conn.execute(
"SELECT name FROM attribute_templates WHERE id = ?", (template_id,)
"SELECT name FROM attribute_templates WHERE id = ?", (template_id, )
).fetchone()
if template:
missing_names.append(template["name"])
@@ -1731,18 +1731,18 @@ class KnowledgeGapDetection:
if missing_names:
gaps.append(
KnowledgeGap(
gap_id=f"gap_attr_{entity_id}",
gap_type="missing_attribute",
entity_id=entity_id,
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
severity="medium",
suggestions=[
gap_id = f"gap_attr_{entity_id}",
gap_type = "missing_attribute",
entity_id = entity_id,
entity_name = entity["name"],
description = f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
severity = "medium",
suggestions = [
f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}",
"检查属性模板定义是否合理",
],
related_entities=[],
metadata={"missing_attributes": missing_names},
related_entities = [],
metadata = {"missing_attributes": missing_names},
)
)
@@ -1756,7 +1756,7 @@ class KnowledgeGapDetection:
# 获取所有实体及其关系数量
entities = conn.execute(
"SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,)
"SELECT id, name, type FROM entities WHERE project_id = ?", (project_id, )
).fetchall()
for entity in entities:
@@ -1793,19 +1793,19 @@ class KnowledgeGapDetection:
gaps.append(
KnowledgeGap(
gap_id=f"gap_sparse_{entity_id}",
gap_type="sparse_relation",
entity_id=entity_id,
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
severity="medium" if relation_count == 0 else "low",
suggestions=[
gap_id = f"gap_sparse_{entity_id}",
gap_type = "sparse_relation",
entity_id = entity_id,
entity_name = entity["name"],
description = f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
severity = "medium" if relation_count == 0 else "low",
suggestions = [
f"检查转录文本中提及 '{entity['name']}' 的其他实体",
f"手动添加 '{entity['name']}' 与其他实体的关系",
"使用实体对齐功能合并相似实体",
],
related_entities=[r["id"] for r in potential_related],
metadata={
related_entities = [r["id"] for r in potential_related],
metadata = {
"relation_count": relation_count,
"potential_related": [r["name"] for r in potential_related],
},
@@ -1831,25 +1831,25 @@ class KnowledgeGapDetection:
AND r1.id IS NULL
AND r2.id IS NULL
""",
(project_id,),
(project_id, ),
).fetchall()
for entity in isolated:
gaps.append(
KnowledgeGap(
gap_id=f"gap_iso_{entity['id']}",
gap_type="isolated_entity",
entity_id=entity["id"],
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
severity="high",
suggestions=[
gap_id = f"gap_iso_{entity['id']}",
gap_type = "isolated_entity",
entity_id = entity["id"],
entity_name = entity["name"],
description = f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
severity = "high",
suggestions = [
f"检查 '{entity['name']}' 是否应该与其他实体建立关系",
f"考虑删除不相关的实体 '{entity['name']}'",
"运行关系发现算法自动识别潜在关系",
],
related_entities=[],
metadata={"entity_type": entity["type"]},
related_entities = [],
metadata = {"entity_type": entity["type"]},
)
)
@@ -1869,21 +1869,21 @@ class KnowledgeGapDetection:
WHERE project_id = ?
AND (definition IS NULL OR definition = '')
""",
(project_id,),
(project_id, ),
).fetchall()
for entity in incomplete:
gaps.append(
KnowledgeGap(
gap_id=f"gap_inc_{entity['id']}",
gap_type="incomplete_entity",
entity_id=entity["id"],
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 缺少定义",
severity="low",
suggestions=[f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"],
related_entities=[],
metadata={"entity_type": entity["type"]},
gap_id = f"gap_inc_{entity['id']}",
gap_type = "incomplete_entity",
entity_id = entity["id"],
entity_name = entity["name"],
description = f"实体 '{entity['name']}' 缺少定义",
severity = "low",
suggestions = [f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"],
related_entities = [],
metadata = {"entity_type": entity["type"]},
)
)
@@ -1897,7 +1897,7 @@ class KnowledgeGapDetection:
# 分析转录文本中频繁提及但未提取为实体的词
transcripts = conn.execute(
"SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,)
"SELECT full_text FROM transcripts WHERE project_id = ?", (project_id, )
).fetchall()
# 合并所有文本
@@ -1905,7 +1905,7 @@ class KnowledgeGapDetection:
# 获取现有实体名称
existing_entities = conn.execute(
"SELECT name FROM entities WHERE project_id = ?", (project_id,)
"SELECT name FROM entities WHERE project_id = ?", (project_id, )
).fetchall()
existing_names = {e["name"].lower() for e in existing_entities}
@@ -1925,18 +1925,18 @@ class KnowledgeGapDetection:
if count >= 3: # 出现3次以上
gaps.append(
KnowledgeGap(
gap_id=f"gap_missing_{hash(entity) % 10000}",
gap_type="missing_key_entity",
entity_id=None,
entity_name=None,
description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
severity="low",
suggestions=[
gap_id = f"gap_missing_{hash(entity) % 10000}",
gap_type = "missing_key_entity",
entity_id = None,
entity_name = None,
description = f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
severity = "low",
suggestions = [
f"考虑将 '{entity}' 添加为实体",
"检查实体提取算法是否需要优化",
],
related_entities=[],
metadata={"mention_count": count},
related_entities = [],
metadata = {"mention_count": count},
)
)
@@ -2040,7 +2040,7 @@ class SearchManager:
整合全文搜索、语义搜索、实体路径发现和知识缺口识别功能
"""
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self.fulltext_search = FullTextSearch(db_path)
self.semantic_search = SemanticSearch(db_path)
@@ -2060,12 +2060,12 @@ class SearchManager:
Dict: 混合搜索结果
"""
# 全文搜索
fulltext_results = self.fulltext_search.search(query, project_id, limit=limit)
fulltext_results = self.fulltext_search.search(query, project_id, limit = limit)
# 语义搜索
semantic_results = []
if self.semantic_search.is_available():
semantic_results = self.semantic_search.search(query, project_id, top_k=limit)
semantic_results = self.semantic_search.search(query, project_id, top_k = limit)
# 合并结果(去重并加权)
combined = {}
@@ -2104,7 +2104,7 @@ class SearchManager:
# 排序
results = list(combined.values())
results.sort(key=lambda x: x["combined_score"], reverse=True)
results.sort(key = lambda x: x["combined_score"], reverse = True)
return {
"query": query,
@@ -2138,7 +2138,7 @@ class SearchManager:
# 索引转录文本
transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,),
(project_id, ),
).fetchall()
for t in transcripts:
@@ -2152,7 +2152,7 @@ class SearchManager:
# 索引实体
entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,),
(project_id, ),
).fetchall()
for e in entities:
@@ -2191,7 +2191,7 @@ class SearchManager:
"""SELECT content_type, COUNT(*) as count
FROM search_indexes WHERE project_id = ?
GROUP BY content_type""",
(project_id,),
(project_id, ),
).fetchall()
type_stats = {r["content_type"]: r["count"] for r in rows}
@@ -2226,7 +2226,7 @@ def fulltext_search(
) -> list[SearchResult]:
"""全文搜索便捷函数"""
manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit)
return manager.fulltext_search.search(query, project_id, limit = limit)
def semantic_search(
@@ -2234,7 +2234,7 @@ def semantic_search(
) -> list[SemanticSearchResult]:
"""语义搜索便捷函数"""
manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k)
return manager.semantic_search.search(query, project_id, top_k = top_k)
def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None:

View File

@@ -86,7 +86,7 @@ class AuditLog:
after_value: str | None = None
success: bool = True
error_message: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
created_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -103,8 +103,8 @@ class EncryptionConfig:
key_derivation: str = "pbkdf2" # pbkdf2, argon2
master_key_hash: str | None = None # 主密钥哈希(用于验证)
salt: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
created_at: str = field(default_factory = lambda: datetime.now().isoformat())
updated_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -123,8 +123,8 @@ class MaskingRule:
is_active: bool = True
priority: int = 0
description: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
created_at: str = field(default_factory = lambda: datetime.now().isoformat())
updated_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -145,8 +145,8 @@ class DataAccessPolicy:
max_access_count: int | None = None # 最大访问次数
require_approval: bool = False
is_active: bool = True
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
created_at: str = field(default_factory = lambda: datetime.now().isoformat())
updated_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -164,7 +164,7 @@ class AccessRequest:
approved_by: str | None = None
approved_at: str | None = None
expires_at: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
created_at: str = field(default_factory = lambda: datetime.now().isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -176,7 +176,7 @@ class SecurityManager:
# 预定义脱敏规则
DEFAULT_MASKING_RULES = {
MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
MaskingRuleType.EMAIL: {"pattern": r"(\w{1, 3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
MaskingRuleType.ID_CARD: {
"pattern": r"(\d{6})\d{8}(\d{4})",
"replacement": r"\1********\2",
@@ -190,12 +190,12 @@ class SecurityManager:
"replacement": r"\1**",
},
MaskingRuleType.ADDRESS: {
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
"pattern": r"([\u4e00-\u9fa5]{2, })([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
"replacement": r"\1\2***",
},
}
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
# 预编译正则缓存
self._compiled_patterns: dict[str, re.Pattern] = {}
@@ -345,18 +345,18 @@ class SecurityManager:
) -> AuditLog:
"""记录审计日志"""
log = AuditLog(
id=self._generate_id(),
action_type=action_type.value,
user_id=user_id,
user_ip=user_ip,
user_agent=user_agent,
resource_type=resource_type,
resource_id=resource_id,
action_details=json.dumps(action_details) if action_details else None,
before_value=before_value,
after_value=after_value,
success=success,
error_message=error_message,
id = self._generate_id(),
action_type = action_type.value,
user_id = user_id,
user_ip = user_ip,
user_agent = user_agent,
resource_type = resource_type,
resource_id = resource_id,
action_details = json.dumps(action_details) if action_details else None,
before_value = before_value,
after_value = after_value,
success = success,
error_message = error_message,
)
conn = sqlite3.connect(self.db_path)
@@ -405,7 +405,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
query = "SELECT * FROM audit_logs WHERE 1=1"
query = "SELECT * FROM audit_logs WHERE 1 = 1"
params = []
if user_id:
@@ -444,19 +444,19 @@ class SecurityManager:
for row in rows:
log = AuditLog(
id=row[0],
action_type=row[1],
user_id=row[2],
user_ip=row[3],
user_agent=row[4],
resource_type=row[5],
resource_id=row[6],
action_details=row[7],
before_value=row[8],
after_value=row[9],
success=bool(row[10]),
error_message=row[11],
created_at=row[12],
id = row[0],
action_type = row[1],
user_id = row[2],
user_ip = row[3],
user_agent = row[4],
resource_type = row[5],
resource_id = row[6],
action_details = row[7],
before_value = row[8],
after_value = row[9],
success = bool(row[10]),
error_message = row[11],
created_at = row[12],
)
logs.append(log)
@@ -470,7 +470,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1=1"
query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1 = 1"
params = []
if start_time:
@@ -513,10 +513,10 @@ class SecurityManager:
raise RuntimeError("cryptography library not available")
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
algorithm = hashes.SHA256(),
length = 32,
salt = salt,
iterations = 100000,
)
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
@@ -533,20 +533,20 @@ class SecurityManager:
key_hash = hashlib.sha256(key).hexdigest()
config = EncryptionConfig(
id=self._generate_id(),
project_id=project_id,
is_enabled=True,
encryption_type="aes-256-gcm",
key_derivation="pbkdf2",
master_key_hash=key_hash,
salt=salt,
id = self._generate_id(),
project_id = project_id,
is_enabled = True,
encryption_type = "aes-256-gcm",
key_derivation = "pbkdf2",
master_key_hash = key_hash,
salt = salt,
)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 检查是否已存在配置
cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,))
cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id, ))
existing = cursor.fetchone()
if existing:
@@ -593,10 +593,10 @@ class SecurityManager:
# 记录审计日志
self.log_audit(
action_type=AuditActionType.ENCRYPTION_ENABLE,
resource_type="project",
resource_id=project_id,
action_details={"encryption_type": config.encryption_type},
action_type = AuditActionType.ENCRYPTION_ENABLE,
resource_type = "project",
resource_id = project_id,
action_details = {"encryption_type": config.encryption_type},
)
return config
@@ -624,9 +624,9 @@ class SecurityManager:
# 记录审计日志
self.log_audit(
action_type=AuditActionType.ENCRYPTION_DISABLE,
resource_type="project",
resource_id=project_id,
action_type = AuditActionType.ENCRYPTION_DISABLE,
resource_type = "project",
resource_id = project_id,
)
return True
@@ -641,7 +641,7 @@ class SecurityManager:
cursor.execute(
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
(project_id,),
(project_id, ),
)
row = cursor.fetchone()
conn.close()
@@ -660,7 +660,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,))
cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id, ))
row = cursor.fetchone()
conn.close()
@@ -668,15 +668,15 @@ class SecurityManager:
return None
return EncryptionConfig(
id=row[0],
project_id=row[1],
is_enabled=bool(row[2]),
encryption_type=row[3],
key_derivation=row[4],
master_key_hash=row[5],
salt=row[6],
created_at=row[7],
updated_at=row[8],
id = row[0],
project_id = row[1],
is_enabled = bool(row[2]),
encryption_type = row[3],
key_derivation = row[4],
master_key_hash = row[5],
salt = row[6],
created_at = row[7],
updated_at = row[8],
)
def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
@@ -724,14 +724,14 @@ class SecurityManager:
replacement = replacement or default["replacement"]
rule = MaskingRule(
id=self._generate_id(),
project_id=project_id,
name=name,
rule_type=rule_type.value,
pattern=pattern or "",
replacement=replacement or "****",
description=description,
priority=priority,
id = self._generate_id(),
project_id = project_id,
name = name,
rule_type = rule_type.value,
pattern = pattern or "",
replacement = replacement or "****",
description = description,
priority = priority,
)
conn = sqlite3.connect(self.db_path)
@@ -764,10 +764,10 @@ class SecurityManager:
# 记录审计日志
self.log_audit(
action_type=AuditActionType.DATA_MASKING,
resource_type="project",
resource_id=project_id,
action_details={"action": "create_rule", "rule_name": name},
action_type = AuditActionType.DATA_MASKING,
resource_type = "project",
resource_id = project_id,
action_details = {"action": "create_rule", "rule_name": name},
)
return rule
@@ -793,17 +793,17 @@ class SecurityManager:
for row in rows:
rules.append(
MaskingRule(
id=row[0],
project_id=row[1],
name=row[2],
rule_type=row[3],
pattern=row[4],
replacement=row[5],
is_active=bool(row[6]),
priority=row[7],
description=row[8],
created_at=row[9],
updated_at=row[10],
id = row[0],
project_id = row[1],
name = row[2],
rule_type = row[3],
pattern = row[4],
replacement = row[5],
is_active = bool(row[6]),
priority = row[7],
description = row[8],
created_at = row[9],
updated_at = row[10],
)
)
@@ -847,7 +847,7 @@ class SecurityManager:
# 获取更新后的规则
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id,))
cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id, ))
row = cursor.fetchone()
conn.close()
@@ -855,17 +855,17 @@ class SecurityManager:
return None
return MaskingRule(
id=row[0],
project_id=row[1],
name=row[2],
rule_type=row[3],
pattern=row[4],
replacement=row[5],
is_active=bool(row[6]),
priority=row[7],
description=row[8],
created_at=row[9],
updated_at=row[10],
id = row[0],
project_id = row[1],
name = row[2],
rule_type = row[3],
pattern = row[4],
replacement = row[5],
is_active = bool(row[6]),
priority = row[7],
description = row[8],
created_at = row[9],
updated_at = row[10],
)
def delete_masking_rule(self, rule_id: str) -> bool:
@@ -873,7 +873,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id,))
cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id, ))
success = cursor.rowcount > 0
conn.commit()
@@ -936,16 +936,16 @@ class SecurityManager:
) -> DataAccessPolicy:
"""创建数据访问策略"""
policy = DataAccessPolicy(
id=self._generate_id(),
project_id=project_id,
name=name,
description=description,
allowed_users=json.dumps(allowed_users) if allowed_users else None,
allowed_roles=json.dumps(allowed_roles) if allowed_roles else None,
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
max_access_count=max_access_count,
require_approval=require_approval,
id = self._generate_id(),
project_id = project_id,
name = name,
description = description,
allowed_users = json.dumps(allowed_users) if allowed_users else None,
allowed_roles = json.dumps(allowed_roles) if allowed_roles else None,
allowed_ips = json.dumps(allowed_ips) if allowed_ips else None,
time_restrictions = json.dumps(time_restrictions) if time_restrictions else None,
max_access_count = max_access_count,
require_approval = require_approval,
)
conn = sqlite3.connect(self.db_path)
@@ -1002,19 +1002,19 @@ class SecurityManager:
for row in rows:
policies.append(
DataAccessPolicy(
id=row[0],
project_id=row[1],
name=row[2],
description=row[3],
allowed_users=row[4],
allowed_roles=row[5],
allowed_ips=row[6],
time_restrictions=row[7],
max_access_count=row[8],
require_approval=bool(row[9]),
is_active=bool(row[10]),
created_at=row[11],
updated_at=row[12],
id = row[0],
project_id = row[1],
name = row[2],
description = row[3],
allowed_users = row[4],
allowed_roles = row[5],
allowed_ips = row[6],
time_restrictions = row[7],
max_access_count = row[8],
require_approval = bool(row[9]),
is_active = bool(row[10]),
created_at = row[11],
updated_at = row[12],
)
)
@@ -1028,7 +1028,7 @@ class SecurityManager:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,)
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id, )
)
row = cursor.fetchone()
conn.close()
@@ -1037,19 +1037,19 @@ class SecurityManager:
return False, "Policy not found or inactive"
policy = DataAccessPolicy(
id=row[0],
project_id=row[1],
name=row[2],
description=row[3],
allowed_users=row[4],
allowed_roles=row[5],
allowed_ips=row[6],
time_restrictions=row[7],
max_access_count=row[8],
require_approval=bool(row[9]),
is_active=bool(row[10]),
created_at=row[11],
updated_at=row[12],
id = row[0],
project_id = row[1],
name = row[2],
description = row[3],
allowed_users = row[4],
allowed_roles = row[5],
allowed_ips = row[6],
time_restrictions = row[7],
max_access_count = row[8],
require_approval = bool(row[9]),
is_active = bool(row[10]),
created_at = row[11],
updated_at = row[12],
)
# 检查用户白名单
@@ -1113,7 +1113,7 @@ class SecurityManager:
try:
if "/" in pattern:
# CIDR 表示法
network = ipaddress.ip_network(pattern, strict=False)
network = ipaddress.ip_network(pattern, strict = False)
return ipaddress.ip_address(ip) in network
else:
# 精确匹配
@@ -1130,11 +1130,11 @@ class SecurityManager:
) -> AccessRequest:
"""创建访问请求"""
request = AccessRequest(
id=self._generate_id(),
policy_id=policy_id,
user_id=user_id,
request_reason=request_reason,
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(),
id = self._generate_id(),
policy_id = policy_id,
user_id = user_id,
request_reason = request_reason,
expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat(),
)
conn = sqlite3.connect(self.db_path)
@@ -1169,7 +1169,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
expires_at = (datetime.now() + timedelta(hours = expires_hours)).isoformat()
approved_at = datetime.now().isoformat()
cursor.execute(
@@ -1184,7 +1184,7 @@ class SecurityManager:
conn.commit()
# 获取更新后的请求
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,))
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, ))
row = cursor.fetchone()
conn.close()
@@ -1192,15 +1192,15 @@ class SecurityManager:
return None
return AccessRequest(
id=row[0],
policy_id=row[1],
user_id=row[2],
request_reason=row[3],
status=row[4],
approved_by=row[5],
approved_at=row[6],
expires_at=row[7],
created_at=row[8],
id = row[0],
policy_id = row[1],
user_id = row[2],
request_reason = row[3],
status = row[4],
approved_by = row[5],
approved_at = row[6],
expires_at = row[7],
created_at = row[8],
)
def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
@@ -1219,7 +1219,7 @@ class SecurityManager:
conn.commit()
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,))
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id, ))
row = cursor.fetchone()
conn.close()
@@ -1227,15 +1227,15 @@ class SecurityManager:
return None
return AccessRequest(
id=row[0],
policy_id=row[1],
user_id=row[2],
request_reason=row[3],
status=row[4],
approved_by=row[5],
approved_at=row[6],
expires_at=row[7],
created_at=row[8],
id = row[0],
policy_id = row[1],
user_id = row[2],
request_reason = row[3],
status = row[4],
approved_by = row[5],
approved_at = row[6],
expires_at = row[7],
created_at = row[8],
)

View File

@@ -313,7 +313,7 @@ class SubscriptionManager:
"export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页PDF导出
}
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self._init_db()
self._init_default_plans()
@@ -572,7 +572,7 @@ class SubscriptionManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id,))
cursor.execute("SELECT * FROM subscription_plans WHERE id = ?", (plan_id, ))
row = cursor.fetchone()
if row:
@@ -588,7 +588,7 @@ class SubscriptionManager:
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,)
"SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier, )
)
row = cursor.fetchone()
@@ -635,19 +635,19 @@ class SubscriptionManager:
plan_id = str(uuid.uuid4())
plan = SubscriptionPlan(
id=plan_id,
name=name,
tier=tier,
description=description,
price_monthly=price_monthly,
price_yearly=price_yearly,
currency=currency,
features=features or [],
limits=limits or {},
is_active=True,
created_at=datetime.now(),
updated_at=datetime.now(),
metadata={},
id = plan_id,
name = name,
tier = tier,
description = description,
price_monthly = price_monthly,
price_yearly = price_yearly,
currency = currency,
features = features or [],
limits = limits or {},
is_active = True,
created_at = datetime.now(),
updated_at = datetime.now(),
metadata = {},
)
cursor = conn.cursor()
@@ -760,7 +760,7 @@ class SubscriptionManager:
SELECT * FROM subscriptions
WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending')
""",
(tenant_id,),
(tenant_id, ),
)
existing = cursor.fetchone()
@@ -777,36 +777,36 @@ class SubscriptionManager:
# 计算周期
if billing_cycle == "yearly":
period_end = now + timedelta(days=365)
period_end = now + timedelta(days = 365)
else:
period_end = now + timedelta(days=30)
period_end = now + timedelta(days = 30)
# 试用处理
trial_start = None
trial_end = None
if trial_days > 0:
trial_start = now
trial_end = now + timedelta(days=trial_days)
trial_end = now + timedelta(days = trial_days)
status = SubscriptionStatus.TRIAL.value
else:
status = SubscriptionStatus.PENDING.value
subscription = Subscription(
id=subscription_id,
tenant_id=tenant_id,
plan_id=plan_id,
status=status,
current_period_start=now,
current_period_end=period_end,
cancel_at_period_end=False,
canceled_at=None,
trial_start=trial_start,
trial_end=trial_end,
payment_provider=payment_provider,
provider_subscription_id=None,
created_at=now,
updated_at=now,
metadata={"billing_cycle": billing_cycle},
id = subscription_id,
tenant_id = tenant_id,
plan_id = plan_id,
status = status,
current_period_start = now,
current_period_end = period_end,
cancel_at_period_end = False,
canceled_at = None,
trial_start = trial_start,
trial_end = trial_end,
payment_provider = payment_provider,
provider_subscription_id = None,
created_at = now,
updated_at = now,
metadata = {"billing_cycle": billing_cycle},
)
cursor.execute(
@@ -878,7 +878,7 @@ class SubscriptionManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id,))
cursor.execute("SELECT * FROM subscriptions WHERE id = ?", (subscription_id, ))
row = cursor.fetchone()
if row:
@@ -899,7 +899,7 @@ class SubscriptionManager:
WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending')
ORDER BY created_at DESC LIMIT 1
""",
(tenant_id,),
(tenant_id, ),
)
row = cursor.fetchone()
@@ -1087,15 +1087,15 @@ class SubscriptionManager:
record_id = str(uuid.uuid4())
record = UsageRecord(
id=record_id,
tenant_id=tenant_id,
resource_type=resource_type,
quantity=quantity,
unit=unit,
recorded_at=datetime.now(),
cost=cost,
description=description,
metadata=metadata or {},
id = record_id,
tenant_id = tenant_id,
resource_type = resource_type,
quantity = quantity,
unit = unit,
recorded_at = datetime.now(),
cost = cost,
description = description,
metadata = metadata or {},
)
cursor = conn.cursor()
@@ -1214,22 +1214,22 @@ class SubscriptionManager:
now = datetime.now()
payment = Payment(
id=payment_id,
tenant_id=tenant_id,
subscription_id=subscription_id,
invoice_id=invoice_id,
amount=amount,
currency=currency,
provider=provider,
provider_payment_id=None,
status=PaymentStatus.PENDING.value,
payment_method=payment_method,
payment_details=payment_details or {},
paid_at=None,
failed_at=None,
failure_reason=None,
created_at=now,
updated_at=now,
id = payment_id,
tenant_id = tenant_id,
subscription_id = subscription_id,
invoice_id = invoice_id,
amount = amount,
currency = currency,
provider = provider,
provider_payment_id = None,
status = PaymentStatus.PENDING.value,
payment_method = payment_method,
payment_details = payment_details or {},
paid_at = None,
failed_at = None,
failure_reason = None,
created_at = now,
updated_at = now,
)
cursor = conn.cursor()
@@ -1389,7 +1389,7 @@ class SubscriptionManager:
def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None:
"""内部方法:获取支付记录"""
cursor = conn.cursor()
cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id,))
cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id, ))
row = cursor.fetchone()
if row:
@@ -1414,27 +1414,27 @@ class SubscriptionManager:
invoice_id = str(uuid.uuid4())
invoice_number = self._generate_invoice_number()
now = datetime.now()
due_date = now + timedelta(days=7) # 7天付款期限
due_date = now + timedelta(days = 7) # 7天付款期限
invoice = Invoice(
id=invoice_id,
tenant_id=tenant_id,
subscription_id=subscription_id,
invoice_number=invoice_number,
status=InvoiceStatus.DRAFT.value,
amount_due=amount,
amount_paid=0,
currency=currency,
period_start=period_start,
period_end=period_end,
description=description,
line_items=line_items or [{"description": description, "amount": amount}],
due_date=due_date,
paid_at=None,
voided_at=None,
void_reason=None,
created_at=now,
updated_at=now,
id = invoice_id,
tenant_id = tenant_id,
subscription_id = subscription_id,
invoice_number = invoice_number,
status = InvoiceStatus.DRAFT.value,
amount_due = amount,
amount_paid = 0,
currency = currency,
period_start = period_start,
period_end = period_end,
description = description,
line_items = line_items or [{"description": description, "amount": amount}],
due_date = due_date,
paid_at = None,
voided_at = None,
void_reason = None,
created_at = now,
updated_at = now,
)
cursor = conn.cursor()
@@ -1475,7 +1475,7 @@ class SubscriptionManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id,))
cursor.execute("SELECT * FROM invoices WHERE id = ?", (invoice_id, ))
row = cursor.fetchone()
if row:
@@ -1490,7 +1490,7 @@ class SubscriptionManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number,))
cursor.execute("SELECT * FROM invoices WHERE invoice_number = ?", (invoice_number, ))
row = cursor.fetchone()
if row:
@@ -1568,7 +1568,7 @@ class SubscriptionManager:
SELECT COUNT(*) as count FROM invoices
WHERE invoice_number LIKE ?
""",
(f"{prefix}%",),
(f"{prefix}%", ),
)
row = cursor.fetchone()
count = row["count"] + 1
@@ -1604,23 +1604,23 @@ class SubscriptionManager:
now = datetime.now()
refund = Refund(
id=refund_id,
tenant_id=tenant_id,
payment_id=payment_id,
invoice_id=payment.invoice_id,
amount=amount,
currency=payment.currency,
reason=reason,
status=RefundStatus.PENDING.value,
requested_by=requested_by,
requested_at=now,
approved_by=None,
approved_at=None,
completed_at=None,
provider_refund_id=None,
metadata={},
created_at=now,
updated_at=now,
id = refund_id,
tenant_id = tenant_id,
payment_id = payment_id,
invoice_id = payment.invoice_id,
amount = amount,
currency = payment.currency,
reason = reason,
status = RefundStatus.PENDING.value,
requested_by = requested_by,
requested_at = now,
approved_by = None,
approved_at = None,
completed_at = None,
provider_refund_id = None,
metadata = {},
created_at = now,
updated_at = now,
)
cursor = conn.cursor()
@@ -1803,7 +1803,7 @@ class SubscriptionManager:
def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None:
"""内部方法:获取退款记录"""
cursor = conn.cursor()
cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id,))
cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id, ))
row = cursor.fetchone()
if row:
@@ -1822,7 +1822,7 @@ class SubscriptionManager:
description: str,
reference_id: str,
balance_after: float,
):
) -> None:
"""内部方法:添加账单历史"""
history_id = str(uuid.uuid4())
@@ -1962,126 +1962,126 @@ class SubscriptionManager:
def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan:
"""数据库行转换为 SubscriptionPlan 对象"""
return SubscriptionPlan(
id=row["id"],
name=row["name"],
tier=row["tier"],
description=row["description"] or "",
price_monthly=row["price_monthly"],
price_yearly=row["price_yearly"],
currency=row["currency"],
features=json.loads(row["features"] or "[]"),
limits=json.loads(row["limits"] or "{}"),
is_active=bool(row["is_active"]),
created_at=(
id = row["id"],
name = row["name"],
tier = row["tier"],
description = row["description"] or "",
price_monthly = row["price_monthly"],
price_yearly = row["price_yearly"],
currency = row["currency"],
features = json.loads(row["features"] or "[]"),
limits = json.loads(row["limits"] or "{}"),
is_active = bool(row["is_active"]),
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
),
metadata=json.loads(row["metadata"] or "{}"),
metadata = json.loads(row["metadata"] or "{}"),
)
def _row_to_subscription(self, row: sqlite3.Row) -> Subscription:
"""数据库行转换为 Subscription 对象"""
return Subscription(
id=row["id"],
tenant_id=row["tenant_id"],
plan_id=row["plan_id"],
status=row["status"],
current_period_start=(
id = row["id"],
tenant_id = row["tenant_id"],
plan_id = row["plan_id"],
status = row["status"],
current_period_start = (
datetime.fromisoformat(row["current_period_start"])
if row["current_period_start"] and isinstance(row["current_period_start"], str)
else row["current_period_start"]
),
current_period_end=(
current_period_end = (
datetime.fromisoformat(row["current_period_end"])
if row["current_period_end"] and isinstance(row["current_period_end"], str)
else row["current_period_end"]
),
cancel_at_period_end=bool(row["cancel_at_period_end"]),
canceled_at=(
cancel_at_period_end = bool(row["cancel_at_period_end"]),
canceled_at = (
datetime.fromisoformat(row["canceled_at"])
if row["canceled_at"] and isinstance(row["canceled_at"], str)
else row["canceled_at"]
),
trial_start=(
trial_start = (
datetime.fromisoformat(row["trial_start"])
if row["trial_start"] and isinstance(row["trial_start"], str)
else row["trial_start"]
),
trial_end=(
trial_end = (
datetime.fromisoformat(row["trial_end"])
if row["trial_end"] and isinstance(row["trial_end"], str)
else row["trial_end"]
),
payment_provider=row["payment_provider"],
provider_subscription_id=row["provider_subscription_id"],
created_at=(
payment_provider = row["payment_provider"],
provider_subscription_id = row["provider_subscription_id"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
),
metadata=json.loads(row["metadata"] or "{}"),
metadata = json.loads(row["metadata"] or "{}"),
)
def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord:
"""数据库行转换为 UsageRecord 对象"""
return UsageRecord(
id=row["id"],
tenant_id=row["tenant_id"],
resource_type=row["resource_type"],
quantity=row["quantity"],
unit=row["unit"],
recorded_at=(
id = row["id"],
tenant_id = row["tenant_id"],
resource_type = row["resource_type"],
quantity = row["quantity"],
unit = row["unit"],
recorded_at = (
datetime.fromisoformat(row["recorded_at"])
if isinstance(row["recorded_at"], str)
else row["recorded_at"]
),
cost=row["cost"],
description=row["description"],
metadata=json.loads(row["metadata"] or "{}"),
cost = row["cost"],
description = row["description"],
metadata = json.loads(row["metadata"] or "{}"),
)
def _row_to_payment(self, row: sqlite3.Row) -> Payment:
"""数据库行转换为 Payment 对象"""
return Payment(
id=row["id"],
tenant_id=row["tenant_id"],
subscription_id=row["subscription_id"],
invoice_id=row["invoice_id"],
amount=row["amount"],
currency=row["currency"],
provider=row["provider"],
provider_payment_id=row["provider_payment_id"],
status=row["status"],
payment_method=row["payment_method"],
payment_details=json.loads(row["payment_details"] or "{}"),
paid_at=(
id = row["id"],
tenant_id = row["tenant_id"],
subscription_id = row["subscription_id"],
invoice_id = row["invoice_id"],
amount = row["amount"],
currency = row["currency"],
provider = row["provider"],
provider_payment_id = row["provider_payment_id"],
status = row["status"],
payment_method = row["payment_method"],
payment_details = json.loads(row["payment_details"] or "{}"),
paid_at = (
datetime.fromisoformat(row["paid_at"])
if row["paid_at"] and isinstance(row["paid_at"], str)
else row["paid_at"]
),
failed_at=(
failed_at = (
datetime.fromisoformat(row["failed_at"])
if row["failed_at"] and isinstance(row["failed_at"], str)
else row["failed_at"]
),
failure_reason=row["failure_reason"],
created_at=(
failure_reason = row["failure_reason"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -2091,48 +2091,48 @@ class SubscriptionManager:
def _row_to_invoice(self, row: sqlite3.Row) -> Invoice:
"""数据库行转换为 Invoice 对象"""
return Invoice(
id=row["id"],
tenant_id=row["tenant_id"],
subscription_id=row["subscription_id"],
invoice_number=row["invoice_number"],
status=row["status"],
amount_due=row["amount_due"],
amount_paid=row["amount_paid"],
currency=row["currency"],
period_start=(
id = row["id"],
tenant_id = row["tenant_id"],
subscription_id = row["subscription_id"],
invoice_number = row["invoice_number"],
status = row["status"],
amount_due = row["amount_due"],
amount_paid = row["amount_paid"],
currency = row["currency"],
period_start = (
datetime.fromisoformat(row["period_start"])
if row["period_start"] and isinstance(row["period_start"], str)
else row["period_start"]
),
period_end=(
period_end = (
datetime.fromisoformat(row["period_end"])
if row["period_end"] and isinstance(row["period_end"], str)
else row["period_end"]
),
description=row["description"],
line_items=json.loads(row["line_items"] or "[]"),
due_date=(
description = row["description"],
line_items = json.loads(row["line_items"] or "[]"),
due_date = (
datetime.fromisoformat(row["due_date"])
if row["due_date"] and isinstance(row["due_date"], str)
else row["due_date"]
),
paid_at=(
paid_at = (
datetime.fromisoformat(row["paid_at"])
if row["paid_at"] and isinstance(row["paid_at"], str)
else row["paid_at"]
),
voided_at=(
voided_at = (
datetime.fromisoformat(row["voided_at"])
if row["voided_at"] and isinstance(row["voided_at"], str)
else row["voided_at"]
),
void_reason=row["void_reason"],
created_at=(
void_reason = row["void_reason"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -2142,39 +2142,39 @@ class SubscriptionManager:
def _row_to_refund(self, row: sqlite3.Row) -> Refund:
"""数据库行转换为 Refund 对象"""
return Refund(
id=row["id"],
tenant_id=row["tenant_id"],
payment_id=row["payment_id"],
invoice_id=row["invoice_id"],
amount=row["amount"],
currency=row["currency"],
reason=row["reason"],
status=row["status"],
requested_by=row["requested_by"],
requested_at=(
id = row["id"],
tenant_id = row["tenant_id"],
payment_id = row["payment_id"],
invoice_id = row["invoice_id"],
amount = row["amount"],
currency = row["currency"],
reason = row["reason"],
status = row["status"],
requested_by = row["requested_by"],
requested_at = (
datetime.fromisoformat(row["requested_at"])
if isinstance(row["requested_at"], str)
else row["requested_at"]
),
approved_by=row["approved_by"],
approved_at=(
approved_by = row["approved_by"],
approved_at = (
datetime.fromisoformat(row["approved_at"])
if row["approved_at"] and isinstance(row["approved_at"], str)
else row["approved_at"]
),
completed_at=(
completed_at = (
datetime.fromisoformat(row["completed_at"])
if row["completed_at"] and isinstance(row["completed_at"], str)
else row["completed_at"]
),
provider_refund_id=row["provider_refund_id"],
metadata=json.loads(row["metadata"] or "{}"),
created_at=(
provider_refund_id = row["provider_refund_id"],
metadata = json.loads(row["metadata"] or "{}"),
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -2184,20 +2184,20 @@ class SubscriptionManager:
def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory:
"""数据库行转换为 BillingHistory 对象"""
return BillingHistory(
id=row["id"],
tenant_id=row["tenant_id"],
type=row["type"],
amount=row["amount"],
currency=row["currency"],
description=row["description"],
reference_id=row["reference_id"],
balance_after=row["balance_after"],
created_at=(
id = row["id"],
tenant_id = row["tenant_id"],
type = row["type"],
amount = row["amount"],
currency = row["currency"],
description = row["description"],
reference_id = row["reference_id"],
balance_after = row["balance_after"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
metadata=json.loads(row["metadata"] or "{}"),
metadata = json.loads(row["metadata"] or "{}"),
)

View File

@@ -257,7 +257,7 @@ class TenantManager:
"export:basic": "基础导出",
}
def __init__(self, db_path: str = "insightflow.db"):
def __init__(self, db_path: str = "insightflow.db") -> None:
self.db_path = db_path
self._init_db()
@@ -437,19 +437,19 @@ class TenantManager:
)
tenant = Tenant(
id=tenant_id,
name=name,
slug=slug,
description=description,
tier=tier,
status=TenantStatus.PENDING.value,
owner_id=owner_id,
created_at=datetime.now(),
updated_at=datetime.now(),
expires_at=None,
settings=settings or {},
resource_limits=resource_limits,
metadata={},
id = tenant_id,
name = name,
slug = slug,
description = description,
tier = tier,
status = TenantStatus.PENDING.value,
owner_id = owner_id,
created_at = datetime.now(),
updated_at = datetime.now(),
expires_at = None,
settings = settings or {},
resource_limits = resource_limits,
metadata = {},
)
cursor = conn.cursor()
@@ -495,7 +495,7 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id,))
cursor.execute("SELECT * FROM tenants WHERE id = ?", (tenant_id, ))
row = cursor.fetchone()
if row:
@@ -510,7 +510,7 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug,))
cursor.execute("SELECT * FROM tenants WHERE slug = ?", (slug, ))
row = cursor.fetchone()
if row:
@@ -531,7 +531,7 @@ class TenantManager:
JOIN tenant_domains d ON t.id = d.tenant_id
WHERE d.domain = ? AND d.status = 'verified'
""",
(domain,),
(domain, ),
)
row = cursor.fetchone()
@@ -605,7 +605,7 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id,))
cursor.execute("DELETE FROM tenants WHERE id = ?", (tenant_id, ))
conn.commit()
return cursor.rowcount > 0
finally:
@@ -619,7 +619,7 @@ class TenantManager:
try:
cursor = conn.cursor()
query = "SELECT * FROM tenants WHERE 1=1"
query = "SELECT * FROM tenants WHERE 1 = 1"
params = []
if status:
@@ -661,18 +661,18 @@ class TenantManager:
domain_id = str(uuid.uuid4())
tenant_domain = TenantDomain(
id=domain_id,
tenant_id=tenant_id,
domain=domain.lower(),
status=DomainStatus.PENDING.value,
verification_token=verification_token,
verification_method=verification_method,
verified_at=None,
created_at=datetime.now(),
updated_at=datetime.now(),
is_primary=is_primary,
ssl_enabled=False,
ssl_expires_at=None,
id = domain_id,
tenant_id = tenant_id,
domain = domain.lower(),
status = DomainStatus.PENDING.value,
verification_token = verification_token,
verification_method = verification_method,
verified_at = None,
created_at = datetime.now(),
updated_at = datetime.now(),
is_primary = is_primary,
ssl_enabled = False,
ssl_expires_at = None,
)
cursor = conn.cursor()
@@ -684,7 +684,7 @@ class TenantManager:
UPDATE tenant_domains SET is_primary = 0
WHERE tenant_id = ?
""",
(tenant_id,),
(tenant_id, ),
)
cursor.execute(
@@ -782,7 +782,7 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id,))
cursor.execute("SELECT * FROM tenant_domains WHERE id = ?", (domain_id, ))
row = cursor.fetchone()
if not row:
@@ -797,7 +797,7 @@ class TenantManager:
"dns_record": {
"type": "TXT",
"name": "_insightflow",
"value": f"insightflow-verify={token}",
"value": f"insightflow-verify = {token}",
"ttl": 3600,
},
"file_verification": {
@@ -805,7 +805,7 @@ class TenantManager:
"content": token,
},
"instructions": [
f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}",
f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify = {token}",
f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt内容为 {token}",
],
}
@@ -841,7 +841,7 @@ class TenantManager:
WHERE tenant_id = ?
ORDER BY is_primary DESC, created_at DESC
""",
(tenant_id,),
(tenant_id, ),
)
rows = cursor.fetchall()
@@ -857,7 +857,7 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id,))
cursor.execute("SELECT * FROM tenant_branding WHERE tenant_id = ?", (tenant_id, ))
row = cursor.fetchone()
if row:
@@ -885,7 +885,7 @@ class TenantManager:
cursor = conn.cursor()
# 检查是否已存在
cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id,))
cursor.execute("SELECT id FROM tenant_branding WHERE tenant_id = ?", (tenant_id, ))
existing = cursor.fetchone()
if existing:
@@ -1022,17 +1022,17 @@ class TenantManager:
final_permissions = permissions or default_permissions
member = TenantMember(
id=member_id,
tenant_id=tenant_id,
user_id="pending", # 临时值,待用户接受邀请后更新
email=email,
role=role,
permissions=final_permissions,
invited_by=invited_by,
invited_at=datetime.now(),
joined_at=None,
last_active_at=None,
status="pending",
id = member_id,
tenant_id = tenant_id,
user_id = "pending", # 临时值,待用户接受邀请后更新
email = email,
role = role,
permissions = final_permissions,
invited_by = invited_by,
invited_at = datetime.now(),
joined_at = None,
last_active_at = None,
status = "pending",
)
cursor = conn.cursor()
@@ -1197,7 +1197,7 @@ class TenantManager:
WHERE m.user_id = ? AND m.status = 'active'
ORDER BY t.created_at DESC
""",
(user_id,),
(user_id, ),
)
rows = cursor.fetchall()
@@ -1227,7 +1227,7 @@ class TenantManager:
projects_count: int = 0,
entities_count: int = 0,
members_count: int = 0,
):
) -> None:
"""记录资源使用"""
conn = self._get_connection()
try:
@@ -1388,7 +1388,7 @@ class TenantManager:
counter = 1
while True:
cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug,))
cursor.execute("SELECT id FROM tenants WHERE slug = ?", (slug, ))
if not cursor.fetchone():
break
slug = f"{base_slug}-{counter}"
@@ -1406,7 +1406,7 @@ class TenantManager:
def _validate_domain(self, domain: str) -> bool:
"""验证域名格式"""
pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$"
pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$"
return bool(re.match(pattern, domain))
def _check_domain_verification(self, domain: str, token: str, method: str) -> bool:
@@ -1431,7 +1431,7 @@ class TenantManager:
# TODO: 实现 HTTP 文件验证
# import requests
# try:
# response = requests.get(f"http://{domain}/.well-known/insightflow-verify.txt", timeout=10)
# response = requests.get(f"http://{domain}/.well-known/insightflow-verify.txt", timeout = 10)
# if response.status_code == 200 and token in response.text:
# return True
# except (ImportError, Exception):
@@ -1467,7 +1467,7 @@ class TenantManager:
email: str,
role: TenantRole,
invited_by: str | None,
):
) -> None:
"""内部方法:添加成员"""
cursor = conn.cursor()
member_id = str(uuid.uuid4())
@@ -1497,60 +1497,60 @@ class TenantManager:
def _row_to_tenant(self, row: sqlite3.Row) -> Tenant:
"""数据库行转换为 Tenant 对象"""
return Tenant(
id=row["id"],
name=row["name"],
slug=row["slug"],
description=row["description"],
tier=row["tier"],
status=row["status"],
owner_id=row["owner_id"],
created_at=(
id = row["id"],
name = row["name"],
slug = row["slug"],
description = row["description"],
tier = row["tier"],
status = row["status"],
owner_id = row["owner_id"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
),
expires_at=(
expires_at = (
datetime.fromisoformat(row["expires_at"])
if row["expires_at"] and isinstance(row["expires_at"], str)
else row["expires_at"]
),
settings=json.loads(row["settings"] or "{}"),
resource_limits=json.loads(row["resource_limits"] or "{}"),
metadata=json.loads(row["metadata"] or "{}"),
settings = json.loads(row["settings"] or "{}"),
resource_limits = json.loads(row["resource_limits"] or "{}"),
metadata = json.loads(row["metadata"] or "{}"),
)
def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain:
"""数据库行转换为 TenantDomain 对象"""
return TenantDomain(
id=row["id"],
tenant_id=row["tenant_id"],
domain=row["domain"],
status=row["status"],
verification_token=row["verification_token"],
verification_method=row["verification_method"],
verified_at=(
id = row["id"],
tenant_id = row["tenant_id"],
domain = row["domain"],
status = row["status"],
verification_token = row["verification_token"],
verification_method = row["verification_method"],
verified_at = (
datetime.fromisoformat(row["verified_at"])
if row["verified_at"] and isinstance(row["verified_at"], str)
else row["verified_at"]
),
created_at=(
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
),
is_primary=bool(row["is_primary"]),
ssl_enabled=bool(row["ssl_enabled"]),
ssl_expires_at=(
is_primary = bool(row["is_primary"]),
ssl_enabled = bool(row["ssl_enabled"]),
ssl_expires_at = (
datetime.fromisoformat(row["ssl_expires_at"])
if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str)
else row["ssl_expires_at"]
@@ -1560,22 +1560,22 @@ class TenantManager:
def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding:
"""数据库行转换为 TenantBranding 对象"""
return TenantBranding(
id=row["id"],
tenant_id=row["tenant_id"],
logo_url=row["logo_url"],
favicon_url=row["favicon_url"],
primary_color=row["primary_color"],
secondary_color=row["secondary_color"],
custom_css=row["custom_css"],
custom_js=row["custom_js"],
login_page_bg=row["login_page_bg"],
email_template=row["email_template"],
created_at=(
id = row["id"],
tenant_id = row["tenant_id"],
logo_url = row["logo_url"],
favicon_url = row["favicon_url"],
primary_color = row["primary_color"],
secondary_color = row["secondary_color"],
custom_css = row["custom_css"],
custom_js = row["custom_js"],
login_page_bg = row["login_page_bg"],
email_template = row["email_template"],
created_at = (
datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
),
updated_at=(
updated_at = (
datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
@@ -1585,29 +1585,29 @@ class TenantManager:
def _row_to_member(self, row: sqlite3.Row) -> TenantMember:
"""数据库行转换为 TenantMember 对象"""
return TenantMember(
id=row["id"],
tenant_id=row["tenant_id"],
user_id=row["user_id"],
email=row["email"],
role=row["role"],
permissions=json.loads(row["permissions"] or "[]"),
invited_by=row["invited_by"],
invited_at=(
id = row["id"],
tenant_id = row["tenant_id"],
user_id = row["user_id"],
email = row["email"],
role = row["role"],
permissions = json.loads(row["permissions"] or "[]"),
invited_by = row["invited_by"],
invited_at = (
datetime.fromisoformat(row["invited_at"])
if isinstance(row["invited_at"], str)
else row["invited_at"]
),
joined_at=(
joined_at = (
datetime.fromisoformat(row["joined_at"])
if row["joined_at"] and isinstance(row["joined_at"], str)
else row["joined_at"]
),
last_active_at=(
last_active_at = (
datetime.fromisoformat(row["last_active_at"])
if row["last_active_at"] and isinstance(row["last_active_at"], str)
else row["last_active_at"]
),
status=row["status"],
status = row["status"],
)

View File

@@ -10,9 +10,9 @@ import sys
# 添加 backend 目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
print("=" * 60)
print(" = " * 60)
print("InsightFlow 多模态模块测试")
print("=" * 60)
print(" = " * 60)
# 测试导入
print("\n1. 测试模块导入...")
@@ -147,6 +147,6 @@ try:
except Exception as e:
print(f" ✗ 数据库多模态方法测试失败: {e}")
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试完成")
print("=" * 60)
print(" = " * 60)

View File

@@ -21,27 +21,27 @@ from search_manager import (
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_fulltext_search():
def test_fulltext_search() -> None:
"""测试全文搜索"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试全文搜索 (FullTextSearch)")
print("=" * 60)
print(" = " * 60)
search = FullTextSearch()
# 测试索引创建
print("\n1. 测试索引创建...")
success = search.index_content(
content_id="test_entity_1",
content_type="entity",
project_id="test_project",
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
content_id = "test_entity_1",
content_type = "entity",
project_id = "test_project",
text = "这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
)
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
# 测试搜索
print("\n2. 测试关键词搜索...")
results = search.search("测试", project_id="test_project")
results = search.search("测试", project_id = "test_project")
print(f" 搜索结果数量: {len(results)}")
if results:
print(f" 第一个结果: {results[0].content[:50]}...")
@@ -49,10 +49,10 @@ def test_fulltext_search():
# 测试布尔搜索
print("\n3. 测试布尔搜索...")
results = search.search("测试 AND 全文", project_id="test_project")
results = search.search("测试 AND 全文", project_id = "test_project")
print(f" AND 搜索结果: {len(results)}")
results = search.search("测试 OR 关键词", project_id="test_project")
results = search.search("测试 OR 关键词", project_id = "test_project")
print(f" OR 搜索结果: {len(results)}")
# 测试高亮
@@ -64,11 +64,11 @@ def test_fulltext_search():
return True
def test_semantic_search():
def test_semantic_search() -> None:
"""测试语义搜索"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试语义搜索 (SemanticSearch)")
print("=" * 60)
print(" = " * 60)
semantic = SemanticSearch()
@@ -89,10 +89,10 @@ def test_semantic_search():
# 测试索引
print("\n3. 测试语义索引...")
success = semantic.index_embedding(
content_id="test_content_1",
content_type="transcript",
project_id="test_project",
text="这是用于语义搜索测试的文本内容。",
content_id = "test_content_1",
content_type = "transcript",
project_id = "test_project",
text = "这是用于语义搜索测试的文本内容。",
)
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
@@ -100,11 +100,11 @@ def test_semantic_search():
return True
def test_entity_path_discovery():
def test_entity_path_discovery() -> None:
"""测试实体路径发现"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试实体路径发现 (EntityPathDiscovery)")
print("=" * 60)
print(" = " * 60)
discovery = EntityPathDiscovery()
@@ -119,11 +119,11 @@ def test_entity_path_discovery():
return True
def test_knowledge_gap_detection():
def test_knowledge_gap_detection() -> None:
"""测试知识缺口识别"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试知识缺口识别 (KnowledgeGapDetection)")
print("=" * 60)
print(" = " * 60)
detection = KnowledgeGapDetection()
@@ -138,11 +138,11 @@ def test_knowledge_gap_detection():
return True
def test_cache_manager():
def test_cache_manager() -> None:
"""测试缓存管理器"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试缓存管理器 (CacheManager)")
print("=" * 60)
print(" = " * 60)
cache = CacheManager()
@@ -150,7 +150,7 @@ def test_cache_manager():
print("\n2. 测试缓存操作...")
# 设置缓存
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60)
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl = 60)
print(" ✓ 设置缓存 test_key_1")
# 获取缓存
@@ -159,7 +159,7 @@ def test_cache_manager():
# 批量操作
cache.set_many(
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl = 60
)
print(" ✓ 批量设置缓存")
@@ -186,11 +186,11 @@ def test_cache_manager():
return True
def test_task_queue():
def test_task_queue() -> None:
"""测试任务队列"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试任务队列 (TaskQueue)")
print("=" * 60)
print(" = " * 60)
queue = TaskQueue()
@@ -200,7 +200,7 @@ def test_task_queue():
print("\n2. 测试任务提交...")
# 定义测试任务处理器
def test_task_handler(payload):
def test_task_handler(payload) -> None:
print(f" 执行任务: {payload}")
return {"status": "success", "processed": True}
@@ -208,7 +208,7 @@ def test_task_queue():
# 提交任务
task_id = queue.submit(
task_type="test_task", payload={"test": "data", "timestamp": time.time()}
task_type = "test_task", payload = {"test": "data", "timestamp": time.time()}
)
print(" ✓ 提交任务: {task_id}")
@@ -227,11 +227,11 @@ def test_task_queue():
return True
def test_performance_monitor():
def test_performance_monitor() -> None:
"""测试性能监控"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试性能监控 (PerformanceMonitor)")
print("=" * 60)
print(" = " * 60)
monitor = PerformanceMonitor()
@@ -240,25 +240,25 @@ def test_performance_monitor():
# 记录一些测试指标
for i in range(5):
monitor.record_metric(
metric_type="api_response",
duration_ms=50 + i * 10,
endpoint="/api/v1/test",
metadata={"test": True},
metric_type = "api_response",
duration_ms = 50 + i * 10,
endpoint = "/api/v1/test",
metadata = {"test": True},
)
for i in range(3):
monitor.record_metric(
metric_type="db_query",
duration_ms=20 + i * 5,
endpoint="SELECT test",
metadata={"test": True},
metric_type = "db_query",
duration_ms = 20 + i * 5,
endpoint = "SELECT test",
metadata = {"test": True},
)
print(" ✓ 记录了 8 个测试指标")
# 获取统计
print("\n2. 获取性能统计...")
stats = monitor.get_stats(hours=1)
stats = monitor.get_stats(hours = 1)
print(f" 总请求数: {stats['overall']['total_requests']}")
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
@@ -274,11 +274,11 @@ def test_performance_monitor():
return True
def test_search_manager():
def test_search_manager() -> None:
"""测试搜索管理器"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试搜索管理器 (SearchManager)")
print("=" * 60)
print(" = " * 60)
manager = get_search_manager()
@@ -295,11 +295,11 @@ def test_search_manager():
return True
def test_performance_manager():
def test_performance_manager() -> None:
"""测试性能管理器"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试性能管理器 (PerformanceManager)")
print("=" * 60)
print(" = " * 60)
manager = get_performance_manager()
@@ -320,12 +320,12 @@ def test_performance_manager():
return True
def run_all_tests():
def run_all_tests() -> None:
"""运行所有测试"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("InsightFlow Phase 7 Task 6 & 8 测试")
print("高级搜索与发现 + 性能优化与扩展")
print("=" * 60)
print(" = " * 60)
results = []
@@ -386,9 +386,9 @@ def run_all_tests():
results.append(("性能管理器", False))
# 打印测试汇总
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试汇总")
print("=" * 60)
print(" = " * 60)
passed = sum(1 for _, result in results if result)
total = len(results)

View File

@@ -18,18 +18,18 @@ from tenant_manager import get_tenant_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_tenant_management():
def test_tenant_management() -> None:
"""测试租户管理功能"""
print("=" * 60)
print(" = " * 60)
print("测试 1: 租户管理")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
# 1. 创建租户
print("\n1.1 创建租户...")
tenant = manager.create_tenant(
name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant"
name = "Test Company", owner_id = "user_001", tier = "pro", description = "A test company tenant"
)
print(f"✅ 租户创建成功: {tenant.id}")
print(f" - 名称: {tenant.name}")
@@ -53,30 +53,30 @@ def test_tenant_management():
# 4. 更新租户
print("\n1.4 更新租户信息...")
updated = manager.update_tenant(
tenant_id=tenant.id, name="Test Company Updated", tier="enterprise"
tenant_id = tenant.id, name = "Test Company Updated", tier = "enterprise"
)
assert updated is not None, "更新租户失败"
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
# 5. 列出租户
print("\n1.5 列出租户...")
tenants = manager.list_tenants(limit=10)
tenants = manager.list_tenants(limit = 10)
print(f"✅ 找到 {len(tenants)} 个租户")
return tenant.id
def test_domain_management(tenant_id: str):
def test_domain_management(tenant_id: str) -> None:
"""测试域名管理功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 2: 域名管理")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
# 1. 添加域名
print("\n2.1 添加自定义域名...")
domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True)
domain = manager.add_domain(tenant_id = tenant_id, domain = "test.example.com", is_primary = True)
print(f"✅ 域名添加成功: {domain.domain}")
print(f" - ID: {domain.id}")
print(f" - 状态: {domain.status}")
@@ -112,25 +112,25 @@ def test_domain_management(tenant_id: str):
return domain.id
def test_branding_management(tenant_id: str):
def test_branding_management(tenant_id: str) -> None:
"""测试品牌白标功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 3: 品牌白标")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
# 1. 更新品牌配置
print("\n3.1 更新品牌配置...")
branding = manager.update_branding(
tenant_id=tenant_id,
logo_url="https://example.com/logo.png",
favicon_url="https://example.com/favicon.ico",
primary_color="#1890ff",
secondary_color="#52c41a",
custom_css=".header { background: #1890ff; }",
custom_js="console.log('Custom JS loaded');",
login_page_bg="https://example.com/bg.jpg",
tenant_id = tenant_id,
logo_url = "https://example.com/logo.png",
favicon_url = "https://example.com/favicon.ico",
primary_color = "#1890ff",
secondary_color = "#52c41a",
custom_css = ".header { background: #1890ff; }",
custom_js = "console.log('Custom JS loaded');",
login_page_bg = "https://example.com/bg.jpg",
)
print("✅ 品牌配置更新成功")
print(f" - Logo: {branding.logo_url}")
@@ -152,18 +152,18 @@ def test_branding_management(tenant_id: str):
return branding.id
def test_member_management(tenant_id: str):
def test_member_management(tenant_id: str) -> None:
"""测试成员管理功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 4: 成员管理")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
# 1. 邀请成员
print("\n4.1 邀请成员...")
member1 = manager.invite_member(
tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001"
tenant_id = tenant_id, email = "admin@test.com", role = "admin", invited_by = "user_001"
)
print(f"✅ 成员邀请成功: {member1.email}")
print(f" - ID: {member1.id}")
@@ -171,7 +171,7 @@ def test_member_management(tenant_id: str):
print(f" - 权限: {member1.permissions}")
member2 = manager.invite_member(
tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001"
tenant_id = tenant_id, email = "member@test.com", role = "member", invited_by = "user_001"
)
print(f"✅ 成员邀请成功: {member2.email}")
@@ -207,24 +207,24 @@ def test_member_management(tenant_id: str):
return member1.id, member2.id
def test_usage_tracking(tenant_id: str):
def test_usage_tracking(tenant_id: str) -> None:
"""测试资源使用统计功能"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("测试 5: 资源使用统计")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
# 1. 记录使用
print("\n5.1 记录资源使用...")
manager.record_usage(
tenant_id=tenant_id,
storage_bytes=1024 * 1024 * 50, # 50MB
transcription_seconds=600, # 10分钟
api_calls=100,
projects_count=5,
entities_count=50,
members_count=3,
tenant_id = tenant_id,
storage_bytes = 1024 * 1024 * 50, # 50MB
transcription_seconds = 600, # 10分钟
api_calls = 100,
projects_count = 5,
entities_count = 50,
members_count = 3,
)
print("✅ 资源使用记录成功")
@@ -249,11 +249,11 @@ def test_usage_tracking(tenant_id: str):
return stats
def cleanup(tenant_id: str, domain_id: str, member_ids: list):
def cleanup(tenant_id: str, domain_id: str, member_ids: list) -> None:
"""清理测试数据"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("清理测试数据")
print("=" * 60)
print(" = " * 60)
manager = get_tenant_manager()
@@ -273,11 +273,11 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list):
print(f"✅ 租户已删除: {tenant_id}")
def main():
def main() -> None:
"""主测试函数"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试")
print("=" * 60)
print(" = " * 60)
tenant_id = None
domain_id = None
@@ -292,9 +292,9 @@ def main():
member_ids = [m1, m2]
test_usage_tracking(tenant_id)
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("✅ 所有测试通过!")
print("=" * 60)
print(" = " * 60)
except Exception as e:
print(f"\n❌ 测试失败: {e}")

View File

@@ -12,17 +12,17 @@ from subscription_manager import PaymentProvider, SubscriptionManager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_subscription_manager():
def test_subscription_manager() -> None:
"""测试订阅管理器"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试")
print("=" * 60)
print(" = " * 60)
# 使用临时文件数据库进行测试
db_path = tempfile.mktemp(suffix=".db")
db_path = tempfile.mktemp(suffix = ".db")
try:
manager = SubscriptionManager(db_path=db_path)
manager = SubscriptionManager(db_path = db_path)
print("\n1. 测试订阅计划管理")
print("-" * 40)
@@ -53,10 +53,10 @@ def test_subscription_manager():
# 创建订阅
subscription = manager.create_subscription(
tenant_id=tenant_id,
plan_id=pro_plan.id,
payment_provider=PaymentProvider.STRIPE.value,
trial_days=14,
tenant_id = tenant_id,
plan_id = pro_plan.id,
payment_provider = PaymentProvider.STRIPE.value,
trial_days = 14,
)
print(f"✓ 创建订阅: {subscription.id}")
@@ -75,21 +75,21 @@ def test_subscription_manager():
# 记录转录用量
usage1 = manager.record_usage(
tenant_id=tenant_id,
resource_type="transcription",
quantity=120,
unit="minute",
description="会议转录",
tenant_id = tenant_id,
resource_type = "transcription",
quantity = 120,
unit = "minute",
description = "会议转录",
)
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
# 记录存储用量
usage2 = manager.record_usage(
tenant_id=tenant_id,
resource_type="storage",
quantity=2.5,
unit="gb",
description="文件存储",
tenant_id = tenant_id,
resource_type = "storage",
quantity = 2.5,
unit = "gb",
description = "文件存储",
)
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
@@ -105,11 +105,11 @@ def test_subscription_manager():
# 创建支付
payment = manager.create_payment(
tenant_id=tenant_id,
amount=99.0,
currency="CNY",
provider=PaymentProvider.ALIPAY.value,
payment_method="qrcode",
tenant_id = tenant_id,
amount = 99.0,
currency = "CNY",
provider = PaymentProvider.ALIPAY.value,
payment_method = "qrcode",
)
print(f"✓ 创建支付: {payment.id}")
print(f" - 金额: ¥{payment.amount}")
@@ -142,11 +142,11 @@ def test_subscription_manager():
# 申请退款
refund = manager.request_refund(
tenant_id=tenant_id,
payment_id=payment.id,
amount=50.0,
reason="服务不满意",
requested_by="user_001",
tenant_id = tenant_id,
payment_id = payment.id,
amount = 50.0,
reason = "服务不满意",
requested_by = "user_001",
)
print(f"✓ 申请退款: {refund.id}")
print(f" - 金额: ¥{refund.amount}")
@@ -178,19 +178,19 @@ def test_subscription_manager():
# Stripe Checkout
stripe_session = manager.create_stripe_checkout_session(
tenant_id=tenant_id,
plan_id=enterprise_plan.id,
success_url="https://example.com/success",
cancel_url="https://example.com/cancel",
tenant_id = tenant_id,
plan_id = enterprise_plan.id,
success_url = "https://example.com/success",
cancel_url = "https://example.com/cancel",
)
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
# 支付宝订单
alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id)
alipay_order = manager.create_alipay_order(tenant_id = tenant_id, plan_id = pro_plan.id)
print(f"✓ 支付宝订单: {alipay_order['order_id']}")
# 微信支付订单
wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id)
wechat_order = manager.create_wechat_order(tenant_id = tenant_id, plan_id = pro_plan.id)
print(f"✓ 微信支付订单: {wechat_order['order_id']}")
# Webhook 处理
@@ -205,18 +205,18 @@ def test_subscription_manager():
# 更改计划
changed = manager.change_plan(
subscription_id=subscription.id, new_plan_id=enterprise_plan.id
subscription_id = subscription.id, new_plan_id = enterprise_plan.id
)
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
# 取消订阅
cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True)
cancelled = manager.cancel_subscription(subscription_id = subscription.id, at_period_end = True)
print(f"✓ 取消订阅: {cancelled.status}")
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("所有测试通过! ✓")
print("=" * 60)
print(" = " * 60)
finally:
# 清理临时数据库

View File

@@ -14,7 +14,7 @@ from ai_manager import ModelType, PredictionType, get_ai_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_custom_model():
def test_custom_model() -> None:
"""测试自定义模型功能"""
print("\n=== 测试自定义模型 ===")
@@ -23,16 +23,16 @@ def test_custom_model():
# 1. 创建自定义模型
print("1. 创建自定义模型...")
model = manager.create_custom_model(
tenant_id="tenant_001",
name="领域实体识别模型",
description="用于识别医疗领域实体的自定义模型",
model_type=ModelType.CUSTOM_NER,
training_data={
tenant_id = "tenant_001",
name = "领域实体识别模型",
description = "用于识别医疗领域实体的自定义模型",
model_type = ModelType.CUSTOM_NER,
training_data = {
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
"domain": "medical",
},
hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
created_by="user_001",
hyperparameters = {"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
created_by = "user_001",
)
print(f" 创建成功: {model.id}, 状态: {model.status.value}")
@@ -67,10 +67,10 @@ def test_custom_model():
for sample_data in samples:
sample = manager.add_training_sample(
model_id=model.id,
text=sample_data["text"],
entities=sample_data["entities"],
metadata={"source": "manual"},
model_id = model.id,
text = sample_data["text"],
entities = sample_data["entities"],
metadata = {"source": "manual"},
)
print(f" 添加样本: {sample.id}")
@@ -81,7 +81,7 @@ def test_custom_model():
# 4. 列出自定义模型
print("4. 列出自定义模型...")
models = manager.list_custom_models(tenant_id="tenant_001")
models = manager.list_custom_models(tenant_id = "tenant_001")
print(f" 找到 {len(models)} 个模型")
for m in models:
print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
@@ -89,7 +89,7 @@ def test_custom_model():
return model.id
async def test_train_and_predict(model_id: str):
async def test_train_and_predict(model_id: str) -> None:
"""测试训练和预测"""
print("\n=== 测试模型训练和预测 ===")
@@ -116,7 +116,7 @@ async def test_train_and_predict(model_id: str):
print(f" 预测失败: {e}")
def test_prediction_models():
def test_prediction_models() -> None:
"""测试预测模型"""
print("\n=== 测试预测模型 ===")
@@ -125,32 +125,32 @@ def test_prediction_models():
# 1. 创建趋势预测模型
print("1. 创建趋势预测模型...")
trend_model = manager.create_prediction_model(
tenant_id="tenant_001",
project_id="project_001",
name="实体数量趋势预测",
prediction_type=PredictionType.TREND,
target_entity_type="PERSON",
features=["entity_count", "time_period", "document_count"],
model_config={"algorithm": "linear_regression", "window_size": 7},
tenant_id = "tenant_001",
project_id = "project_001",
name = "实体数量趋势预测",
prediction_type = PredictionType.TREND,
target_entity_type = "PERSON",
features = ["entity_count", "time_period", "document_count"],
model_config = {"algorithm": "linear_regression", "window_size": 7},
)
print(f" 创建成功: {trend_model.id}")
# 2. 创建异常检测模型
print("2. 创建异常检测模型...")
anomaly_model = manager.create_prediction_model(
tenant_id="tenant_001",
project_id="project_001",
name="实体增长异常检测",
prediction_type=PredictionType.ANOMALY,
target_entity_type=None,
features=["daily_growth", "weekly_growth"],
model_config={"threshold": 2.5, "sensitivity": "medium"},
tenant_id = "tenant_001",
project_id = "project_001",
name = "实体增长异常检测",
prediction_type = PredictionType.ANOMALY,
target_entity_type = None,
features = ["daily_growth", "weekly_growth"],
model_config = {"threshold": 2.5, "sensitivity": "medium"},
)
print(f" 创建成功: {anomaly_model.id}")
# 3. 列出预测模型
print("3. 列出预测模型...")
models = manager.list_prediction_models(tenant_id="tenant_001")
models = manager.list_prediction_models(tenant_id = "tenant_001")
print(f" 找到 {len(models)} 个预测模型")
for m in models:
print(f" - {m.name} ({m.prediction_type.value})")
@@ -158,7 +158,7 @@ def test_prediction_models():
return trend_model.id, anomaly_model.id
async def test_predictions(trend_model_id: str, anomaly_model_id: str):
async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None:
"""测试预测功能"""
print("\n=== 测试预测功能 ===")
@@ -193,7 +193,7 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
print(f" 检测结果: {anomaly_result.prediction_data}")
def test_kg_rag():
def test_kg_rag() -> None:
"""测试知识图谱 RAG"""
print("\n=== 测试知识图谱 RAG ===")
@@ -202,28 +202,28 @@ def test_kg_rag():
# 创建 RAG 配置
print("1. 创建知识图谱 RAG 配置...")
rag = manager.create_kg_rag(
tenant_id="tenant_001",
project_id="project_001",
name="项目知识问答",
description="基于项目知识图谱的智能问答",
kg_config={
tenant_id = "tenant_001",
project_id = "project_001",
name = "项目知识问答",
description = "基于项目知识图谱的智能问答",
kg_config = {
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
"relation_types": ["works_with", "belongs_to", "depends_on"],
},
retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
retrieval_config = {"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
generation_config = {"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
)
print(f" 创建成功: {rag.id}")
# 列出 RAG 配置
print("2. 列出 RAG 配置...")
rags = manager.list_kg_rags(tenant_id="tenant_001")
rags = manager.list_kg_rags(tenant_id = "tenant_001")
print(f" 找到 {len(rags)} 个配置")
return rag.id
async def test_kg_rag_query(rag_id: str):
async def test_kg_rag_query(rag_id: str) -> None:
"""测试 RAG 查询"""
print("\n=== 测试知识图谱 RAG 查询 ===")
@@ -279,10 +279,10 @@ async def test_kg_rag_query(rag_id: str):
try:
result = await manager.query_kg_rag(
rag_id=rag_id,
query=query_text,
project_entities=project_entities,
project_relations=project_relations,
rag_id = rag_id,
query = query_text,
project_entities = project_entities,
project_relations = project_relations,
)
print(f" 查询: {result.query}")
@@ -294,7 +294,7 @@ async def test_kg_rag_query(rag_id: str):
print(f" 查询失败: {e}")
async def test_smart_summary():
async def test_smart_summary() -> None:
"""测试智能摘要"""
print("\n=== 测试智能摘要 ===")
@@ -326,12 +326,12 @@ async def test_smart_summary():
print(f"1. 生成 {summary_type} 类型摘要...")
try:
summary = await manager.generate_smart_summary(
tenant_id="tenant_001",
project_id="project_001",
source_type="transcript",
source_id="transcript_001",
summary_type=summary_type,
content_data=content_data,
tenant_id = "tenant_001",
project_id = "project_001",
source_type = "transcript",
source_id = "transcript_001",
summary_type = summary_type,
content_data = content_data,
)
print(f" 摘要类型: {summary.summary_type}")
@@ -342,11 +342,11 @@ async def test_smart_summary():
print(f" 生成失败: {e}")
async def main():
async def main() -> None:
"""主测试函数"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 4 - AI 能力增强测试")
print("=" * 60)
print(" = " * 60)
try:
# 测试自定义模型
@@ -370,9 +370,9 @@ async def main():
# 测试智能摘要
await test_smart_summary()
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("所有测试完成!")
print("=" * 60)
print(" = " * 60)
except Exception as e:
print(f"\n测试失败: {e}")

View File

@@ -36,13 +36,13 @@ if backend_dir not in sys.path:
class TestGrowthManager:
"""测试 Growth Manager 功能"""
def __init__(self):
def __init__(self) -> None:
self.manager = GrowthManager()
self.test_tenant_id = "test_tenant_001"
self.test_user_id = "test_user_001"
self.test_results = []
def log(self, message: str, success: bool = True):
def log(self, message: str, success: bool = True) -> None:
"""记录测试结果"""
status = "" if success else ""
print(f"{status} {message}")
@@ -50,21 +50,21 @@ class TestGrowthManager:
# ==================== 测试用户行为分析 ====================
async def test_track_event(self):
async def test_track_event(self) -> None:
"""测试事件追踪"""
print("\n📊 测试事件追踪...")
try:
event = await self.manager.track_event(
tenant_id=self.test_tenant_id,
user_id=self.test_user_id,
event_type=EventType.PAGE_VIEW,
event_name="dashboard_view",
properties={"page": "/dashboard", "duration": 120},
session_id="session_001",
device_info={"browser": "Chrome", "os": "MacOS"},
referrer="https://google.com",
utm_params={"source": "google", "medium": "organic", "campaign": "summer"},
tenant_id = self.test_tenant_id,
user_id = self.test_user_id,
event_type = EventType.PAGE_VIEW,
event_name = "dashboard_view",
properties = {"page": "/dashboard", "duration": 120},
session_id = "session_001",
device_info = {"browser": "Chrome", "os": "MacOS"},
referrer = "https://google.com",
utm_params = {"source": "google", "medium": "organic", "campaign": "summer"},
)
assert event.id is not None
@@ -74,10 +74,10 @@ class TestGrowthManager:
self.log(f"事件追踪成功: {event.id}")
return True
except Exception as e:
self.log(f"事件追踪失败: {e}", success=False)
self.log(f"事件追踪失败: {e}", success = False)
return False
async def test_track_multiple_events(self):
async def test_track_multiple_events(self) -> None:
"""测试追踪多个事件"""
print("\n📊 测试追踪多个事件...")
@@ -91,20 +91,20 @@ class TestGrowthManager:
for event_type, event_name, props in events:
await self.manager.track_event(
tenant_id=self.test_tenant_id,
user_id=self.test_user_id,
event_type=event_type,
event_name=event_name,
properties=props,
tenant_id = self.test_tenant_id,
user_id = self.test_user_id,
event_type = event_type,
event_name = event_name,
properties = props,
)
self.log(f"成功追踪 {len(events)} 个事件")
return True
except Exception as e:
self.log(f"批量事件追踪失败: {e}", success=False)
self.log(f"批量事件追踪失败: {e}", success = False)
return False
def test_get_user_profile(self):
def test_get_user_profile(self) -> None:
"""测试获取用户画像"""
print("\n👤 测试用户画像...")
@@ -120,18 +120,18 @@ class TestGrowthManager:
return True
except Exception as e:
self.log(f"获取用户画像失败: {e}", success=False)
self.log(f"获取用户画像失败: {e}", success = False)
return False
def test_get_analytics_summary(self):
def test_get_analytics_summary(self) -> None:
"""测试获取分析汇总"""
print("\n📈 测试分析汇总...")
try:
summary = self.manager.get_user_analytics_summary(
tenant_id=self.test_tenant_id,
start_date=datetime.now() - timedelta(days=7),
end_date=datetime.now(),
tenant_id = self.test_tenant_id,
start_date = datetime.now() - timedelta(days = 7),
end_date = datetime.now(),
)
assert "unique_users" in summary
@@ -141,25 +141,25 @@ class TestGrowthManager:
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
return True
except Exception as e:
self.log(f"获取分析汇总失败: {e}", success=False)
self.log(f"获取分析汇总失败: {e}", success = False)
return False
def test_create_funnel(self):
def test_create_funnel(self) -> None:
"""测试创建转化漏斗"""
print("\n🎯 测试创建转化漏斗...")
try:
funnel = self.manager.create_funnel(
tenant_id=self.test_tenant_id,
name="用户注册转化漏斗",
description="从访问到完成注册的转化流程",
steps=[
tenant_id = self.test_tenant_id,
name = "用户注册转化漏斗",
description = "从访问到完成注册的转化流程",
steps = [
{"name": "访问首页", "event_name": "page_view_home"},
{"name": "点击注册", "event_name": "signup_click"},
{"name": "填写信息", "event_name": "signup_form_fill"},
{"name": "完成注册", "event_name": "signup_complete"},
],
created_by="test",
created_by = "test",
)
assert funnel.id is not None
@@ -168,10 +168,10 @@ class TestGrowthManager:
self.log(f"漏斗创建成功: {funnel.id}")
return funnel.id
except Exception as e:
self.log(f"创建漏斗失败: {e}", success=False)
self.log(f"创建漏斗失败: {e}", success = False)
return None
def test_analyze_funnel(self, funnel_id: str):
def test_analyze_funnel(self, funnel_id: str) -> None:
"""测试分析漏斗"""
print("\n📉 测试漏斗分析...")
@@ -181,9 +181,9 @@ class TestGrowthManager:
try:
analysis = self.manager.analyze_funnel(
funnel_id=funnel_id,
period_start=datetime.now() - timedelta(days=30),
period_end=datetime.now(),
funnel_id = funnel_id,
period_start = datetime.now() - timedelta(days = 30),
period_end = datetime.now(),
)
if analysis:
@@ -194,18 +194,18 @@ class TestGrowthManager:
self.log("漏斗分析返回空结果")
return False
except Exception as e:
self.log(f"漏斗分析失败: {e}", success=False)
self.log(f"漏斗分析失败: {e}", success = False)
return False
def test_calculate_retention(self):
def test_calculate_retention(self) -> None:
"""测试留存率计算"""
print("\n🔄 测试留存率计算...")
try:
retention = self.manager.calculate_retention(
tenant_id=self.test_tenant_id,
cohort_date=datetime.now() - timedelta(days=7),
periods=[1, 3, 7],
tenant_id = self.test_tenant_id,
cohort_date = datetime.now() - timedelta(days = 7),
periods = [1, 3, 7],
)
assert "cohort_date" in retention
@@ -214,34 +214,34 @@ class TestGrowthManager:
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
return True
except Exception as e:
self.log(f"留存率计算失败: {e}", success=False)
self.log(f"留存率计算失败: {e}", success = False)
return False
# ==================== 测试 A/B 测试框架 ====================
def test_create_experiment(self):
def test_create_experiment(self) -> None:
"""测试创建实验"""
print("\n🧪 测试创建 A/B 测试实验...")
try:
experiment = self.manager.create_experiment(
tenant_id=self.test_tenant_id,
name="首页按钮颜色测试",
description="测试不同按钮颜色对转化率的影响",
hypothesis="蓝色按钮比红色按钮有更高的点击率",
variants=[
tenant_id = self.test_tenant_id,
name = "首页按钮颜色测试",
description = "测试不同按钮颜色对转化率的影响",
hypothesis = "蓝色按钮比红色按钮有更高的点击率",
variants = [
{"id": "control", "name": "红色按钮", "is_control": True},
{"id": "variant_a", "name": "蓝色按钮", "is_control": False},
{"id": "variant_b", "name": "绿色按钮", "is_control": False},
],
traffic_allocation=TrafficAllocationType.RANDOM,
traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
target_audience={"conditions": []},
primary_metric="button_click_rate",
secondary_metrics=["conversion_rate", "bounce_rate"],
min_sample_size=100,
confidence_level=0.95,
created_by="test",
traffic_allocation = TrafficAllocationType.RANDOM,
traffic_split = {"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
target_audience = {"conditions": []},
primary_metric = "button_click_rate",
secondary_metrics = ["conversion_rate", "bounce_rate"],
min_sample_size = 100,
confidence_level = 0.95,
created_by = "test",
)
assert experiment.id is not None
@@ -250,10 +250,10 @@ class TestGrowthManager:
self.log(f"实验创建成功: {experiment.id}")
return experiment.id
except Exception as e:
self.log(f"创建实验失败: {e}", success=False)
self.log(f"创建实验失败: {e}", success = False)
return None
def test_list_experiments(self):
def test_list_experiments(self) -> None:
"""测试列出实验"""
print("\n📋 测试列出实验...")
@@ -263,10 +263,10 @@ class TestGrowthManager:
self.log(f"列出 {len(experiments)} 个实验")
return True
except Exception as e:
self.log(f"列出实验失败: {e}", success=False)
self.log(f"列出实验失败: {e}", success = False)
return False
def test_assign_variant(self, experiment_id: str):
def test_assign_variant(self, experiment_id: str) -> None:
"""测试分配变体"""
print("\n🎲 测试分配实验变体...")
@@ -284,9 +284,9 @@ class TestGrowthManager:
for user_id in test_users:
variant_id = self.manager.assign_variant(
experiment_id=experiment_id,
user_id=user_id,
user_attributes={"user_id": user_id, "segment": "new"},
experiment_id = experiment_id,
user_id = user_id,
user_attributes = {"user_id": user_id, "segment": "new"},
)
if variant_id:
@@ -295,10 +295,10 @@ class TestGrowthManager:
self.log(f"变体分配完成: {len(assignments)} 个用户")
return True
except Exception as e:
self.log(f"变体分配失败: {e}", success=False)
self.log(f"变体分配失败: {e}", success = False)
return False
def test_record_experiment_metric(self, experiment_id: str):
def test_record_experiment_metric(self, experiment_id: str) -> None:
"""测试记录实验指标"""
print("\n📊 测试记录实验指标...")
@@ -318,20 +318,20 @@ class TestGrowthManager:
for user_id, variant_id, value in test_data:
self.manager.record_experiment_metric(
experiment_id=experiment_id,
variant_id=variant_id,
user_id=user_id,
metric_name="button_click_rate",
metric_value=value,
experiment_id = experiment_id,
variant_id = variant_id,
user_id = user_id,
metric_name = "button_click_rate",
metric_value = value,
)
self.log(f"成功记录 {len(test_data)} 条指标")
return True
except Exception as e:
self.log(f"记录指标失败: {e}", success=False)
self.log(f"记录指标失败: {e}", success = False)
return False
def test_analyze_experiment(self, experiment_id: str):
def test_analyze_experiment(self, experiment_id: str) -> None:
"""测试分析实验结果"""
print("\n📈 测试分析实验结果...")
@@ -346,25 +346,25 @@ class TestGrowthManager:
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
return True
else:
self.log(f"实验分析返回错误: {result['error']}", success=False)
self.log(f"实验分析返回错误: {result['error']}", success = False)
return False
except Exception as e:
self.log(f"实验分析失败: {e}", success=False)
self.log(f"实验分析失败: {e}", success = False)
return False
# ==================== 测试邮件营销 ====================
def test_create_email_template(self):
def test_create_email_template(self) -> None:
"""测试创建邮件模板"""
print("\n📧 测试创建邮件模板...")
try:
template = self.manager.create_email_template(
tenant_id=self.test_tenant_id,
name="欢迎邮件",
template_type=EmailTemplateType.WELCOME,
subject="欢迎加入 InsightFlow",
html_content="""
tenant_id = self.test_tenant_id,
name = "欢迎邮件",
template_type = EmailTemplateType.WELCOME,
subject = "欢迎加入 InsightFlow",
html_content = """
<h1>欢迎,{{user_name}}</h1>
<p>感谢您注册 InsightFlow。我们很高兴您能加入我们</p>
<p>您的账户已创建,可以开始使用以下功能:</p>
@@ -373,10 +373,10 @@ class TestGrowthManager:
<li>智能实体提取</li>
<li>团队协作</li>
</ul>
<p><a href="{{dashboard_url}}">立即开始使用</a></p>
<p><a href = "{{dashboard_url}}">立即开始使用</a></p>
""",
from_name="InsightFlow 团队",
from_email="welcome@insightflow.io",
from_name = "InsightFlow 团队",
from_email = "welcome@insightflow.io",
)
assert template.id is not None
@@ -385,10 +385,10 @@ class TestGrowthManager:
self.log(f"邮件模板创建成功: {template.id}")
return template.id
except Exception as e:
self.log(f"创建邮件模板失败: {e}", success=False)
self.log(f"创建邮件模板失败: {e}", success = False)
return None
def test_list_email_templates(self):
def test_list_email_templates(self) -> None:
"""测试列出邮件模板"""
print("\n📧 测试列出邮件模板...")
@@ -398,10 +398,10 @@ class TestGrowthManager:
self.log(f"列出 {len(templates)} 个邮件模板")
return True
except Exception as e:
self.log(f"列出邮件模板失败: {e}", success=False)
self.log(f"列出邮件模板失败: {e}", success = False)
return False
def test_render_template(self, template_id: str):
def test_render_template(self, template_id: str) -> None:
"""测试渲染邮件模板"""
print("\n🎨 测试渲染邮件模板...")
@@ -411,8 +411,8 @@ class TestGrowthManager:
try:
rendered = self.manager.render_template(
template_id=template_id,
variables={
template_id = template_id,
variables = {
"user_name": "张三",
"dashboard_url": "https://app.insightflow.io/dashboard",
},
@@ -424,13 +424,13 @@ class TestGrowthManager:
self.log(f"模板渲染成功: {rendered['subject']}")
return True
else:
self.log("模板渲染返回空结果", success=False)
self.log("模板渲染返回空结果", success = False)
return False
except Exception as e:
self.log(f"模板渲染失败: {e}", success=False)
self.log(f"模板渲染失败: {e}", success = False)
return False
def test_create_email_campaign(self, template_id: str):
def test_create_email_campaign(self, template_id: str) -> None:
"""测试创建邮件营销活动"""
print("\n📮 测试创建邮件营销活动...")
@@ -440,10 +440,10 @@ class TestGrowthManager:
try:
campaign = self.manager.create_email_campaign(
tenant_id=self.test_tenant_id,
name="新用户欢迎活动",
template_id=template_id,
recipient_list=[
tenant_id = self.test_tenant_id,
name = "新用户欢迎活动",
template_id = template_id,
recipient_list = [
{"user_id": "user_001", "email": "user1@example.com"},
{"user_id": "user_002", "email": "user2@example.com"},
{"user_id": "user_003", "email": "user3@example.com"},
@@ -456,21 +456,21 @@ class TestGrowthManager:
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
return campaign.id
except Exception as e:
self.log(f"创建营销活动失败: {e}", success=False)
self.log(f"创建营销活动失败: {e}", success = False)
return None
def test_create_automation_workflow(self):
def test_create_automation_workflow(self) -> None:
"""测试创建自动化工作流"""
print("\n🤖 测试创建自动化工作流...")
try:
workflow = self.manager.create_automation_workflow(
tenant_id=self.test_tenant_id,
name="新用户欢迎序列",
description="用户注册后自动发送欢迎邮件序列",
trigger_type=WorkflowTriggerType.USER_SIGNUP,
trigger_conditions={"event": "user_signup"},
actions=[
tenant_id = self.test_tenant_id,
name = "新用户欢迎序列",
description = "用户注册后自动发送欢迎邮件序列",
trigger_type = WorkflowTriggerType.USER_SIGNUP,
trigger_conditions = {"event": "user_signup"},
actions = [
{"type": "send_email", "template_type": "welcome", "delay_hours": 0},
{"type": "send_email", "template_type": "onboarding", "delay_hours": 24},
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72},
@@ -483,27 +483,27 @@ class TestGrowthManager:
self.log(f"自动化工作流创建成功: {workflow.id}")
return True
except Exception as e:
self.log(f"创建工作流失败: {e}", success=False)
self.log(f"创建工作流失败: {e}", success = False)
return False
# ==================== 测试推荐系统 ====================
def test_create_referral_program(self):
def test_create_referral_program(self) -> None:
"""测试创建推荐计划"""
print("\n🎁 测试创建推荐计划...")
try:
program = self.manager.create_referral_program(
tenant_id=self.test_tenant_id,
name="邀请好友奖励计划",
description="邀请好友注册,双方获得积分奖励",
referrer_reward_type="credit",
referrer_reward_value=100.0,
referee_reward_type="credit",
referee_reward_value=50.0,
max_referrals_per_user=10,
referral_code_length=8,
expiry_days=30,
tenant_id = self.test_tenant_id,
name = "邀请好友奖励计划",
description = "邀请好友注册,双方获得积分奖励",
referrer_reward_type = "credit",
referrer_reward_value = 100.0,
referee_reward_type = "credit",
referee_reward_value = 50.0,
max_referrals_per_user = 10,
referral_code_length = 8,
expiry_days = 30,
)
assert program.id is not None
@@ -512,10 +512,10 @@ class TestGrowthManager:
self.log(f"推荐计划创建成功: {program.id}")
return program.id
except Exception as e:
self.log(f"创建推荐计划失败: {e}", success=False)
self.log(f"创建推荐计划失败: {e}", success = False)
return None
def test_generate_referral_code(self, program_id: str):
def test_generate_referral_code(self, program_id: str) -> None:
"""测试生成推荐码"""
print("\n🔑 测试生成推荐码...")
@@ -525,7 +525,7 @@ class TestGrowthManager:
try:
referral = self.manager.generate_referral_code(
program_id=program_id, referrer_id="referrer_user_001"
program_id = program_id, referrer_id = "referrer_user_001"
)
if referral:
@@ -535,13 +535,13 @@ class TestGrowthManager:
self.log(f"推荐码生成成功: {referral.referral_code}")
return referral.referral_code
else:
self.log("生成推荐码返回空结果", success=False)
self.log("生成推荐码返回空结果", success = False)
return None
except Exception as e:
self.log(f"生成推荐码失败: {e}", success=False)
self.log(f"生成推荐码失败: {e}", success = False)
return None
def test_apply_referral_code(self, referral_code: str):
def test_apply_referral_code(self, referral_code: str) -> None:
"""测试应用推荐码"""
print("\n✅ 测试应用推荐码...")
@@ -551,20 +551,20 @@ class TestGrowthManager:
try:
success = self.manager.apply_referral_code(
referral_code=referral_code, referee_id="new_user_001"
referral_code = referral_code, referee_id = "new_user_001"
)
if success:
self.log(f"推荐码应用成功: {referral_code}")
return True
else:
self.log("推荐码应用失败", success=False)
self.log("推荐码应用失败", success = False)
return False
except Exception as e:
self.log(f"应用推荐码失败: {e}", success=False)
self.log(f"应用推荐码失败: {e}", success = False)
return False
def test_get_referral_stats(self, program_id: str):
def test_get_referral_stats(self, program_id: str) -> None:
"""测试获取推荐统计"""
print("\n📊 测试获取推荐统计...")
@@ -583,24 +583,24 @@ class TestGrowthManager:
)
return True
except Exception as e:
self.log(f"获取推荐统计失败: {e}", success=False)
self.log(f"获取推荐统计失败: {e}", success = False)
return False
def test_create_team_incentive(self):
def test_create_team_incentive(self) -> None:
"""测试创建团队激励"""
print("\n🏆 测试创建团队升级激励...")
try:
incentive = self.manager.create_team_incentive(
tenant_id=self.test_tenant_id,
name="团队升级奖励",
description="团队规模达到5人升级到 Pro 计划可获得折扣",
target_tier="pro",
min_team_size=5,
incentive_type="discount",
incentive_value=20.0, # 20% 折扣
valid_from=datetime.now(),
valid_until=datetime.now() + timedelta(days=90),
tenant_id = self.test_tenant_id,
name = "团队升级奖励",
description = "团队规模达到5人升级到 Pro 计划可获得折扣",
target_tier = "pro",
min_team_size = 5,
incentive_type = "discount",
incentive_value = 20.0, # 20% 折扣
valid_from = datetime.now(),
valid_until = datetime.now() + timedelta(days = 90),
)
assert incentive.id is not None
@@ -609,27 +609,27 @@ class TestGrowthManager:
self.log(f"团队激励创建成功: {incentive.id}")
return True
except Exception as e:
self.log(f"创建团队激励失败: {e}", success=False)
self.log(f"创建团队激励失败: {e}", success = False)
return False
def test_check_team_incentive_eligibility(self):
def test_check_team_incentive_eligibility(self) -> None:
"""测试检查团队激励资格"""
print("\n🔍 测试检查团队激励资格...")
try:
incentives = self.manager.check_team_incentive_eligibility(
tenant_id=self.test_tenant_id, current_tier="free", team_size=5
tenant_id = self.test_tenant_id, current_tier = "free", team_size = 5
)
self.log(f"找到 {len(incentives)} 个符合条件的激励")
return True
except Exception as e:
self.log(f"检查激励资格失败: {e}", success=False)
self.log(f"检查激励资格失败: {e}", success = False)
return False
# ==================== 测试实时仪表板 ====================
def test_get_realtime_dashboard(self):
def test_get_realtime_dashboard(self) -> None:
"""测试获取实时仪表板"""
print("\n📺 测试实时分析仪表板...")
@@ -646,21 +646,21 @@ class TestGrowthManager:
)
return True
except Exception as e:
self.log(f"获取实时仪表板失败: {e}", success=False)
self.log(f"获取实时仪表板失败: {e}", success = False)
return False
# ==================== 运行所有测试 ====================
async def run_all_tests(self):
async def run_all_tests(self) -> None:
"""运行所有测试"""
print("=" * 60)
print(" = " * 60)
print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试")
print("=" * 60)
print(" = " * 60)
# 用户行为分析测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📊 模块 1: 用户行为分析")
print("=" * 60)
print(" = " * 60)
await self.test_track_event()
await self.test_track_multiple_events()
@@ -671,9 +671,9 @@ class TestGrowthManager:
self.test_calculate_retention()
# A/B 测试框架测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("🧪 模块 2: A/B 测试框架")
print("=" * 60)
print(" = " * 60)
experiment_id = self.test_create_experiment()
self.test_list_experiments()
@@ -682,9 +682,9 @@ class TestGrowthManager:
self.test_analyze_experiment(experiment_id)
# 邮件营销测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📧 模块 3: 邮件营销自动化")
print("=" * 60)
print(" = " * 60)
template_id = self.test_create_email_template()
self.test_list_email_templates()
@@ -693,9 +693,9 @@ class TestGrowthManager:
self.test_create_automation_workflow()
# 推荐系统测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("🎁 模块 4: 推荐系统")
print("=" * 60)
print(" = " * 60)
program_id = self.test_create_referral_program()
referral_code = self.test_generate_referral_code(program_id)
@@ -705,16 +705,16 @@ class TestGrowthManager:
self.test_check_team_incentive_eligibility()
# 实时仪表板测试
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📺 模块 5: 实时分析仪表板")
print("=" * 60)
print(" = " * 60)
self.test_get_realtime_dashboard()
# 测试总结
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("📋 测试总结")
print("=" * 60)
print(" = " * 60)
total_tests = len(self.test_results)
passed_tests = sum(1 for _, success in self.test_results if success)
@@ -731,12 +731,12 @@ class TestGrowthManager:
if not success:
print(f" - {message}")
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("✨ 测试完成!")
print("=" * 60)
print(" = " * 60)
async def main():
async def main() -> None:
"""主函数"""
tester = TestGrowthManager()
await tester.run_all_tests()

View File

@@ -33,7 +33,7 @@ if backend_dir not in sys.path:
class TestDeveloperEcosystem:
"""开发者生态系统测试类"""
def __init__(self):
def __init__(self) -> None:
self.manager = DeveloperEcosystemManager()
self.test_results = []
self.created_ids = {
@@ -45,7 +45,7 @@ class TestDeveloperEcosystem:
"portal_config": [],
}
def log(self, message: str, success: bool = True):
def log(self, message: str, success: bool = True) -> None:
"""记录测试结果"""
status = "" if success else ""
print(f"{status} {message}")
@@ -53,11 +53,11 @@ class TestDeveloperEcosystem:
{"message": message, "success": success, "timestamp": datetime.now().isoformat()}
)
def run_all_tests(self):
def run_all_tests(self) -> None:
"""运行所有测试"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests")
print("=" * 60)
print(" = " * 60)
# SDK Tests
print("\n📦 SDK Release & Management Tests")
@@ -119,69 +119,69 @@ class TestDeveloperEcosystem:
# Print Summary
self.print_summary()
def test_sdk_create(self):
def test_sdk_create(self) -> None:
"""测试创建 SDK"""
try:
sdk = self.manager.create_sdk_release(
name="InsightFlow Python SDK",
language=SDKLanguage.PYTHON,
version="1.0.0",
description="Python SDK for InsightFlow API",
changelog="Initial release",
download_url="https://pypi.org/insightflow/1.0.0",
documentation_url="https://docs.insightflow.io/python",
repository_url="https://github.com/insightflow/python-sdk",
package_name="insightflow",
min_platform_version="1.0.0",
dependencies=[{"name": "requests", "version": ">=2.0"}],
file_size=1024000,
checksum="abc123",
created_by="test_user",
name = "InsightFlow Python SDK",
language = SDKLanguage.PYTHON,
version = "1.0.0",
description = "Python SDK for InsightFlow API",
changelog = "Initial release",
download_url = "https://pypi.org/insightflow/1.0.0",
documentation_url = "https://docs.insightflow.io/python",
repository_url = "https://github.com/insightflow/python-sdk",
package_name = "insightflow",
min_platform_version = "1.0.0",
dependencies = [{"name": "requests", "version": ">= 2.0"}],
file_size = 1024000,
checksum = "abc123",
created_by = "test_user",
)
self.created_ids["sdk"].append(sdk.id)
self.log(f"Created SDK: {sdk.name} ({sdk.id})")
# Create JavaScript SDK
sdk_js = self.manager.create_sdk_release(
name="InsightFlow JavaScript SDK",
language=SDKLanguage.JAVASCRIPT,
version="1.0.0",
description="JavaScript SDK for InsightFlow API",
changelog="Initial release",
download_url="https://npmjs.com/insightflow/1.0.0",
documentation_url="https://docs.insightflow.io/js",
repository_url="https://github.com/insightflow/js-sdk",
package_name="@insightflow/sdk",
min_platform_version="1.0.0",
dependencies=[{"name": "axios", "version": ">=0.21"}],
file_size=512000,
checksum="def456",
created_by="test_user",
name = "InsightFlow JavaScript SDK",
language = SDKLanguage.JAVASCRIPT,
version = "1.0.0",
description = "JavaScript SDK for InsightFlow API",
changelog = "Initial release",
download_url = "https://npmjs.com/insightflow/1.0.0",
documentation_url = "https://docs.insightflow.io/js",
repository_url = "https://github.com/insightflow/js-sdk",
package_name = "@insightflow/sdk",
min_platform_version = "1.0.0",
dependencies = [{"name": "axios", "version": ">= 0.21"}],
file_size = 512000,
checksum = "def456",
created_by = "test_user",
)
self.created_ids["sdk"].append(sdk_js.id)
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
except Exception as e:
self.log(f"Failed to create SDK: {str(e)}", success=False)
self.log(f"Failed to create SDK: {str(e)}", success = False)
def test_sdk_list(self):
def test_sdk_list(self) -> None:
"""测试列出 SDK"""
try:
sdks = self.manager.list_sdk_releases()
self.log(f"Listed {len(sdks)} SDKs")
# Test filter by language
python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON)
python_sdks = self.manager.list_sdk_releases(language = SDKLanguage.PYTHON)
self.log(f"Found {len(python_sdks)} Python SDKs")
# Test search
search_results = self.manager.list_sdk_releases(search="Python")
search_results = self.manager.list_sdk_releases(search = "Python")
self.log(f"Search found {len(search_results)} SDKs")
except Exception as e:
self.log(f"Failed to list SDKs: {str(e)}", success=False)
self.log(f"Failed to list SDKs: {str(e)}", success = False)
def test_sdk_get(self):
def test_sdk_get(self) -> None:
"""测试获取 SDK 详情"""
try:
if self.created_ids["sdk"]:
@@ -189,23 +189,23 @@ class TestDeveloperEcosystem:
if sdk:
self.log(f"Retrieved SDK: {sdk.name}")
else:
self.log("SDK not found", success=False)
self.log("SDK not found", success = False)
except Exception as e:
self.log(f"Failed to get SDK: {str(e)}", success=False)
self.log(f"Failed to get SDK: {str(e)}", success = False)
def test_sdk_update(self):
def test_sdk_update(self) -> None:
"""测试更新 SDK"""
try:
if self.created_ids["sdk"]:
sdk = self.manager.update_sdk_release(
self.created_ids["sdk"][0], description="Updated description"
self.created_ids["sdk"][0], description = "Updated description"
)
if sdk:
self.log(f"Updated SDK: {sdk.name}")
except Exception as e:
self.log(f"Failed to update SDK: {str(e)}", success=False)
self.log(f"Failed to update SDK: {str(e)}", success = False)
def test_sdk_publish(self):
def test_sdk_publish(self) -> None:
"""测试发布 SDK"""
try:
if self.created_ids["sdk"]:
@@ -213,86 +213,86 @@ class TestDeveloperEcosystem:
if sdk:
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
except Exception as e:
self.log(f"Failed to publish SDK: {str(e)}", success=False)
self.log(f"Failed to publish SDK: {str(e)}", success = False)
def test_sdk_version_add(self):
def test_sdk_version_add(self) -> None:
"""测试添加 SDK 版本"""
try:
if self.created_ids["sdk"]:
version = self.manager.add_sdk_version(
sdk_id=self.created_ids["sdk"][0],
version="1.1.0",
is_lts=True,
release_notes="Bug fixes and improvements",
download_url="https://pypi.org/insightflow/1.1.0",
checksum="xyz789",
file_size=1100000,
sdk_id = self.created_ids["sdk"][0],
version = "1.1.0",
is_lts = True,
release_notes = "Bug fixes and improvements",
download_url = "https://pypi.org/insightflow/1.1.0",
checksum = "xyz789",
file_size = 1100000,
)
self.log(f"Added SDK version: {version.version}")
except Exception as e:
self.log(f"Failed to add SDK version: {str(e)}", success=False)
self.log(f"Failed to add SDK version: {str(e)}", success = False)
def test_template_create(self):
def test_template_create(self) -> None:
"""测试创建模板"""
try:
template = self.manager.create_template(
name="医疗行业实体识别模板",
description="专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
category=TemplateCategory.MEDICAL,
subcategory="entity_recognition",
tags=["medical", "healthcare", "ner"],
author_id="dev_001",
author_name="Medical AI Lab",
price=99.0,
currency="CNY",
preview_image_url="https://cdn.insightflow.io/templates/medical.png",
demo_url="https://demo.insightflow.io/medical",
documentation_url="https://docs.insightflow.io/templates/medical",
download_url="https://cdn.insightflow.io/templates/medical.zip",
version="1.0.0",
min_platform_version="2.0.0",
file_size=5242880,
checksum="tpl123",
name = "医疗行业实体识别模板",
description = "专门针对医疗行业的实体识别模板,支持疾病、药物、症状等实体",
category = TemplateCategory.MEDICAL,
subcategory = "entity_recognition",
tags = ["medical", "healthcare", "ner"],
author_id = "dev_001",
author_name = "Medical AI Lab",
price = 99.0,
currency = "CNY",
preview_image_url = "https://cdn.insightflow.io/templates/medical.png",
demo_url = "https://demo.insightflow.io/medical",
documentation_url = "https://docs.insightflow.io/templates/medical",
download_url = "https://cdn.insightflow.io/templates/medical.zip",
version = "1.0.0",
min_platform_version = "2.0.0",
file_size = 5242880,
checksum = "tpl123",
)
self.created_ids["template"].append(template.id)
self.log(f"Created template: {template.name} ({template.id})")
# Create free template
template_free = self.manager.create_template(
name="通用实体识别模板",
description="适用于一般场景的实体识别模板",
category=TemplateCategory.GENERAL,
subcategory=None,
tags=["general", "ner", "basic"],
author_id="dev_002",
author_name="InsightFlow Team",
price=0.0,
currency="CNY",
name = "通用实体识别模板",
description = "适用于一般场景的实体识别模板",
category = TemplateCategory.GENERAL,
subcategory = None,
tags = ["general", "ner", "basic"],
author_id = "dev_002",
author_name = "InsightFlow Team",
price = 0.0,
currency = "CNY",
)
self.created_ids["template"].append(template_free.id)
self.log(f"Created free template: {template_free.name}")
except Exception as e:
self.log(f"Failed to create template: {str(e)}", success=False)
self.log(f"Failed to create template: {str(e)}", success = False)
def test_template_list(self):
def test_template_list(self) -> None:
"""测试列出模板"""
try:
templates = self.manager.list_templates()
self.log(f"Listed {len(templates)} templates")
# Filter by category
medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL)
medical_templates = self.manager.list_templates(category = TemplateCategory.MEDICAL)
self.log(f"Found {len(medical_templates)} medical templates")
# Filter by price
free_templates = self.manager.list_templates(max_price=0)
free_templates = self.manager.list_templates(max_price = 0)
self.log(f"Found {len(free_templates)} free templates")
except Exception as e:
self.log(f"Failed to list templates: {str(e)}", success=False)
self.log(f"Failed to list templates: {str(e)}", success = False)
def test_template_get(self):
def test_template_get(self) -> None:
"""测试获取模板详情"""
try:
if self.created_ids["template"]:
@@ -300,21 +300,21 @@ class TestDeveloperEcosystem:
if template:
self.log(f"Retrieved template: {template.name}")
except Exception as e:
self.log(f"Failed to get template: {str(e)}", success=False)
self.log(f"Failed to get template: {str(e)}", success = False)
def test_template_approve(self):
def test_template_approve(self) -> None:
"""测试审核通过模板"""
try:
if self.created_ids["template"]:
template = self.manager.approve_template(
self.created_ids["template"][0], reviewed_by="admin_001"
self.created_ids["template"][0], reviewed_by = "admin_001"
)
if template:
self.log(f"Approved template: {template.name}")
except Exception as e:
self.log(f"Failed to approve template: {str(e)}", success=False)
self.log(f"Failed to approve template: {str(e)}", success = False)
def test_template_publish(self):
def test_template_publish(self) -> None:
"""测试发布模板"""
try:
if self.created_ids["template"]:
@@ -322,84 +322,84 @@ class TestDeveloperEcosystem:
if template:
self.log(f"Published template: {template.name}")
except Exception as e:
self.log(f"Failed to publish template: {str(e)}", success=False)
self.log(f"Failed to publish template: {str(e)}", success = False)
def test_template_review(self):
def test_template_review(self) -> None:
"""测试添加模板评价"""
try:
if self.created_ids["template"]:
review = self.manager.add_template_review(
template_id=self.created_ids["template"][0],
user_id="user_001",
user_name="Test User",
rating=5,
comment="Great template! Very accurate for medical entities.",
is_verified_purchase=True,
template_id = self.created_ids["template"][0],
user_id = "user_001",
user_name = "Test User",
rating = 5,
comment = "Great template! Very accurate for medical entities.",
is_verified_purchase = True,
)
self.log(f"Added template review: {review.rating} stars")
except Exception as e:
self.log(f"Failed to add template review: {str(e)}", success=False)
self.log(f"Failed to add template review: {str(e)}", success = False)
def test_plugin_create(self):
def test_plugin_create(self) -> None:
"""测试创建插件"""
try:
plugin = self.manager.create_plugin(
name="飞书机器人集成插件",
description="将 InsightFlow 与飞书机器人集成,实现自动通知",
category=PluginCategory.INTEGRATION,
tags=["feishu", "bot", "integration", "notification"],
author_id="dev_003",
author_name="Integration Team",
price=49.0,
currency="CNY",
pricing_model="paid",
preview_image_url="https://cdn.insightflow.io/plugins/feishu.png",
demo_url="https://demo.insightflow.io/feishu",
documentation_url="https://docs.insightflow.io/plugins/feishu",
repository_url="https://github.com/insightflow/feishu-plugin",
download_url="https://cdn.insightflow.io/plugins/feishu.zip",
webhook_url="https://api.insightflow.io/webhooks/feishu",
permissions=["read:projects", "write:notifications"],
version="1.0.0",
min_platform_version="2.0.0",
file_size=1048576,
checksum="plg123",
name = "飞书机器人集成插件",
description = "将 InsightFlow 与飞书机器人集成,实现自动通知",
category = PluginCategory.INTEGRATION,
tags = ["feishu", "bot", "integration", "notification"],
author_id = "dev_003",
author_name = "Integration Team",
price = 49.0,
currency = "CNY",
pricing_model = "paid",
preview_image_url = "https://cdn.insightflow.io/plugins/feishu.png",
demo_url = "https://demo.insightflow.io/feishu",
documentation_url = "https://docs.insightflow.io/plugins/feishu",
repository_url = "https://github.com/insightflow/feishu-plugin",
download_url = "https://cdn.insightflow.io/plugins/feishu.zip",
webhook_url = "https://api.insightflow.io/webhooks/feishu",
permissions = ["read:projects", "write:notifications"],
version = "1.0.0",
min_platform_version = "2.0.0",
file_size = 1048576,
checksum = "plg123",
)
self.created_ids["plugin"].append(plugin.id)
self.log(f"Created plugin: {plugin.name} ({plugin.id})")
# Create free plugin
plugin_free = self.manager.create_plugin(
name="数据导出插件",
description="支持多种格式的数据导出",
category=PluginCategory.ANALYSIS,
tags=["export", "data", "csv", "json"],
author_id="dev_004",
author_name="Data Team",
price=0.0,
currency="CNY",
pricing_model="free",
name = "数据导出插件",
description = "支持多种格式的数据导出",
category = PluginCategory.ANALYSIS,
tags = ["export", "data", "csv", "json"],
author_id = "dev_004",
author_name = "Data Team",
price = 0.0,
currency = "CNY",
pricing_model = "free",
)
self.created_ids["plugin"].append(plugin_free.id)
self.log(f"Created free plugin: {plugin_free.name}")
except Exception as e:
self.log(f"Failed to create plugin: {str(e)}", success=False)
self.log(f"Failed to create plugin: {str(e)}", success = False)
def test_plugin_list(self):
def test_plugin_list(self) -> None:
"""测试列出插件"""
try:
plugins = self.manager.list_plugins()
self.log(f"Listed {len(plugins)} plugins")
# Filter by category
integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION)
integration_plugins = self.manager.list_plugins(category = PluginCategory.INTEGRATION)
self.log(f"Found {len(integration_plugins)} integration plugins")
except Exception as e:
self.log(f"Failed to list plugins: {str(e)}", success=False)
self.log(f"Failed to list plugins: {str(e)}", success = False)
def test_plugin_get(self):
def test_plugin_get(self) -> None:
"""测试获取插件详情"""
try:
if self.created_ids["plugin"]:
@@ -407,24 +407,24 @@ class TestDeveloperEcosystem:
if plugin:
self.log(f"Retrieved plugin: {plugin.name}")
except Exception as e:
self.log(f"Failed to get plugin: {str(e)}", success=False)
self.log(f"Failed to get plugin: {str(e)}", success = False)
def test_plugin_review(self):
def test_plugin_review(self) -> None:
"""测试审核插件"""
try:
if self.created_ids["plugin"]:
plugin = self.manager.review_plugin(
self.created_ids["plugin"][0],
reviewed_by="admin_001",
status=PluginStatus.APPROVED,
notes="Code review passed",
reviewed_by = "admin_001",
status = PluginStatus.APPROVED,
notes = "Code review passed",
)
if plugin:
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
except Exception as e:
self.log(f"Failed to review plugin: {str(e)}", success=False)
self.log(f"Failed to review plugin: {str(e)}", success = False)
def test_plugin_publish(self):
def test_plugin_publish(self) -> None:
"""测试发布插件"""
try:
if self.created_ids["plugin"]:
@@ -432,56 +432,56 @@ class TestDeveloperEcosystem:
if plugin:
self.log(f"Published plugin: {plugin.name}")
except Exception as e:
self.log(f"Failed to publish plugin: {str(e)}", success=False)
self.log(f"Failed to publish plugin: {str(e)}", success = False)
def test_plugin_review_add(self):
def test_plugin_review_add(self) -> None:
"""测试添加插件评价"""
try:
if self.created_ids["plugin"]:
review = self.manager.add_plugin_review(
plugin_id=self.created_ids["plugin"][0],
user_id="user_002",
user_name="Plugin User",
rating=4,
comment="Works great with Feishu!",
is_verified_purchase=True,
plugin_id = self.created_ids["plugin"][0],
user_id = "user_002",
user_name = "Plugin User",
rating = 4,
comment = "Works great with Feishu!",
is_verified_purchase = True,
)
self.log(f"Added plugin review: {review.rating} stars")
except Exception as e:
self.log(f"Failed to add plugin review: {str(e)}", success=False)
self.log(f"Failed to add plugin review: {str(e)}", success = False)
def test_developer_profile_create(self):
def test_developer_profile_create(self) -> None:
"""测试创建开发者档案"""
try:
# Generate unique user IDs
unique_id = uuid.uuid4().hex[:8]
profile = self.manager.create_developer_profile(
user_id=f"user_dev_{unique_id}_001",
display_name="张三",
email=f"zhangsan_{unique_id}@example.com",
bio="专注于医疗AI和自然语言处理",
website="https://zhangsan.dev",
github_url="https://github.com/zhangsan",
avatar_url="https://cdn.example.com/avatars/zhangsan.png",
user_id = f"user_dev_{unique_id}_001",
display_name = "张三",
email = f"zhangsan_{unique_id}@example.com",
bio = "专注于医疗AI和自然语言处理",
website = "https://zhangsan.dev",
github_url = "https://github.com/zhangsan",
avatar_url = "https://cdn.example.com/avatars/zhangsan.png",
)
self.created_ids["developer"].append(profile.id)
self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
# Create another developer
profile2 = self.manager.create_developer_profile(
user_id=f"user_dev_{unique_id}_002",
display_name="李四",
email=f"lisi_{unique_id}@example.com",
bio="全栈开发者,热爱开源",
user_id = f"user_dev_{unique_id}_002",
display_name = "李四",
email = f"lisi_{unique_id}@example.com",
bio = "全栈开发者,热爱开源",
)
self.created_ids["developer"].append(profile2.id)
self.log(f"Created developer profile: {profile2.display_name}")
except Exception as e:
self.log(f"Failed to create developer profile: {str(e)}", success=False)
self.log(f"Failed to create developer profile: {str(e)}", success = False)
def test_developer_profile_get(self):
def test_developer_profile_get(self) -> None:
"""测试获取开发者档案"""
try:
if self.created_ids["developer"]:
@@ -489,9 +489,9 @@ class TestDeveloperEcosystem:
if profile:
self.log(f"Retrieved developer profile: {profile.display_name}")
except Exception as e:
self.log(f"Failed to get developer profile: {str(e)}", success=False)
self.log(f"Failed to get developer profile: {str(e)}", success = False)
def test_developer_verify(self):
def test_developer_verify(self) -> None:
"""测试验证开发者"""
try:
if self.created_ids["developer"]:
@@ -501,9 +501,9 @@ class TestDeveloperEcosystem:
if profile:
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
except Exception as e:
self.log(f"Failed to verify developer: {str(e)}", success=False)
self.log(f"Failed to verify developer: {str(e)}", success = False)
def test_developer_stats_update(self):
def test_developer_stats_update(self) -> None:
"""测试更新开发者统计"""
try:
if self.created_ids["developer"]:
@@ -513,38 +513,38 @@ class TestDeveloperEcosystem:
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates"
)
except Exception as e:
self.log(f"Failed to update developer stats: {str(e)}", success=False)
self.log(f"Failed to update developer stats: {str(e)}", success = False)
def test_code_example_create(self):
def test_code_example_create(self) -> None:
"""测试创建代码示例"""
try:
example = self.manager.create_code_example(
title="使用 Python SDK 创建项目",
description="演示如何使用 Python SDK 创建新项目",
language="python",
category="quickstart",
code="""from insightflow import Client
title = "使用 Python SDK 创建项目",
description = "演示如何使用 Python SDK 创建新项目",
language = "python",
category = "quickstart",
code = """from insightflow import Client
client = Client(api_key="your_api_key")
project = client.projects.create(name="My Project")
client = Client(api_key = "your_api_key")
project = client.projects.create(name = "My Project")
print(f"Created project: {project.id}")
""",
explanation="首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
tags=["python", "quickstart", "projects"],
author_id="dev_001",
author_name="InsightFlow Team",
api_endpoints=["/api/v1/projects"],
explanation = "首先导入 Client 类,然后使用 API Key 初始化客户端,最后调用 create 方法创建项目。",
tags = ["python", "quickstart", "projects"],
author_id = "dev_001",
author_name = "InsightFlow Team",
api_endpoints = ["/api/v1/projects"],
)
self.created_ids["code_example"].append(example.id)
self.log(f"Created code example: {example.title}")
# Create JavaScript example
example_js = self.manager.create_code_example(
title="使用 JavaScript SDK 上传文件",
description="演示如何使用 JavaScript SDK 上传音频文件",
language="javascript",
category="upload",
code="""const { Client } = require('insightflow');
title = "使用 JavaScript SDK 上传文件",
description = "演示如何使用 JavaScript SDK 上传音频文件",
language = "javascript",
category = "upload",
code = """const { Client } = require('insightflow');
const client = new Client({ apiKey: 'your_api_key' });
const result = await client.uploads.create({
@@ -553,31 +553,31 @@ const result = await client.uploads.create({
});
console.log('Upload complete:', result.id);
""",
explanation="使用 JavaScript SDK 上传文件到 InsightFlow",
tags=["javascript", "upload", "audio"],
author_id="dev_002",
author_name="JS Team",
explanation = "使用 JavaScript SDK 上传文件到 InsightFlow",
tags = ["javascript", "upload", "audio"],
author_id = "dev_002",
author_name = "JS Team",
)
self.created_ids["code_example"].append(example_js.id)
self.log(f"Created code example: {example_js.title}")
except Exception as e:
self.log(f"Failed to create code example: {str(e)}", success=False)
self.log(f"Failed to create code example: {str(e)}", success = False)
def test_code_example_list(self):
def test_code_example_list(self) -> None:
"""测试列出代码示例"""
try:
examples = self.manager.list_code_examples()
self.log(f"Listed {len(examples)} code examples")
# Filter by language
python_examples = self.manager.list_code_examples(language="python")
python_examples = self.manager.list_code_examples(language = "python")
self.log(f"Found {len(python_examples)} Python examples")
except Exception as e:
self.log(f"Failed to list code examples: {str(e)}", success=False)
self.log(f"Failed to list code examples: {str(e)}", success = False)
def test_code_example_get(self):
def test_code_example_get(self) -> None:
"""测试获取代码示例详情"""
try:
if self.created_ids["code_example"]:
@@ -587,30 +587,30 @@ console.log('Upload complete:', result.id);
f"Retrieved code example: {example.title} (views: {example.view_count})"
)
except Exception as e:
self.log(f"Failed to get code example: {str(e)}", success=False)
self.log(f"Failed to get code example: {str(e)}", success = False)
def test_portal_config_create(self):
def test_portal_config_create(self) -> None:
"""测试创建开发者门户配置"""
try:
config = self.manager.create_portal_config(
name="InsightFlow Developer Portal",
description="开发者门户 - SDK、API 文档和示例代码",
theme="default",
primary_color="#1890ff",
secondary_color="#52c41a",
support_email="developers@insightflow.io",
support_url="https://support.insightflow.io",
github_url="https://github.com/insightflow",
discord_url="https://discord.gg/insightflow",
api_base_url="https://api.insightflow.io/v1",
name = "InsightFlow Developer Portal",
description = "开发者门户 - SDK、API 文档和示例代码",
theme = "default",
primary_color = "#1890ff",
secondary_color = "#52c41a",
support_email = "developers@insightflow.io",
support_url = "https://support.insightflow.io",
github_url = "https://github.com/insightflow",
discord_url = "https://discord.gg/insightflow",
api_base_url = "https://api.insightflow.io/v1",
)
self.created_ids["portal_config"].append(config.id)
self.log(f"Created portal config: {config.name}")
except Exception as e:
self.log(f"Failed to create portal config: {str(e)}", success=False)
self.log(f"Failed to create portal config: {str(e)}", success = False)
def test_portal_config_get(self):
def test_portal_config_get(self) -> None:
"""测试获取开发者门户配置"""
try:
if self.created_ids["portal_config"]:
@@ -624,29 +624,29 @@ console.log('Upload complete:', result.id);
self.log(f"Active portal config: {active_config.name}")
except Exception as e:
self.log(f"Failed to get portal config: {str(e)}", success=False)
self.log(f"Failed to get portal config: {str(e)}", success = False)
def test_revenue_record(self):
def test_revenue_record(self) -> None:
"""测试记录开发者收益"""
try:
if self.created_ids["developer"] and self.created_ids["plugin"]:
revenue = self.manager.record_revenue(
developer_id=self.created_ids["developer"][0],
item_type="plugin",
item_id=self.created_ids["plugin"][0],
item_name="飞书机器人集成插件",
sale_amount=49.0,
currency="CNY",
buyer_id="user_buyer_001",
transaction_id="txn_123456",
developer_id = self.created_ids["developer"][0],
item_type = "plugin",
item_id = self.created_ids["plugin"][0],
item_name = "飞书机器人集成插件",
sale_amount = 49.0,
currency = "CNY",
buyer_id = "user_buyer_001",
transaction_id = "txn_123456",
)
self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}")
self.log(f" - Platform fee: {revenue.platform_fee}")
self.log(f" - Developer earnings: {revenue.developer_earnings}")
except Exception as e:
self.log(f"Failed to record revenue: {str(e)}", success=False)
self.log(f"Failed to record revenue: {str(e)}", success = False)
def test_revenue_summary(self):
def test_revenue_summary(self) -> None:
"""测试获取开发者收益汇总"""
try:
if self.created_ids["developer"]:
@@ -659,13 +659,13 @@ console.log('Upload complete:', result.id);
self.log(f" - Total earnings: {summary['total_earnings']}")
self.log(f" - Transaction count: {summary['transaction_count']}")
except Exception as e:
self.log(f"Failed to get revenue summary: {str(e)}", success=False)
self.log(f"Failed to get revenue summary: {str(e)}", success = False)
def print_summary(self):
def print_summary(self) -> None:
"""打印测试摘要"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("Test Summary")
print("=" * 60)
print(" = " * 60)
total = len(self.test_results)
passed = sum(1 for r in self.test_results if r["success"])
@@ -686,10 +686,10 @@ console.log('Upload complete:', result.id);
if ids:
print(f" {resource_type}: {len(ids)}")
print("=" * 60)
print(" = " * 60)
def main():
def main() -> None:
"""主函数"""
test = TestDeveloperEcosystem()
test.run_all_tests()

View File

@@ -34,22 +34,22 @@ if backend_dir not in sys.path:
class TestOpsManager:
"""测试运维与监控管理器"""
def __init__(self):
def __init__(self) -> None:
self.manager = get_ops_manager()
self.tenant_id = "test_tenant_001"
self.test_results = []
def log(self, message: str, success: bool = True):
def log(self, message: str, success: bool = True) -> None:
"""记录测试结果"""
status = "" if success else ""
print(f"{status} {message}")
self.test_results.append((message, success))
def run_all_tests(self):
def run_all_tests(self) -> None:
"""运行所有测试"""
print("=" * 60)
print(" = " * 60)
print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests")
print("=" * 60)
print(" = " * 60)
# 1. 告警系统测试
self.test_alert_rules()
@@ -73,46 +73,46 @@ class TestOpsManager:
# 打印测试总结
self.print_summary()
def test_alert_rules(self):
def test_alert_rules(self) -> None:
"""测试告警规则管理"""
print("\n📋 Testing Alert Rules...")
try:
# 创建阈值告警规则
rule1 = self.manager.create_alert_rule(
tenant_id=self.tenant_id,
name="CPU 使用率告警",
description="当 CPU 使用率超过 80% 时触发告警",
rule_type=AlertRuleType.THRESHOLD,
severity=AlertSeverity.P1,
metric="cpu_usage_percent",
condition=">",
threshold=80.0,
duration=300,
evaluation_interval=60,
channels=[],
labels={"service": "api", "team": "platform"},
annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by="test_user",
tenant_id = self.tenant_id,
name = "CPU 使用率告警",
description = "当 CPU 使用率超过 80% 时触发告警",
rule_type = AlertRuleType.THRESHOLD,
severity = AlertSeverity.P1,
metric = "cpu_usage_percent",
condition = ">",
threshold = 80.0,
duration = 300,
evaluation_interval = 60,
channels = [],
labels = {"service": "api", "team": "platform"},
annotations = {"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by = "test_user",
)
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
# 创建异常检测告警规则
rule2 = self.manager.create_alert_rule(
tenant_id=self.tenant_id,
name="内存异常检测",
description="检测内存使用异常",
rule_type=AlertRuleType.ANOMALY,
severity=AlertSeverity.P2,
metric="memory_usage_percent",
condition=">",
threshold=0.0,
duration=600,
evaluation_interval=300,
channels=[],
labels={"service": "database"},
annotations={},
created_by="test_user",
tenant_id = self.tenant_id,
name = "内存异常检测",
description = "检测内存使用异常",
rule_type = AlertRuleType.ANOMALY,
severity = AlertSeverity.P2,
metric = "memory_usage_percent",
condition = ">",
threshold = 0.0,
duration = 600,
evaluation_interval = 300,
channels = [],
labels = {"service": "database"},
annotations = {},
created_by = "test_user",
)
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
@@ -129,7 +129,7 @@ class TestOpsManager:
# 更新告警规则
updated_rule = self.manager.update_alert_rule(
rule1.id, threshold=85.0, description="更新后的描述"
rule1.id, threshold = 85.0, description = "更新后的描述"
)
assert updated_rule.threshold == 85.0
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
@@ -140,46 +140,46 @@ class TestOpsManager:
self.log("Deleted test alert rules")
except Exception as e:
self.log(f"Alert rules test failed: {e}", success=False)
self.log(f"Alert rules test failed: {e}", success = False)
def test_alert_channels(self):
def test_alert_channels(self) -> None:
"""测试告警渠道管理"""
print("\n📢 Testing Alert Channels...")
try:
# 创建飞书告警渠道
channel1 = self.manager.create_alert_channel(
tenant_id=self.tenant_id,
name="飞书告警",
channel_type=AlertChannelType.FEISHU,
config={
tenant_id = self.tenant_id,
name = "飞书告警",
channel_type = AlertChannelType.FEISHU,
config = {
"webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test",
"secret": "test_secret",
},
severity_filter=["p0", "p1"],
severity_filter = ["p0", "p1"],
)
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
# 创建钉钉告警渠道
channel2 = self.manager.create_alert_channel(
tenant_id=self.tenant_id,
name="钉钉告警",
channel_type=AlertChannelType.DINGTALK,
config={
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test",
tenant_id = self.tenant_id,
name = "钉钉告警",
channel_type = AlertChannelType.DINGTALK,
config = {
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token = test",
"secret": "test_secret",
},
severity_filter=["p0", "p1", "p2"],
severity_filter = ["p0", "p1", "p2"],
)
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
# 创建 Slack 告警渠道
channel3 = self.manager.create_alert_channel(
tenant_id=self.tenant_id,
name="Slack 告警",
channel_type=AlertChannelType.SLACK,
config={"webhook_url": "https://hooks.slack.com/services/test"},
severity_filter=["p0", "p1", "p2", "p3"],
tenant_id = self.tenant_id,
name = "Slack 告警",
channel_type = AlertChannelType.SLACK,
config = {"webhook_url": "https://hooks.slack.com/services/test"},
severity_filter = ["p0", "p1", "p2", "p3"],
)
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
@@ -198,46 +198,46 @@ class TestOpsManager:
for channel in channels:
if channel.tenant_id == self.tenant_id:
with self.manager._get_db() as conn:
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,))
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id, ))
conn.commit()
self.log("Deleted test alert channels")
except Exception as e:
self.log(f"Alert channels test failed: {e}", success=False)
self.log(f"Alert channels test failed: {e}", success = False)
def test_alerts(self):
def test_alerts(self) -> None:
"""测试告警管理"""
print("\n🚨 Testing Alerts...")
try:
# 创建告警规则
rule = self.manager.create_alert_rule(
tenant_id=self.tenant_id,
name="测试告警规则",
description="用于测试的告警规则",
rule_type=AlertRuleType.THRESHOLD,
severity=AlertSeverity.P1,
metric="test_metric",
condition=">",
threshold=100.0,
duration=60,
evaluation_interval=60,
channels=[],
labels={},
annotations={},
created_by="test_user",
tenant_id = self.tenant_id,
name = "测试告警规则",
description = "用于测试的告警规则",
rule_type = AlertRuleType.THRESHOLD,
severity = AlertSeverity.P1,
metric = "test_metric",
condition = ">",
threshold = 100.0,
duration = 60,
evaluation_interval = 60,
channels = [],
labels = {},
annotations = {},
created_by = "test_user",
)
# 记录资源指标
for i in range(10):
self.manager.record_resource_metric(
tenant_id=self.tenant_id,
resource_type=ResourceType.CPU,
resource_id="server-001",
metric_name="test_metric",
metric_value=110.0 + i,
unit="percent",
metadata={"region": "cn-north-1"},
tenant_id = self.tenant_id,
resource_type = ResourceType.CPU,
resource_id = "server-001",
metric_name = "test_metric",
metric_value = 110.0 + i,
unit = "percent",
metadata = {"region": "cn-north-1"},
)
self.log("Recorded 10 resource metrics")
@@ -248,24 +248,24 @@ class TestOpsManager:
now = datetime.now().isoformat()
alert = Alert(
id=alert_id,
rule_id=rule.id,
tenant_id=self.tenant_id,
severity=AlertSeverity.P1,
status=AlertStatus.FIRING,
title="测试告警",
description="这是一条测试告警",
metric="test_metric",
value=120.0,
threshold=100.0,
labels={"test": "true"},
annotations={},
started_at=now,
resolved_at=None,
acknowledged_by=None,
acknowledged_at=None,
notification_sent={},
suppression_count=0,
id = alert_id,
rule_id = rule.id,
tenant_id = self.tenant_id,
severity = AlertSeverity.P1,
status = AlertStatus.FIRING,
title = "测试告警",
description = "这是一条测试告警",
metric = "test_metric",
value = 120.0,
threshold = 100.0,
labels = {"test": "true"},
annotations = {},
started_at = now,
resolved_at = None,
acknowledged_by = None,
acknowledged_at = None,
notification_sent = {},
suppression_count = 0,
)
with self.manager._get_db() as conn:
@@ -320,23 +320,23 @@ class TestOpsManager:
# 清理
self.manager.delete_alert_rule(rule.id)
with self.manager._get_db() as conn:
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id,))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM alerts WHERE id = ?", (alert_id, ))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up test data")
except Exception as e:
self.log(f"Alerts test failed: {e}", success=False)
self.log(f"Alerts test failed: {e}", success = False)
def test_capacity_planning(self):
def test_capacity_planning(self) -> None:
"""测试容量规划"""
print("\n📊 Testing Capacity Planning...")
try:
# 记录历史指标数据
base_time = datetime.now() - timedelta(days=30)
base_time = datetime.now() - timedelta(days = 30)
for i in range(30):
timestamp = (base_time + timedelta(days=i)).isoformat()
timestamp = (base_time + timedelta(days = i)).isoformat()
with self.manager._get_db() as conn:
conn.execute(
"""
@@ -360,13 +360,13 @@ class TestOpsManager:
self.log("Recorded 30 days of historical metrics")
# 创建容量规划
prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
prediction_date = (datetime.now() + timedelta(days = 30)).strftime("%Y-%m-%d")
plan = self.manager.create_capacity_plan(
tenant_id=self.tenant_id,
resource_type=ResourceType.CPU,
current_capacity=100.0,
prediction_date=prediction_date,
confidence=0.85,
tenant_id = self.tenant_id,
resource_type = ResourceType.CPU,
current_capacity = 100.0,
prediction_date = prediction_date,
confidence = 0.85,
)
self.log(f"Created capacity plan: {plan.id}")
@@ -381,32 +381,32 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up capacity planning test data")
except Exception as e:
self.log(f"Capacity planning test failed: {e}", success=False)
self.log(f"Capacity planning test failed: {e}", success = False)
def test_auto_scaling(self):
def test_auto_scaling(self) -> None:
"""测试自动扩缩容"""
print("\n⚖️ Testing Auto Scaling...")
try:
# 创建自动扩缩容策略
policy = self.manager.create_auto_scaling_policy(
tenant_id=self.tenant_id,
name="API 服务自动扩缩容",
resource_type=ResourceType.CPU,
min_instances=2,
max_instances=10,
target_utilization=0.7,
scale_up_threshold=0.8,
scale_down_threshold=0.3,
scale_up_step=2,
scale_down_step=1,
cooldown_period=300,
tenant_id = self.tenant_id,
name = "API 服务自动扩缩容",
resource_type = ResourceType.CPU,
min_instances = 2,
max_instances = 10,
target_utilization = 0.7,
scale_up_threshold = 0.8,
scale_down_threshold = 0.3,
scale_up_step = 2,
scale_down_step = 1,
cooldown_period = 300,
)
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
@@ -421,7 +421,7 @@ class TestOpsManager:
# 模拟扩缩容评估
event = self.manager.evaluate_scaling_policy(
policy_id=policy.id, current_instances=3, current_utilization=0.85
policy_id = policy.id, current_instances = 3, current_utilization = 0.85
)
if event:
@@ -437,46 +437,46 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute(
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,)
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id, )
)
conn.commit()
self.log("Cleaned up auto scaling test data")
except Exception as e:
self.log(f"Auto scaling test failed: {e}", success=False)
self.log(f"Auto scaling test failed: {e}", success = False)
def test_health_checks(self):
def test_health_checks(self) -> None:
"""测试健康检查"""
print("\n💓 Testing Health Checks...")
try:
# 创建 HTTP 健康检查
check1 = self.manager.create_health_check(
tenant_id=self.tenant_id,
name="API 服务健康检查",
target_type="service",
target_id="api-service",
check_type="http",
check_config={"url": "https://api.insightflow.io/health", "expected_status": 200},
interval=60,
timeout=10,
retry_count=3,
tenant_id = self.tenant_id,
name = "API 服务健康检查",
target_type = "service",
target_id = "api-service",
check_type = "http",
check_config = {"url": "https://api.insightflow.io/health", "expected_status": 200},
interval = 60,
timeout = 10,
retry_count = 3,
)
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
# 创建 TCP 健康检查
check2 = self.manager.create_health_check(
tenant_id=self.tenant_id,
name="数据库健康检查",
target_type="database",
target_id="postgres-001",
check_type="tcp",
check_config={"host": "db.insightflow.io", "port": 5432},
interval=30,
timeout=5,
retry_count=2,
tenant_id = self.tenant_id,
name = "数据库健康检查",
target_type = "database",
target_id = "postgres-001",
check_type = "tcp",
check_config = {"host": "db.insightflow.io", "port": 5432},
interval = 30,
timeout = 5,
retry_count = 2,
)
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
@@ -486,7 +486,7 @@ class TestOpsManager:
self.log(f"Listed {len(checks)} health checks")
# 执行健康检查(异步)
async def run_health_check():
async def run_health_check() -> None:
result = await self.manager.execute_health_check(check1.id)
return result
@@ -495,28 +495,28 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up health check test data")
except Exception as e:
self.log(f"Health checks test failed: {e}", success=False)
self.log(f"Health checks test failed: {e}", success = False)
def test_failover(self):
def test_failover(self) -> None:
"""测试故障转移"""
print("\n🔄 Testing Failover...")
try:
# 创建故障转移配置
config = self.manager.create_failover_config(
tenant_id=self.tenant_id,
name="主备数据中心故障转移",
primary_region="cn-north-1",
secondary_regions=["cn-south-1", "cn-east-1"],
failover_trigger="health_check_failed",
auto_failover=False,
failover_timeout=300,
health_check_id=None,
tenant_id = self.tenant_id,
name = "主备数据中心故障转移",
primary_region = "cn-north-1",
secondary_regions = ["cn-south-1", "cn-east-1"],
failover_trigger = "health_check_failed",
auto_failover = False,
failover_timeout = 300,
health_check_id = None,
)
self.log(f"Created failover config: {config.name} (ID: {config.id})")
@@ -530,7 +530,7 @@ class TestOpsManager:
# 发起故障转移
event = self.manager.initiate_failover(
config_id=config.id, reason="Primary region health check failed"
config_id = config.id, reason = "Primary region health check failed"
)
if event:
@@ -550,31 +550,31 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up failover test data")
except Exception as e:
self.log(f"Failover test failed: {e}", success=False)
self.log(f"Failover test failed: {e}", success = False)
def test_backup(self):
def test_backup(self) -> None:
"""测试备份与恢复"""
print("\n💾 Testing Backup & Recovery...")
try:
# 创建备份任务
job = self.manager.create_backup_job(
tenant_id=self.tenant_id,
name="每日数据库备份",
backup_type="full",
target_type="database",
target_id="postgres-main",
schedule="0 2 * * *", # 每天凌晨2点
retention_days=30,
encryption_enabled=True,
compression_enabled=True,
storage_location="s3://insightflow-backups/",
tenant_id = self.tenant_id,
name = "每日数据库备份",
backup_type = "full",
target_type = "database",
target_id = "postgres-main",
schedule = "0 2 * * *", # 每天凌晨2点
retention_days = 30,
encryption_enabled = True,
compression_enabled = True,
storage_location = "s3://insightflow-backups/",
)
self.log(f"Created backup job: {job.name} (ID: {job.id})")
@@ -604,15 +604,15 @@ class TestOpsManager:
# 清理
with self.manager._get_db() as conn:
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up backup test data")
except Exception as e:
self.log(f"Backup test failed: {e}", success=False)
self.log(f"Backup test failed: {e}", success = False)
def test_cost_optimization(self):
def test_cost_optimization(self) -> None:
"""测试成本优化"""
print("\n💰 Testing Cost Optimization...")
@@ -622,15 +622,15 @@ class TestOpsManager:
for i in range(5):
self.manager.record_resource_utilization(
tenant_id=self.tenant_id,
resource_type=ResourceType.CPU,
resource_id=f"server-{i:03d}",
utilization_rate=0.05 + random.random() * 0.1, # 低利用率
peak_utilization=0.15,
avg_utilization=0.08,
idle_time_percent=0.85,
report_date=report_date,
recommendations=["Consider downsizing this resource"],
tenant_id = self.tenant_id,
resource_type = ResourceType.CPU,
resource_id = f"server-{i:03d}",
utilization_rate = 0.05 + random.random() * 0.1, # 低利用率
peak_utilization = 0.15,
avg_utilization = 0.08,
idle_time_percent = 0.85,
report_date = report_date,
recommendations = ["Consider downsizing this resource"],
)
self.log("Recorded 5 resource utilization records")
@@ -638,7 +638,7 @@ class TestOpsManager:
# 生成成本报告
now = datetime.now()
report = self.manager.generate_cost_report(
tenant_id=self.tenant_id, year=now.year, month=now.month
tenant_id = self.tenant_id, year = now.year, month = now.month
)
self.log(f"Generated cost report: {report.id}")
@@ -687,24 +687,24 @@ class TestOpsManager:
with self.manager._get_db() as conn:
conn.execute(
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
(self.tenant_id,),
(self.tenant_id, ),
)
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id, ))
conn.execute(
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,)
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id, )
)
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id, ))
conn.commit()
self.log("Cleaned up cost optimization test data")
except Exception as e:
self.log(f"Cost optimization test failed: {e}", success=False)
self.log(f"Cost optimization test failed: {e}", success = False)
def print_summary(self):
def print_summary(self) -> None:
"""打印测试总结"""
print("\n" + "=" * 60)
print("\n" + " = " * 60)
print("Test Summary")
print("=" * 60)
print(" = " * 60)
total = len(self.test_results)
passed = sum(1 for _, success in self.test_results if success)
@@ -720,10 +720,10 @@ class TestOpsManager:
if not success:
print(f"{message}")
print("=" * 60)
print(" = " * 60)
def main():
def main() -> None:
"""主函数"""
test = TestOpsManager()
test.run_all_tests()

View File

@@ -10,7 +10,7 @@ from typing import Any
class TingwuClient:
def __init__(self):
def __init__(self) -> None:
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
@@ -31,7 +31,7 @@ class TingwuClient:
"x-acs-action": "CreateTask",
"x-acs-version": "2023-09-30",
"x-acs-date": timestamp,
"Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing",
"Authorization": f"ACS3-HMAC-SHA256 Credential = {self.access_key}/acs/tingwu/cn-beijing",
}
def create_task(self, audio_url: str, language: str = "zh") -> str:
@@ -43,17 +43,17 @@ class TingwuClient:
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key
access_key_id = self.access_key, access_key_secret = self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
request = tingwu_models.CreateTaskRequest(
type="offline",
input=tingwu_models.Input(source="OSS", file_url=audio_url),
parameters=tingwu_models.Parameters(
transcription=tingwu_models.Transcription(
diarization_enabled=True, sentence_max_length=20
type = "offline",
input = tingwu_models.Input(source = "OSS", file_url = audio_url),
parameters = tingwu_models.Parameters(
transcription = tingwu_models.Transcription(
diarization_enabled = True, sentence_max_length = 20
)
),
)
@@ -78,12 +78,9 @@ class TingwuClient:
"""获取任务结果"""
try:
# 导入移到文件顶部会导致循环导入,保持在这里
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key
access_key_id = self.access_key, access_key_secret = self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)

View File

@@ -37,7 +37,7 @@ DEFAULT_RETRY_COUNT = 3 # 默认重试次数
DEFAULT_RETRY_DELAY = 5 # 默认重试延迟(秒)
# Configure logging
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)
@@ -87,16 +87,16 @@ class WorkflowTask:
workflow_id: str
name: str
task_type: str # analyze, align, discover_relations, notify, custom
config: dict = field(default_factory=dict)
config: dict = field(default_factory = dict)
order: int = 0
depends_on: list[str] = field(default_factory=list)
depends_on: list[str] = field(default_factory = list)
timeout_seconds: int = 300
retry_count: int = 3
retry_delay: int = 5
created_at: str = ""
updated_at: str = ""
def __post_init__(self):
def __post_init__(self) -> None:
if not self.created_at:
self.created_at = datetime.now().isoformat()
if not self.updated_at:
@@ -112,7 +112,7 @@ class WebhookConfig:
webhook_type: str # feishu, dingtalk, slack, custom
url: str
secret: str = "" # 用于签名验证
headers: dict = field(default_factory=dict)
headers: dict = field(default_factory = dict)
template: str = "" # 消息模板
is_active: bool = True
created_at: str = ""
@@ -121,7 +121,7 @@ class WebhookConfig:
success_count: int = 0
fail_count: int = 0
def __post_init__(self):
def __post_init__(self) -> None:
if not self.created_at:
self.created_at = datetime.now().isoformat()
if not self.updated_at:
@@ -140,8 +140,8 @@ class Workflow:
status: str = "active"
schedule: str | None = None # cron expression or interval
schedule_type: str = "manual" # manual, cron, interval
config: dict = field(default_factory=dict)
webhook_ids: list[str] = field(default_factory=list)
config: dict = field(default_factory = dict)
webhook_ids: list[str] = field(default_factory = list)
is_active: bool = True
created_at: str = ""
updated_at: str = ""
@@ -151,7 +151,7 @@ class Workflow:
success_count: int = 0
fail_count: int = 0
def __post_init__(self):
def __post_init__(self) -> None:
if not self.created_at:
self.created_at = datetime.now().isoformat()
if not self.updated_at:
@@ -169,12 +169,12 @@ class WorkflowLog:
start_time: str | None = None
end_time: str | None = None
duration_ms: int = 0
input_data: dict = field(default_factory=dict)
output_data: dict = field(default_factory=dict)
input_data: dict = field(default_factory = dict)
output_data: dict = field(default_factory = dict)
error_message: str = ""
created_at: str = ""
def __post_init__(self):
def __post_init__(self) -> None:
if not self.created_at:
self.created_at = datetime.now().isoformat()
@@ -182,8 +182,8 @@ class WorkflowLog:
class WebhookNotifier:
"""Webhook 通知器 - 支持飞书、钉钉、Slack"""
def __init__(self):
self.http_client = httpx.AsyncClient(timeout=30.0)
def __init__(self) -> None:
self.http_client = httpx.AsyncClient(timeout = 30.0)
async def send(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Webhook 通知"""
@@ -210,7 +210,7 @@ class WebhookNotifier:
# 签名计算
if config.secret:
string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod = hashlib.sha256).digest()
sign = base64.b64encode(hmac_code).decode("utf-8")
else:
sign = ""
@@ -250,7 +250,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json=payload, headers=headers)
response = await self.http_client.post(config.url, json = payload, headers = headers)
response.raise_for_status()
result = response.json()
@@ -265,10 +265,10 @@ class WebhookNotifier:
secret_enc = config.secret.encode("utf-8")
string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new(
secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256
secret_enc, string_to_sign.encode("utf-8"), digestmod = hashlib.sha256
).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
url = f"{config.url}&timestamp={timestamp}&sign={sign}"
url = f"{config.url}&timestamp = {timestamp}&sign = {sign}"
else:
url = config.url
@@ -295,7 +295,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(url, json=payload, headers=headers)
response = await self.http_client.post(url, json = payload, headers = headers)
response.raise_for_status()
result = response.json()
@@ -316,7 +316,7 @@ class WebhookNotifier:
headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json=payload, headers=headers)
response = await self.http_client.post(config.url, json = payload, headers = headers)
response.raise_for_status()
return response.text == "ok"
@@ -325,12 +325,12 @@ class WebhookNotifier:
"""发送自定义 Webhook 通知"""
headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(config.url, json=message, headers=headers)
response = await self.http_client.post(config.url, json = message, headers = headers)
response.raise_for_status()
return True
async def close(self):
async def close(self) -> None:
"""关闭 HTTP 客户端"""
await self.http_client.aclose()
@@ -343,7 +343,7 @@ class WorkflowManager:
DEFAULT_RETRY_COUNT: int = 3
DEFAULT_RETRY_DELAY: int = 5
def __init__(self, db_manager=None):
def __init__(self, db_manager = None) -> None:
self.db = db_manager
self.scheduler = AsyncIOScheduler()
self.notifier = WebhookNotifier()
@@ -381,13 +381,13 @@ class WorkflowManager:
def stop(self) -> None:
"""停止工作流管理器"""
if self.scheduler.running:
self.scheduler.shutdown(wait=True)
self.scheduler.shutdown(wait = True)
logger.info("Workflow scheduler stopped")
async def _load_and_schedule_workflows(self):
async def _load_and_schedule_workflows(self) -> None:
"""从数据库加载并调度所有活跃工作流"""
try:
workflows = self.list_workflows(status="active")
workflows = self.list_workflows(status = "active")
for workflow in workflows:
if workflow.schedule and workflow.is_active:
self._schedule_workflow(workflow)
@@ -408,25 +408,25 @@ class WorkflowManager:
elif workflow.schedule_type == "interval":
# 间隔调度
interval_minutes = int(workflow.schedule)
trigger = IntervalTrigger(minutes=interval_minutes)
trigger = IntervalTrigger(minutes = interval_minutes)
else:
return
self.scheduler.add_job(
func=self._execute_workflow_job,
trigger=trigger,
id=job_id,
args=[workflow.id],
replace_existing=True,
max_instances=1,
coalesce=True,
func = self._execute_workflow_job,
trigger = trigger,
id = job_id,
args = [workflow.id],
replace_existing = True,
max_instances = 1,
coalesce = True,
)
logger.info(
f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}"
)
async def _execute_workflow_job(self, workflow_id: str):
async def _execute_workflow_job(self, workflow_id: str) -> None:
"""调度器调用的工作流执行函数"""
try:
await self.execute_workflow(workflow_id)
@@ -488,7 +488,7 @@ class WorkflowManager:
"""获取工作流"""
conn = self.db.get_conn()
try:
row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id,)).fetchone()
row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id, )).fetchone()
if not row:
return None
@@ -516,7 +516,7 @@ class WorkflowManager:
conditions.append("workflow_type = ?")
params.append(workflow_type)
where_clause = " AND ".join(conditions) if conditions else "1=1"
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute(
f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params
@@ -585,10 +585,10 @@ class WorkflowManager:
self.scheduler.remove_job(job_id)
# 删除相关任务
conn.execute("DELETE FROM workflow_tasks WHERE workflow_id = ?", (workflow_id,))
conn.execute("DELETE FROM workflow_tasks WHERE workflow_id = ?", (workflow_id, ))
# 删除工作流
conn.execute("DELETE FROM workflows WHERE id = ?", (workflow_id,))
conn.execute("DELETE FROM workflows WHERE id = ?", (workflow_id, ))
conn.commit()
return True
@@ -598,24 +598,24 @@ class WorkflowManager:
def _row_to_workflow(self, row) -> Workflow:
"""将数据库行转换为 Workflow 对象"""
return Workflow(
id=row["id"],
name=row["name"],
description=row["description"] or "",
workflow_type=row["workflow_type"],
project_id=row["project_id"],
status=row["status"],
schedule=row["schedule"],
schedule_type=row["schedule_type"],
config=json.loads(row["config"]) if row["config"] else {},
webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
last_run_at=row["last_run_at"],
next_run_at=row["next_run_at"],
run_count=row["run_count"] or 0,
success_count=row["success_count"] or 0,
fail_count=row["fail_count"] or 0,
id = row["id"],
name = row["name"],
description = row["description"] or "",
workflow_type = row["workflow_type"],
project_id = row["project_id"],
status = row["status"],
schedule = row["schedule"],
schedule_type = row["schedule_type"],
config = json.loads(row["config"]) if row["config"] else {},
webhook_ids = json.loads(row["webhook_ids"]) if row["webhook_ids"] else [],
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
last_run_at = row["last_run_at"],
next_run_at = row["next_run_at"],
run_count = row["run_count"] or 0,
success_count = row["success_count"] or 0,
fail_count = row["fail_count"] or 0,
)
# ==================== Workflow Task CRUD ====================
@@ -654,7 +654,7 @@ class WorkflowManager:
"""获取任务"""
conn = self.db.get_conn()
try:
row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id,)).fetchone()
row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id, )).fetchone()
if not row:
return None
@@ -669,7 +669,7 @@ class WorkflowManager:
try:
rows = conn.execute(
"SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order",
(workflow_id,),
(workflow_id, ),
).fetchall()
return [self._row_to_task(row) for row in rows]
@@ -720,7 +720,7 @@ class WorkflowManager:
"""删除任务"""
conn = self.db.get_conn()
try:
conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id,))
conn.execute("DELETE FROM workflow_tasks WHERE id = ?", (task_id, ))
conn.commit()
return True
finally:
@@ -729,18 +729,18 @@ class WorkflowManager:
def _row_to_task(self, row) -> WorkflowTask:
"""将数据库行转换为 WorkflowTask 对象"""
return WorkflowTask(
id=row["id"],
workflow_id=row["workflow_id"],
name=row["name"],
task_type=row["task_type"],
config=json.loads(row["config"]) if row["config"] else {},
order=row["task_order"] or 0,
depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [],
timeout_seconds=row["timeout_seconds"] or 300,
retry_count=row["retry_count"] or 3,
retry_delay=row["retry_delay"] or 5,
created_at=row["created_at"],
updated_at=row["updated_at"],
id = row["id"],
workflow_id = row["workflow_id"],
name = row["name"],
task_type = row["task_type"],
config = json.loads(row["config"]) if row["config"] else {},
order = row["task_order"] or 0,
depends_on = json.loads(row["depends_on"]) if row["depends_on"] else [],
timeout_seconds = row["timeout_seconds"] or 300,
retry_count = row["retry_count"] or 3,
retry_delay = row["retry_delay"] or 5,
created_at = row["created_at"],
updated_at = row["updated_at"],
)
# ==================== Webhook Config CRUD ====================
@@ -781,7 +781,7 @@ class WorkflowManager:
conn = self.db.get_conn()
try:
row = conn.execute(
"SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,)
"SELECT * FROM webhook_configs WHERE id = ?", (webhook_id, )
).fetchone()
if not row:
@@ -844,7 +844,7 @@ class WorkflowManager:
"""删除 Webhook 配置"""
conn = self.db.get_conn()
try:
conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id,))
conn.execute("DELETE FROM webhook_configs WHERE id = ?", (webhook_id, ))
conn.commit()
return True
finally:
@@ -875,19 +875,19 @@ class WorkflowManager:
def _row_to_webhook(self, row) -> WebhookConfig:
"""将数据库行转换为 WebhookConfig 对象"""
return WebhookConfig(
id=row["id"],
name=row["name"],
webhook_type=row["webhook_type"],
url=row["url"],
secret=row["secret"] or "",
headers=json.loads(row["headers"]) if row["headers"] else {},
template=row["template"] or "",
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
last_used_at=row["last_used_at"],
success_count=row["success_count"] or 0,
fail_count=row["fail_count"] or 0,
id = row["id"],
name = row["name"],
webhook_type = row["webhook_type"],
url = row["url"],
secret = row["secret"] or "",
headers = json.loads(row["headers"]) if row["headers"] else {},
template = row["template"] or "",
is_active = bool(row["is_active"]),
created_at = row["created_at"],
updated_at = row["updated_at"],
last_used_at = row["last_used_at"],
success_count = row["success_count"] or 0,
fail_count = row["fail_count"] or 0,
)
# ==================== Workflow Log ====================
@@ -952,7 +952,7 @@ class WorkflowManager:
"""获取日志"""
conn = self.db.get_conn()
try:
row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id,)).fetchone()
row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id, )).fetchone()
if not row:
return None
@@ -985,7 +985,7 @@ class WorkflowManager:
conditions.append("status = ?")
params.append(status)
where_clause = " AND ".join(conditions) if conditions else "1=1"
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
rows = conn.execute(
f"""SELECT * FROM workflow_logs
@@ -1003,7 +1003,7 @@ class WorkflowManager:
"""获取工作流统计"""
conn = self.db.get_conn()
try:
since = (datetime.now() - timedelta(days=days)).isoformat()
since = (datetime.now() - timedelta(days = days)).isoformat()
# 总执行次数
total = conn.execute(
@@ -1060,17 +1060,17 @@ class WorkflowManager:
def _row_to_log(self, row) -> WorkflowLog:
"""将数据库行转换为 WorkflowLog 对象"""
return WorkflowLog(
id=row["id"],
workflow_id=row["workflow_id"],
task_id=row["task_id"],
status=row["status"],
start_time=row["start_time"],
end_time=row["end_time"],
duration_ms=row["duration_ms"] or 0,
input_data=json.loads(row["input_data"]) if row["input_data"] else {},
output_data=json.loads(row["output_data"]) if row["output_data"] else {},
error_message=row["error_message"] or "",
created_at=row["created_at"],
id = row["id"],
workflow_id = row["workflow_id"],
task_id = row["task_id"],
status = row["status"],
start_time = row["start_time"],
end_time = row["end_time"],
duration_ms = row["duration_ms"] or 0,
input_data = json.loads(row["input_data"]) if row["input_data"] else {},
output_data = json.loads(row["output_data"]) if row["output_data"] else {},
error_message = row["error_message"] or "",
created_at = row["created_at"],
)
# ==================== Workflow Execution ====================
@@ -1086,15 +1086,15 @@ class WorkflowManager:
# 更新最后运行时间
now = datetime.now().isoformat()
self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1)
self.update_workflow(workflow_id, last_run_at = now, run_count = workflow.run_count + 1)
# 创建工作流执行日志
log = WorkflowLog(
id=str(uuid.uuid4())[:UUID_LENGTH],
workflow_id=workflow_id,
status=TaskStatus.RUNNING.value,
start_time=now,
input_data=input_data or {},
id = str(uuid.uuid4())[:UUID_LENGTH],
workflow_id = workflow_id,
status = TaskStatus.RUNNING.value,
start_time = now,
input_data = input_data or {},
)
self.create_log(log)
@@ -1113,21 +1113,21 @@ class WorkflowManager:
results = await self._execute_tasks_with_deps(tasks, input_data, log.id)
# 发送通知
await self._send_workflow_notification(workflow, results, success=True)
await self._send_workflow_notification(workflow, results, success = True)
# 更新日志为成功
end_time = datetime.now()
duration = int((end_time - start_time).total_seconds() * 1000)
self.update_log(
log.id,
status=TaskStatus.SUCCESS.value,
end_time=end_time.isoformat(),
duration_ms=duration,
output_data=results,
status = TaskStatus.SUCCESS.value,
end_time = end_time.isoformat(),
duration_ms = duration,
output_data = results,
)
# 更新成功计数
self.update_workflow(workflow_id, success_count=workflow.success_count + 1)
self.update_workflow(workflow_id, success_count = workflow.success_count + 1)
return {
"success": True,
@@ -1145,17 +1145,17 @@ class WorkflowManager:
duration = int((end_time - start_time).total_seconds() * 1000)
self.update_log(
log.id,
status=TaskStatus.FAILED.value,
end_time=end_time.isoformat(),
duration_ms=duration,
error_message=str(e),
status = TaskStatus.FAILED.value,
end_time = end_time.isoformat(),
duration_ms = duration,
error_message = str(e),
)
# 更新失败计数
self.update_workflow(workflow_id, fail_count=workflow.fail_count + 1)
self.update_workflow(workflow_id, fail_count = workflow.fail_count + 1)
# 发送失败通知
await self._send_workflow_notification(workflow, {"error": str(e)}, success=False)
await self._send_workflow_notification(workflow, {"error": str(e)}, success = False)
raise
@@ -1185,7 +1185,7 @@ class WorkflowManager:
task_input = {**input_data, **results}
task_coros.append(self._execute_single_task(task, task_input, log_id))
task_results = await asyncio.gather(*task_coros, return_exceptions=True)
task_results = await asyncio.gather(*task_coros, return_exceptions = True)
for task, result in zip(ready_tasks, task_results):
if isinstance(result, Exception):
@@ -1217,25 +1217,25 @@ class WorkflowManager:
# 创建任务日志
task_log = WorkflowLog(
id=str(uuid.uuid4())[:UUID_LENGTH],
workflow_id=task.workflow_id,
task_id=task.id,
status=TaskStatus.RUNNING.value,
start_time=datetime.now().isoformat(),
input_data=input_data,
id = str(uuid.uuid4())[:UUID_LENGTH],
workflow_id = task.workflow_id,
task_id = task.id,
status = TaskStatus.RUNNING.value,
start_time = datetime.now().isoformat(),
input_data = input_data,
)
self.create_log(task_log)
try:
# 设置超时
result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds)
result = await asyncio.wait_for(handler(task, input_data), timeout = task.timeout_seconds)
# 更新任务日志为成功
self.update_log(
task_log.id,
status=TaskStatus.SUCCESS.value,
end_time=datetime.now().isoformat(),
output_data={"result": result} if not isinstance(result, dict) else result,
status = TaskStatus.SUCCESS.value,
end_time = datetime.now().isoformat(),
output_data = {"result": result} if not isinstance(result, dict) else result,
)
return result
@@ -1243,18 +1243,18 @@ class WorkflowManager:
except TimeoutError:
self.update_log(
task_log.id,
status=TaskStatus.FAILED.value,
end_time=datetime.now().isoformat(),
error_message="Task timeout",
status = TaskStatus.FAILED.value,
end_time = datetime.now().isoformat(),
error_message = "Task timeout",
)
raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s")
except Exception as e:
self.update_log(
task_log.id,
status=TaskStatus.FAILED.value,
end_time=datetime.now().isoformat(),
error_message=str(e),
status = TaskStatus.FAILED.value,
end_time = datetime.now().isoformat(),
error_message = str(e),
)
raise
@@ -1415,7 +1415,7 @@ class WorkflowManager:
async def _send_workflow_notification(
self, workflow: Workflow, results: dict, success: bool = True
):
) -> None:
"""发送工作流执行通知"""
if not workflow.webhook_ids:
return
@@ -1476,7 +1476,7 @@ class WorkflowManager:
**结果:**
```json
{json.dumps(results, ensure_ascii=False, indent=2)}
{json.dumps(results, ensure_ascii = False, indent = 2)}
```
""",
}
@@ -1510,7 +1510,7 @@ class WorkflowManager:
_workflow_manager = None
def get_workflow_manager(db_manager=None) -> WorkflowManager:
def get_workflow_manager(db_manager = None) -> WorkflowManager:
"""获取 WorkflowManager 单例"""
global _workflow_manager
if _workflow_manager is None:

View File

@@ -55,7 +55,7 @@ def check_bare_excepts(content: str, file_path: Path) -> list[dict]:
for i, line in enumerate(lines, 1):
stripped = line.strip()
# 检查 except Exception: 或 except :
# 检查 except Exception: 或 except Exception:
if re.match(r'^except\s*:', stripped):
issues.append({
"line": i,
@@ -140,7 +140,7 @@ def check_magic_numbers(content: str, file_path: Path) -> list[dict]:
lines = content.split('\n')
# 常见魔法数字模式(排除常见索引和简单值)
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3,})(?![\w\d_])')
magic_pattern = re.compile(r'(?<![\w\d_])(\d{3, })(?![\w\d_])')
for i, line in enumerate(lines, 1):
if line.strip().startswith('#'):
@@ -217,7 +217,7 @@ def fix_line_length(content: str) -> str:
for line in lines:
if len(line) > 100:
# 尝试在逗号或运算符处折行
if ',' in line[80:]:
if ', ' in line[80:]:
# 简单处理:截断并添加续行
indent = len(line) - len(line.lstrip())
new_lines.append(line)
@@ -231,7 +231,7 @@ def fix_line_length(content: str) -> str:
def analyze_file(file_path: Path) -> dict:
"""分析单个文件"""
try:
content = file_path.read_text(encoding='utf-8')
content = file_path.read_text(encoding = 'utf-8')
except Exception as e:
return {"error": str(e)}
@@ -251,7 +251,7 @@ def analyze_file(file_path: Path) -> dict:
def fix_file(file_path: Path, issues: dict) -> bool:
"""自动修复文件问题"""
try:
content = file_path.read_text(encoding='utf-8')
content = file_path.read_text(encoding = 'utf-8')
original_content = content
# 修复裸异常
@@ -260,7 +260,7 @@ def fix_file(file_path: Path, issues: dict) -> bool:
# 如果有修改,写回文件
if content != original_content:
file_path.write_text(content, encoding='utf-8')
file_path.write_text(content, encoding = 'utf-8')
return True
return False
except Exception as e:
@@ -333,7 +333,7 @@ def generate_report(all_issues: dict) -> str:
return '\n'.join(lines)
def git_commit_and_push():
def git_commit_and_push() -> None:
"""提交并推送代码"""
try:
os.chdir(PROJECT_PATH)
@@ -341,15 +341,15 @@ def git_commit_and_push():
# 检查是否有修改
result = subprocess.run(
["git", "status", "--porcelain"],
capture_output=True,
text=True
capture_output = True,
text = True
)
if not result.stdout.strip():
return "没有需要提交的更改"
# 添加所有修改
subprocess.run(["git", "add", "-A"], check=True)
subprocess.run(["git", "add", "-A"], check = True)
# 提交
subprocess.run(
@@ -359,11 +359,11 @@ def git_commit_and_push():
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解"""],
check=True
check = True
)
# 推送
subprocess.run(["git", "push"], check=True)
subprocess.run(["git", "push"], check = True)
return "✅ 提交并推送成功"
except subprocess.CalledProcessError as e:
@@ -371,7 +371,7 @@ def git_commit_and_push():
except Exception as e:
return f"❌ 错误: {e}"
def main():
def main() -> None:
"""主函数"""
print("🔍 开始代码审查...")
@@ -392,7 +392,7 @@ def main():
# 生成报告
report_content = generate_report(all_issues)
report_path = PROJECT_PATH / "AUTO_CODE_REVIEW_REPORT.md"
report_path.write_text(report_content, encoding='utf-8')
report_path.write_text(report_content, encoding = 'utf-8')
print("\n📄 报告已生成:", report_path)
@@ -402,7 +402,7 @@ def main():
print(git_result)
# 追加提交结果到报告
with open(report_path, 'a', encoding='utf-8') as f:
with open(report_path, 'a', encoding = 'utf-8') as f:
f.write(f"\n\n## Git 提交结果\n\n{git_result}\n")
print("\n✅ 代码审查完成!")

View File

@@ -16,7 +16,7 @@ class CodeIssue:
issue_type: str,
message: str,
severity: str = "info",
):
) -> None:
self.file_path = file_path
self.line_no = line_no
self.issue_type = issue_type
@@ -24,12 +24,12 @@ class CodeIssue:
self.severity = severity # info, warning, error
self.fixed = False
def __repr__(self):
def __repr__(self) -> None:
return f"{self.severity.upper()}: {self.file_path}:{self.line_no} - {self.issue_type}: {self.message}"
class CodeReviewer:
def __init__(self, base_path: str):
def __init__(self, base_path: str) -> None:
self.base_path = Path(base_path)
self.issues: list[CodeIssue] = []
self.fixed_issues: list[CodeIssue] = []
@@ -45,7 +45,7 @@ class CodeReviewer:
def scan_file(self, file_path: Path) -> None:
"""扫描单个文件"""
try:
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception as e:
@@ -110,7 +110,7 @@ class CodeReviewer:
match = re.match(r"^(?:from\s+(\S+)\s+)?import\s+(.+)$", line.strip())
if match:
module = match.group(1) or ""
names = match.group(2).split(",")
names = match.group(2).split(", ")
for name in names:
name = name.strip().split()[0] # 处理 'as' 别名
key = f"{module}.{name}" if module else name
@@ -223,10 +223,10 @@ class CodeReviewer:
"""检查魔法数字"""
# 常见的魔法数字模式
magic_patterns = [
(r"=\s*(\d{3,})\s*[^:]", "可能的魔法数字"),
(r"timeout\s*=\s*(\d+)", "timeout 魔法数字"),
(r"limit\s*=\s*(\d+)", "limit 魔法数字"),
(r"port\s*=\s*(\d+)", "port 魔法数字"),
(r" = \s*(\d{3, })\s*[^:]", "可能的魔法数字"),
(r"timeout\s* = \s*(\d+)", "timeout 魔法数字"),
(r"limit\s* = \s*(\d+)", "limit 魔法数字"),
(r"port\s* = \s*(\d+)", "port 魔法数字"),
]
for i, line in enumerate(lines, 1):
@@ -238,7 +238,7 @@ class CodeReviewer:
for pattern, msg in magic_patterns:
if re.search(pattern, code_part, re.IGNORECASE):
# 排除常见的合理数字
match = re.search(r"(\d{3,})", code_part)
match = re.search(r"(\d{3, })", code_part)
if match:
num = int(match.group(1))
if num in [
@@ -303,7 +303,7 @@ class CodeReviewer:
for i, line in enumerate(lines, 1):
# 检查硬编码密钥
if re.search(
r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']',
r'(password|secret|key|token)\s* = \s*["\'][^"\']+["\']',
line,
re.IGNORECASE,
):
@@ -340,7 +340,7 @@ class CodeReviewer:
continue
try:
with open(full_path, "r", encoding="utf-8") as f:
with open(full_path, "r", encoding = "utf-8") as f:
content = f.read()
lines = content.split("\n")
except Exception as e:
@@ -363,9 +363,9 @@ class CodeReviewer:
idx = issue.line_no - 1
if 0 <= idx < len(lines):
line = lines[idx]
# 将 except: 改为 except Exception:
# 将 except Exception: 改为 except Exception:
if re.search(r"except\s*:\s*$", line.strip()):
lines[idx] = line.replace("except:", "except Exception:")
lines[idx] = line.replace("except Exception:", "except Exception:")
issue.fixed = True
elif re.search(r"except\s+Exception\s*:\s*$", line.strip()):
# 已经是 Exception但可能需要更具体
@@ -373,7 +373,7 @@ class CodeReviewer:
# 如果文件有修改,写回
if lines != original_lines:
with open(full_path, "w", encoding="utf-8") as f:
with open(full_path, "w", encoding = "utf-8") as f:
f.write("\n".join(lines))
print(f"Fixed issues in {file_path}")
@@ -421,7 +421,7 @@ class CodeReviewer:
return "\n".join(report)
def main():
def main() -> None:
base_path = "/root/.openclaw/workspace/projects/insightflow/backend"
reviewer = CodeReviewer(base_path)
@@ -439,7 +439,7 @@ def main():
# 生成报告
report = reviewer.generate_report()
report_path = Path(base_path).parent / "CODE_REVIEW_REPORT.md"
with open(report_path, "w", encoding="utf-8") as f:
with open(report_path, "w", encoding = "utf-8") as f:
f.write(report)
print(f"\n报告已保存到: {report_path}")