fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
- 修复缺失的urllib.parse导入
This commit is contained in:
OpenClaw Bot
2026-02-28 06:03:09 +08:00
parent ff83cab6c7
commit fe3d64a1d2
41 changed files with 4501 additions and 1176 deletions

View File

@@ -137,3 +137,8 @@
### unused_import ### unused_import
- `/root/.openclaw/workspace/projects/insightflow/auto_code_fixer.py:11` - 未使用的导入: Any - `/root/.openclaw/workspace/projects/insightflow/auto_code_fixer.py:11` - 未使用的导入: Any
## Git 提交结果
✅ 提交并推送成功

View File

@@ -205,7 +205,7 @@ MIT
--- ---
## Phase 8: 商业化与规模化 - 进行中 🚧 ## Phase 8: 商业化与规模化 - 已完成 ✅
基于 Phase 1-7 的完整功能Phase 8 聚焦**商业化落地**和**规模化运营** 基于 Phase 1-7 的完整功能Phase 8 聚焦**商业化落地**和**规模化运营**
@@ -231,25 +231,25 @@ MIT
- ✅ 数据保留策略(自动归档、数据删除) - ✅ 数据保留策略(自动归档、数据删除)
### 4. 运营与增长工具 📈 ### 4. 运营与增长工具 📈
**优先级: P1** **优先级: P1** | **状态: ✅ 已完成**
- 用户行为分析Mixpanel/Amplitude 集成) - 用户行为分析Mixpanel/Amplitude 集成)
- A/B 测试框架 - A/B 测试框架
- 邮件营销自动化(欢迎序列、流失挽回) - 邮件营销自动化(欢迎序列、流失挽回)
- 推荐系统(邀请返利、团队升级激励) - 推荐系统(邀请返利、团队升级激励)
### 5. 开发者生态 🛠️ ### 5. 开发者生态 🛠️
**优先级: P2** **优先级: P2** | **状态: ✅ 已完成**
- SDK 发布Python/JavaScript/Go - SDK 发布Python/JavaScript/Go
- 模板市场(行业模板、预训练模型) - 模板市场(行业模板、预训练模型)
- 插件市场(第三方插件审核与分发) - 插件市场(第三方插件审核与分发)
- 开发者文档与示例代码 - 开发者文档与示例代码
### 6. 全球化与本地化 🌍 ### 6. 全球化与本地化 🌍
**优先级: P2** **优先级: P2** | **状态: ✅ 已完成**
- 多语言支持i18n至少 10 种语言) - 多语言支持i18n12 种语言)
- 区域数据中心(北美、欧洲、亚太) - 区域数据中心(北美、欧洲、亚太)
- 本地化支付(各国主流支付方式) - 本地化支付(各国主流支付方式)
- 时区与日历本地化 - 时区与日历本地化
### 7. AI 能力增强 🤖 ### 7. AI 能力增强 🤖
**优先级: P1** | **状态: ✅ 已完成** **优先级: P1** | **状态: ✅ 已完成**
@@ -259,11 +259,11 @@ MIT
- ✅ 预测性分析(趋势预测、异常检测) - ✅ 预测性分析(趋势预测、异常检测)
### 8. 运维与监控 🔧 ### 8. 运维与监控 🔧
**优先级: P2** **优先级: P2** | **状态: ✅ 已完成**
- 实时告警系统PagerDuty/Opsgenie 集成) - 实时告警系统PagerDuty/Opsgenie 集成)
- 容量规划与自动扩缩容 - 容量规划与自动扩缩容
- 灾备与故障转移(多活架构) - 灾备与故障转移(多活架构)
- 成本优化(资源利用率监控) - 成本优化(资源利用率监控)
--- ---
@@ -516,3 +516,20 @@ MIT
**建议开发顺序**: 1 → 2 → 3 → 7 → 4 → 5 → 6 → 8 **建议开发顺序**: 1 → 2 → 3 → 7 → 4 → 5 → 6 → 8
**Phase 8 全部完成!** 🎉 **Phase 8 全部完成!** 🎉
**实际完成时间**: 3 天 (2026-02-25 至 2026-02-28)
---
## 项目总览
| Phase | 描述 | 状态 | 完成时间 |
|-------|------|------|----------|
| Phase 1-3 | 基础功能 | ✅ 已完成 | 2026-02 |
| Phase 4 | Agent 助手与知识溯源 | ✅ 已完成 | 2026-02 |
| Phase 5 | 高级功能 | ✅ 已完成 | 2026-02 |
| Phase 6 | API 开放平台 | ✅ 已完成 | 2026-02 |
| Phase 7 | 智能化与生态扩展 | ✅ 已完成 | 2026-02-24 |
| Phase 8 | 商业化与规模化 | ✅ 已完成 | 2026-02-28 |
**InsightFlow 全部功能开发完成!** 🚀

View File

@@ -8,13 +8,19 @@ import os
import re import re
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Any
class CodeIssue: class CodeIssue:
"""代码问题记录""" """代码问题记录"""
def __init__(self, file_path: str, line_no: int, issue_type: str, message: str, severity: str = "warning"): def __init__(
self,
file_path: str,
line_no: int,
issue_type: str,
message: str,
severity: str = "warning",
):
self.file_path = file_path self.file_path = file_path
self.line_no = line_no self.line_no = line_no
self.issue_type = issue_type self.issue_type = issue_type
@@ -83,7 +89,9 @@ class CodeFixer:
# 检查敏感信息 # 检查敏感信息
self._check_sensitive_info(file_path, content, lines) self._check_sensitive_info(file_path, content, lines)
def _check_duplicate_imports(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_duplicate_imports(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查重复导入""" """检查重复导入"""
imports = {} imports = {}
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
@@ -94,38 +102,64 @@ class CodeFixer:
key = f"{module}:{names}" key = f"{module}:{names}"
if key in imports: if key in imports:
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "duplicate_import", f"重复导入: {line.strip()}", "warning") CodeIssue(
str(file_path),
i,
"duplicate_import",
f"重复导入: {line.strip()}",
"warning",
)
) )
imports[key] = i imports[key] = i
def _check_bare_exceptions(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_bare_exceptions(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查裸异常捕获""" """检查裸异常捕获"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line): if re.search(r"except\s*:\s*$", line) or re.search(r"except\s*:\s*#", line):
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "bare_exception", "裸异常捕获,应指定具体异常类型", "error") CodeIssue(
str(file_path),
i,
"bare_exception",
"裸异常捕获,应指定具体异常类型",
"error",
)
) )
def _check_pep8_issues(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_pep8_issues(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查 PEP8 格式问题""" """检查 PEP8 格式问题"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 行长度超过 120 # 行长度超过 120
if len(line) > 120: if len(line) > 120:
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "line_too_long", f"行长度 {len(line)} 超过 120 字符", "warning") CodeIssue(
str(file_path),
i,
"line_too_long",
f"行长度 {len(line)} 超过 120 字符",
"warning",
)
) )
# 行尾空格 # 行尾空格
if line.rstrip() != line: if line.rstrip() != line:
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "trailing_whitespace", "行尾有空格", "info") CodeIssue(
str(file_path), i, "trailing_whitespace", "行尾有空格", "info"
)
) )
# 多余的空行 # 多余的空行
if i > 1 and line.strip() == "" and lines[i - 2].strip() == "": if i > 1 and line.strip() == "" and lines[i - 2].strip() == "":
if i < len(lines) and lines[i].strip() != "": if i < len(lines) and lines[i].strip() != "":
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "extra_blank_line", "多余的空行", "info") CodeIssue(
str(file_path), i, "extra_blank_line", "多余的空行", "info"
)
) )
def _check_unused_imports(self, file_path: Path, content: str) -> None: def _check_unused_imports(self, file_path: Path, content: str) -> None:
@@ -157,10 +191,18 @@ class CodeFixer:
for name, line in imports.items(): for name, line in imports.items():
if name not in used_names and not name.startswith("_"): if name not in used_names and not name.startswith("_"):
self.issues.append( self.issues.append(
CodeIssue(str(file_path), line, "unused_import", f"未使用的导入: {name}", "warning") CodeIssue(
str(file_path),
line,
"unused_import",
f"未使用的导入: {name}",
"warning",
)
) )
def _check_type_annotations(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_type_annotations(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查类型注解""" """检查类型注解"""
try: try:
tree = ast.parse(content) tree = ast.parse(content)
@@ -171,7 +213,11 @@ class CodeFixer:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
# 检查函数参数类型注解 # 检查函数参数类型注解
for arg in node.args.args: for arg in node.args.args:
if arg.annotation is None and arg.arg != "self" and arg.arg != "cls": if (
arg.annotation is None
and arg.arg != "self"
and arg.arg != "cls"
):
self.issues.append( self.issues.append(
CodeIssue( CodeIssue(
str(file_path), str(file_path),
@@ -182,22 +228,40 @@ class CodeFixer:
) )
) )
def _check_string_formatting(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_string_formatting(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查字符串格式化""" """检查字符串格式化"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 检查 % 格式化 # 检查 % 格式化
if re.search(r"['\"].*%[sdif].*['\"]\s*%", line) or re.search(r"['\"].*%\(.*\).*['\"]\s*%", line): if re.search(r"['\"].*%[sdif].*['\"]\s*%", line) or re.search(
r"['\"].*%\(.*\).*['\"]\s*%", line
):
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "old_string_format", "使用 % 格式化,建议改为 f-string", "info") CodeIssue(
str(file_path),
i,
"old_string_format",
"使用 % 格式化,建议改为 f-string",
"info",
)
) )
# 检查 .format() # 检查 .format()
if re.search(r"['\"].*\{.*\}.*['\"]\.format\(", line): if re.search(r"['\"].*\{.*\}.*['\"]\.format\(", line):
self.issues.append( self.issues.append(
CodeIssue(str(file_path), i, "format_method", "使用 .format(),建议改为 f-string", "info") CodeIssue(
str(file_path),
i,
"format_method",
"使用 .format(),建议改为 f-string",
"info",
)
) )
def _check_magic_numbers(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_magic_numbers(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查魔法数字""" """检查魔法数字"""
# 排除的魔法数字 # 排除的魔法数字
excluded = {"0", "1", "-1", "0.0", "1.0", "100", "0.5", "3600", "86400", "1024"} excluded = {"0", "1", "-1", "0.0", "1.0", "100", "0.5", "3600", "86400", "1024"}
@@ -223,11 +287,15 @@ class CodeFixer:
) )
) )
def _check_sql_injection(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_sql_injection(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查 SQL 注入风险""" """检查 SQL 注入风险"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 检查字符串拼接 SQL # 检查字符串拼接 SQL
if re.search(r"execute\s*\(\s*['\"].*%", line) or re.search(r"execute\s*\(\s*f['\"]", line): if re.search(r"execute\s*\(\s*['\"].*%", line) or re.search(
r"execute\s*\(\s*f['\"]", line
):
self.issues.append( self.issues.append(
CodeIssue( CodeIssue(
str(file_path), str(file_path),
@@ -238,7 +306,9 @@ class CodeFixer:
) )
) )
def _check_cors_config(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_cors_config(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查 CORS 配置""" """检查 CORS 配置"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
if "allow_origins" in line and "*" in line: if "allow_origins" in line and "*" in line:
@@ -252,7 +322,9 @@ class CodeFixer:
) )
) )
def _check_sensitive_info(self, file_path: Path, content: str, lines: list[str]) -> None: def _check_sensitive_info(
self, file_path: Path, content: str, lines: list[str]
) -> None:
"""检查敏感信息泄露""" """检查敏感信息泄露"""
patterns = [ patterns = [
(r"password\s*=\s*['\"][^'\"]+['\"]", "硬编码密码"), (r"password\s*=\s*['\"][^'\"]+['\"]", "硬编码密码"),
@@ -323,7 +395,11 @@ class CodeFixer:
line_idx = issue.line_no - 1 line_idx = issue.line_no - 1
if 0 <= line_idx < len(lines) and line_idx not in fixed_lines: if 0 <= line_idx < len(lines) and line_idx not in fixed_lines:
# 检查是否是多余的空行 # 检查是否是多余的空行
if line_idx > 0 and lines[line_idx].strip() == "" and lines[line_idx - 1].strip() == "": if (
line_idx > 0
and lines[line_idx].strip() == ""
and lines[line_idx - 1].strip() == ""
):
lines.pop(line_idx) lines.pop(line_idx)
fixed_lines.add(line_idx) fixed_lines.add(line_idx)
self.fixed_issues.append(issue) self.fixed_issues.append(issue)
@@ -386,7 +462,9 @@ class CodeFixer:
report.append("") report.append("")
if self.fixed_issues: if self.fixed_issues:
for issue in self.fixed_issues: for issue in self.fixed_issues:
report.append(f"- `{issue.file_path}:{issue.line_no}` - {issue.message}") report.append(
f"- `{issue.file_path}:{issue.line_no}` - {issue.message}"
)
else: else:
report.append("") report.append("")
report.append("") report.append("")
@@ -399,7 +477,9 @@ class CodeFixer:
report.append("") report.append("")
if manual_issues: if manual_issues:
for issue in manual_issues: for issue in manual_issues:
report.append(f"- `{issue.file_path}:{issue.line_no}` [{issue.severity}] {issue.message}") report.append(
f"- `{issue.file_path}:{issue.line_no}` [{issue.severity}] {issue.message}"
)
else: else:
report.append("") report.append("")
report.append("") report.append("")
@@ -407,7 +487,11 @@ class CodeFixer:
# 其他问题 # 其他问题
report.append("## 📋 其他发现的问题") report.append("## 📋 其他发现的问题")
report.append("") report.append("")
other_issues = [i for i in self.issues if i.issue_type not in manual_types and i not in self.fixed_issues] other_issues = [
i
for i in self.issues
if i.issue_type not in manual_types and i not in self.fixed_issues
]
# 按类型分组 # 按类型分组
by_type = {} by_type = {}
@@ -420,7 +504,9 @@ class CodeFixer:
report.append(f"### {issue_type}") report.append(f"### {issue_type}")
report.append("") report.append("")
for issue in issues[:10]: # 每种类型最多显示10个 for issue in issues[:10]: # 每种类型最多显示10个
report.append(f"- `{issue.file_path}:{issue.line_no}` - {issue.message}") report.append(
f"- `{issue.file_path}:{issue.line_no}` - {issue.message}"
)
if len(issues) > 10: if len(issues) > 10:
report.append(f"- ... 还有 {len(issues) - 10} 个类似问题") report.append(f"- ... 还有 {len(issues) - 10} 个类似问题")
report.append("") report.append("")
@@ -453,7 +539,9 @@ def git_commit_and_push(project_path: str) -> tuple[bool, str]:
- 修复PEP8格式问题 - 修复PEP8格式问题
- 添加类型注解""" - 添加类型注解"""
subprocess.run(["git", "commit", "-m", commit_msg], cwd=project_path, check=True) subprocess.run(
["git", "commit", "-m", commit_msg], cwd=project_path, check=True
)
# 推送 # 推送
subprocess.run(["git", "push"], cwd=project_path, check=True) subprocess.run(["git", "push"], cwd=project_path, check=True)

View File

@@ -27,6 +27,7 @@ import httpx
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class ModelType(StrEnum): class ModelType(StrEnum):
"""模型类型""" """模型类型"""
@@ -35,6 +36,7 @@ class ModelType(StrEnum):
SUMMARIZATION = "summarization" # 摘要 SUMMARIZATION = "summarization" # 摘要
PREDICTION = "prediction" # 预测 PREDICTION = "prediction" # 预测
class ModelStatus(StrEnum): class ModelStatus(StrEnum):
"""模型状态""" """模型状态"""
@@ -44,6 +46,7 @@ class ModelStatus(StrEnum):
FAILED = "failed" FAILED = "failed"
ARCHIVED = "archived" ARCHIVED = "archived"
class MultimodalProvider(StrEnum): class MultimodalProvider(StrEnum):
"""多模态模型提供商""" """多模态模型提供商"""
@@ -52,6 +55,7 @@ class MultimodalProvider(StrEnum):
GEMINI = "gemini-pro-vision" GEMINI = "gemini-pro-vision"
KIMI_VL = "kimi-vl" KIMI_VL = "kimi-vl"
class PredictionType(StrEnum): class PredictionType(StrEnum):
"""预测类型""" """预测类型"""
@@ -60,6 +64,7 @@ class PredictionType(StrEnum):
ENTITY_GROWTH = "entity_growth" # 实体增长预测 ENTITY_GROWTH = "entity_growth" # 实体增长预测
RELATION_EVOLUTION = "relation_evolution" # 关系演变预测 RELATION_EVOLUTION = "relation_evolution" # 关系演变预测
@dataclass @dataclass
class CustomModel: class CustomModel:
"""自定义模型""" """自定义模型"""
@@ -79,6 +84,7 @@ class CustomModel:
trained_at: str | None trained_at: str | None
created_by: str created_by: str
@dataclass @dataclass
class TrainingSample: class TrainingSample:
"""训练样本""" """训练样本"""
@@ -90,6 +96,7 @@ class TrainingSample:
metadata: dict metadata: dict
created_at: str created_at: str
@dataclass @dataclass
class MultimodalAnalysis: class MultimodalAnalysis:
"""多模态分析结果""" """多模态分析结果"""
@@ -106,6 +113,7 @@ class MultimodalAnalysis:
cost: float cost: float
created_at: str created_at: str
@dataclass @dataclass
class KnowledgeGraphRAG: class KnowledgeGraphRAG:
"""基于知识图谱的 RAG 配置""" """基于知识图谱的 RAG 配置"""
@@ -122,6 +130,7 @@ class KnowledgeGraphRAG:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class RAGQuery: class RAGQuery:
"""RAG 查询记录""" """RAG 查询记录"""
@@ -137,6 +146,7 @@ class RAGQuery:
latency_ms: int latency_ms: int
created_at: str created_at: str
@dataclass @dataclass
class PredictionModel: class PredictionModel:
"""预测模型""" """预测模型"""
@@ -156,6 +166,7 @@ class PredictionModel:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class PredictionResult: class PredictionResult:
"""预测结果""" """预测结果"""
@@ -171,6 +182,7 @@ class PredictionResult:
is_correct: bool | None is_correct: bool | None
created_at: str created_at: str
@dataclass @dataclass
class SmartSummary: class SmartSummary:
"""智能摘要""" """智能摘要"""
@@ -188,6 +200,7 @@ class SmartSummary:
tokens_used: int tokens_used: int
created_at: str created_at: str
class AIManager: class AIManager:
"""AI 能力管理主类""" """AI 能力管理主类"""
@@ -304,7 +317,12 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
sample = TrainingSample( sample = TrainingSample(
id=sample_id, model_id=model_id, text=text, entities=entities, metadata=metadata or {}, created_at=now id=sample_id,
model_id=model_id,
text=text,
entities=entities,
metadata=metadata or {},
created_at=now,
) )
with self._get_db() as conn: with self._get_db() as conn:
@@ -410,20 +428,30 @@ class AIManager:
entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"]) entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"])
prompt = f"""从以下文本中提取实体,类型限定为: {', '.join(entity_types)} prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)}
文本: {text} 文本: {text}
以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}] 以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}]
只返回 JSON 数组,不要其他内容。""" 只返回 JSON 数组,不要其他内容。"""
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1} payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.1,
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -506,7 +534,10 @@ class AIManager:
async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict: async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict:
"""调用 GPT-4V""" """调用 GPT-4V"""
headers = {"Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json"} headers = {
"Authorization": f"Bearer {self.openai_api_key}",
"Content-Type": "application/json",
}
content = [{"type": "text", "text": prompt}] content = [{"type": "text", "text": prompt}]
for url in image_urls: for url in image_urls:
@@ -520,7 +551,10 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=120.0 "https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -552,7 +586,10 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://api.anthropic.com/v1/messages", headers=headers, json=payload, timeout=120.0 "https://api.anthropic.com/v1/messages",
headers=headers,
json=payload,
timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -560,23 +597,34 @@ class AIManager:
return { return {
"content": result["content"][0]["text"], "content": result["content"][0]["text"],
"tokens_used": result["usage"]["input_tokens"] + result["usage"]["output_tokens"], "tokens_used": result["usage"]["input_tokens"] + result["usage"]["output_tokens"],
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"]) * 0.000015, "cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"])
* 0.000015,
} }
async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict: async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Kimi 多模态模型""" """调用 Kimi 多模态模型"""
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
# Kimi 目前可能不支持真正的多模态,这里模拟返回 # Kimi 目前可能不支持真正的多模态,这里模拟返回
# 实际实现时需要根据 Kimi API 更新 # 实际实现时需要根据 Kimi API 更新
content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。" content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。"
payload = {"model": "k2p5", "messages": [{"role": "user", "content": content}], "temperature": 0.3} payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": content}],
"temperature": 0.3,
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -587,7 +635,9 @@ class AIManager:
"cost": result["usage"]["total_tokens"] * 0.000005, "cost": result["usage"]["total_tokens"] * 0.000005,
} }
def get_multimodal_analyses(self, tenant_id: str, project_id: str | None = None) -> list[MultimodalAnalysis]: def get_multimodal_analyses(
self, tenant_id: str, project_id: str | None = None
) -> list[MultimodalAnalysis]:
"""获取多模态分析历史""" """获取多模态分析历史"""
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -668,7 +718,9 @@ class AIManager:
return self._row_to_kg_rag(row) return self._row_to_kg_rag(row)
def list_kg_rags(self, tenant_id: str, project_id: str | None = None) -> list[KnowledgeGraphRAG]: def list_kg_rags(
self, tenant_id: str, project_id: str | None = None
) -> list[KnowledgeGraphRAG]:
"""列出知识图谱 RAG 配置""" """列出知识图谱 RAG 配置"""
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -720,7 +772,10 @@ class AIManager:
relevant_relations = [] relevant_relations = []
entity_ids = {e["id"] for e in relevant_entities} entity_ids = {e["id"] for e in relevant_entities}
for relation in project_relations: for relation in project_relations:
if relation.get("source_entity_id") in entity_ids or relation.get("target_entity_id") in entity_ids: if (
relation.get("source_entity_id") in entity_ids
or relation.get("target_entity_id") in entity_ids
):
relevant_relations.append(relation) relevant_relations.append(relation)
# 2. 构建上下文 # 2. 构建上下文
@@ -747,7 +802,10 @@ class AIManager:
2. 如果涉及多个实体,说明它们之间的关联 2. 如果涉及多个实体,说明它们之间的关联
3. 保持简洁专业""" 3. 保持简洁专业"""
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
payload = { payload = {
"model": "k2p5", "model": "k2p5",
@@ -758,7 +816,10 @@ class AIManager:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -773,7 +834,8 @@ class AIManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
sources = [ sources = [
{"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]} for e in relevant_entities {"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]}
for e in relevant_entities
] ]
rag_query = RAGQuery( rag_query = RAGQuery(
@@ -843,7 +905,13 @@ class AIManager:
return "\n".join(context) return "\n".join(context)
async def generate_smart_summary( async def generate_smart_summary(
self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: dict self,
tenant_id: str,
project_id: str,
source_type: str,
source_id: str,
summary_type: str,
content_data: dict,
) -> SmartSummary: ) -> SmartSummary:
"""生成智能摘要""" """生成智能摘要"""
summary_id = f"ss_{uuid.uuid4().hex[:16]}" summary_id = f"ss_{uuid.uuid4().hex[:16]}"
@@ -853,7 +921,7 @@ class AIManager:
if summary_type == "extractive": if summary_type == "extractive":
prompt = f"""从以下内容中提取关键句子作为摘要: prompt = f"""从以下内容中提取关键句子作为摘要:
{content_data.get('text', '')[:5000]} {content_data.get("text", "")[:5000]}
要求: 要求:
1. 提取 3-5 个最重要的句子 1. 提取 3-5 个最重要的句子
@@ -863,7 +931,7 @@ class AIManager:
elif summary_type == "abstractive": elif summary_type == "abstractive":
prompt = f"""对以下内容生成简洁的摘要: prompt = f"""对以下内容生成简洁的摘要:
{content_data.get('text', '')[:5000]} {content_data.get("text", "")[:5000]}
要求: 要求:
1. 用 2-3 句话概括核心内容 1. 用 2-3 句话概括核心内容
@@ -873,7 +941,7 @@ class AIManager:
elif summary_type == "key_points": elif summary_type == "key_points":
prompt = f"""从以下内容中提取关键要点: prompt = f"""从以下内容中提取关键要点:
{content_data.get('text', '')[:5000]} {content_data.get("text", "")[:5000]}
要求: 要求:
1. 列出 5-8 个关键要点 1. 列出 5-8 个关键要点
@@ -883,20 +951,30 @@ class AIManager:
else: # timeline else: # timeline
prompt = f"""基于以下内容生成时间线摘要: prompt = f"""基于以下内容生成时间线摘要:
{content_data.get('text', '')[:5000]} {content_data.get("text", "")[:5000]}
要求: 要求:
1. 按时间顺序组织关键事件 1. 按时间顺序组织关键事件
2. 标注时间节点(如果有) 2. 标注时间节点(如果有)
3. 突出里程碑事件""" 3. 突出里程碑事件"""
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3} payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -1040,14 +1118,18 @@ class AIManager:
def get_prediction_model(self, model_id: str) -> PredictionModel | None: def get_prediction_model(self, model_id: str) -> PredictionModel | None:
"""获取预测模型""" """获取预测模型"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM prediction_models WHERE id = ?", (model_id,)).fetchone() row = conn.execute(
"SELECT * FROM prediction_models WHERE id = ?", (model_id,)
).fetchone()
if not row: if not row:
return None return None
return self._row_to_prediction_model(row) return self._row_to_prediction_model(row)
def list_prediction_models(self, tenant_id: str, project_id: str | None = None) -> list[PredictionModel]: def list_prediction_models(
self, tenant_id: str, project_id: str | None = None
) -> list[PredictionModel]:
"""列出预测模型""" """列出预测模型"""
query = "SELECT * FROM prediction_models WHERE tenant_id = ?" query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -1062,7 +1144,9 @@ class AIManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_prediction_model(row) for row in rows] return [self._row_to_prediction_model(row) for row in rows]
async def train_prediction_model(self, model_id: str, historical_data: list[dict]) -> PredictionModel: async def train_prediction_model(
self, model_id: str, historical_data: list[dict]
) -> PredictionModel:
"""训练预测模型""" """训练预测模型"""
model = self.get_prediction_model(model_id) model = self.get_prediction_model(model_id)
if not model: if not model:
@@ -1150,7 +1234,8 @@ class AIManager:
# 更新预测计数 # 更新预测计数
conn.execute( conn.execute(
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?", (model_id,) "UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?",
(model_id,),
) )
conn.commit() conn.commit()
@@ -1243,7 +1328,9 @@ class AIManager:
# 计算增长率 # 计算增长率
counts = [h.get("count", 0) for h in entity_history] counts = [h.get("count", 0) for h in entity_history]
growth_rates = [(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))] growth_rates = [
(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))
]
avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0 avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0
# 预测下一个周期的实体数量 # 预测下一个周期的实体数量
@@ -1262,7 +1349,11 @@ class AIManager:
relation_history = input_data.get("relation_history", []) relation_history = input_data.get("relation_history", [])
if len(relation_history) < 2: if len(relation_history) < 2:
return {"predicted_relations": [], "confidence": 0.5, "explanation": "历史数据不足,无法预测关系演变"} return {
"predicted_relations": [],
"confidence": 0.5,
"explanation": "历史数据不足,无法预测关系演变",
}
# 分析关系变化趋势 # 分析关系变化趋势
relation_counts = defaultdict(int) relation_counts = defaultdict(int)
@@ -1273,7 +1364,9 @@ class AIManager:
# 预测可能出现的新关系类型 # 预测可能出现的新关系类型
predicted_relations = [ predicted_relations = [
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
for rel_type, count in sorted(relation_counts.items(), key=lambda x: x[1], reverse=True)[:5] for rel_type, count in sorted(
relation_counts.items(), key=lambda x: x[1], reverse=True
)[:5]
] ]
return { return {
@@ -1296,7 +1389,9 @@ class AIManager:
return [self._row_to_prediction_result(row) for row in rows] return [self._row_to_prediction_result(row) for row in rows]
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None: def update_prediction_feedback(
self, prediction_id: str, actual_value: str, is_correct: bool
) -> None:
"""更新预测反馈(用于模型改进)""" """更新预测反馈(用于模型改进)"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -1405,9 +1500,11 @@ class AIManager:
created_at=row["created_at"], created_at=row["created_at"],
) )
# Singleton instance # Singleton instance
_ai_manager = None _ai_manager = None
def get_ai_manager() -> AIManager: def get_ai_manager() -> AIManager:
global _ai_manager global _ai_manager
if _ai_manager is None: if _ai_manager is None:

View File

@@ -15,11 +15,13 @@ from enum import Enum
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
class ApiKeyStatus(Enum): class ApiKeyStatus(Enum):
ACTIVE = "active" ACTIVE = "active"
REVOKED = "revoked" REVOKED = "revoked"
EXPIRED = "expired" EXPIRED = "expired"
@dataclass @dataclass
class ApiKey: class ApiKey:
id: str id: str
@@ -37,6 +39,7 @@ class ApiKey:
revoked_reason: str | None revoked_reason: str | None
total_calls: int = 0 total_calls: int = 0
class ApiKeyManager: class ApiKeyManager:
"""API Key 管理器""" """API Key 管理器"""
@@ -220,7 +223,8 @@ class ApiKeyManager:
if datetime.now() > expires: if datetime.now() > expires:
# 更新状态为过期 # 更新状态为过期
conn.execute( conn.execute(
"UPDATE api_keys SET status = ? WHERE id = ?", (ApiKeyStatus.EXPIRED.value, api_key.id) "UPDATE api_keys SET status = ? WHERE id = ?",
(ApiKeyStatus.EXPIRED.value, api_key.id),
) )
conn.commit() conn.commit()
return None return None
@@ -232,7 +236,9 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
# 验证所有权(如果提供了 owner_id # 验证所有权(如果提供了 owner_id
if owner_id: if owner_id:
row = conn.execute("SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)).fetchone() row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
).fetchone()
if not row or row[0] != owner_id: if not row or row[0] != owner_id:
return False return False
@@ -242,7 +248,13 @@ class ApiKeyManager:
SET status = ?, revoked_at = ?, revoked_reason = ? SET status = ?, revoked_at = ?, revoked_reason = ?
WHERE id = ? AND status = ? WHERE id = ? AND status = ?
""", """,
(ApiKeyStatus.REVOKED.value, datetime.now().isoformat(), reason, key_id, ApiKeyStatus.ACTIVE.value), (
ApiKeyStatus.REVOKED.value,
datetime.now().isoformat(),
reason,
key_id,
ApiKeyStatus.ACTIVE.value,
),
) )
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -264,7 +276,11 @@ class ApiKeyManager:
return None return None
def list_keys( def list_keys(
self, owner_id: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0 self,
owner_id: str | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ApiKey]: ) -> list[ApiKey]:
"""列出 API Keys""" """列出 API Keys"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -319,7 +335,9 @@ class ApiKeyManager:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
# 验证所有权 # 验证所有权
if owner_id: if owner_id:
row = conn.execute("SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)).fetchone() row = conn.execute(
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)
).fetchone()
if not row or row[0] != owner_id: if not row or row[0] != owner_id:
return False return False
@@ -361,7 +379,16 @@ class ApiKeyManager:
ip_address, user_agent, error_message) ip_address, user_agent, error_message)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", """,
(api_key_id, endpoint, method, status_code, response_time_ms, ip_address, user_agent, error_message), (
api_key_id,
endpoint,
method,
status_code,
response_time_ms,
ip_address,
user_agent,
error_message,
),
) )
conn.commit() conn.commit()
@@ -436,7 +463,9 @@ class ApiKeyManager:
endpoint_params = [] endpoint_params = []
if api_key_id: if api_key_id:
endpoint_query = endpoint_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") endpoint_query = endpoint_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
)
endpoint_params.insert(0, api_key_id) endpoint_params.insert(0, api_key_id)
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC" endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
@@ -455,7 +484,9 @@ class ApiKeyManager:
daily_params = [] daily_params = []
if api_key_id: if api_key_id:
daily_query = daily_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at") daily_query = daily_query.replace(
"WHERE created_at", "WHERE api_key_id = ? AND created_at"
)
daily_params.insert(0, api_key_id) daily_params.insert(0, api_key_id)
daily_query += " GROUP BY date(created_at) ORDER BY date" daily_query += " GROUP BY date(created_at) ORDER BY date"
@@ -494,9 +525,11 @@ class ApiKeyManager:
total_calls=row["total_calls"], total_calls=row["total_calls"],
) )
# 全局实例 # 全局实例
_api_key_manager: ApiKeyManager | None = None _api_key_manager: ApiKeyManager | None = None
def get_api_key_manager() -> ApiKeyManager: def get_api_key_manager() -> ApiKeyManager:
"""获取 API Key 管理器实例""" """获取 API Key 管理器实例"""
global _api_key_manager global _api_key_manager

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any from typing import Any
class SharePermission(Enum): class SharePermission(Enum):
"""分享权限级别""" """分享权限级别"""
@@ -19,6 +20,7 @@ class SharePermission(Enum):
EDIT = "edit" # 可编辑 EDIT = "edit" # 可编辑
ADMIN = "admin" # 管理员 ADMIN = "admin" # 管理员
class CommentTargetType(Enum): class CommentTargetType(Enum):
"""评论目标类型""" """评论目标类型"""
@@ -27,6 +29,7 @@ class CommentTargetType(Enum):
TRANSCRIPT = "transcript" # 转录文本评论 TRANSCRIPT = "transcript" # 转录文本评论
PROJECT = "project" # 项目级评论 PROJECT = "project" # 项目级评论
class ChangeType(Enum): class ChangeType(Enum):
"""变更类型""" """变更类型"""
@@ -36,6 +39,7 @@ class ChangeType(Enum):
MERGE = "merge" # 合并 MERGE = "merge" # 合并
SPLIT = "split" # 拆分 SPLIT = "split" # 拆分
@dataclass @dataclass
class ProjectShare: class ProjectShare:
"""项目分享链接""" """项目分享链接"""
@@ -54,6 +58,7 @@ class ProjectShare:
allow_download: bool # 允许下载 allow_download: bool # 允许下载
allow_export: bool # 允许导出 allow_export: bool # 允许导出
@dataclass @dataclass
class Comment: class Comment:
"""评论/批注""" """评论/批注"""
@@ -74,6 +79,7 @@ class Comment:
mentions: list[str] # 提及的用户 mentions: list[str] # 提及的用户
attachments: list[dict] # 附件 attachments: list[dict] # 附件
@dataclass @dataclass
class ChangeRecord: class ChangeRecord:
"""变更记录""" """变更记录"""
@@ -95,6 +101,7 @@ class ChangeRecord:
reverted_at: str | None # 回滚时间 reverted_at: str | None # 回滚时间
reverted_by: str | None # 回滚者 reverted_by: str | None # 回滚者
@dataclass @dataclass
class TeamMember: class TeamMember:
"""团队成员""" """团队成员"""
@@ -110,6 +117,7 @@ class TeamMember:
last_active_at: str | None # 最后活跃时间 last_active_at: str | None # 最后活跃时间
permissions: list[str] # 具体权限列表 permissions: list[str] # 具体权限列表
@dataclass @dataclass
class TeamSpace: class TeamSpace:
"""团队空间""" """团队空间"""
@@ -124,6 +132,7 @@ class TeamSpace:
project_count: int project_count: int
settings: dict[str, Any] # 团队设置 settings: dict[str, Any] # 团队设置
class CollaborationManager: class CollaborationManager:
"""协作管理主类""" """协作管理主类"""
@@ -425,7 +434,9 @@ class CollaborationManager:
) )
self.db.conn.commit() self.db.conn.commit()
def get_comments(self, target_type: str, target_id: str, include_resolved: bool = True) -> list[Comment]: def get_comments(
self, target_type: str, target_id: str, include_resolved: bool = True
) -> list[Comment]:
"""获取评论列表""" """获取评论列表"""
if not self.db: if not self.db:
return [] return []
@@ -542,7 +553,9 @@ class CollaborationManager:
self.db.conn.commit() self.db.conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def get_project_comments(self, project_id: str, limit: int = 50, offset: int = 0) -> list[Comment]: def get_project_comments(
self, project_id: str, limit: int = 50, offset: int = 0
) -> list[Comment]:
"""获取项目下的所有评论""" """获取项目下的所有评论"""
if not self.db: if not self.db:
return [] return []
@@ -978,9 +991,11 @@ class CollaborationManager:
) )
self.db.conn.commit() self.db.conn.commit()
# 全局协作管理器实例 # 全局协作管理器实例
_collaboration_manager = None _collaboration_manager = None
def get_collaboration_manager(db_manager=None) -> None: def get_collaboration_manager(db_manager=None) -> None:
"""获取协作管理器单例""" """获取协作管理器单例"""
global _collaboration_manager global _collaboration_manager

View File

@@ -14,6 +14,7 @@ from datetime import datetime
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
@dataclass @dataclass
class Project: class Project:
id: str id: str
@@ -22,6 +23,7 @@ class Project:
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@dataclass @dataclass
class Entity: class Entity:
id: str id: str
@@ -42,6 +44,7 @@ class Entity:
if self.attributes is None: if self.attributes is None:
self.attributes = {} self.attributes = {}
@dataclass @dataclass
class AttributeTemplate: class AttributeTemplate:
"""属性模板定义""" """属性模板定义"""
@@ -62,6 +65,7 @@ class AttributeTemplate:
if self.options is None: if self.options is None:
self.options = [] self.options = []
@dataclass @dataclass
class EntityAttribute: class EntityAttribute:
"""实体属性值""" """实体属性值"""
@@ -82,6 +86,7 @@ class EntityAttribute:
if self.options is None: if self.options is None:
self.options = [] self.options = []
@dataclass @dataclass
class AttributeHistory: class AttributeHistory:
"""属性变更历史""" """属性变更历史"""
@@ -95,6 +100,7 @@ class AttributeHistory:
changed_at: str = "" changed_at: str = ""
change_reason: str = "" change_reason: str = ""
@dataclass @dataclass
class EntityMention: class EntityMention:
id: str id: str
@@ -105,6 +111,7 @@ class EntityMention:
text_snippet: str text_snippet: str
confidence: float = 1.0 confidence: float = 1.0
class DatabaseManager: class DatabaseManager:
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path self.db_path = db_path
@@ -137,7 +144,9 @@ class DatabaseManager:
) )
conn.commit() conn.commit()
conn.close() conn.close()
return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now) return Project(
id=project_id, name=name, description=description, created_at=now, updated_at=now
)
def get_project(self, project_id: str) -> Project | None: def get_project(self, project_id: str) -> Project | None:
conn = self.get_conn() conn = self.get_conn()
@@ -190,7 +199,9 @@ class DatabaseManager:
return Entity(**data) return Entity(**data)
return None return None
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> list[Entity]: def find_similar_entities(
self, project_id: str, name: str, threshold: float = 0.8
) -> list[Entity]:
"""查找相似实体""" """查找相似实体"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -224,12 +235,16 @@ class DatabaseManager:
"UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?", "UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?",
(json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id), (json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id),
) )
conn.execute("UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id))
conn.execute( conn.execute(
"UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?", (target_id, source_id) "UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id)
) )
conn.execute( conn.execute(
"UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?", (target_id, source_id) "UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?",
(target_id, source_id),
)
conn.execute(
"UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?",
(target_id, source_id),
) )
conn.execute("DELETE FROM entities WHERE id = ?", (source_id,)) conn.execute("DELETE FROM entities WHERE id = ?", (source_id,))
@@ -297,7 +312,8 @@ class DatabaseManager:
conn = self.get_conn() conn = self.get_conn()
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,)) conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
conn.execute( conn.execute(
"DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?", (entity_id, entity_id) "DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?",
(entity_id, entity_id),
) )
conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id,)) conn.execute("DELETE FROM entity_attributes WHERE entity_id = ?", (entity_id,))
conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,)) conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,))
@@ -328,7 +344,8 @@ class DatabaseManager:
def get_entity_mentions(self, entity_id: str) -> list[EntityMention]: def get_entity_mentions(self, entity_id: str) -> list[EntityMention]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id,) "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos",
(entity_id,),
).fetchall() ).fetchall()
conn.close() conn.close()
return [EntityMention(**dict(r)) for r in rows] return [EntityMention(**dict(r)) for r in rows]
@@ -336,7 +353,12 @@ class DatabaseManager:
# ==================== Transcript Operations ==================== # ==================== Transcript Operations ====================
def save_transcript( def save_transcript(
self, transcript_id: str, project_id: str, filename: str, full_text: str, transcript_type: str = "audio" self,
transcript_id: str,
project_id: str,
filename: str,
full_text: str,
transcript_type: str = "audio",
): ):
conn = self.get_conn() conn = self.get_conn()
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -365,7 +387,8 @@ class DatabaseManager:
conn = self.get_conn() conn = self.get_conn()
now = datetime.now().isoformat() now = datetime.now().isoformat()
conn.execute( conn.execute(
"UPDATE transcripts SET full_text = ?, updated_at = ? WHERE id = ?", (full_text, now, transcript_id) "UPDATE transcripts SET full_text = ?, updated_at = ? WHERE id = ?",
(full_text, now, transcript_id),
) )
conn.commit() conn.commit()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
@@ -390,7 +413,16 @@ class DatabaseManager:
"""INSERT INTO entity_relations """INSERT INTO entity_relations
(id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, created_at) (id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(relation_id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, now), (
relation_id,
project_id,
source_entity_id,
target_entity_id,
relation_type,
evidence,
transcript_id,
now,
),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -410,7 +442,8 @@ class DatabaseManager:
def list_project_relations(self, project_id: str) -> list[dict]: def list_project_relations(self, project_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC",
(project_id,),
).fetchall() ).fetchall()
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
@@ -451,7 +484,9 @@ class DatabaseManager:
).fetchone() ).fetchone()
if existing: if existing:
conn.execute("UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],)) conn.execute(
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],)
)
conn.commit() conn.commit()
conn.close() conn.close()
return existing["id"] return existing["id"]
@@ -593,9 +628,13 @@ class DatabaseManager:
"top_entities": [dict(e) for e in top_entities], "top_entities": [dict(e) for e in top_entities],
} }
def get_transcript_context(self, transcript_id: str, position: int, context_chars: int = 200) -> str: def get_transcript_context(
self, transcript_id: str, position: int, context_chars: int = 200
) -> str:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,)
).fetchone()
conn.close() conn.close()
if not row: if not row:
return "" return ""
@@ -685,7 +724,10 @@ class DatabaseManager:
conn.close() conn.close()
return {"daily_activity": [dict(d) for d in daily_stats], "top_entities": [dict(e) for e in entity_stats]} return {
"daily_activity": [dict(d) for d in daily_stats],
"top_entities": [dict(e) for e in entity_stats],
}
# ==================== Phase 5: Entity Attributes ==================== # ==================== Phase 5: Entity Attributes ====================
@@ -716,7 +758,9 @@ class DatabaseManager:
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None: def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM attribute_templates WHERE id = ?", (template_id,)).fetchone() row = conn.execute(
"SELECT * FROM attribute_templates WHERE id = ?", (template_id,)
).fetchone()
conn.close() conn.close()
if row: if row:
data = dict(row) data = dict(row)
@@ -742,7 +786,15 @@ class DatabaseManager:
def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None: def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None:
conn = self.get_conn() conn = self.get_conn()
allowed_fields = ["name", "type", "options", "default_value", "description", "is_required", "sort_order"] allowed_fields = [
"name",
"type",
"options",
"default_value",
"description",
"is_required",
"sort_order",
]
updates = [] updates = []
values = [] values = []
@@ -844,7 +896,11 @@ class DatabaseManager:
return None return None
attrs = self.get_entity_attributes(entity_id) attrs = self.get_entity_attributes(entity_id)
entity.attributes = { entity.attributes = {
attr.template_name: {"value": attr.value, "type": attr.template_type, "template_id": attr.template_id} attr.template_name: {
"value": attr.value,
"type": attr.template_type,
"template_id": attr.template_id,
}
for attr in attrs for attr in attrs
} }
return entity return entity
@@ -854,7 +910,8 @@ class DatabaseManager:
): ):
conn = self.get_conn() conn = self.get_conn()
old_row = conn.execute( old_row = conn.execute(
"SELECT value FROM entity_attributes WHERE entity_id = ? AND template_id = ?", (entity_id, template_id) "SELECT value FROM entity_attributes WHERE entity_id = ? AND template_id = ?",
(entity_id, template_id),
).fetchone() ).fetchone()
if old_row: if old_row:
@@ -874,7 +931,8 @@ class DatabaseManager:
), ),
) )
conn.execute( conn.execute(
"DELETE FROM entity_attributes WHERE entity_id = ? AND template_id = ?", (entity_id, template_id) "DELETE FROM entity_attributes WHERE entity_id = ? AND template_id = ?",
(entity_id, template_id),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -905,7 +963,9 @@ class DatabaseManager:
conn.close() conn.close()
return [AttributeHistory(**dict(r)) for r in rows] return [AttributeHistory(**dict(r)) for r in rows]
def search_entities_by_attributes(self, project_id: str, attribute_filters: dict[str, str]) -> list[Entity]: def search_entities_by_attributes(
self, project_id: str, attribute_filters: dict[str, str]
) -> list[Entity]:
entities = self.list_project_entities(project_id) entities = self.list_project_entities(project_id)
if not attribute_filters: if not attribute_filters:
return entities return entities
@@ -999,8 +1059,12 @@ class DatabaseManager:
if row: if row:
data = dict(row) data = dict(row)
data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None
data["extracted_entities"] = json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] data["extracted_entities"] = (
data["extracted_relations"] = json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] json.loads(data["extracted_entities"]) if data["extracted_entities"] else []
)
data["extracted_relations"] = (
json.loads(data["extracted_relations"]) if data["extracted_relations"] else []
)
return data return data
return None return None
@@ -1016,8 +1080,12 @@ class DatabaseManager:
for row in rows: for row in rows:
data = dict(row) data = dict(row)
data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None data["resolution"] = json.loads(data["resolution"]) if data["resolution"] else None
data["extracted_entities"] = json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] data["extracted_entities"] = (
data["extracted_relations"] = json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] json.loads(data["extracted_entities"]) if data["extracted_entities"] else []
)
data["extracted_relations"] = (
json.loads(data["extracted_relations"]) if data["extracted_relations"] else []
)
videos.append(data) videos.append(data)
return videos return videos
@@ -1065,7 +1133,9 @@ class DatabaseManager:
frames = [] frames = []
for row in rows: for row in rows:
data = dict(row) data = dict(row)
data["extracted_entities"] = json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] data["extracted_entities"] = (
json.loads(data["extracted_entities"]) if data["extracted_entities"] else []
)
frames.append(data) frames.append(data)
return frames return frames
@@ -1113,8 +1183,12 @@ class DatabaseManager:
if row: if row:
data = dict(row) data = dict(row)
data["extracted_entities"] = json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] data["extracted_entities"] = (
data["extracted_relations"] = json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] json.loads(data["extracted_entities"]) if data["extracted_entities"] else []
)
data["extracted_relations"] = (
json.loads(data["extracted_relations"]) if data["extracted_relations"] else []
)
return data return data
return None return None
@@ -1129,8 +1203,12 @@ class DatabaseManager:
images = [] images = []
for row in rows: for row in rows:
data = dict(row) data = dict(row)
data["extracted_entities"] = json.loads(data["extracted_entities"]) if data["extracted_entities"] else [] data["extracted_entities"] = (
data["extracted_relations"] = json.loads(data["extracted_relations"]) if data["extracted_relations"] else [] json.loads(data["extracted_entities"]) if data["extracted_entities"] else []
)
data["extracted_relations"] = (
json.loads(data["extracted_relations"]) if data["extracted_relations"] else []
)
images.append(data) images.append(data)
return images return images
@@ -1154,7 +1232,17 @@ class DatabaseManager:
(id, project_id, entity_id, modality, source_id, source_type, (id, project_id, entity_id, modality, source_id, source_type,
text_snippet, confidence, created_at) text_snippet, confidence, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(mention_id, project_id, entity_id, modality, source_id, source_type, text_snippet, confidence, now), (
mention_id,
project_id,
entity_id,
modality,
source_id,
source_type,
text_snippet,
confidence,
now,
),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -1217,7 +1305,16 @@ class DatabaseManager:
(id, entity_id, linked_entity_id, link_type, confidence, (id, entity_id, linked_entity_id, link_type, confidence,
evidence, modalities, created_at) evidence, modalities, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(link_id, entity_id, linked_entity_id, link_type, confidence, evidence, json.dumps(modalities or []), now), (
link_id,
entity_id,
linked_entity_id,
link_type,
confidence,
evidence,
json.dumps(modalities or []),
now,
),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -1256,11 +1353,15 @@ class DatabaseManager:
} }
# 视频数量 # 视频数量
row = conn.execute("SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,)).fetchone() row = conn.execute(
"SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,)
).fetchone()
stats["video_count"] = row["count"] stats["video_count"] = row["count"]
# 图片数量 # 图片数量
row = conn.execute("SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,)).fetchone() row = conn.execute(
"SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,)
).fetchone()
stats["image_count"] = row["count"] stats["image_count"] = row["count"]
# 多模态实体数量 # 多模态实体数量
@@ -1291,9 +1392,11 @@ class DatabaseManager:
conn.close() conn.close()
return stats return stats
# Singleton instance # Singleton instance
_db_manager = None _db_manager = None
def get_db_manager() -> DatabaseManager: def get_db_manager() -> DatabaseManager:
global _db_manager global _db_manager
if _db_manager is None: if _db_manager is None:

View File

@@ -21,6 +21,7 @@ from enum import StrEnum
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class SDKLanguage(StrEnum): class SDKLanguage(StrEnum):
"""SDK 语言类型""" """SDK 语言类型"""
@@ -31,6 +32,7 @@ class SDKLanguage(StrEnum):
JAVA = "java" JAVA = "java"
RUST = "rust" RUST = "rust"
class SDKStatus(StrEnum): class SDKStatus(StrEnum):
"""SDK 状态""" """SDK 状态"""
@@ -40,6 +42,7 @@ class SDKStatus(StrEnum):
DEPRECATED = "deprecated" # 已弃用 DEPRECATED = "deprecated" # 已弃用
ARCHIVED = "archived" # 已归档 ARCHIVED = "archived" # 已归档
class TemplateCategory(StrEnum): class TemplateCategory(StrEnum):
"""模板分类""" """模板分类"""
@@ -50,6 +53,7 @@ class TemplateCategory(StrEnum):
TECH = "tech" # 科技 TECH = "tech" # 科技
GENERAL = "general" # 通用 GENERAL = "general" # 通用
class TemplateStatus(StrEnum): class TemplateStatus(StrEnum):
"""模板状态""" """模板状态"""
@@ -59,6 +63,7 @@ class TemplateStatus(StrEnum):
PUBLISHED = "published" # 已发布 PUBLISHED = "published" # 已发布
UNLISTED = "unlisted" # 未列出 UNLISTED = "unlisted" # 未列出
class PluginStatus(StrEnum): class PluginStatus(StrEnum):
"""插件状态""" """插件状态"""
@@ -69,6 +74,7 @@ class PluginStatus(StrEnum):
PUBLISHED = "published" # 已发布 PUBLISHED = "published" # 已发布
SUSPENDED = "suspended" # 已暂停 SUSPENDED = "suspended" # 已暂停
class PluginCategory(StrEnum): class PluginCategory(StrEnum):
"""插件分类""" """插件分类"""
@@ -79,6 +85,7 @@ class PluginCategory(StrEnum):
SECURITY = "security" # 安全 SECURITY = "security" # 安全
CUSTOM = "custom" # 自定义 CUSTOM = "custom" # 自定义
class DeveloperStatus(StrEnum): class DeveloperStatus(StrEnum):
"""开发者认证状态""" """开发者认证状态"""
@@ -88,6 +95,7 @@ class DeveloperStatus(StrEnum):
CERTIFIED = "certified" # 已认证(高级) CERTIFIED = "certified" # 已认证(高级)
SUSPENDED = "suspended" # 已暂停 SUSPENDED = "suspended" # 已暂停
@dataclass @dataclass
class SDKRelease: class SDKRelease:
"""SDK 发布""" """SDK 发布"""
@@ -113,6 +121,7 @@ class SDKRelease:
published_at: str | None published_at: str | None
created_by: str created_by: str
@dataclass @dataclass
class SDKVersion: class SDKVersion:
"""SDK 版本历史""" """SDK 版本历史"""
@@ -129,6 +138,7 @@ class SDKVersion:
download_count: int download_count: int
created_at: str created_at: str
@dataclass @dataclass
class TemplateMarketItem: class TemplateMarketItem:
"""模板市场项目""" """模板市场项目"""
@@ -160,6 +170,7 @@ class TemplateMarketItem:
updated_at: str updated_at: str
published_at: str | None published_at: str | None
@dataclass @dataclass
class TemplateReview: class TemplateReview:
"""模板评价""" """模板评价"""
@@ -175,6 +186,7 @@ class TemplateReview:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class PluginMarketItem: class PluginMarketItem:
"""插件市场项目""" """插件市场项目"""
@@ -213,6 +225,7 @@ class PluginMarketItem:
reviewed_at: str | None reviewed_at: str | None
review_notes: str | None review_notes: str | None
@dataclass @dataclass
class PluginReview: class PluginReview:
"""插件评价""" """插件评价"""
@@ -228,6 +241,7 @@ class PluginReview:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class DeveloperProfile: class DeveloperProfile:
"""开发者档案""" """开发者档案"""
@@ -251,6 +265,7 @@ class DeveloperProfile:
updated_at: str updated_at: str
verified_at: str | None verified_at: str | None
@dataclass @dataclass
class DeveloperRevenue: class DeveloperRevenue:
"""开发者收益""" """开发者收益"""
@@ -268,6 +283,7 @@ class DeveloperRevenue:
transaction_id: str transaction_id: str
created_at: str created_at: str
@dataclass @dataclass
class CodeExample: class CodeExample:
"""代码示例""" """代码示例"""
@@ -290,6 +306,7 @@ class CodeExample:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class APIDocumentation: class APIDocumentation:
"""API 文档生成记录""" """API 文档生成记录"""
@@ -303,6 +320,7 @@ class APIDocumentation:
generated_at: str generated_at: str
generated_by: str generated_by: str
@dataclass @dataclass
class DeveloperPortalConfig: class DeveloperPortalConfig:
"""开发者门户配置""" """开发者门户配置"""
@@ -326,6 +344,7 @@ class DeveloperPortalConfig:
created_at: str created_at: str
updated_at: str updated_at: str
class DeveloperEcosystemManager: class DeveloperEcosystemManager:
"""开发者生态系统管理主类""" """开发者生态系统管理主类"""
@@ -432,7 +451,10 @@ class DeveloperEcosystemManager:
return None return None
def list_sdk_releases( def list_sdk_releases(
self, language: SDKLanguage | None = None, status: SDKStatus | None = None, search: str | None = None self,
language: SDKLanguage | None = None,
status: SDKStatus | None = None,
search: str | None = None,
) -> list[SDKRelease]: ) -> list[SDKRelease]:
"""列出 SDK 发布""" """列出 SDK 发布"""
query = "SELECT * FROM sdk_releases WHERE 1=1" query = "SELECT * FROM sdk_releases WHERE 1=1"
@@ -474,7 +496,10 @@ class DeveloperEcosystemManager:
with self._get_db() as conn: with self._get_db() as conn:
set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) set_clause = ", ".join([f"{k} = ?" for k in updates.keys()])
conn.execute(f"UPDATE sdk_releases SET {set_clause} WHERE id = ?", list(updates.values()) + [sdk_id]) conn.execute(
f"UPDATE sdk_releases SET {set_clause} WHERE id = ?",
list(updates.values()) + [sdk_id],
)
conn.commit() conn.commit()
return self.get_sdk_release(sdk_id) return self.get_sdk_release(sdk_id)
@@ -543,7 +568,19 @@ class DeveloperEcosystemManager:
checksum, file_size, download_count, created_at) checksum, file_size, download_count, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
(version_id, sdk_id, version, True, is_lts, release_notes, download_url, checksum, file_size, 0, now), (
version_id,
sdk_id,
version,
True,
is_lts,
release_notes,
download_url,
checksum,
file_size,
0,
now,
),
) )
conn.commit() conn.commit()
@@ -662,7 +699,9 @@ class DeveloperEcosystemManager:
def get_template(self, template_id: str) -> TemplateMarketItem | None: def get_template(self, template_id: str) -> TemplateMarketItem | None:
"""获取模板详情""" """获取模板详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM template_market WHERE id = ?", (template_id,)).fetchone() row = conn.execute(
"SELECT * FROM template_market WHERE id = ?", (template_id,)
).fetchone()
if row: if row:
return self._row_to_template(row) return self._row_to_template(row)
@@ -851,7 +890,12 @@ class DeveloperEcosystemManager:
SET rating = ?, rating_count = ?, review_count = ? SET rating = ?, rating_count = ?, review_count = ?
WHERE id = ? WHERE id = ?
""", """,
(round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], template_id), (
round(row["avg_rating"], 2) if row["avg_rating"] else 0,
row["count"],
row["count"],
template_id,
),
) )
def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]: def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]:
@@ -1159,7 +1203,12 @@ class DeveloperEcosystemManager:
SET rating = ?, rating_count = ?, review_count = ? SET rating = ?, rating_count = ?, review_count = ?
WHERE id = ? WHERE id = ?
""", """,
(round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], plugin_id), (
round(row["avg_rating"], 2) if row["avg_rating"] else 0,
row["count"],
row["count"],
plugin_id,
),
) )
def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]: def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]:
@@ -1248,7 +1297,10 @@ class DeveloperEcosystemManager:
return revenue return revenue
def get_developer_revenues( def get_developer_revenues(
self, developer_id: str, start_date: datetime | None = None, end_date: datetime | None = None self,
developer_id: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[DeveloperRevenue]: ) -> list[DeveloperRevenue]:
"""获取开发者收益记录""" """获取开发者收益记录"""
query = "SELECT * FROM developer_revenues WHERE developer_id = ?" query = "SELECT * FROM developer_revenues WHERE developer_id = ?"
@@ -1365,7 +1417,9 @@ class DeveloperEcosystemManager:
def get_developer_profile(self, developer_id: str) -> DeveloperProfile | None: def get_developer_profile(self, developer_id: str) -> DeveloperProfile | None:
"""获取开发者档案""" """获取开发者档案"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_profiles WHERE id = ?", (developer_id,)).fetchone() row = conn.execute(
"SELECT * FROM developer_profiles WHERE id = ?", (developer_id,)
).fetchone()
if row: if row:
return self._row_to_developer_profile(row) return self._row_to_developer_profile(row)
@@ -1374,13 +1428,17 @@ class DeveloperEcosystemManager:
def get_developer_profile_by_user(self, user_id: str) -> DeveloperProfile | None: def get_developer_profile_by_user(self, user_id: str) -> DeveloperProfile | None:
"""通过用户 ID 获取开发者档案""" """通过用户 ID 获取开发者档案"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,)).fetchone() row = conn.execute(
"SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,)
).fetchone()
if row: if row:
return self._row_to_developer_profile(row) return self._row_to_developer_profile(row)
return None return None
def verify_developer(self, developer_id: str, status: DeveloperStatus) -> DeveloperProfile | None: def verify_developer(
self, developer_id: str, status: DeveloperStatus
) -> DeveloperProfile | None:
"""验证开发者""" """验证开发者"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1393,7 +1451,9 @@ class DeveloperEcosystemManager:
""", """,
( (
status.value, status.value,
now if status in [DeveloperStatus.VERIFIED, DeveloperStatus.CERTIFIED] else None, now
if status in [DeveloperStatus.VERIFIED, DeveloperStatus.CERTIFIED]
else None,
now, now,
developer_id, developer_id,
), ),
@@ -1642,7 +1702,9 @@ class DeveloperEcosystemManager:
def get_latest_api_documentation(self) -> APIDocumentation | None: def get_latest_api_documentation(self) -> APIDocumentation | None:
"""获取最新 API 文档""" """获取最新 API 文档"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1").fetchone() row = conn.execute(
"SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1"
).fetchone()
if row: if row:
return self._row_to_api_documentation(row) return self._row_to_api_documentation(row)
@@ -1729,7 +1791,9 @@ class DeveloperEcosystemManager:
def get_portal_config(self, config_id: str) -> DeveloperPortalConfig | None: def get_portal_config(self, config_id: str) -> DeveloperPortalConfig | None:
"""获取开发者门户配置""" """获取开发者门户配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,)).fetchone() row = conn.execute(
"SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,)
).fetchone()
if row: if row:
return self._row_to_portal_config(row) return self._row_to_portal_config(row)
@@ -1738,7 +1802,9 @@ class DeveloperEcosystemManager:
def get_active_portal_config(self) -> DeveloperPortalConfig | None: def get_active_portal_config(self) -> DeveloperPortalConfig | None:
"""获取活跃的开发者门户配置""" """获取活跃的开发者门户配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1").fetchone() row = conn.execute(
"SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1"
).fetchone()
if row: if row:
return self._row_to_portal_config(row) return self._row_to_portal_config(row)
@@ -1984,9 +2050,11 @@ class DeveloperEcosystemManager:
updated_at=row["updated_at"], updated_at=row["updated_at"],
) )
# Singleton instance # Singleton instance
_developer_ecosystem_manager = None _developer_ecosystem_manager = None
def get_developer_ecosystem_manager() -> DeveloperEcosystemManager: def get_developer_ecosystem_manager() -> DeveloperEcosystemManager:
"""获取开发者生态系统管理器单例""" """获取开发者生态系统管理器单例"""
global _developer_ecosystem_manager global _developer_ecosystem_manager

View File

@@ -7,6 +7,7 @@ Document Processor - Phase 3
import io import io
import os import os
class DocumentProcessor: class DocumentProcessor:
"""文档处理器 - 提取 PDF/DOCX 文本""" """文档处理器 - 提取 PDF/DOCX 文本"""
@@ -33,7 +34,9 @@ class DocumentProcessor:
ext = os.path.splitext(filename.lower())[1] ext = os.path.splitext(filename.lower())[1]
if ext not in self.supported_formats: if ext not in self.supported_formats:
raise ValueError(f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}") raise ValueError(
f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}"
)
extractor = self.supported_formats[ext] extractor = self.supported_formats[ext]
text = extractor(content) text = extractor(content)
@@ -71,7 +74,9 @@ class DocumentProcessor:
text_parts.append(page_text) text_parts.append(page_text)
return "\n\n".join(text_parts) return "\n\n".join(text_parts)
except ImportError: except ImportError:
raise ImportError("PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2") raise ImportError(
"PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2"
)
except Exception as e: except Exception as e:
raise ValueError(f"PDF extraction failed: {str(e)}") raise ValueError(f"PDF extraction failed: {str(e)}")
@@ -100,7 +105,9 @@ class DocumentProcessor:
return "\n\n".join(text_parts) return "\n\n".join(text_parts)
except ImportError: except ImportError:
raise ImportError("DOCX processing requires python-docx. Install with: pip install python-docx") raise ImportError(
"DOCX processing requires python-docx. Install with: pip install python-docx"
)
except Exception as e: except Exception as e:
raise ValueError(f"DOCX extraction failed: {str(e)}") raise ValueError(f"DOCX extraction failed: {str(e)}")
@@ -149,6 +156,7 @@ class DocumentProcessor:
ext = os.path.splitext(filename.lower())[1] ext = os.path.splitext(filename.lower())[1]
return ext in self.supported_formats return ext in self.supported_formats
# 简单的文本提取器(不需要外部依赖) # 简单的文本提取器(不需要外部依赖)
class SimpleTextExtractor: class SimpleTextExtractor:
"""简单的文本提取器,用于测试""" """简单的文本提取器,用于测试"""
@@ -165,6 +173,7 @@ class SimpleTextExtractor:
return content.decode("latin-1", errors="ignore") return content.decode("latin-1", errors="ignore")
if __name__ == "__main__": if __name__ == "__main__":
# 测试 # 测试
processor = DocumentProcessor() processor = DocumentProcessor()

View File

@@ -21,6 +21,7 @@ from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SSOProvider(StrEnum): class SSOProvider(StrEnum):
"""SSO 提供商类型""" """SSO 提供商类型"""
@@ -32,6 +33,7 @@ class SSOProvider(StrEnum):
GOOGLE = "google" # Google Workspace GOOGLE = "google" # Google Workspace
CUSTOM_SAML = "custom_saml" # 自定义 SAML CUSTOM_SAML = "custom_saml" # 自定义 SAML
class SSOStatus(StrEnum): class SSOStatus(StrEnum):
"""SSO 配置状态""" """SSO 配置状态"""
@@ -40,6 +42,7 @@ class SSOStatus(StrEnum):
ACTIVE = "active" # 已启用 ACTIVE = "active" # 已启用
ERROR = "error" # 配置错误 ERROR = "error" # 配置错误
class SCIMSyncStatus(StrEnum): class SCIMSyncStatus(StrEnum):
"""SCIM 同步状态""" """SCIM 同步状态"""
@@ -48,6 +51,7 @@ class SCIMSyncStatus(StrEnum):
SUCCESS = "success" # 同步成功 SUCCESS = "success" # 同步成功
FAILED = "failed" # 同步失败 FAILED = "failed" # 同步失败
class AuditLogExportFormat(StrEnum): class AuditLogExportFormat(StrEnum):
"""审计日志导出格式""" """审计日志导出格式"""
@@ -56,6 +60,7 @@ class AuditLogExportFormat(StrEnum):
PDF = "pdf" PDF = "pdf"
XLSX = "xlsx" XLSX = "xlsx"
class DataRetentionAction(StrEnum): class DataRetentionAction(StrEnum):
"""数据保留策略动作""" """数据保留策略动作"""
@@ -63,6 +68,7 @@ class DataRetentionAction(StrEnum):
DELETE = "delete" # 删除 DELETE = "delete" # 删除
ANONYMIZE = "anonymize" # 匿名化 ANONYMIZE = "anonymize" # 匿名化
class ComplianceStandard(StrEnum): class ComplianceStandard(StrEnum):
"""合规标准""" """合规标准"""
@@ -72,6 +78,7 @@ class ComplianceStandard(StrEnum):
HIPAA = "hipaa" HIPAA = "hipaa"
PCI_DSS = "pci_dss" PCI_DSS = "pci_dss"
@dataclass @dataclass
class SSOConfig: class SSOConfig:
"""SSO 配置数据类""" """SSO 配置数据类"""
@@ -104,6 +111,7 @@ class SSOConfig:
last_tested_at: datetime | None last_tested_at: datetime | None
last_error: str | None last_error: str | None
@dataclass @dataclass
class SCIMConfig: class SCIMConfig:
"""SCIM 配置数据类""" """SCIM 配置数据类"""
@@ -128,6 +136,7 @@ class SCIMConfig:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class SCIMUser: class SCIMUser:
"""SCIM 用户数据类""" """SCIM 用户数据类"""
@@ -147,6 +156,7 @@ class SCIMUser:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class AuditLogExport: class AuditLogExport:
"""审计日志导出记录""" """审计日志导出记录"""
@@ -171,6 +181,7 @@ class AuditLogExport:
completed_at: datetime | None completed_at: datetime | None
error_message: str | None error_message: str | None
@dataclass @dataclass
class DataRetentionPolicy: class DataRetentionPolicy:
"""数据保留策略""" """数据保留策略"""
@@ -198,6 +209,7 @@ class DataRetentionPolicy:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class DataRetentionJob: class DataRetentionJob:
"""数据保留任务""" """数据保留任务"""
@@ -215,6 +227,7 @@ class DataRetentionJob:
details: dict[str, Any] details: dict[str, Any]
created_at: datetime created_at: datetime
@dataclass @dataclass
class SAMLAuthRequest: class SAMLAuthRequest:
"""SAML 认证请求""" """SAML 认证请求"""
@@ -229,6 +242,7 @@ class SAMLAuthRequest:
used: bool used: bool
used_at: datetime | None used_at: datetime | None
@dataclass @dataclass
class SAMLAuthResponse: class SAMLAuthResponse:
"""SAML 认证响应""" """SAML 认证响应"""
@@ -245,13 +259,24 @@ class SAMLAuthResponse:
processed_at: datetime | None processed_at: datetime | None
created_at: datetime created_at: datetime
class EnterpriseManager: class EnterpriseManager:
"""企业级功能管理器""" """企业级功能管理器"""
# 默认属性映射 # 默认属性映射
DEFAULT_ATTRIBUTE_MAPPING = { DEFAULT_ATTRIBUTE_MAPPING = {
SSOProvider.WECHAT_WORK: {"email": "email", "name": "name", "department": "department", "position": "position"}, SSOProvider.WECHAT_WORK: {
SSOProvider.DINGTALK: {"email": "email", "name": "name", "department": "department", "job_title": "title"}, "email": "email",
"name": "name",
"department": "department",
"position": "position",
},
SSOProvider.DINGTALK: {
"email": "email",
"name": "name",
"department": "department",
"job_title": "title",
},
SSOProvider.FEISHU: { SSOProvider.FEISHU: {
"email": "email", "email": "email",
"name": "name", "name": "name",
@@ -505,18 +530,42 @@ class EnterpriseManager:
# 创建索引 # 创建索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sso_tenant ON sso_configs(tenant_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_sso_tenant ON sso_configs(tenant_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sso_provider ON sso_configs(provider)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_sso_provider ON sso_configs(provider)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_saml_requests_config ON saml_auth_requests(sso_config_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_saml_requests_expires ON saml_auth_requests(expires_at)") "CREATE INDEX IF NOT EXISTS idx_saml_requests_config ON saml_auth_requests(sso_config_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_saml_responses_request ON saml_auth_responses(request_id)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_scim_config_tenant ON scim_configs(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_scim_users_tenant ON scim_users(tenant_id)") "CREATE INDEX IF NOT EXISTS idx_saml_requests_expires ON saml_auth_requests(expires_at)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_scim_users_external ON scim_users(external_id)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_export_tenant ON audit_log_exports(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_export_status ON audit_log_exports(status)") "CREATE INDEX IF NOT EXISTS idx_saml_responses_request ON saml_auth_responses(request_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_retention_tenant ON data_retention_policies(tenant_id)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_retention_type ON data_retention_policies(resource_type)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_retention_jobs_policy ON data_retention_jobs(policy_id)") "CREATE INDEX IF NOT EXISTS idx_scim_config_tenant ON scim_configs(tenant_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_retention_jobs_status ON data_retention_jobs(status)") )
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_scim_users_tenant ON scim_users(tenant_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_scim_users_external ON scim_users(external_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_audit_export_tenant ON audit_log_exports(tenant_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_audit_export_status ON audit_log_exports(status)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_retention_tenant ON data_retention_policies(tenant_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_retention_type ON data_retention_policies(resource_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_retention_jobs_policy ON data_retention_jobs(policy_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_retention_jobs_status ON data_retention_jobs(status)"
)
conn.commit() conn.commit()
logger.info("Enterprise tables initialized successfully") logger.info("Enterprise tables initialized successfully")
@@ -649,7 +698,9 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_tenant_sso_config(self, tenant_id: str, provider: str | None = None) -> SSOConfig | None: def get_tenant_sso_config(
self, tenant_id: str, provider: str | None = None
) -> SSOConfig | None:
"""获取租户的 SSO 配置""" """获取租户的 SSO 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -734,7 +785,7 @@ class EnterpriseManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
f""" f"""
UPDATE sso_configs SET {', '.join(updates)} UPDATE sso_configs SET {", ".join(updates)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -943,7 +994,11 @@ class EnterpriseManager:
"""解析 SAML 响应(简化实现)""" """解析 SAML 响应(简化实现)"""
# 实际应该使用 python-saml 库解析 # 实际应该使用 python-saml 库解析
# 这里返回模拟数据 # 这里返回模拟数据
return {"email": "user@example.com", "name": "Test User", "session_index": f"_{uuid.uuid4().hex}"} return {
"email": "user@example.com",
"name": "Test User",
"session_index": f"_{uuid.uuid4().hex}",
}
def _generate_self_signed_cert(self) -> str: def _generate_self_signed_cert(self) -> str:
"""生成自签名证书(简化实现)""" """生成自签名证书(简化实现)"""
@@ -1094,7 +1149,7 @@ class EnterpriseManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
f""" f"""
UPDATE scim_configs SET {', '.join(updates)} UPDATE scim_configs SET {", ".join(updates)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -1175,7 +1230,9 @@ class EnterpriseManager:
# GET {scim_base_url}/Users # GET {scim_base_url}/Users
return [] return []
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]) -> None: def _upsert_scim_user(
self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]
) -> None:
"""插入或更新 SCIM 用户""" """插入或更新 SCIM 用户"""
cursor = conn.cursor() cursor = conn.cursor()
@@ -1352,7 +1409,9 @@ class EnterpriseManager:
logs = self._apply_compliance_filter(logs, export.compliance_standard) logs = self._apply_compliance_filter(logs, export.compliance_standard)
# 生成导出文件 # 生成导出文件
file_path, file_size, checksum = self._generate_export_file(export_id, logs, export.export_format) file_path, file_size, checksum = self._generate_export_file(
export_id, logs, export.export_format
)
now = datetime.now() now = datetime.now()
@@ -1386,7 +1445,12 @@ class EnterpriseManager:
conn.close() conn.close()
def _fetch_audit_logs( def _fetch_audit_logs(
self, tenant_id: str, start_date: datetime, end_date: datetime, filters: dict[str, Any], db_manager=None self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
filters: dict[str, Any],
db_manager=None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""获取审计日志数据""" """获取审计日志数据"""
if db_manager is None: if db_manager is None:
@@ -1396,7 +1460,9 @@ class EnterpriseManager:
# 这里简化实现 # 这里简化实现
return [] return []
def _apply_compliance_filter(self, logs: list[dict[str, Any]], standard: str) -> list[dict[str, Any]]: def _apply_compliance_filter(
self, logs: list[dict[str, Any]], standard: str
) -> list[dict[str, Any]]:
"""应用合规标准字段过滤""" """应用合规标准字段过滤"""
fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), [])
@@ -1410,7 +1476,9 @@ class EnterpriseManager:
return filtered_logs return filtered_logs
def _generate_export_file(self, export_id: str, logs: list[dict[str, Any]], format: str) -> tuple[str, int, str]: def _generate_export_file(
self, export_id: str, logs: list[dict[str, Any]], format: str
) -> tuple[str, int, str]:
"""生成导出文件""" """生成导出文件"""
import hashlib import hashlib
import os import os
@@ -1599,7 +1667,9 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def list_retention_policies(self, tenant_id: str, resource_type: str | None = None) -> list[DataRetentionPolicy]: def list_retention_policies(
self, tenant_id: str, resource_type: str | None = None
) -> list[DataRetentionPolicy]:
"""列出数据保留策略""" """列出数据保留策略"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1667,7 +1737,7 @@ class EnterpriseManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
f""" f"""
UPDATE data_retention_policies SET {', '.join(updates)} UPDATE data_retention_policies SET {", ".join(updates)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -1910,10 +1980,14 @@ class EnterpriseManager:
default_role=row["default_role"], default_role=row["default_role"],
domain_restriction=json.loads(row["domain_restriction"] or "[]"), domain_restriction=json.loads(row["domain_restriction"] or "[]"),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
last_tested_at=( last_tested_at=(
datetime.fromisoformat(row["last_tested_at"]) datetime.fromisoformat(row["last_tested_at"])
@@ -1932,10 +2006,14 @@ class EnterpriseManager:
request_id=row["request_id"], request_id=row["request_id"],
relay_state=row["relay_state"], relay_state=row["relay_state"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
expires_at=( expires_at=(
datetime.fromisoformat(row["expires_at"]) if isinstance(row["expires_at"], str) else row["expires_at"] datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str)
else row["expires_at"]
), ),
used=bool(row["used"]), used=bool(row["used"]),
used_at=( used_at=(
@@ -1966,10 +2044,14 @@ class EnterpriseManager:
attribute_mapping=json.loads(row["attribute_mapping"] or "{}"), attribute_mapping=json.loads(row["attribute_mapping"] or "{}"),
sync_rules=json.loads(row["sync_rules"] or "{}"), sync_rules=json.loads(row["sync_rules"] or "{}"),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -1988,13 +2070,19 @@ class EnterpriseManager:
groups=json.loads(row["groups"] or "[]"), groups=json.loads(row["groups"] or "[]"),
raw_data=json.loads(row["raw_data"] or "{}"), raw_data=json.loads(row["raw_data"] or "{}"),
synced_at=( synced_at=(
datetime.fromisoformat(row["synced_at"]) if isinstance(row["synced_at"], str) else row["synced_at"] datetime.fromisoformat(row["synced_at"])
if isinstance(row["synced_at"], str)
else row["synced_at"]
), ),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -2005,9 +2093,13 @@ class EnterpriseManager:
tenant_id=row["tenant_id"], tenant_id=row["tenant_id"],
export_format=row["export_format"], export_format=row["export_format"],
start_date=( start_date=(
datetime.fromisoformat(row["start_date"]) if isinstance(row["start_date"], str) else row["start_date"] datetime.fromisoformat(row["start_date"])
if isinstance(row["start_date"], str)
else row["start_date"]
), ),
end_date=datetime.fromisoformat(row["end_date"]) if isinstance(row["end_date"], str) else row["end_date"], end_date=datetime.fromisoformat(row["end_date"])
if isinstance(row["end_date"], str)
else row["end_date"],
filters=json.loads(row["filters"] or "{}"), filters=json.loads(row["filters"] or "{}"),
compliance_standard=row["compliance_standard"], compliance_standard=row["compliance_standard"],
status=row["status"], status=row["status"],
@@ -2022,11 +2114,15 @@ class EnterpriseManager:
else row["downloaded_at"] else row["downloaded_at"]
), ),
expires_at=( expires_at=(
datetime.fromisoformat(row["expires_at"]) if isinstance(row["expires_at"], str) else row["expires_at"] datetime.fromisoformat(row["expires_at"])
if isinstance(row["expires_at"], str)
else row["expires_at"]
), ),
created_by=row["created_by"], created_by=row["created_by"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
completed_at=( completed_at=(
datetime.fromisoformat(row["completed_at"]) datetime.fromisoformat(row["completed_at"])
@@ -2060,10 +2156,14 @@ class EnterpriseManager:
), ),
last_execution_result=row["last_execution_result"], last_execution_result=row["last_execution_result"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -2090,13 +2190,17 @@ class EnterpriseManager:
error_count=row["error_count"], error_count=row["error_count"],
details=json.loads(row["details"] or "{}"), details=json.loads(row["details"] or "{}"),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
) )
# 全局实例 # 全局实例
_enterprise_manager = None _enterprise_manager = None
def get_enterprise_manager(db_path: str = "insightflow.db") -> EnterpriseManager: def get_enterprise_manager(db_path: str = "insightflow.db") -> EnterpriseManager:
"""获取 EnterpriseManager 单例""" """获取 EnterpriseManager 单例"""
global _enterprise_manager global _enterprise_manager

View File

@@ -15,6 +15,7 @@ import numpy as np
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@dataclass @dataclass
class EntityEmbedding: class EntityEmbedding:
entity_id: str entity_id: str
@@ -22,6 +23,7 @@ class EntityEmbedding:
definition: str definition: str
embedding: list[float] embedding: list[float]
class EntityAligner: class EntityAligner:
"""实体对齐器 - 使用 embedding 进行相似度匹配""" """实体对齐器 - 使用 embedding 进行相似度匹配"""
@@ -50,7 +52,10 @@ class EntityAligner:
try: try:
response = httpx.post( response = httpx.post(
f"{KIMI_BASE_URL}/v1/embeddings", f"{KIMI_BASE_URL}/v1/embeddings",
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"}, headers={
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={"model": "k2p5", "input": text[:500]}, # 限制长度 json={"model": "k2p5", "input": text[:500]}, # 限制长度
timeout=30.0, timeout=30.0,
) )
@@ -230,7 +235,12 @@ class EntityAligner:
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
) )
result = {"new_entity": new_ent, "matched_entity": None, "similarity": 0.0, "should_merge": False} result = {
"new_entity": new_ent,
"matched_entity": None,
"similarity": 0.0,
"should_merge": False,
}
if matched: if matched:
# 计算相似度 # 计算相似度
@@ -282,8 +292,15 @@ class EntityAligner:
try: try:
response = httpx.post( response = httpx.post(
f"{KIMI_BASE_URL}/v1/chat/completions", f"{KIMI_BASE_URL}/v1/chat/completions",
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"}, headers={
json={"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3}, "Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
timeout=30.0, timeout=30.0,
) )
response.raise_for_status() response.raise_for_status()
@@ -301,6 +318,7 @@ class EntityAligner:
return [] return []
# 简单的字符串相似度计算(不使用 embedding # 简单的字符串相似度计算(不使用 embedding
def simple_similarity(str1: str, str2: str) -> float: def simple_similarity(str1: str, str2: str) -> float:
""" """
@@ -332,6 +350,7 @@ def simple_similarity(str1: str, str2: str) -> float:
return SequenceMatcher(None, s1, s2).ratio() return SequenceMatcher(None, s1, s2).ratio()
if __name__ == "__main__": if __name__ == "__main__":
# 测试 # 测试
aligner = EntityAligner() aligner = EntityAligner()

View File

@@ -23,12 +23,20 @@ try:
from reportlab.lib.pagesizes import A4 from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch from reportlab.lib.units import inch
from reportlab.platypus import PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle from reportlab.platypus import (
PageBreak,
Paragraph,
SimpleDocTemplate,
Spacer,
Table,
TableStyle,
)
REPORTLAB_AVAILABLE = True REPORTLAB_AVAILABLE = True
except ImportError: except ImportError:
REPORTLAB_AVAILABLE = False REPORTLAB_AVAILABLE = False
@dataclass @dataclass
class ExportEntity: class ExportEntity:
id: str id: str
@@ -39,6 +47,7 @@ class ExportEntity:
mention_count: int mention_count: int
attributes: dict[str, Any] attributes: dict[str, Any]
@dataclass @dataclass
class ExportRelation: class ExportRelation:
id: str id: str
@@ -48,6 +57,7 @@ class ExportRelation:
confidence: float confidence: float
evidence: str evidence: str
@dataclass @dataclass
class ExportTranscript: class ExportTranscript:
id: str id: str
@@ -57,6 +67,7 @@ class ExportTranscript:
segments: list[dict] segments: list[dict]
entity_mentions: list[dict] entity_mentions: list[dict]
class ExportManager: class ExportManager:
"""导出管理器 - 处理各种导出需求""" """导出管理器 - 处理各种导出需求"""
@@ -159,7 +170,9 @@ class ExportManager:
color = type_colors.get(entity.type, type_colors["default"]) color = type_colors.get(entity.type, type_colors["default"])
# 节点圆圈 # 节点圆圈
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>') svg_parts.append(
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>'
)
# 实体名称 # 实体名称
svg_parts.append( svg_parts.append(
@@ -184,16 +197,20 @@ class ExportManager:
f'fill="white" stroke="#bdc3c7" rx="5"/>' f'fill="white" stroke="#bdc3c7" rx="5"/>'
) )
svg_parts.append( svg_parts.append(
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" ' f'fill="#2c3e50">实体类型</text>' f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
f'fill="#2c3e50">实体类型</text>'
) )
for i, (etype, color) in enumerate(type_colors.items()): for i, (etype, color) in enumerate(type_colors.items()):
if etype != "default": if etype != "default":
y_pos = legend_y + 25 + i * 20 y_pos = legend_y + 25 + i * 20
svg_parts.append(f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>') svg_parts.append(
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
)
text_y = y_pos + 4 text_y = y_pos + 4
svg_parts.append( svg_parts.append(
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" ' f'fill="#2c3e50">{etype}</text>' f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
f'fill="#2c3e50">{etype}</text>'
) )
svg_parts.append("</svg>") svg_parts.append("</svg>")
@@ -283,7 +300,9 @@ class ExportManager:
all_attrs.update(e.attributes.keys()) all_attrs.update(e.attributes.keys())
# 表头 # 表头
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [f"属性:{a}" for a in sorted(all_attrs)] headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [
f"属性:{a}" for a in sorted(all_attrs)
]
writer = csv.writer(output) writer = csv.writer(output)
writer.writerow(headers) writer.writerow(headers)
@@ -314,7 +333,9 @@ class ExportManager:
return output.getvalue() return output.getvalue()
def export_transcript_markdown(self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity]) -> str: def export_transcript_markdown(
self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity]
) -> str:
""" """
导出转录文本为 Markdown 格式 导出转录文本为 Markdown 格式
@@ -392,15 +413,25 @@ class ExportManager:
raise ImportError("reportlab is required for PDF export") raise ImportError("reportlab is required for PDF export")
output = io.BytesIO() output = io.BytesIO()
doc = SimpleDocTemplate(output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18) doc = SimpleDocTemplate(
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18
)
# 样式 # 样式
styles = getSampleStyleSheet() styles = getSampleStyleSheet()
title_style = ParagraphStyle( title_style = ParagraphStyle(
"CustomTitle", parent=styles["Heading1"], fontSize=24, spaceAfter=30, textColor=colors.HexColor("#2c3e50") "CustomTitle",
parent=styles["Heading1"],
fontSize=24,
spaceAfter=30,
textColor=colors.HexColor("#2c3e50"),
) )
heading_style = ParagraphStyle( heading_style = ParagraphStyle(
"CustomHeading", parent=styles["Heading2"], fontSize=16, spaceAfter=12, textColor=colors.HexColor("#34495e") "CustomHeading",
parent=styles["Heading2"],
fontSize=16,
spaceAfter=12,
textColor=colors.HexColor("#34495e"),
) )
story = [] story = []
@@ -408,7 +439,9 @@ class ExportManager:
# 标题页 # 标题页
story.append(Paragraph("InsightFlow 项目报告", title_style)) story.append(Paragraph("InsightFlow 项目报告", title_style))
story.append(Paragraph(f"项目名称: {project_name}", styles["Heading2"])) story.append(Paragraph(f"项目名称: {project_name}", styles["Heading2"]))
story.append(Paragraph(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}", styles["Normal"])) story.append(
Paragraph(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}", styles["Normal"])
)
story.append(Spacer(1, 0.3 * inch)) story.append(Spacer(1, 0.3 * inch))
# 统计概览 # 统计概览
@@ -458,7 +491,9 @@ class ExportManager:
story.append(Paragraph("实体列表", heading_style)) story.append(Paragraph("实体列表", heading_style))
entity_data = [["名称", "类型", "提及次数", "定义"]] entity_data = [["名称", "类型", "提及次数", "定义"]]
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[:50]: # 限制前50个 for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[
:50
]: # 限制前50个
entity_data.append( entity_data.append(
[ [
e.name, e.name,
@@ -468,7 +503,9 @@ class ExportManager:
] ]
) )
entity_table = Table(entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]) entity_table = Table(
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
)
entity_table.setStyle( entity_table.setStyle(
TableStyle( TableStyle(
[ [
@@ -495,7 +532,9 @@ class ExportManager:
for r in relations[:100]: # 限制前100个 for r in relations[:100]: # 限制前100个
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"]) relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
relation_table = Table(relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]) relation_table = Table(
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
)
relation_table.setStyle( relation_table.setStyle(
TableStyle( TableStyle(
[ [
@@ -557,16 +596,24 @@ class ExportManager:
for r in relations for r in relations
], ],
"transcripts": [ "transcripts": [
{"id": t.id, "name": t.name, "type": t.type, "content": t.content, "segments": t.segments} {
"id": t.id,
"name": t.name,
"type": t.type,
"content": t.content,
"segments": t.segments,
}
for t in transcripts for t in transcripts
], ],
} }
return json.dumps(data, ensure_ascii=False, indent=2) return json.dumps(data, ensure_ascii=False, indent=2)
# 全局导出管理器实例 # 全局导出管理器实例
_export_manager = None _export_manager = None
def get_export_manager(db_manager=None) -> None: def get_export_manager(db_manager=None) -> None:
"""获取导出管理器实例""" """获取导出管理器实例"""
global _export_manager global _export_manager

View File

@@ -28,6 +28,7 @@ import httpx
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class EventType(StrEnum): class EventType(StrEnum):
"""事件类型""" """事件类型"""
@@ -43,6 +44,7 @@ class EventType(StrEnum):
INVITE_ACCEPTED = "invite_accepted" # 接受邀请 INVITE_ACCEPTED = "invite_accepted" # 接受邀请
REFERRAL_REWARD = "referral_reward" # 推荐奖励 REFERRAL_REWARD = "referral_reward" # 推荐奖励
class ExperimentStatus(StrEnum): class ExperimentStatus(StrEnum):
"""实验状态""" """实验状态"""
@@ -52,6 +54,7 @@ class ExperimentStatus(StrEnum):
COMPLETED = "completed" # 已完成 COMPLETED = "completed" # 已完成
ARCHIVED = "archived" # 已归档 ARCHIVED = "archived" # 已归档
class TrafficAllocationType(StrEnum): class TrafficAllocationType(StrEnum):
"""流量分配类型""" """流量分配类型"""
@@ -59,6 +62,7 @@ class TrafficAllocationType(StrEnum):
STRATIFIED = "stratified" # 分层分配 STRATIFIED = "stratified" # 分层分配
TARGETED = "targeted" # 定向分配 TARGETED = "targeted" # 定向分配
class EmailTemplateType(StrEnum): class EmailTemplateType(StrEnum):
"""邮件模板类型""" """邮件模板类型"""
@@ -70,6 +74,7 @@ class EmailTemplateType(StrEnum):
REFERRAL = "referral" # 推荐邀请 REFERRAL = "referral" # 推荐邀请
NEWSLETTER = "newsletter" # 新闻通讯 NEWSLETTER = "newsletter" # 新闻通讯
class EmailStatus(StrEnum): class EmailStatus(StrEnum):
"""邮件状态""" """邮件状态"""
@@ -83,6 +88,7 @@ class EmailStatus(StrEnum):
BOUNCED = "bounced" # 退信 BOUNCED = "bounced" # 退信
FAILED = "failed" # 失败 FAILED = "failed" # 失败
class WorkflowTriggerType(StrEnum): class WorkflowTriggerType(StrEnum):
"""工作流触发类型""" """工作流触发类型"""
@@ -94,6 +100,7 @@ class WorkflowTriggerType(StrEnum):
MILESTONE = "milestone" # 里程碑 MILESTONE = "milestone" # 里程碑
CUSTOM_EVENT = "custom_event" # 自定义事件 CUSTOM_EVENT = "custom_event" # 自定义事件
class ReferralStatus(StrEnum): class ReferralStatus(StrEnum):
"""推荐状态""" """推荐状态"""
@@ -102,6 +109,7 @@ class ReferralStatus(StrEnum):
REWARDED = "rewarded" # 已奖励 REWARDED = "rewarded" # 已奖励
EXPIRED = "expired" # 已过期 EXPIRED = "expired" # 已过期
@dataclass @dataclass
class AnalyticsEvent: class AnalyticsEvent:
"""分析事件""" """分析事件"""
@@ -120,6 +128,7 @@ class AnalyticsEvent:
utm_medium: str | None utm_medium: str | None
utm_campaign: str | None utm_campaign: str | None
@dataclass @dataclass
class UserProfile: class UserProfile:
"""用户画像""" """用户画像"""
@@ -139,6 +148,7 @@ class UserProfile:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class Funnel: class Funnel:
"""转化漏斗""" """转化漏斗"""
@@ -151,6 +161,7 @@ class Funnel:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class FunnelAnalysis: class FunnelAnalysis:
"""漏斗分析结果""" """漏斗分析结果"""
@@ -163,6 +174,7 @@ class FunnelAnalysis:
overall_conversion: float # 总体转化率 overall_conversion: float # 总体转化率
drop_off_points: list[dict] # 流失点 drop_off_points: list[dict] # 流失点
@dataclass @dataclass
class Experiment: class Experiment:
"""A/B 测试实验""" """A/B 测试实验"""
@@ -187,6 +199,7 @@ class Experiment:
updated_at: datetime updated_at: datetime
created_by: str created_by: str
@dataclass @dataclass
class ExperimentResult: class ExperimentResult:
"""实验结果""" """实验结果"""
@@ -204,6 +217,7 @@ class ExperimentResult:
uplift: float # 提升幅度 uplift: float # 提升幅度
created_at: datetime created_at: datetime
@dataclass @dataclass
class EmailTemplate: class EmailTemplate:
"""邮件模板""" """邮件模板"""
@@ -224,6 +238,7 @@ class EmailTemplate:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class EmailCampaign: class EmailCampaign:
"""邮件营销活动""" """邮件营销活动"""
@@ -245,6 +260,7 @@ class EmailCampaign:
completed_at: datetime | None completed_at: datetime | None
created_at: datetime created_at: datetime
@dataclass @dataclass
class EmailLog: class EmailLog:
"""邮件发送记录""" """邮件发送记录"""
@@ -266,6 +282,7 @@ class EmailLog:
error_message: str | None error_message: str | None
created_at: datetime created_at: datetime
@dataclass @dataclass
class AutomationWorkflow: class AutomationWorkflow:
"""自动化工作流""" """自动化工作流"""
@@ -282,6 +299,7 @@ class AutomationWorkflow:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class ReferralProgram: class ReferralProgram:
"""推荐计划""" """推荐计划"""
@@ -301,6 +319,7 @@ class ReferralProgram:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class Referral: class Referral:
"""推荐记录""" """推荐记录"""
@@ -321,6 +340,7 @@ class Referral:
expires_at: datetime expires_at: datetime
created_at: datetime created_at: datetime
@dataclass @dataclass
class TeamIncentive: class TeamIncentive:
"""团队升级激励""" """团队升级激励"""
@@ -338,6 +358,7 @@ class TeamIncentive:
is_active: bool is_active: bool
created_at: datetime created_at: datetime
class GrowthManager: class GrowthManager:
"""运营与增长管理主类""" """运营与增长管理主类"""
@@ -437,7 +458,10 @@ class GrowthManager:
async def _send_to_mixpanel(self, event: AnalyticsEvent): async def _send_to_mixpanel(self, event: AnalyticsEvent):
"""发送事件到 Mixpanel""" """发送事件到 Mixpanel"""
try: try:
headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.mixpanel_token}"} headers = {
"Content-Type": "application/json",
"Authorization": f"Basic {self.mixpanel_token}",
}
payload = { payload = {
"event": event.event_name, "event": event.event_name,
@@ -450,7 +474,9 @@ class GrowthManager:
} }
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post("https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0) await client.post(
"https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0
)
except Exception as e: except Exception as e:
print(f"Failed to send to Mixpanel: {e}") print(f"Failed to send to Mixpanel: {e}")
@@ -473,16 +499,24 @@ class GrowthManager:
} }
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post("https://api.amplitude.com/2/httpapi", headers=headers, json=payload, timeout=10.0) await client.post(
"https://api.amplitude.com/2/httpapi",
headers=headers,
json=payload,
timeout=10.0,
)
except Exception as e: except Exception as e:
print(f"Failed to send to Amplitude: {e}") print(f"Failed to send to Amplitude: {e}")
async def _update_user_profile(self, tenant_id: str, user_id: str, event_type: EventType, event_name: str): async def _update_user_profile(
self, tenant_id: str, user_id: str, event_type: EventType, event_name: str
):
"""更新用户画像""" """更新用户画像"""
with self._get_db() as conn: with self._get_db() as conn:
# 检查用户画像是否存在 # 检查用户画像是否存在
row = conn.execute( row = conn.execute(
"SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", (tenant_id, user_id) "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?",
(tenant_id, user_id),
).fetchone() ).fetchone()
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -538,7 +572,8 @@ class GrowthManager:
"""获取用户画像""" """获取用户画像"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
"SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?", (tenant_id, user_id) "SELECT * FROM user_profiles WHERE tenant_id = ? AND user_id = ?",
(tenant_id, user_id),
).fetchone() ).fetchone()
if row: if row:
@@ -599,7 +634,9 @@ class GrowthManager:
"event_type_distribution": {r["event_type"]: r["count"] for r in type_rows}, "event_type_distribution": {r["event_type"]: r["count"] for r in type_rows},
} }
def create_funnel(self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str) -> Funnel: def create_funnel(
self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str
) -> Funnel:
"""创建转化漏斗""" """创建转化漏斗"""
funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" funnel_id = f"fnl_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -664,7 +701,9 @@ class GrowthManager:
FROM analytics_events FROM analytics_events
WHERE event_name = ? AND timestamp >= ? AND timestamp <= ? WHERE event_name = ? AND timestamp >= ? AND timestamp <= ?
""" """
row = conn.execute(query, (event_name, period_start.isoformat(), period_end.isoformat())).fetchone() row = conn.execute(
query, (event_name, period_start.isoformat(), period_end.isoformat())
).fetchone()
user_count = row["user_count"] if row else 0 user_count = row["user_count"] if row else 0
@@ -696,7 +735,9 @@ class GrowthManager:
overall_conversion = 0.0 overall_conversion = 0.0
# 找出主要流失点 # 找出主要流失点
drop_off_points = [s for s in step_conversions if s["drop_off_rate"] > 0.2 and s != step_conversions[0]] drop_off_points = [
s for s in step_conversions if s["drop_off_rate"] > 0.2 and s != step_conversions[0]
]
return FunnelAnalysis( return FunnelAnalysis(
funnel_id=funnel_id, funnel_id=funnel_id,
@@ -708,7 +749,9 @@ class GrowthManager:
drop_off_points=drop_off_points, drop_off_points=drop_off_points,
) )
def calculate_retention(self, tenant_id: str, cohort_date: datetime, periods: list[int] = None) -> dict: def calculate_retention(
self, tenant_id: str, cohort_date: datetime, periods: list[int] = None
) -> dict:
"""计算留存率""" """计算留存率"""
if periods is None: if periods is None:
periods = [1, 3, 7, 14, 30] periods = [1, 3, 7, 14, 30]
@@ -725,7 +768,8 @@ class GrowthManager:
) )
""" """
cohort_rows = conn.execute( cohort_rows = conn.execute(
cohort_query, (tenant_id, cohort_date.isoformat(), tenant_id, cohort_date.isoformat()) cohort_query,
(tenant_id, cohort_date.isoformat(), tenant_id, cohort_date.isoformat()),
).fetchall() ).fetchall()
cohort_users = {r["user_id"] for r in cohort_rows} cohort_users = {r["user_id"] for r in cohort_rows}
@@ -757,7 +801,11 @@ class GrowthManager:
"retention_rate": round(retention_rate, 4), "retention_rate": round(retention_rate, 4),
} }
return {"cohort_date": cohort_date.isoformat(), "cohort_size": cohort_size, "retention": retention_rates} return {
"cohort_date": cohort_date.isoformat(),
"cohort_size": cohort_size,
"retention": retention_rates,
}
# ==================== A/B 测试框架 ==================== # ==================== A/B 测试框架 ====================
@@ -842,7 +890,9 @@ class GrowthManager:
def get_experiment(self, experiment_id: str) -> Experiment | None: def get_experiment(self, experiment_id: str) -> Experiment | None:
"""获取实验详情""" """获取实验详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM experiments WHERE id = ?", (experiment_id,)).fetchone() row = conn.execute(
"SELECT * FROM experiments WHERE id = ?", (experiment_id,)
).fetchone()
if row: if row:
return self._row_to_experiment(row) return self._row_to_experiment(row)
@@ -863,7 +913,9 @@ class GrowthManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_experiment(row) for row in rows] return [self._row_to_experiment(row) for row in rows]
def assign_variant(self, experiment_id: str, user_id: str, user_attributes: dict = None) -> str | None: def assign_variant(
self, experiment_id: str, user_id: str, user_attributes: dict = None
) -> str | None:
"""为用户分配实验变体""" """为用户分配实验变体"""
experiment = self.get_experiment(experiment_id) experiment = self.get_experiment(experiment_id)
if not experiment or experiment.status != ExperimentStatus.RUNNING: if not experiment or experiment.status != ExperimentStatus.RUNNING:
@@ -884,9 +936,13 @@ class GrowthManager:
if experiment.traffic_allocation == TrafficAllocationType.RANDOM: if experiment.traffic_allocation == TrafficAllocationType.RANDOM:
variant_id = self._random_allocation(experiment.variants, experiment.traffic_split) variant_id = self._random_allocation(experiment.variants, experiment.traffic_split)
elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED: elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED:
variant_id = self._stratified_allocation(experiment.variants, experiment.traffic_split, user_attributes) variant_id = self._stratified_allocation(
experiment.variants, experiment.traffic_split, user_attributes
)
else: # TARGETED else: # TARGETED
variant_id = self._targeted_allocation(experiment.variants, experiment.target_audience, user_attributes) variant_id = self._targeted_allocation(
experiment.variants, experiment.target_audience, user_attributes
)
if variant_id: if variant_id:
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -932,7 +988,9 @@ class GrowthManager:
return self._random_allocation(variants, traffic_split) return self._random_allocation(variants, traffic_split)
def _targeted_allocation(self, variants: list[dict], target_audience: dict, user_attributes: dict) -> str | None: def _targeted_allocation(
self, variants: list[dict], target_audience: dict, user_attributes: dict
) -> str | None:
"""定向分配(基于目标受众条件)""" """定向分配(基于目标受众条件)"""
# 检查用户是否符合目标受众条件 # 检查用户是否符合目标受众条件
conditions = target_audience.get("conditions", []) conditions = target_audience.get("conditions", [])
@@ -963,7 +1021,12 @@ class GrowthManager:
return self._random_allocation(variants, target_audience.get("traffic_split", {})) return self._random_allocation(variants, target_audience.get("traffic_split", {}))
def record_experiment_metric( def record_experiment_metric(
self, experiment_id: str, variant_id: str, user_id: str, metric_name: str, metric_value: float self,
experiment_id: str,
variant_id: str,
user_id: str,
metric_name: str,
metric_value: float,
): ):
"""记录实验指标""" """记录实验指标"""
with self._get_db() as conn: with self._get_db() as conn:
@@ -1022,7 +1085,9 @@ class GrowthManager:
(experiment_id, variant_id, experiment.primary_metric), (experiment_id, variant_id, experiment.primary_metric),
).fetchone() ).fetchone()
mean_value = metric_row["mean_value"] if metric_row and metric_row["mean_value"] else 0 mean_value = (
metric_row["mean_value"] if metric_row and metric_row["mean_value"] else 0
)
results[variant_id] = { results[variant_id] = {
"variant_name": variant.get("name", variant_id), "variant_name": variant.get("name", variant_id),
@@ -1073,7 +1138,13 @@ class GrowthManager:
SET status = ?, start_date = ?, updated_at = ? SET status = ?, start_date = ?, updated_at = ?
WHERE id = ? AND status = ? WHERE id = ? AND status = ?
""", """,
(ExperimentStatus.RUNNING.value, now, now, experiment_id, ExperimentStatus.DRAFT.value), (
ExperimentStatus.RUNNING.value,
now,
now,
experiment_id,
ExperimentStatus.DRAFT.value,
),
) )
conn.commit() conn.commit()
@@ -1089,7 +1160,13 @@ class GrowthManager:
SET status = ?, end_date = ?, updated_at = ? SET status = ?, end_date = ?, updated_at = ?
WHERE id = ? AND status = ? WHERE id = ? AND status = ?
""", """,
(ExperimentStatus.COMPLETED.value, now, now, experiment_id, ExperimentStatus.RUNNING.value), (
ExperimentStatus.COMPLETED.value,
now,
now,
experiment_id,
ExperimentStatus.RUNNING.value,
),
) )
conn.commit() conn.commit()
@@ -1168,13 +1245,17 @@ class GrowthManager:
def get_email_template(self, template_id: str) -> EmailTemplate | None: def get_email_template(self, template_id: str) -> EmailTemplate | None:
"""获取邮件模板""" """获取邮件模板"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM email_templates WHERE id = ?", (template_id,)).fetchone() row = conn.execute(
"SELECT * FROM email_templates WHERE id = ?", (template_id,)
).fetchone()
if row: if row:
return self._row_to_email_template(row) return self._row_to_email_template(row)
return None return None
def list_email_templates(self, tenant_id: str, template_type: EmailTemplateType = None) -> list[EmailTemplate]: def list_email_templates(
self, tenant_id: str, template_type: EmailTemplateType = None
) -> list[EmailTemplate]:
"""列出邮件模板""" """列出邮件模板"""
query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1"
params = [tenant_id] params = [tenant_id]
@@ -1215,7 +1296,12 @@ class GrowthManager:
} }
def create_email_campaign( def create_email_campaign(
self, tenant_id: str, name: str, template_id: str, recipient_list: list[dict], scheduled_at: datetime = None self,
tenant_id: str,
name: str,
template_id: str,
recipient_list: list[dict],
scheduled_at: datetime = None,
) -> EmailCampaign: ) -> EmailCampaign:
"""创建邮件营销活动""" """创建邮件营销活动"""
campaign_id = f"ec_{uuid.uuid4().hex[:16]}" campaign_id = f"ec_{uuid.uuid4().hex[:16]}"
@@ -1294,7 +1380,9 @@ class GrowthManager:
return campaign return campaign
async def send_email(self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict) -> bool: async def send_email(
self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict
) -> bool:
"""发送单封邮件""" """发送单封邮件"""
template = self.get_email_template(template_id) template = self.get_email_template(template_id)
if not template: if not template:
@@ -1363,7 +1451,9 @@ class GrowthManager:
async def send_campaign(self, campaign_id: str) -> dict: async def send_campaign(self, campaign_id: str) -> dict:
"""发送整个营销活动""" """发送整个营销活动"""
with self._get_db() as conn: with self._get_db() as conn:
campaign_row = conn.execute("SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,)).fetchone() campaign_row = conn.execute(
"SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,)
).fetchone()
if not campaign_row: if not campaign_row:
return {"error": "Campaign not found"} return {"error": "Campaign not found"}
@@ -1378,7 +1468,8 @@ class GrowthManager:
# 更新活动状态 # 更新活动状态
now = datetime.now().isoformat() now = datetime.now().isoformat()
conn.execute( conn.execute(
"UPDATE email_campaigns SET status = ?, started_at = ? WHERE id = ?", ("sending", now, campaign_id) "UPDATE email_campaigns SET status = ?, started_at = ? WHERE id = ?",
("sending", now, campaign_id),
) )
conn.commit() conn.commit()
@@ -1390,7 +1481,9 @@ class GrowthManager:
# 获取用户变量 # 获取用户变量
variables = self._get_user_variables(log["tenant_id"], log["user_id"]) variables = self._get_user_variables(log["tenant_id"], log["user_id"])
success = await self.send_email(campaign_id, log["user_id"], log["email"], log["template_id"], variables) success = await self.send_email(
campaign_id, log["user_id"], log["email"], log["template_id"], variables
)
if success: if success:
success_count += 1 success_count += 1
@@ -1410,7 +1503,12 @@ class GrowthManager:
) )
conn.commit() conn.commit()
return {"campaign_id": campaign_id, "total": len(logs), "success": success_count, "failed": failed_count} return {
"campaign_id": campaign_id,
"total": len(logs),
"success": success_count,
"failed": failed_count,
}
def _get_user_variables(self, tenant_id: str, user_id: str) -> dict: def _get_user_variables(self, tenant_id: str, user_id: str) -> dict:
"""获取用户变量用于邮件模板""" """获取用户变量用于邮件模板"""
@@ -1493,7 +1591,8 @@ class GrowthManager:
# 更新执行计数 # 更新执行计数
conn.execute( conn.execute(
"UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?", (workflow_id,) "UPDATE automation_workflows SET execution_count = execution_count + 1 WHERE id = ?",
(workflow_id,),
) )
conn.commit() conn.commit()
@@ -1666,7 +1765,9 @@ class GrowthManager:
code = "".join(random.choices(chars, k=length)) code = "".join(random.choices(chars, k=length))
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT 1 FROM referrals WHERE referral_code = ?", (code,)).fetchone() row = conn.execute(
"SELECT 1 FROM referrals WHERE referral_code = ?", (code,)
).fetchone()
if not row: if not row:
return code return code
@@ -1674,7 +1775,9 @@ class GrowthManager:
def _get_referral_program(self, program_id: str) -> ReferralProgram | None: def _get_referral_program(self, program_id: str) -> ReferralProgram | None:
"""获取推荐计划""" """获取推荐计划"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM referral_programs WHERE id = ?", (program_id,)).fetchone() row = conn.execute(
"SELECT * FROM referral_programs WHERE id = ?", (program_id,)
).fetchone()
if row: if row:
return self._row_to_referral_program(row) return self._row_to_referral_program(row)
@@ -1758,7 +1861,9 @@ class GrowthManager:
"rewarded": stats["rewarded"] or 0, "rewarded": stats["rewarded"] or 0,
"expired": stats["expired"] or 0, "expired": stats["expired"] or 0,
"unique_referrers": stats["unique_referrers"] or 0, "unique_referrers": stats["unique_referrers"] or 0,
"conversion_rate": round((stats["converted"] or 0) / max(stats["total_referrals"] or 1, 1), 4), "conversion_rate": round(
(stats["converted"] or 0) / max(stats["total_referrals"] or 1, 1), 4
),
} }
def create_team_incentive( def create_team_incentive(
@@ -1898,7 +2003,9 @@ class GrowthManager:
(tenant_id, hour_start.isoformat(), hour_end.isoformat()), (tenant_id, hour_start.isoformat(), hour_end.isoformat()),
).fetchone() ).fetchone()
hourly_trend.append({"hour": hour_end.strftime("%H:00"), "active_users": row["count"] or 0}) hourly_trend.append(
{"hour": hour_end.strftime("%H:00"), "active_users": row["count"] or 0}
)
return { return {
"tenant_id": tenant_id, "tenant_id": tenant_id,
@@ -1917,7 +2024,9 @@ class GrowthManager:
} }
for r in recent_events for r in recent_events
], ],
"top_features": [{"feature": r["event_name"], "usage_count": r["count"]} for r in top_features], "top_features": [
{"feature": r["event_name"], "usage_count": r["count"]} for r in top_features
],
"hourly_trend": list(reversed(hourly_trend)), "hourly_trend": list(reversed(hourly_trend)),
} }
@@ -2038,9 +2147,11 @@ class GrowthManager:
created_at=row["created_at"], created_at=row["created_at"],
) )
# Singleton instance # Singleton instance
_growth_manager = None _growth_manager = None
def get_growth_manager() -> GrowthManager: def get_growth_manager() -> GrowthManager:
global _growth_manager global _growth_manager
if _growth_manager is None: if _growth_manager is None:

View File

@@ -33,6 +33,7 @@ try:
except ImportError: except ImportError:
PYTESSERACT_AVAILABLE = False PYTESSERACT_AVAILABLE = False
@dataclass @dataclass
class ImageEntity: class ImageEntity:
"""图片中检测到的实体""" """图片中检测到的实体"""
@@ -42,6 +43,7 @@ class ImageEntity:
confidence: float confidence: float
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height) bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
@dataclass @dataclass
class ImageRelation: class ImageRelation:
"""图片中检测到的关系""" """图片中检测到的关系"""
@@ -51,6 +53,7 @@ class ImageRelation:
relation_type: str relation_type: str
confidence: float confidence: float
@dataclass @dataclass
class ImageProcessingResult: class ImageProcessingResult:
"""图片处理结果""" """图片处理结果"""
@@ -66,6 +69,7 @@ class ImageProcessingResult:
success: bool success: bool
error_message: str = "" error_message: str = ""
@dataclass @dataclass
class BatchProcessingResult: class BatchProcessingResult:
"""批量图片处理结果""" """批量图片处理结果"""
@@ -75,6 +79,7 @@ class BatchProcessingResult:
success_count: int success_count: int
failed_count: int failed_count: int
class ImageProcessor: class ImageProcessor:
"""图片处理器 - 处理各种类型图片""" """图片处理器 - 处理各种类型图片"""
@@ -213,7 +218,10 @@ class ImageProcessor:
return "handwritten" return "handwritten"
# 检测是否为截图可能有UI元素 # 检测是否为截图可能有UI元素
if any(keyword in ocr_text.lower() for keyword in ["button", "menu", "click", "登录", "确定", "取消"]): if any(
keyword in ocr_text.lower()
for keyword in ["button", "menu", "click", "登录", "确定", "取消"]
):
return "screenshot" return "screenshot"
# 默认文档类型 # 默认文档类型
@@ -316,7 +324,9 @@ class ImageProcessor:
return unique_entities return unique_entities
def generate_description(self, image_type: str, ocr_text: str, entities: list[ImageEntity]) -> str: def generate_description(
self, image_type: str, ocr_text: str, entities: list[ImageEntity]
) -> str:
""" """
生成图片描述 生成图片描述
@@ -346,7 +356,11 @@ class ImageProcessor:
return " ".join(description_parts) return " ".join(description_parts)
def process_image( def process_image(
self, image_data: bytes, filename: str = None, image_id: str = None, detect_type: bool = True self,
image_data: bytes,
filename: str = None,
image_id: str = None,
detect_type: bool = True,
) -> ImageProcessingResult: ) -> ImageProcessingResult:
""" """
处理单张图片 处理单张图片
@@ -469,7 +483,9 @@ class ImageProcessor:
return relations return relations
def process_batch(self, images_data: list[tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult: def process_batch(
self, images_data: list[tuple[bytes, str]], project_id: str = None
) -> BatchProcessingResult:
""" """
批量处理图片 批量处理图片
@@ -494,7 +510,10 @@ class ImageProcessor:
failed_count += 1 failed_count += 1
return BatchProcessingResult( return BatchProcessingResult(
results=results, total_count=len(results), success_count=success_count, failed_count=failed_count results=results,
total_count=len(results),
success_count=success_count,
failed_count=failed_count,
) )
def image_to_base64(self, image_data: bytes) -> str: def image_to_base64(self, image_data: bytes) -> str:
@@ -534,9 +553,11 @@ class ImageProcessor:
print(f"Thumbnail generation error: {e}") print(f"Thumbnail generation error: {e}")
return image_data return image_data
# Singleton instance # Singleton instance
_image_processor = None _image_processor = None
def get_image_processor(temp_dir: str = None) -> ImageProcessor: def get_image_processor(temp_dir: str = None) -> ImageProcessor:
"""获取图片处理器单例""" """获取图片处理器单例"""
global _image_processor global _image_processor

View File

@@ -15,6 +15,7 @@ import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
class ReasoningType(Enum): class ReasoningType(Enum):
"""推理类型""" """推理类型"""
@@ -24,6 +25,7 @@ class ReasoningType(Enum):
COMPARATIVE = "comparative" # 对比推理 COMPARATIVE = "comparative" # 对比推理
SUMMARY = "summary" # 总结推理 SUMMARY = "summary" # 总结推理
@dataclass @dataclass
class ReasoningResult: class ReasoningResult:
"""推理结果""" """推理结果"""
@@ -35,6 +37,7 @@ class ReasoningResult:
related_entities: list[str] # 相关实体 related_entities: list[str] # 相关实体
gaps: list[str] # 知识缺口 gaps: list[str] # 知识缺口
@dataclass @dataclass
class InferencePath: class InferencePath:
"""推理路径""" """推理路径"""
@@ -44,24 +47,35 @@ class InferencePath:
path: list[dict] # 路径上的节点和关系 path: list[dict] # 路径上的节点和关系
strength: float # 路径强度 strength: float # 路径强度
class KnowledgeReasoner: class KnowledgeReasoner:
"""知识推理引擎""" """知识推理引擎"""
def __init__(self, api_key: str = None, base_url: str = None): def __init__(self, api_key: str = None, base_url: str = None):
self.api_key = api_key or KIMI_API_KEY self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL self.base_url = base_url or KIMI_BASE_URL
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str: async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
"""调用 LLM""" """调用 LLM"""
if not self.api_key: if not self.api_key:
raise ValueError("KIMI_API_KEY not set") raise ValueError("KIMI_API_KEY not set")
payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature} payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature,
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0 f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@@ -124,7 +138,9 @@ class KnowledgeReasoner:
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"} return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
async def _causal_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult: async def _causal_reasoning(
self, query: str, project_context: dict, graph_data: dict
) -> ReasoningResult:
"""因果推理 - 分析原因和影响""" """因果推理 - 分析原因和影响"""
# 构建因果分析提示 # 构建因果分析提示
@@ -183,7 +199,9 @@ class KnowledgeReasoner:
gaps=["无法完成因果推理"], gaps=["无法完成因果推理"],
) )
async def _comparative_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult: async def _comparative_reasoning(
self, query: str, project_context: dict, graph_data: dict
) -> ReasoningResult:
"""对比推理 - 比较实体间的异同""" """对比推理 - 比较实体间的异同"""
prompt = f"""基于以下知识图谱进行对比分析: prompt = f"""基于以下知识图谱进行对比分析:
@@ -235,7 +253,9 @@ class KnowledgeReasoner:
gaps=[], gaps=[],
) )
async def _temporal_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult: async def _temporal_reasoning(
self, query: str, project_context: dict, graph_data: dict
) -> ReasoningResult:
"""时序推理 - 分析时间线和演变""" """时序推理 - 分析时间线和演变"""
prompt = f"""基于以下知识图谱进行时序分析: prompt = f"""基于以下知识图谱进行时序分析:
@@ -287,7 +307,9 @@ class KnowledgeReasoner:
gaps=[], gaps=[],
) )
async def _associative_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult: async def _associative_reasoning(
self, query: str, project_context: dict, graph_data: dict
) -> ReasoningResult:
"""关联推理 - 发现实体间的隐含关联""" """关联推理 - 发现实体间的隐含关联"""
prompt = f"""基于以下知识图谱进行关联分析: prompt = f"""基于以下知识图谱进行关联分析:
@@ -360,7 +382,9 @@ class KnowledgeReasoner:
adj[tgt] = [] adj[tgt] = []
adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r}) adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r})
# 无向图也添加反向 # 无向图也添加反向
adj[tgt].append({"target": src, "relation": r.get("type", "related"), "data": r, "reverse": True}) adj[tgt].append(
{"target": src, "relation": r.get("type", "related"), "data": r, "reverse": True}
)
# BFS 搜索路径 # BFS 搜索路径
from collections import deque from collections import deque
@@ -478,9 +502,11 @@ class KnowledgeReasoner:
"confidence": 0.5, "confidence": 0.5,
} }
# Singleton instance # Singleton instance
_reasoner = None _reasoner = None
def get_knowledge_reasoner() -> KnowledgeReasoner: def get_knowledge_reasoner() -> KnowledgeReasoner:
global _reasoner global _reasoner
if _reasoner is None: if _reasoner is None:

View File

@@ -15,11 +15,13 @@ import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@dataclass @dataclass
class ChatMessage: class ChatMessage:
role: str role: str
content: str content: str
@dataclass @dataclass
class EntityExtractionResult: class EntityExtractionResult:
name: str name: str
@@ -27,6 +29,7 @@ class EntityExtractionResult:
definition: str definition: str
confidence: float confidence: float
@dataclass @dataclass
class RelationExtractionResult: class RelationExtractionResult:
source: str source: str
@@ -34,15 +37,21 @@ class RelationExtractionResult:
type: str type: str
confidence: float confidence: float
class LLMClient: class LLMClient:
"""Kimi API 客户端""" """Kimi API 客户端"""
def __init__(self, api_key: str = None, base_url: str = None): def __init__(self, api_key: str = None, base_url: str = None):
self.api_key = api_key or KIMI_API_KEY self.api_key = api_key or KIMI_API_KEY
self.base_url = base_url or KIMI_BASE_URL self.base_url = base_url or KIMI_BASE_URL
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
async def chat(self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str: async def chat(
self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False
) -> str:
"""发送聊天请求""" """发送聊天请求"""
if not self.api_key: if not self.api_key:
raise ValueError("KIMI_API_KEY not set") raise ValueError("KIMI_API_KEY not set")
@@ -56,13 +65,18 @@ class LLMClient:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0 f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
return result["choices"][0]["message"]["content"] return result["choices"][0]["message"]["content"]
async def chat_stream(self, messages: list[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]: async def chat_stream(
self, messages: list[ChatMessage], temperature: float = 0.3
) -> AsyncGenerator[str, None]:
"""流式聊天请求""" """流式聊天请求"""
if not self.api_key: if not self.api_key:
raise ValueError("KIMI_API_KEY not set") raise ValueError("KIMI_API_KEY not set")
@@ -76,7 +90,11 @@ class LLMClient:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
"POST", f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0 "POST",
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=120.0,
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@@ -164,7 +182,9 @@ class LLMClient:
请用中文回答,保持简洁专业。如果信息不足,请明确说明。""" 请用中文回答,保持简洁专业。如果信息不足,请明确说明。"""
messages = [ messages = [
ChatMessage(role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"), ChatMessage(
role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"
),
ChatMessage(role="user", content=prompt), ChatMessage(role="user", content=prompt),
] ]
@@ -211,7 +231,10 @@ class LLMClient:
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str: async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
"""分析实体在项目中的演变/态度变化""" """分析实体在项目中的演变/态度变化"""
mentions_text = "\n".join( mentions_text = "\n".join(
[f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20]] # 限制数量 [
f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}"
for m in mentions[:20]
] # 限制数量
) )
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化: prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
@@ -230,9 +253,11 @@ class LLMClient:
messages = [ChatMessage(role="user", content=prompt)] messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(messages, temperature=0.3) return await self.chat(messages, temperature=0.3)
# Singleton instance # Singleton instance
_llm_client = None _llm_client = None
def get_llm_client() -> LLMClient: def get_llm_client() -> LLMClient:
global _llm_client global _llm_client
if _llm_client is None: if _llm_client is None:

View File

@@ -35,6 +35,7 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LanguageCode(StrEnum): class LanguageCode(StrEnum):
"""支持的语言代码""" """支持的语言代码"""
@@ -51,6 +52,7 @@ class LanguageCode(StrEnum):
AR = "ar" AR = "ar"
HI = "hi" HI = "hi"
class RegionCode(StrEnum): class RegionCode(StrEnum):
"""区域代码""" """区域代码"""
@@ -62,6 +64,7 @@ class RegionCode(StrEnum):
LATIN_AMERICA = "latam" LATIN_AMERICA = "latam"
MIDDLE_EAST = "me" MIDDLE_EAST = "me"
class DataCenterRegion(StrEnum): class DataCenterRegion(StrEnum):
"""数据中心区域""" """数据中心区域"""
@@ -75,6 +78,7 @@ class DataCenterRegion(StrEnum):
CN_NORTH = "cn-north" CN_NORTH = "cn-north"
CN_EAST = "cn-east" CN_EAST = "cn-east"
class PaymentProvider(StrEnum): class PaymentProvider(StrEnum):
"""支付提供商""" """支付提供商"""
@@ -91,6 +95,7 @@ class PaymentProvider(StrEnum):
SEPA = "sepa" SEPA = "sepa"
UNIONPAY = "unionpay" UNIONPAY = "unionpay"
class CalendarType(StrEnum): class CalendarType(StrEnum):
"""日历类型""" """日历类型"""
@@ -102,6 +107,7 @@ class CalendarType(StrEnum):
PERSIAN = "persian" PERSIAN = "persian"
BUDDHIST = "buddhist" BUDDHIST = "buddhist"
@dataclass @dataclass
class Translation: class Translation:
id: str id: str
@@ -116,6 +122,7 @@ class Translation:
reviewed_by: str | None reviewed_by: str | None
reviewed_at: datetime | None reviewed_at: datetime | None
@dataclass @dataclass
class LanguageConfig: class LanguageConfig:
code: str code: str
@@ -133,6 +140,7 @@ class LanguageConfig:
first_day_of_week: int first_day_of_week: int
calendar_type: str calendar_type: str
@dataclass @dataclass
class DataCenter: class DataCenter:
id: str id: str
@@ -147,6 +155,7 @@ class DataCenter:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class TenantDataCenterMapping: class TenantDataCenterMapping:
id: str id: str
@@ -158,6 +167,7 @@ class TenantDataCenterMapping:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class LocalizedPaymentMethod: class LocalizedPaymentMethod:
id: str id: str
@@ -175,6 +185,7 @@ class LocalizedPaymentMethod:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class CountryConfig: class CountryConfig:
code: str code: str
@@ -196,6 +207,7 @@ class CountryConfig:
vat_rate: float | None vat_rate: float | None
is_active: bool is_active: bool
@dataclass @dataclass
class TimezoneConfig: class TimezoneConfig:
id: str id: str
@@ -206,6 +218,7 @@ class TimezoneConfig:
region: str region: str
is_active: bool is_active: bool
@dataclass @dataclass
class CurrencyConfig: class CurrencyConfig:
code: str code: str
@@ -217,6 +230,7 @@ class CurrencyConfig:
thousands_separator: str thousands_separator: str
is_active: bool is_active: bool
@dataclass @dataclass
class LocalizationSettings: class LocalizationSettings:
id: str id: str
@@ -236,6 +250,7 @@ class LocalizationSettings:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class LocalizationManager: class LocalizationManager:
DEFAULT_LANGUAGES = { DEFAULT_LANGUAGES = {
LanguageCode.EN: { LanguageCode.EN: {
@@ -807,16 +822,32 @@ class LocalizationManager:
) )
""") """)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_translations_key ON translations(key)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_translations_key ON translations(key)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_translations_lang ON translations(language)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_translations_ns ON translations(namespace)") "CREATE INDEX IF NOT EXISTS idx_translations_lang ON translations(language)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_translations_ns ON translations(namespace)"
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_dc_region ON data_centers(region_code)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_dc_region ON data_centers(region_code)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_dc_status ON data_centers(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_dc_status ON data_centers(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenant_dc ON tenant_data_center_mappings(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_payment_provider ON localized_payment_methods(provider)") "CREATE INDEX IF NOT EXISTS idx_tenant_dc ON tenant_data_center_mappings(tenant_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_payment_active ON localized_payment_methods(is_active)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_country_region ON country_configs(region)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_tz_country ON timezone_configs(country_code)") "CREATE INDEX IF NOT EXISTS idx_payment_provider ON localized_payment_methods(provider)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_locale_settings_tenant ON localization_settings(tenant_id)") )
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_payment_active ON localized_payment_methods(is_active)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_country_region ON country_configs(region)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_tz_country ON timezone_configs(country_code)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_locale_settings_tenant ON localization_settings(tenant_id)"
)
conn.commit() conn.commit()
logger.info("Localization tables initialized successfully") logger.info("Localization tables initialized successfully")
except Exception as e: except Exception as e:
@@ -923,7 +954,9 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_translation(self, key: str, language: str, namespace: str = "common", fallback: bool = True) -> str | None: def get_translation(
self, key: str, language: str, namespace: str = "common", fallback: bool = True
) -> str | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -937,7 +970,9 @@ class LocalizationManager:
if fallback: if fallback:
lang_config = self.get_language_config(language) lang_config = self.get_language_config(language)
if lang_config and lang_config.fallback_language: if lang_config and lang_config.fallback_language:
return self.get_translation(key, lang_config.fallback_language, namespace, False) return self.get_translation(
key, lang_config.fallback_language, namespace, False
)
if language != "en": if language != "en":
return self.get_translation(key, "en", namespace, False) return self.get_translation(key, "en", namespace, False)
return None return None
@@ -945,7 +980,12 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def set_translation( def set_translation(
self, key: str, language: str, value: str, namespace: str = "common", context: str | None = None self,
key: str,
language: str,
value: str,
namespace: str = "common",
context: str | None = None,
) -> Translation: ) -> Translation:
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -971,7 +1011,8 @@ class LocalizationManager:
) -> Translation | None: ) -> Translation | None:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace) "SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?",
(key, language, namespace),
) )
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -983,7 +1024,8 @@ class LocalizationManager:
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"DELETE FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace) "DELETE FROM translations WHERE key = ? AND language = ? AND namespace = ?",
(key, language, namespace),
) )
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -991,7 +1033,11 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_translations( def list_translations(
self, language: str | None = None, namespace: str | None = None, limit: int = 1000, offset: int = 0 self,
language: str | None = None,
namespace: str | None = None,
limit: int = 1000,
offset: int = 0,
) -> list[Translation]: ) -> list[Translation]:
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1062,7 +1108,9 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_data_centers(self, status: str | None = None, region: str | None = None) -> list[DataCenter]: def list_data_centers(
self, status: str | None = None, region: str | None = None
) -> list[DataCenter]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1085,7 +1133,9 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,)) cursor.execute(
"SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_tenant_dc_mapping(row) return self._row_to_tenant_dc_mapping(row)
@@ -1135,7 +1185,16 @@ class LocalizationManager:
primary_dc_id = excluded.primary_dc_id, secondary_dc_id = excluded.secondary_dc_id, primary_dc_id = excluded.primary_dc_id, secondary_dc_id = excluded.secondary_dc_id,
region_code = excluded.region_code, data_residency = excluded.data_residency, updated_at = excluded.updated_at region_code = excluded.region_code, data_residency = excluded.data_residency, updated_at = excluded.updated_at
""", """,
(mapping_id, tenant_id, primary_dc_id, secondary_dc_id, region_code, data_residency, now, now), (
mapping_id,
tenant_id,
primary_dc_id,
secondary_dc_id,
region_code,
data_residency,
now,
now,
),
) )
conn.commit() conn.commit()
return self.get_tenant_data_center(tenant_id) return self.get_tenant_data_center(tenant_id)
@@ -1146,7 +1205,9 @@ class LocalizationManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,)) cursor.execute(
"SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,)
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return self._row_to_payment_method(row) return self._row_to_payment_method(row)
@@ -1177,7 +1238,9 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_localized_payment_methods(self, country_code: str, language: str = "en") -> list[dict[str, Any]]: def get_localized_payment_methods(
self, country_code: str, language: str = "en"
) -> list[dict[str, Any]]:
methods = self.list_payment_methods(country_code=country_code) methods = self.list_payment_methods(country_code=country_code)
result = [] result = []
for method in methods: for method in methods:
@@ -1207,7 +1270,9 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_country_configs(self, region: str | None = None, active_only: bool = True) -> list[CountryConfig]: def list_country_configs(
self, region: str | None = None, active_only: bool = True
) -> list[CountryConfig]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1226,7 +1291,11 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def format_datetime( def format_datetime(
self, dt: datetime, language: str = "en", timezone: str | None = None, format_type: str = "datetime" self,
dt: datetime,
language: str = "en",
timezone: str | None = None,
format_type: str = "datetime",
) -> str: ) -> str:
try: try:
if timezone and PYTZ_AVAILABLE: if timezone and PYTZ_AVAILABLE:
@@ -1259,7 +1328,9 @@ class LocalizationManager:
logger.error(f"Error formatting datetime: {e}") logger.error(f"Error formatting datetime: {e}")
return dt.strftime("%Y-%m-%d %H:%M") return dt.strftime("%Y-%m-%d %H:%M")
def format_number(self, number: float, language: str = "en", decimal_places: int | None = None) -> str: def format_number(
self, number: float, language: str = "en", decimal_places: int | None = None
) -> str:
try: try:
if BABEL_AVAILABLE: if BABEL_AVAILABLE:
try: try:
@@ -1417,7 +1488,9 @@ class LocalizationManager:
params.append(datetime.now()) params.append(datetime.now())
params.append(tenant_id) params.append(tenant_id)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f"UPDATE localization_settings SET {', '.join(updates)} WHERE tenant_id = ?", params) cursor.execute(
f"UPDATE localization_settings SET {', '.join(updates)} WHERE tenant_id = ?", params
)
conn.commit() conn.commit()
return self.get_localization_settings(tenant_id) return self.get_localization_settings(tenant_id)
finally: finally:
@@ -1454,10 +1527,14 @@ class LocalizationManager:
namespace=row["namespace"], namespace=row["namespace"],
context=row["context"], context=row["context"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
is_reviewed=bool(row["is_reviewed"]), is_reviewed=bool(row["is_reviewed"]),
reviewed_by=row["reviewed_by"], reviewed_by=row["reviewed_by"],
@@ -1498,10 +1575,14 @@ class LocalizationManager:
supported_regions=json.loads(row["supported_regions"] or "[]"), supported_regions=json.loads(row["supported_regions"] or "[]"),
capabilities=json.loads(row["capabilities"] or "{}"), capabilities=json.loads(row["capabilities"] or "{}"),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -1514,10 +1595,14 @@ class LocalizationManager:
region_code=row["region_code"], region_code=row["region_code"],
data_residency=row["data_residency"], data_residency=row["data_residency"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -1536,10 +1621,14 @@ class LocalizationManager:
min_amount=row["min_amount"], min_amount=row["min_amount"],
max_amount=row["max_amount"], max_amount=row["max_amount"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -1582,15 +1671,21 @@ class LocalizationManager:
region_code=row["region_code"], region_code=row["region_code"],
data_residency=row["data_residency"], data_residency=row["data_residency"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
_localization_manager = None _localization_manager = None
def get_localization_manager(db_path: str = "insightflow.db") -> LocalizationManager: def get_localization_manager(db_path: str = "insightflow.db") -> LocalizationManager:
global _localization_manager global _localization_manager
if _localization_manager is None: if _localization_manager is None:

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,7 @@ try:
except ImportError: except ImportError:
NUMPY_AVAILABLE = False NUMPY_AVAILABLE = False
@dataclass @dataclass
class MultimodalEntity: class MultimodalEntity:
"""多模态实体""" """多模态实体"""
@@ -32,6 +33,7 @@ class MultimodalEntity:
if self.modality_features is None: if self.modality_features is None:
self.modality_features = {} self.modality_features = {}
@dataclass @dataclass
class EntityLink: class EntityLink:
"""实体关联""" """实体关联"""
@@ -46,6 +48,7 @@ class EntityLink:
confidence: float confidence: float
evidence: str evidence: str
@dataclass @dataclass
class AlignmentResult: class AlignmentResult:
"""对齐结果""" """对齐结果"""
@@ -56,6 +59,7 @@ class AlignmentResult:
match_type: str # exact, fuzzy, embedding match_type: str # exact, fuzzy, embedding
confidence: float confidence: float
@dataclass @dataclass
class FusionResult: class FusionResult:
"""知识融合结果""" """知识融合结果"""
@@ -66,11 +70,17 @@ class FusionResult:
source_modalities: list[str] source_modalities: list[str]
confidence: float confidence: float
class MultimodalEntityLinker: class MultimodalEntityLinker:
"""多模态实体关联器 - 跨模态实体对齐和知识融合""" """多模态实体关联器 - 跨模态实体对齐和知识融合"""
# 关联类型 # 关联类型
LINK_TYPES = {"same_as": "同一实体", "related_to": "相关实体", "part_of": "组成部分", "mentions": "提及关系"} LINK_TYPES = {
"same_as": "同一实体",
"related_to": "相关实体",
"part_of": "组成部分",
"mentions": "提及关系",
}
# 模态类型 # 模态类型
MODALITIES = ["audio", "video", "image", "document"] MODALITIES = ["audio", "video", "image", "document"]
@@ -123,7 +133,9 @@ class MultimodalEntityLinker:
(相似度, 匹配类型) (相似度, 匹配类型)
""" """
# 名称相似度 # 名称相似度
name_sim = self.calculate_string_similarity(entity1.get("name", ""), entity2.get("name", "")) name_sim = self.calculate_string_similarity(
entity1.get("name", ""), entity2.get("name", "")
)
# 如果名称完全匹配 # 如果名称完全匹配
if name_sim == 1.0: if name_sim == 1.0:
@@ -142,7 +154,9 @@ class MultimodalEntityLinker:
return 0.95, "alias_match" return 0.95, "alias_match"
# 定义相似度 # 定义相似度
def_sim = self.calculate_string_similarity(entity1.get("definition", ""), entity2.get("definition", "")) def_sim = self.calculate_string_similarity(
entity1.get("definition", ""), entity2.get("definition", "")
)
# 综合相似度 # 综合相似度
combined_sim = name_sim * 0.7 + def_sim * 0.3 combined_sim = name_sim * 0.7 + def_sim * 0.3
@@ -301,7 +315,9 @@ class MultimodalEntityLinker:
fused_properties["contexts"].append(mention.get("mention_context")) fused_properties["contexts"].append(mention.get("mention_context"))
# 选择最佳定义(最长的那个) # 选择最佳定义(最长的那个)
best_definition = max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else "" best_definition = (
max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else ""
)
# 选择最佳名称(最常见的那个) # 选择最佳名称(最常见的那个)
from collections import Counter from collections import Counter
@@ -374,7 +390,9 @@ class MultimodalEntityLinker:
return conflicts return conflicts
def suggest_entity_merges(self, entities: list[dict], existing_links: list[EntityLink] = None) -> list[dict]: def suggest_entity_merges(
self, entities: list[dict], existing_links: list[EntityLink] = None
) -> list[dict]:
""" """
建议实体合并 建议实体合并
@@ -489,12 +507,16 @@ class MultimodalEntityLinker:
"total_multimodal_records": len(multimodal_entities), "total_multimodal_records": len(multimodal_entities),
"unique_entities": len(entity_modalities), "unique_entities": len(entity_modalities),
"cross_modal_entities": cross_modal_count, "cross_modal_entities": cross_modal_count,
"cross_modal_ratio": cross_modal_count / len(entity_modalities) if entity_modalities else 0, "cross_modal_ratio": cross_modal_count / len(entity_modalities)
if entity_modalities
else 0,
} }
# Singleton instance # Singleton instance
_multimodal_entity_linker = None _multimodal_entity_linker = None
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker: def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
"""获取多模态实体关联器单例""" """获取多模态实体关联器单例"""
global _multimodal_entity_linker global _multimodal_entity_linker

View File

@@ -35,6 +35,7 @@ try:
except ImportError: except ImportError:
FFMPEG_AVAILABLE = False FFMPEG_AVAILABLE = False
@dataclass @dataclass
class VideoFrame: class VideoFrame:
"""视频关键帧数据类""" """视频关键帧数据类"""
@@ -52,6 +53,7 @@ class VideoFrame:
if self.entities_detected is None: if self.entities_detected is None:
self.entities_detected = [] self.entities_detected = []
@dataclass @dataclass
class VideoInfo: class VideoInfo:
"""视频信息数据类""" """视频信息数据类"""
@@ -75,6 +77,7 @@ class VideoInfo:
if self.metadata is None: if self.metadata is None:
self.metadata = {} self.metadata = {}
@dataclass @dataclass
class VideoProcessingResult: class VideoProcessingResult:
"""视频处理结果""" """视频处理结果"""
@@ -87,6 +90,7 @@ class VideoProcessingResult:
success: bool success: bool
error_message: str = "" error_message: str = ""
class MultimodalProcessor: class MultimodalProcessor:
"""多模态处理器 - 处理视频文件""" """多模态处理器 - 处理视频文件"""
@@ -122,8 +126,12 @@ class MultimodalProcessor:
try: try:
if FFMPEG_AVAILABLE: if FFMPEG_AVAILABLE:
probe = ffmpeg.probe(video_path) probe = ffmpeg.probe(video_path)
video_stream = next((s for s in probe["streams"] if s["codec_type"] == "video"), None) video_stream = next(
audio_stream = next((s for s in probe["streams"] if s["codec_type"] == "audio"), None) (s for s in probe["streams"] if s["codec_type"] == "video"), None
)
audio_stream = next(
(s for s in probe["streams"] if s["codec_type"] == "audio"), None
)
if video_stream: if video_stream:
return { return {
@@ -154,7 +162,9 @@ class MultimodalProcessor:
return { return {
"duration": float(data["format"].get("duration", 0)), "duration": float(data["format"].get("duration", 0)),
"width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0, "width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0,
"height": int(data["streams"][0].get("height", 0)) if data["streams"] else 0, "height": int(data["streams"][0].get("height", 0))
if data["streams"]
else 0,
"fps": 30.0, # 默认值 "fps": 30.0, # 默认值
"has_audio": len(data["streams"]) > 1, "has_audio": len(data["streams"]) > 1,
"bitrate": int(data["format"].get("bit_rate", 0)), "bitrate": int(data["format"].get("bit_rate", 0)),
@@ -246,7 +256,9 @@ class MultimodalProcessor:
if frame_number % frame_interval_frames == 0: if frame_number % frame_interval_frames == 0:
timestamp = frame_number / fps timestamp = frame_number / fps
frame_path = os.path.join(video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg") frame_path = os.path.join(
video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg"
)
cv2.imwrite(frame_path, frame) cv2.imwrite(frame_path, frame)
frame_paths.append(frame_path) frame_paths.append(frame_path)
@@ -258,12 +270,26 @@ class MultimodalProcessor:
Path(video_path).stem Path(video_path).stem
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg") output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
cmd = ["ffmpeg", "-i", video_path, "-vf", f"fps=1/{interval}", "-frame_pts", "1", "-y", output_pattern] cmd = [
"ffmpeg",
"-i",
video_path,
"-vf",
f"fps=1/{interval}",
"-frame_pts",
"1",
"-y",
output_pattern,
]
subprocess.run(cmd, check=True, capture_output=True) subprocess.run(cmd, check=True, capture_output=True)
# 获取生成的帧文件列表 # 获取生成的帧文件列表
frame_paths = sorted( frame_paths = sorted(
[os.path.join(video_frames_dir, f) for f in os.listdir(video_frames_dir) if f.startswith("frame_")] [
os.path.join(video_frames_dir, f)
for f in os.listdir(video_frames_dir)
if f.startswith("frame_")
]
) )
except Exception as e: except Exception as e:
print(f"Error extracting keyframes: {e}") print(f"Error extracting keyframes: {e}")
@@ -409,7 +435,9 @@ class MultimodalProcessor:
if video_id: if video_id:
# 清理特定视频的文件 # 清理特定视频的文件
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]: for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
target_dir = os.path.join(dir_path, video_id) if dir_path == self.frames_dir else dir_path target_dir = (
os.path.join(dir_path, video_id) if dir_path == self.frames_dir else dir_path
)
if os.path.exists(target_dir): if os.path.exists(target_dir):
for f in os.listdir(target_dir): for f in os.listdir(target_dir):
if video_id in f: if video_id in f:
@@ -421,9 +449,11 @@ class MultimodalProcessor:
shutil.rmtree(dir_path) shutil.rmtree(dir_path)
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
# Singleton instance # Singleton instance
_multimodal_processor = None _multimodal_processor = None
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor: def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
"""获取多模态处理器单例""" """获取多模态处理器单例"""
global _multimodal_processor global _multimodal_processor

View File

@@ -26,6 +26,7 @@ except ImportError:
NEO4J_AVAILABLE = False NEO4J_AVAILABLE = False
logger.warning("Neo4j driver not installed. Neo4j features will be disabled.") logger.warning("Neo4j driver not installed. Neo4j features will be disabled.")
@dataclass @dataclass
class GraphEntity: class GraphEntity:
"""图数据库中的实体节点""" """图数据库中的实体节点"""
@@ -44,6 +45,7 @@ class GraphEntity:
if self.properties is None: if self.properties is None:
self.properties = {} self.properties = {}
@dataclass @dataclass
class GraphRelation: class GraphRelation:
"""图数据库中的关系边""" """图数据库中的关系边"""
@@ -59,6 +61,7 @@ class GraphRelation:
if self.properties is None: if self.properties is None:
self.properties = {} self.properties = {}
@dataclass @dataclass
class PathResult: class PathResult:
"""路径查询结果""" """路径查询结果"""
@@ -68,6 +71,7 @@ class PathResult:
length: int length: int
total_weight: float = 0.0 total_weight: float = 0.0
@dataclass @dataclass
class CommunityResult: class CommunityResult:
"""社区发现结果""" """社区发现结果"""
@@ -77,6 +81,7 @@ class CommunityResult:
size: int size: int
density: float = 0.0 density: float = 0.0
@dataclass @dataclass
class CentralityResult: class CentralityResult:
"""中心性分析结果""" """中心性分析结果"""
@@ -86,6 +91,7 @@ class CentralityResult:
score: float score: float
rank: int = 0 rank: int = 0
class Neo4jManager: class Neo4jManager:
"""Neo4j 图数据库管理器""" """Neo4j 图数据库管理器"""
@@ -172,7 +178,9 @@ class Neo4jManager:
# ==================== 数据同步 ==================== # ==================== 数据同步 ====================
def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None: def sync_project(
self, project_id: str, project_name: str, project_description: str = ""
) -> None:
"""同步项目节点到 Neo4j""" """同步项目节点到 Neo4j"""
if not self._driver: if not self._driver:
return return
@@ -343,7 +351,9 @@ class Neo4jManager:
# ==================== 复杂图查询 ==================== # ==================== 复杂图查询 ====================
def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> PathResult | None: def find_shortest_path(
self, source_id: str, target_id: str, max_depth: int = 10
) -> PathResult | None:
""" """
查找两个实体之间的最短路径 查找两个实体之间的最短路径
@@ -378,7 +388,10 @@ class Neo4jManager:
path = record["path"] path = record["path"]
# 提取节点和关系 # 提取节点和关系
nodes = [{"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes] nodes = [
{"id": node["id"], "name": node["name"], "type": node["type"]}
for node in path.nodes
]
relationships = [ relationships = [
{ {
@@ -390,9 +403,13 @@ class Neo4jManager:
for rel in path.relationships for rel in path.relationships
] ]
return PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships)) return PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships)
)
def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> list[PathResult]: def find_all_paths(
self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10
) -> list[PathResult]:
""" """
查找两个实体之间的所有路径 查找两个实体之间的所有路径
@@ -426,7 +443,10 @@ class Neo4jManager:
for record in result: for record in result:
path = record["path"] path = record["path"]
nodes = [{"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes] nodes = [
{"id": node["id"], "name": node["name"], "type": node["type"]}
for node in path.nodes
]
relationships = [ relationships = [
{ {
@@ -438,11 +458,17 @@ class Neo4jManager:
for rel in path.relationships for rel in path.relationships
] ]
paths.append(PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships))) paths.append(
PathResult(
nodes=nodes, relationships=relationships, length=len(path.relationships)
)
)
return paths return paths
def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> list[dict]: def find_neighbors(
self, entity_id: str, relation_type: str = None, limit: int = 50
) -> list[dict]:
""" """
查找实体的邻居节点 查找实体的邻居节点
@@ -520,7 +546,11 @@ class Neo4jManager:
) )
return [ return [
{"id": record["common"]["id"], "name": record["common"]["name"], "type": record["common"]["type"]} {
"id": record["common"]["id"],
"name": record["common"]["name"],
"type": record["common"]["type"],
}
for record in result for record in result
] ]
@@ -720,13 +750,19 @@ class Neo4jManager:
actual_edges = sum(n["connections"] for n in nodes) / 2 actual_edges = sum(n["connections"] for n in nodes) / 2
density = actual_edges / max_edges if max_edges > 0 else 0 density = actual_edges / max_edges if max_edges > 0 else 0
results.append(CommunityResult(community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0))) results.append(
CommunityResult(
community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0)
)
)
# 按大小排序 # 按大小排序
results.sort(key=lambda x: x.size, reverse=True) results.sort(key=lambda x: x.size, reverse=True)
return results return results
def find_central_entities(self, project_id: str, metric: str = "degree") -> list[CentralityResult]: def find_central_entities(
self, project_id: str, metric: str = "degree"
) -> list[CentralityResult]:
""" """
查找中心实体 查找中心实体
@@ -860,7 +896,9 @@ class Neo4jManager:
"type_distribution": types, "type_distribution": types,
"average_degree": round(avg_degree, 2) if avg_degree else 0, "average_degree": round(avg_degree, 2) if avg_degree else 0,
"relation_type_distribution": relation_types, "relation_type_distribution": relation_types,
"density": round(relation_count / (entity_count * (entity_count - 1)), 4) if entity_count > 1 else 0, "density": round(relation_count / (entity_count * (entity_count - 1)), 4)
if entity_count > 1
else 0,
} }
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
@@ -930,9 +968,11 @@ class Neo4jManager:
return {"nodes": nodes, "relationships": relationships} return {"nodes": nodes, "relationships": relationships}
# 全局单例 # 全局单例
_neo4j_manager = None _neo4j_manager = None
def get_neo4j_manager() -> Neo4jManager: def get_neo4j_manager() -> Neo4jManager:
"""获取 Neo4j 管理器单例""" """获取 Neo4j 管理器单例"""
global _neo4j_manager global _neo4j_manager
@@ -940,6 +980,7 @@ def get_neo4j_manager() -> Neo4jManager:
_neo4j_manager = Neo4jManager() _neo4j_manager = Neo4jManager()
return _neo4j_manager return _neo4j_manager
def close_neo4j_manager() -> None: def close_neo4j_manager() -> None:
"""关闭 Neo4j 连接""" """关闭 Neo4j 连接"""
global _neo4j_manager global _neo4j_manager
@@ -947,8 +988,11 @@ def close_neo4j_manager() -> None:
_neo4j_manager.close() _neo4j_manager.close()
_neo4j_manager = None _neo4j_manager = None
# 便捷函数 # 便捷函数
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None: def sync_project_to_neo4j(
project_id: str, project_name: str, entities: list[dict], relations: list[dict]
) -> None:
""" """
同步整个项目到 Neo4j 同步整个项目到 Neo4j
@@ -995,7 +1039,10 @@ def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dic
] ]
manager.sync_relations_batch(graph_relations) manager.sync_relations_batch(graph_relations)
logger.info(f"Synced project {project_id} to Neo4j: {len(entities)} entities, {len(relations)} relations") logger.info(
f"Synced project {project_id} to Neo4j: {len(entities)} entities, {len(relations)} relations"
)
if __name__ == "__main__": if __name__ == "__main__":
# 测试代码 # 测试代码
@@ -1016,7 +1063,11 @@ if __name__ == "__main__":
# 测试实体 # 测试实体
test_entity = GraphEntity( test_entity = GraphEntity(
id="test-entity-1", project_id="test-project", name="Test Entity", type="Person", definition="A test entity" id="test-entity-1",
project_id="test-project",
name="Test Entity",
type="Person",
definition="A test entity",
) )
manager.sync_entity(test_entity) manager.sync_entity(test_entity)
print("✅ Entity synced") print("✅ Entity synced")

View File

@@ -29,6 +29,7 @@ import httpx
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class AlertSeverity(StrEnum): class AlertSeverity(StrEnum):
"""告警严重级别 P0-P3""" """告警严重级别 P0-P3"""
@@ -37,6 +38,7 @@ class AlertSeverity(StrEnum):
P2 = "p2" # 一般 - 部分功能受影响需要4小时内处理 P2 = "p2" # 一般 - 部分功能受影响需要4小时内处理
P3 = "p3" # 轻微 - 非核心功能问题24小时内处理 P3 = "p3" # 轻微 - 非核心功能问题24小时内处理
class AlertStatus(StrEnum): class AlertStatus(StrEnum):
"""告警状态""" """告警状态"""
@@ -45,6 +47,7 @@ class AlertStatus(StrEnum):
ACKNOWLEDGED = "acknowledged" # 已确认 ACKNOWLEDGED = "acknowledged" # 已确认
SUPPRESSED = "suppressed" # 已抑制 SUPPRESSED = "suppressed" # 已抑制
class AlertChannelType(StrEnum): class AlertChannelType(StrEnum):
"""告警渠道类型""" """告警渠道类型"""
@@ -57,6 +60,7 @@ class AlertChannelType(StrEnum):
SMS = "sms" SMS = "sms"
WEBHOOK = "webhook" WEBHOOK = "webhook"
class AlertRuleType(StrEnum): class AlertRuleType(StrEnum):
"""告警规则类型""" """告警规则类型"""
@@ -65,6 +69,7 @@ class AlertRuleType(StrEnum):
PREDICTIVE = "predictive" # 预测性告警 PREDICTIVE = "predictive" # 预测性告警
COMPOSITE = "composite" # 复合告警 COMPOSITE = "composite" # 复合告警
class ResourceType(StrEnum): class ResourceType(StrEnum):
"""资源类型""" """资源类型"""
@@ -77,6 +82,7 @@ class ResourceType(StrEnum):
CACHE = "cache" CACHE = "cache"
QUEUE = "queue" QUEUE = "queue"
class ScalingAction(StrEnum): class ScalingAction(StrEnum):
"""扩缩容动作""" """扩缩容动作"""
@@ -84,6 +90,7 @@ class ScalingAction(StrEnum):
SCALE_DOWN = "scale_down" # 缩容 SCALE_DOWN = "scale_down" # 缩容
MAINTAIN = "maintain" # 保持 MAINTAIN = "maintain" # 保持
class HealthStatus(StrEnum): class HealthStatus(StrEnum):
"""健康状态""" """健康状态"""
@@ -92,6 +99,7 @@ class HealthStatus(StrEnum):
UNHEALTHY = "unhealthy" UNHEALTHY = "unhealthy"
UNKNOWN = "unknown" UNKNOWN = "unknown"
class BackupStatus(StrEnum): class BackupStatus(StrEnum):
"""备份状态""" """备份状态"""
@@ -101,6 +109,7 @@ class BackupStatus(StrEnum):
FAILED = "failed" FAILED = "failed"
VERIFIED = "verified" VERIFIED = "verified"
@dataclass @dataclass
class AlertRule: class AlertRule:
"""告警规则""" """告警规则"""
@@ -124,6 +133,7 @@ class AlertRule:
updated_at: str updated_at: str
created_by: str created_by: str
@dataclass @dataclass
class AlertChannel: class AlertChannel:
"""告警渠道配置""" """告警渠道配置"""
@@ -141,6 +151,7 @@ class AlertChannel:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class Alert: class Alert:
"""告警实例""" """告警实例"""
@@ -164,6 +175,7 @@ class Alert:
notification_sent: dict[str, bool] # 渠道发送状态 notification_sent: dict[str, bool] # 渠道发送状态
suppression_count: int # 抑制计数 suppression_count: int # 抑制计数
@dataclass @dataclass
class AlertSuppressionRule: class AlertSuppressionRule:
"""告警抑制规则""" """告警抑制规则"""
@@ -177,6 +189,7 @@ class AlertSuppressionRule:
created_at: str created_at: str
expires_at: str | None expires_at: str | None
@dataclass @dataclass
class AlertGroup: class AlertGroup:
"""告警聚合组""" """告警聚合组"""
@@ -188,6 +201,7 @@ class AlertGroup:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class ResourceMetric: class ResourceMetric:
"""资源指标""" """资源指标"""
@@ -202,6 +216,7 @@ class ResourceMetric:
timestamp: str timestamp: str
metadata: dict metadata: dict
@dataclass @dataclass
class CapacityPlan: class CapacityPlan:
"""容量规划""" """容量规划"""
@@ -217,6 +232,7 @@ class CapacityPlan:
estimated_cost: float estimated_cost: float
created_at: str created_at: str
@dataclass @dataclass
class AutoScalingPolicy: class AutoScalingPolicy:
"""自动扩缩容策略""" """自动扩缩容策略"""
@@ -237,6 +253,7 @@ class AutoScalingPolicy:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class ScalingEvent: class ScalingEvent:
"""扩缩容事件""" """扩缩容事件"""
@@ -254,6 +271,7 @@ class ScalingEvent:
completed_at: str | None completed_at: str | None
error_message: str | None error_message: str | None
@dataclass @dataclass
class HealthCheck: class HealthCheck:
"""健康检查配置""" """健康检查配置"""
@@ -274,6 +292,7 @@ class HealthCheck:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class HealthCheckResult: class HealthCheckResult:
"""健康检查结果""" """健康检查结果"""
@@ -287,6 +306,7 @@ class HealthCheckResult:
details: dict details: dict
checked_at: str checked_at: str
@dataclass @dataclass
class FailoverConfig: class FailoverConfig:
"""故障转移配置""" """故障转移配置"""
@@ -304,6 +324,7 @@ class FailoverConfig:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class FailoverEvent: class FailoverEvent:
"""故障转移事件""" """故障转移事件"""
@@ -319,6 +340,7 @@ class FailoverEvent:
completed_at: str | None completed_at: str | None
rolled_back_at: str | None rolled_back_at: str | None
@dataclass @dataclass
class BackupJob: class BackupJob:
"""备份任务""" """备份任务"""
@@ -338,6 +360,7 @@ class BackupJob:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class BackupRecord: class BackupRecord:
"""备份记录""" """备份记录"""
@@ -354,6 +377,7 @@ class BackupRecord:
error_message: str | None error_message: str | None
storage_path: str storage_path: str
@dataclass @dataclass
class CostReport: class CostReport:
"""成本报告""" """成本报告"""
@@ -368,6 +392,7 @@ class CostReport:
anomalies: list[dict] # 异常检测 anomalies: list[dict] # 异常检测
created_at: str created_at: str
@dataclass @dataclass
class ResourceUtilization: class ResourceUtilization:
"""资源利用率""" """资源利用率"""
@@ -383,6 +408,7 @@ class ResourceUtilization:
report_date: str report_date: str
recommendations: list[str] recommendations: list[str]
@dataclass @dataclass
class IdleResource: class IdleResource:
"""闲置资源""" """闲置资源"""
@@ -399,6 +425,7 @@ class IdleResource:
recommendation: str recommendation: str
detected_at: str detected_at: str
@dataclass @dataclass
class CostOptimizationSuggestion: class CostOptimizationSuggestion:
"""成本优化建议""" """成本优化建议"""
@@ -418,6 +445,7 @@ class CostOptimizationSuggestion:
created_at: str created_at: str
applied_at: str | None applied_at: str | None
class OpsManager: class OpsManager:
"""运维与监控管理主类""" """运维与监控管理主类"""
@@ -577,7 +605,10 @@ class OpsManager:
with self._get_db() as conn: with self._get_db() as conn:
set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) set_clause = ", ".join([f"{k} = ?" for k in updates.keys()])
conn.execute(f"UPDATE alert_rules SET {set_clause} WHERE id = ?", list(updates.values()) + [rule_id]) conn.execute(
f"UPDATE alert_rules SET {set_clause} WHERE id = ?",
list(updates.values()) + [rule_id],
)
conn.commit() conn.commit()
return self.get_alert_rule(rule_id) return self.get_alert_rule(rule_id)
@@ -592,7 +623,12 @@ class OpsManager:
# ==================== 告警渠道管理 ==================== # ==================== 告警渠道管理 ====================
def create_alert_channel( def create_alert_channel(
self, tenant_id: str, name: str, channel_type: AlertChannelType, config: dict, severity_filter: list[str] = None self,
tenant_id: str,
name: str,
channel_type: AlertChannelType,
config: dict,
severity_filter: list[str] = None,
) -> AlertChannel: ) -> AlertChannel:
"""创建告警渠道""" """创建告警渠道"""
channel_id = f"ac_{uuid.uuid4().hex[:16]}" channel_id = f"ac_{uuid.uuid4().hex[:16]}"
@@ -643,7 +679,9 @@ class OpsManager:
def get_alert_channel(self, channel_id: str) -> AlertChannel | None: def get_alert_channel(self, channel_id: str) -> AlertChannel | None:
"""获取告警渠道""" """获取告警渠道"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM alert_channels WHERE id = ?", (channel_id,)).fetchone() row = conn.execute(
"SELECT * FROM alert_channels WHERE id = ?", (channel_id,)
).fetchone()
if row: if row:
return self._row_to_alert_channel(row) return self._row_to_alert_channel(row)
@@ -653,7 +691,8 @@ class OpsManager:
"""列出租户的所有告警渠道""" """列出租户的所有告警渠道"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM alert_channels WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id,) "SELECT * FROM alert_channels WHERE tenant_id = ? ORDER BY created_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_alert_channel(row) for row in rows] return [self._row_to_alert_channel(row) for row in rows]
@@ -779,7 +818,9 @@ class OpsManager:
for rule in rules: for rule in rules:
# 获取相关指标 # 获取相关指标
metrics = self.get_recent_metrics(tenant_id, rule.metric, seconds=rule.duration + rule.evaluation_interval) metrics = self.get_recent_metrics(
tenant_id, rule.metric, seconds=rule.duration + rule.evaluation_interval
)
# 评估规则 # 评估规则
evaluator = self._alert_evaluators.get(rule.rule_type.value) evaluator = self._alert_evaluators.get(rule.rule_type.value)
@@ -921,7 +962,10 @@ class OpsManager:
"card": { "card": {
"config": {"wide_screen_mode": True}, "config": {"wide_screen_mode": True},
"header": { "header": {
"title": {"tag": "plain_text", "content": f"🚨 [{alert.severity.value.upper()}] {alert.title}"}, "title": {
"tag": "plain_text",
"content": f"🚨 [{alert.severity.value.upper()}] {alert.title}",
},
"template": severity_colors.get(alert.severity.value, "blue"), "template": severity_colors.get(alert.severity.value, "blue"),
}, },
"elements": [ "elements": [
@@ -932,7 +976,10 @@ class OpsManager:
"content": f"**描述:** {alert.description}\n\n**指标:** {alert.metric}\n**当前值:** {alert.value}\n**阈值:** {alert.threshold}", "content": f"**描述:** {alert.description}\n\n**指标:** {alert.metric}\n**当前值:** {alert.value}\n**阈值:** {alert.threshold}",
}, },
}, },
{"tag": "div", "text": {"tag": "lark_md", "content": f"**时间:** {alert.started_at}"}}, {
"tag": "div",
"text": {"tag": "lark_md", "content": f"**时间:** {alert.started_at}"},
},
], ],
}, },
} }
@@ -999,7 +1046,10 @@ class OpsManager:
"blocks": [ "blocks": [
{ {
"type": "header", "type": "header",
"text": {"type": "plain_text", "text": f"{emoji} [{alert.severity.value.upper()}] {alert.title}"}, "text": {
"type": "plain_text",
"text": f"{emoji} [{alert.severity.value.upper()}] {alert.title}",
},
}, },
{ {
"type": "section", "type": "section",
@@ -1010,7 +1060,10 @@ class OpsManager:
{"type": "mrkdwn", "text": f"*阈值:*\n{alert.threshold}"}, {"type": "mrkdwn", "text": f"*阈值:*\n{alert.threshold}"},
], ],
}, },
{"type": "context", "elements": [{"type": "mrkdwn", "text": f"触发时间: {alert.started_at}"}]}, {
"type": "context",
"elements": [{"type": "mrkdwn", "text": f"触发时间: {alert.started_at}"}],
},
], ],
} }
@@ -1070,7 +1123,9 @@ class OpsManager:
} }
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post("https://events.pagerduty.com/v2/enqueue", json=message, timeout=30.0) response = await client.post(
"https://events.pagerduty.com/v2/enqueue", json=message, timeout=30.0
)
success = response.status_code == 202 success = response.status_code == 202
self._update_channel_stats(channel.id, success) self._update_channel_stats(channel.id, success)
return success return success
@@ -1095,7 +1150,11 @@ class OpsManager:
"description": alert.description, "description": alert.description,
"priority": priority_map.get(alert.severity.value, "P3"), "priority": priority_map.get(alert.severity.value, "P3"),
"alias": alert.id, "alias": alert.id,
"details": {"metric": alert.metric, "value": str(alert.value), "threshold": str(alert.threshold)}, "details": {
"metric": alert.metric,
"value": str(alert.value),
"threshold": str(alert.threshold),
},
} }
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@@ -1234,17 +1293,22 @@ class OpsManager:
) )
conn.commit() conn.commit()
def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool) -> None: def _update_alert_notification_status(
self, alert_id: str, channel_id: str, success: bool
) -> None:
"""更新告警通知状态""" """更新告警通知状态"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone() row = conn.execute(
"SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)
).fetchone()
if row: if row:
notification_sent = json.loads(row["notification_sent"]) notification_sent = json.loads(row["notification_sent"])
notification_sent[channel_id] = success notification_sent[channel_id] = success
conn.execute( conn.execute(
"UPDATE alerts SET notification_sent = ? WHERE id = ?", (json.dumps(notification_sent), alert_id) "UPDATE alerts SET notification_sent = ? WHERE id = ?",
(json.dumps(notification_sent), alert_id),
) )
conn.commit() conn.commit()
@@ -1409,7 +1473,9 @@ class OpsManager:
return metric return metric
def get_recent_metrics(self, tenant_id: str, metric_name: str, seconds: int = 3600) -> list[ResourceMetric]: def get_recent_metrics(
self, tenant_id: str, metric_name: str, seconds: int = 3600
) -> list[ResourceMetric]:
"""获取最近的指标数据""" """获取最近的指标数据"""
cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat() cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat()
@@ -1459,7 +1525,9 @@ class OpsManager:
now = datetime.now().isoformat() now = datetime.now().isoformat()
# 基于历史数据预测 # 基于历史数据预测
metrics = self.get_recent_metrics(tenant_id, f"{resource_type.value}_usage", seconds=30 * 24 * 3600) metrics = self.get_recent_metrics(
tenant_id, f"{resource_type.value}_usage", seconds=30 * 24 * 3600
)
if metrics: if metrics:
values = [m.metric_value for m in metrics] values = [m.metric_value for m in metrics]
@@ -1553,7 +1621,8 @@ class OpsManager:
"""获取容量规划列表""" """获取容量规划列表"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM capacity_plans WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id,) "SELECT * FROM capacity_plans WHERE tenant_id = ? ORDER BY created_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_capacity_plan(row) for row in rows] return [self._row_to_capacity_plan(row) for row in rows]
@@ -1629,7 +1698,9 @@ class OpsManager:
def get_auto_scaling_policy(self, policy_id: str) -> AutoScalingPolicy | None: def get_auto_scaling_policy(self, policy_id: str) -> AutoScalingPolicy | None:
"""获取自动扩缩容策略""" """获取自动扩缩容策略"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,)).fetchone() row = conn.execute(
"SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,)
).fetchone()
if row: if row:
return self._row_to_auto_scaling_policy(row) return self._row_to_auto_scaling_policy(row)
@@ -1639,7 +1710,8 @@ class OpsManager:
"""列出租户的自动扩缩容策略""" """列出租户的自动扩缩容策略"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM auto_scaling_policies WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id,) "SELECT * FROM auto_scaling_policies WHERE tenant_id = ? ORDER BY created_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_auto_scaling_policy(row) for row in rows] return [self._row_to_auto_scaling_policy(row) for row in rows]
@@ -1664,7 +1736,9 @@ class OpsManager:
if current_utilization > policy.scale_up_threshold: if current_utilization > policy.scale_up_threshold:
if current_instances < policy.max_instances: if current_instances < policy.max_instances:
action = ScalingAction.SCALE_UP action = ScalingAction.SCALE_UP
reason = f"利用率 {current_utilization:.1%} 超过扩容阈值 {policy.scale_up_threshold:.1%}" reason = (
f"利用率 {current_utilization:.1%} 超过扩容阈值 {policy.scale_up_threshold:.1%}"
)
elif current_utilization < policy.scale_down_threshold: elif current_utilization < policy.scale_down_threshold:
if current_instances > policy.min_instances: if current_instances > policy.min_instances:
action = ScalingAction.SCALE_DOWN action = ScalingAction.SCALE_DOWN
@@ -1681,7 +1755,12 @@ class OpsManager:
return None return None
def _create_scaling_event( def _create_scaling_event(
self, policy: AutoScalingPolicy, action: ScalingAction, from_count: int, to_count: int, reason: str self,
policy: AutoScalingPolicy,
action: ScalingAction,
from_count: int,
to_count: int,
reason: str,
) -> ScalingEvent: ) -> ScalingEvent:
"""创建扩缩容事件""" """创建扩缩容事件"""
event_id = f"se_{uuid.uuid4().hex[:16]}" event_id = f"se_{uuid.uuid4().hex[:16]}"
@@ -1741,7 +1820,9 @@ class OpsManager:
return self._row_to_scaling_event(row) return self._row_to_scaling_event(row)
return None return None
def update_scaling_event_status(self, event_id: str, status: str, error_message: str = None) -> ScalingEvent | None: def update_scaling_event_status(
self, event_id: str, status: str, error_message: str = None
) -> ScalingEvent | None:
"""更新扩缩容事件状态""" """更新扩缩容事件状态"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1777,7 +1858,9 @@ class OpsManager:
return self._row_to_scaling_event(row) return self._row_to_scaling_event(row)
return None return None
def list_scaling_events(self, tenant_id: str, policy_id: str = None, limit: int = 100) -> list[ScalingEvent]: def list_scaling_events(
self, tenant_id: str, policy_id: str = None, limit: int = 100
) -> list[ScalingEvent]:
"""列出租户的扩缩容事件""" """列出租户的扩缩容事件"""
query = "SELECT * FROM scaling_events WHERE tenant_id = ?" query = "SELECT * FROM scaling_events WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -1873,7 +1956,8 @@ class OpsManager:
"""列出租户的健康检查""" """列出租户的健康检查"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM health_checks WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id,) "SELECT * FROM health_checks WHERE tenant_id = ? ORDER BY created_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_health_check(row) for row in rows] return [self._row_to_health_check(row) for row in rows]
@@ -1947,7 +2031,11 @@ class OpsManager:
if response.status_code == expected_status: if response.status_code == expected_status:
return HealthStatus.HEALTHY, response_time, "OK" return HealthStatus.HEALTHY, response_time, "OK"
else: else:
return HealthStatus.DEGRADED, response_time, f"Unexpected status: {response.status_code}" return (
HealthStatus.DEGRADED,
response_time,
f"Unexpected status: {response.status_code}",
)
except Exception as e: except Exception as e:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e) return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e)
@@ -1962,7 +2050,9 @@ class OpsManager:
start_time = time.time() start_time = time.time()
try: try:
reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=check.timeout) reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port), timeout=check.timeout
)
response_time = (time.time() - start_time) * 1000 response_time = (time.time() - start_time) * 1000
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
@@ -2057,7 +2147,9 @@ class OpsManager:
def get_failover_config(self, config_id: str) -> FailoverConfig | None: def get_failover_config(self, config_id: str) -> FailoverConfig | None:
"""获取故障转移配置""" """获取故障转移配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM failover_configs WHERE id = ?", (config_id,)).fetchone() row = conn.execute(
"SELECT * FROM failover_configs WHERE id = ?", (config_id,)
).fetchone()
if row: if row:
return self._row_to_failover_config(row) return self._row_to_failover_config(row)
@@ -2067,7 +2159,8 @@ class OpsManager:
"""列出租户的故障转移配置""" """列出租户的故障转移配置"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM failover_configs WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id,) "SELECT * FROM failover_configs WHERE tenant_id = ? ORDER BY created_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_failover_config(row) for row in rows] return [self._row_to_failover_config(row) for row in rows]
@@ -2256,7 +2349,8 @@ class OpsManager:
"""列出租户的备份任务""" """列出租户的备份任务"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM backup_jobs WHERE tenant_id = ? ORDER BY created_at DESC", (tenant_id,) "SELECT * FROM backup_jobs WHERE tenant_id = ? ORDER BY created_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_backup_job(row) for row in rows] return [self._row_to_backup_job(row) for row in rows]
@@ -2334,7 +2428,9 @@ class OpsManager:
return self._row_to_backup_record(row) return self._row_to_backup_record(row)
return None return None
def list_backup_records(self, tenant_id: str, job_id: str = None, limit: int = 100) -> list[BackupRecord]: def list_backup_records(
self, tenant_id: str, job_id: str = None, limit: int = 100
) -> list[BackupRecord]:
"""列出租户的备份记录""" """列出租户的备份记录"""
query = "SELECT * FROM backup_records WHERE tenant_id = ?" query = "SELECT * FROM backup_records WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -2379,7 +2475,9 @@ class OpsManager:
# 简化计算:假设每单位资源每月成本 # 简化计算:假设每单位资源每月成本
unit_cost = 10.0 unit_cost = 10.0
resource_cost = unit_cost * util.utilization_rate resource_cost = unit_cost * util.utilization_rate
breakdown[util.resource_type.value] = breakdown.get(util.resource_type.value, 0) + resource_cost breakdown[util.resource_type.value] = (
breakdown.get(util.resource_type.value, 0) + resource_cost
)
total_cost += resource_cost total_cost += resource_cost
# 检测异常 # 检测异常
@@ -2457,7 +2555,11 @@ class OpsManager:
def _calculate_cost_trends(self, tenant_id: str, year: int, month: int) -> dict: def _calculate_cost_trends(self, tenant_id: str, year: int, month: int) -> dict:
"""计算成本趋势""" """计算成本趋势"""
# 简化实现:返回模拟趋势 # 简化实现:返回模拟趋势
return {"month_over_month": 0.05, "year_over_year": 0.15, "forecast_next_month": 1.05} # 5% 增长 # 15% 增长 return {
"month_over_month": 0.05,
"year_over_year": 0.15,
"forecast_next_month": 1.05,
} # 5% 增长 # 15% 增长
def record_resource_utilization( def record_resource_utilization(
self, self,
@@ -2512,7 +2614,9 @@ class OpsManager:
return util return util
def get_resource_utilizations(self, tenant_id: str, report_period: str) -> list[ResourceUtilization]: def get_resource_utilizations(
self, tenant_id: str, report_period: str
) -> list[ResourceUtilization]:
"""获取资源利用率列表""" """获取资源利用率列表"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2590,11 +2694,14 @@ class OpsManager:
"""获取闲置资源列表""" """获取闲置资源列表"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM idle_resources WHERE tenant_id = ? ORDER BY detected_at DESC", (tenant_id,) "SELECT * FROM idle_resources WHERE tenant_id = ? ORDER BY detected_at DESC",
(tenant_id,),
).fetchall() ).fetchall()
return [self._row_to_idle_resource(row) for row in rows] return [self._row_to_idle_resource(row) for row in rows]
def generate_cost_optimization_suggestions(self, tenant_id: str) -> list[CostOptimizationSuggestion]: def generate_cost_optimization_suggestions(
self, tenant_id: str
) -> list[CostOptimizationSuggestion]:
"""生成成本优化建议""" """生成成本优化建议"""
suggestions = [] suggestions = []
@@ -2677,7 +2784,9 @@ class OpsManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_cost_optimization_suggestion(row) for row in rows] return [self._row_to_cost_optimization_suggestion(row) for row in rows]
def apply_cost_optimization_suggestion(self, suggestion_id: str) -> CostOptimizationSuggestion | None: def apply_cost_optimization_suggestion(
self, suggestion_id: str
) -> CostOptimizationSuggestion | None:
"""应用成本优化建议""" """应用成本优化建议"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -2694,10 +2803,14 @@ class OpsManager:
return self.get_cost_optimization_suggestion(suggestion_id) return self.get_cost_optimization_suggestion(suggestion_id)
def get_cost_optimization_suggestion(self, suggestion_id: str) -> CostOptimizationSuggestion | None: def get_cost_optimization_suggestion(
self, suggestion_id: str
) -> CostOptimizationSuggestion | None:
"""获取成本优化建议详情""" """获取成本优化建议详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,)).fetchone() row = conn.execute(
"SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,)
).fetchone()
if row: if row:
return self._row_to_cost_optimization_suggestion(row) return self._row_to_cost_optimization_suggestion(row)
@@ -2980,9 +3093,11 @@ class OpsManager:
applied_at=row["applied_at"], applied_at=row["applied_at"],
) )
# Singleton instance # Singleton instance
_ops_manager = None _ops_manager = None
def get_ops_manager() -> OpsManager: def get_ops_manager() -> OpsManager:
global _ops_manager global _ops_manager
if _ops_manager is None: if _ops_manager is None:

View File

@@ -9,6 +9,7 @@ from datetime import datetime
import oss2 import oss2
class OSSUploader: class OSSUploader:
def __init__(self): def __init__(self):
self.access_key = os.getenv("ALI_ACCESS_KEY") self.access_key = os.getenv("ALI_ACCESS_KEY")
@@ -40,9 +41,11 @@ class OSSUploader:
"""删除 OSS 对象""" """删除 OSS 对象"""
self.bucket.delete_object(object_name) self.bucket.delete_object(object_name)
# 单例 # 单例
_oss_uploader = None _oss_uploader = None
def get_oss_uploader() -> OSSUploader: def get_oss_uploader() -> OSSUploader:
global _oss_uploader global _oss_uploader
if _oss_uploader is None: if _oss_uploader is None:

View File

@@ -42,6 +42,7 @@ except ImportError:
# ==================== 数据模型 ==================== # ==================== 数据模型 ====================
@dataclass @dataclass
class CacheStats: class CacheStats:
"""缓存统计数据模型""" """缓存统计数据模型"""
@@ -58,6 +59,7 @@ class CacheStats:
if self.total_requests > 0: if self.total_requests > 0:
self.hit_rate = round(self.hits / self.total_requests, 4) self.hit_rate = round(self.hits / self.total_requests, 4)
@dataclass @dataclass
class CacheEntry: class CacheEntry:
"""缓存条目数据模型""" """缓存条目数据模型"""
@@ -70,6 +72,7 @@ class CacheEntry:
last_accessed: float = 0 last_accessed: float = 0
size_bytes: int = 0 size_bytes: int = 0
@dataclass @dataclass
class PerformanceMetric: class PerformanceMetric:
"""性能指标数据模型""" """性能指标数据模型"""
@@ -91,6 +94,7 @@ class PerformanceMetric:
"metadata": self.metadata, "metadata": self.metadata,
} }
@dataclass @dataclass
class TaskInfo: class TaskInfo:
"""任务信息数据模型""" """任务信息数据模型"""
@@ -122,6 +126,7 @@ class TaskInfo:
"max_retries": self.max_retries, "max_retries": self.max_retries,
} }
@dataclass @dataclass
class ShardInfo: class ShardInfo:
"""分片信息数据模型""" """分片信息数据模型"""
@@ -134,8 +139,10 @@ class ShardInfo:
created_at: str = "" created_at: str = ""
last_accessed: str = "" last_accessed: str = ""
# ==================== Redis 缓存层 ==================== # ==================== Redis 缓存层 ====================
class CacheManager: class CacheManager:
""" """
缓存管理器 缓存管理器
@@ -213,8 +220,12 @@ class CacheManager:
) )
""") """)
conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_type ON performance_metrics(metric_type)") conn.execute(
conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_time ON performance_metrics(timestamp)") "CREATE INDEX IF NOT EXISTS idx_metrics_type ON performance_metrics(metric_type)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_metrics_time ON performance_metrics(timestamp)"
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -229,7 +240,10 @@ class CacheManager:
def _evict_lru(self, required_space: int = 0) -> None: def _evict_lru(self, required_space: int = 0) -> None:
"""LRU 淘汰策略""" """LRU 淘汰策略"""
with self.cache_lock: with self.cache_lock:
while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache: while (
self.current_memory_size + required_space > self.max_memory_size
and self.memory_cache
):
# 移除最久未访问的 # 移除最久未访问的
oldest_key, oldest_entry = self.memory_cache.popitem(last=False) oldest_key, oldest_entry = self.memory_cache.popitem(last=False)
self.current_memory_size -= oldest_entry.size_bytes self.current_memory_size -= oldest_entry.size_bytes
@@ -429,7 +443,9 @@ class CacheManager:
{ {
"memory_size_bytes": self.current_memory_size, "memory_size_bytes": self.current_memory_size,
"max_memory_size_bytes": self.max_memory_size, "max_memory_size_bytes": self.max_memory_size,
"memory_usage_percent": round(self.current_memory_size / self.max_memory_size * 100, 2), "memory_usage_percent": round(
self.current_memory_size / self.max_memory_size * 100, 2
),
"cache_entries": len(self.memory_cache), "cache_entries": len(self.memory_cache),
} }
) )
@@ -531,7 +547,9 @@ class CacheManager:
stats["transcripts"] += 1 stats["transcripts"] += 1
# 预热项目知识库摘要 # 预热项目知识库摘要
entity_count = conn.execute("SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,)).fetchone()[0] entity_count = conn.execute(
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,)
).fetchone()[0]
relation_count = conn.execute( relation_count = conn.execute(
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,) "SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,)
@@ -581,8 +599,10 @@ class CacheManager:
return count return count
# ==================== 数据库分片 ==================== # ==================== 数据库分片 ====================
class DatabaseSharding: class DatabaseSharding:
""" """
数据库分片管理器 数据库分片管理器
@@ -594,7 +614,12 @@ class DatabaseSharding:
- 分片迁移工具 - 分片迁移工具
""" """
def __init__(self, base_db_path: str = "insightflow.db", shard_db_dir: str = "./shards", shards_count: int = 4): def __init__(
self,
base_db_path: str = "insightflow.db",
shard_db_dir: str = "./shards",
shards_count: int = 4,
):
self.base_db_path = base_db_path self.base_db_path = base_db_path
self.shard_db_dir = shard_db_dir self.shard_db_dir = shard_db_dir
self.shards_count = shards_count self.shards_count = shards_count
@@ -731,7 +756,9 @@ class DatabaseSharding:
source_conn = sqlite3.connect(source_info.db_path) source_conn = sqlite3.connect(source_info.db_path)
source_conn.row_factory = sqlite3.Row source_conn.row_factory = sqlite3.Row
entities = source_conn.execute("SELECT * FROM entities WHERE project_id = ?", (project_id,)).fetchall() entities = source_conn.execute(
"SELECT * FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
relations = source_conn.execute( relations = source_conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id,) "SELECT * FROM entity_relations WHERE project_id = ?", (project_id,)
@@ -875,8 +902,10 @@ class DatabaseSharding:
"message": "Rebalancing analysis completed", "message": "Rebalancing analysis completed",
} }
# ==================== 异步任务队列 ==================== # ==================== 异步任务队列 ====================
class TaskQueue: class TaskQueue:
""" """
异步任务队列管理器 异步任务队列管理器
@@ -1031,7 +1060,9 @@ class TaskQueue:
if task.retry_count <= task.max_retries: if task.retry_count <= task.max_retries:
task.status = "retrying" task.status = "retrying"
# 延迟重试 # 延迟重试
threading.Timer(10 * task.retry_count, self._execute_task, args=(task_id,)).start() threading.Timer(
10 * task.retry_count, self._execute_task, args=(task_id,)
).start()
else: else:
task.status = "failed" task.status = "failed"
task.error_message = str(e) task.error_message = str(e)
@@ -1131,7 +1162,9 @@ class TaskQueue:
with self.task_lock: with self.task_lock:
return self.tasks.get(task_id) return self.tasks.get(task_id)
def list_tasks(self, status: str | None = None, task_type: str | None = None, limit: int = 100) -> list[TaskInfo]: def list_tasks(
self, status: str | None = None, task_type: str | None = None, limit: int = 100
) -> list[TaskInfo]:
"""列出任务""" """列出任务"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -1254,8 +1287,10 @@ class TaskQueue:
"backend": "celery" if self.use_celery else "memory", "backend": "celery" if self.use_celery else "memory",
} }
# ==================== 性能监控 ==================== # ==================== 性能监控 ====================
class PerformanceMonitor: class PerformanceMonitor:
""" """
性能监控器 性能监控器
@@ -1268,7 +1303,10 @@ class PerformanceMonitor:
""" """
def __init__( def __init__(
self, db_path: str = "insightflow.db", slow_query_threshold: int = 1000, alert_threshold: int = 5000 # 毫秒 self,
db_path: str = "insightflow.db",
slow_query_threshold: int = 1000,
alert_threshold: int = 5000, # 毫秒
): # 毫秒 ): # 毫秒
self.db_path = db_path self.db_path = db_path
self.slow_query_threshold = slow_query_threshold self.slow_query_threshold = slow_query_threshold
@@ -1283,7 +1321,11 @@ class PerformanceMonitor:
self.alert_handlers: list[Callable] = [] self.alert_handlers: list[Callable] = []
def record_metric( def record_metric(
self, metric_type: str, duration_ms: float, endpoint: str | None = None, metadata: dict | None = None self,
metric_type: str,
duration_ms: float,
endpoint: str | None = None,
metadata: dict | None = None,
): ):
""" """
记录性能指标 记录性能指标
@@ -1565,10 +1607,15 @@ class PerformanceMonitor:
return deleted return deleted
# ==================== 性能装饰器 ==================== # ==================== 性能装饰器 ====================
def cached( def cached(
cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None cache_manager: CacheManager,
key_prefix: str = "",
ttl: int = 3600,
key_func: Callable | None = None,
) -> None: ) -> None:
""" """
缓存装饰器 缓存装饰器
@@ -1608,6 +1655,7 @@ def cached(
return decorator return decorator
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None: def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
""" """
性能监控装饰器 性能监控装饰器
@@ -1635,8 +1683,10 @@ def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | Non
return decorator return decorator
# ==================== 性能管理器 ==================== # ==================== 性能管理器 ====================
class PerformanceManager: class PerformanceManager:
""" """
性能管理器 - 统一入口 性能管理器 - 统一入口
@@ -1644,7 +1694,12 @@ class PerformanceManager:
整合缓存管理、数据库分片、任务队列和性能监控功能 整合缓存管理、数据库分片、任务队列和性能监控功能
""" """
def __init__(self, db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False): def __init__(
self,
db_path: str = "insightflow.db",
redis_url: str | None = None,
enable_sharding: bool = False,
):
self.db_path = db_path self.db_path = db_path
# 初始化各模块 # 初始化各模块
@@ -1693,14 +1748,18 @@ class PerformanceManager:
return stats return stats
# 单例模式 # 单例模式
_performance_manager = None _performance_manager = None
def get_performance_manager( def get_performance_manager(
db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False
) -> PerformanceManager: ) -> PerformanceManager:
"""获取性能管理器单例""" """获取性能管理器单例"""
global _performance_manager global _performance_manager
if _performance_manager is None: if _performance_manager is None:
_performance_manager = PerformanceManager(db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding) _performance_manager = PerformanceManager(
db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding
)
return _performance_manager return _performance_manager

View File

@@ -11,6 +11,7 @@ import json
import os import os
import sqlite3 import sqlite3
import time import time
import urllib.parse
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
@@ -27,6 +28,7 @@ try:
except ImportError: except ImportError:
WEBDAV_AVAILABLE = False WEBDAV_AVAILABLE = False
class PluginType(Enum): class PluginType(Enum):
"""插件类型""" """插件类型"""
@@ -38,6 +40,7 @@ class PluginType(Enum):
WEBDAV = "webdav" WEBDAV = "webdav"
CUSTOM = "custom" CUSTOM = "custom"
class PluginStatus(Enum): class PluginStatus(Enum):
"""插件状态""" """插件状态"""
@@ -46,6 +49,7 @@ class PluginStatus(Enum):
ERROR = "error" ERROR = "error"
PENDING = "pending" PENDING = "pending"
@dataclass @dataclass
class Plugin: class Plugin:
"""插件配置""" """插件配置"""
@@ -61,6 +65,7 @@ class Plugin:
last_used_at: str | None = None last_used_at: str | None = None
use_count: int = 0 use_count: int = 0
@dataclass @dataclass
class PluginConfig: class PluginConfig:
"""插件详细配置""" """插件详细配置"""
@@ -73,6 +78,7 @@ class PluginConfig:
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@dataclass @dataclass
class BotSession: class BotSession:
"""机器人会话""" """机器人会话"""
@@ -90,6 +96,7 @@ class BotSession:
last_message_at: str | None = None last_message_at: str | None = None
message_count: int = 0 message_count: int = 0
@dataclass @dataclass
class WebhookEndpoint: class WebhookEndpoint:
"""Webhook 端点配置Zapier/Make集成""" """Webhook 端点配置Zapier/Make集成"""
@@ -108,6 +115,7 @@ class WebhookEndpoint:
last_triggered_at: str | None = None last_triggered_at: str | None = None
trigger_count: int = 0 trigger_count: int = 0
@dataclass @dataclass
class WebDAVSync: class WebDAVSync:
"""WebDAV 同步配置""" """WebDAV 同步配置"""
@@ -129,6 +137,7 @@ class WebDAVSync:
updated_at: str = "" updated_at: str = ""
sync_count: int = 0 sync_count: int = 0
@dataclass @dataclass
class ChromeExtensionToken: class ChromeExtensionToken:
"""Chrome 扩展令牌""" """Chrome 扩展令牌"""
@@ -145,6 +154,7 @@ class ChromeExtensionToken:
use_count: int = 0 use_count: int = 0
is_revoked: bool = False is_revoked: bool = False
class PluginManager: class PluginManager:
"""插件管理主类""" """插件管理主类"""
@@ -206,7 +216,9 @@ class PluginManager:
return self._row_to_plugin(row) return self._row_to_plugin(row)
return None return None
def list_plugins(self, project_id: str = None, plugin_type: str = None, status: str = None) -> list[Plugin]: def list_plugins(
self, project_id: str = None, plugin_type: str = None, status: str = None
) -> list[Plugin]:
"""列出插件""" """列出插件"""
conn = self.db.get_conn() conn = self.db.get_conn()
@@ -225,7 +237,9 @@ class PluginManager:
where_clause = " AND ".join(conditions) if conditions else "1=1" where_clause = " AND ".join(conditions) if conditions else "1=1"
rows = conn.execute(f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params).fetchall() rows = conn.execute(
f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params
).fetchall()
conn.close() conn.close()
return [self._row_to_plugin(row) for row in rows] return [self._row_to_plugin(row) for row in rows]
@@ -292,7 +306,9 @@ class PluginManager:
# ==================== Plugin Config ==================== # ==================== Plugin Config ====================
def set_plugin_config(self, plugin_id: str, key: str, value: str, is_encrypted: bool = False) -> PluginConfig: def set_plugin_config(
self, plugin_id: str, key: str, value: str, is_encrypted: bool = False
) -> PluginConfig:
"""设置插件配置""" """设置插件配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -336,7 +352,8 @@ class PluginManager:
"""获取插件配置""" """获取插件配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
row = conn.execute( row = conn.execute(
"SELECT config_value FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key) "SELECT config_value FROM plugin_configs WHERE plugin_id = ? AND config_key = ?",
(plugin_id, key),
).fetchone() ).fetchone()
conn.close() conn.close()
@@ -355,7 +372,9 @@ class PluginManager:
def delete_plugin_config(self, plugin_id: str, key: str) -> bool: def delete_plugin_config(self, plugin_id: str, key: str) -> bool:
"""删除插件配置""" """删除插件配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
cursor = conn.execute("DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key)) cursor = conn.execute(
"DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key)
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -375,6 +394,7 @@ class PluginManager:
conn.commit() conn.commit()
conn.close() conn.close()
class ChromeExtensionHandler: class ChromeExtensionHandler:
"""Chrome 扩展处理器""" """Chrome 扩展处理器"""
@@ -485,13 +505,17 @@ class ChromeExtensionHandler:
def revoke_token(self, token_id: str) -> bool: def revoke_token(self, token_id: str) -> bool:
"""撤销令牌""" """撤销令牌"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
cursor = conn.execute("UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,)) cursor = conn.execute(
"UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,)
)
conn.commit() conn.commit()
conn.close() conn.close()
return cursor.rowcount > 0 return cursor.rowcount > 0
def list_tokens(self, user_id: str = None, project_id: str = None) -> list[ChromeExtensionToken]: def list_tokens(
self, user_id: str = None, project_id: str = None
) -> list[ChromeExtensionToken]:
"""列出令牌""" """列出令牌"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -508,7 +532,8 @@ class ChromeExtensionHandler:
where_clause = " AND ".join(conditions) where_clause = " AND ".join(conditions)
rows = conn.execute( rows = conn.execute(
f"SELECT * FROM chrome_extension_tokens WHERE {where_clause} ORDER BY created_at DESC", params f"SELECT * FROM chrome_extension_tokens WHERE {where_clause} ORDER BY created_at DESC",
params,
).fetchall() ).fetchall()
conn.close() conn.close()
@@ -533,7 +558,12 @@ class ChromeExtensionHandler:
return tokens return tokens
async def import_webpage( async def import_webpage(
self, token: ChromeExtensionToken, url: str, title: str, content: str, html_content: str = None self,
token: ChromeExtensionToken,
url: str,
title: str,
content: str,
html_content: str = None,
) -> dict: ) -> dict:
"""导入网页内容""" """导入网页内容"""
if not token.project_id: if not token.project_id:
@@ -568,6 +598,7 @@ class ChromeExtensionHandler:
"content_length": len(content), "content_length": len(content),
} }
class BotHandler: class BotHandler:
"""飞书/钉钉机器人处理器""" """飞书/钉钉机器人处理器"""
@@ -576,7 +607,12 @@ class BotHandler:
self.bot_type = bot_type self.bot_type = bot_type
def create_session( def create_session(
self, session_id: str, session_name: str, project_id: str = None, webhook_url: str = "", secret: str = "" self,
session_id: str,
session_name: str,
project_id: str = None,
webhook_url: str = "",
secret: str = "",
) -> BotSession: ) -> BotSession:
"""创建机器人会话""" """创建机器人会话"""
bot_id = str(uuid.uuid4())[:8] bot_id = str(uuid.uuid4())[:8]
@@ -588,7 +624,19 @@ class BotHandler:
(id, bot_type, session_id, session_name, project_id, webhook_url, secret, (id, bot_type, session_id, session_name, project_id, webhook_url, secret,
is_active, created_at, updated_at, message_count) is_active, created_at, updated_at, message_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(bot_id, self.bot_type, session_id, session_name, project_id, webhook_url, secret, True, now, now, 0), (
bot_id,
self.bot_type,
session_id,
session_name,
project_id,
webhook_url,
secret,
True,
now,
now,
0,
),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -663,7 +711,9 @@ class BotHandler:
values.append(session_id) values.append(session_id)
values.append(self.bot_type) values.append(self.bot_type)
query = f"UPDATE bot_sessions SET {', '.join(updates)} WHERE session_id = ? AND bot_type = ?" query = (
f"UPDATE bot_sessions SET {', '.join(updates)} WHERE session_id = ? AND bot_type = ?"
)
conn.execute(query, values) conn.execute(query, values)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -674,7 +724,8 @@ class BotHandler:
"""删除会话""" """删除会话"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
cursor = conn.execute( cursor = conn.execute(
"DELETE FROM bot_sessions WHERE session_id = ? AND bot_type = ?", (session_id, self.bot_type) "DELETE FROM bot_sessions WHERE session_id = ? AND bot_type = ?",
(session_id, self.bot_type),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -753,13 +804,16 @@ class BotHandler:
return { return {
"success": True, "success": True,
"response": f"""📊 项目状态: "response": f"""📊 项目状态:
实体数量: {stats.get('entity_count', 0)} 实体数量: {stats.get("entity_count", 0)}
关系数量: {stats.get('relation_count', 0)} 关系数量: {stats.get("relation_count", 0)}
转录数量: {stats.get('transcript_count', 0)}""", 转录数量: {stats.get("transcript_count", 0)}""",
} }
# 默认回复 # 默认回复
return {"success": True, "response": f"收到消息:{text[:100]}...\n\n使用 /help 查看可用命令"} return {
"success": True,
"response": f"收到消息:{text[:100]}...\n\n使用 /help 查看可用命令",
}
async def _handle_audio_message(self, session: BotSession, message: dict) -> dict: async def _handle_audio_message(self, session: BotSession, message: dict) -> dict:
"""处理音频消息""" """处理音频消息"""
@@ -820,13 +874,20 @@ class BotHandler:
if session.secret: if session.secret:
string_to_sign = f"{timestamp}\n{session.secret}" string_to_sign = f"{timestamp}\n{session.secret}"
hmac_code = hmac.new( hmac_code = hmac.new(
session.secret.encode("utf-8"), string_to_sign.encode("utf-8"), digestmod=hashlib.sha256 session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
digestmod=hashlib.sha256,
).digest() ).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
else: else:
sign = "" sign = ""
payload = {"timestamp": timestamp, "sign": sign, "msg_type": "text", "content": {"text": message}} payload = {
"timestamp": timestamp,
"sign": sign,
"msg_type": "text",
"content": {"text": message},
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
@@ -834,7 +895,9 @@ class BotHandler:
) )
return response.status_code == 200 return response.status_code == 200
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool: async def _send_dingtalk_message(
self, session: BotSession, message: str, msg_type: str
) -> bool:
"""发送钉钉消息""" """发送钉钉消息"""
timestamp = str(round(time.time() * 1000)) timestamp = str(round(time.time() * 1000))
@@ -842,7 +905,9 @@ class BotHandler:
if session.secret: if session.secret:
string_to_sign = f"{timestamp}\n{session.secret}" string_to_sign = f"{timestamp}\n{session.secret}"
hmac_code = hmac.new( hmac_code = hmac.new(
session.secret.encode("utf-8"), string_to_sign.encode("utf-8"), digestmod=hashlib.sha256 session.secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
digestmod=hashlib.sha256,
).digest() ).digest()
sign = base64.b64encode(hmac_code).decode("utf-8") sign = base64.b64encode(hmac_code).decode("utf-8")
sign = urllib.parse.quote(sign) sign = urllib.parse.quote(sign)
@@ -856,9 +921,12 @@ class BotHandler:
url = f"{url}&timestamp={timestamp}&sign={sign}" url = f"{url}&timestamp={timestamp}&sign={sign}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload, headers={"Content-Type": "application/json"}) response = await client.post(
url, json=payload, headers={"Content-Type": "application/json"}
)
return response.status_code == 200 return response.status_code == 200
class WebhookIntegration: class WebhookIntegration:
"""Zapier/Make Webhook 集成""" """Zapier/Make Webhook 集成"""
@@ -921,7 +989,8 @@ class WebhookIntegration:
"""获取端点""" """获取端点"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
row = conn.execute( row = conn.execute(
"SELECT * FROM webhook_endpoints WHERE id = ? AND endpoint_type = ?", (endpoint_id, self.endpoint_type) "SELECT * FROM webhook_endpoints WHERE id = ? AND endpoint_type = ?",
(endpoint_id, self.endpoint_type),
).fetchone() ).fetchone()
conn.close() conn.close()
@@ -1039,7 +1108,9 @@ class WebhookIntegration:
payload = {"event": event_type, "timestamp": datetime.now().isoformat(), "data": data} payload = {"event": event_type, "timestamp": datetime.now().isoformat(), "data": data}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0) response = await client.post(
endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0
)
success = response.status_code in [200, 201, 202] success = response.status_code in [200, 201, 202]
@@ -1078,6 +1149,7 @@ class WebhookIntegration:
"message": "Test event sent successfully" if success else "Failed to send test event", "message": "Test event sent successfully" if success else "Failed to send test event",
} }
class WebDAVSyncManager: class WebDAVSyncManager:
"""WebDAV 同步管理""" """WebDAV 同步管理"""
@@ -1157,7 +1229,8 @@ class WebDAVSyncManager:
if project_id: if project_id:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM webdav_syncs WHERE project_id = ? ORDER BY created_at DESC",
(project_id,),
).fetchall() ).fetchall()
else: else:
rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall() rows = conn.execute("SELECT * FROM webdav_syncs ORDER BY created_at DESC").fetchall()
@@ -1278,7 +1351,11 @@ class WebDAVSyncManager:
transcripts = self.pm.db.list_project_transcripts(sync.project_id) transcripts = self.pm.db.list_project_transcripts(sync.project_id)
export_data = { export_data = {
"project": {"id": project.id, "name": project.name, "description": project.description}, "project": {
"id": project.id,
"name": project.name,
"description": project.description,
},
"entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], "entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities],
"relations": relations, "relations": relations,
"transcripts": [{"id": t["id"], "filename": t["filename"]} for t in transcripts], "transcripts": [{"id": t["id"], "filename": t["filename"]} for t in transcripts],
@@ -1333,9 +1410,11 @@ class WebDAVSyncManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
# Singleton instance # Singleton instance
_plugin_manager = None _plugin_manager = None
def get_plugin_manager(db_manager=None) -> None: def get_plugin_manager(db_manager=None) -> None:
"""获取 PluginManager 单例""" """获取 PluginManager 单例"""
global _plugin_manager global _plugin_manager

View File

@@ -12,6 +12,7 @@ from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
@dataclass @dataclass
class RateLimitConfig: class RateLimitConfig:
"""限流配置""" """限流配置"""
@@ -20,6 +21,7 @@ class RateLimitConfig:
burst_size: int = 10 # 突发请求数 burst_size: int = 10 # 突发请求数
window_size: int = 60 # 窗口大小(秒) window_size: int = 60 # 窗口大小(秒)
@dataclass @dataclass
class RateLimitInfo: class RateLimitInfo:
"""限流信息""" """限流信息"""
@@ -29,6 +31,7 @@ class RateLimitInfo:
reset_time: int # 重置时间戳 reset_time: int # 重置时间戳
retry_after: int # 需要等待的秒数 retry_after: int # 需要等待的秒数
class SlidingWindowCounter: class SlidingWindowCounter:
"""滑动窗口计数器""" """滑动窗口计数器"""
@@ -60,6 +63,7 @@ class SlidingWindowCounter:
for k in old_keys: for k in old_keys:
self.requests.pop(k, None) self.requests.pop(k, None)
class RateLimiter: class RateLimiter:
"""API 限流器""" """API 限流器"""
@@ -106,13 +110,18 @@ class RateLimiter:
# 检查是否超过限制 # 检查是否超过限制
if current_count >= stored_config.requests_per_minute: if current_count >= stored_config.requests_per_minute:
return RateLimitInfo( return RateLimitInfo(
allowed=False, remaining=0, reset_time=reset_time, retry_after=stored_config.window_size allowed=False,
remaining=0,
reset_time=reset_time,
retry_after=stored_config.window_size,
) )
# 允许请求,增加计数 # 允许请求,增加计数
await counter.add_request() await counter.add_request()
return RateLimitInfo(allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0) return RateLimitInfo(
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0
)
async def get_limit_info(self, key: str) -> RateLimitInfo: async def get_limit_info(self, key: str) -> RateLimitInfo:
"""获取限流信息(不增加计数)""" """获取限流信息(不增加计数)"""
@@ -136,7 +145,9 @@ class RateLimiter:
allowed=current_count < config.requests_per_minute, allowed=current_count < config.requests_per_minute,
remaining=remaining, remaining=remaining,
reset_time=reset_time, reset_time=reset_time,
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0, retry_after=max(0, config.window_size)
if current_count >= config.requests_per_minute
else 0,
) )
def reset(self, key: str | None = None) -> None: def reset(self, key: str | None = None) -> None:
@@ -148,9 +159,11 @@ class RateLimiter:
self.counters.clear() self.counters.clear()
self.configs.clear() self.configs.clear()
# 全局限流器实例 # 全局限流器实例
_rate_limiter: RateLimiter | None = None _rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter: def get_rate_limiter() -> RateLimiter:
"""获取限流器实例""" """获取限流器实例"""
global _rate_limiter global _rate_limiter
@@ -158,6 +171,7 @@ def get_rate_limiter() -> RateLimiter:
_rate_limiter = RateLimiter() _rate_limiter = RateLimiter()
return _rate_limiter return _rate_limiter
# 限流装饰器(用于函数级别限流) # 限流装饰器(用于函数级别限流)
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None: def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
""" """
@@ -178,7 +192,9 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
info = await limiter.is_allowed(key, config) info = await limiter.is_allowed(key, config)
if not info.allowed: if not info.allowed:
raise RateLimitExceeded(f"Rate limit exceeded. Try again in {info.retry_after} seconds.") raise RateLimitExceeded(
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
)
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -189,7 +205,9 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
info = asyncio.run(limiter.is_allowed(key, config)) info = asyncio.run(limiter.is_allowed(key, config))
if not info.allowed: if not info.allowed:
raise RateLimitExceeded(f"Rate limit exceeded. Try again in {info.retry_after} seconds.") raise RateLimitExceeded(
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
)
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -197,5 +215,6 @@ def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None)
return decorator return decorator
class RateLimitExceeded(Exception): class RateLimitExceeded(Exception):
"""限流异常""" """限流异常"""

View File

@@ -19,6 +19,7 @@ from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
class SearchOperator(Enum): class SearchOperator(Enum):
"""搜索操作符""" """搜索操作符"""
@@ -26,6 +27,7 @@ class SearchOperator(Enum):
OR = "OR" OR = "OR"
NOT = "NOT" NOT = "NOT"
# 尝试导入 sentence-transformers 用于语义搜索 # 尝试导入 sentence-transformers 用于语义搜索
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@@ -37,6 +39,7 @@ except ImportError:
# ==================== 数据模型 ==================== # ==================== 数据模型 ====================
@dataclass @dataclass
class SearchResult: class SearchResult:
"""搜索结果数据模型""" """搜索结果数据模型"""
@@ -60,6 +63,7 @@ class SearchResult:
"metadata": self.metadata, "metadata": self.metadata,
} }
@dataclass @dataclass
class SemanticSearchResult: class SemanticSearchResult:
"""语义搜索结果数据模型""" """语义搜索结果数据模型"""
@@ -85,6 +89,7 @@ class SemanticSearchResult:
result["embedding_dim"] = len(self.embedding) result["embedding_dim"] = len(self.embedding)
return result return result
@dataclass @dataclass
class EntityPath: class EntityPath:
"""实体关系路径数据模型""" """实体关系路径数据模型"""
@@ -114,6 +119,7 @@ class EntityPath:
"path_description": self.path_description, "path_description": self.path_description,
} }
@dataclass @dataclass
class KnowledgeGap: class KnowledgeGap:
"""知识缺口数据模型""" """知识缺口数据模型"""
@@ -141,6 +147,7 @@ class KnowledgeGap:
"metadata": self.metadata, "metadata": self.metadata,
} }
@dataclass @dataclass
class SearchIndex: class SearchIndex:
"""搜索索引数据模型""" """搜索索引数据模型"""
@@ -154,6 +161,7 @@ class SearchIndex:
created_at: str created_at: str
updated_at: str updated_at: str
@dataclass @dataclass
class TextEmbedding: class TextEmbedding:
"""文本 Embedding 数据模型""" """文本 Embedding 数据模型"""
@@ -166,8 +174,10 @@ class TextEmbedding:
model_name: str model_name: str
created_at: str created_at: str
# ==================== 全文搜索 ==================== # ==================== 全文搜索 ====================
class FullTextSearch: class FullTextSearch:
""" """
全文搜索模块 全文搜索模块
@@ -222,10 +232,14 @@ class FullTextSearch:
""") """)
# 创建索引 # 创建索引
conn.execute("CREATE INDEX IF NOT EXISTS idx_search_content ON search_indexes(content_id, content_type)") conn.execute(
"CREATE INDEX IF NOT EXISTS idx_search_content ON search_indexes(content_id, content_type)"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_search_project ON search_indexes(project_id)") conn.execute("CREATE INDEX IF NOT EXISTS idx_search_project ON search_indexes(project_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_term_freq_term ON search_term_freq(term)") conn.execute("CREATE INDEX IF NOT EXISTS idx_term_freq_term ON search_term_freq(term)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_term_freq_project ON search_term_freq(project_id)") conn.execute(
"CREATE INDEX IF NOT EXISTS idx_term_freq_project ON search_term_freq(project_id)"
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -320,7 +334,14 @@ class FullTextSearch:
(term, content_id, content_type, project_id, frequency, positions) (term, content_id, content_type, project_id, frequency, positions)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", """,
(token, content_id, content_type, project_id, freq, json.dumps(positions, ensure_ascii=False)), (
token,
content_id,
content_type,
project_id,
freq,
json.dumps(positions, ensure_ascii=False),
),
) )
conn.commit() conn.commit()
@@ -364,7 +385,7 @@ class FullTextSearch:
# 排序和分页 # 排序和分页
scored_results.sort(key=lambda x: x.score, reverse=True) scored_results.sort(key=lambda x: x.score, reverse=True)
return scored_results[offset: offset + limit] return scored_results[offset : offset + limit]
def _parse_boolean_query(self, query: str) -> dict: def _parse_boolean_query(self, query: str) -> dict:
""" """
@@ -405,7 +426,10 @@ class FullTextSearch:
return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases} return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases}
def _execute_boolean_search( def _execute_boolean_search(
self, parsed_query: dict, project_id: str | None = None, content_types: list[str] | None = None self,
parsed_query: dict,
project_id: str | None = None,
content_types: list[str] | None = None,
) -> list[dict]: ) -> list[dict]:
"""执行布尔搜索""" """执行布尔搜索"""
conn = self._get_conn() conn = self._get_conn()
@@ -510,7 +534,8 @@ class FullTextSearch:
{ {
"id": content_id, "id": content_id,
"content_type": content_type, "content_type": content_type,
"project_id": project_id or self._get_project_id(conn, content_id, content_type), "project_id": project_id
or self._get_project_id(conn, content_id, content_type),
"content": content, "content": content,
"terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"], "terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"],
} }
@@ -519,15 +544,21 @@ class FullTextSearch:
conn.close() conn.close()
return results return results
def _get_content_by_id(self, conn: sqlite3.Connection, content_id: str, content_type: str) -> str | None: def _get_content_by_id(
self, conn: sqlite3.Connection, content_id: str, content_type: str
) -> str | None:
"""根据ID获取内容""" """根据ID获取内容"""
try: try:
if content_type == "transcript": if content_type == "transcript":
row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,)
).fetchone()
return row["full_text"] if row else None return row["full_text"] if row else None
elif content_type == "entity": elif content_type == "entity":
row = conn.execute("SELECT name, definition FROM entities WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,)
).fetchone()
if row: if row:
return f"{row['name']} {row['definition'] or ''}" return f"{row['name']} {row['definition'] or ''}"
return None return None
@@ -551,15 +582,23 @@ class FullTextSearch:
print(f"获取内容失败: {e}") print(f"获取内容失败: {e}")
return None return None
def _get_project_id(self, conn: sqlite3.Connection, content_id: str, content_type: str) -> str | None: def _get_project_id(
self, conn: sqlite3.Connection, content_id: str, content_type: str
) -> str | None:
"""获取内容所属的项目ID""" """获取内容所属的项目ID"""
try: try:
if content_type == "transcript": if content_type == "transcript":
row = conn.execute("SELECT project_id FROM transcripts WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT project_id FROM transcripts WHERE id = ?", (content_id,)
).fetchone()
elif content_type == "entity": elif content_type == "entity":
row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (content_id,)
).fetchone()
elif content_type == "relation": elif content_type == "relation":
row = conn.execute("SELECT project_id FROM entity_relations WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT project_id FROM entity_relations WHERE id = ?", (content_id,)
).fetchone()
else: else:
return None return None
@@ -673,12 +712,14 @@ class FullTextSearch:
# 删除索引 # 删除索引
conn.execute( conn.execute(
"DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?", (content_id, content_type) "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
) )
# 删除词频统计 # 删除词频统计
conn.execute( conn.execute(
"DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?", (content_id, content_type) "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
) )
conn.commit() conn.commit()
@@ -696,7 +737,8 @@ class FullTextSearch:
try: try:
# 索引转录文本 # 索引转录文本
transcripts = conn.execute( transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id,) "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,),
).fetchall() ).fetchall()
for t in transcripts: for t in transcripts:
@@ -708,7 +750,8 @@ class FullTextSearch:
# 索引实体 # 索引实体
entities = conn.execute( entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id,) "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,),
).fetchall() ).fetchall()
for e in entities: for e in entities:
@@ -743,8 +786,10 @@ class FullTextSearch:
conn.close() conn.close()
return stats return stats
# ==================== 语义搜索 ==================== # ==================== 语义搜索 ====================
class SemanticSearch: class SemanticSearch:
""" """
语义搜索模块 语义搜索模块
@@ -756,7 +801,11 @@ class SemanticSearch:
- 语义相似内容推荐 - 语义相似内容推荐
""" """
def __init__(self, db_path: str = "insightflow.db", model_name: str = "paraphrase-multilingual-MiniLM-L12-v2"): def __init__(
self,
db_path: str = "insightflow.db",
model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
):
self.db_path = db_path self.db_path = db_path
self.model_name = model_name self.model_name = model_name
self.model = None self.model = None
@@ -793,7 +842,9 @@ class SemanticSearch:
) )
""") """)
conn.execute("CREATE INDEX IF NOT EXISTS idx_embedding_content ON embeddings(content_id, content_type)") conn.execute(
"CREATE INDEX IF NOT EXISTS idx_embedding_content ON embeddings(content_id, content_type)"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_embedding_project ON embeddings(project_id)") conn.execute("CREATE INDEX IF NOT EXISTS idx_embedding_project ON embeddings(project_id)")
conn.commit() conn.commit()
@@ -828,7 +879,9 @@ class SemanticSearch:
print(f"生成 embedding 失败: {e}") print(f"生成 embedding 失败: {e}")
return None return None
def index_embedding(self, content_id: str, content_type: str, project_id: str, text: str) -> bool: def index_embedding(
self, content_id: str, content_type: str, project_id: str, text: str
) -> bool:
""" """
为内容生成并保存 embedding 为内容生成并保存 embedding
@@ -975,11 +1028,15 @@ class SemanticSearch:
try: try:
if content_type == "transcript": if content_type == "transcript":
row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,)
).fetchone()
result = row["full_text"] if row else None result = row["full_text"] if row else None
elif content_type == "entity": elif content_type == "entity":
row = conn.execute("SELECT name, definition FROM entities WHERE id = ?", (content_id,)).fetchone() row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,)
).fetchone()
result = f"{row['name']}: {row['definition']}" if row else None result = f"{row['name']}: {row['definition']}" if row else None
elif content_type == "relation": elif content_type == "relation":
@@ -992,7 +1049,11 @@ class SemanticSearch:
WHERE r.id = ?""", WHERE r.id = ?""",
(content_id,), (content_id,),
).fetchone() ).fetchone()
result = f"{row['source_name']} {row['relation_type']} {row['target_name']}" if row else None result = (
f"{row['source_name']} {row['relation_type']} {row['target_name']}"
if row
else None
)
else: else:
result = None result = None
@@ -1005,7 +1066,9 @@ class SemanticSearch:
print(f"获取内容失败: {e}") print(f"获取内容失败: {e}")
return None return None
def find_similar_content(self, content_id: str, content_type: str, top_k: int = 5) -> list[SemanticSearchResult]: def find_similar_content(
self, content_id: str, content_type: str, top_k: int = 5
) -> list[SemanticSearchResult]:
""" """
查找与指定内容相似的内容 查找与指定内容相似的内容
@@ -1076,7 +1139,10 @@ class SemanticSearch:
"""删除内容的 embedding""" """删除内容的 embedding"""
try: try:
conn = self._get_conn() conn = self._get_conn()
conn.execute("DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type)) conn.execute(
"DELETE FROM embeddings WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
)
conn.commit() conn.commit()
conn.close() conn.close()
return True return True
@@ -1084,8 +1150,10 @@ class SemanticSearch:
print(f"删除 embedding 失败: {e}") print(f"删除 embedding 失败: {e}")
return False return False
# ==================== 实体关系路径发现 ==================== # ==================== 实体关系路径发现 ====================
class EntityPathDiscovery: class EntityPathDiscovery:
""" """
实体关系路径发现模块 实体关系路径发现模块
@@ -1106,7 +1174,9 @@ class EntityPathDiscovery:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def find_shortest_path(self, source_entity_id: str, target_entity_id: str, max_depth: int = 5) -> EntityPath | None: def find_shortest_path(
self, source_entity_id: str, target_entity_id: str, max_depth: int = 5
) -> EntityPath | None:
""" """
查找两个实体之间的最短路径BFS算法 查找两个实体之间的最短路径BFS算法
@@ -1121,7 +1191,9 @@ class EntityPathDiscovery:
conn = self._get_conn() conn = self._get_conn()
# 获取项目ID # 获取项目ID
row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)).fetchone() row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)
).fetchone()
if not row: if not row:
conn.close() conn.close()
@@ -1194,7 +1266,9 @@ class EntityPathDiscovery:
conn = self._get_conn() conn = self._get_conn()
# 获取项目ID # 获取项目ID
row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)).fetchone() row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)
).fetchone()
if not row: if not row:
conn.close() conn.close()
@@ -1250,7 +1324,9 @@ class EntityPathDiscovery:
# 获取实体信息 # 获取实体信息
nodes = [] nodes = []
for entity_id in entity_ids: for entity_id in entity_ids:
row = conn.execute("SELECT id, name, type FROM entities WHERE id = ?", (entity_id,)).fetchone() row = conn.execute(
"SELECT id, name, type FROM entities WHERE id = ?", (entity_id,)
).fetchone()
if row: if row:
nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]}) nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
@@ -1318,7 +1394,9 @@ class EntityPathDiscovery:
conn = self._get_conn() conn = self._get_conn()
# 获取项目ID # 获取项目ID
row = conn.execute("SELECT project_id, name FROM entities WHERE id = ?", (entity_id,)).fetchone() row = conn.execute(
"SELECT project_id, name FROM entities WHERE id = ?", (entity_id,)
).fetchone()
if not row: if not row:
conn.close() conn.close()
@@ -1376,7 +1454,9 @@ class EntityPathDiscovery:
"hops": depth + 1, "hops": depth + 1,
"relation_type": neighbor["relation_type"], "relation_type": neighbor["relation_type"],
"evidence": neighbor["evidence"], "evidence": neighbor["evidence"],
"path": self._get_path_to_entity(entity_id, neighbor_id, project_id, conn), "path": self._get_path_to_entity(
entity_id, neighbor_id, project_id, conn
),
} }
) )
@@ -1481,7 +1561,9 @@ class EntityPathDiscovery:
conn = self._get_conn() conn = self._get_conn()
# 获取所有实体 # 获取所有实体
entities = conn.execute("SELECT id, name FROM entities WHERE project_id = ?", (project_id,)).fetchall() entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
# 计算每个实体作为桥梁的次数 # 计算每个实体作为桥梁的次数
bridge_scores = [] bridge_scores = []
@@ -1512,10 +1594,10 @@ class EntityPathDiscovery:
f""" f"""
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM entity_relations FROM entity_relations
WHERE ((source_entity_id IN ({','.join(['?' for _ in neighbor_ids])}) WHERE ((source_entity_id IN ({",".join(["?" for _ in neighbor_ids])})
AND target_entity_id IN ({','.join(['?' for _ in neighbor_ids])})) AND target_entity_id IN ({",".join(["?" for _ in neighbor_ids])}))
OR (target_entity_id IN ({','.join(['?' for _ in neighbor_ids])}) OR (target_entity_id IN ({",".join(["?" for _ in neighbor_ids])})
AND source_entity_id IN ({','.join(['?' for _ in neighbor_ids])}))) AND source_entity_id IN ({",".join(["?" for _ in neighbor_ids])})))
AND project_id = ? AND project_id = ?
""", """,
list(neighbor_ids) * 4 + [project_id], list(neighbor_ids) * 4 + [project_id],
@@ -1541,8 +1623,10 @@ class EntityPathDiscovery:
bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True) bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True)
return bridge_scores[:20] # 返回前20 return bridge_scores[:20] # 返回前20
# ==================== 知识缺口识别 ==================== # ==================== 知识缺口识别 ====================
class KnowledgeGapDetection: class KnowledgeGapDetection:
""" """
知识缺口识别模块 知识缺口识别模块
@@ -1603,7 +1687,8 @@ class KnowledgeGapDetection:
# 获取项目的属性模板 # 获取项目的属性模板
templates = conn.execute( templates = conn.execute(
"SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", (project_id,) "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?",
(project_id,),
).fetchall() ).fetchall()
if not templates: if not templates:
@@ -1617,7 +1702,9 @@ class KnowledgeGapDetection:
return [] return []
# 检查每个实体的属性完整性 # 检查每个实体的属性完整性
entities = conn.execute("SELECT id, name FROM entities WHERE project_id = ?", (project_id,)).fetchall() entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
for entity in entities: for entity in entities:
entity_id = entity["id"] entity_id = entity["id"]
@@ -1668,7 +1755,9 @@ class KnowledgeGapDetection:
gaps = [] gaps = []
# 获取所有实体及其关系数量 # 获取所有实体及其关系数量
entities = conn.execute("SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,)).fetchall() entities = conn.execute(
"SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
for entity in entities: for entity in entities:
entity_id = entity["id"] entity_id = entity["id"]
@@ -1807,13 +1896,17 @@ class KnowledgeGapDetection:
gaps = [] gaps = []
# 分析转录文本中频繁提及但未提取为实体的词 # 分析转录文本中频繁提及但未提取为实体的词
transcripts = conn.execute("SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,)).fetchall() transcripts = conn.execute(
"SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,)
).fetchall()
# 合并所有文本 # 合并所有文本
all_text = " ".join([t["full_text"] or "" for t in transcripts]) all_text = " ".join([t["full_text"] or "" for t in transcripts])
# 获取现有实体名称 # 获取现有实体名称
existing_entities = conn.execute("SELECT name FROM entities WHERE project_id = ?", (project_id,)).fetchall() existing_entities = conn.execute(
"SELECT name FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
existing_names = {e["name"].lower() for e in existing_entities} existing_names = {e["name"].lower() for e in existing_entities}
@@ -1838,7 +1931,10 @@ class KnowledgeGapDetection:
entity_name=None, entity_name=None,
description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)", description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
severity="low", severity="low",
suggestions=[f"考虑将 '{entity}' 添加为实体", "检查实体提取算法是否需要优化"], suggestions=[
f"考虑将 '{entity}' 添加为实体",
"检查实体提取算法是否需要优化",
],
related_entities=[], related_entities=[],
metadata={"mention_count": count}, metadata={"mention_count": count},
) )
@@ -1898,7 +1994,11 @@ class KnowledgeGapDetection:
"relation_count": stats["relation_count"], "relation_count": stats["relation_count"],
"transcript_count": stats["transcript_count"], "transcript_count": stats["transcript_count"],
}, },
"gap_summary": {"total": len(gaps), "by_type": dict(gap_by_type), "by_severity": severity_count}, "gap_summary": {
"total": len(gaps),
"by_type": dict(gap_by_type),
"by_severity": severity_count,
},
"top_gaps": [g.to_dict() for g in gaps[:10]], "top_gaps": [g.to_dict() for g in gaps[:10]],
"recommendations": self._generate_recommendations(gaps), "recommendations": self._generate_recommendations(gaps),
} }
@@ -1929,8 +2029,10 @@ class KnowledgeGapDetection:
return recommendations return recommendations
# ==================== 搜索管理器 ==================== # ==================== 搜索管理器 ====================
class SearchManager: class SearchManager:
""" """
搜索管理器 - 统一入口 搜索管理器 - 统一入口
@@ -2035,7 +2137,8 @@ class SearchManager:
# 索引转录文本 # 索引转录文本
transcripts = conn.execute( transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id,) "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,),
).fetchall() ).fetchall()
for t in transcripts: for t in transcripts:
@@ -2048,7 +2151,8 @@ class SearchManager:
# 索引实体 # 索引实体
entities = conn.execute( entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id,) "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,),
).fetchall() ).fetchall()
for e in entities: for e in entities:
@@ -2076,9 +2180,9 @@ class SearchManager:
).fetchone()["count"] ).fetchone()["count"]
# 语义索引统计 # 语义索引统计
semantic_count = conn.execute(f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params).fetchone()[ semantic_count = conn.execute(
"count" f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params
] ).fetchone()["count"]
# 按类型统计 # 按类型统计
type_stats = {} type_stats = {}
@@ -2101,9 +2205,11 @@ class SearchManager:
"semantic_search_available": self.semantic_search.is_available(), "semantic_search_available": self.semantic_search.is_available(),
} }
# 单例模式 # 单例模式
_search_manager = None _search_manager = None
def get_search_manager(db_path: str = "insightflow.db") -> SearchManager: def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
"""获取搜索管理器单例""" """获取搜索管理器单例"""
global _search_manager global _search_manager
@@ -2111,22 +2217,30 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
_search_manager = SearchManager(db_path) _search_manager = SearchManager(db_path)
return _search_manager return _search_manager
# 便捷函数 # 便捷函数
def fulltext_search(query: str, project_id: str | None = None, limit: int = 20) -> list[SearchResult]: def fulltext_search(
query: str, project_id: str | None = None, limit: int = 20
) -> list[SearchResult]:
"""全文搜索便捷函数""" """全文搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit) return manager.fulltext_search.search(query, project_id, limit=limit)
def semantic_search(query: str, project_id: str | None = None, top_k: int = 10) -> list[SemanticSearchResult]:
def semantic_search(
query: str, project_id: str | None = None, top_k: int = 10
) -> list[SemanticSearchResult]:
"""语义搜索便捷函数""" """语义搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k) return manager.semantic_search.search(query, project_id, top_k=top_k)
def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None: def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None:
"""查找实体路径便捷函数""" """查找实体路径便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth) return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth)
def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]: def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]:
"""知识缺口检测便捷函数""" """知识缺口检测便捷函数"""
manager = get_search_manager() manager = get_search_manager()

View File

@@ -25,6 +25,7 @@ except ImportError:
CRYPTO_AVAILABLE = False CRYPTO_AVAILABLE = False
print("Warning: cryptography not available, encryption features disabled") print("Warning: cryptography not available, encryption features disabled")
class AuditActionType(Enum): class AuditActionType(Enum):
"""审计动作类型""" """审计动作类型"""
@@ -47,6 +48,7 @@ class AuditActionType(Enum):
WEBHOOK_SEND = "webhook_send" WEBHOOK_SEND = "webhook_send"
BOT_MESSAGE = "bot_message" BOT_MESSAGE = "bot_message"
class DataSensitivityLevel(Enum): class DataSensitivityLevel(Enum):
"""数据敏感度级别""" """数据敏感度级别"""
@@ -55,6 +57,7 @@ class DataSensitivityLevel(Enum):
CONFIDENTIAL = "confidential" # 机密 CONFIDENTIAL = "confidential" # 机密
SECRET = "secret" # 绝密 SECRET = "secret" # 绝密
class MaskingRuleType(Enum): class MaskingRuleType(Enum):
"""脱敏规则类型""" """脱敏规则类型"""
@@ -66,6 +69,7 @@ class MaskingRuleType(Enum):
ADDRESS = "address" # 地址 ADDRESS = "address" # 地址
CUSTOM = "custom" # 自定义 CUSTOM = "custom" # 自定义
@dataclass @dataclass
class AuditLog: class AuditLog:
"""审计日志条目""" """审计日志条目"""
@@ -87,6 +91,7 @@ class AuditLog:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@dataclass @dataclass
class EncryptionConfig: class EncryptionConfig:
"""加密配置""" """加密配置"""
@@ -104,6 +109,7 @@ class EncryptionConfig:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@dataclass @dataclass
class MaskingRule: class MaskingRule:
"""脱敏规则""" """脱敏规则"""
@@ -123,6 +129,7 @@ class MaskingRule:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@dataclass @dataclass
class DataAccessPolicy: class DataAccessPolicy:
"""数据访问策略""" """数据访问策略"""
@@ -144,6 +151,7 @@ class DataAccessPolicy:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@dataclass @dataclass
class AccessRequest: class AccessRequest:
"""访问请求(用于需要审批的访问)""" """访问请求(用于需要审批的访问)"""
@@ -161,6 +169,7 @@ class AccessRequest:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
class SecurityManager: class SecurityManager:
"""安全管理器""" """安全管理器"""
@@ -168,9 +177,18 @@ class SecurityManager:
DEFAULT_MASKING_RULES = { DEFAULT_MASKING_RULES = {
MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"}, MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"}, MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
MaskingRuleType.ID_CARD: {"pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2"}, MaskingRuleType.ID_CARD: {
MaskingRuleType.BANK_CARD: {"pattern": r"(\d{4})\d+(\d{4})", "replacement": r"\1 **** **** \2"}, "pattern": r"(\d{6})\d{8}(\d{4})",
MaskingRuleType.NAME: {"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", "replacement": r"\1**"}, "replacement": r"\1********\2",
},
MaskingRuleType.BANK_CARD: {
"pattern": r"(\d{4})\d+(\d{4})",
"replacement": r"\1 **** **** \2",
},
MaskingRuleType.NAME: {
"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+",
"replacement": r"\1**",
},
MaskingRuleType.ADDRESS: { MaskingRuleType.ADDRESS: {
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)", "pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
"replacement": r"\1\2***", "replacement": r"\1\2***",
@@ -281,19 +299,33 @@ class SecurityManager:
# 创建索引 # 创建索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_user ON audit_logs(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_user ON audit_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_resource ON audit_logs(resource_type, resource_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action_type)") "CREATE INDEX IF NOT EXISTS idx_audit_logs_resource ON audit_logs(resource_type, resource_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_created ON audit_logs(created_at)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_encryption_project ON encryption_configs(project_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_masking_project ON masking_rules(project_id)") "CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action_type)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_access_policy_project ON data_access_policies(project_id)") )
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_audit_logs_created ON audit_logs(created_at)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_encryption_project ON encryption_configs(project_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_masking_project ON masking_rules(project_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_access_policy_project ON data_access_policies(project_id)"
)
conn.commit() conn.commit()
conn.close() conn.close()
def _generate_id(self) -> str: def _generate_id(self) -> str:
"""生成唯一ID""" """生成唯一ID"""
return hashlib.sha256(f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()).hexdigest()[:32] return hashlib.sha256(
f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
).hexdigest()[:32]
# ==================== 审计日志 ==================== # ==================== 审计日志 ====================
@@ -431,7 +463,9 @@ class SecurityManager:
conn.close() conn.close()
return logs return logs
def get_audit_stats(self, start_time: str | None = None, end_time: str | None = None) -> dict[str, Any]: def get_audit_stats(
self, start_time: str | None = None, end_time: str | None = None
) -> dict[str, Any]:
"""获取审计统计""" """获取审计统计"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -589,7 +623,11 @@ class SecurityManager:
conn.close() conn.close()
# 记录审计日志 # 记录审计日志
self.log_audit(action_type=AuditActionType.ENCRYPTION_DISABLE, resource_type="project", resource_id=project_id) self.log_audit(
action_type=AuditActionType.ENCRYPTION_DISABLE,
resource_type="project",
resource_id=project_id,
)
return True return True
@@ -601,7 +639,10 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id,)) cursor.execute(
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
(project_id,),
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -794,7 +835,7 @@ class SecurityManager:
cursor.execute( cursor.execute(
f""" f"""
UPDATE masking_rules UPDATE masking_rules
SET {', '.join(set_clauses)} SET {", ".join(set_clauses)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -840,7 +881,9 @@ class SecurityManager:
return success return success
def apply_masking(self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None) -> str: def apply_masking(
self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None
) -> str:
"""应用脱敏规则到文本""" """应用脱敏规则到文本"""
rules = self.get_masking_rules(project_id) rules = self.get_masking_rules(project_id)
@@ -862,7 +905,9 @@ class SecurityManager:
return masked_text return masked_text
def apply_masking_to_entity(self, entity_data: dict[str, Any], project_id: str) -> dict[str, Any]: def apply_masking_to_entity(
self, entity_data: dict[str, Any], project_id: str
) -> dict[str, Any]:
"""对实体数据应用脱敏""" """对实体数据应用脱敏"""
masked_data = entity_data.copy() masked_data = entity_data.copy()
@@ -936,7 +981,9 @@ class SecurityManager:
return policy return policy
def get_access_policies(self, project_id: str, active_only: bool = True) -> list[DataAccessPolicy]: def get_access_policies(
self, project_id: str, active_only: bool = True
) -> list[DataAccessPolicy]:
"""获取数据访问策略""" """获取数据访问策略"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -980,7 +1027,9 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,)) cursor.execute(
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -1073,7 +1122,11 @@ class SecurityManager:
return ip == pattern return ip == pattern
def create_access_request( def create_access_request(
self, policy_id: str, user_id: str, request_reason: str | None = None, expires_hours: int = 24 self,
policy_id: str,
user_id: str,
request_reason: str | None = None,
expires_hours: int = 24,
) -> AccessRequest: ) -> AccessRequest:
"""创建访问请求""" """创建访问请求"""
request = AccessRequest( request = AccessRequest(
@@ -1185,9 +1238,11 @@ class SecurityManager:
created_at=row[8], created_at=row[8],
) )
# 全局安全管理器实例 # 全局安全管理器实例
_security_manager = None _security_manager = None
def get_security_manager(db_path: str = "insightflow.db") -> SecurityManager: def get_security_manager(db_path: str = "insightflow.db") -> SecurityManager:
"""获取安全管理器实例""" """获取安全管理器实例"""
global _security_manager global _security_manager

View File

@@ -21,6 +21,7 @@ from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SubscriptionStatus(StrEnum): class SubscriptionStatus(StrEnum):
"""订阅状态""" """订阅状态"""
@@ -31,6 +32,7 @@ class SubscriptionStatus(StrEnum):
TRIAL = "trial" # 试用中 TRIAL = "trial" # 试用中
PENDING = "pending" # 待支付 PENDING = "pending" # 待支付
class PaymentProvider(StrEnum): class PaymentProvider(StrEnum):
"""支付提供商""" """支付提供商"""
@@ -39,6 +41,7 @@ class PaymentProvider(StrEnum):
WECHAT = "wechat" # 微信支付 WECHAT = "wechat" # 微信支付
BANK_TRANSFER = "bank_transfer" # 银行转账 BANK_TRANSFER = "bank_transfer" # 银行转账
class PaymentStatus(StrEnum): class PaymentStatus(StrEnum):
"""支付状态""" """支付状态"""
@@ -49,6 +52,7 @@ class PaymentStatus(StrEnum):
REFUNDED = "refunded" # 已退款 REFUNDED = "refunded" # 已退款
PARTIAL_REFUNDED = "partial_refunded" # 部分退款 PARTIAL_REFUNDED = "partial_refunded" # 部分退款
class InvoiceStatus(StrEnum): class InvoiceStatus(StrEnum):
"""发票状态""" """发票状态"""
@@ -59,6 +63,7 @@ class InvoiceStatus(StrEnum):
VOID = "void" # 作废 VOID = "void" # 作废
CREDIT_NOTE = "credit_note" # 贷项通知单 CREDIT_NOTE = "credit_note" # 贷项通知单
class RefundStatus(StrEnum): class RefundStatus(StrEnum):
"""退款状态""" """退款状态"""
@@ -68,6 +73,7 @@ class RefundStatus(StrEnum):
COMPLETED = "completed" # 已完成 COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败 FAILED = "failed" # 失败
@dataclass @dataclass
class SubscriptionPlan: class SubscriptionPlan:
"""订阅计划数据类""" """订阅计划数据类"""
@@ -86,6 +92,7 @@ class SubscriptionPlan:
updated_at: datetime updated_at: datetime
metadata: dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
class Subscription: class Subscription:
"""订阅数据类""" """订阅数据类"""
@@ -106,6 +113,7 @@ class Subscription:
updated_at: datetime updated_at: datetime
metadata: dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
class UsageRecord: class UsageRecord:
"""用量记录数据类""" """用量记录数据类"""
@@ -120,6 +128,7 @@ class UsageRecord:
description: str | None description: str | None
metadata: dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
class Payment: class Payment:
"""支付记录数据类""" """支付记录数据类"""
@@ -141,6 +150,7 @@ class Payment:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class Invoice: class Invoice:
"""发票数据类""" """发票数据类"""
@@ -164,6 +174,7 @@ class Invoice:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class Refund: class Refund:
"""退款数据类""" """退款数据类"""
@@ -186,6 +197,7 @@ class Refund:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class BillingHistory: class BillingHistory:
"""账单历史数据类""" """账单历史数据类"""
@@ -201,6 +213,7 @@ class BillingHistory:
created_at: datetime created_at: datetime
metadata: dict[str, Any] metadata: dict[str, Any]
class SubscriptionManager: class SubscriptionManager:
"""订阅与计费管理器""" """订阅与计费管理器"""
@@ -213,7 +226,13 @@ class SubscriptionManager:
"price_monthly": 0.0, "price_monthly": 0.0,
"price_yearly": 0.0, "price_yearly": 0.0,
"currency": "CNY", "currency": "CNY",
"features": ["basic_analysis", "export_png", "3_projects", "100_mb_storage", "60_min_transcription"], "features": [
"basic_analysis",
"export_png",
"3_projects",
"100_mb_storage",
"60_min_transcription",
],
"limits": { "limits": {
"max_projects": 3, "max_projects": 3,
"max_storage_mb": 100, "max_storage_mb": 100,
@@ -280,9 +299,17 @@ class SubscriptionManager:
# 按量计费单价CNY # 按量计费单价CNY
USAGE_PRICING = { USAGE_PRICING = {
"transcription": {"unit": "minute", "price": 0.5, "free_quota": 60}, # 0.5元/分钟 # 每月免费额度 "transcription": {
"unit": "minute",
"price": 0.5,
"free_quota": 60,
}, # 0.5元/分钟 # 每月免费额度
"storage": {"unit": "gb", "price": 10.0, "free_quota": 0.1}, # 10元/GB/月 # 100MB免费 "storage": {"unit": "gb", "price": 10.0, "free_quota": 0.1}, # 10元/GB/月 # 100MB免费
"api_call": {"unit": "1000_calls", "price": 5.0, "free_quota": 1000}, # 5元/1000次 # 每月免费1000次 "api_call": {
"unit": "1000_calls",
"price": 5.0,
"free_quota": 1000,
}, # 5元/1000次 # 每月免费1000次
"export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页PDF导出 "export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页PDF导出
} }
@@ -456,21 +483,39 @@ class SubscriptionManager:
""") """)
# 创建索引 # 创建索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_subscriptions_tenant ON subscriptions(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_subscriptions_status ON subscriptions(status)") "CREATE INDEX IF NOT EXISTS idx_subscriptions_tenant ON subscriptions(tenant_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_subscriptions_plan ON subscriptions(plan_id)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_tenant ON usage_records(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_type ON usage_records(resource_type)") "CREATE INDEX IF NOT EXISTS idx_subscriptions_status ON subscriptions(status)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_recorded ON usage_records(recorded_at)") )
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_subscriptions_plan ON subscriptions(plan_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_usage_tenant ON usage_records(tenant_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_usage_type ON usage_records(resource_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_usage_recorded ON usage_records(recorded_at)"
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_payments_tenant ON payments(tenant_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_payments_tenant ON payments(tenant_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_payments_status ON payments(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_payments_status ON payments(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_invoices_tenant ON invoices(tenant_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_invoices_tenant ON invoices(tenant_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_invoices_status ON invoices(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_invoices_status ON invoices(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_invoices_number ON invoices(invoice_number)") cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_invoices_number ON invoices(invoice_number)"
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_refunds_tenant ON refunds(tenant_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_refunds_tenant ON refunds(tenant_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_refunds_status ON refunds(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_refunds_status ON refunds(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_billing_tenant ON billing_history(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_billing_created ON billing_history(created_at)") "CREATE INDEX IF NOT EXISTS idx_billing_tenant ON billing_history(tenant_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_billing_created ON billing_history(created_at)"
)
conn.commit() conn.commit()
logger.info("Subscription tables initialized successfully") logger.info("Subscription tables initialized successfully")
@@ -542,7 +587,9 @@ class SubscriptionManager:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,)) cursor.execute(
"SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,)
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@@ -561,7 +608,9 @@ class SubscriptionManager:
if include_inactive: if include_inactive:
cursor.execute("SELECT * FROM subscription_plans ORDER BY price_monthly") cursor.execute("SELECT * FROM subscription_plans ORDER BY price_monthly")
else: else:
cursor.execute("SELECT * FROM subscription_plans WHERE is_active = 1 ORDER BY price_monthly") cursor.execute(
"SELECT * FROM subscription_plans WHERE is_active = 1 ORDER BY price_monthly"
)
rows = cursor.fetchall() rows = cursor.fetchall()
return [self._row_to_plan(row) for row in rows] return [self._row_to_plan(row) for row in rows]
@@ -679,7 +728,7 @@ class SubscriptionManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
f""" f"""
UPDATE subscription_plans SET {', '.join(updates)} UPDATE subscription_plans SET {", ".join(updates)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -901,7 +950,7 @@ class SubscriptionManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
f""" f"""
UPDATE subscriptions SET {', '.join(updates)} UPDATE subscriptions SET {", ".join(updates)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -913,7 +962,9 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def cancel_subscription(self, subscription_id: str, at_period_end: bool = True) -> Subscription | None: def cancel_subscription(
self, subscription_id: str, at_period_end: bool = True
) -> Subscription | None:
"""取消订阅""" """取消订阅"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -965,7 +1016,9 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def change_plan(self, subscription_id: str, new_plan_id: str, prorate: bool = True) -> Subscription | None: def change_plan(
self, subscription_id: str, new_plan_id: str, prorate: bool = True
) -> Subscription | None:
"""更改订阅计划""" """更改订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1214,7 +1267,9 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def confirm_payment(self, payment_id: str, provider_payment_id: str | None = None) -> Payment | None: def confirm_payment(
self, payment_id: str, provider_payment_id: str | None = None
) -> Payment | None:
"""确认支付完成""" """确认支付完成"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1525,7 +1580,9 @@ class SubscriptionManager:
# ==================== 退款管理 ==================== # ==================== 退款管理 ====================
def request_refund(self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str) -> Refund: def request_refund(
self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str
) -> Refund:
"""申请退款""" """申请退款"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1632,7 +1689,9 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def complete_refund(self, refund_id: str, provider_refund_id: str | None = None) -> Refund | None: def complete_refund(
self, refund_id: str, provider_refund_id: str | None = None
) -> Refund | None:
"""完成退款""" """完成退款"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1825,7 +1884,12 @@ class SubscriptionManager:
# ==================== 支付提供商集成 ==================== # ==================== 支付提供商集成 ====================
def create_stripe_checkout_session( def create_stripe_checkout_session(
self, tenant_id: str, plan_id: str, success_url: str, cancel_url: str, billing_cycle: str = "monthly" self,
tenant_id: str,
plan_id: str,
success_url: str,
cancel_url: str,
billing_cycle: str = "monthly",
) -> dict[str, Any]: ) -> dict[str, Any]:
"""创建 Stripe Checkout 会话(占位实现)""" """创建 Stripe Checkout 会话(占位实现)"""
# 这里应该集成 Stripe SDK # 这里应该集成 Stripe SDK
@@ -1837,7 +1901,9 @@ class SubscriptionManager:
"provider": "stripe", "provider": "stripe",
} }
def create_alipay_order(self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly") -> dict[str, Any]: def create_alipay_order(
self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly"
) -> dict[str, Any]:
"""创建支付宝订单(占位实现)""" """创建支付宝订单(占位实现)"""
# 这里应该集成支付宝 SDK # 这里应该集成支付宝 SDK
plan = self.get_plan(plan_id) plan = self.get_plan(plan_id)
@@ -1852,7 +1918,9 @@ class SubscriptionManager:
"provider": "alipay", "provider": "alipay",
} }
def create_wechat_order(self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly") -> dict[str, Any]: def create_wechat_order(
self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly"
) -> dict[str, Any]:
"""创建微信支付订单(占位实现)""" """创建微信支付订单(占位实现)"""
# 这里应该集成微信支付 SDK # 这里应该集成微信支付 SDK
plan = self.get_plan(plan_id) plan = self.get_plan(plan_id)
@@ -1905,10 +1973,14 @@ class SubscriptionManager:
limits=json.loads(row["limits"] or "{}"), limits=json.loads(row["limits"] or "{}"),
is_active=bool(row["is_active"]), is_active=bool(row["is_active"]),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
metadata=json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
@@ -1949,10 +2021,14 @@ class SubscriptionManager:
payment_provider=row["payment_provider"], payment_provider=row["payment_provider"],
provider_subscription_id=row["provider_subscription_id"], provider_subscription_id=row["provider_subscription_id"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
metadata=json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
@@ -2001,10 +2077,14 @@ class SubscriptionManager:
), ),
failure_reason=row["failure_reason"], failure_reason=row["failure_reason"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -2048,10 +2128,14 @@ class SubscriptionManager:
), ),
void_reason=row["void_reason"], void_reason=row["void_reason"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -2086,10 +2170,14 @@ class SubscriptionManager:
provider_refund_id=row["provider_refund_id"], provider_refund_id=row["provider_refund_id"],
metadata=json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -2105,14 +2193,18 @@ class SubscriptionManager:
reference_id=row["reference_id"], reference_id=row["reference_id"],
balance_after=row["balance_after"], balance_after=row["balance_after"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
metadata=json.loads(row["metadata"] or "{}"), metadata=json.loads(row["metadata"] or "{}"),
) )
# 全局订阅管理器实例 # 全局订阅管理器实例
subscription_manager = None subscription_manager = None
def get_subscription_manager(db_path: str = "insightflow.db") -> SubscriptionManager: def get_subscription_manager(db_path: str = "insightflow.db") -> SubscriptionManager:
"""获取订阅管理器实例(单例模式)""" """获取订阅管理器实例(单例模式)"""
global subscription_manager global subscription_manager

View File

@@ -23,6 +23,7 @@ from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TenantLimits: class TenantLimits:
"""租户资源限制常量""" """租户资源限制常量"""
@@ -42,6 +43,7 @@ class TenantLimits:
UNLIMITED = -1 UNLIMITED = -1
class TenantStatus(StrEnum): class TenantStatus(StrEnum):
"""租户状态""" """租户状态"""
@@ -51,6 +53,7 @@ class TenantStatus(StrEnum):
EXPIRED = "expired" # 过期 EXPIRED = "expired" # 过期
PENDING = "pending" # 待激活 PENDING = "pending" # 待激活
class TenantTier(StrEnum): class TenantTier(StrEnum):
"""租户订阅层级""" """租户订阅层级"""
@@ -58,6 +61,7 @@ class TenantTier(StrEnum):
PRO = "pro" # 专业版 PRO = "pro" # 专业版
ENTERPRISE = "enterprise" # 企业版 ENTERPRISE = "enterprise" # 企业版
class TenantRole(StrEnum): class TenantRole(StrEnum):
"""租户角色""" """租户角色"""
@@ -66,6 +70,7 @@ class TenantRole(StrEnum):
MEMBER = "member" # 成员 MEMBER = "member" # 成员
VIEWER = "viewer" # 查看者 VIEWER = "viewer" # 查看者
class DomainStatus(StrEnum): class DomainStatus(StrEnum):
"""域名状态""" """域名状态"""
@@ -74,6 +79,7 @@ class DomainStatus(StrEnum):
FAILED = "failed" # 验证失败 FAILED = "failed" # 验证失败
EXPIRED = "expired" # 已过期 EXPIRED = "expired" # 已过期
@dataclass @dataclass
class Tenant: class Tenant:
"""租户数据类""" """租户数据类"""
@@ -92,6 +98,7 @@ class Tenant:
resource_limits: dict[str, Any] # 资源限制 resource_limits: dict[str, Any] # 资源限制
metadata: dict[str, Any] # 元数据 metadata: dict[str, Any] # 元数据
@dataclass @dataclass
class TenantDomain: class TenantDomain:
"""租户域名数据类""" """租户域名数据类"""
@@ -109,6 +116,7 @@ class TenantDomain:
ssl_enabled: bool # SSL 是否启用 ssl_enabled: bool # SSL 是否启用
ssl_expires_at: datetime | None ssl_expires_at: datetime | None
@dataclass @dataclass
class TenantBranding: class TenantBranding:
"""租户品牌配置数据类""" """租户品牌配置数据类"""
@@ -126,6 +134,7 @@ class TenantBranding:
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@dataclass @dataclass
class TenantMember: class TenantMember:
"""租户成员数据类""" """租户成员数据类"""
@@ -142,6 +151,7 @@ class TenantMember:
last_active_at: datetime | None last_active_at: datetime | None
status: str # active/pending/suspended status: str # active/pending/suspended
@dataclass @dataclass
class TenantPermission: class TenantPermission:
"""租户权限定义数据类""" """租户权限定义数据类"""
@@ -156,6 +166,7 @@ class TenantPermission:
conditions: dict | None # 条件限制 conditions: dict | None # 条件限制
created_at: datetime created_at: datetime
class TenantManager: class TenantManager:
"""租户管理器 - 多租户 SaaS 架构核心""" """租户管理器 - 多租户 SaaS 架构核心"""
@@ -199,8 +210,24 @@ class TenantManager:
# 角色权限映射 # 角色权限映射
ROLE_PERMISSIONS = { ROLE_PERMISSIONS = {
TenantRole.OWNER: ["tenant:*", "project:*", "member:*", "billing:*", "settings:*", "api:*", "export:*"], TenantRole.OWNER: [
TenantRole.ADMIN: ["tenant:read", "project:*", "member:*", "billing:read", "settings:*", "api:*", "export:*"], "tenant:*",
"project:*",
"member:*",
"billing:*",
"settings:*",
"api:*",
"export:*",
],
TenantRole.ADMIN: [
"tenant:read",
"project:*",
"member:*",
"billing:read",
"settings:*",
"api:*",
"export:*",
],
TenantRole.MEMBER: [ TenantRole.MEMBER: [
"tenant:read", "tenant:read",
"project:create", "project:create",
@@ -360,10 +387,18 @@ class TenantManager:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenants_slug ON tenants(slug)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenants_slug ON tenants(slug)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenants_owner ON tenants(owner_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenants_owner ON tenants(owner_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenants_status ON tenants(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_tenants_status ON tenants(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_domains_tenant ON tenant_domains(tenant_id)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_domains_domain ON tenant_domains(domain)") "CREATE INDEX IF NOT EXISTS idx_domains_tenant ON tenant_domains(tenant_id)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_domains_status ON tenant_domains(status)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_members_tenant ON tenant_members(tenant_id)") cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_domains_domain ON tenant_domains(domain)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_domains_status ON tenant_domains(status)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_members_tenant ON tenant_members(tenant_id)"
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_members_user ON tenant_members(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_members_user ON tenant_members(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_tenant ON tenant_usage(tenant_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_tenant ON tenant_usage(tenant_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_date ON tenant_usage(date)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_usage_date ON tenant_usage(date)")
@@ -380,7 +415,12 @@ class TenantManager:
# ==================== 租户管理 ==================== # ==================== 租户管理 ====================
def create_tenant( def create_tenant(
self, name: str, owner_id: str, tier: str = "free", description: str | None = None, settings: dict | None = None self,
name: str,
owner_id: str,
tier: str = "free",
description: str | None = None,
settings: dict | None = None,
) -> Tenant: ) -> Tenant:
"""创建新租户""" """创建新租户"""
conn = self._get_connection() conn = self._get_connection()
@@ -389,8 +429,12 @@ class TenantManager:
slug = self._generate_slug(name) slug = self._generate_slug(name)
# 获取对应层级的资源限制 # 获取对应层级的资源限制
tier_enum = TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE tier_enum = (
resource_limits = self.DEFAULT_LIMITS.get(tier_enum, self.DEFAULT_LIMITS[TenantTier.FREE]) TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE
)
resource_limits = self.DEFAULT_LIMITS.get(
tier_enum, self.DEFAULT_LIMITS[TenantTier.FREE]
)
tenant = Tenant( tenant = Tenant(
id=tenant_id, id=tenant_id,
@@ -544,7 +588,7 @@ class TenantManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
f""" f"""
UPDATE tenants SET {', '.join(updates)} UPDATE tenants SET {", ".join(updates)}
WHERE id = ? WHERE id = ?
""", """,
params, params,
@@ -599,7 +643,11 @@ class TenantManager:
# ==================== 域名管理 ==================== # ==================== 域名管理 ====================
def add_domain( def add_domain(
self, tenant_id: str, domain: str, is_primary: bool = False, verification_method: str = "dns" self,
tenant_id: str,
domain: str,
is_primary: bool = False,
verification_method: str = "dns",
) -> TenantDomain: ) -> TenantDomain:
"""为租户添加自定义域名""" """为租户添加自定义域名"""
conn = self._get_connection() conn = self._get_connection()
@@ -752,7 +800,10 @@ class TenantManager:
"value": f"insightflow-verify={token}", "value": f"insightflow-verify={token}",
"ttl": 3600, "ttl": 3600,
}, },
"file_verification": {"url": f"http://{domain}/.well-known/insightflow-verify.txt", "content": token}, "file_verification": {
"url": f"http://{domain}/.well-known/insightflow-verify.txt",
"content": token,
},
"instructions": [ "instructions": [
f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}", f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}",
f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt内容为 {token}", f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt内容为 {token}",
@@ -873,7 +924,7 @@ class TenantManager:
cursor.execute( cursor.execute(
f""" f"""
UPDATE tenant_branding SET {', '.join(updates)} UPDATE tenant_branding SET {", ".join(updates)}
WHERE tenant_id = ? WHERE tenant_id = ?
""", """,
params, params,
@@ -951,7 +1002,12 @@ class TenantManager:
# ==================== 成员与权限管理 ==================== # ==================== 成员与权限管理 ====================
def invite_member( def invite_member(
self, tenant_id: str, email: str, role: str, invited_by: str, permissions: list[str] | None = None self,
tenant_id: str,
email: str,
role: str,
invited_by: str,
permissions: list[str] | None = None,
) -> TenantMember: ) -> TenantMember:
"""邀请成员加入租户""" """邀请成员加入租户"""
conn = self._get_connection() conn = self._get_connection()
@@ -959,7 +1015,9 @@ class TenantManager:
member_id = str(uuid.uuid4()) member_id = str(uuid.uuid4())
# 使用角色默认权限 # 使用角色默认权限
role_enum = TenantRole(role) if role in [r.value for r in TenantRole] else TenantRole.MEMBER role_enum = (
TenantRole(role) if role in [r.value for r in TenantRole] else TenantRole.MEMBER
)
default_permissions = self.ROLE_PERMISSIONS.get(role_enum, []) default_permissions = self.ROLE_PERMISSIONS.get(role_enum, [])
final_permissions = permissions or default_permissions final_permissions = permissions or default_permissions
@@ -1146,7 +1204,13 @@ class TenantManager:
result = [] result = []
for row in rows: for row in rows:
tenant = self._row_to_tenant(row) tenant = self._row_to_tenant(row)
result.append({**asdict(tenant), "member_role": row["role"], "member_status": row["member_status"]}) result.append(
{
**asdict(tenant),
"member_role": row["role"],
"member_status": row["member_status"],
}
)
return result return result
finally: finally:
@@ -1253,14 +1317,21 @@ class TenantManager:
row["total_storage"] or 0, limits.get("max_storage_mb", 0) * 1024 * 1024 row["total_storage"] or 0, limits.get("max_storage_mb", 0) * 1024 * 1024
), ),
"transcription": self._calc_percentage( "transcription": self._calc_percentage(
row["total_transcription"] or 0, limits.get("max_transcription_minutes", 0) * 60 row["total_transcription"] or 0,
limits.get("max_transcription_minutes", 0) * 60,
), ),
"api_calls": self._calc_percentage( "api_calls": self._calc_percentage(
row["total_api_calls"] or 0, limits.get("max_api_calls_per_day", 0) row["total_api_calls"] or 0, limits.get("max_api_calls_per_day", 0)
), ),
"projects": self._calc_percentage(row["max_projects"] or 0, limits.get("max_projects", 0)), "projects": self._calc_percentage(
"entities": self._calc_percentage(row["max_entities"] or 0, limits.get("max_entities", 0)), row["max_projects"] or 0, limits.get("max_projects", 0)
"members": self._calc_percentage(row["max_members"] or 0, limits.get("max_team_members", 0)), ),
"entities": self._calc_percentage(
row["max_entities"] or 0, limits.get("max_entities", 0)
),
"members": self._calc_percentage(
row["max_members"] or 0, limits.get("max_team_members", 0)
),
}, },
} }
@@ -1434,10 +1505,14 @@ class TenantManager:
status=row["status"], status=row["status"],
owner_id=row["owner_id"], owner_id=row["owner_id"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
expires_at=( expires_at=(
datetime.fromisoformat(row["expires_at"]) datetime.fromisoformat(row["expires_at"])
@@ -1464,10 +1539,14 @@ class TenantManager:
else row["verified_at"] else row["verified_at"]
), ),
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
is_primary=bool(row["is_primary"]), is_primary=bool(row["is_primary"]),
ssl_enabled=bool(row["ssl_enabled"]), ssl_enabled=bool(row["ssl_enabled"]),
@@ -1492,10 +1571,14 @@ class TenantManager:
login_page_bg=row["login_page_bg"], login_page_bg=row["login_page_bg"],
email_template=row["email_template"], email_template=row["email_template"],
created_at=( created_at=(
datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"] datetime.fromisoformat(row["created_at"])
if isinstance(row["created_at"], str)
else row["created_at"]
), ),
updated_at=( updated_at=(
datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"] datetime.fromisoformat(row["updated_at"])
if isinstance(row["updated_at"], str)
else row["updated_at"]
), ),
) )
@@ -1510,7 +1593,9 @@ class TenantManager:
permissions=json.loads(row["permissions"] or "[]"), permissions=json.loads(row["permissions"] or "[]"),
invited_by=row["invited_by"], invited_by=row["invited_by"],
invited_at=( invited_at=(
datetime.fromisoformat(row["invited_at"]) if isinstance(row["invited_at"], str) else row["invited_at"] datetime.fromisoformat(row["invited_at"])
if isinstance(row["invited_at"], str)
else row["invited_at"]
), ),
joined_at=( joined_at=(
datetime.fromisoformat(row["joined_at"]) datetime.fromisoformat(row["joined_at"])
@@ -1525,8 +1610,10 @@ class TenantManager:
status=row["status"], status=row["status"],
) )
# ==================== 租户上下文管理 ==================== # ==================== 租户上下文管理 ====================
class TenantContext: class TenantContext:
"""租户上下文管理器 - 用于请求级别的租户隔离""" """租户上下文管理器 - 用于请求级别的租户隔离"""
@@ -1559,9 +1646,11 @@ class TenantContext:
cls._current_tenant_id = None cls._current_tenant_id = None
cls._current_user_id = None cls._current_user_id = None
# 全局租户管理器实例 # 全局租户管理器实例
tenant_manager = None tenant_manager = None
def get_tenant_manager(db_path: str = "insightflow.db") -> TenantManager: def get_tenant_manager(db_path: str = "insightflow.db") -> TenantManager:
"""获取租户管理器实例(单例模式)""" """获取租户管理器实例(单例模式)"""
global tenant_manager global tenant_manager

View File

@@ -19,18 +19,21 @@ print("\n1. 测试模块导入...")
try: try:
from multimodal_processor import get_multimodal_processor from multimodal_processor import get_multimodal_processor
print(" ✓ multimodal_processor 导入成功") print(" ✓ multimodal_processor 导入成功")
except ImportError as e: except ImportError as e:
print(f" ✗ multimodal_processor 导入失败: {e}") print(f" ✗ multimodal_processor 导入失败: {e}")
try: try:
from image_processor import get_image_processor from image_processor import get_image_processor
print(" ✓ image_processor 导入成功") print(" ✓ image_processor 导入成功")
except ImportError as e: except ImportError as e:
print(f" ✗ image_processor 导入失败: {e}") print(f" ✗ image_processor 导入失败: {e}")
try: try:
from multimodal_entity_linker import get_multimodal_entity_linker from multimodal_entity_linker import get_multimodal_entity_linker
print(" ✓ multimodal_entity_linker 导入成功") print(" ✓ multimodal_entity_linker 导入成功")
except ImportError as e: except ImportError as e:
print(f" ✗ multimodal_entity_linker 导入失败: {e}") print(f" ✗ multimodal_entity_linker 导入失败: {e}")
@@ -110,7 +113,7 @@ try:
for dir_name, dir_path in [ for dir_name, dir_path in [
("视频", processor.video_dir), ("视频", processor.video_dir),
("", processor.frames_dir), ("", processor.frames_dir),
("音频", processor.audio_dir) ("音频", processor.audio_dir),
]: ]:
if os.path.exists(dir_path): if os.path.exists(dir_path):
print(f"{dir_name}目录存在: {dir_path}") print(f"{dir_name}目录存在: {dir_path}")
@@ -125,11 +128,12 @@ print("\n6. 测试数据库多模态方法...")
try: try:
from db_manager import get_db_manager from db_manager import get_db_manager
db = get_db_manager() db = get_db_manager()
# 检查多模态表是否存在 # 检查多模态表是否存在
conn = db.get_conn() conn = db.get_conn()
tables = ['videos', 'video_frames', 'images', 'multimodal_mentions', 'multimodal_entity_links'] tables = ["videos", "video_frames", "images", "multimodal_mentions", "multimodal_entity_links"]
for table in tables: for table in tables:
try: try:

View File

@@ -20,6 +20,7 @@ from search_manager import (
# 添加 backend 到路径 # 添加 backend 到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_fulltext_search(): def test_fulltext_search():
"""测试全文搜索""" """测试全文搜索"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -34,7 +35,7 @@ def test_fulltext_search():
content_id="test_entity_1", content_id="test_entity_1",
content_type="entity", content_type="entity",
project_id="test_project", project_id="test_project",
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。" text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。",
) )
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
@@ -56,15 +57,13 @@ def test_fulltext_search():
# 测试高亮 # 测试高亮
print("\n4. 测试文本高亮...") print("\n4. 测试文本高亮...")
highlighted = search.highlight_text( highlighted = search.highlight_text("这是一个测试实体,用于验证全文搜索功能。", "测试 全文")
"这是一个测试实体,用于验证全文搜索功能。",
"测试 全文"
)
print(f" 高亮结果: {highlighted}") print(f" 高亮结果: {highlighted}")
print("\n✓ 全文搜索测试完成") print("\n✓ 全文搜索测试完成")
return True return True
def test_semantic_search(): def test_semantic_search():
"""测试语义搜索""" """测试语义搜索"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -93,13 +92,14 @@ def test_semantic_search():
content_id="test_content_1", content_id="test_content_1",
content_type="transcript", content_type="transcript",
project_id="test_project", project_id="test_project",
text="这是用于语义搜索测试的文本内容。" text="这是用于语义搜索测试的文本内容。",
) )
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}") print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
print("\n✓ 语义搜索测试完成") print("\n✓ 语义搜索测试完成")
return True return True
def test_entity_path_discovery(): def test_entity_path_discovery():
"""测试实体路径发现""" """测试实体路径发现"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -118,6 +118,7 @@ def test_entity_path_discovery():
print("\n✓ 实体路径发现测试完成") print("\n✓ 实体路径发现测试完成")
return True return True
def test_knowledge_gap_detection(): def test_knowledge_gap_detection():
"""测试知识缺口识别""" """测试知识缺口识别"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -136,6 +137,7 @@ def test_knowledge_gap_detection():
print("\n✓ 知识缺口识别测试完成") print("\n✓ 知识缺口识别测试完成")
return True return True
def test_cache_manager(): def test_cache_manager():
"""测试缓存管理器""" """测试缓存管理器"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -156,11 +158,9 @@ def test_cache_manager():
print(" ✓ 获取缓存: {value}") print(" ✓ 获取缓存: {value}")
# 批量操作 # 批量操作
cache.set_many({ cache.set_many(
"batch_key_1": "value1", {"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60
"batch_key_2": "value2", )
"batch_key_3": "value3"
}, ttl=60)
print(" ✓ 批量设置缓存") print(" ✓ 批量设置缓存")
_ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"]) _ = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
@@ -185,6 +185,7 @@ def test_cache_manager():
print("\n✓ 缓存管理器测试完成") print("\n✓ 缓存管理器测试完成")
return True return True
def test_task_queue(): def test_task_queue():
"""测试任务队列""" """测试任务队列"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -207,8 +208,7 @@ def test_task_queue():
# 提交任务 # 提交任务
task_id = queue.submit( task_id = queue.submit(
task_type="test_task", task_type="test_task", payload={"test": "data", "timestamp": time.time()}
payload={"test": "data", "timestamp": time.time()}
) )
print(" ✓ 提交任务: {task_id}") print(" ✓ 提交任务: {task_id}")
@@ -226,6 +226,7 @@ def test_task_queue():
print("\n✓ 任务队列测试完成") print("\n✓ 任务队列测试完成")
return True return True
def test_performance_monitor(): def test_performance_monitor():
"""测试性能监控""" """测试性能监控"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -242,7 +243,7 @@ def test_performance_monitor():
metric_type="api_response", metric_type="api_response",
duration_ms=50 + i * 10, duration_ms=50 + i * 10,
endpoint="/api/v1/test", endpoint="/api/v1/test",
metadata={"test": True} metadata={"test": True},
) )
for i in range(3): for i in range(3):
@@ -250,7 +251,7 @@ def test_performance_monitor():
metric_type="db_query", metric_type="db_query",
duration_ms=20 + i * 5, duration_ms=20 + i * 5,
endpoint="SELECT test", endpoint="SELECT test",
metadata={"test": True} metadata={"test": True},
) )
print(" ✓ 记录了 8 个测试指标") print(" ✓ 记录了 8 个测试指标")
@@ -263,13 +264,16 @@ def test_performance_monitor():
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms") print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
print("\n3. 按类型统计:") print("\n3. 按类型统计:")
for type_stat in stats.get('by_type', []): for type_stat in stats.get("by_type", []):
print(f" {type_stat['type']}: {type_stat['count']} 次, " print(
f"平均 {type_stat['avg_duration_ms']} ms") f" {type_stat['type']}: {type_stat['count']} 次, "
f"平均 {type_stat['avg_duration_ms']} ms"
)
print("\n✓ 性能监控测试完成") print("\n✓ 性能监控测试完成")
return True return True
def test_search_manager(): def test_search_manager():
"""测试搜索管理器""" """测试搜索管理器"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -290,6 +294,7 @@ def test_search_manager():
print("\n✓ 搜索管理器测试完成") print("\n✓ 搜索管理器测试完成")
return True return True
def test_performance_manager(): def test_performance_manager():
"""测试性能管理器""" """测试性能管理器"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -314,6 +319,7 @@ def test_performance_manager():
print("\n✓ 性能管理器测试完成") print("\n✓ 性能管理器测试完成")
return True return True
def run_all_tests(): def run_all_tests():
"""运行所有测试""" """运行所有测试"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -400,6 +406,7 @@ def run_all_tests():
return passed == total return passed == total
if __name__ == "__main__": if __name__ == "__main__":
success = run_all_tests() success = run_all_tests()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)

View File

@@ -17,6 +17,7 @@ from tenant_manager import get_tenant_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_tenant_management(): def test_tenant_management():
"""测试租户管理功能""" """测试租户管理功能"""
print("=" * 60) print("=" * 60)
@@ -28,10 +29,7 @@ def test_tenant_management():
# 1. 创建租户 # 1. 创建租户
print("\n1.1 创建租户...") print("\n1.1 创建租户...")
tenant = manager.create_tenant( tenant = manager.create_tenant(
name="Test Company", name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant"
owner_id="user_001",
tier="pro",
description="A test company tenant"
) )
print(f"✅ 租户创建成功: {tenant.id}") print(f"✅ 租户创建成功: {tenant.id}")
print(f" - 名称: {tenant.name}") print(f" - 名称: {tenant.name}")
@@ -55,9 +53,7 @@ def test_tenant_management():
# 4. 更新租户 # 4. 更新租户
print("\n1.4 更新租户信息...") print("\n1.4 更新租户信息...")
updated = manager.update_tenant( updated = manager.update_tenant(
tenant_id=tenant.id, tenant_id=tenant.id, name="Test Company Updated", tier="enterprise"
name="Test Company Updated",
tier="enterprise"
) )
assert updated is not None, "更新租户失败" assert updated is not None, "更新租户失败"
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}") print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
@@ -69,6 +65,7 @@ def test_tenant_management():
return tenant.id return tenant.id
def test_domain_management(tenant_id: str): def test_domain_management(tenant_id: str):
"""测试域名管理功能""" """测试域名管理功能"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -79,11 +76,7 @@ def test_domain_management(tenant_id: str):
# 1. 添加域名 # 1. 添加域名
print("\n2.1 添加自定义域名...") print("\n2.1 添加自定义域名...")
domain = manager.add_domain( domain = manager.add_domain(tenant_id=tenant_id, domain="test.example.com", is_primary=True)
tenant_id=tenant_id,
domain="test.example.com",
is_primary=True
)
print(f"✅ 域名添加成功: {domain.domain}") print(f"✅ 域名添加成功: {domain.domain}")
print(f" - ID: {domain.id}") print(f" - ID: {domain.id}")
print(f" - 状态: {domain.status}") print(f" - 状态: {domain.status}")
@@ -118,6 +111,7 @@ def test_domain_management(tenant_id: str):
return domain.id return domain.id
def test_branding_management(tenant_id: str): def test_branding_management(tenant_id: str):
"""测试品牌白标功能""" """测试品牌白标功能"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -136,7 +130,7 @@ def test_branding_management(tenant_id: str):
secondary_color="#52c41a", secondary_color="#52c41a",
custom_css=".header { background: #1890ff; }", custom_css=".header { background: #1890ff; }",
custom_js="console.log('Custom JS loaded');", custom_js="console.log('Custom JS loaded');",
login_page_bg="https://example.com/bg.jpg" login_page_bg="https://example.com/bg.jpg",
) )
print("✅ 品牌配置更新成功") print("✅ 品牌配置更新成功")
print(f" - Logo: {branding.logo_url}") print(f" - Logo: {branding.logo_url}")
@@ -157,6 +151,7 @@ def test_branding_management(tenant_id: str):
return branding.id return branding.id
def test_member_management(tenant_id: str): def test_member_management(tenant_id: str):
"""测试成员管理功能""" """测试成员管理功能"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -168,10 +163,7 @@ def test_member_management(tenant_id: str):
# 1. 邀请成员 # 1. 邀请成员
print("\n4.1 邀请成员...") print("\n4.1 邀请成员...")
member1 = manager.invite_member( member1 = manager.invite_member(
tenant_id=tenant_id, tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001"
email="admin@test.com",
role="admin",
invited_by="user_001"
) )
print(f"✅ 成员邀请成功: {member1.email}") print(f"✅ 成员邀请成功: {member1.email}")
print(f" - ID: {member1.id}") print(f" - ID: {member1.id}")
@@ -179,10 +171,7 @@ def test_member_management(tenant_id: str):
print(f" - 权限: {member1.permissions}") print(f" - 权限: {member1.permissions}")
member2 = manager.invite_member( member2 = manager.invite_member(
tenant_id=tenant_id, tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001"
email="member@test.com",
role="member",
invited_by="user_001"
) )
print(f"✅ 成员邀请成功: {member2.email}") print(f"✅ 成员邀请成功: {member2.email}")
@@ -217,6 +206,7 @@ def test_member_management(tenant_id: str):
return member1.id, member2.id return member1.id, member2.id
def test_usage_tracking(tenant_id: str): def test_usage_tracking(tenant_id: str):
"""测试资源使用统计功能""" """测试资源使用统计功能"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -230,11 +220,11 @@ def test_usage_tracking(tenant_id: str):
manager.record_usage( manager.record_usage(
tenant_id=tenant_id, tenant_id=tenant_id,
storage_bytes=1024 * 1024 * 50, # 50MB storage_bytes=1024 * 1024 * 50, # 50MB
transcription_seconds=600, # 10分钟 transcription_seconds=600, # 10分钟
api_calls=100, api_calls=100,
projects_count=5, projects_count=5,
entities_count=50, entities_count=50,
members_count=3 members_count=3,
) )
print("✅ 资源使用记录成功") print("✅ 资源使用记录成功")
@@ -258,6 +248,7 @@ def test_usage_tracking(tenant_id: str):
return stats return stats
def cleanup(tenant_id: str, domain_id: str, member_ids: list): def cleanup(tenant_id: str, domain_id: str, member_ids: list):
"""清理测试数据""" """清理测试数据"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -281,6 +272,7 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list):
manager.delete_tenant(tenant_id) manager.delete_tenant(tenant_id)
print(f"✅ 租户已删除: {tenant_id}") print(f"✅ 租户已删除: {tenant_id}")
def main(): def main():
"""主测试函数""" """主测试函数"""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -307,6 +299,7 @@ def main():
except Exception as e: except Exception as e:
print(f"\n❌ 测试失败: {e}") print(f"\n❌ 测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally: finally:
@@ -317,5 +310,6 @@ def main():
except Exception as e: except Exception as e:
print(f"⚠️ 清理失败: {e}") print(f"⚠️ 清理失败: {e}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -11,6 +11,7 @@ from subscription_manager import PaymentProvider, SubscriptionManager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_subscription_manager(): def test_subscription_manager():
"""测试订阅管理器""" """测试订阅管理器"""
print("=" * 60) print("=" * 60)
@@ -18,7 +19,7 @@ def test_subscription_manager():
print("=" * 60) print("=" * 60)
# 使用临时文件数据库进行测试 # 使用临时文件数据库进行测试
db_path = tempfile.mktemp(suffix='.db') db_path = tempfile.mktemp(suffix=".db")
try: try:
manager = SubscriptionManager(db_path=db_path) manager = SubscriptionManager(db_path=db_path)
@@ -55,7 +56,7 @@ def test_subscription_manager():
tenant_id=tenant_id, tenant_id=tenant_id,
plan_id=pro_plan.id, plan_id=pro_plan.id,
payment_provider=PaymentProvider.STRIPE.value, payment_provider=PaymentProvider.STRIPE.value,
trial_days=14 trial_days=14,
) )
print(f"✓ 创建订阅: {subscription.id}") print(f"✓ 创建订阅: {subscription.id}")
@@ -78,7 +79,7 @@ def test_subscription_manager():
resource_type="transcription", resource_type="transcription",
quantity=120, quantity=120,
unit="minute", unit="minute",
description="会议转录" description="会议转录",
) )
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}") print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
@@ -88,7 +89,7 @@ def test_subscription_manager():
resource_type="storage", resource_type="storage",
quantity=2.5, quantity=2.5,
unit="gb", unit="gb",
description="文件存储" description="文件存储",
) )
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}") print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
@@ -96,7 +97,7 @@ def test_subscription_manager():
summary = manager.get_usage_summary(tenant_id) summary = manager.get_usage_summary(tenant_id)
print("✓ 用量汇总:") print("✓ 用量汇总:")
print(f" - 总费用: ¥{summary['total_cost']:.2f}") print(f" - 总费用: ¥{summary['total_cost']:.2f}")
for resource, data in summary['breakdown'].items(): for resource, data in summary["breakdown"].items():
print(f" - {resource}: {data['quantity']}{data['cost']:.2f})") print(f" - {resource}: {data['quantity']}{data['cost']:.2f})")
print("\n4. 测试支付管理") print("\n4. 测试支付管理")
@@ -108,7 +109,7 @@ def test_subscription_manager():
amount=99.0, amount=99.0,
currency="CNY", currency="CNY",
provider=PaymentProvider.ALIPAY.value, provider=PaymentProvider.ALIPAY.value,
payment_method="qrcode" payment_method="qrcode",
) )
print(f"✓ 创建支付: {payment.id}") print(f"✓ 创建支付: {payment.id}")
print(f" - 金额: ¥{payment.amount}") print(f" - 金额: ¥{payment.amount}")
@@ -145,7 +146,7 @@ def test_subscription_manager():
payment_id=payment.id, payment_id=payment.id,
amount=50.0, amount=50.0,
reason="服务不满意", reason="服务不满意",
requested_by="user_001" requested_by="user_001",
) )
print(f"✓ 申请退款: {refund.id}") print(f"✓ 申请退款: {refund.id}")
print(f" - 金额: ¥{refund.amount}") print(f" - 金额: ¥{refund.amount}")
@@ -180,29 +181,23 @@ def test_subscription_manager():
tenant_id=tenant_id, tenant_id=tenant_id,
plan_id=enterprise_plan.id, plan_id=enterprise_plan.id,
success_url="https://example.com/success", success_url="https://example.com/success",
cancel_url="https://example.com/cancel" cancel_url="https://example.com/cancel",
) )
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}") print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
# 支付宝订单 # 支付宝订单
alipay_order = manager.create_alipay_order( alipay_order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=pro_plan.id)
tenant_id=tenant_id,
plan_id=pro_plan.id
)
print(f"✓ 支付宝订单: {alipay_order['order_id']}") print(f"✓ 支付宝订单: {alipay_order['order_id']}")
# 微信支付订单 # 微信支付订单
wechat_order = manager.create_wechat_order( wechat_order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=pro_plan.id)
tenant_id=tenant_id,
plan_id=pro_plan.id
)
print(f"✓ 微信支付订单: {wechat_order['order_id']}") print(f"✓ 微信支付订单: {wechat_order['order_id']}")
# Webhook 处理 # Webhook 处理
webhook_result = manager.handle_webhook("stripe", { webhook_result = manager.handle_webhook(
"event_type": "checkout.session.completed", "stripe",
"data": {"object": {"id": "cs_test"}} {"event_type": "checkout.session.completed", "data": {"object": {"id": "cs_test"}}},
}) )
print(f"✓ Webhook 处理: {webhook_result}") print(f"✓ Webhook 处理: {webhook_result}")
print("\n9. 测试订阅变更") print("\n9. 测试订阅变更")
@@ -210,16 +205,12 @@ def test_subscription_manager():
# 更改计划 # 更改计划
changed = manager.change_plan( changed = manager.change_plan(
subscription_id=subscription.id, subscription_id=subscription.id, new_plan_id=enterprise_plan.id
new_plan_id=enterprise_plan.id
) )
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)") print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
# 取消订阅 # 取消订阅
cancelled = manager.cancel_subscription( cancelled = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=True)
subscription_id=subscription.id,
at_period_end=True
)
print(f"✓ 取消订阅: {cancelled.status}") print(f"✓ 取消订阅: {cancelled.status}")
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}") print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
@@ -233,11 +224,13 @@ def test_subscription_manager():
os.remove(db_path) os.remove(db_path)
print(f"\n清理临时数据库: {db_path}") print(f"\n清理临时数据库: {db_path}")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
test_subscription_manager() test_subscription_manager()
except Exception as e: except Exception as e:
print(f"\n❌ 测试失败: {e}") print(f"\n❌ 测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View File

@@ -13,6 +13,7 @@ from ai_manager import ModelType, PredictionType, get_ai_manager
# Add backend directory to path # Add backend directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_custom_model(): def test_custom_model():
"""测试自定义模型功能""" """测试自定义模型功能"""
print("\n=== 测试自定义模型 ===") print("\n=== 测试自定义模型 ===")
@@ -28,14 +29,10 @@ def test_custom_model():
model_type=ModelType.CUSTOM_NER, model_type=ModelType.CUSTOM_NER,
training_data={ training_data={
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"], "entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
"domain": "medical" "domain": "medical",
}, },
hyperparameters={ hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
"epochs": 15, created_by="user_001",
"learning_rate": 0.001,
"batch_size": 32
},
created_by="user_001"
) )
print(f" 创建成功: {model.id}, 状态: {model.status.value}") print(f" 创建成功: {model.id}, 状态: {model.status.value}")
@@ -47,8 +44,8 @@ def test_custom_model():
"entities": [ "entities": [
{"start": 2, "end": 4, "label": "PERSON", "text": "张三"}, {"start": 2, "end": 4, "label": "PERSON", "text": "张三"},
{"start": 6, "end": 9, "label": "DISEASE", "text": "高血压"}, {"start": 6, "end": 9, "label": "DISEASE", "text": "高血压"},
{"start": 14, "end": 17, "label": "DRUG", "text": "降压药"} {"start": 14, "end": 17, "label": "DRUG", "text": "降压药"},
] ],
}, },
{ {
"text": "李四因感冒发烧到医院就诊,医生开具了退烧药。", "text": "李四因感冒发烧到医院就诊,医生开具了退烧药。",
@@ -56,16 +53,16 @@ def test_custom_model():
{"start": 0, "end": 2, "label": "PERSON", "text": "李四"}, {"start": 0, "end": 2, "label": "PERSON", "text": "李四"},
{"start": 3, "end": 5, "label": "SYMPTOM", "text": "感冒"}, {"start": 3, "end": 5, "label": "SYMPTOM", "text": "感冒"},
{"start": 5, "end": 7, "label": "SYMPTOM", "text": "发烧"}, {"start": 5, "end": 7, "label": "SYMPTOM", "text": "发烧"},
{"start": 21, "end": 24, "label": "DRUG", "text": "退烧药"} {"start": 21, "end": 24, "label": "DRUG", "text": "退烧药"},
] ],
}, },
{ {
"text": "王五接受了心脏搭桥手术,术后恢复良好。", "text": "王五接受了心脏搭桥手术,术后恢复良好。",
"entities": [ "entities": [
{"start": 0, "end": 2, "label": "PERSON", "text": "王五"}, {"start": 0, "end": 2, "label": "PERSON", "text": "王五"},
{"start": 5, "end": 11, "label": "TREATMENT", "text": "心脏搭桥手术"} {"start": 5, "end": 11, "label": "TREATMENT", "text": "心脏搭桥手术"},
] ],
} },
] ]
for sample_data in samples: for sample_data in samples:
@@ -73,7 +70,7 @@ def test_custom_model():
model_id=model.id, model_id=model.id,
text=sample_data["text"], text=sample_data["text"],
entities=sample_data["entities"], entities=sample_data["entities"],
metadata={"source": "manual"} metadata={"source": "manual"},
) )
print(f" 添加样本: {sample.id}") print(f" 添加样本: {sample.id}")
@@ -91,6 +88,7 @@ def test_custom_model():
return model.id return model.id
async def test_train_and_predict(model_id: str): async def test_train_and_predict(model_id: str):
"""测试训练和预测""" """测试训练和预测"""
print("\n=== 测试模型训练和预测 ===") print("\n=== 测试模型训练和预测 ===")
@@ -117,6 +115,7 @@ async def test_train_and_predict(model_id: str):
except Exception as e: except Exception as e:
print(f" 预测失败: {e}") print(f" 预测失败: {e}")
def test_prediction_models(): def test_prediction_models():
"""测试预测模型""" """测试预测模型"""
print("\n=== 测试预测模型 ===") print("\n=== 测试预测模型 ===")
@@ -132,10 +131,7 @@ def test_prediction_models():
prediction_type=PredictionType.TREND, prediction_type=PredictionType.TREND,
target_entity_type="PERSON", target_entity_type="PERSON",
features=["entity_count", "time_period", "document_count"], features=["entity_count", "time_period", "document_count"],
model_config={ model_config={"algorithm": "linear_regression", "window_size": 7},
"algorithm": "linear_regression",
"window_size": 7
}
) )
print(f" 创建成功: {trend_model.id}") print(f" 创建成功: {trend_model.id}")
@@ -148,10 +144,7 @@ def test_prediction_models():
prediction_type=PredictionType.ANOMALY, prediction_type=PredictionType.ANOMALY,
target_entity_type=None, target_entity_type=None,
features=["daily_growth", "weekly_growth"], features=["daily_growth", "weekly_growth"],
model_config={ model_config={"threshold": 2.5, "sensitivity": "medium"},
"threshold": 2.5,
"sensitivity": "medium"
}
) )
print(f" 创建成功: {anomaly_model.id}") print(f" 创建成功: {anomaly_model.id}")
@@ -164,6 +157,7 @@ def test_prediction_models():
return trend_model.id, anomaly_model.id return trend_model.id, anomaly_model.id
async def test_predictions(trend_model_id: str, anomaly_model_id: str): async def test_predictions(trend_model_id: str, anomaly_model_id: str):
"""测试预测功能""" """测试预测功能"""
print("\n=== 测试预测功能 ===") print("\n=== 测试预测功能 ===")
@@ -179,7 +173,7 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
{"date": "2024-01-04", "value": 14}, {"date": "2024-01-04", "value": 14},
{"date": "2024-01-05", "value": 18}, {"date": "2024-01-05", "value": 18},
{"date": "2024-01-06", "value": 20}, {"date": "2024-01-06", "value": 20},
{"date": "2024-01-07", "value": 22} {"date": "2024-01-07", "value": 22},
] ]
trained = await manager.train_prediction_model(trend_model_id, historical_data) trained = await manager.train_prediction_model(trend_model_id, historical_data)
print(f" 训练完成,准确率: {trained.accuracy}") print(f" 训练完成,准确率: {trained.accuracy}")
@@ -187,22 +181,18 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
# 2. 趋势预测 # 2. 趋势预测
print("2. 趋势预测...") print("2. 趋势预测...")
trend_result = await manager.predict( trend_result = await manager.predict(
trend_model_id, trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]}
{"historical_values": [10, 12, 15, 14, 18, 20, 22]}
) )
print(f" 预测结果: {trend_result.prediction_data}") print(f" 预测结果: {trend_result.prediction_data}")
# 3. 异常检测 # 3. 异常检测
print("3. 异常检测...") print("3. 异常检测...")
anomaly_result = await manager.predict( anomaly_result = await manager.predict(
anomaly_model_id, anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]}
{
"value": 50,
"historical_values": [10, 12, 11, 13, 12, 14, 13]
}
) )
print(f" 检测结果: {anomaly_result.prediction_data}") print(f" 检测结果: {anomaly_result.prediction_data}")
def test_kg_rag(): def test_kg_rag():
"""测试知识图谱 RAG""" """测试知识图谱 RAG"""
print("\n=== 测试知识图谱 RAG ===") print("\n=== 测试知识图谱 RAG ===")
@@ -218,18 +208,10 @@ def test_kg_rag():
description="基于项目知识图谱的智能问答", description="基于项目知识图谱的智能问答",
kg_config={ kg_config={
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"], "entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
"relation_types": ["works_with", "belongs_to", "depends_on"] "relation_types": ["works_with", "belongs_to", "depends_on"],
}, },
retrieval_config={ retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
"top_k": 5, generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
"similarity_threshold": 0.7,
"expand_relations": True
},
generation_config={
"temperature": 0.3,
"max_tokens": 1000,
"include_sources": True
}
) )
print(f" 创建成功: {rag.id}") print(f" 创建成功: {rag.id}")
@@ -240,6 +222,7 @@ def test_kg_rag():
return rag.id return rag.id
async def test_kg_rag_query(rag_id: str): async def test_kg_rag_query(rag_id: str):
"""测试 RAG 查询""" """测试 RAG 查询"""
print("\n=== 测试知识图谱 RAG 查询 ===") print("\n=== 测试知识图谱 RAG 查询 ===")
@@ -252,33 +235,43 @@ async def test_kg_rag_query(rag_id: str):
{"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"}, {"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"},
{"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"}, {"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"},
{"id": "e4", "name": "Kubernetes", "type": "TECH", "definition": "容器编排平台"}, {"id": "e4", "name": "Kubernetes", "type": "TECH", "definition": "容器编排平台"},
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"} {"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"},
] ]
project_relations = [{"source_entity_id": "e1", project_relations = [
"target_entity_id": "e3", {
"source_name": "张三", "source_entity_id": "e1",
"target_name": "Project Alpha", "target_entity_id": "e3",
"relation_type": "works_with", "source_name": "张三",
"evidence": "张三负责 Project Alpha 的管理工作"}, "target_name": "Project Alpha",
{"source_entity_id": "e2", "relation_type": "works_with",
"target_entity_id": "e3", "evidence": "张三负责 Project Alpha 的管理工作",
"source_name": "李四", },
"target_name": "Project Alpha", {
"relation_type": "works_with", "source_entity_id": "e2",
"evidence": "李四负责 Project Alpha 的技术架构"}, "target_entity_id": "e3",
{"source_entity_id": "e3", "source_name": "李四",
"target_entity_id": "e4", "target_name": "Project Alpha",
"source_name": "Project Alpha", "relation_type": "works_with",
"target_name": "Kubernetes", "evidence": "李四负责 Project Alpha 的技术架构",
"relation_type": "depends_on", },
"evidence": "项目使用 Kubernetes 进行部署"}, {
{"source_entity_id": "e1", "source_entity_id": "e3",
"target_entity_id": "e5", "target_entity_id": "e4",
"source_name": "张三", "source_name": "Project Alpha",
"target_name": "TechCorp", "target_name": "Kubernetes",
"relation_type": "belongs_to", "relation_type": "depends_on",
"evidence": "张三是 TechCorp 的员工"}] "evidence": "项目使用 Kubernetes 进行部署",
},
{
"source_entity_id": "e1",
"target_entity_id": "e5",
"source_name": "张三",
"target_name": "TechCorp",
"relation_type": "belongs_to",
"evidence": "张三是 TechCorp 的员工",
},
]
# 执行查询 # 执行查询
print("1. 执行 RAG 查询...") print("1. 执行 RAG 查询...")
@@ -289,7 +282,7 @@ async def test_kg_rag_query(rag_id: str):
rag_id=rag_id, rag_id=rag_id,
query=query_text, query=query_text,
project_entities=project_entities, project_entities=project_entities,
project_relations=project_relations project_relations=project_relations,
) )
print(f" 查询: {result.query}") print(f" 查询: {result.query}")
@@ -300,6 +293,7 @@ async def test_kg_rag_query(rag_id: str):
except Exception as e: except Exception as e:
print(f" 查询失败: {e}") print(f" 查询失败: {e}")
async def test_smart_summary(): async def test_smart_summary():
"""测试智能摘要""" """测试智能摘要"""
print("\n=== 测试智能摘要 ===") print("\n=== 测试智能摘要 ===")
@@ -321,8 +315,8 @@ async def test_smart_summary():
{"name": "张三", "type": "PERSON"}, {"name": "张三", "type": "PERSON"},
{"name": "李四", "type": "PERSON"}, {"name": "李四", "type": "PERSON"},
{"name": "Project Alpha", "type": "PROJECT"}, {"name": "Project Alpha", "type": "PROJECT"},
{"name": "Kubernetes", "type": "TECH"} {"name": "Kubernetes", "type": "TECH"},
] ],
} }
# 生成不同类型的摘要 # 生成不同类型的摘要
@@ -337,7 +331,7 @@ async def test_smart_summary():
source_type="transcript", source_type="transcript",
source_id="transcript_001", source_id="transcript_001",
summary_type=summary_type, summary_type=summary_type,
content_data=content_data content_data=content_data,
) )
print(f" 摘要类型: {summary.summary_type}") print(f" 摘要类型: {summary.summary_type}")
@@ -347,6 +341,7 @@ async def test_smart_summary():
except Exception as e: except Exception as e:
print(f" 生成失败: {e}") print(f" 生成失败: {e}")
async def main(): async def main():
"""主测试函数""" """主测试函数"""
print("=" * 60) print("=" * 60)
@@ -382,7 +377,9 @@ async def main():
except Exception as e: except Exception as e:
print(f"\n测试失败: {e}") print(f"\n测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -32,6 +32,7 @@ backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path: if backend_dir not in sys.path:
sys.path.insert(0, backend_dir) sys.path.insert(0, backend_dir)
class TestGrowthManager: class TestGrowthManager:
"""测试 Growth Manager 功能""" """测试 Growth Manager 功能"""
@@ -63,7 +64,7 @@ class TestGrowthManager:
session_id="session_001", session_id="session_001",
device_info={"browser": "Chrome", "os": "MacOS"}, device_info={"browser": "Chrome", "os": "MacOS"},
referrer="https://google.com", referrer="https://google.com",
utm_params={"source": "google", "medium": "organic", "campaign": "summer"} utm_params={"source": "google", "medium": "organic", "campaign": "summer"},
) )
assert event.id is not None assert event.id is not None
@@ -94,7 +95,7 @@ class TestGrowthManager:
user_id=self.test_user_id, user_id=self.test_user_id,
event_type=event_type, event_type=event_type,
event_name=event_name, event_name=event_name,
properties=props properties=props,
) )
self.log(f"成功追踪 {len(events)} 个事件") self.log(f"成功追踪 {len(events)} 个事件")
@@ -130,7 +131,7 @@ class TestGrowthManager:
summary = self.manager.get_user_analytics_summary( summary = self.manager.get_user_analytics_summary(
tenant_id=self.test_tenant_id, tenant_id=self.test_tenant_id,
start_date=datetime.now() - timedelta(days=7), start_date=datetime.now() - timedelta(days=7),
end_date=datetime.now() end_date=datetime.now(),
) )
assert "unique_users" in summary assert "unique_users" in summary
@@ -156,9 +157,9 @@ class TestGrowthManager:
{"name": "访问首页", "event_name": "page_view_home"}, {"name": "访问首页", "event_name": "page_view_home"},
{"name": "点击注册", "event_name": "signup_click"}, {"name": "点击注册", "event_name": "signup_click"},
{"name": "填写信息", "event_name": "signup_form_fill"}, {"name": "填写信息", "event_name": "signup_form_fill"},
{"name": "完成注册", "event_name": "signup_complete"} {"name": "完成注册", "event_name": "signup_complete"},
], ],
created_by="test" created_by="test",
) )
assert funnel.id is not None assert funnel.id is not None
@@ -182,7 +183,7 @@ class TestGrowthManager:
analysis = self.manager.analyze_funnel( analysis = self.manager.analyze_funnel(
funnel_id=funnel_id, funnel_id=funnel_id,
period_start=datetime.now() - timedelta(days=30), period_start=datetime.now() - timedelta(days=30),
period_end=datetime.now() period_end=datetime.now(),
) )
if analysis: if analysis:
@@ -204,7 +205,7 @@ class TestGrowthManager:
retention = self.manager.calculate_retention( retention = self.manager.calculate_retention(
tenant_id=self.test_tenant_id, tenant_id=self.test_tenant_id,
cohort_date=datetime.now() - timedelta(days=7), cohort_date=datetime.now() - timedelta(days=7),
periods=[1, 3, 7] periods=[1, 3, 7],
) )
assert "cohort_date" in retention assert "cohort_date" in retention
@@ -231,7 +232,7 @@ class TestGrowthManager:
variants=[ variants=[
{"id": "control", "name": "红色按钮", "is_control": True}, {"id": "control", "name": "红色按钮", "is_control": True},
{"id": "variant_a", "name": "蓝色按钮", "is_control": False}, {"id": "variant_a", "name": "蓝色按钮", "is_control": False},
{"id": "variant_b", "name": "绿色按钮", "is_control": False} {"id": "variant_b", "name": "绿色按钮", "is_control": False},
], ],
traffic_allocation=TrafficAllocationType.RANDOM, traffic_allocation=TrafficAllocationType.RANDOM,
traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33}, traffic_split={"control": 0.34, "variant_a": 0.33, "variant_b": 0.33},
@@ -240,7 +241,7 @@ class TestGrowthManager:
secondary_metrics=["conversion_rate", "bounce_rate"], secondary_metrics=["conversion_rate", "bounce_rate"],
min_sample_size=100, min_sample_size=100,
confidence_level=0.95, confidence_level=0.95,
created_by="test" created_by="test",
) )
assert experiment.id is not None assert experiment.id is not None
@@ -285,7 +286,7 @@ class TestGrowthManager:
variant_id = self.manager.assign_variant( variant_id = self.manager.assign_variant(
experiment_id=experiment_id, experiment_id=experiment_id,
user_id=user_id, user_id=user_id,
user_attributes={"user_id": user_id, "segment": "new"} user_attributes={"user_id": user_id, "segment": "new"},
) )
if variant_id: if variant_id:
@@ -321,7 +322,7 @@ class TestGrowthManager:
variant_id=variant_id, variant_id=variant_id,
user_id=user_id, user_id=user_id,
metric_name="button_click_rate", metric_name="button_click_rate",
metric_value=value metric_value=value,
) )
self.log(f"成功记录 {len(test_data)} 条指标") self.log(f"成功记录 {len(test_data)} 条指标")
@@ -375,7 +376,7 @@ class TestGrowthManager:
<p><a href="{{dashboard_url}}">立即开始使用</a></p> <p><a href="{{dashboard_url}}">立即开始使用</a></p>
""", """,
from_name="InsightFlow 团队", from_name="InsightFlow 团队",
from_email="welcome@insightflow.io" from_email="welcome@insightflow.io",
) )
assert template.id is not None assert template.id is not None
@@ -413,8 +414,8 @@ class TestGrowthManager:
template_id=template_id, template_id=template_id,
variables={ variables={
"user_name": "张三", "user_name": "张三",
"dashboard_url": "https://app.insightflow.io/dashboard" "dashboard_url": "https://app.insightflow.io/dashboard",
} },
) )
if rendered: if rendered:
@@ -445,8 +446,8 @@ class TestGrowthManager:
recipient_list=[ recipient_list=[
{"user_id": "user_001", "email": "user1@example.com"}, {"user_id": "user_001", "email": "user1@example.com"},
{"user_id": "user_002", "email": "user2@example.com"}, {"user_id": "user_002", "email": "user2@example.com"},
{"user_id": "user_003", "email": "user3@example.com"} {"user_id": "user_003", "email": "user3@example.com"},
] ],
) )
assert campaign.id is not None assert campaign.id is not None
@@ -472,8 +473,8 @@ class TestGrowthManager:
actions=[ actions=[
{"type": "send_email", "template_type": "welcome", "delay_hours": 0}, {"type": "send_email", "template_type": "welcome", "delay_hours": 0},
{"type": "send_email", "template_type": "onboarding", "delay_hours": 24}, {"type": "send_email", "template_type": "onboarding", "delay_hours": 24},
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72} {"type": "send_email", "template_type": "feature_tips", "delay_hours": 72},
] ],
) )
assert workflow.id is not None assert workflow.id is not None
@@ -502,7 +503,7 @@ class TestGrowthManager:
referee_reward_value=50.0, referee_reward_value=50.0,
max_referrals_per_user=10, max_referrals_per_user=10,
referral_code_length=8, referral_code_length=8,
expiry_days=30 expiry_days=30,
) )
assert program.id is not None assert program.id is not None
@@ -524,8 +525,7 @@ class TestGrowthManager:
try: try:
referral = self.manager.generate_referral_code( referral = self.manager.generate_referral_code(
program_id=program_id, program_id=program_id, referrer_id="referrer_user_001"
referrer_id="referrer_user_001"
) )
if referral: if referral:
@@ -551,8 +551,7 @@ class TestGrowthManager:
try: try:
success = self.manager.apply_referral_code( success = self.manager.apply_referral_code(
referral_code=referral_code, referral_code=referral_code, referee_id="new_user_001"
referee_id="new_user_001"
) )
if success: if success:
@@ -579,7 +578,9 @@ class TestGrowthManager:
assert "total_referrals" in stats assert "total_referrals" in stats
assert "conversion_rate" in stats assert "conversion_rate" in stats
self.log(f"推荐统计: {stats['total_referrals']} 推荐, {stats['conversion_rate']:.2%} 转化率") self.log(
f"推荐统计: {stats['total_referrals']} 推荐, {stats['conversion_rate']:.2%} 转化率"
)
return True return True
except Exception as e: except Exception as e:
self.log(f"获取推荐统计失败: {e}", success=False) self.log(f"获取推荐统计失败: {e}", success=False)
@@ -599,7 +600,7 @@ class TestGrowthManager:
incentive_type="discount", incentive_type="discount",
incentive_value=20.0, # 20% 折扣 incentive_value=20.0, # 20% 折扣
valid_from=datetime.now(), valid_from=datetime.now(),
valid_until=datetime.now() + timedelta(days=90) valid_until=datetime.now() + timedelta(days=90),
) )
assert incentive.id is not None assert incentive.id is not None
@@ -617,9 +618,7 @@ class TestGrowthManager:
try: try:
incentives = self.manager.check_team_incentive_eligibility( incentives = self.manager.check_team_incentive_eligibility(
tenant_id=self.test_tenant_id, tenant_id=self.test_tenant_id, current_tier="free", team_size=5
current_tier="free",
team_size=5
) )
self.log(f"找到 {len(incentives)} 个符合条件的激励") self.log(f"找到 {len(incentives)} 个符合条件的激励")
@@ -642,7 +641,9 @@ class TestGrowthManager:
assert "top_features" in dashboard assert "top_features" in dashboard
today = dashboard["today"] today = dashboard["today"]
self.log(f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件") self.log(
f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件"
)
return True return True
except Exception as e: except Exception as e:
self.log(f"获取实时仪表板失败: {e}", success=False) self.log(f"获取实时仪表板失败: {e}", success=False)
@@ -734,10 +735,12 @@ class TestGrowthManager:
print("✨ 测试完成!") print("✨ 测试完成!")
print("=" * 60) print("=" * 60)
async def main(): async def main():
"""主函数""" """主函数"""
tester = TestGrowthManager() tester = TestGrowthManager()
await tester.run_all_tests() await tester.run_all_tests()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -29,6 +29,7 @@ backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path: if backend_dir not in sys.path:
sys.path.insert(0, backend_dir) sys.path.insert(0, backend_dir)
class TestDeveloperEcosystem: class TestDeveloperEcosystem:
"""开发者生态系统测试类""" """开发者生态系统测试类"""
@@ -36,23 +37,21 @@ class TestDeveloperEcosystem:
self.manager = DeveloperEcosystemManager() self.manager = DeveloperEcosystemManager()
self.test_results = [] self.test_results = []
self.created_ids = { self.created_ids = {
'sdk': [], "sdk": [],
'template': [], "template": [],
'plugin': [], "plugin": [],
'developer': [], "developer": [],
'code_example': [], "code_example": [],
'portal_config': [] "portal_config": [],
} }
def log(self, message: str, success: bool = True): def log(self, message: str, success: bool = True):
"""记录测试结果""" """记录测试结果"""
status = "" if success else "" status = "" if success else ""
print(f"{status} {message}") print(f"{status} {message}")
self.test_results.append({ self.test_results.append(
'message': message, {"message": message, "success": success, "timestamp": datetime.now().isoformat()}
'success': success, )
'timestamp': datetime.now().isoformat()
})
def run_all_tests(self): def run_all_tests(self):
"""运行所有测试""" """运行所有测试"""
@@ -137,9 +136,9 @@ class TestDeveloperEcosystem:
dependencies=[{"name": "requests", "version": ">=2.0"}], dependencies=[{"name": "requests", "version": ">=2.0"}],
file_size=1024000, file_size=1024000,
checksum="abc123", checksum="abc123",
created_by="test_user" created_by="test_user",
) )
self.created_ids['sdk'].append(sdk.id) self.created_ids["sdk"].append(sdk.id)
self.log(f"Created SDK: {sdk.name} ({sdk.id})") self.log(f"Created SDK: {sdk.name} ({sdk.id})")
# Create JavaScript SDK # Create JavaScript SDK
@@ -157,9 +156,9 @@ class TestDeveloperEcosystem:
dependencies=[{"name": "axios", "version": ">=0.21"}], dependencies=[{"name": "axios", "version": ">=0.21"}],
file_size=512000, file_size=512000,
checksum="def456", checksum="def456",
created_by="test_user" created_by="test_user",
) )
self.created_ids['sdk'].append(sdk_js.id) self.created_ids["sdk"].append(sdk_js.id)
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})") self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
except Exception as e: except Exception as e:
@@ -185,8 +184,8 @@ class TestDeveloperEcosystem:
def test_sdk_get(self): def test_sdk_get(self):
"""测试获取 SDK 详情""" """测试获取 SDK 详情"""
try: try:
if self.created_ids['sdk']: if self.created_ids["sdk"]:
sdk = self.manager.get_sdk_release(self.created_ids['sdk'][0]) sdk = self.manager.get_sdk_release(self.created_ids["sdk"][0])
if sdk: if sdk:
self.log(f"Retrieved SDK: {sdk.name}") self.log(f"Retrieved SDK: {sdk.name}")
else: else:
@@ -197,10 +196,9 @@ class TestDeveloperEcosystem:
def test_sdk_update(self): def test_sdk_update(self):
"""测试更新 SDK""" """测试更新 SDK"""
try: try:
if self.created_ids['sdk']: if self.created_ids["sdk"]:
sdk = self.manager.update_sdk_release( sdk = self.manager.update_sdk_release(
self.created_ids['sdk'][0], self.created_ids["sdk"][0], description="Updated description"
description="Updated description"
) )
if sdk: if sdk:
self.log(f"Updated SDK: {sdk.name}") self.log(f"Updated SDK: {sdk.name}")
@@ -210,8 +208,8 @@ class TestDeveloperEcosystem:
def test_sdk_publish(self): def test_sdk_publish(self):
"""测试发布 SDK""" """测试发布 SDK"""
try: try:
if self.created_ids['sdk']: if self.created_ids["sdk"]:
sdk = self.manager.publish_sdk_release(self.created_ids['sdk'][0]) sdk = self.manager.publish_sdk_release(self.created_ids["sdk"][0])
if sdk: if sdk:
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})") self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
except Exception as e: except Exception as e:
@@ -220,15 +218,15 @@ class TestDeveloperEcosystem:
def test_sdk_version_add(self): def test_sdk_version_add(self):
"""测试添加 SDK 版本""" """测试添加 SDK 版本"""
try: try:
if self.created_ids['sdk']: if self.created_ids["sdk"]:
version = self.manager.add_sdk_version( version = self.manager.add_sdk_version(
sdk_id=self.created_ids['sdk'][0], sdk_id=self.created_ids["sdk"][0],
version="1.1.0", version="1.1.0",
is_lts=True, is_lts=True,
release_notes="Bug fixes and improvements", release_notes="Bug fixes and improvements",
download_url="https://pypi.org/insightflow/1.1.0", download_url="https://pypi.org/insightflow/1.1.0",
checksum="xyz789", checksum="xyz789",
file_size=1100000 file_size=1100000,
) )
self.log(f"Added SDK version: {version.version}") self.log(f"Added SDK version: {version.version}")
except Exception as e: except Exception as e:
@@ -254,9 +252,9 @@ class TestDeveloperEcosystem:
version="1.0.0", version="1.0.0",
min_platform_version="2.0.0", min_platform_version="2.0.0",
file_size=5242880, file_size=5242880,
checksum="tpl123" checksum="tpl123",
) )
self.created_ids['template'].append(template.id) self.created_ids["template"].append(template.id)
self.log(f"Created template: {template.name} ({template.id})") self.log(f"Created template: {template.name} ({template.id})")
# Create free template # Create free template
@@ -269,9 +267,9 @@ class TestDeveloperEcosystem:
author_id="dev_002", author_id="dev_002",
author_name="InsightFlow Team", author_name="InsightFlow Team",
price=0.0, price=0.0,
currency="CNY" currency="CNY",
) )
self.created_ids['template'].append(template_free.id) self.created_ids["template"].append(template_free.id)
self.log(f"Created free template: {template_free.name}") self.log(f"Created free template: {template_free.name}")
except Exception as e: except Exception as e:
@@ -297,8 +295,8 @@ class TestDeveloperEcosystem:
def test_template_get(self): def test_template_get(self):
"""测试获取模板详情""" """测试获取模板详情"""
try: try:
if self.created_ids['template']: if self.created_ids["template"]:
template = self.manager.get_template(self.created_ids['template'][0]) template = self.manager.get_template(self.created_ids["template"][0])
if template: if template:
self.log(f"Retrieved template: {template.name}") self.log(f"Retrieved template: {template.name}")
except Exception as e: except Exception as e:
@@ -307,10 +305,9 @@ class TestDeveloperEcosystem:
def test_template_approve(self): def test_template_approve(self):
"""测试审核通过模板""" """测试审核通过模板"""
try: try:
if self.created_ids['template']: if self.created_ids["template"]:
template = self.manager.approve_template( template = self.manager.approve_template(
self.created_ids['template'][0], self.created_ids["template"][0], reviewed_by="admin_001"
reviewed_by="admin_001"
) )
if template: if template:
self.log(f"Approved template: {template.name}") self.log(f"Approved template: {template.name}")
@@ -320,8 +317,8 @@ class TestDeveloperEcosystem:
def test_template_publish(self): def test_template_publish(self):
"""测试发布模板""" """测试发布模板"""
try: try:
if self.created_ids['template']: if self.created_ids["template"]:
template = self.manager.publish_template(self.created_ids['template'][0]) template = self.manager.publish_template(self.created_ids["template"][0])
if template: if template:
self.log(f"Published template: {template.name}") self.log(f"Published template: {template.name}")
except Exception as e: except Exception as e:
@@ -330,14 +327,14 @@ class TestDeveloperEcosystem:
def test_template_review(self): def test_template_review(self):
"""测试添加模板评价""" """测试添加模板评价"""
try: try:
if self.created_ids['template']: if self.created_ids["template"]:
review = self.manager.add_template_review( review = self.manager.add_template_review(
template_id=self.created_ids['template'][0], template_id=self.created_ids["template"][0],
user_id="user_001", user_id="user_001",
user_name="Test User", user_name="Test User",
rating=5, rating=5,
comment="Great template! Very accurate for medical entities.", comment="Great template! Very accurate for medical entities.",
is_verified_purchase=True is_verified_purchase=True,
) )
self.log(f"Added template review: {review.rating} stars") self.log(f"Added template review: {review.rating} stars")
except Exception as e: except Exception as e:
@@ -366,9 +363,9 @@ class TestDeveloperEcosystem:
version="1.0.0", version="1.0.0",
min_platform_version="2.0.0", min_platform_version="2.0.0",
file_size=1048576, file_size=1048576,
checksum="plg123" checksum="plg123",
) )
self.created_ids['plugin'].append(plugin.id) self.created_ids["plugin"].append(plugin.id)
self.log(f"Created plugin: {plugin.name} ({plugin.id})") self.log(f"Created plugin: {plugin.name} ({plugin.id})")
# Create free plugin # Create free plugin
@@ -381,9 +378,9 @@ class TestDeveloperEcosystem:
author_name="Data Team", author_name="Data Team",
price=0.0, price=0.0,
currency="CNY", currency="CNY",
pricing_model="free" pricing_model="free",
) )
self.created_ids['plugin'].append(plugin_free.id) self.created_ids["plugin"].append(plugin_free.id)
self.log(f"Created free plugin: {plugin_free.name}") self.log(f"Created free plugin: {plugin_free.name}")
except Exception as e: except Exception as e:
@@ -405,8 +402,8 @@ class TestDeveloperEcosystem:
def test_plugin_get(self): def test_plugin_get(self):
"""测试获取插件详情""" """测试获取插件详情"""
try: try:
if self.created_ids['plugin']: if self.created_ids["plugin"]:
plugin = self.manager.get_plugin(self.created_ids['plugin'][0]) plugin = self.manager.get_plugin(self.created_ids["plugin"][0])
if plugin: if plugin:
self.log(f"Retrieved plugin: {plugin.name}") self.log(f"Retrieved plugin: {plugin.name}")
except Exception as e: except Exception as e:
@@ -415,12 +412,12 @@ class TestDeveloperEcosystem:
def test_plugin_review(self): def test_plugin_review(self):
"""测试审核插件""" """测试审核插件"""
try: try:
if self.created_ids['plugin']: if self.created_ids["plugin"]:
plugin = self.manager.review_plugin( plugin = self.manager.review_plugin(
self.created_ids['plugin'][0], self.created_ids["plugin"][0],
reviewed_by="admin_001", reviewed_by="admin_001",
status=PluginStatus.APPROVED, status=PluginStatus.APPROVED,
notes="Code review passed" notes="Code review passed",
) )
if plugin: if plugin:
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})") self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
@@ -430,8 +427,8 @@ class TestDeveloperEcosystem:
def test_plugin_publish(self): def test_plugin_publish(self):
"""测试发布插件""" """测试发布插件"""
try: try:
if self.created_ids['plugin']: if self.created_ids["plugin"]:
plugin = self.manager.publish_plugin(self.created_ids['plugin'][0]) plugin = self.manager.publish_plugin(self.created_ids["plugin"][0])
if plugin: if plugin:
self.log(f"Published plugin: {plugin.name}") self.log(f"Published plugin: {plugin.name}")
except Exception as e: except Exception as e:
@@ -440,14 +437,14 @@ class TestDeveloperEcosystem:
def test_plugin_review_add(self): def test_plugin_review_add(self):
"""测试添加插件评价""" """测试添加插件评价"""
try: try:
if self.created_ids['plugin']: if self.created_ids["plugin"]:
review = self.manager.add_plugin_review( review = self.manager.add_plugin_review(
plugin_id=self.created_ids['plugin'][0], plugin_id=self.created_ids["plugin"][0],
user_id="user_002", user_id="user_002",
user_name="Plugin User", user_name="Plugin User",
rating=4, rating=4,
comment="Works great with Feishu!", comment="Works great with Feishu!",
is_verified_purchase=True is_verified_purchase=True,
) )
self.log(f"Added plugin review: {review.rating} stars") self.log(f"Added plugin review: {review.rating} stars")
except Exception as e: except Exception as e:
@@ -466,9 +463,9 @@ class TestDeveloperEcosystem:
bio="专注于医疗AI和自然语言处理", bio="专注于医疗AI和自然语言处理",
website="https://zhangsan.dev", website="https://zhangsan.dev",
github_url="https://github.com/zhangsan", github_url="https://github.com/zhangsan",
avatar_url="https://cdn.example.com/avatars/zhangsan.png" avatar_url="https://cdn.example.com/avatars/zhangsan.png",
) )
self.created_ids['developer'].append(profile.id) self.created_ids["developer"].append(profile.id)
self.log(f"Created developer profile: {profile.display_name} ({profile.id})") self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
# Create another developer # Create another developer
@@ -476,9 +473,9 @@ class TestDeveloperEcosystem:
user_id=f"user_dev_{unique_id}_002", user_id=f"user_dev_{unique_id}_002",
display_name="李四", display_name="李四",
email=f"lisi_{unique_id}@example.com", email=f"lisi_{unique_id}@example.com",
bio="全栈开发者,热爱开源" bio="全栈开发者,热爱开源",
) )
self.created_ids['developer'].append(profile2.id) self.created_ids["developer"].append(profile2.id)
self.log(f"Created developer profile: {profile2.display_name}") self.log(f"Created developer profile: {profile2.display_name}")
except Exception as e: except Exception as e:
@@ -487,8 +484,8 @@ class TestDeveloperEcosystem:
def test_developer_profile_get(self): def test_developer_profile_get(self):
"""测试获取开发者档案""" """测试获取开发者档案"""
try: try:
if self.created_ids['developer']: if self.created_ids["developer"]:
profile = self.manager.get_developer_profile(self.created_ids['developer'][0]) profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
if profile: if profile:
self.log(f"Retrieved developer profile: {profile.display_name}") self.log(f"Retrieved developer profile: {profile.display_name}")
except Exception as e: except Exception as e:
@@ -497,10 +494,9 @@ class TestDeveloperEcosystem:
def test_developer_verify(self): def test_developer_verify(self):
"""测试验证开发者""" """测试验证开发者"""
try: try:
if self.created_ids['developer']: if self.created_ids["developer"]:
profile = self.manager.verify_developer( profile = self.manager.verify_developer(
self.created_ids['developer'][0], self.created_ids["developer"][0], DeveloperStatus.VERIFIED
DeveloperStatus.VERIFIED
) )
if profile: if profile:
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})") self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
@@ -510,10 +506,12 @@ class TestDeveloperEcosystem:
def test_developer_stats_update(self): def test_developer_stats_update(self):
"""测试更新开发者统计""" """测试更新开发者统计"""
try: try:
if self.created_ids['developer']: if self.created_ids["developer"]:
self.manager.update_developer_stats(self.created_ids['developer'][0]) self.manager.update_developer_stats(self.created_ids["developer"][0])
profile = self.manager.get_developer_profile(self.created_ids['developer'][0]) profile = self.manager.get_developer_profile(self.created_ids["developer"][0])
self.log(f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates") self.log(
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates"
)
except Exception as e: except Exception as e:
self.log(f"Failed to update developer stats: {str(e)}", success=False) self.log(f"Failed to update developer stats: {str(e)}", success=False)
@@ -535,9 +533,9 @@ print(f"Created project: {project.id}")
tags=["python", "quickstart", "projects"], tags=["python", "quickstart", "projects"],
author_id="dev_001", author_id="dev_001",
author_name="InsightFlow Team", author_name="InsightFlow Team",
api_endpoints=["/api/v1/projects"] api_endpoints=["/api/v1/projects"],
) )
self.created_ids['code_example'].append(example.id) self.created_ids["code_example"].append(example.id)
self.log(f"Created code example: {example.title}") self.log(f"Created code example: {example.title}")
# Create JavaScript example # Create JavaScript example
@@ -558,9 +556,9 @@ console.log('Upload complete:', result.id);
explanation="使用 JavaScript SDK 上传文件到 InsightFlow", explanation="使用 JavaScript SDK 上传文件到 InsightFlow",
tags=["javascript", "upload", "audio"], tags=["javascript", "upload", "audio"],
author_id="dev_002", author_id="dev_002",
author_name="JS Team" author_name="JS Team",
) )
self.created_ids['code_example'].append(example_js.id) self.created_ids["code_example"].append(example_js.id)
self.log(f"Created code example: {example_js.title}") self.log(f"Created code example: {example_js.title}")
except Exception as e: except Exception as e:
@@ -582,10 +580,12 @@ console.log('Upload complete:', result.id);
def test_code_example_get(self): def test_code_example_get(self):
"""测试获取代码示例详情""" """测试获取代码示例详情"""
try: try:
if self.created_ids['code_example']: if self.created_ids["code_example"]:
example = self.manager.get_code_example(self.created_ids['code_example'][0]) example = self.manager.get_code_example(self.created_ids["code_example"][0])
if example: if example:
self.log(f"Retrieved code example: {example.title} (views: {example.view_count})") self.log(
f"Retrieved code example: {example.title} (views: {example.view_count})"
)
except Exception as e: except Exception as e:
self.log(f"Failed to get code example: {str(e)}", success=False) self.log(f"Failed to get code example: {str(e)}", success=False)
@@ -602,9 +602,9 @@ console.log('Upload complete:', result.id);
support_url="https://support.insightflow.io", support_url="https://support.insightflow.io",
github_url="https://github.com/insightflow", github_url="https://github.com/insightflow",
discord_url="https://discord.gg/insightflow", discord_url="https://discord.gg/insightflow",
api_base_url="https://api.insightflow.io/v1" api_base_url="https://api.insightflow.io/v1",
) )
self.created_ids['portal_config'].append(config.id) self.created_ids["portal_config"].append(config.id)
self.log(f"Created portal config: {config.name}") self.log(f"Created portal config: {config.name}")
except Exception as e: except Exception as e:
@@ -613,8 +613,8 @@ console.log('Upload complete:', result.id);
def test_portal_config_get(self): def test_portal_config_get(self):
"""测试获取开发者门户配置""" """测试获取开发者门户配置"""
try: try:
if self.created_ids['portal_config']: if self.created_ids["portal_config"]:
config = self.manager.get_portal_config(self.created_ids['portal_config'][0]) config = self.manager.get_portal_config(self.created_ids["portal_config"][0])
if config: if config:
self.log(f"Retrieved portal config: {config.name}") self.log(f"Retrieved portal config: {config.name}")
@@ -629,16 +629,16 @@ console.log('Upload complete:', result.id);
def test_revenue_record(self): def test_revenue_record(self):
"""测试记录开发者收益""" """测试记录开发者收益"""
try: try:
if self.created_ids['developer'] and self.created_ids['plugin']: if self.created_ids["developer"] and self.created_ids["plugin"]:
revenue = self.manager.record_revenue( revenue = self.manager.record_revenue(
developer_id=self.created_ids['developer'][0], developer_id=self.created_ids["developer"][0],
item_type="plugin", item_type="plugin",
item_id=self.created_ids['plugin'][0], item_id=self.created_ids["plugin"][0],
item_name="飞书机器人集成插件", item_name="飞书机器人集成插件",
sale_amount=49.0, sale_amount=49.0,
currency="CNY", currency="CNY",
buyer_id="user_buyer_001", buyer_id="user_buyer_001",
transaction_id="txn_123456" transaction_id="txn_123456",
) )
self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}") self.log(f"Recorded revenue: {revenue.sale_amount} {revenue.currency}")
self.log(f" - Platform fee: {revenue.platform_fee}") self.log(f" - Platform fee: {revenue.platform_fee}")
@@ -649,8 +649,10 @@ console.log('Upload complete:', result.id);
def test_revenue_summary(self): def test_revenue_summary(self):
"""测试获取开发者收益汇总""" """测试获取开发者收益汇总"""
try: try:
if self.created_ids['developer']: if self.created_ids["developer"]:
summary = self.manager.get_developer_revenue_summary(self.created_ids['developer'][0]) summary = self.manager.get_developer_revenue_summary(
self.created_ids["developer"][0]
)
self.log("Revenue summary for developer:") self.log("Revenue summary for developer:")
self.log(f" - Total sales: {summary['total_sales']}") self.log(f" - Total sales: {summary['total_sales']}")
self.log(f" - Total fees: {summary['total_fees']}") self.log(f" - Total fees: {summary['total_fees']}")
@@ -666,7 +668,7 @@ console.log('Upload complete:', result.id);
print("=" * 60) print("=" * 60)
total = len(self.test_results) total = len(self.test_results)
passed = sum(1 for r in self.test_results if r['success']) passed = sum(1 for r in self.test_results if r["success"])
failed = total - passed failed = total - passed
print(f"Total tests: {total}") print(f"Total tests: {total}")
@@ -676,7 +678,7 @@ console.log('Upload complete:', result.id);
if failed > 0: if failed > 0:
print("\nFailed tests:") print("\nFailed tests:")
for r in self.test_results: for r in self.test_results:
if not r['success']: if not r["success"]:
print(f" - {r['message']}") print(f" - {r['message']}")
print("\nCreated resources:") print("\nCreated resources:")
@@ -686,10 +688,12 @@ console.log('Upload complete:', result.id);
print("=" * 60) print("=" * 60)
def main(): def main():
"""主函数""" """主函数"""
test = TestDeveloperEcosystem() test = TestDeveloperEcosystem()
test.run_all_tests() test.run_all_tests()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -30,6 +30,7 @@ backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path: if backend_dir not in sys.path:
sys.path.insert(0, backend_dir) sys.path.insert(0, backend_dir)
class TestOpsManager: class TestOpsManager:
"""测试运维与监控管理器""" """测试运维与监控管理器"""
@@ -92,7 +93,7 @@ class TestOpsManager:
channels=[], channels=[],
labels={"service": "api", "team": "platform"}, labels={"service": "api", "team": "platform"},
annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"}, annotations={"summary": "CPU 使用率过高", "runbook": "https://wiki/runbooks/cpu"},
created_by="test_user" created_by="test_user",
) )
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})") self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
@@ -111,7 +112,7 @@ class TestOpsManager:
channels=[], channels=[],
labels={"service": "database"}, labels={"service": "database"},
annotations={}, annotations={},
created_by="test_user" created_by="test_user",
) )
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})") self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
@@ -128,9 +129,7 @@ class TestOpsManager:
# 更新告警规则 # 更新告警规则
updated_rule = self.manager.update_alert_rule( updated_rule = self.manager.update_alert_rule(
rule1.id, rule1.id, threshold=85.0, description="更新后的描述"
threshold=85.0,
description="更新后的描述"
) )
assert updated_rule.threshold == 85.0 assert updated_rule.threshold == 85.0
self.log(f"Updated alert rule threshold to {updated_rule.threshold}") self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
@@ -155,9 +154,9 @@ class TestOpsManager:
channel_type=AlertChannelType.FEISHU, channel_type=AlertChannelType.FEISHU,
config={ config={
"webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test", "webhook_url": "https://open.feishu.cn/open-apis/bot/v2/hook/test",
"secret": "test_secret" "secret": "test_secret",
}, },
severity_filter=["p0", "p1"] severity_filter=["p0", "p1"],
) )
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})") self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
@@ -168,9 +167,9 @@ class TestOpsManager:
channel_type=AlertChannelType.DINGTALK, channel_type=AlertChannelType.DINGTALK,
config={ config={
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test", "webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=test",
"secret": "test_secret" "secret": "test_secret",
}, },
severity_filter=["p0", "p1", "p2"] severity_filter=["p0", "p1", "p2"],
) )
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})") self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
@@ -179,10 +178,8 @@ class TestOpsManager:
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
name="Slack 告警", name="Slack 告警",
channel_type=AlertChannelType.SLACK, channel_type=AlertChannelType.SLACK,
config={ config={"webhook_url": "https://hooks.slack.com/services/test"},
"webhook_url": "https://hooks.slack.com/services/test" severity_filter=["p0", "p1", "p2", "p3"],
},
severity_filter=["p0", "p1", "p2", "p3"]
) )
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})") self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
@@ -228,7 +225,7 @@ class TestOpsManager:
channels=[], channels=[],
labels={}, labels={},
annotations={}, annotations={},
created_by="test_user" created_by="test_user",
) )
# 记录资源指标 # 记录资源指标
@@ -240,12 +237,13 @@ class TestOpsManager:
metric_name="test_metric", metric_name="test_metric",
metric_value=110.0 + i, metric_value=110.0 + i,
unit="percent", unit="percent",
metadata={"region": "cn-north-1"} metadata={"region": "cn-north-1"},
) )
self.log("Recorded 10 resource metrics") self.log("Recorded 10 resource metrics")
# 手动创建告警 # 手动创建告警
from ops_manager import Alert from ops_manager import Alert
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}" alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -267,20 +265,35 @@ class TestOpsManager:
acknowledged_by=None, acknowledged_by=None,
acknowledged_at=None, acknowledged_at=None,
notification_sent={}, notification_sent={},
suppression_count=0 suppression_count=0,
) )
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute(""" conn.execute(
"""
INSERT INTO alerts INSERT INTO alerts
(id, rule_id, tenant_id, severity, status, title, description, (id, rule_id, tenant_id, severity, status, title, description,
metric, value, threshold, labels, annotations, started_at, notification_sent, suppression_count) metric, value, threshold, labels, annotations, started_at, notification_sent, suppression_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (alert.id, alert.rule_id, alert.tenant_id, alert.severity.value, """,
alert.status.value, alert.title, alert.description, (
alert.metric, alert.value, alert.threshold, alert.id,
json.dumps(alert.labels), json.dumps(alert.annotations), alert.rule_id,
alert.started_at, json.dumps(alert.notification_sent), alert.suppression_count)) alert.tenant_id,
alert.severity.value,
alert.status.value,
alert.title,
alert.description,
alert.metric,
alert.value,
alert.threshold,
json.dumps(alert.labels),
json.dumps(alert.annotations),
alert.started_at,
json.dumps(alert.notification_sent),
alert.suppression_count,
),
)
conn.commit() conn.commit()
self.log(f"Created test alert: {alert.id}") self.log(f"Created test alert: {alert.id}")
@@ -325,12 +338,23 @@ class TestOpsManager:
for i in range(30): for i in range(30):
timestamp = (base_time + timedelta(days=i)).isoformat() timestamp = (base_time + timedelta(days=i)).isoformat()
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute(""" conn.execute(
"""
INSERT INTO resource_metrics INSERT INTO resource_metrics
(id, tenant_id, resource_type, resource_id, metric_name, metric_value, unit, timestamp) (id, tenant_id, resource_type, resource_id, metric_name, metric_value, unit, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (f"cm_{i}", self.tenant_id, ResourceType.CPU.value, "server-001", """,
"cpu_usage_percent", 50.0 + random.random() * 30, "percent", timestamp)) (
f"cm_{i}",
self.tenant_id,
ResourceType.CPU.value,
"server-001",
"cpu_usage_percent",
50.0 + random.random() * 30,
"percent",
timestamp,
),
)
conn.commit() conn.commit()
self.log("Recorded 30 days of historical metrics") self.log("Recorded 30 days of historical metrics")
@@ -342,7 +366,7 @@ class TestOpsManager:
resource_type=ResourceType.CPU, resource_type=ResourceType.CPU,
current_capacity=100.0, current_capacity=100.0,
prediction_date=prediction_date, prediction_date=prediction_date,
confidence=0.85 confidence=0.85,
) )
self.log(f"Created capacity plan: {plan.id}") self.log(f"Created capacity plan: {plan.id}")
@@ -382,7 +406,7 @@ class TestOpsManager:
scale_down_threshold=0.3, scale_down_threshold=0.3,
scale_up_step=2, scale_up_step=2,
scale_down_step=1, scale_down_step=1,
cooldown_period=300 cooldown_period=300,
) )
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})") self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
@@ -397,9 +421,7 @@ class TestOpsManager:
# 模拟扩缩容评估 # 模拟扩缩容评估
event = self.manager.evaluate_scaling_policy( event = self.manager.evaluate_scaling_policy(
policy_id=policy.id, policy_id=policy.id, current_instances=3, current_utilization=0.85
current_instances=3,
current_utilization=0.85
) )
if event: if event:
@@ -416,7 +438,9 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,)) conn.execute(
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,)
)
conn.commit() conn.commit()
self.log("Cleaned up auto scaling test data") self.log("Cleaned up auto scaling test data")
@@ -435,13 +459,10 @@ class TestOpsManager:
target_type="service", target_type="service",
target_id="api-service", target_id="api-service",
check_type="http", check_type="http",
check_config={ check_config={"url": "https://api.insightflow.io/health", "expected_status": 200},
"url": "https://api.insightflow.io/health",
"expected_status": 200
},
interval=60, interval=60,
timeout=10, timeout=10,
retry_count=3 retry_count=3,
) )
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})") self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
@@ -452,13 +473,10 @@ class TestOpsManager:
target_type="database", target_type="database",
target_id="postgres-001", target_id="postgres-001",
check_type="tcp", check_type="tcp",
check_config={ check_config={"host": "db.insightflow.io", "port": 5432},
"host": "db.insightflow.io",
"port": 5432
},
interval=30, interval=30,
timeout=5, timeout=5,
retry_count=2 retry_count=2,
) )
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})") self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
@@ -498,7 +516,7 @@ class TestOpsManager:
failover_trigger="health_check_failed", failover_trigger="health_check_failed",
auto_failover=False, auto_failover=False,
failover_timeout=300, failover_timeout=300,
health_check_id=None health_check_id=None,
) )
self.log(f"Created failover config: {config.name} (ID: {config.id})") self.log(f"Created failover config: {config.name} (ID: {config.id})")
@@ -512,8 +530,7 @@ class TestOpsManager:
# 发起故障转移 # 发起故障转移
event = self.manager.initiate_failover( event = self.manager.initiate_failover(
config_id=config.id, config_id=config.id, reason="Primary region health check failed"
reason="Primary region health check failed"
) )
if event: if event:
@@ -557,7 +574,7 @@ class TestOpsManager:
retention_days=30, retention_days=30,
encryption_enabled=True, encryption_enabled=True,
compression_enabled=True, compression_enabled=True,
storage_location="s3://insightflow-backups/" storage_location="s3://insightflow-backups/",
) )
self.log(f"Created backup job: {job.name} (ID: {job.id})") self.log(f"Created backup job: {job.name} (ID: {job.id})")
@@ -613,7 +630,7 @@ class TestOpsManager:
avg_utilization=0.08, avg_utilization=0.08,
idle_time_percent=0.85, idle_time_percent=0.85,
report_date=report_date, report_date=report_date,
recommendations=["Consider downsizing this resource"] recommendations=["Consider downsizing this resource"],
) )
self.log("Recorded 5 resource utilization records") self.log("Recorded 5 resource utilization records")
@@ -621,9 +638,7 @@ class TestOpsManager:
# 生成成本报告 # 生成成本报告
now = datetime.now() now = datetime.now()
report = self.manager.generate_cost_report( report = self.manager.generate_cost_report(
tenant_id=self.tenant_id, tenant_id=self.tenant_id, year=now.year, month=now.month
year=now.year,
month=now.month
) )
self.log(f"Generated cost report: {report.id}") self.log(f"Generated cost report: {report.id}")
@@ -639,9 +654,10 @@ class TestOpsManager:
idle_list = self.manager.get_idle_resources(self.tenant_id) idle_list = self.manager.get_idle_resources(self.tenant_id)
for resource in idle_list: for resource in idle_list:
self.log( self.log(
f" Idle resource: { f" Idle resource: {resource.resource_name} (est. cost: {
resource.resource_name} (est. cost: { resource.estimated_monthly_cost
resource.estimated_monthly_cost}/month)") }/month)"
)
# 生成成本优化建议 # 生成成本优化建议
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id) suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
@@ -649,7 +665,9 @@ class TestOpsManager:
for suggestion in suggestions: for suggestion in suggestions:
self.log(f" Suggestion: {suggestion.title}") self.log(f" Suggestion: {suggestion.title}")
self.log(f" Potential savings: {suggestion.potential_savings} {suggestion.currency}") self.log(
f" Potential savings: {suggestion.potential_savings} {suggestion.currency}"
)
self.log(f" Confidence: {suggestion.confidence}") self.log(f" Confidence: {suggestion.confidence}")
self.log(f" Difficulty: {suggestion.difficulty}") self.log(f" Difficulty: {suggestion.difficulty}")
@@ -667,9 +685,14 @@ class TestOpsManager:
# 清理 # 清理
with self.manager._get_db() as conn: with self.manager._get_db() as conn:
conn.execute("DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?", (self.tenant_id,)) conn.execute(
"DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?",
(self.tenant_id,),
)
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,))
conn.execute("DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,)) conn.execute(
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,)
)
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,)) conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
conn.commit() conn.commit()
self.log("Cleaned up cost optimization test data") self.log("Cleaned up cost optimization test data")
@@ -699,10 +722,12 @@ class TestOpsManager:
print("=" * 60) print("=" * 60)
def main(): def main():
"""主函数""" """主函数"""
test = TestOpsManager() test = TestOpsManager()
test.run_all_tests() test.run_all_tests()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -8,6 +8,7 @@ import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
class TingwuClient: class TingwuClient:
def __init__(self): def __init__(self):
self.access_key = os.getenv("ALI_ACCESS_KEY", "") self.access_key = os.getenv("ALI_ACCESS_KEY", "")
@@ -17,7 +18,9 @@ class TingwuClient:
if not self.access_key or not self.secret_key: if not self.access_key or not self.secret_key:
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required") raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]: def _sign_request(
self, method: str, uri: str, query: str = "", body: str = ""
) -> dict[str, str]:
"""阿里云签名 V3""" """阿里云签名 V3"""
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
@@ -39,7 +42,9 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key) config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)
@@ -47,7 +52,9 @@ class TingwuClient:
type="offline", type="offline",
input=tingwu_models.Input(source="OSS", file_url=audio_url), input=tingwu_models.Input(source="OSS", file_url=audio_url),
parameters=tingwu_models.Parameters( parameters=tingwu_models.Parameters(
transcription=tingwu_models.Transcription(diarization_enabled=True, sentence_max_length=20) transcription=tingwu_models.Transcription(
diarization_enabled=True, sentence_max_length=20
)
), ),
) )
@@ -65,7 +72,9 @@ class TingwuClient:
print(f"Tingwu API error: {e}") print(f"Tingwu API error: {e}")
return f"mock_task_{int(time.time())}" return f"mock_task_{int(time.time())}"
def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> dict[str, Any]: def get_task_result(
self, task_id: str, max_retries: int = 60, interval: int = 5
) -> dict[str, Any]:
"""获取任务结果""" """获取任务结果"""
try: try:
# 导入移到文件顶部会导致循环导入,保持在这里 # 导入移到文件顶部会导致循环导入,保持在这里
@@ -73,7 +82,9 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key) config = open_api_models.Config(
access_key_id=self.access_key, access_key_secret=self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)

View File

@@ -15,6 +15,7 @@ import hashlib
import hmac import hmac
import json import json
import logging import logging
import urllib.parse
import uuid import uuid
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -32,6 +33,7 @@ from apscheduler.triggers.interval import IntervalTrigger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowStatus(Enum): class WorkflowStatus(Enum):
"""工作流状态""" """工作流状态"""
@@ -40,6 +42,7 @@ class WorkflowStatus(Enum):
ERROR = "error" ERROR = "error"
COMPLETED = "completed" COMPLETED = "completed"
class WorkflowType(Enum): class WorkflowType(Enum):
"""工作流类型""" """工作流类型"""
@@ -49,6 +52,7 @@ class WorkflowType(Enum):
SCHEDULED_REPORT = "scheduled_report" # 定时报告 SCHEDULED_REPORT = "scheduled_report" # 定时报告
CUSTOM = "custom" # 自定义工作流 CUSTOM = "custom" # 自定义工作流
class WebhookType(Enum): class WebhookType(Enum):
"""Webhook 类型""" """Webhook 类型"""
@@ -57,6 +61,7 @@ class WebhookType(Enum):
SLACK = "slack" SLACK = "slack"
CUSTOM = "custom" CUSTOM = "custom"
class TaskStatus(Enum): class TaskStatus(Enum):
"""任务执行状态""" """任务执行状态"""
@@ -66,6 +71,7 @@ class TaskStatus(Enum):
FAILED = "failed" FAILED = "failed"
CANCELLED = "cancelled" CANCELLED = "cancelled"
@dataclass @dataclass
class WorkflowTask: class WorkflowTask:
"""工作流任务定义""" """工作流任务定义"""
@@ -89,6 +95,7 @@ class WorkflowTask:
if not self.updated_at: if not self.updated_at:
self.updated_at = self.created_at self.updated_at = self.created_at
@dataclass @dataclass
class WebhookConfig: class WebhookConfig:
"""Webhook 配置""" """Webhook 配置"""
@@ -113,6 +120,7 @@ class WebhookConfig:
if not self.updated_at: if not self.updated_at:
self.updated_at = self.created_at self.updated_at = self.created_at
@dataclass @dataclass
class Workflow: class Workflow:
"""工作流定义""" """工作流定义"""
@@ -142,6 +150,7 @@ class Workflow:
if not self.updated_at: if not self.updated_at:
self.updated_at = self.created_at self.updated_at = self.created_at
@dataclass @dataclass
class WorkflowLog: class WorkflowLog:
"""工作流执行日志""" """工作流执行日志"""
@@ -162,6 +171,7 @@ class WorkflowLog:
if not self.created_at: if not self.created_at:
self.created_at = datetime.now().isoformat() self.created_at = datetime.now().isoformat()
class WebhookNotifier: class WebhookNotifier:
"""Webhook 通知器 - 支持飞书、钉钉、Slack""" """Webhook 通知器 - 支持飞书、钉钉、Slack"""
@@ -213,11 +223,23 @@ class WebhookNotifier:
"timestamp": timestamp, "timestamp": timestamp,
"sign": sign, "sign": sign,
"msg_type": "post", "msg_type": "post",
"content": {"post": {"zh_cn": {"title": message.get("title", ""), "content": message.get("body", [])}}}, "content": {
"post": {
"zh_cn": {
"title": message.get("title", ""),
"content": message.get("body", []),
}
}
},
} }
else: else:
# 卡片消息 # 卡片消息
payload = {"timestamp": timestamp, "sign": sign, "msg_type": "interactive", "card": message.get("card", {})} payload = {
"timestamp": timestamp,
"sign": sign,
"msg_type": "interactive",
"card": message.get("card", {}),
}
headers = {"Content-Type": "application/json", **config.headers} headers = {"Content-Type": "application/json", **config.headers}
@@ -235,7 +257,9 @@ class WebhookNotifier:
if config.secret: if config.secret:
secret_enc = config.secret.encode("utf-8") secret_enc = config.secret.encode("utf-8")
string_to_sign = f"{timestamp}\n{config.secret}" string_to_sign = f"{timestamp}\n{config.secret}"
hmac_code = hmac.new(secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() hmac_code = hmac.new(
secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256
).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
url = f"{config.url}&timestamp={timestamp}&sign={sign}" url = f"{config.url}&timestamp={timestamp}&sign={sign}"
else: else:
@@ -303,6 +327,7 @@ class WebhookNotifier:
"""关闭 HTTP 客户端""" """关闭 HTTP 客户端"""
await self.http_client.aclose() await self.http_client.aclose()
class WorkflowManager: class WorkflowManager:
"""工作流管理器 - 核心管理类""" """工作流管理器 - 核心管理类"""
@@ -390,7 +415,9 @@ class WorkflowManager:
coalesce=True, coalesce=True,
) )
logger.info(f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}") logger.info(
f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}"
)
async def _execute_workflow_job(self, workflow_id: str): async def _execute_workflow_job(self, workflow_id: str):
"""调度器调用的工作流执行函数""" """调度器调用的工作流执行函数"""
@@ -463,7 +490,9 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def list_workflows(self, project_id: str = None, status: str = None, workflow_type: str = None) -> list[Workflow]: def list_workflows(
self, project_id: str = None, status: str = None, workflow_type: str = None
) -> list[Workflow]:
"""列出工作流""" """列出工作流"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -632,7 +661,8 @@ class WorkflowManager:
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
rows = conn.execute( rows = conn.execute(
"SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", (workflow_id,) "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order",
(workflow_id,),
).fetchall() ).fetchall()
return [self._row_to_task(row) for row in rows] return [self._row_to_task(row) for row in rows]
@@ -743,7 +773,9 @@ class WorkflowManager:
"""获取 Webhook 配置""" """获取 Webhook 配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
row = conn.execute("SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,)).fetchone() row = conn.execute(
"SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,)
).fetchone()
if not row: if not row:
return None return None
@@ -766,7 +798,15 @@ class WorkflowManager:
"""更新 Webhook 配置""" """更新 Webhook 配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
allowed_fields = ["name", "webhook_type", "url", "secret", "headers", "template", "is_active"] allowed_fields = [
"name",
"webhook_type",
"url",
"secret",
"headers",
"template",
"is_active",
]
updates = [] updates = []
values = [] values = []
@@ -915,7 +955,12 @@ class WorkflowManager:
conn.close() conn.close()
def list_logs( def list_logs(
self, workflow_id: str = None, task_id: str = None, status: str = None, limit: int = 100, offset: int = 0 self,
workflow_id: str = None,
task_id: str = None,
status: str = None,
limit: int = 100,
offset: int = 0,
) -> list[WorkflowLog]: ) -> list[WorkflowLog]:
"""列出工作流日志""" """列出工作流日志"""
conn = self.db.get_conn() conn = self.db.get_conn()
@@ -955,7 +1000,8 @@ class WorkflowManager:
# 总执行次数 # 总执行次数
total = conn.execute( total = conn.execute(
"SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since) "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?",
(workflow_id, since),
).fetchone()[0] ).fetchone()[0]
# 成功次数 # 成功次数
@@ -997,7 +1043,9 @@ class WorkflowManager:
"failed": failed, "failed": failed,
"success_rate": round(success / total * 100, 2) if total > 0 else 0, "success_rate": round(success / total * 100, 2) if total > 0 else 0,
"avg_duration_ms": round(avg_duration, 2), "avg_duration_ms": round(avg_duration, 2),
"daily": [{"date": r["date"], "count": r["count"], "success": r["success"]} for r in daily], "daily": [
{"date": r["date"], "count": r["count"], "success": r["success"]} for r in daily
],
} }
finally: finally:
conn.close() conn.close()
@@ -1104,7 +1152,9 @@ class WorkflowManager:
raise raise
async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask], input_data: dict, log_id: str) -> dict: async def _execute_tasks_with_deps(
self, tasks: list[WorkflowTask], input_data: dict, log_id: str
) -> dict:
"""按依赖顺序执行任务""" """按依赖顺序执行任务"""
results = {} results = {}
completed_tasks = set() completed_tasks = set()
@@ -1112,7 +1162,10 @@ class WorkflowManager:
while len(completed_tasks) < len(tasks): while len(completed_tasks) < len(tasks):
# 找到可以执行的任务(依赖已完成) # 找到可以执行的任务(依赖已完成)
ready_tasks = [ ready_tasks = [
t for t in tasks if t.id not in completed_tasks and all(dep in completed_tasks for dep in t.depends_on) t
for t in tasks
if t.id not in completed_tasks
and all(dep in completed_tasks for dep in t.depends_on)
] ]
if not ready_tasks: if not ready_tasks:
@@ -1191,7 +1244,10 @@ class WorkflowManager:
except Exception as e: except Exception as e:
self.update_log( self.update_log(
task_log.id, status=TaskStatus.FAILED.value, end_time=datetime.now().isoformat(), error_message=str(e) task_log.id,
status=TaskStatus.FAILED.value,
end_time=datetime.now().isoformat(),
error_message=str(e),
) )
raise raise
@@ -1222,7 +1278,12 @@ class WorkflowManager:
# 这里调用现有的文件分析逻辑 # 这里调用现有的文件分析逻辑
# 实际实现需要与 main.py 中的 upload_audio 逻辑集成 # 实际实现需要与 main.py 中的 upload_audio 逻辑集成
return {"task": "analyze", "project_id": project_id, "files_processed": len(file_ids), "status": "completed"} return {
"task": "analyze",
"project_id": project_id,
"files_processed": len(file_ids),
"status": "completed",
}
async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict: async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理实体对齐任务""" """处理实体对齐任务"""
@@ -1283,7 +1344,12 @@ class WorkflowManager:
async def _handle_custom_task(self, task: WorkflowTask, input_data: dict) -> dict: async def _handle_custom_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理自定义任务""" """处理自定义任务"""
# 自定义任务的具体逻辑由外部处理器实现 # 自定义任务的具体逻辑由外部处理器实现
return {"task": "custom", "task_name": task.name, "config": task.config, "status": "completed"} return {
"task": "custom",
"task_name": task.name,
"config": task.config,
"status": "completed",
}
# ==================== Default Workflow Implementations ==================== # ==================== Default Workflow Implementations ====================
@@ -1340,7 +1406,9 @@ class WorkflowManager:
# ==================== Notification ==================== # ==================== Notification ====================
async def _send_workflow_notification(self, workflow: Workflow, results: dict, success: bool = True): async def _send_workflow_notification(
self, workflow: Workflow, results: dict, success: bool = True
):
"""发送工作流执行通知""" """发送工作流执行通知"""
if not workflow.webhook_ids: if not workflow.webhook_ids:
return return
@@ -1397,7 +1465,7 @@ class WorkflowManager:
**状态:** {status_text} **状态:** {status_text}
**时间:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} **时间:** {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
**结果:** **结果:**
```json ```json
@@ -1418,7 +1486,11 @@ class WorkflowManager:
"title": f"Workflow Execution: {workflow.name}", "title": f"Workflow Execution: {workflow.name}",
"fields": [ "fields": [
{"title": "Status", "value": status_text, "short": True}, {"title": "Status", "value": status_text, "short": True},
{"title": "Time", "value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "short": True}, {
"title": "Time",
"value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"short": True,
},
], ],
"footer": "InsightFlow", "footer": "InsightFlow",
"ts": int(datetime.now().timestamp()), "ts": int(datetime.now().timestamp()),
@@ -1426,9 +1498,11 @@ class WorkflowManager:
] ]
} }
# Singleton instance # Singleton instance
_workflow_manager = None _workflow_manager = None
def get_workflow_manager(db_manager=None) -> WorkflowManager: def get_workflow_manager(db_manager=None) -> WorkflowManager:
"""获取 WorkflowManager 单例""" """获取 WorkflowManager 单例"""
global _workflow_manager global _workflow_manager

View File

@@ -9,7 +9,14 @@ from pathlib import Path
class CodeIssue: class CodeIssue:
def __init__(self, file_path: str, line_no: int, issue_type: str, message: str, severity: str = "info"): def __init__(
self,
file_path: str,
line_no: int,
issue_type: str,
message: str,
severity: str = "info",
):
self.file_path = file_path self.file_path = file_path
self.line_no = line_no self.line_no = line_no
self.issue_type = issue_type self.issue_type = issue_type
@@ -74,17 +81,29 @@ class CodeReviewer:
# 9. 检查敏感信息 # 9. 检查敏感信息
self._check_sensitive_info(content, lines, rel_path) self._check_sensitive_info(content, lines, rel_path)
def _check_bare_exceptions(self, content: str, lines: list[str], file_path: str) -> None: def _check_bare_exceptions(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查裸异常捕获""" """检查裸异常捕获"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
if re.search(r"except\s*:\s*$", line.strip()) or re.search(r"except\s+Exception\s*:\s*$", line.strip()): if re.search(r"except\s*:\s*$", line.strip()) or re.search(
r"except\s+Exception\s*:\s*$", line.strip()
):
# 跳过有注释说明的情况 # 跳过有注释说明的情况
if "# noqa" in line or "# intentional" in line.lower(): if "# noqa" in line or "# intentional" in line.lower():
continue continue
issue = CodeIssue(file_path, i, "bare_exception", "裸异常捕获,应该使用具体异常类型", "warning") issue = CodeIssue(
file_path,
i,
"bare_exception",
"裸异常捕获,应该使用具体异常类型",
"warning",
)
self.issues.append(issue) self.issues.append(issue)
def _check_duplicate_imports(self, content: str, lines: list[str], file_path: str) -> None: def _check_duplicate_imports(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查重复导入""" """检查重复导入"""
imports = {} imports = {}
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
@@ -96,30 +115,50 @@ class CodeReviewer:
name = name.strip().split()[0] # 处理 'as' 别名 name = name.strip().split()[0] # 处理 'as' 别名
key = f"{module}.{name}" if module else name key = f"{module}.{name}" if module else name
if key in imports: if key in imports:
issue = CodeIssue(file_path, i, "duplicate_import", f"重复导入: {key}", "warning") issue = CodeIssue(
file_path,
i,
"duplicate_import",
f"重复导入: {key}",
"warning",
)
self.issues.append(issue) self.issues.append(issue)
imports[key] = i imports[key] = i
def _check_pep8_issues(self, content: str, lines: list[str], file_path: str) -> None: def _check_pep8_issues(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查 PEP8 问题""" """检查 PEP8 问题"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 行长度超过 120 # 行长度超过 120
if len(line) > 120: if len(line) > 120:
issue = CodeIssue(file_path, i, "line_too_long", f"行长度 {len(line)} 超过 120 字符", "info") issue = CodeIssue(
file_path,
i,
"line_too_long",
f"行长度 {len(line)} 超过 120 字符",
"info",
)
self.issues.append(issue) self.issues.append(issue)
# 行尾空格 # 行尾空格
if line.rstrip() != line: if line.rstrip() != line:
issue = CodeIssue(file_path, i, "trailing_whitespace", "行尾有空格", "info") issue = CodeIssue(
file_path, i, "trailing_whitespace", "行尾有空格", "info"
)
self.issues.append(issue) self.issues.append(issue)
# 多余的空行 # 多余的空行
if i > 1 and line.strip() == "" and lines[i - 2].strip() == "": if i > 1 and line.strip() == "" and lines[i - 2].strip() == "":
if i < len(lines) and lines[i].strip() == "": if i < len(lines) and lines[i].strip() == "":
issue = CodeIssue(file_path, i, "extra_blank_line", "多余的空行", "info") issue = CodeIssue(
file_path, i, "extra_blank_line", "多余的空行", "info"
)
self.issues.append(issue) self.issues.append(issue)
def _check_unused_imports(self, content: str, lines: list[str], file_path: str) -> None: def _check_unused_imports(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查未使用的导入""" """检查未使用的导入"""
try: try:
tree = ast.parse(content) tree = ast.parse(content)
@@ -147,10 +186,14 @@ class CodeReviewer:
# 排除一些常见例外 # 排除一些常见例外
if name in ["annotations", "TYPE_CHECKING"]: if name in ["annotations", "TYPE_CHECKING"]:
continue continue
issue = CodeIssue(file_path, lineno, "unused_import", f"未使用的导入: {name}", "info") issue = CodeIssue(
file_path, lineno, "unused_import", f"未使用的导入: {name}", "info"
)
self.issues.append(issue) self.issues.append(issue)
def _check_string_formatting(self, content: str, lines: list[str], file_path: str) -> None: def _check_string_formatting(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查混合字符串格式化""" """检查混合字符串格式化"""
has_fstring = False has_fstring = False
has_percent = False has_percent = False
@@ -165,10 +208,18 @@ class CodeReviewer:
has_format = True has_format = True
if has_fstring and (has_percent or has_format): if has_fstring and (has_percent or has_format):
issue = CodeIssue(file_path, 0, "mixed_formatting", "文件混合使用多种字符串格式化方式,建议统一为 f-string", "info") issue = CodeIssue(
file_path,
0,
"mixed_formatting",
"文件混合使用多种字符串格式化方式,建议统一为 f-string",
"info",
)
self.issues.append(issue) self.issues.append(issue)
def _check_magic_numbers(self, content: str, lines: list[str], file_path: str) -> None: def _check_magic_numbers(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查魔法数字""" """检查魔法数字"""
# 常见的魔法数字模式 # 常见的魔法数字模式
magic_patterns = [ magic_patterns = [
@@ -190,36 +241,88 @@ class CodeReviewer:
match = re.search(r"(\d{3,})", code_part) match = re.search(r"(\d{3,})", code_part)
if match: if match:
num = int(match.group(1)) num = int(match.group(1))
if num in [200, 404, 500, 401, 403, 429, 1000, 1024, 2048, 4096, 8080, 3000, 8000]: if num in [
200,
404,
500,
401,
403,
429,
1000,
1024,
2048,
4096,
8080,
3000,
8000,
]:
continue continue
issue = CodeIssue(file_path, i, "magic_number", f"{msg}: {num}", "info") issue = CodeIssue(
file_path, i, "magic_number", f"{msg}: {num}", "info"
)
self.issues.append(issue) self.issues.append(issue)
def _check_sql_injection(self, content: str, lines: list[str], file_path: str) -> None: def _check_sql_injection(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查 SQL 注入风险""" """检查 SQL 注入风险"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 检查字符串拼接的 SQL # 检查字符串拼接的 SQL
if re.search(r'execute\s*\(\s*["\'].*%s', line) or re.search(r'execute\s*\(\s*f["\']', line): if re.search(r'execute\s*\(\s*["\'].*%s', line) or re.search(
r'execute\s*\(\s*f["\']', line
):
if "?" not in line and "%s" in line: if "?" not in line and "%s" in line:
issue = CodeIssue(file_path, i, "sql_injection_risk", "可能的 SQL 注入风险 - 需要人工确认", "error") issue = CodeIssue(
file_path,
i,
"sql_injection_risk",
"可能的 SQL 注入风险 - 需要人工确认",
"error",
)
self.manual_review_issues.append(issue) self.manual_review_issues.append(issue)
def _check_cors_config(self, content: str, lines: list[str], file_path: str) -> None: def _check_cors_config(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查 CORS 配置""" """检查 CORS 配置"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
if "allow_origins" in line and '["*"]' in line: if "allow_origins" in line and '["*"]' in line:
issue = CodeIssue(file_path, i, "cors_wildcard", "CORS 允许所有来源 - 需要人工确认", "warning") issue = CodeIssue(
file_path,
i,
"cors_wildcard",
"CORS 允许所有来源 - 需要人工确认",
"warning",
)
self.manual_review_issues.append(issue) self.manual_review_issues.append(issue)
def _check_sensitive_info(self, content: str, lines: list[str], file_path: str) -> None: def _check_sensitive_info(
self, content: str, lines: list[str], file_path: str
) -> None:
"""检查敏感信息""" """检查敏感信息"""
for i, line in enumerate(lines, 1): for i, line in enumerate(lines, 1):
# 检查硬编码密钥 # 检查硬编码密钥
if re.search(r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']', line, re.IGNORECASE): if re.search(
if "os.getenv" not in line and "environ" not in line and "getenv" not in line: r'(password|secret|key|token)\s*=\s*["\'][^"\']+["\']',
line,
re.IGNORECASE,
):
if (
"os.getenv" not in line
and "environ" not in line
and "getenv" not in line
):
# 排除一些常见假阳性 # 排除一些常见假阳性
if not re.search(r'["\']\*+["\']', line) and not re.search(r'["\']<[^"\']*>["\']', line): if not re.search(r'["\']\*+["\']', line) and not re.search(
issue = CodeIssue(file_path, i, "hardcoded_secret", "可能的硬编码敏感信息 - 需要人工确认", "error") r'["\']<[^"\']*>["\']', line
):
issue = CodeIssue(
file_path,
i,
"hardcoded_secret",
"可能的硬编码敏感信息 - 需要人工确认",
"error",
)
self.manual_review_issues.append(issue) self.manual_review_issues.append(issue)
def auto_fix(self) -> None: def auto_fix(self) -> None:
@@ -289,7 +392,9 @@ class CodeReviewer:
if self.fixed_issues: if self.fixed_issues:
report.append(f"共修复 {len(self.fixed_issues)} 个问题:\n") report.append(f"共修复 {len(self.fixed_issues)} 个问题:\n")
for issue in self.fixed_issues: for issue in self.fixed_issues:
report.append(f"- ✅ {issue.file_path}:{issue.line_no} - {issue.issue_type}: {issue.message}") report.append(
f"- ✅ {issue.file_path}:{issue.line_no} - {issue.issue_type}: {issue.message}"
)
else: else:
report.append("") report.append("")
@@ -297,7 +402,9 @@ class CodeReviewer:
if self.manual_review_issues: if self.manual_review_issues:
report.append(f"共发现 {len(self.manual_review_issues)} 个问题:\n") report.append(f"共发现 {len(self.manual_review_issues)} 个问题:\n")
for issue in self.manual_review_issues: for issue in self.manual_review_issues:
report.append(f"- ⚠️ {issue.file_path}:{issue.line_no} - {issue.issue_type}: {issue.message}") report.append(
f"- ⚠️ {issue.file_path}:{issue.line_no} - {issue.issue_type}: {issue.message}"
)
else: else:
report.append("") report.append("")
@@ -305,7 +412,9 @@ class CodeReviewer:
if self.issues: if self.issues:
report.append(f"共发现 {len(self.issues)} 个问题:\n") report.append(f"共发现 {len(self.issues)} 个问题:\n")
for issue in self.issues: for issue in self.issues:
report.append(f"- 📝 {issue.file_path}:{issue.line_no} - {issue.issue_type}: {issue.message}") report.append(
f"- 📝 {issue.file_path}:{issue.line_no} - {issue.issue_type}: {issue.message}"
)
else: else:
report.append("") report.append("")