fix: auto-fix code issues (cron)
- 修复隐式 Optional 类型注解 (RUF013) - 修复不必要的赋值后返回 (RET504) - 优化列表推导式 (PERF401) - 修复未使用的参数 (ARG002) - 清理重复导入 - 优化异常处理
This commit is contained in:
320
auto_fix_code.py
320
auto_fix_code.py
@@ -1,235 +1,109 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
自动代码修复脚本 - 修复 InsightFlow 项目中的常见问题
|
Auto-fix script for InsightFlow code issues
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
def get_python_files(directory):
|
|
||||||
"""获取目录下所有 Python 文件"""
|
|
||||||
python_files = []
|
|
||||||
for root, _, files in os.walk(directory):
|
|
||||||
for file in files:
|
|
||||||
if file.endswith('.py'):
|
|
||||||
python_files.append(os.path.join(root, file))
|
|
||||||
return python_files
|
|
||||||
|
|
||||||
|
|
||||||
def fix_missing_imports(content, filepath):
|
|
||||||
"""修复缺失的导入"""
|
|
||||||
fixes = []
|
|
||||||
|
|
||||||
# 检查是否使用了 re 但没有导入
|
|
||||||
if 're.search(' in content or 're.sub(' in content or 're.match(' in content:
|
|
||||||
if 'import re' not in content:
|
|
||||||
# 找到合适的位置添加导入
|
|
||||||
lines = content.split('\n')
|
|
||||||
import_idx = 0
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
if line.startswith('import ') or line.startswith('from '):
|
|
||||||
import_idx = i + 1
|
|
||||||
lines.insert(import_idx, 'import re')
|
|
||||||
content = '\n'.join(lines)
|
|
||||||
fixes.append("添加缺失的 'import re'")
|
|
||||||
|
|
||||||
# 检查是否使用了 csv 但没有导入
|
|
||||||
if 'csv.' in content and 'import csv' not in content:
|
|
||||||
lines = content.split('\n')
|
|
||||||
import_idx = 0
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
if line.startswith('import ') or line.startswith('from '):
|
|
||||||
import_idx = i + 1
|
|
||||||
lines.insert(import_idx, 'import csv')
|
|
||||||
content = '\n'.join(lines)
|
|
||||||
fixes.append("添加缺失的 'import csv'")
|
|
||||||
|
|
||||||
# 检查是否使用了 urllib 但没有导入
|
|
||||||
if 'urllib.' in content and 'import urllib' not in content:
|
|
||||||
lines = content.split('\n')
|
|
||||||
import_idx = 0
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
if line.startswith('import ') or line.startswith('from '):
|
|
||||||
import_idx = i + 1
|
|
||||||
lines.insert(import_idx, 'import urllib.parse')
|
|
||||||
content = '\n'.join(lines)
|
|
||||||
fixes.append("添加缺失的 'import urllib.parse'")
|
|
||||||
|
|
||||||
return content, fixes
|
|
||||||
|
|
||||||
|
|
||||||
def fix_bare_excepts(content):
|
|
||||||
"""修复裸异常捕获"""
|
|
||||||
fixes = []
|
|
||||||
|
|
||||||
# 替换裸 except:
|
|
||||||
bare_except_pattern = r'except\s*:\s*$'
|
|
||||||
lines = content.split('\n')
|
|
||||||
new_lines = []
|
|
||||||
for line in lines:
|
|
||||||
if re.match(bare_except_pattern, line.strip()):
|
|
||||||
# 缩进保持一致
|
|
||||||
indent = len(line) - len(line.lstrip())
|
|
||||||
new_line = ' ' * indent + 'except Exception:'
|
|
||||||
new_lines.append(new_line)
|
|
||||||
fixes.append(f"修复裸异常捕获: {line.strip()}")
|
|
||||||
else:
|
|
||||||
new_lines.append(line)
|
|
||||||
|
|
||||||
content = '\n'.join(new_lines)
|
|
||||||
return content, fixes
|
|
||||||
|
|
||||||
|
|
||||||
def fix_unused_imports(content):
|
|
||||||
"""修复未使用的导入 - 简单版本"""
|
|
||||||
fixes = []
|
|
||||||
|
|
||||||
# 查找导入语句
|
|
||||||
import_pattern = r'^from\s+(\S+)\s+import\s+(.+)$'
|
|
||||||
lines = content.split('\n')
|
|
||||||
new_lines = []
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
match = re.match(import_pattern, line)
|
|
||||||
if match:
|
|
||||||
module = match.group(1)
|
|
||||||
imports = match.group(2)
|
|
||||||
|
|
||||||
# 检查每个导入是否被使用
|
|
||||||
imported_items = [i.strip() for i in imports.split(',')]
|
|
||||||
used_items = []
|
|
||||||
|
|
||||||
for item in imported_items:
|
|
||||||
# 简单的使用检查
|
|
||||||
item_name = item.split(' as ')[-1].strip() if ' as ' in item else item.strip()
|
|
||||||
if item_name in content.replace(line, ''):
|
|
||||||
used_items.append(item)
|
|
||||||
else:
|
|
||||||
fixes.append(f"移除未使用的导入: {item}")
|
|
||||||
|
|
||||||
if used_items:
|
|
||||||
new_lines.append(f"from {module} import {', '.join(used_items)}")
|
|
||||||
else:
|
|
||||||
fixes.append(f"移除整行导入: {line.strip()}")
|
|
||||||
else:
|
|
||||||
new_lines.append(line)
|
|
||||||
|
|
||||||
content = '\n'.join(new_lines)
|
|
||||||
return content, fixes
|
|
||||||
|
|
||||||
|
|
||||||
def fix_string_formatting(content):
|
|
||||||
"""统一字符串格式化为 f-string"""
|
|
||||||
fixes = []
|
|
||||||
|
|
||||||
# 修复 .format() 调用
|
|
||||||
format_pattern = r'["\']([^"\']*)\{([^}]+)\}[^"\']*["\']\.format\(([^)]+)\)'
|
|
||||||
|
|
||||||
def replace_format(match):
|
|
||||||
template = match.group(1) + '{' + match.group(2) + '}'
|
|
||||||
# 简单替换,实际可能需要更复杂的处理
|
|
||||||
return f'f"{template}"'
|
|
||||||
|
|
||||||
new_content = re.sub(format_pattern, replace_format, content)
|
|
||||||
if new_content != content:
|
|
||||||
fixes.append("统一字符串格式化为 f-string")
|
|
||||||
content = new_content
|
|
||||||
|
|
||||||
return content, fixes
|
|
||||||
|
|
||||||
|
|
||||||
def fix_pep8_formatting(content):
|
|
||||||
"""修复 PEP8 格式问题"""
|
|
||||||
fixes = []
|
|
||||||
lines = content.split('\n')
|
|
||||||
new_lines = []
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
original = line
|
|
||||||
# 修复 E221: multiple spaces before operator
|
|
||||||
line = re.sub(r'(\w+)\s{2,}=\s', r'\1 = ', line)
|
|
||||||
# 修复 E251: unexpected spaces around keyword / parameter equals
|
|
||||||
line = re.sub(r'(\w+)\s*=\s{2,}', r'\1 = ', line)
|
|
||||||
line = re.sub(r'(\w+)\s{2,}=\s*', r'\1 = ', line)
|
|
||||||
|
|
||||||
if line != original:
|
|
||||||
fixes.append(f"修复 PEP8 格式: {original.strip()[:50]}")
|
|
||||||
|
|
||||||
new_lines.append(line)
|
|
||||||
|
|
||||||
content = '\n'.join(new_lines)
|
|
||||||
return content, fixes
|
|
||||||
|
|
||||||
|
|
||||||
def fix_file(filepath):
|
def fix_file(filepath):
|
||||||
"""修复单个文件"""
|
"""Fix common issues in a Python file"""
|
||||||
print(f"\n处理文件: {filepath}")
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
try:
|
|
||||||
with open(filepath, encoding='utf-8') as f:
|
original = content
|
||||||
content = f.read()
|
changes = []
|
||||||
except Exception as e:
|
|
||||||
print(f" 无法读取文件: {e}")
|
# 1. Fix implicit Optional (RUF013)
|
||||||
return []
|
# Pattern: def func(arg: type = None) -> def func(arg: type | None = None)
|
||||||
|
implicit_optional_pattern = r'(def\s+\w+\([^)]*?)(\w+\s*:\s*(?!.*\|.*None)([a-zA-Z_][a-zA-Z0-9_\[\]]*)\s*=\s*None)'
|
||||||
original_content = content
|
|
||||||
all_fixes = []
|
def fix_optional(match):
|
||||||
|
prefix = match.group(1)
|
||||||
# 应用各种修复
|
full_arg = match.group(2)
|
||||||
content, fixes = fix_missing_imports(content, filepath)
|
arg_name = full_arg.split(':')[0].strip()
|
||||||
all_fixes.extend(fixes)
|
arg_type = match.group(3).strip()
|
||||||
|
return f'{prefix}{arg_name}: {arg_type} | None = None'
|
||||||
content, fixes = fix_bare_excepts(content)
|
|
||||||
all_fixes.extend(fixes)
|
# More careful approach for implicit Optional
|
||||||
|
lines = content.split('\n')
|
||||||
content, fixes = fix_pep8_formatting(content)
|
new_lines = []
|
||||||
all_fixes.extend(fixes)
|
for line in lines:
|
||||||
|
original_line = line
|
||||||
# 保存修改
|
# Fix patterns like "metadata: dict = None,"
|
||||||
if content != original_content:
|
if re.search(r':\s*\w+\s*=\s*None', line) and '| None' not in line:
|
||||||
try:
|
# Match parameter definitions
|
||||||
with open(filepath, 'w', encoding='utf-8') as f:
|
match = re.search(r'(\w+)\s*:\s*(\w+(?:\[[^\]]+\])?)\s*=\s*None', line)
|
||||||
f.write(content)
|
if match:
|
||||||
print(f" 已修复 {len(all_fixes)} 个问题")
|
param_name = match.group(1)
|
||||||
for fix in all_fixes[:5]: # 只显示前5个
|
param_type = match.group(2)
|
||||||
print(f" - {fix}")
|
if param_type != 'NoneType':
|
||||||
if len(all_fixes) > 5:
|
line = line.replace(f'{param_name}: {param_type} = None',
|
||||||
print(f" ... 还有 {len(all_fixes) - 5} 个修复")
|
f'{param_name}: {param_type} | None = None')
|
||||||
except Exception as e:
|
if line != original_line:
|
||||||
print(f" 保存文件失败: {e}")
|
changes.append(f"Fixed implicit Optional: {param_name}")
|
||||||
else:
|
new_lines.append(line)
|
||||||
print(" 无需修复")
|
content = '\n'.join(new_lines)
|
||||||
|
|
||||||
return all_fixes
|
# 2. Fix unnecessary assignment before return (RET504)
|
||||||
|
return_patterns = [
|
||||||
|
(r'(\s+)entities\s*=\s*json\.loads\([^)]+\)\s*\n\1return\s+entities\b',
|
||||||
|
r'\1return json.loads(entities_match.group(0).split("=")[1].strip().split("\n")[0])'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. Fix RUF010 - Use explicit conversion flag
|
||||||
|
# f"...{str(var)}..." -> f"...{var!s}..."
|
||||||
|
content = re.sub(r'\{str\(([^)]+)\)\}', r'{\1!s}', content)
|
||||||
|
content = re.sub(r'\{repr\(([^)]+)\)\}', r'{\1!r}', content)
|
||||||
|
|
||||||
|
# 4. Fix RET505 - Unnecessary else after return
|
||||||
|
# This is complex, skip for now
|
||||||
|
|
||||||
|
# 5. Fix PERF401 - List comprehensions (basic cases)
|
||||||
|
# This is complex, skip for now
|
||||||
|
|
||||||
|
# 6. Fix RUF012 - Mutable default values
|
||||||
|
# Pattern: def func(arg: list = []) -> def func(arg: list = None) with handling
|
||||||
|
content = re.sub(r'(\w+)\s*:\s*list\s*=\s*\[\]', r'\1: list | None = None', content)
|
||||||
|
content = re.sub(r'(\w+)\s*:\s*dict\s*=\s*\{\}', r'\1: dict | None = None', content)
|
||||||
|
|
||||||
|
# 7. Fix unused imports (basic)
|
||||||
|
# Remove duplicate imports
|
||||||
|
import_lines = re.findall(r'^(import\s+\w+|from\s+\w+\s+import\s+[^\n]+)$', content, re.MULTILINE)
|
||||||
|
seen_imports = set()
|
||||||
|
for imp in import_lines:
|
||||||
|
if imp in seen_imports:
|
||||||
|
content = content.replace(imp + '\n', '\n', 1)
|
||||||
|
changes.append(f"Removed duplicate import: {imp}")
|
||||||
|
seen_imports.add(imp)
|
||||||
|
|
||||||
|
if content != original:
|
||||||
|
with open(filepath, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(content)
|
||||||
|
return True, changes
|
||||||
|
return False, []
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
backend_dir = Path('/root/.openclaw/workspace/projects/insightflow/backend')
|
||||||
base_dir = '/root/.openclaw/workspace/projects/insightflow'
|
py_files = list(backend_dir.glob('*.py'))
|
||||||
backend_dir = os.path.join(base_dir, 'backend')
|
|
||||||
|
fixed_files = []
|
||||||
print("=" * 60)
|
all_changes = []
|
||||||
print("InsightFlow 代码自动修复工具")
|
|
||||||
print("=" * 60)
|
for filepath in py_files:
|
||||||
|
fixed, changes = fix_file(filepath)
|
||||||
# 获取所有 Python 文件
|
if fixed:
|
||||||
files = get_python_files(backend_dir)
|
fixed_files.append(filepath.name)
|
||||||
print(f"\n找到 {len(files)} 个 Python 文件")
|
all_changes.extend([f"{filepath.name}: {c}" for c in changes])
|
||||||
|
|
||||||
total_fixes = 0
|
print(f"Fixed {len(fixed_files)} files:")
|
||||||
fixed_files = 0
|
for f in fixed_files:
|
||||||
|
print(f" - {f}")
|
||||||
for filepath in files:
|
if all_changes:
|
||||||
fixes = fix_file(filepath)
|
print("\nChanges made:")
|
||||||
if fixes:
|
for c in all_changes[:20]:
|
||||||
total_fixes += len(fixes)
|
print(f" {c}")
|
||||||
fixed_files += 1
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"修复完成: {fixed_files} 个文件, {total_fixes} 个问题")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -291,7 +291,10 @@ class AIManager:
|
|||||||
return self._row_to_custom_model(row)
|
return self._row_to_custom_model(row)
|
||||||
|
|
||||||
def list_custom_models(
|
def list_custom_models(
|
||||||
self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
status: ModelStatus | None = None,
|
||||||
) -> list[CustomModel]:
|
) -> list[CustomModel]:
|
||||||
"""列出自定义模型"""
|
"""列出自定义模型"""
|
||||||
query = "SELECT * FROM custom_models WHERE tenant_id = ?"
|
query = "SELECT * FROM custom_models WHERE tenant_id = ?"
|
||||||
@@ -311,7 +314,11 @@ class AIManager:
|
|||||||
return [self._row_to_custom_model(row) for row in rows]
|
return [self._row_to_custom_model(row) for row in rows]
|
||||||
|
|
||||||
def add_training_sample(
|
def add_training_sample(
|
||||||
self, model_id: str, text: str, entities: list[dict], metadata: dict = None,
|
self,
|
||||||
|
model_id: str,
|
||||||
|
text: str,
|
||||||
|
entities: list[dict],
|
||||||
|
metadata: dict | None = None,
|
||||||
) -> TrainingSample:
|
) -> TrainingSample:
|
||||||
"""添加训练样本"""
|
"""添加训练样本"""
|
||||||
sample_id = f"ts_{uuid.uuid4().hex[:16]}"
|
sample_id = f"ts_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -463,8 +470,7 @@ class AIManager:
|
|||||||
json_match = re.search(r"\[.*?\]", content, re.DOTALL)
|
json_match = re.search(r"\[.*?\]", content, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
entities = json.loads(json_match.group())
|
return json.loads(json_match.group())
|
||||||
return entities
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (json.JSONDecodeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -542,8 +548,9 @@ class AIManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
content = [{"type": "text", "text": prompt}]
|
content = [{"type": "text", "text": prompt}]
|
||||||
for url in image_urls:
|
content.extend(
|
||||||
content.append({"type": "image_url", "image_url": {"url": url}})
|
[{"type": "image_url", "image_url": {"url": url}} for url in image_urls]
|
||||||
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": "gpt-4-vision-preview",
|
"model": "gpt-4-vision-preview",
|
||||||
@@ -575,9 +582,9 @@ class AIManager:
|
|||||||
"anthropic-version": "2023-06-01",
|
"anthropic-version": "2023-06-01",
|
||||||
}
|
}
|
||||||
|
|
||||||
content = []
|
content = [
|
||||||
for url in image_urls:
|
{"type": "image", "source": {"type": "url", "url": url}} for url in image_urls
|
||||||
content.append({"type": "image", "source": {"type": "url", "url": url}})
|
]
|
||||||
content.append({"type": "text", "text": prompt})
|
content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -638,7 +645,9 @@ class AIManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_multimodal_analyses(
|
def get_multimodal_analyses(
|
||||||
self, tenant_id: str, project_id: str | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
project_id: str | None = None,
|
||||||
) -> list[MultimodalAnalysis]:
|
) -> list[MultimodalAnalysis]:
|
||||||
"""获取多模态分析历史"""
|
"""获取多模态分析历史"""
|
||||||
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
|
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
|
||||||
@@ -721,7 +730,9 @@ class AIManager:
|
|||||||
return self._row_to_kg_rag(row)
|
return self._row_to_kg_rag(row)
|
||||||
|
|
||||||
def list_kg_rags(
|
def list_kg_rags(
|
||||||
self, tenant_id: str, project_id: str | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
project_id: str | None = None,
|
||||||
) -> list[KnowledgeGraphRAG]:
|
) -> list[KnowledgeGraphRAG]:
|
||||||
"""列出知识图谱 RAG 配置"""
|
"""列出知识图谱 RAG 配置"""
|
||||||
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
|
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
|
||||||
@@ -738,7 +749,11 @@ class AIManager:
|
|||||||
return [self._row_to_kg_rag(row) for row in rows]
|
return [self._row_to_kg_rag(row) for row in rows]
|
||||||
|
|
||||||
async def query_kg_rag(
|
async def query_kg_rag(
|
||||||
self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict],
|
self,
|
||||||
|
rag_id: str,
|
||||||
|
query: str,
|
||||||
|
project_entities: list[dict],
|
||||||
|
project_relations: list[dict],
|
||||||
) -> RAGQuery:
|
) -> RAGQuery:
|
||||||
"""基于知识图谱的 RAG 查询"""
|
"""基于知识图谱的 RAG 查询"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -771,14 +786,15 @@ class AIManager:
|
|||||||
relevant_entities = relevant_entities[:top_k]
|
relevant_entities = relevant_entities[:top_k]
|
||||||
|
|
||||||
# 检索相关关系
|
# 检索相关关系
|
||||||
relevant_relations = []
|
|
||||||
entity_ids = {e["id"] for e in relevant_entities}
|
entity_ids = {e["id"] for e in relevant_entities}
|
||||||
for relation in project_relations:
|
relevant_relations = [
|
||||||
|
relation
|
||||||
|
for relation in project_relations
|
||||||
if (
|
if (
|
||||||
relation.get("source_entity_id") in entity_ids
|
relation.get("source_entity_id") in entity_ids
|
||||||
or relation.get("target_entity_id") in entity_ids
|
or relation.get("target_entity_id") in entity_ids
|
||||||
):
|
)
|
||||||
relevant_relations.append(relation)
|
]
|
||||||
|
|
||||||
# 2. 构建上下文
|
# 2. 构建上下文
|
||||||
context = {"entities": relevant_entities, "relations": relevant_relations[:10]}
|
context = {"entities": relevant_entities, "relations": relevant_relations[:10]}
|
||||||
@@ -1123,7 +1139,8 @@ class AIManager:
|
|||||||
"""获取预测模型"""
|
"""获取预测模型"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM prediction_models WHERE id = ?", (model_id,),
|
"SELECT * FROM prediction_models WHERE id = ?",
|
||||||
|
(model_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
@@ -1132,7 +1149,9 @@ class AIManager:
|
|||||||
return self._row_to_prediction_model(row)
|
return self._row_to_prediction_model(row)
|
||||||
|
|
||||||
def list_prediction_models(
|
def list_prediction_models(
|
||||||
self, tenant_id: str, project_id: str | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
project_id: str | None = None,
|
||||||
) -> list[PredictionModel]:
|
) -> list[PredictionModel]:
|
||||||
"""列出预测模型"""
|
"""列出预测模型"""
|
||||||
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
|
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
|
||||||
@@ -1149,7 +1168,9 @@ class AIManager:
|
|||||||
return [self._row_to_prediction_model(row) for row in rows]
|
return [self._row_to_prediction_model(row) for row in rows]
|
||||||
|
|
||||||
async def train_prediction_model(
|
async def train_prediction_model(
|
||||||
self, model_id: str, historical_data: list[dict],
|
self,
|
||||||
|
model_id: str,
|
||||||
|
historical_data: list[dict],
|
||||||
) -> PredictionModel:
|
) -> PredictionModel:
|
||||||
"""训练预测模型"""
|
"""训练预测模型"""
|
||||||
model = self.get_prediction_model(model_id)
|
model = self.get_prediction_model(model_id)
|
||||||
@@ -1369,7 +1390,9 @@ class AIManager:
|
|||||||
predicted_relations = [
|
predicted_relations = [
|
||||||
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
|
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
|
||||||
for rel_type, count in sorted(
|
for rel_type, count in sorted(
|
||||||
relation_counts.items(), key=lambda x: x[1], reverse=True,
|
relation_counts.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True,
|
||||||
)[:5]
|
)[:5]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1394,7 +1417,10 @@ class AIManager:
|
|||||||
return [self._row_to_prediction_result(row) for row in rows]
|
return [self._row_to_prediction_result(row) for row in rows]
|
||||||
|
|
||||||
def update_prediction_feedback(
|
def update_prediction_feedback(
|
||||||
self, prediction_id: str, actual_value: str, is_correct: bool,
|
self,
|
||||||
|
prediction_id: str,
|
||||||
|
actual_value: str,
|
||||||
|
is_correct: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""更新预测反馈(用于模型改进)"""
|
"""更新预测反馈(用于模型改进)"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class ApiKeyManager:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
owner_id: str | None = None,
|
owner_id: str | None = None,
|
||||||
permissions: list[str] = None,
|
permissions: list[str] | None = None,
|
||||||
rate_limit: int = 60,
|
rate_limit: int = 60,
|
||||||
expires_days: int | None = None,
|
expires_days: int | None = None,
|
||||||
) -> tuple[str, ApiKey]:
|
) -> tuple[str, ApiKey]:
|
||||||
@@ -238,7 +238,8 @@ class ApiKeyManager:
|
|||||||
# 验证所有权(如果提供了 owner_id)
|
# 验证所有权(如果提供了 owner_id)
|
||||||
if owner_id:
|
if owner_id:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,),
|
"SELECT owner_id FROM api_keys WHERE id = ?",
|
||||||
|
(key_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if not row or row[0] != owner_id:
|
if not row or row[0] != owner_id:
|
||||||
return False
|
return False
|
||||||
@@ -267,7 +268,8 @@ class ApiKeyManager:
|
|||||||
|
|
||||||
if owner_id:
|
if owner_id:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id),
|
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?",
|
||||||
|
(key_id, owner_id),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
else:
|
else:
|
||||||
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
|
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
|
||||||
@@ -337,7 +339,8 @@ class ApiKeyManager:
|
|||||||
# 验证所有权
|
# 验证所有权
|
||||||
if owner_id:
|
if owner_id:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT owner_id FROM api_keys WHERE id = ?", (key_id,),
|
"SELECT owner_id FROM api_keys WHERE id = ?",
|
||||||
|
(key_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if not row or row[0] != owner_id:
|
if not row or row[0] != owner_id:
|
||||||
return False
|
return False
|
||||||
@@ -465,7 +468,8 @@ class ApiKeyManager:
|
|||||||
endpoint_params = []
|
endpoint_params = []
|
||||||
if api_key_id:
|
if api_key_id:
|
||||||
endpoint_query = endpoint_query.replace(
|
endpoint_query = endpoint_query.replace(
|
||||||
"WHERE created_at", "WHERE api_key_id = ? AND created_at",
|
"WHERE created_at",
|
||||||
|
"WHERE api_key_id = ? AND created_at",
|
||||||
)
|
)
|
||||||
endpoint_params.insert(0, api_key_id)
|
endpoint_params.insert(0, api_key_id)
|
||||||
|
|
||||||
@@ -486,7 +490,8 @@ class ApiKeyManager:
|
|||||||
daily_params = []
|
daily_params = []
|
||||||
if api_key_id:
|
if api_key_id:
|
||||||
daily_query = daily_query.replace(
|
daily_query = daily_query.replace(
|
||||||
"WHERE created_at", "WHERE api_key_id = ? AND created_at",
|
"WHERE created_at",
|
||||||
|
"WHERE api_key_id = ? AND created_at",
|
||||||
)
|
)
|
||||||
daily_params.insert(0, api_key_id)
|
daily_params.insert(0, api_key_id)
|
||||||
|
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ class CollaborationManager:
|
|||||||
)
|
)
|
||||||
self.db.conn.commit()
|
self.db.conn.commit()
|
||||||
|
|
||||||
def revoke_share_link(self, share_id: str, revoked_by: str) -> bool:
|
def revoke_share_link(self, share_id: str, _revoked_by: str) -> bool:
|
||||||
"""撤销分享链接"""
|
"""撤销分享链接"""
|
||||||
if self.db:
|
if self.db:
|
||||||
cursor = self.db.conn.cursor()
|
cursor = self.db.conn.cursor()
|
||||||
@@ -335,26 +335,24 @@ class CollaborationManager:
|
|||||||
(project_id,),
|
(project_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
shares = []
|
return [
|
||||||
for row in cursor.fetchall():
|
ProjectShare(
|
||||||
shares.append(
|
id=row[0],
|
||||||
ProjectShare(
|
project_id=row[1],
|
||||||
id=row[0],
|
token=row[2],
|
||||||
project_id=row[1],
|
permission=row[3],
|
||||||
token=row[2],
|
created_by=row[4],
|
||||||
permission=row[3],
|
created_at=row[5],
|
||||||
created_by=row[4],
|
expires_at=row[6],
|
||||||
created_at=row[5],
|
max_uses=row[7],
|
||||||
expires_at=row[6],
|
use_count=row[8],
|
||||||
max_uses=row[7],
|
password_hash=row[9],
|
||||||
use_count=row[8],
|
is_active=bool(row[10]),
|
||||||
password_hash=row[9],
|
allow_download=bool(row[11]),
|
||||||
is_active=bool(row[10]),
|
allow_export=bool(row[12]),
|
||||||
allow_download=bool(row[11]),
|
|
||||||
allow_export=bool(row[12]),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return shares
|
for row in cursor.fetchall()
|
||||||
|
]
|
||||||
|
|
||||||
# ============ 评论和批注 ============
|
# ============ 评论和批注 ============
|
||||||
|
|
||||||
@@ -435,7 +433,10 @@ class CollaborationManager:
|
|||||||
self.db.conn.commit()
|
self.db.conn.commit()
|
||||||
|
|
||||||
def get_comments(
|
def get_comments(
|
||||||
self, target_type: str, target_id: str, include_resolved: bool = True,
|
self,
|
||||||
|
target_type: str,
|
||||||
|
target_id: str,
|
||||||
|
include_resolved: bool = True,
|
||||||
) -> list[Comment]:
|
) -> list[Comment]:
|
||||||
"""获取评论列表"""
|
"""获取评论列表"""
|
||||||
if not self.db:
|
if not self.db:
|
||||||
@@ -461,10 +462,7 @@ class CollaborationManager:
|
|||||||
(target_type, target_id),
|
(target_type, target_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
comments = []
|
return [self._row_to_comment(row) for row in cursor.fetchall()]
|
||||||
for row in cursor.fetchall():
|
|
||||||
comments.append(self._row_to_comment(row))
|
|
||||||
return comments
|
|
||||||
|
|
||||||
def _row_to_comment(self, row) -> Comment:
|
def _row_to_comment(self, row) -> Comment:
|
||||||
"""将数据库行转换为Comment对象"""
|
"""将数据库行转换为Comment对象"""
|
||||||
@@ -554,7 +552,10 @@ class CollaborationManager:
|
|||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def get_project_comments(
|
def get_project_comments(
|
||||||
self, project_id: str, limit: int = 50, offset: int = 0,
|
self,
|
||||||
|
project_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
) -> list[Comment]:
|
) -> list[Comment]:
|
||||||
"""获取项目下的所有评论"""
|
"""获取项目下的所有评论"""
|
||||||
if not self.db:
|
if not self.db:
|
||||||
@@ -571,10 +572,7 @@ class CollaborationManager:
|
|||||||
(project_id, limit, offset),
|
(project_id, limit, offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
comments = []
|
return [self._row_to_comment(row) for row in cursor.fetchall()]
|
||||||
for row in cursor.fetchall():
|
|
||||||
comments.append(self._row_to_comment(row))
|
|
||||||
return comments
|
|
||||||
|
|
||||||
# ============ 变更历史 ============
|
# ============ 变更历史 ============
|
||||||
|
|
||||||
@@ -697,10 +695,7 @@ class CollaborationManager:
|
|||||||
(project_id, limit, offset),
|
(project_id, limit, offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
records = []
|
return [self._row_to_change_record(row) for row in cursor.fetchall()]
|
||||||
for row in cursor.fetchall():
|
|
||||||
records.append(self._row_to_change_record(row))
|
|
||||||
return records
|
|
||||||
|
|
||||||
def _row_to_change_record(self, row) -> ChangeRecord:
|
def _row_to_change_record(self, row) -> ChangeRecord:
|
||||||
"""将数据库行转换为ChangeRecord对象"""
|
"""将数据库行转换为ChangeRecord对象"""
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class Entity:
|
|||||||
canonical_name: str = ""
|
canonical_name: str = ""
|
||||||
aliases: list[str] = None
|
aliases: list[str] = None
|
||||||
embedding: str = "" # Phase 3: 实体嵌入向量
|
embedding: str = "" # Phase 3: 实体嵌入向量
|
||||||
attributes: dict = None # Phase 5: 实体属性
|
attributes: dict | None = None # Phase 5: 实体属性
|
||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
|
|
||||||
@@ -149,7 +149,11 @@ class DatabaseManager:
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
return Project(
|
return Project(
|
||||||
id=project_id, name=name, description=description, created_at=now, updated_at=now,
|
id=project_id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_project(self, project_id: str) -> Project | None:
|
def get_project(self, project_id: str) -> Project | None:
|
||||||
@@ -206,7 +210,10 @@ class DatabaseManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def find_similar_entities(
|
def find_similar_entities(
|
||||||
self, project_id: str, name: str, threshold: float = 0.8,
|
self,
|
||||||
|
project_id: str,
|
||||||
|
name: str,
|
||||||
|
threshold: float = 0.8,
|
||||||
) -> list[Entity]:
|
) -> list[Entity]:
|
||||||
"""查找相似实体"""
|
"""查找相似实体"""
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
@@ -243,7 +250,8 @@ class DatabaseManager:
|
|||||||
(json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id),
|
(json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id),
|
||||||
)
|
)
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id),
|
"UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?",
|
||||||
|
(target_id, source_id),
|
||||||
)
|
)
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?",
|
"UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?",
|
||||||
@@ -272,7 +280,8 @@ class DatabaseManager:
|
|||||||
def list_project_entities(self, project_id: str) -> list[Entity]:
|
def list_project_entities(self, project_id: str) -> list[Entity]:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,),
|
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -478,7 +487,8 @@ class DatabaseManager:
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM entity_relations WHERE id = ?", (relation_id,),
|
"SELECT * FROM entity_relations WHERE id = ?",
|
||||||
|
(relation_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
@@ -494,12 +504,14 @@ class DatabaseManager:
|
|||||||
def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str:
|
def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
existing = conn.execute(
|
existing = conn.execute(
|
||||||
"SELECT * FROM glossary WHERE project_id = ? AND term = ?", (project_id, term),
|
"SELECT * FROM glossary WHERE project_id = ? AND term = ?",
|
||||||
|
(project_id, term),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?", (existing["id"],),
|
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?",
|
||||||
|
(existing["id"],),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -519,7 +531,8 @@ class DatabaseManager:
|
|||||||
def list_glossary(self, project_id: str) -> list[dict]:
|
def list_glossary(self, project_id: str) -> list[dict]:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,),
|
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
@@ -605,15 +618,18 @@ class DatabaseManager:
|
|||||||
project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
|
project = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
|
||||||
|
|
||||||
entity_count = conn.execute(
|
entity_count = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) as count FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()["count"]
|
).fetchone()["count"]
|
||||||
|
|
||||||
transcript_count = conn.execute(
|
transcript_count = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) as count FROM transcripts WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()["count"]
|
).fetchone()["count"]
|
||||||
|
|
||||||
relation_count = conn.execute(
|
relation_count = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) as count FROM entity_relations WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()["count"]
|
).fetchone()["count"]
|
||||||
|
|
||||||
recent_transcripts = conn.execute(
|
recent_transcripts = conn.execute(
|
||||||
@@ -645,11 +661,15 @@ class DatabaseManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_transcript_context(
|
def get_transcript_context(
|
||||||
self, transcript_id: str, position: int, context_chars: int = 200,
|
self,
|
||||||
|
transcript_id: str,
|
||||||
|
position: int,
|
||||||
|
context_chars: int = 200,
|
||||||
) -> str:
|
) -> str:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT full_text FROM transcripts WHERE id = ?", (transcript_id,),
|
"SELECT full_text FROM transcripts WHERE id = ?",
|
||||||
|
(transcript_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
if not row:
|
if not row:
|
||||||
@@ -662,7 +682,11 @@ class DatabaseManager:
|
|||||||
# ==================== Phase 5: Timeline Operations ====================
|
# ==================== Phase 5: Timeline Operations ====================
|
||||||
|
|
||||||
def get_project_timeline(
|
def get_project_timeline(
|
||||||
self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None,
|
self,
|
||||||
|
project_id: str,
|
||||||
|
entity_id: str | None = None,
|
||||||
|
start_date: str = None,
|
||||||
|
end_date: str = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
|
|
||||||
@@ -776,7 +800,8 @@ class DatabaseManager:
|
|||||||
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
|
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM attribute_templates WHERE id = ?", (template_id,),
|
"SELECT * FROM attribute_templates WHERE id = ?",
|
||||||
|
(template_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
if row:
|
if row:
|
||||||
@@ -841,7 +866,10 @@ class DatabaseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def set_entity_attribute(
|
def set_entity_attribute(
|
||||||
self, attr: EntityAttribute, changed_by: str = "system", change_reason: str = "",
|
self,
|
||||||
|
attr: EntityAttribute,
|
||||||
|
changed_by: str = "system",
|
||||||
|
change_reason: str = "",
|
||||||
) -> EntityAttribute:
|
) -> EntityAttribute:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
@@ -930,7 +958,11 @@ class DatabaseManager:
|
|||||||
return entity
|
return entity
|
||||||
|
|
||||||
def delete_entity_attribute(
|
def delete_entity_attribute(
|
||||||
self, entity_id: str, template_id: str, changed_by: str = "system", change_reason: str = "",
|
self,
|
||||||
|
entity_id: str,
|
||||||
|
template_id: str,
|
||||||
|
changed_by: str = "system",
|
||||||
|
change_reason: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
old_row = conn.execute(
|
old_row = conn.execute(
|
||||||
@@ -964,7 +996,10 @@ class DatabaseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_attribute_history(
|
def get_attribute_history(
|
||||||
self, entity_id: str = None, template_id: str = None, limit: int = 50,
|
self,
|
||||||
|
entity_id: str | None = None,
|
||||||
|
template_id: str = None,
|
||||||
|
limit: int = 50,
|
||||||
) -> list[AttributeHistory]:
|
) -> list[AttributeHistory]:
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
conditions = []
|
conditions = []
|
||||||
@@ -990,7 +1025,9 @@ class DatabaseManager:
|
|||||||
return [AttributeHistory(**dict(r)) for r in rows]
|
return [AttributeHistory(**dict(r)) for r in rows]
|
||||||
|
|
||||||
def search_entities_by_attributes(
|
def search_entities_by_attributes(
|
||||||
self, project_id: str, attribute_filters: dict[str, str],
|
self,
|
||||||
|
project_id: str,
|
||||||
|
attribute_filters: dict[str, str],
|
||||||
) -> list[Entity]:
|
) -> list[Entity]:
|
||||||
entities = self.list_project_entities(project_id)
|
entities = self.list_project_entities(project_id)
|
||||||
if not attribute_filters:
|
if not attribute_filters:
|
||||||
@@ -1040,8 +1077,8 @@ class DatabaseManager:
|
|||||||
filename: str,
|
filename: str,
|
||||||
duration: float = 0,
|
duration: float = 0,
|
||||||
fps: float = 0,
|
fps: float = 0,
|
||||||
resolution: dict = None,
|
resolution: dict | None = None,
|
||||||
audio_transcript_id: str = None,
|
audio_transcript_id: str | None = None,
|
||||||
full_ocr_text: str = "",
|
full_ocr_text: str = "",
|
||||||
extracted_entities: list[dict] = None,
|
extracted_entities: list[dict] = None,
|
||||||
extracted_relations: list[dict] = None,
|
extracted_relations: list[dict] = None,
|
||||||
@@ -1098,7 +1135,8 @@ class DatabaseManager:
|
|||||||
"""获取项目的所有视频"""
|
"""获取项目的所有视频"""
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC", (project_id,),
|
"SELECT * FROM videos WHERE project_id = ? ORDER BY created_at DESC",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -1121,8 +1159,8 @@ class DatabaseManager:
|
|||||||
video_id: str,
|
video_id: str,
|
||||||
frame_number: int,
|
frame_number: int,
|
||||||
timestamp: float,
|
timestamp: float,
|
||||||
image_url: str = None,
|
image_url: str | None = None,
|
||||||
ocr_text: str = None,
|
ocr_text: str | None = None,
|
||||||
extracted_entities: list[dict] = None,
|
extracted_entities: list[dict] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""创建视频帧记录"""
|
"""创建视频帧记录"""
|
||||||
@@ -1153,7 +1191,8 @@ class DatabaseManager:
|
|||||||
"""获取视频的所有帧"""
|
"""获取视频的所有帧"""
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"""SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""", (video_id,),
|
"""SELECT * FROM video_frames WHERE video_id = ? ORDER BY timestamp""",
|
||||||
|
(video_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -1223,7 +1262,8 @@ class DatabaseManager:
|
|||||||
"""获取项目的所有图片"""
|
"""获取项目的所有图片"""
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC", (project_id,),
|
"SELECT * FROM images WHERE project_id = ? ORDER BY created_at DESC",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -1288,7 +1328,9 @@ class DatabaseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]:
|
def get_project_multimodal_mentions(
|
||||||
|
self, project_id: str, modality: str | None = None
|
||||||
|
) -> list[dict]:
|
||||||
"""获取项目的多模态提及"""
|
"""获取项目的多模态提及"""
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
|
|
||||||
@@ -1381,13 +1423,15 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# 视频数量
|
# 视频数量
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) as count FROM videos WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
stats["video_count"] = row["count"]
|
stats["video_count"] = row["count"]
|
||||||
|
|
||||||
# 图片数量
|
# 图片数量
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) as count FROM images WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
stats["image_count"] = row["count"]
|
stats["image_count"] = row["count"]
|
||||||
|
|
||||||
|
|||||||
@@ -538,7 +538,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""获取 SDK 版本历史"""
|
"""获取 SDK 版本历史"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC", (sdk_id,),
|
"SELECT * FROM sdk_versions WHERE sdk_id = ? ORDER BY created_at DESC",
|
||||||
|
(sdk_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
return [self._row_to_sdk_version(row) for row in rows]
|
return [self._row_to_sdk_version(row) for row in rows]
|
||||||
|
|
||||||
@@ -700,7 +701,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""获取模板详情"""
|
"""获取模板详情"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM template_market WHERE id = ?", (template_id,),
|
"SELECT * FROM template_market WHERE id = ?",
|
||||||
|
(template_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1076,7 +1078,11 @@ class DeveloperEcosystemManager:
|
|||||||
return [self._row_to_plugin(row) for row in rows]
|
return [self._row_to_plugin(row) for row in rows]
|
||||||
|
|
||||||
def review_plugin(
|
def review_plugin(
|
||||||
self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "",
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
reviewed_by: str,
|
||||||
|
status: PluginStatus,
|
||||||
|
notes: str = "",
|
||||||
) -> PluginMarketItem | None:
|
) -> PluginMarketItem | None:
|
||||||
"""审核插件"""
|
"""审核插件"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
@@ -1420,7 +1426,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""获取开发者档案"""
|
"""获取开发者档案"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM developer_profiles WHERE id = ?", (developer_id,),
|
"SELECT * FROM developer_profiles WHERE id = ?",
|
||||||
|
(developer_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1431,7 +1438,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""通过用户 ID 获取开发者档案"""
|
"""通过用户 ID 获取开发者档案"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,),
|
"SELECT * FROM developer_profiles WHERE user_id = ?",
|
||||||
|
(user_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1439,7 +1447,9 @@ class DeveloperEcosystemManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def verify_developer(
|
def verify_developer(
|
||||||
self, developer_id: str, status: DeveloperStatus,
|
self,
|
||||||
|
developer_id: str,
|
||||||
|
status: DeveloperStatus,
|
||||||
) -> DeveloperProfile | None:
|
) -> DeveloperProfile | None:
|
||||||
"""验证开发者"""
|
"""验证开发者"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
@@ -1453,9 +1463,11 @@ class DeveloperEcosystemManager:
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
status.value,
|
status.value,
|
||||||
now
|
(
|
||||||
if status in [DeveloperStatus.VERIFIED, DeveloperStatus.CERTIFIED]
|
now
|
||||||
else None,
|
if status in [DeveloperStatus.VERIFIED, DeveloperStatus.CERTIFIED]
|
||||||
|
else None
|
||||||
|
),
|
||||||
now,
|
now,
|
||||||
developer_id,
|
developer_id,
|
||||||
),
|
),
|
||||||
@@ -1469,7 +1481,8 @@ class DeveloperEcosystemManager:
|
|||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
# 统计插件数量
|
# 统计插件数量
|
||||||
plugin_row = conn.execute(
|
plugin_row = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?", (developer_id,),
|
"SELECT COUNT(*) as count FROM plugin_market WHERE author_id = ?",
|
||||||
|
(developer_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
# 统计模板数量
|
# 统计模板数量
|
||||||
@@ -1583,7 +1596,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""获取代码示例"""
|
"""获取代码示例"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM code_examples WHERE id = ?", (example_id,),
|
"SELECT * FROM code_examples WHERE id = ?",
|
||||||
|
(example_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1699,7 +1713,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""获取 API 文档"""
|
"""获取 API 文档"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM api_documentation WHERE id = ?", (doc_id,),
|
"SELECT * FROM api_documentation WHERE id = ?",
|
||||||
|
(doc_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1799,7 +1814,8 @@ class DeveloperEcosystemManager:
|
|||||||
"""获取开发者门户配置"""
|
"""获取开发者门户配置"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,),
|
"SELECT * FROM developer_portal_configs WHERE id = ?",
|
||||||
|
(config_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class DocumentProcessor:
|
|||||||
"PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2",
|
"PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"PDF extraction failed: {str(e)}")
|
raise ValueError(f"PDF extraction failed: {e!s}")
|
||||||
|
|
||||||
def _extract_docx(self, content: bytes) -> str:
|
def _extract_docx(self, content: bytes) -> str:
|
||||||
"""提取 DOCX 文本"""
|
"""提取 DOCX 文本"""
|
||||||
@@ -109,7 +109,7 @@ class DocumentProcessor:
|
|||||||
"DOCX processing requires python-docx. Install with: pip install python-docx",
|
"DOCX processing requires python-docx. Install with: pip install python-docx",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"DOCX extraction failed: {str(e)}")
|
raise ValueError(f"DOCX extraction failed: {e!s}")
|
||||||
|
|
||||||
def _extract_txt(self, content: bytes) -> str:
|
def _extract_txt(self, content: bytes) -> str:
|
||||||
"""提取纯文本"""
|
"""提取纯文本"""
|
||||||
|
|||||||
@@ -699,7 +699,9 @@ class EnterpriseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_tenant_sso_config(
|
def get_tenant_sso_config(
|
||||||
self, tenant_id: str, provider: str | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
provider: str | None = None,
|
||||||
) -> SSOConfig | None:
|
) -> SSOConfig | None:
|
||||||
"""获取租户的 SSO 配置"""
|
"""获取租户的 SSO 配置"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -871,7 +873,10 @@ class EnterpriseManager:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def create_saml_auth_request(
|
def create_saml_auth_request(
|
||||||
self, tenant_id: str, config_id: str, relay_state: str | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
config_id: str,
|
||||||
|
relay_state: str | None = None,
|
||||||
) -> SAMLAuthRequest:
|
) -> SAMLAuthRequest:
|
||||||
"""创建 SAML 认证请求"""
|
"""创建 SAML 认证请求"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1235,7 +1240,10 @@ class EnterpriseManager:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def _upsert_scim_user(
|
def _upsert_scim_user(
|
||||||
self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any],
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
tenant_id: str,
|
||||||
|
user_data: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""插入或更新 SCIM 用户"""
|
"""插入或更新 SCIM 用户"""
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -1405,7 +1413,11 @@ class EnterpriseManager:
|
|||||||
try:
|
try:
|
||||||
# 获取审计日志数据
|
# 获取审计日志数据
|
||||||
logs = self._fetch_audit_logs(
|
logs = self._fetch_audit_logs(
|
||||||
export.tenant_id, export.start_date, export.end_date, export.filters, db_manager,
|
export.tenant_id,
|
||||||
|
export.start_date,
|
||||||
|
export.end_date,
|
||||||
|
export.filters,
|
||||||
|
db_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据合规标准过滤字段
|
# 根据合规标准过滤字段
|
||||||
@@ -1414,7 +1426,9 @@ class EnterpriseManager:
|
|||||||
|
|
||||||
# 生成导出文件
|
# 生成导出文件
|
||||||
file_path, file_size, checksum = self._generate_export_file(
|
file_path, file_size, checksum = self._generate_export_file(
|
||||||
export_id, logs, export.export_format,
|
export_id,
|
||||||
|
logs,
|
||||||
|
export.export_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
@@ -1465,7 +1479,9 @@ class EnterpriseManager:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def _apply_compliance_filter(
|
def _apply_compliance_filter(
|
||||||
self, logs: list[dict[str, Any]], standard: str,
|
self,
|
||||||
|
logs: list[dict[str, Any]],
|
||||||
|
standard: str,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""应用合规标准字段过滤"""
|
"""应用合规标准字段过滤"""
|
||||||
fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), [])
|
fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), [])
|
||||||
@@ -1481,7 +1497,10 @@ class EnterpriseManager:
|
|||||||
return filtered_logs
|
return filtered_logs
|
||||||
|
|
||||||
def _generate_export_file(
|
def _generate_export_file(
|
||||||
self, export_id: str, logs: list[dict[str, Any]], format: str,
|
self,
|
||||||
|
export_id: str,
|
||||||
|
logs: list[dict[str, Any]],
|
||||||
|
format: str,
|
||||||
) -> tuple[str, int, str]:
|
) -> tuple[str, int, str]:
|
||||||
"""生成导出文件"""
|
"""生成导出文件"""
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -1672,7 +1691,9 @@ class EnterpriseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def list_retention_policies(
|
def list_retention_policies(
|
||||||
self, tenant_id: str, resource_type: str | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
resource_type: str | None = None,
|
||||||
) -> list[DataRetentionPolicy]:
|
) -> list[DataRetentionPolicy]:
|
||||||
"""列出数据保留策略"""
|
"""列出数据保留策略"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1876,7 +1897,10 @@ class EnterpriseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _retain_audit_logs(
|
def _retain_audit_logs(
|
||||||
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime,
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
policy: DataRetentionPolicy,
|
||||||
|
cutoff_date: datetime,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
"""保留审计日志"""
|
"""保留审计日志"""
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -1909,14 +1933,20 @@ class EnterpriseManager:
|
|||||||
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
|
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
|
||||||
|
|
||||||
def _retain_projects(
|
def _retain_projects(
|
||||||
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime,
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
policy: DataRetentionPolicy,
|
||||||
|
cutoff_date: datetime,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
"""保留项目数据"""
|
"""保留项目数据"""
|
||||||
# 简化实现
|
# 简化实现
|
||||||
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
|
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
|
||||||
|
|
||||||
def _retain_transcripts(
|
def _retain_transcripts(
|
||||||
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime,
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
policy: DataRetentionPolicy,
|
||||||
|
cutoff_date: datetime,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
"""保留转录数据"""
|
"""保留转录数据"""
|
||||||
# 简化实现
|
# 简化实现
|
||||||
@@ -2101,9 +2131,11 @@ class EnterpriseManager:
|
|||||||
if isinstance(row["start_date"], str)
|
if isinstance(row["start_date"], str)
|
||||||
else row["start_date"]
|
else row["start_date"]
|
||||||
),
|
),
|
||||||
end_date=datetime.fromisoformat(row["end_date"])
|
end_date=(
|
||||||
if isinstance(row["end_date"], str)
|
datetime.fromisoformat(row["end_date"])
|
||||||
else row["end_date"],
|
if isinstance(row["end_date"], str)
|
||||||
|
else row["end_date"]
|
||||||
|
),
|
||||||
filters=json.loads(row["filters"] or "{}"),
|
filters=json.loads(row["filters"] or "{}"),
|
||||||
compliance_standard=row["compliance_standard"],
|
compliance_standard=row["compliance_standard"],
|
||||||
status=row["status"],
|
status=row["status"],
|
||||||
|
|||||||
@@ -178,7 +178,10 @@ class EntityAligner:
|
|||||||
return best_match
|
return best_match
|
||||||
|
|
||||||
def _fallback_similarity_match(
|
def _fallback_similarity_match(
|
||||||
self, entities: list[object], name: str, exclude_id: str | None = None,
|
self,
|
||||||
|
entities: list[object],
|
||||||
|
name: str,
|
||||||
|
exclude_id: str | None = None,
|
||||||
) -> object | None:
|
) -> object | None:
|
||||||
"""
|
"""
|
||||||
回退到简单的相似度匹配(不使用 embedding)
|
回退到简单的相似度匹配(不使用 embedding)
|
||||||
@@ -212,7 +215,10 @@ class EntityAligner:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def batch_align_entities(
|
def batch_align_entities(
|
||||||
self, project_id: str, new_entities: list[dict], threshold: float | None = None,
|
self,
|
||||||
|
project_id: str,
|
||||||
|
new_entities: list[dict],
|
||||||
|
threshold: float | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
批量对齐实体
|
批量对齐实体
|
||||||
@@ -232,7 +238,10 @@ class EntityAligner:
|
|||||||
|
|
||||||
for new_ent in new_entities:
|
for new_ent in new_entities:
|
||||||
matched = self.find_similar_entity(
|
matched = self.find_similar_entity(
|
||||||
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold,
|
project_id,
|
||||||
|
new_ent["name"],
|
||||||
|
new_ent.get("definition", ""),
|
||||||
|
threshold=threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
|
|||||||
@@ -75,7 +75,10 @@ class ExportManager:
|
|||||||
self.db = db_manager
|
self.db = db_manager
|
||||||
|
|
||||||
def export_knowledge_graph_svg(
|
def export_knowledge_graph_svg(
|
||||||
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation],
|
self,
|
||||||
|
project_id: str,
|
||||||
|
entities: list[ExportEntity],
|
||||||
|
relations: list[ExportRelation],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
导出知识图谱为 SVG 格式
|
导出知识图谱为 SVG 格式
|
||||||
@@ -220,7 +223,10 @@ class ExportManager:
|
|||||||
return "\n".join(svg_parts)
|
return "\n".join(svg_parts)
|
||||||
|
|
||||||
def export_knowledge_graph_png(
|
def export_knowledge_graph_png(
|
||||||
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation],
|
self,
|
||||||
|
project_id: str,
|
||||||
|
entities: list[ExportEntity],
|
||||||
|
relations: list[ExportRelation],
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
导出知识图谱为 PNG 格式
|
导出知识图谱为 PNG 格式
|
||||||
@@ -337,7 +343,9 @@ class ExportManager:
|
|||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
def export_transcript_markdown(
|
def export_transcript_markdown(
|
||||||
self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity],
|
self,
|
||||||
|
transcript: ExportTranscript,
|
||||||
|
entities_map: dict[str, ExportEntity],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
导出转录文本为 Markdown 格式
|
导出转录文本为 Markdown 格式
|
||||||
@@ -417,7 +425,12 @@ class ExportManager:
|
|||||||
|
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
doc = SimpleDocTemplate(
|
doc = SimpleDocTemplate(
|
||||||
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18,
|
output,
|
||||||
|
pagesize=A4,
|
||||||
|
rightMargin=72,
|
||||||
|
leftMargin=72,
|
||||||
|
topMargin=72,
|
||||||
|
bottomMargin=18,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 样式
|
# 样式
|
||||||
@@ -510,7 +523,8 @@ class ExportManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
entity_table = Table(
|
entity_table = Table(
|
||||||
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch],
|
entity_data,
|
||||||
|
colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch],
|
||||||
)
|
)
|
||||||
entity_table.setStyle(
|
entity_table.setStyle(
|
||||||
TableStyle(
|
TableStyle(
|
||||||
@@ -539,7 +553,8 @@ class ExportManager:
|
|||||||
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
|
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
|
||||||
|
|
||||||
relation_table = Table(
|
relation_table = Table(
|
||||||
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch],
|
relation_data,
|
||||||
|
colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch],
|
||||||
)
|
)
|
||||||
relation_table.setStyle(
|
relation_table.setStyle(
|
||||||
TableStyle(
|
TableStyle(
|
||||||
|
|||||||
@@ -383,11 +383,11 @@ class GrowthManager:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
event_name: str,
|
event_name: str,
|
||||||
properties: dict = None,
|
properties: dict | None = None,
|
||||||
session_id: str = None,
|
session_id: str | None = None,
|
||||||
device_info: dict = None,
|
device_info: dict | None = None,
|
||||||
referrer: str = None,
|
referrer: str | None = None,
|
||||||
utm_params: dict = None,
|
utm_params: dict | None = None,
|
||||||
) -> AnalyticsEvent:
|
) -> AnalyticsEvent:
|
||||||
"""追踪事件"""
|
"""追踪事件"""
|
||||||
event_id = f"evt_{uuid.uuid4().hex[:16]}"
|
event_id = f"evt_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -475,7 +475,10 @@ class GrowthManager:
|
|||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
await client.post(
|
await client.post(
|
||||||
"https://api.mixpanel.com/track", headers=headers, json=[payload], timeout=10.0,
|
"https://api.mixpanel.com/track",
|
||||||
|
headers=headers,
|
||||||
|
json=[payload],
|
||||||
|
timeout=10.0,
|
||||||
)
|
)
|
||||||
except (RuntimeError, ValueError, TypeError) as e:
|
except (RuntimeError, ValueError, TypeError) as e:
|
||||||
print(f"Failed to send to Mixpanel: {e}")
|
print(f"Failed to send to Mixpanel: {e}")
|
||||||
@@ -509,7 +512,11 @@ class GrowthManager:
|
|||||||
print(f"Failed to send to Amplitude: {e}")
|
print(f"Failed to send to Amplitude: {e}")
|
||||||
|
|
||||||
async def _update_user_profile(
|
async def _update_user_profile(
|
||||||
self, tenant_id: str, user_id: str, event_type: EventType, event_name: str,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
event_type: EventType,
|
||||||
|
event_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""更新用户画像"""
|
"""更新用户画像"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
@@ -581,7 +588,10 @@ class GrowthManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_analytics_summary(
|
def get_user_analytics_summary(
|
||||||
self, tenant_id: str, start_date: datetime = None, end_date: datetime = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""获取用户分析汇总"""
|
"""获取用户分析汇总"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
@@ -635,7 +645,12 @@ class GrowthManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def create_funnel(
|
def create_funnel(
|
||||||
self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
steps: list[dict],
|
||||||
|
created_by: str,
|
||||||
) -> Funnel:
|
) -> Funnel:
|
||||||
"""创建转化漏斗"""
|
"""创建转化漏斗"""
|
||||||
funnel_id = f"fnl_{uuid.uuid4().hex[:16]}"
|
funnel_id = f"fnl_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -673,12 +688,16 @@ class GrowthManager:
|
|||||||
return funnel
|
return funnel
|
||||||
|
|
||||||
def analyze_funnel(
|
def analyze_funnel(
|
||||||
self, funnel_id: str, period_start: datetime = None, period_end: datetime = None,
|
self,
|
||||||
|
funnel_id: str,
|
||||||
|
period_start: datetime | None = None,
|
||||||
|
period_end: datetime = None,
|
||||||
) -> FunnelAnalysis | None:
|
) -> FunnelAnalysis | None:
|
||||||
"""分析漏斗转化率"""
|
"""分析漏斗转化率"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
funnel_row = conn.execute(
|
funnel_row = conn.execute(
|
||||||
"SELECT * FROM funnels WHERE id = ?", (funnel_id,),
|
"SELECT * FROM funnels WHERE id = ?",
|
||||||
|
(funnel_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not funnel_row:
|
if not funnel_row:
|
||||||
@@ -704,7 +723,8 @@ class GrowthManager:
|
|||||||
WHERE event_name = ? AND timestamp >= ? AND timestamp <= ?
|
WHERE event_name = ? AND timestamp >= ? AND timestamp <= ?
|
||||||
"""
|
"""
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
query, (event_name, period_start.isoformat(), period_end.isoformat()),
|
query,
|
||||||
|
(event_name, period_start.isoformat(), period_end.isoformat()),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
user_count = row["user_count"] if row else 0
|
user_count = row["user_count"] if row else 0
|
||||||
@@ -752,7 +772,10 @@ class GrowthManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def calculate_retention(
|
def calculate_retention(
|
||||||
self, tenant_id: str, cohort_date: datetime, periods: list[int] = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
cohort_date: datetime,
|
||||||
|
periods: list[int] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""计算留存率"""
|
"""计算留存率"""
|
||||||
if periods is None:
|
if periods is None:
|
||||||
@@ -825,7 +848,7 @@ class GrowthManager:
|
|||||||
secondary_metrics: list[str],
|
secondary_metrics: list[str],
|
||||||
min_sample_size: int = 100,
|
min_sample_size: int = 100,
|
||||||
confidence_level: float = 0.95,
|
confidence_level: float = 0.95,
|
||||||
created_by: str = None,
|
created_by: str | None = None,
|
||||||
) -> Experiment:
|
) -> Experiment:
|
||||||
"""创建 A/B 测试实验"""
|
"""创建 A/B 测试实验"""
|
||||||
experiment_id = f"exp_{uuid.uuid4().hex[:16]}"
|
experiment_id = f"exp_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -893,14 +916,17 @@ class GrowthManager:
|
|||||||
"""获取实验详情"""
|
"""获取实验详情"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM experiments WHERE id = ?", (experiment_id,),
|
"SELECT * FROM experiments WHERE id = ?",
|
||||||
|
(experiment_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
return self._row_to_experiment(row)
|
return self._row_to_experiment(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]:
|
def list_experiments(
|
||||||
|
self, tenant_id: str, status: ExperimentStatus | None = None
|
||||||
|
) -> list[Experiment]:
|
||||||
"""列出实验"""
|
"""列出实验"""
|
||||||
query = "SELECT * FROM experiments WHERE tenant_id = ?"
|
query = "SELECT * FROM experiments WHERE tenant_id = ?"
|
||||||
params = [tenant_id]
|
params = [tenant_id]
|
||||||
@@ -916,7 +942,10 @@ class GrowthManager:
|
|||||||
return [self._row_to_experiment(row) for row in rows]
|
return [self._row_to_experiment(row) for row in rows]
|
||||||
|
|
||||||
def assign_variant(
|
def assign_variant(
|
||||||
self, experiment_id: str, user_id: str, user_attributes: dict = None,
|
self,
|
||||||
|
experiment_id: str,
|
||||||
|
user_id: str,
|
||||||
|
user_attributes: dict | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""为用户分配实验变体"""
|
"""为用户分配实验变体"""
|
||||||
experiment = self.get_experiment(experiment_id)
|
experiment = self.get_experiment(experiment_id)
|
||||||
@@ -939,11 +968,15 @@ class GrowthManager:
|
|||||||
variant_id = self._random_allocation(experiment.variants, experiment.traffic_split)
|
variant_id = self._random_allocation(experiment.variants, experiment.traffic_split)
|
||||||
elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED:
|
elif experiment.traffic_allocation == TrafficAllocationType.STRATIFIED:
|
||||||
variant_id = self._stratified_allocation(
|
variant_id = self._stratified_allocation(
|
||||||
experiment.variants, experiment.traffic_split, user_attributes,
|
experiment.variants,
|
||||||
|
experiment.traffic_split,
|
||||||
|
user_attributes,
|
||||||
)
|
)
|
||||||
else: # TARGETED
|
else: # TARGETED
|
||||||
variant_id = self._targeted_allocation(
|
variant_id = self._targeted_allocation(
|
||||||
experiment.variants, experiment.target_audience, user_attributes,
|
experiment.variants,
|
||||||
|
experiment.target_audience,
|
||||||
|
user_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
if variant_id:
|
if variant_id:
|
||||||
@@ -978,7 +1011,10 @@ class GrowthManager:
|
|||||||
return random.choices(variant_ids, weights=normalized_weights, k=1)[0]
|
return random.choices(variant_ids, weights=normalized_weights, k=1)[0]
|
||||||
|
|
||||||
def _stratified_allocation(
|
def _stratified_allocation(
|
||||||
self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict,
|
self,
|
||||||
|
variants: list[dict],
|
||||||
|
traffic_split: dict[str, float],
|
||||||
|
user_attributes: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""分层分配(基于用户属性)"""
|
"""分层分配(基于用户属性)"""
|
||||||
# 简化的分层分配:根据用户 ID 哈希值分配
|
# 简化的分层分配:根据用户 ID 哈希值分配
|
||||||
@@ -991,7 +1027,10 @@ class GrowthManager:
|
|||||||
return self._random_allocation(variants, traffic_split)
|
return self._random_allocation(variants, traffic_split)
|
||||||
|
|
||||||
def _targeted_allocation(
|
def _targeted_allocation(
|
||||||
self, variants: list[dict], target_audience: dict, user_attributes: dict,
|
self,
|
||||||
|
variants: list[dict],
|
||||||
|
target_audience: dict,
|
||||||
|
user_attributes: dict,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""定向分配(基于目标受众条件)"""
|
"""定向分配(基于目标受众条件)"""
|
||||||
# 检查用户是否符合目标受众条件
|
# 检查用户是否符合目标受众条件
|
||||||
@@ -1005,7 +1044,14 @@ class GrowthManager:
|
|||||||
|
|
||||||
user_value = user_attributes.get(attr_name) if user_attributes else None
|
user_value = user_attributes.get(attr_name) if user_attributes else None
|
||||||
|
|
||||||
if operator == "equals" and user_value != value or operator == "not_equals" and user_value == value or operator == "in" and user_value not in value:
|
if (
|
||||||
|
operator == "equals"
|
||||||
|
and user_value != value
|
||||||
|
or operator == "not_equals"
|
||||||
|
and user_value == value
|
||||||
|
or operator == "in"
|
||||||
|
and user_value not in value
|
||||||
|
):
|
||||||
matches = False
|
matches = False
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -1177,11 +1223,11 @@ class GrowthManager:
|
|||||||
template_type: EmailTemplateType,
|
template_type: EmailTemplateType,
|
||||||
subject: str,
|
subject: str,
|
||||||
html_content: str,
|
html_content: str,
|
||||||
text_content: str = None,
|
text_content: str | None = None,
|
||||||
variables: list[str] = None,
|
variables: list[str] = None,
|
||||||
from_name: str = None,
|
from_name: str | None = None,
|
||||||
from_email: str = None,
|
from_email: str | None = None,
|
||||||
reply_to: str = None,
|
reply_to: str | None = None,
|
||||||
) -> EmailTemplate:
|
) -> EmailTemplate:
|
||||||
"""创建邮件模板"""
|
"""创建邮件模板"""
|
||||||
template_id = f"et_{uuid.uuid4().hex[:16]}"
|
template_id = f"et_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -1242,7 +1288,8 @@ class GrowthManager:
|
|||||||
"""获取邮件模板"""
|
"""获取邮件模板"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM email_templates WHERE id = ?", (template_id,),
|
"SELECT * FROM email_templates WHERE id = ?",
|
||||||
|
(template_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1250,7 +1297,9 @@ class GrowthManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def list_email_templates(
|
def list_email_templates(
|
||||||
self, tenant_id: str, template_type: EmailTemplateType = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
template_type: EmailTemplateType | None = None,
|
||||||
) -> list[EmailTemplate]:
|
) -> list[EmailTemplate]:
|
||||||
"""列出邮件模板"""
|
"""列出邮件模板"""
|
||||||
query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1"
|
query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1"
|
||||||
@@ -1297,7 +1346,7 @@ class GrowthManager:
|
|||||||
name: str,
|
name: str,
|
||||||
template_id: str,
|
template_id: str,
|
||||||
recipient_list: list[dict],
|
recipient_list: list[dict],
|
||||||
scheduled_at: datetime = None,
|
scheduled_at: datetime | None = None,
|
||||||
) -> EmailCampaign:
|
) -> EmailCampaign:
|
||||||
"""创建邮件营销活动"""
|
"""创建邮件营销活动"""
|
||||||
campaign_id = f"ec_{uuid.uuid4().hex[:16]}"
|
campaign_id = f"ec_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -1377,7 +1426,12 @@ class GrowthManager:
|
|||||||
return campaign
|
return campaign
|
||||||
|
|
||||||
async def send_email(
|
async def send_email(
|
||||||
self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict,
|
self,
|
||||||
|
campaign_id: str,
|
||||||
|
user_id: str,
|
||||||
|
email: str,
|
||||||
|
template_id: str,
|
||||||
|
variables: dict,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送单封邮件"""
|
"""发送单封邮件"""
|
||||||
template = self.get_email_template(template_id)
|
template = self.get_email_template(template_id)
|
||||||
@@ -1448,7 +1502,8 @@ class GrowthManager:
|
|||||||
"""发送整个营销活动"""
|
"""发送整个营销活动"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
campaign_row = conn.execute(
|
campaign_row = conn.execute(
|
||||||
"SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,),
|
"SELECT * FROM email_campaigns WHERE id = ?",
|
||||||
|
(campaign_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not campaign_row:
|
if not campaign_row:
|
||||||
@@ -1478,7 +1533,11 @@ class GrowthManager:
|
|||||||
variables = self._get_user_variables(log["tenant_id"], log["user_id"])
|
variables = self._get_user_variables(log["tenant_id"], log["user_id"])
|
||||||
|
|
||||||
success = await self.send_email(
|
success = await self.send_email(
|
||||||
campaign_id, log["user_id"], log["email"], log["template_id"], variables,
|
campaign_id,
|
||||||
|
log["user_id"],
|
||||||
|
log["email"],
|
||||||
|
log["template_id"],
|
||||||
|
variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -1763,7 +1822,8 @@ class GrowthManager:
|
|||||||
|
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT 1 FROM referrals WHERE referral_code = ?", (code,),
|
"SELECT 1 FROM referrals WHERE referral_code = ?",
|
||||||
|
(code,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
@@ -1773,7 +1833,8 @@ class GrowthManager:
|
|||||||
"""获取推荐计划"""
|
"""获取推荐计划"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM referral_programs WHERE id = ?", (program_id,),
|
"SELECT * FROM referral_programs WHERE id = ?",
|
||||||
|
(program_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1859,7 +1920,8 @@ class GrowthManager:
|
|||||||
"expired": stats["expired"] or 0,
|
"expired": stats["expired"] or 0,
|
||||||
"unique_referrers": stats["unique_referrers"] or 0,
|
"unique_referrers": stats["unique_referrers"] or 0,
|
||||||
"conversion_rate": round(
|
"conversion_rate": round(
|
||||||
(stats["converted"] or 0) / max(stats["total_referrals"] or 1, 1), 4,
|
(stats["converted"] or 0) / max(stats["total_referrals"] or 1, 1),
|
||||||
|
4,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1922,7 +1984,10 @@ class GrowthManager:
|
|||||||
return incentive
|
return incentive
|
||||||
|
|
||||||
def check_team_incentive_eligibility(
|
def check_team_incentive_eligibility(
|
||||||
self, tenant_id: str, current_tier: str, team_size: int,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
current_tier: str,
|
||||||
|
team_size: int,
|
||||||
) -> list[TeamIncentive]:
|
) -> list[TeamIncentive]:
|
||||||
"""检查团队激励资格"""
|
"""检查团队激励资格"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class ImageProcessor:
|
|||||||
"other": "其他",
|
"other": "其他",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None) -> None:
|
def __init__(self, temp_dir: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
初始化图片处理器
|
初始化图片处理器
|
||||||
|
|
||||||
@@ -106,7 +106,7 @@ class ImageProcessor:
|
|||||||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||||||
os.makedirs(self.temp_dir, exist_ok=True)
|
os.makedirs(self.temp_dir, exist_ok=True)
|
||||||
|
|
||||||
def preprocess_image(self, image, image_type: str = None) -> None:
|
def preprocess_image(self, image, image_type: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
预处理图片以提高OCR质量
|
预处理图片以提高OCR质量
|
||||||
|
|
||||||
@@ -328,7 +328,10 @@ class ImageProcessor:
|
|||||||
return unique_entities
|
return unique_entities
|
||||||
|
|
||||||
def generate_description(
|
def generate_description(
|
||||||
self, image_type: str, ocr_text: str, entities: list[ImageEntity],
|
self,
|
||||||
|
image_type: str,
|
||||||
|
ocr_text: str,
|
||||||
|
entities: list[ImageEntity],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
生成图片描述
|
生成图片描述
|
||||||
@@ -361,8 +364,8 @@ class ImageProcessor:
|
|||||||
def process_image(
|
def process_image(
|
||||||
self,
|
self,
|
||||||
image_data: bytes,
|
image_data: bytes,
|
||||||
filename: str = None,
|
filename: str | None = None,
|
||||||
image_id: str = None,
|
image_id: str | None = None,
|
||||||
detect_type: bool = True,
|
detect_type: bool = True,
|
||||||
) -> ImageProcessingResult:
|
) -> ImageProcessingResult:
|
||||||
"""
|
"""
|
||||||
@@ -487,7 +490,9 @@ class ImageProcessor:
|
|||||||
return relations
|
return relations
|
||||||
|
|
||||||
def process_batch(
|
def process_batch(
|
||||||
self, images_data: list[tuple[bytes, str]], project_id: str = None,
|
self,
|
||||||
|
images_data: list[tuple[bytes, str]],
|
||||||
|
project_id: str | None = None,
|
||||||
) -> BatchProcessingResult:
|
) -> BatchProcessingResult:
|
||||||
"""
|
"""
|
||||||
批量处理图片
|
批量处理图片
|
||||||
@@ -561,7 +566,7 @@ class ImageProcessor:
|
|||||||
_image_processor = None
|
_image_processor = None
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
|
def get_image_processor(temp_dir: str | None = None) -> ImageProcessor:
|
||||||
"""获取图片处理器单例"""
|
"""获取图片处理器单例"""
|
||||||
global _image_processor
|
global _image_processor
|
||||||
if _image_processor is None:
|
if _image_processor is None:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class InferencePath:
|
|||||||
class KnowledgeReasoner:
|
class KnowledgeReasoner:
|
||||||
"""知识推理引擎"""
|
"""知识推理引擎"""
|
||||||
|
|
||||||
def __init__(self, api_key: str = None, base_url: str = None) -> None:
|
def __init__(self, api_key: str | None = None, base_url: str = None) -> None:
|
||||||
self.api_key = api_key or KIMI_API_KEY
|
self.api_key = api_key or KIMI_API_KEY
|
||||||
self.base_url = base_url or KIMI_BASE_URL
|
self.base_url = base_url or KIMI_BASE_URL
|
||||||
self.headers = {
|
self.headers = {
|
||||||
@@ -82,7 +82,11 @@ class KnowledgeReasoner:
|
|||||||
return result["choices"][0]["message"]["content"]
|
return result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
async def enhanced_qa(
|
async def enhanced_qa(
|
||||||
self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium",
|
self,
|
||||||
|
query: str,
|
||||||
|
project_context: dict,
|
||||||
|
graph_data: dict,
|
||||||
|
reasoning_depth: str = "medium",
|
||||||
) -> ReasoningResult:
|
) -> ReasoningResult:
|
||||||
"""
|
"""
|
||||||
增强问答 - 结合图谱推理的问答
|
增强问答 - 结合图谱推理的问答
|
||||||
@@ -139,7 +143,10 @@ class KnowledgeReasoner:
|
|||||||
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
|
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
|
||||||
|
|
||||||
async def _causal_reasoning(
|
async def _causal_reasoning(
|
||||||
self, query: str, project_context: dict, graph_data: dict,
|
self,
|
||||||
|
query: str,
|
||||||
|
project_context: dict,
|
||||||
|
graph_data: dict,
|
||||||
) -> ReasoningResult:
|
) -> ReasoningResult:
|
||||||
"""因果推理 - 分析原因和影响"""
|
"""因果推理 - 分析原因和影响"""
|
||||||
|
|
||||||
@@ -200,7 +207,10 @@ class KnowledgeReasoner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _comparative_reasoning(
|
async def _comparative_reasoning(
|
||||||
self, query: str, project_context: dict, graph_data: dict,
|
self,
|
||||||
|
query: str,
|
||||||
|
project_context: dict,
|
||||||
|
graph_data: dict,
|
||||||
) -> ReasoningResult:
|
) -> ReasoningResult:
|
||||||
"""对比推理 - 比较实体间的异同"""
|
"""对比推理 - 比较实体间的异同"""
|
||||||
|
|
||||||
@@ -254,7 +264,10 @@ class KnowledgeReasoner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _temporal_reasoning(
|
async def _temporal_reasoning(
|
||||||
self, query: str, project_context: dict, graph_data: dict,
|
self,
|
||||||
|
query: str,
|
||||||
|
project_context: dict,
|
||||||
|
graph_data: dict,
|
||||||
) -> ReasoningResult:
|
) -> ReasoningResult:
|
||||||
"""时序推理 - 分析时间线和演变"""
|
"""时序推理 - 分析时间线和演变"""
|
||||||
|
|
||||||
@@ -308,7 +321,10 @@ class KnowledgeReasoner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _associative_reasoning(
|
async def _associative_reasoning(
|
||||||
self, query: str, project_context: dict, graph_data: dict,
|
self,
|
||||||
|
query: str,
|
||||||
|
project_context: dict,
|
||||||
|
graph_data: dict,
|
||||||
) -> ReasoningResult:
|
) -> ReasoningResult:
|
||||||
"""关联推理 - 发现实体间的隐含关联"""
|
"""关联推理 - 发现实体间的隐含关联"""
|
||||||
|
|
||||||
@@ -362,7 +378,11 @@ class KnowledgeReasoner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def find_inference_paths(
|
def find_inference_paths(
|
||||||
self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3,
|
self,
|
||||||
|
start_entity: str,
|
||||||
|
end_entity: str,
|
||||||
|
graph_data: dict,
|
||||||
|
max_depth: int = 3,
|
||||||
) -> list[InferencePath]:
|
) -> list[InferencePath]:
|
||||||
"""
|
"""
|
||||||
发现两个实体之间的推理路径
|
发现两个实体之间的推理路径
|
||||||
@@ -449,7 +469,10 @@ class KnowledgeReasoner:
|
|||||||
return length_factor * confidence_factor
|
return length_factor * confidence_factor
|
||||||
|
|
||||||
async def summarize_project(
|
async def summarize_project(
|
||||||
self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive",
|
self,
|
||||||
|
project_context: dict,
|
||||||
|
graph_data: dict,
|
||||||
|
summary_type: str = "comprehensive",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
项目智能总结
|
项目智能总结
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class RelationExtractionResult:
|
|||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""Kimi API 客户端"""
|
"""Kimi API 客户端"""
|
||||||
|
|
||||||
def __init__(self, api_key: str = None, base_url: str = None) -> None:
|
def __init__(self, api_key: str | None = None, base_url: str = None) -> None:
|
||||||
self.api_key = api_key or KIMI_API_KEY
|
self.api_key = api_key or KIMI_API_KEY
|
||||||
self.base_url = base_url or KIMI_BASE_URL
|
self.base_url = base_url or KIMI_BASE_URL
|
||||||
self.headers = {
|
self.headers = {
|
||||||
@@ -52,7 +52,10 @@ class LLMClient:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False,
|
self,
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
temperature: float = 0.3,
|
||||||
|
stream: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""发送聊天请求"""
|
"""发送聊天请求"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
@@ -77,7 +80,9 @@ class LLMClient:
|
|||||||
return result["choices"][0]["message"]["content"]
|
return result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self, messages: list[ChatMessage], temperature: float = 0.3,
|
self,
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
temperature: float = 0.3,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""流式聊天请求"""
|
"""流式聊天请求"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
@@ -90,13 +95,16 @@ class LLMClient:
|
|||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client, client.stream(
|
async with (
|
||||||
"POST",
|
httpx.AsyncClient() as client,
|
||||||
f"{self.base_url}/v1/chat/completions",
|
client.stream(
|
||||||
headers=self.headers,
|
"POST",
|
||||||
json=payload,
|
f"{self.base_url}/v1/chat/completions",
|
||||||
timeout=120.0,
|
headers=self.headers,
|
||||||
) as response:
|
json=payload,
|
||||||
|
timeout=120.0,
|
||||||
|
) as response,
|
||||||
|
):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
if line.startswith("data: "):
|
if line.startswith("data: "):
|
||||||
@@ -112,7 +120,8 @@ class LLMClient:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def extract_entities_with_confidence(
|
async def extract_entities_with_confidence(
|
||||||
self, text: str,
|
self,
|
||||||
|
text: str,
|
||||||
) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]:
|
) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]:
|
||||||
"""提取实体和关系,带置信度分数"""
|
"""提取实体和关系,带置信度分数"""
|
||||||
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
|
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
|
||||||
@@ -189,7 +198,8 @@ class LLMClient:
|
|||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。",
|
role="system",
|
||||||
|
content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。",
|
||||||
),
|
),
|
||||||
ChatMessage(role="user", content=prompt),
|
ChatMessage(role="user", content=prompt),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -963,7 +963,11 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def get_translation(
|
def get_translation(
|
||||||
self, key: str, language: str, namespace: str = "common", fallback: bool = True,
|
self,
|
||||||
|
key: str,
|
||||||
|
language: str,
|
||||||
|
namespace: str = "common",
|
||||||
|
fallback: bool = True,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -979,7 +983,10 @@ class LocalizationManager:
|
|||||||
lang_config = self.get_language_config(language)
|
lang_config = self.get_language_config(language)
|
||||||
if lang_config and lang_config.fallback_language:
|
if lang_config and lang_config.fallback_language:
|
||||||
return self.get_translation(
|
return self.get_translation(
|
||||||
key, lang_config.fallback_language, namespace, False,
|
key,
|
||||||
|
lang_config.fallback_language,
|
||||||
|
namespace,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
if language != "en":
|
if language != "en":
|
||||||
return self.get_translation(key, "en", namespace, False)
|
return self.get_translation(key, "en", namespace, False)
|
||||||
@@ -1019,7 +1026,11 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def _get_translation_internal(
|
def _get_translation_internal(
|
||||||
self, conn: sqlite3.Connection, key: str, language: str, namespace: str,
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
key: str,
|
||||||
|
language: str,
|
||||||
|
namespace: str,
|
||||||
) -> Translation | None:
|
) -> Translation | None:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -1121,7 +1132,9 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def list_data_centers(
|
def list_data_centers(
|
||||||
self, status: str | None = None, region: str | None = None,
|
self,
|
||||||
|
status: str | None = None,
|
||||||
|
region: str | None = None,
|
||||||
) -> list[DataCenter]:
|
) -> list[DataCenter]:
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -1146,7 +1159,8 @@ class LocalizationManager:
|
|||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?", (tenant_id,),
|
"SELECT * FROM tenant_data_center_mappings WHERE tenant_id = ?",
|
||||||
|
(tenant_id,),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
@@ -1156,7 +1170,10 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def set_tenant_data_center(
|
def set_tenant_data_center(
|
||||||
self, tenant_id: str, region_code: str, data_residency: str = "regional",
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
region_code: str,
|
||||||
|
data_residency: str = "regional",
|
||||||
) -> TenantDataCenterMapping:
|
) -> TenantDataCenterMapping:
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -1222,7 +1239,8 @@ class LocalizationManager:
|
|||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM localized_payment_methods WHERE provider = ?", (provider,),
|
"SELECT * FROM localized_payment_methods WHERE provider = ?",
|
||||||
|
(provider,),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
@@ -1232,7 +1250,10 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def list_payment_methods(
|
def list_payment_methods(
|
||||||
self, country_code: str | None = None, currency: str | None = None, active_only: bool = True,
|
self,
|
||||||
|
country_code: str | None = None,
|
||||||
|
currency: str | None = None,
|
||||||
|
active_only: bool = True,
|
||||||
) -> list[LocalizedPaymentMethod]:
|
) -> list[LocalizedPaymentMethod]:
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -1255,7 +1276,9 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def get_localized_payment_methods(
|
def get_localized_payment_methods(
|
||||||
self, country_code: str, language: str = "en",
|
self,
|
||||||
|
country_code: str,
|
||||||
|
language: str = "en",
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
methods = self.list_payment_methods(country_code=country_code)
|
methods = self.list_payment_methods(country_code=country_code)
|
||||||
result = []
|
result = []
|
||||||
@@ -1287,7 +1310,9 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def list_country_configs(
|
def list_country_configs(
|
||||||
self, region: str | None = None, active_only: bool = True,
|
self,
|
||||||
|
region: str | None = None,
|
||||||
|
active_only: bool = True,
|
||||||
) -> list[CountryConfig]:
|
) -> list[CountryConfig]:
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -1345,14 +1370,19 @@ class LocalizationManager:
|
|||||||
return dt.strftime("%Y-%m-%d %H:%M")
|
return dt.strftime("%Y-%m-%d %H:%M")
|
||||||
|
|
||||||
def format_number(
|
def format_number(
|
||||||
self, number: float, language: str = "en", decimal_places: int | None = None,
|
self,
|
||||||
|
number: float,
|
||||||
|
language: str = "en",
|
||||||
|
decimal_places: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
if BABEL_AVAILABLE:
|
if BABEL_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
locale = Locale.parse(language.replace("_", "-"))
|
locale = Locale.parse(language.replace("_", "-"))
|
||||||
return numbers.format_decimal(
|
return numbers.format_decimal(
|
||||||
number, locale=locale, decimal_quantization=(decimal_places is not None),
|
number,
|
||||||
|
locale=locale,
|
||||||
|
decimal_quantization=(decimal_places is not None),
|
||||||
)
|
)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
pass
|
pass
|
||||||
@@ -1514,7 +1544,9 @@ class LocalizationManager:
|
|||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def detect_user_preferences(
|
def detect_user_preferences(
|
||||||
self, accept_language: str | None = None, ip_country: str | None = None,
|
self,
|
||||||
|
accept_language: str | None = None,
|
||||||
|
ip_country: str | None = None,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"}
|
preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"}
|
||||||
if accept_language:
|
if accept_language:
|
||||||
|
|||||||
734
backend/main.py
734
backend/main.py
File diff suppressed because it is too large
Load Diff
@@ -30,7 +30,7 @@ class MultimodalEntity:
|
|||||||
source_id: str
|
source_id: str
|
||||||
mention_context: str
|
mention_context: str
|
||||||
confidence: float
|
confidence: float
|
||||||
modality_features: dict = None # 模态特定特征
|
modality_features: dict | None = None # 模态特定特征
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.modality_features is None:
|
if self.modality_features is None:
|
||||||
@@ -137,7 +137,8 @@ class MultimodalEntityLinker:
|
|||||||
"""
|
"""
|
||||||
# 名称相似度
|
# 名称相似度
|
||||||
name_sim = self.calculate_string_similarity(
|
name_sim = self.calculate_string_similarity(
|
||||||
entity1.get("name", ""), entity2.get("name", ""),
|
entity1.get("name", ""),
|
||||||
|
entity2.get("name", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果名称完全匹配
|
# 如果名称完全匹配
|
||||||
@@ -158,7 +159,8 @@ class MultimodalEntityLinker:
|
|||||||
|
|
||||||
# 定义相似度
|
# 定义相似度
|
||||||
def_sim = self.calculate_string_similarity(
|
def_sim = self.calculate_string_similarity(
|
||||||
entity1.get("definition", ""), entity2.get("definition", ""),
|
entity1.get("definition", ""),
|
||||||
|
entity2.get("definition", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 综合相似度
|
# 综合相似度
|
||||||
@@ -170,7 +172,10 @@ class MultimodalEntityLinker:
|
|||||||
return combined_sim, "none"
|
return combined_sim, "none"
|
||||||
|
|
||||||
def find_matching_entity(
|
def find_matching_entity(
|
||||||
self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None,
|
self,
|
||||||
|
query_entity: dict,
|
||||||
|
candidate_entities: list[dict],
|
||||||
|
exclude_ids: set[str] = None,
|
||||||
) -> AlignmentResult | None:
|
) -> AlignmentResult | None:
|
||||||
"""
|
"""
|
||||||
在候选实体中查找匹配的实体
|
在候选实体中查找匹配的实体
|
||||||
@@ -270,7 +275,10 @@ class MultimodalEntityLinker:
|
|||||||
return links
|
return links
|
||||||
|
|
||||||
def fuse_entity_knowledge(
|
def fuse_entity_knowledge(
|
||||||
self, entity_id: str, linked_entities: list[dict], multimodal_mentions: list[dict],
|
self,
|
||||||
|
entity_id: str,
|
||||||
|
linked_entities: list[dict],
|
||||||
|
multimodal_mentions: list[dict],
|
||||||
) -> FusionResult:
|
) -> FusionResult:
|
||||||
"""
|
"""
|
||||||
融合多模态实体知识
|
融合多模态实体知识
|
||||||
@@ -394,7 +402,9 @@ class MultimodalEntityLinker:
|
|||||||
return conflicts
|
return conflicts
|
||||||
|
|
||||||
def suggest_entity_merges(
|
def suggest_entity_merges(
|
||||||
self, entities: list[dict], existing_links: list[EntityLink] = None,
|
self,
|
||||||
|
entities: list[dict],
|
||||||
|
existing_links: list[EntityLink] = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
建议实体合并
|
建议实体合并
|
||||||
@@ -510,9 +520,9 @@ class MultimodalEntityLinker:
|
|||||||
"total_multimodal_records": len(multimodal_entities),
|
"total_multimodal_records": len(multimodal_entities),
|
||||||
"unique_entities": len(entity_modalities),
|
"unique_entities": len(entity_modalities),
|
||||||
"cross_modal_entities": cross_modal_count,
|
"cross_modal_entities": cross_modal_count,
|
||||||
"cross_modal_ratio": cross_modal_count / len(entity_modalities)
|
"cross_modal_ratio": (
|
||||||
if entity_modalities
|
cross_modal_count / len(entity_modalities) if entity_modalities else 0
|
||||||
else 0,
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class VideoInfo:
|
|||||||
transcript_id: str = ""
|
transcript_id: str = ""
|
||||||
status: str = "pending"
|
status: str = "pending"
|
||||||
error_message: str = ""
|
error_message: str = ""
|
||||||
metadata: dict = None
|
metadata: dict | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
@@ -97,7 +97,7 @@ class VideoProcessingResult:
|
|||||||
class MultimodalProcessor:
|
class MultimodalProcessor:
|
||||||
"""多模态处理器 - 处理视频文件"""
|
"""多模态处理器 - 处理视频文件"""
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
|
def __init__(self, temp_dir: str | None = None, frame_interval: int = 5) -> None:
|
||||||
"""
|
"""
|
||||||
初始化多模态处理器
|
初始化多模态处理器
|
||||||
|
|
||||||
@@ -130,10 +130,12 @@ class MultimodalProcessor:
|
|||||||
if FFMPEG_AVAILABLE:
|
if FFMPEG_AVAILABLE:
|
||||||
probe = ffmpeg.probe(video_path)
|
probe = ffmpeg.probe(video_path)
|
||||||
video_stream = next(
|
video_stream = next(
|
||||||
(s for s in probe["streams"] if s["codec_type"] == "video"), None,
|
(s for s in probe["streams"] if s["codec_type"] == "video"),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
audio_stream = next(
|
audio_stream = next(
|
||||||
(s for s in probe["streams"] if s["codec_type"] == "audio"), None,
|
(s for s in probe["streams"] if s["codec_type"] == "audio"),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if video_stream:
|
if video_stream:
|
||||||
@@ -165,9 +167,9 @@ class MultimodalProcessor:
|
|||||||
return {
|
return {
|
||||||
"duration": float(data["format"].get("duration", 0)),
|
"duration": float(data["format"].get("duration", 0)),
|
||||||
"width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0,
|
"width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0,
|
||||||
"height": int(data["streams"][0].get("height", 0))
|
"height": (
|
||||||
if data["streams"]
|
int(data["streams"][0].get("height", 0)) if data["streams"] else 0
|
||||||
else 0,
|
),
|
||||||
"fps": 30.0, # 默认值
|
"fps": 30.0, # 默认值
|
||||||
"has_audio": len(data["streams"]) > 1,
|
"has_audio": len(data["streams"]) > 1,
|
||||||
"bitrate": int(data["format"].get("bit_rate", 0)),
|
"bitrate": int(data["format"].get("bit_rate", 0)),
|
||||||
@@ -177,7 +179,7 @@ class MultimodalProcessor:
|
|||||||
|
|
||||||
return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0}
|
return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0}
|
||||||
|
|
||||||
def extract_audio(self, video_path: str, output_path: str = None) -> str:
|
def extract_audio(self, video_path: str, output_path: str | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
从视频中提取音频
|
从视频中提取音频
|
||||||
|
|
||||||
@@ -223,7 +225,9 @@ class MultimodalProcessor:
|
|||||||
print(f"Error extracting audio: {e}")
|
print(f"Error extracting audio: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
|
def extract_keyframes(
|
||||||
|
self, video_path: str, video_id: str, interval: int | None = None
|
||||||
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
从视频中提取关键帧
|
从视频中提取关键帧
|
||||||
|
|
||||||
@@ -260,7 +264,8 @@ class MultimodalProcessor:
|
|||||||
if frame_number % frame_interval_frames == 0:
|
if frame_number % frame_interval_frames == 0:
|
||||||
timestamp = frame_number / fps
|
timestamp = frame_number / fps
|
||||||
frame_path = os.path.join(
|
frame_path = os.path.join(
|
||||||
video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg",
|
video_frames_dir,
|
||||||
|
f"frame_{frame_number:06d}_{timestamp:.2f}.jpg",
|
||||||
)
|
)
|
||||||
cv2.imwrite(frame_path, frame)
|
cv2.imwrite(frame_path, frame)
|
||||||
frame_paths.append(frame_path)
|
frame_paths.append(frame_path)
|
||||||
@@ -333,7 +338,11 @@ class MultimodalProcessor:
|
|||||||
return "", 0.0
|
return "", 0.0
|
||||||
|
|
||||||
def process_video(
|
def process_video(
|
||||||
self, video_data: bytes, filename: str, project_id: str, video_id: str = None,
|
self,
|
||||||
|
video_data: bytes,
|
||||||
|
filename: str,
|
||||||
|
project_id: str,
|
||||||
|
video_id: str | None = None,
|
||||||
) -> VideoProcessingResult:
|
) -> VideoProcessingResult:
|
||||||
"""
|
"""
|
||||||
处理视频文件:提取音频、关键帧、OCR
|
处理视频文件:提取音频、关键帧、OCR
|
||||||
@@ -426,7 +435,7 @@ class MultimodalProcessor:
|
|||||||
error_message=str(e),
|
error_message=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
def cleanup(self, video_id: str = None) -> None:
|
def cleanup(self, video_id: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
清理临时文件
|
清理临时文件
|
||||||
|
|
||||||
@@ -457,7 +466,9 @@ class MultimodalProcessor:
|
|||||||
_multimodal_processor = None
|
_multimodal_processor = None
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
|
def get_multimodal_processor(
|
||||||
|
temp_dir: str | None = None, frame_interval: int = 5
|
||||||
|
) -> MultimodalProcessor:
|
||||||
"""获取多模态处理器单例"""
|
"""获取多模态处理器单例"""
|
||||||
global _multimodal_processor
|
global _multimodal_processor
|
||||||
if _multimodal_processor is None:
|
if _multimodal_processor is None:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class GraphEntity:
|
|||||||
type: str
|
type: str
|
||||||
definition: str = ""
|
definition: str = ""
|
||||||
aliases: list[str] = None
|
aliases: list[str] = None
|
||||||
properties: dict = None
|
properties: dict | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.aliases is None:
|
if self.aliases is None:
|
||||||
@@ -55,7 +55,7 @@ class GraphRelation:
|
|||||||
target_id: str
|
target_id: str
|
||||||
relation_type: str
|
relation_type: str
|
||||||
evidence: str = ""
|
evidence: str = ""
|
||||||
properties: dict = None
|
properties: dict | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.properties is None:
|
if self.properties is None:
|
||||||
@@ -95,7 +95,7 @@ class CentralityResult:
|
|||||||
class Neo4jManager:
|
class Neo4jManager:
|
||||||
"""Neo4j 图数据库管理器"""
|
"""Neo4j 图数据库管理器"""
|
||||||
|
|
||||||
def __init__(self, uri: str = None, user: str = None, password: str = None) -> None:
|
def __init__(self, uri: str | None = None, user: str = None, password: str = None) -> None:
|
||||||
self.uri = uri or NEO4J_URI
|
self.uri = uri or NEO4J_URI
|
||||||
self.user = user or NEO4J_USER
|
self.user = user or NEO4J_USER
|
||||||
self.password = password or NEO4J_PASSWORD
|
self.password = password or NEO4J_PASSWORD
|
||||||
@@ -179,7 +179,10 @@ class Neo4jManager:
|
|||||||
# ==================== 数据同步 ====================
|
# ==================== 数据同步 ====================
|
||||||
|
|
||||||
def sync_project(
|
def sync_project(
|
||||||
self, project_id: str, project_name: str, project_description: str = "",
|
self,
|
||||||
|
project_id: str,
|
||||||
|
project_name: str,
|
||||||
|
project_description: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""同步项目节点到 Neo4j"""
|
"""同步项目节点到 Neo4j"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
@@ -352,7 +355,10 @@ class Neo4jManager:
|
|||||||
# ==================== 复杂图查询 ====================
|
# ==================== 复杂图查询 ====================
|
||||||
|
|
||||||
def find_shortest_path(
|
def find_shortest_path(
|
||||||
self, source_id: str, target_id: str, max_depth: int = 10,
|
self,
|
||||||
|
source_id: str,
|
||||||
|
target_id: str,
|
||||||
|
max_depth: int = 10,
|
||||||
) -> PathResult | None:
|
) -> PathResult | None:
|
||||||
"""
|
"""
|
||||||
查找两个实体之间的最短路径
|
查找两个实体之间的最短路径
|
||||||
@@ -404,11 +410,17 @@ class Neo4jManager:
|
|||||||
]
|
]
|
||||||
|
|
||||||
return PathResult(
|
return PathResult(
|
||||||
nodes=nodes, relationships=relationships, length=len(path.relationships),
|
nodes=nodes,
|
||||||
|
relationships=relationships,
|
||||||
|
length=len(path.relationships),
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_all_paths(
|
def find_all_paths(
|
||||||
self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10,
|
self,
|
||||||
|
source_id: str,
|
||||||
|
target_id: str,
|
||||||
|
max_depth: int = 5,
|
||||||
|
limit: int = 10,
|
||||||
) -> list[PathResult]:
|
) -> list[PathResult]:
|
||||||
"""
|
"""
|
||||||
查找两个实体之间的所有路径
|
查找两个实体之间的所有路径
|
||||||
@@ -460,14 +472,19 @@ class Neo4jManager:
|
|||||||
|
|
||||||
paths.append(
|
paths.append(
|
||||||
PathResult(
|
PathResult(
|
||||||
nodes=nodes, relationships=relationships, length=len(path.relationships),
|
nodes=nodes,
|
||||||
|
relationships=relationships,
|
||||||
|
length=len(path.relationships),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
def find_neighbors(
|
def find_neighbors(
|
||||||
self, entity_id: str, relation_type: str = None, limit: int = 50,
|
self,
|
||||||
|
entity_id: str,
|
||||||
|
relation_type: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
查找实体的邻居节点
|
查找实体的邻居节点
|
||||||
@@ -752,7 +769,10 @@ class Neo4jManager:
|
|||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
CommunityResult(
|
CommunityResult(
|
||||||
community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0),
|
community_id=comm_id,
|
||||||
|
nodes=nodes,
|
||||||
|
size=size,
|
||||||
|
density=min(density, 1.0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -761,7 +781,9 @@ class Neo4jManager:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def find_central_entities(
|
def find_central_entities(
|
||||||
self, project_id: str, metric: str = "degree",
|
self,
|
||||||
|
project_id: str,
|
||||||
|
metric: str = "degree",
|
||||||
) -> list[CentralityResult]:
|
) -> list[CentralityResult]:
|
||||||
"""
|
"""
|
||||||
查找中心实体
|
查找中心实体
|
||||||
@@ -896,9 +918,11 @@ class Neo4jManager:
|
|||||||
"type_distribution": types,
|
"type_distribution": types,
|
||||||
"average_degree": round(avg_degree, 2) if avg_degree else 0,
|
"average_degree": round(avg_degree, 2) if avg_degree else 0,
|
||||||
"relation_type_distribution": relation_types,
|
"relation_type_distribution": relation_types,
|
||||||
"density": round(relation_count / (entity_count * (entity_count - 1)), 4)
|
"density": (
|
||||||
if entity_count > 1
|
round(relation_count / (entity_count * (entity_count - 1)), 4)
|
||||||
else 0,
|
if entity_count > 1
|
||||||
|
else 0
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
|
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
|
||||||
@@ -993,7 +1017,10 @@ def close_neo4j_manager() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def sync_project_to_neo4j(
|
def sync_project_to_neo4j(
|
||||||
project_id: str, project_name: str, entities: list[dict], relations: list[dict],
|
project_id: str,
|
||||||
|
project_name: str,
|
||||||
|
entities: list[dict],
|
||||||
|
relations: list[dict],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
同步整个项目到 Neo4j
|
同步整个项目到 Neo4j
|
||||||
|
|||||||
@@ -680,7 +680,8 @@ class OpsManager:
|
|||||||
"""获取告警渠道"""
|
"""获取告警渠道"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM alert_channels WHERE id = ?", (channel_id,),
|
"SELECT * FROM alert_channels WHERE id = ?",
|
||||||
|
(channel_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -819,7 +820,9 @@ class OpsManager:
|
|||||||
for rule in rules:
|
for rule in rules:
|
||||||
# 获取相关指标
|
# 获取相关指标
|
||||||
metrics = self.get_recent_metrics(
|
metrics = self.get_recent_metrics(
|
||||||
tenant_id, rule.metric, seconds=rule.duration + rule.evaluation_interval,
|
tenant_id,
|
||||||
|
rule.metric,
|
||||||
|
seconds=rule.duration + rule.evaluation_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 评估规则
|
# 评估规则
|
||||||
@@ -1129,7 +1132,9 @@ class OpsManager:
|
|||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"https://events.pagerduty.com/v2/enqueue", json=message, timeout=30.0,
|
"https://events.pagerduty.com/v2/enqueue",
|
||||||
|
json=message,
|
||||||
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
success = response.status_code == 202
|
success = response.status_code == 202
|
||||||
self._update_channel_stats(channel.id, success)
|
self._update_channel_stats(channel.id, success)
|
||||||
@@ -1299,12 +1304,16 @@ class OpsManager:
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def _update_alert_notification_status(
|
def _update_alert_notification_status(
|
||||||
self, alert_id: str, channel_id: str, success: bool,
|
self,
|
||||||
|
alert_id: str,
|
||||||
|
channel_id: str,
|
||||||
|
success: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""更新告警通知状态"""
|
"""更新告警通知状态"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,),
|
"SELECT notification_sent FROM alerts WHERE id = ?",
|
||||||
|
(alert_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1394,7 +1403,8 @@ class OpsManager:
|
|||||||
"""检查告警是否被抑制"""
|
"""检查告警是否被抑制"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT * FROM alert_suppression_rules WHERE tenant_id = ?", (rule.tenant_id,),
|
"SELECT * FROM alert_suppression_rules WHERE tenant_id = ?",
|
||||||
|
(rule.tenant_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
@@ -1436,7 +1446,7 @@ class OpsManager:
|
|||||||
metric_name: str,
|
metric_name: str,
|
||||||
metric_value: float,
|
metric_value: float,
|
||||||
unit: str,
|
unit: str,
|
||||||
metadata: dict = None,
|
metadata: dict | None = None,
|
||||||
) -> ResourceMetric:
|
) -> ResourceMetric:
|
||||||
"""记录资源指标"""
|
"""记录资源指标"""
|
||||||
metric_id = f"rm_{uuid.uuid4().hex[:16]}"
|
metric_id = f"rm_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -1479,7 +1489,10 @@ class OpsManager:
|
|||||||
return metric
|
return metric
|
||||||
|
|
||||||
def get_recent_metrics(
|
def get_recent_metrics(
|
||||||
self, tenant_id: str, metric_name: str, seconds: int = 3600,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
metric_name: str,
|
||||||
|
seconds: int = 3600,
|
||||||
) -> list[ResourceMetric]:
|
) -> list[ResourceMetric]:
|
||||||
"""获取最近的指标数据"""
|
"""获取最近的指标数据"""
|
||||||
cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat()
|
cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat()
|
||||||
@@ -1531,7 +1544,9 @@ class OpsManager:
|
|||||||
|
|
||||||
# 基于历史数据预测
|
# 基于历史数据预测
|
||||||
metrics = self.get_recent_metrics(
|
metrics = self.get_recent_metrics(
|
||||||
tenant_id, f"{resource_type.value}_usage", seconds=30 * 24 * 3600,
|
tenant_id,
|
||||||
|
f"{resource_type.value}_usage",
|
||||||
|
seconds=30 * 24 * 3600,
|
||||||
)
|
)
|
||||||
|
|
||||||
if metrics:
|
if metrics:
|
||||||
@@ -1704,7 +1719,8 @@ class OpsManager:
|
|||||||
"""获取自动扩缩容策略"""
|
"""获取自动扩缩容策略"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,),
|
"SELECT * FROM auto_scaling_policies WHERE id = ?",
|
||||||
|
(policy_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -1721,7 +1737,10 @@ class OpsManager:
|
|||||||
return [self._row_to_auto_scaling_policy(row) for row in rows]
|
return [self._row_to_auto_scaling_policy(row) for row in rows]
|
||||||
|
|
||||||
def evaluate_scaling_policy(
|
def evaluate_scaling_policy(
|
||||||
self, policy_id: str, current_instances: int, current_utilization: float,
|
self,
|
||||||
|
policy_id: str,
|
||||||
|
current_instances: int,
|
||||||
|
current_utilization: float,
|
||||||
) -> ScalingEvent | None:
|
) -> ScalingEvent | None:
|
||||||
"""评估扩缩容策略"""
|
"""评估扩缩容策略"""
|
||||||
policy = self.get_auto_scaling_policy(policy_id)
|
policy = self.get_auto_scaling_policy(policy_id)
|
||||||
@@ -1826,7 +1845,10 @@ class OpsManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def update_scaling_event_status(
|
def update_scaling_event_status(
|
||||||
self, event_id: str, status: str, error_message: str = None,
|
self,
|
||||||
|
event_id: str,
|
||||||
|
status: str,
|
||||||
|
error_message: str | None = None,
|
||||||
) -> ScalingEvent | None:
|
) -> ScalingEvent | None:
|
||||||
"""更新扩缩容事件状态"""
|
"""更新扩缩容事件状态"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
@@ -1864,7 +1886,10 @@ class OpsManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def list_scaling_events(
|
def list_scaling_events(
|
||||||
self, tenant_id: str, policy_id: str = None, limit: int = 100,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
policy_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
) -> list[ScalingEvent]:
|
) -> list[ScalingEvent]:
|
||||||
"""列出租户的扩缩容事件"""
|
"""列出租户的扩缩容事件"""
|
||||||
query = "SELECT * FROM scaling_events WHERE tenant_id = ?"
|
query = "SELECT * FROM scaling_events WHERE tenant_id = ?"
|
||||||
@@ -2056,7 +2081,8 @@ class OpsManager:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.wait_for(
|
reader, writer = await asyncio.wait_for(
|
||||||
asyncio.open_connection(host, port), timeout=check.timeout,
|
asyncio.open_connection(host, port),
|
||||||
|
timeout=check.timeout,
|
||||||
)
|
)
|
||||||
response_time = (time.time() - start_time) * 1000
|
response_time = (time.time() - start_time) * 1000
|
||||||
writer.close()
|
writer.close()
|
||||||
@@ -2101,7 +2127,7 @@ class OpsManager:
|
|||||||
failover_trigger: str,
|
failover_trigger: str,
|
||||||
auto_failover: bool = False,
|
auto_failover: bool = False,
|
||||||
failover_timeout: int = 300,
|
failover_timeout: int = 300,
|
||||||
health_check_id: str = None,
|
health_check_id: str | None = None,
|
||||||
) -> FailoverConfig:
|
) -> FailoverConfig:
|
||||||
"""创建故障转移配置"""
|
"""创建故障转移配置"""
|
||||||
config_id = f"fc_{uuid.uuid4().hex[:16]}"
|
config_id = f"fc_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -2153,7 +2179,8 @@ class OpsManager:
|
|||||||
"""获取故障转移配置"""
|
"""获取故障转移配置"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM failover_configs WHERE id = ?", (config_id,),
|
"SELECT * FROM failover_configs WHERE id = ?",
|
||||||
|
(config_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -2259,7 +2286,8 @@ class OpsManager:
|
|||||||
"""获取故障转移事件"""
|
"""获取故障转移事件"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM failover_events WHERE id = ?", (event_id,),
|
"SELECT * FROM failover_events WHERE id = ?",
|
||||||
|
(event_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -2290,7 +2318,7 @@ class OpsManager:
|
|||||||
retention_days: int = 30,
|
retention_days: int = 30,
|
||||||
encryption_enabled: bool = True,
|
encryption_enabled: bool = True,
|
||||||
compression_enabled: bool = True,
|
compression_enabled: bool = True,
|
||||||
storage_location: str = None,
|
storage_location: str | None = None,
|
||||||
) -> BackupJob:
|
) -> BackupJob:
|
||||||
"""创建备份任务"""
|
"""创建备份任务"""
|
||||||
job_id = f"bj_{uuid.uuid4().hex[:16]}"
|
job_id = f"bj_{uuid.uuid4().hex[:16]}"
|
||||||
@@ -2410,7 +2438,9 @@ class OpsManager:
|
|||||||
|
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None:
|
def _complete_backup(
|
||||||
|
self, record_id: str, size_bytes: int, checksum: str | None = None
|
||||||
|
) -> None:
|
||||||
"""完成备份"""
|
"""完成备份"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]
|
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]
|
||||||
@@ -2430,7 +2460,8 @@ class OpsManager:
|
|||||||
"""获取备份记录"""
|
"""获取备份记录"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM backup_records WHERE id = ?", (record_id,),
|
"SELECT * FROM backup_records WHERE id = ?",
|
||||||
|
(record_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
@@ -2438,7 +2469,10 @@ class OpsManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def list_backup_records(
|
def list_backup_records(
|
||||||
self, tenant_id: str, job_id: str = None, limit: int = 100,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
job_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
) -> list[BackupRecord]:
|
) -> list[BackupRecord]:
|
||||||
"""列出租户的备份记录"""
|
"""列出租户的备份记录"""
|
||||||
query = "SELECT * FROM backup_records WHERE tenant_id = ?"
|
query = "SELECT * FROM backup_records WHERE tenant_id = ?"
|
||||||
@@ -2624,7 +2658,9 @@ class OpsManager:
|
|||||||
return util
|
return util
|
||||||
|
|
||||||
def get_resource_utilizations(
|
def get_resource_utilizations(
|
||||||
self, tenant_id: str, report_period: str,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
report_period: str,
|
||||||
) -> list[ResourceUtilization]:
|
) -> list[ResourceUtilization]:
|
||||||
"""获取资源利用率列表"""
|
"""获取资源利用率列表"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
@@ -2709,7 +2745,8 @@ class OpsManager:
|
|||||||
return [self._row_to_idle_resource(row) for row in rows]
|
return [self._row_to_idle_resource(row) for row in rows]
|
||||||
|
|
||||||
def generate_cost_optimization_suggestions(
|
def generate_cost_optimization_suggestions(
|
||||||
self, tenant_id: str,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
) -> list[CostOptimizationSuggestion]:
|
) -> list[CostOptimizationSuggestion]:
|
||||||
"""生成成本优化建议"""
|
"""生成成本优化建议"""
|
||||||
suggestions = []
|
suggestions = []
|
||||||
@@ -2777,7 +2814,9 @@ class OpsManager:
|
|||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
def get_cost_optimization_suggestions(
|
def get_cost_optimization_suggestions(
|
||||||
self, tenant_id: str, is_applied: bool = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
is_applied: bool | None = None,
|
||||||
) -> list[CostOptimizationSuggestion]:
|
) -> list[CostOptimizationSuggestion]:
|
||||||
"""获取成本优化建议"""
|
"""获取成本优化建议"""
|
||||||
query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?"
|
query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?"
|
||||||
@@ -2794,7 +2833,8 @@ class OpsManager:
|
|||||||
return [self._row_to_cost_optimization_suggestion(row) for row in rows]
|
return [self._row_to_cost_optimization_suggestion(row) for row in rows]
|
||||||
|
|
||||||
def apply_cost_optimization_suggestion(
|
def apply_cost_optimization_suggestion(
|
||||||
self, suggestion_id: str,
|
self,
|
||||||
|
suggestion_id: str,
|
||||||
) -> CostOptimizationSuggestion | None:
|
) -> CostOptimizationSuggestion | None:
|
||||||
"""应用成本优化建议"""
|
"""应用成本优化建议"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
@@ -2813,12 +2853,14 @@ class OpsManager:
|
|||||||
return self.get_cost_optimization_suggestion(suggestion_id)
|
return self.get_cost_optimization_suggestion(suggestion_id)
|
||||||
|
|
||||||
def get_cost_optimization_suggestion(
|
def get_cost_optimization_suggestion(
|
||||||
self, suggestion_id: str,
|
self,
|
||||||
|
suggestion_id: str,
|
||||||
) -> CostOptimizationSuggestion | None:
|
) -> CostOptimizationSuggestion | None:
|
||||||
"""获取成本优化建议详情"""
|
"""获取成本优化建议详情"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,),
|
"SELECT * FROM cost_optimization_suggestions WHERE id = ?",
|
||||||
|
(suggestion_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
|
|||||||
@@ -444,7 +444,8 @@ class CacheManager:
|
|||||||
"memory_size_bytes": self.current_memory_size,
|
"memory_size_bytes": self.current_memory_size,
|
||||||
"max_memory_size_bytes": self.max_memory_size,
|
"max_memory_size_bytes": self.max_memory_size,
|
||||||
"memory_usage_percent": round(
|
"memory_usage_percent": round(
|
||||||
self.current_memory_size / self.max_memory_size * 100, 2,
|
self.current_memory_size / self.max_memory_size * 100,
|
||||||
|
2,
|
||||||
),
|
),
|
||||||
"cache_entries": len(self.memory_cache),
|
"cache_entries": len(self.memory_cache),
|
||||||
},
|
},
|
||||||
@@ -548,11 +549,13 @@ class CacheManager:
|
|||||||
|
|
||||||
# 预热项目知识库摘要
|
# 预热项目知识库摘要
|
||||||
entity_count = conn.execute(
|
entity_count = conn.execute(
|
||||||
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()[0]
|
).fetchone()[0]
|
||||||
|
|
||||||
relation_count = conn.execute(
|
relation_count = conn.execute(
|
||||||
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,),
|
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchone()[0]
|
).fetchone()[0]
|
||||||
|
|
||||||
summary = {
|
summary = {
|
||||||
@@ -757,11 +760,13 @@ class DatabaseSharding:
|
|||||||
source_conn.row_factory = sqlite3.Row
|
source_conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
entities = source_conn.execute(
|
entities = source_conn.execute(
|
||||||
"SELECT * FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT * FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
relations = source_conn.execute(
|
relations = source_conn.execute(
|
||||||
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id,),
|
"SELECT * FROM entity_relations WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
source_conn.close()
|
source_conn.close()
|
||||||
@@ -1061,7 +1066,9 @@ class TaskQueue:
|
|||||||
task.status = "retrying"
|
task.status = "retrying"
|
||||||
# 延迟重试
|
# 延迟重试
|
||||||
threading.Timer(
|
threading.Timer(
|
||||||
10 * task.retry_count, self._execute_task, args=(task_id,),
|
10 * task.retry_count,
|
||||||
|
self._execute_task,
|
||||||
|
args=(task_id,),
|
||||||
).start()
|
).start()
|
||||||
else:
|
else:
|
||||||
task.status = "failed"
|
task.status = "failed"
|
||||||
@@ -1163,7 +1170,10 @@ class TaskQueue:
|
|||||||
return self.tasks.get(task_id)
|
return self.tasks.get(task_id)
|
||||||
|
|
||||||
def list_tasks(
|
def list_tasks(
|
||||||
self, status: str | None = None, task_type: str | None = None, limit: int = 100,
|
self,
|
||||||
|
status: str | None = None,
|
||||||
|
task_type: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
) -> list[TaskInfo]:
|
) -> list[TaskInfo]:
|
||||||
"""列出任务"""
|
"""列出任务"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -1635,7 +1645,7 @@ def cached(
|
|||||||
cache_key = key_func(*args, **kwargs)
|
cache_key = key_func(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# 默认使用函数名和参数哈希
|
# 默认使用函数名和参数哈希
|
||||||
key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}"
|
key_data = f"{func.__name__}:{args!s}:{kwargs!s}"
|
||||||
cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}"
|
cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}"
|
||||||
|
|
||||||
# 尝试从缓存获取
|
# 尝试从缓存获取
|
||||||
@@ -1754,12 +1764,16 @@ _performance_manager = None
|
|||||||
|
|
||||||
|
|
||||||
def get_performance_manager(
|
def get_performance_manager(
|
||||||
db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False,
|
db_path: str = "insightflow.db",
|
||||||
|
redis_url: str | None = None,
|
||||||
|
enable_sharding: bool = False,
|
||||||
) -> PerformanceManager:
|
) -> PerformanceManager:
|
||||||
"""获取性能管理器单例"""
|
"""获取性能管理器单例"""
|
||||||
global _performance_manager
|
global _performance_manager
|
||||||
if _performance_manager is None:
|
if _performance_manager is None:
|
||||||
_performance_manager = PerformanceManager(
|
_performance_manager = PerformanceManager(
|
||||||
db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding,
|
db_path=db_path,
|
||||||
|
redis_url=redis_url,
|
||||||
|
enable_sharding=enable_sharding,
|
||||||
)
|
)
|
||||||
return _performance_manager
|
return _performance_manager
|
||||||
|
|||||||
@@ -220,7 +220,10 @@ class PluginManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def list_plugins(
|
def list_plugins(
|
||||||
self, project_id: str = None, plugin_type: str = None, status: str = None,
|
self,
|
||||||
|
project_id: str | None = None,
|
||||||
|
plugin_type: str = None,
|
||||||
|
status: str = None,
|
||||||
) -> list[Plugin]:
|
) -> list[Plugin]:
|
||||||
"""列出插件"""
|
"""列出插件"""
|
||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
@@ -241,7 +244,8 @@ class PluginManager:
|
|||||||
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
|
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
|
||||||
|
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC", params,
|
f"SELECT * FROM plugins WHERE {where_clause} ORDER BY created_at DESC",
|
||||||
|
params,
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -310,7 +314,11 @@ class PluginManager:
|
|||||||
# ==================== Plugin Config ====================
|
# ==================== Plugin Config ====================
|
||||||
|
|
||||||
def set_plugin_config(
|
def set_plugin_config(
|
||||||
self, plugin_id: str, key: str, value: str, is_encrypted: bool = False,
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
key: str,
|
||||||
|
value: str,
|
||||||
|
is_encrypted: bool = False,
|
||||||
) -> PluginConfig:
|
) -> PluginConfig:
|
||||||
"""设置插件配置"""
|
"""设置插件配置"""
|
||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
@@ -367,7 +375,8 @@ class PluginManager:
|
|||||||
"""获取插件所有配置"""
|
"""获取插件所有配置"""
|
||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?", (plugin_id,),
|
"SELECT config_key, config_value FROM plugin_configs WHERE plugin_id = ?",
|
||||||
|
(plugin_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -377,7 +386,8 @@ class PluginManager:
|
|||||||
"""删除插件配置"""
|
"""删除插件配置"""
|
||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
cursor = conn.execute(
|
cursor = conn.execute(
|
||||||
"DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?", (plugin_id, key),
|
"DELETE FROM plugin_configs WHERE plugin_id = ? AND config_key = ?",
|
||||||
|
(plugin_id, key),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -408,10 +418,10 @@ class ChromeExtensionHandler:
|
|||||||
def create_token(
|
def create_token(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
user_id: str = None,
|
user_id: str | None = None,
|
||||||
project_id: str = None,
|
project_id: str | None = None,
|
||||||
permissions: list[str] = None,
|
permissions: list[str] = None,
|
||||||
expires_days: int = None,
|
expires_days: int | None = None,
|
||||||
) -> ChromeExtensionToken:
|
) -> ChromeExtensionToken:
|
||||||
"""创建 Chrome 扩展令牌"""
|
"""创建 Chrome 扩展令牌"""
|
||||||
token_id = str(uuid.uuid4())[:UUID_LENGTH]
|
token_id = str(uuid.uuid4())[:UUID_LENGTH]
|
||||||
@@ -512,7 +522,8 @@ class ChromeExtensionHandler:
|
|||||||
"""撤销令牌"""
|
"""撤销令牌"""
|
||||||
conn = self.pm.db.get_conn()
|
conn = self.pm.db.get_conn()
|
||||||
cursor = conn.execute(
|
cursor = conn.execute(
|
||||||
"UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?", (token_id,),
|
"UPDATE chrome_extension_tokens SET is_revoked = 1 WHERE id = ?",
|
||||||
|
(token_id,),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -520,7 +531,9 @@ class ChromeExtensionHandler:
|
|||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def list_tokens(
|
def list_tokens(
|
||||||
self, user_id: str = None, project_id: str = None,
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
project_id: str = None,
|
||||||
) -> list[ChromeExtensionToken]:
|
) -> list[ChromeExtensionToken]:
|
||||||
"""列出令牌"""
|
"""列出令牌"""
|
||||||
conn = self.pm.db.get_conn()
|
conn = self.pm.db.get_conn()
|
||||||
@@ -569,7 +582,7 @@ class ChromeExtensionHandler:
|
|||||||
url: str,
|
url: str,
|
||||||
title: str,
|
title: str,
|
||||||
content: str,
|
content: str,
|
||||||
html_content: str = None,
|
html_content: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""导入网页内容"""
|
"""导入网页内容"""
|
||||||
if not token.project_id:
|
if not token.project_id:
|
||||||
@@ -616,7 +629,7 @@ class BotHandler:
|
|||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
project_id: str = None,
|
project_id: str | None = None,
|
||||||
webhook_url: str = "",
|
webhook_url: str = "",
|
||||||
secret: str = "",
|
secret: str = "",
|
||||||
) -> BotSession:
|
) -> BotSession:
|
||||||
@@ -674,7 +687,7 @@ class BotHandler:
|
|||||||
return self._row_to_session(row)
|
return self._row_to_session(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_sessions(self, project_id: str = None) -> list[BotSession]:
|
def list_sessions(self, project_id: str | None = None) -> list[BotSession]:
|
||||||
"""列出会话"""
|
"""列出会话"""
|
||||||
conn = self.pm.db.get_conn()
|
conn = self.pm.db.get_conn()
|
||||||
|
|
||||||
@@ -849,7 +862,7 @@ class BotHandler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"success": False, "error": f"Failed to process audio: {str(e)}"}
|
return {"success": False, "error": f"Failed to process audio: {e!s}"}
|
||||||
|
|
||||||
async def _handle_file_message(self, session: BotSession, message: dict) -> dict:
|
async def _handle_file_message(self, session: BotSession, message: dict) -> dict:
|
||||||
"""处理文件消息"""
|
"""处理文件消息"""
|
||||||
@@ -897,12 +910,17 @@ class BotHandler:
|
|||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
session.webhook_url, json=payload, headers={"Content-Type": "application/json"},
|
session.webhook_url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
|
|
||||||
async def _send_dingtalk_message(
|
async def _send_dingtalk_message(
|
||||||
self, session: BotSession, message: str, msg_type: str,
|
self,
|
||||||
|
session: BotSession,
|
||||||
|
message: str,
|
||||||
|
msg_type: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送钉钉消息"""
|
"""发送钉钉消息"""
|
||||||
timestamp = str(round(time.time() * 1000))
|
timestamp = str(round(time.time() * 1000))
|
||||||
@@ -928,7 +946,9 @@ class BotHandler:
|
|||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
url, json=payload, headers={"Content-Type": "application/json"},
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
|
|
||||||
@@ -944,9 +964,9 @@ class WebhookIntegration:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
endpoint_url: str,
|
endpoint_url: str,
|
||||||
project_id: str = None,
|
project_id: str | None = None,
|
||||||
auth_type: str = "none",
|
auth_type: str = "none",
|
||||||
auth_config: dict = None,
|
auth_config: dict | None = None,
|
||||||
trigger_events: list[str] = None,
|
trigger_events: list[str] = None,
|
||||||
) -> WebhookEndpoint:
|
) -> WebhookEndpoint:
|
||||||
"""创建 Webhook 端点"""
|
"""创建 Webhook 端点"""
|
||||||
@@ -1004,7 +1024,7 @@ class WebhookIntegration:
|
|||||||
return self._row_to_endpoint(row)
|
return self._row_to_endpoint(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]:
|
def list_endpoints(self, project_id: str | None = None) -> list[WebhookEndpoint]:
|
||||||
"""列出端点"""
|
"""列出端点"""
|
||||||
conn = self.pm.db.get_conn()
|
conn = self.pm.db.get_conn()
|
||||||
|
|
||||||
@@ -1115,7 +1135,10 @@ class WebhookIntegration:
|
|||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
endpoint.endpoint_url, json=payload, headers=headers, timeout=30.0,
|
endpoint.endpoint_url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
success = response.status_code in [200, 201, 202]
|
success = response.status_code in [200, 201, 202]
|
||||||
@@ -1229,7 +1252,7 @@ class WebDAVSyncManager:
|
|||||||
return self._row_to_sync(row)
|
return self._row_to_sync(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_syncs(self, project_id: str = None) -> list[WebDAVSync]:
|
def list_syncs(self, project_id: str | None = None) -> list[WebDAVSync]:
|
||||||
"""列出同步配置"""
|
"""列出同步配置"""
|
||||||
conn = self.pm.db.get_conn()
|
conn = self.pm.db.get_conn()
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,10 @@ class RateLimiter:
|
|||||||
await counter.add_request()
|
await counter.add_request()
|
||||||
|
|
||||||
return RateLimitInfo(
|
return RateLimitInfo(
|
||||||
allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0,
|
allowed=True,
|
||||||
|
remaining=remaining - 1,
|
||||||
|
reset_time=reset_time,
|
||||||
|
retry_after=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||||||
@@ -145,9 +148,9 @@ class RateLimiter:
|
|||||||
allowed=current_count < config.requests_per_minute,
|
allowed=current_count < config.requests_per_minute,
|
||||||
remaining=remaining,
|
remaining=remaining,
|
||||||
reset_time=reset_time,
|
reset_time=reset_time,
|
||||||
retry_after=max(0, config.window_size)
|
retry_after=(
|
||||||
if current_count >= config.requests_per_minute
|
max(0, config.window_size) if current_count >= config.requests_per_minute else 0
|
||||||
else 0,
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self, key: str | None = None) -> None:
|
def reset(self, key: str | None = None) -> None:
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ class FullTextSearch:
|
|||||||
# 排序和分页
|
# 排序和分页
|
||||||
scored_results.sort(key=lambda x: x.score, reverse=True)
|
scored_results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
return scored_results[offset: offset + limit]
|
return scored_results[offset : offset + limit]
|
||||||
|
|
||||||
def _parse_boolean_query(self, query: str) -> dict:
|
def _parse_boolean_query(self, query: str) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -545,19 +545,24 @@ class FullTextSearch:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def _get_content_by_id(
|
def _get_content_by_id(
|
||||||
self, conn: sqlite3.Connection, content_id: str, content_type: str,
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
content_id: str,
|
||||||
|
content_type: str,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""根据ID获取内容"""
|
"""根据ID获取内容"""
|
||||||
try:
|
try:
|
||||||
if content_type == "transcript":
|
if content_type == "transcript":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,),
|
"SELECT full_text FROM transcripts WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
return row["full_text"] if row else None
|
return row["full_text"] if row else None
|
||||||
|
|
||||||
elif content_type == "entity":
|
elif content_type == "entity":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT name, definition FROM entities WHERE id = ?", (content_id,),
|
"SELECT name, definition FROM entities WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if row:
|
if row:
|
||||||
return f"{row['name']} {row['definition'] or ''}"
|
return f"{row['name']} {row['definition'] or ''}"
|
||||||
@@ -583,21 +588,27 @@ class FullTextSearch:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_project_id(
|
def _get_project_id(
|
||||||
self, conn: sqlite3.Connection, content_id: str, content_type: str,
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
content_id: str,
|
||||||
|
content_type: str,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""获取内容所属的项目ID"""
|
"""获取内容所属的项目ID"""
|
||||||
try:
|
try:
|
||||||
if content_type == "transcript":
|
if content_type == "transcript":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT project_id FROM transcripts WHERE id = ?", (content_id,),
|
"SELECT project_id FROM transcripts WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
elif content_type == "entity":
|
elif content_type == "entity":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT project_id FROM entities WHERE id = ?", (content_id,),
|
"SELECT project_id FROM entities WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
elif content_type == "relation":
|
elif content_type == "relation":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT project_id FROM entity_relations WHERE id = ?", (content_id,),
|
"SELECT project_id FROM entity_relations WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -880,7 +891,11 @@ class SemanticSearch:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def index_embedding(
|
def index_embedding(
|
||||||
self, content_id: str, content_type: str, project_id: str, text: str,
|
self,
|
||||||
|
content_id: str,
|
||||||
|
content_type: str,
|
||||||
|
project_id: str,
|
||||||
|
text: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
为内容生成并保存 embedding
|
为内容生成并保存 embedding
|
||||||
@@ -1029,13 +1044,15 @@ class SemanticSearch:
|
|||||||
try:
|
try:
|
||||||
if content_type == "transcript":
|
if content_type == "transcript":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,),
|
"SELECT full_text FROM transcripts WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
result = row["full_text"] if row else None
|
result = row["full_text"] if row else None
|
||||||
|
|
||||||
elif content_type == "entity":
|
elif content_type == "entity":
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT name, definition FROM entities WHERE id = ?", (content_id,),
|
"SELECT name, definition FROM entities WHERE id = ?",
|
||||||
|
(content_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
result = f"{row['name']}: {row['definition']}" if row else None
|
result = f"{row['name']}: {row['definition']}" if row else None
|
||||||
|
|
||||||
@@ -1067,7 +1084,10 @@ class SemanticSearch:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def find_similar_content(
|
def find_similar_content(
|
||||||
self, content_id: str, content_type: str, top_k: int = 5,
|
self,
|
||||||
|
content_id: str,
|
||||||
|
content_type: str,
|
||||||
|
top_k: int = 5,
|
||||||
) -> list[SemanticSearchResult]:
|
) -> list[SemanticSearchResult]:
|
||||||
"""
|
"""
|
||||||
查找与指定内容相似的内容
|
查找与指定内容相似的内容
|
||||||
@@ -1175,7 +1195,10 @@ class EntityPathDiscovery:
|
|||||||
return conn
|
return conn
|
||||||
|
|
||||||
def find_shortest_path(
|
def find_shortest_path(
|
||||||
self, source_entity_id: str, target_entity_id: str, max_depth: int = 5,
|
self,
|
||||||
|
source_entity_id: str,
|
||||||
|
target_entity_id: str,
|
||||||
|
max_depth: int = 5,
|
||||||
) -> EntityPath | None:
|
) -> EntityPath | None:
|
||||||
"""
|
"""
|
||||||
查找两个实体之间的最短路径(BFS算法)
|
查找两个实体之间的最短路径(BFS算法)
|
||||||
@@ -1192,7 +1215,8 @@ class EntityPathDiscovery:
|
|||||||
|
|
||||||
# 获取项目ID
|
# 获取项目ID
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,),
|
"SELECT project_id FROM entities WHERE id = ?",
|
||||||
|
(source_entity_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
@@ -1250,7 +1274,11 @@ class EntityPathDiscovery:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def find_all_paths(
|
def find_all_paths(
|
||||||
self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10,
|
self,
|
||||||
|
source_entity_id: str,
|
||||||
|
target_entity_id: str,
|
||||||
|
max_depth: int = 4,
|
||||||
|
max_paths: int = 10,
|
||||||
) -> list[EntityPath]:
|
) -> list[EntityPath]:
|
||||||
"""
|
"""
|
||||||
查找两个实体之间的所有路径(限制数量和深度)
|
查找两个实体之间的所有路径(限制数量和深度)
|
||||||
@@ -1268,7 +1296,8 @@ class EntityPathDiscovery:
|
|||||||
|
|
||||||
# 获取项目ID
|
# 获取项目ID
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,),
|
"SELECT project_id FROM entities WHERE id = ?",
|
||||||
|
(source_entity_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
@@ -1280,7 +1309,11 @@ class EntityPathDiscovery:
|
|||||||
paths = []
|
paths = []
|
||||||
|
|
||||||
def dfs(
|
def dfs(
|
||||||
current_id: str, target_id: str, path: list[str], visited: set[str], depth: int,
|
current_id: str,
|
||||||
|
target_id: str,
|
||||||
|
path: list[str],
|
||||||
|
visited: set[str],
|
||||||
|
depth: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if depth > max_depth:
|
if depth > max_depth:
|
||||||
return
|
return
|
||||||
@@ -1328,7 +1361,8 @@ class EntityPathDiscovery:
|
|||||||
nodes = []
|
nodes = []
|
||||||
for entity_id in entity_ids:
|
for entity_id in entity_ids:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT id, name, type FROM entities WHERE id = ?", (entity_id,),
|
"SELECT id, name, type FROM entities WHERE id = ?",
|
||||||
|
(entity_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if row:
|
if row:
|
||||||
nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
|
nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
|
||||||
@@ -1398,7 +1432,8 @@ class EntityPathDiscovery:
|
|||||||
|
|
||||||
# 获取项目ID
|
# 获取项目ID
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT project_id, name FROM entities WHERE id = ?", (entity_id,),
|
"SELECT project_id, name FROM entities WHERE id = ?",
|
||||||
|
(entity_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
@@ -1445,7 +1480,8 @@ class EntityPathDiscovery:
|
|||||||
|
|
||||||
# 获取邻居信息
|
# 获取邻居信息
|
||||||
neighbor_info = conn.execute(
|
neighbor_info = conn.execute(
|
||||||
"SELECT name, type FROM entities WHERE id = ?", (neighbor_id,),
|
"SELECT name, type FROM entities WHERE id = ?",
|
||||||
|
(neighbor_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if neighbor_info:
|
if neighbor_info:
|
||||||
@@ -1458,7 +1494,10 @@ class EntityPathDiscovery:
|
|||||||
"relation_type": neighbor["relation_type"],
|
"relation_type": neighbor["relation_type"],
|
||||||
"evidence": neighbor["evidence"],
|
"evidence": neighbor["evidence"],
|
||||||
"path": self._get_path_to_entity(
|
"path": self._get_path_to_entity(
|
||||||
entity_id, neighbor_id, project_id, conn,
|
entity_id,
|
||||||
|
neighbor_id,
|
||||||
|
project_id,
|
||||||
|
conn,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -1470,7 +1509,11 @@ class EntityPathDiscovery:
|
|||||||
return relations
|
return relations
|
||||||
|
|
||||||
def _get_path_to_entity(
|
def _get_path_to_entity(
|
||||||
self, source_id: str, target_id: str, project_id: str, conn: sqlite3.Connection,
|
self,
|
||||||
|
source_id: str,
|
||||||
|
target_id: str,
|
||||||
|
project_id: str,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""获取从源实体到目标实体的路径(简化版)"""
|
"""获取从源实体到目标实体的路径(简化版)"""
|
||||||
# BFS 找路径
|
# BFS 找路径
|
||||||
@@ -1565,7 +1608,8 @@ class EntityPathDiscovery:
|
|||||||
|
|
||||||
# 获取所有实体
|
# 获取所有实体
|
||||||
entities = conn.execute(
|
entities = conn.execute(
|
||||||
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT id, name FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
# 计算每个实体作为桥梁的次数
|
# 计算每个实体作为桥梁的次数
|
||||||
@@ -1706,7 +1750,8 @@ class KnowledgeGapDetection:
|
|||||||
|
|
||||||
# 检查每个实体的属性完整性
|
# 检查每个实体的属性完整性
|
||||||
entities = conn.execute(
|
entities = conn.execute(
|
||||||
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT id, name FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
@@ -1714,7 +1759,8 @@ class KnowledgeGapDetection:
|
|||||||
|
|
||||||
# 获取实体已有的属性
|
# 获取实体已有的属性
|
||||||
existing_attrs = conn.execute(
|
existing_attrs = conn.execute(
|
||||||
"SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,),
|
"SELECT template_id FROM entity_attributes WHERE entity_id = ?",
|
||||||
|
(entity_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
existing_template_ids = {a["template_id"] for a in existing_attrs}
|
existing_template_ids = {a["template_id"] for a in existing_attrs}
|
||||||
@@ -1726,7 +1772,8 @@ class KnowledgeGapDetection:
|
|||||||
missing_names = []
|
missing_names = []
|
||||||
for template_id in missing_templates:
|
for template_id in missing_templates:
|
||||||
template = conn.execute(
|
template = conn.execute(
|
||||||
"SELECT name FROM attribute_templates WHERE id = ?", (template_id,),
|
"SELECT name FROM attribute_templates WHERE id = ?",
|
||||||
|
(template_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if template:
|
if template:
|
||||||
missing_names.append(template["name"])
|
missing_names.append(template["name"])
|
||||||
@@ -1759,7 +1806,8 @@ class KnowledgeGapDetection:
|
|||||||
|
|
||||||
# 获取所有实体及其关系数量
|
# 获取所有实体及其关系数量
|
||||||
entities = conn.execute(
|
entities = conn.execute(
|
||||||
"SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT id, name, type FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
@@ -1900,7 +1948,8 @@ class KnowledgeGapDetection:
|
|||||||
|
|
||||||
# 分析转录文本中频繁提及但未提取为实体的词
|
# 分析转录文本中频繁提及但未提取为实体的词
|
||||||
transcripts = conn.execute(
|
transcripts = conn.execute(
|
||||||
"SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,),
|
"SELECT full_text FROM transcripts WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
# 合并所有文本
|
# 合并所有文本
|
||||||
@@ -1908,7 +1957,8 @@ class KnowledgeGapDetection:
|
|||||||
|
|
||||||
# 获取现有实体名称
|
# 获取现有实体名称
|
||||||
existing_entities = conn.execute(
|
existing_entities = conn.execute(
|
||||||
"SELECT name FROM entities WHERE project_id = ?", (project_id,),
|
"SELECT name FROM entities WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
existing_names = {e["name"].lower() for e in existing_entities}
|
existing_names = {e["name"].lower() for e in existing_entities}
|
||||||
@@ -2146,7 +2196,10 @@ class SearchManager:
|
|||||||
|
|
||||||
for t in transcripts:
|
for t in transcripts:
|
||||||
if t["full_text"] and self.semantic_search.index_embedding(
|
if t["full_text"] and self.semantic_search.index_embedding(
|
||||||
t["id"], "transcript", t["project_id"], t["full_text"],
|
t["id"],
|
||||||
|
"transcript",
|
||||||
|
t["project_id"],
|
||||||
|
t["full_text"],
|
||||||
):
|
):
|
||||||
semantic_stats["indexed"] += 1
|
semantic_stats["indexed"] += 1
|
||||||
else:
|
else:
|
||||||
@@ -2179,12 +2232,14 @@ class SearchManager:
|
|||||||
|
|
||||||
# 全文索引统计
|
# 全文索引统计
|
||||||
fulltext_count = conn.execute(
|
fulltext_count = conn.execute(
|
||||||
f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params,
|
f"SELECT COUNT(*) as count FROM search_indexes {where_clause}",
|
||||||
|
params,
|
||||||
).fetchone()["count"]
|
).fetchone()["count"]
|
||||||
|
|
||||||
# 语义索引统计
|
# 语义索引统计
|
||||||
semantic_count = conn.execute(
|
semantic_count = conn.execute(
|
||||||
f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params,
|
f"SELECT COUNT(*) as count FROM embeddings {where_clause}",
|
||||||
|
params,
|
||||||
).fetchone()["count"]
|
).fetchone()["count"]
|
||||||
|
|
||||||
# 按类型统计
|
# 按类型统计
|
||||||
@@ -2225,7 +2280,9 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
|
|||||||
|
|
||||||
|
|
||||||
def fulltext_search(
|
def fulltext_search(
|
||||||
query: str, project_id: str | None = None, limit: int = 20,
|
query: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""全文搜索便捷函数"""
|
"""全文搜索便捷函数"""
|
||||||
manager = get_search_manager()
|
manager = get_search_manager()
|
||||||
@@ -2233,7 +2290,9 @@ def fulltext_search(
|
|||||||
|
|
||||||
|
|
||||||
def semantic_search(
|
def semantic_search(
|
||||||
query: str, project_id: str | None = None, top_k: int = 10,
|
query: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
top_k: int = 10,
|
||||||
) -> list[SemanticSearchResult]:
|
) -> list[SemanticSearchResult]:
|
||||||
"""语义搜索便捷函数"""
|
"""语义搜索便捷函数"""
|
||||||
manager = get_search_manager()
|
manager = get_search_manager()
|
||||||
|
|||||||
@@ -464,7 +464,9 @@ class SecurityManager:
|
|||||||
return logs
|
return logs
|
||||||
|
|
||||||
def get_audit_stats(
|
def get_audit_stats(
|
||||||
self, start_time: str | None = None, end_time: str | None = None,
|
self,
|
||||||
|
start_time: str | None = None,
|
||||||
|
end_time: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""获取审计统计"""
|
"""获取审计统计"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -882,7 +884,10 @@ class SecurityManager:
|
|||||||
return success
|
return success
|
||||||
|
|
||||||
def apply_masking(
|
def apply_masking(
|
||||||
self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None,
|
self,
|
||||||
|
text: str,
|
||||||
|
project_id: str,
|
||||||
|
rule_types: list[MaskingRuleType] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""应用脱敏规则到文本"""
|
"""应用脱敏规则到文本"""
|
||||||
rules = self.get_masking_rules(project_id)
|
rules = self.get_masking_rules(project_id)
|
||||||
@@ -906,7 +911,9 @@ class SecurityManager:
|
|||||||
return masked_text
|
return masked_text
|
||||||
|
|
||||||
def apply_masking_to_entity(
|
def apply_masking_to_entity(
|
||||||
self, entity_data: dict[str, Any], project_id: str,
|
self,
|
||||||
|
entity_data: dict[str, Any],
|
||||||
|
project_id: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""对实体数据应用脱敏"""
|
"""对实体数据应用脱敏"""
|
||||||
masked_data = entity_data.copy()
|
masked_data = entity_data.copy()
|
||||||
@@ -982,7 +989,9 @@ class SecurityManager:
|
|||||||
return policy
|
return policy
|
||||||
|
|
||||||
def get_access_policies(
|
def get_access_policies(
|
||||||
self, project_id: str, active_only: bool = True,
|
self,
|
||||||
|
project_id: str,
|
||||||
|
active_only: bool = True,
|
||||||
) -> list[DataAccessPolicy]:
|
) -> list[DataAccessPolicy]:
|
||||||
"""获取数据访问策略"""
|
"""获取数据访问策略"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -1021,14 +1030,18 @@ class SecurityManager:
|
|||||||
return policies
|
return policies
|
||||||
|
|
||||||
def check_access_permission(
|
def check_access_permission(
|
||||||
self, policy_id: str, user_id: str, user_ip: str | None = None,
|
self,
|
||||||
|
policy_id: str,
|
||||||
|
user_id: str,
|
||||||
|
user_ip: str | None = None,
|
||||||
) -> tuple[bool, str | None]:
|
) -> tuple[bool, str | None]:
|
||||||
"""检查访问权限"""
|
"""检查访问权限"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,),
|
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
|
||||||
|
(policy_id,),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -1163,7 +1176,10 @@ class SecurityManager:
|
|||||||
return request
|
return request
|
||||||
|
|
||||||
def approve_access_request(
|
def approve_access_request(
|
||||||
self, request_id: str, approved_by: str, expires_hours: int = 24,
|
self,
|
||||||
|
request_id: str,
|
||||||
|
approved_by: str,
|
||||||
|
expires_hours: int = 24,
|
||||||
) -> AccessRequest | None:
|
) -> AccessRequest | None:
|
||||||
"""批准访问请求"""
|
"""批准访问请求"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
|||||||
@@ -588,7 +588,8 @@ class SubscriptionManager:
|
|||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1", (tier,),
|
"SELECT * FROM subscription_plans WHERE tier = ? AND is_active = 1",
|
||||||
|
(tier,),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
|
|
||||||
@@ -963,7 +964,9 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def cancel_subscription(
|
def cancel_subscription(
|
||||||
self, subscription_id: str, at_period_end: bool = True,
|
self,
|
||||||
|
subscription_id: str,
|
||||||
|
at_period_end: bool = True,
|
||||||
) -> Subscription | None:
|
) -> Subscription | None:
|
||||||
"""取消订阅"""
|
"""取消订阅"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1017,7 +1020,10 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def change_plan(
|
def change_plan(
|
||||||
self, subscription_id: str, new_plan_id: str, prorate: bool = True,
|
self,
|
||||||
|
subscription_id: str,
|
||||||
|
new_plan_id: str,
|
||||||
|
prorate: bool = True,
|
||||||
) -> Subscription | None:
|
) -> Subscription | None:
|
||||||
"""更改订阅计划"""
|
"""更改订阅计划"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1125,7 +1131,10 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_usage_summary(
|
def get_usage_summary(
|
||||||
self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""获取用量汇总"""
|
"""获取用量汇总"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1268,7 +1277,9 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def confirm_payment(
|
def confirm_payment(
|
||||||
self, payment_id: str, provider_payment_id: str | None = None,
|
self,
|
||||||
|
payment_id: str,
|
||||||
|
provider_payment_id: str | None = None,
|
||||||
) -> Payment | None:
|
) -> Payment | None:
|
||||||
"""确认支付完成"""
|
"""确认支付完成"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1361,7 +1372,11 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def list_payments(
|
def list_payments(
|
||||||
self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
) -> list[Payment]:
|
) -> list[Payment]:
|
||||||
"""列出支付记录"""
|
"""列出支付记录"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1501,7 +1516,11 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def list_invoices(
|
def list_invoices(
|
||||||
self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
) -> list[Invoice]:
|
) -> list[Invoice]:
|
||||||
"""列出发票"""
|
"""列出发票"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1581,7 +1600,12 @@ class SubscriptionManager:
|
|||||||
# ==================== 退款管理 ====================
|
# ==================== 退款管理 ====================
|
||||||
|
|
||||||
def request_refund(
|
def request_refund(
|
||||||
self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
payment_id: str,
|
||||||
|
amount: float,
|
||||||
|
reason: str,
|
||||||
|
requested_by: str,
|
||||||
) -> Refund:
|
) -> Refund:
|
||||||
"""申请退款"""
|
"""申请退款"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1690,7 +1714,9 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def complete_refund(
|
def complete_refund(
|
||||||
self, refund_id: str, provider_refund_id: str | None = None,
|
self,
|
||||||
|
refund_id: str,
|
||||||
|
provider_refund_id: str | None = None,
|
||||||
) -> Refund | None:
|
) -> Refund | None:
|
||||||
"""完成退款"""
|
"""完成退款"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1775,7 +1801,11 @@ class SubscriptionManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def list_refunds(
|
def list_refunds(
|
||||||
self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
) -> list[Refund]:
|
) -> list[Refund]:
|
||||||
"""列出退款记录"""
|
"""列出退款记录"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1902,7 +1932,10 @@ class SubscriptionManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def create_alipay_order(
|
def create_alipay_order(
|
||||||
self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly",
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
plan_id: str,
|
||||||
|
billing_cycle: str = "monthly",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""创建支付宝订单(占位实现)"""
|
"""创建支付宝订单(占位实现)"""
|
||||||
# 这里应该集成支付宝 SDK
|
# 这里应该集成支付宝 SDK
|
||||||
@@ -1919,7 +1952,10 @@ class SubscriptionManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def create_wechat_order(
|
def create_wechat_order(
|
||||||
self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly",
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
plan_id: str,
|
||||||
|
billing_cycle: str = "monthly",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""创建微信支付订单(占位实现)"""
|
"""创建微信支付订单(占位实现)"""
|
||||||
# 这里应该集成微信支付 SDK
|
# 这里应该集成微信支付 SDK
|
||||||
|
|||||||
@@ -433,7 +433,8 @@ class TenantManager:
|
|||||||
TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE
|
TenantTier(tier) if tier in [t.value for t in TenantTier] else TenantTier.FREE
|
||||||
)
|
)
|
||||||
resource_limits = self.DEFAULT_LIMITS.get(
|
resource_limits = self.DEFAULT_LIMITS.get(
|
||||||
tier_enum, self.DEFAULT_LIMITS[TenantTier.FREE],
|
tier_enum,
|
||||||
|
self.DEFAULT_LIMITS[TenantTier.FREE],
|
||||||
)
|
)
|
||||||
|
|
||||||
tenant = Tenant(
|
tenant = Tenant(
|
||||||
@@ -612,7 +613,11 @@ class TenantManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def list_tenants(
|
def list_tenants(
|
||||||
self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0,
|
self,
|
||||||
|
status: str | None = None,
|
||||||
|
tier: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
) -> list[Tenant]:
|
) -> list[Tenant]:
|
||||||
"""列出租户"""
|
"""列出租户"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1103,7 +1108,11 @@ class TenantManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def update_member_role(
|
def update_member_role(
|
||||||
self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
member_id: str,
|
||||||
|
role: str,
|
||||||
|
permissions: list[str] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""更新成员角色"""
|
"""更新成员角色"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1268,7 +1277,10 @@ class TenantManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_usage_stats(
|
def get_usage_stats(
|
||||||
self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None,
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""获取使用统计"""
|
"""获取使用统计"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
@@ -1314,23 +1326,28 @@ class TenantManager:
|
|||||||
"limits": limits,
|
"limits": limits,
|
||||||
"usage_percentages": {
|
"usage_percentages": {
|
||||||
"storage": self._calc_percentage(
|
"storage": self._calc_percentage(
|
||||||
row["total_storage"] or 0, limits.get("max_storage_mb", 0) * 1024 * 1024,
|
row["total_storage"] or 0,
|
||||||
|
limits.get("max_storage_mb", 0) * 1024 * 1024,
|
||||||
),
|
),
|
||||||
"transcription": self._calc_percentage(
|
"transcription": self._calc_percentage(
|
||||||
row["total_transcription"] or 0,
|
row["total_transcription"] or 0,
|
||||||
limits.get("max_transcription_minutes", 0) * 60,
|
limits.get("max_transcription_minutes", 0) * 60,
|
||||||
),
|
),
|
||||||
"api_calls": self._calc_percentage(
|
"api_calls": self._calc_percentage(
|
||||||
row["total_api_calls"] or 0, limits.get("max_api_calls_per_day", 0),
|
row["total_api_calls"] or 0,
|
||||||
|
limits.get("max_api_calls_per_day", 0),
|
||||||
),
|
),
|
||||||
"projects": self._calc_percentage(
|
"projects": self._calc_percentage(
|
||||||
row["max_projects"] or 0, limits.get("max_projects", 0),
|
row["max_projects"] or 0,
|
||||||
|
limits.get("max_projects", 0),
|
||||||
),
|
),
|
||||||
"entities": self._calc_percentage(
|
"entities": self._calc_percentage(
|
||||||
row["max_entities"] or 0, limits.get("max_entities", 0),
|
row["max_entities"] or 0,
|
||||||
|
limits.get("max_entities", 0),
|
||||||
),
|
),
|
||||||
"members": self._calc_percentage(
|
"members": self._calc_percentage(
|
||||||
row["max_members"] or 0, limits.get("max_team_members", 0),
|
row["max_members"] or 0,
|
||||||
|
limits.get("max_team_members", 0),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1406,8 +1423,10 @@ class TenantManager:
|
|||||||
|
|
||||||
def _validate_domain(self, domain: str) -> bool:
|
def _validate_domain(self, domain: str) -> bool:
|
||||||
"""验证域名格式"""
|
"""验证域名格式"""
|
||||||
pattern = (r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*"
|
pattern = (
|
||||||
r"[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$")
|
r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])?\.)*"
|
||||||
|
r"[a-zA-Z0-9](?:[a-zA-Z0-9-]{0, 61}[a-zA-Z0-9])$"
|
||||||
|
)
|
||||||
return bool(re.match(pattern, domain))
|
return bool(re.match(pattern, domain))
|
||||||
|
|
||||||
def _check_domain_verification(self, domain: str, token: str, method: str) -> bool:
|
def _check_domain_verification(self, domain: str, token: str, method: str) -> bool:
|
||||||
|
|||||||
@@ -159,7 +159,8 @@ def test_cache_manager() -> None:
|
|||||||
|
|
||||||
# 批量操作
|
# 批量操作
|
||||||
cache.set_many(
|
cache.set_many(
|
||||||
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"}, ttl=60,
|
{"batch_key_1": "value1", "batch_key_2": "value2", "batch_key_3": "value3"},
|
||||||
|
ttl=60,
|
||||||
)
|
)
|
||||||
print(" ✓ 批量设置缓存")
|
print(" ✓ 批量设置缓存")
|
||||||
|
|
||||||
@@ -208,7 +209,8 @@ def test_task_queue() -> None:
|
|||||||
|
|
||||||
# 提交任务
|
# 提交任务
|
||||||
task_id = queue.submit(
|
task_id = queue.submit(
|
||||||
task_type="test_task", payload={"test": "data", "timestamp": time.time()},
|
task_type="test_task",
|
||||||
|
payload={"test": "data", "timestamp": time.time()},
|
||||||
)
|
)
|
||||||
print(" ✓ 提交任务: {task_id}")
|
print(" ✓ 提交任务: {task_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ def test_tenant_management() -> None:
|
|||||||
# 1. 创建租户
|
# 1. 创建租户
|
||||||
print("\n1.1 创建租户...")
|
print("\n1.1 创建租户...")
|
||||||
tenant = manager.create_tenant(
|
tenant = manager.create_tenant(
|
||||||
name="Test Company", owner_id="user_001", tier="pro", description="A test company tenant",
|
name="Test Company",
|
||||||
|
owner_id="user_001",
|
||||||
|
tier="pro",
|
||||||
|
description="A test company tenant",
|
||||||
)
|
)
|
||||||
print(f"✅ 租户创建成功: {tenant.id}")
|
print(f"✅ 租户创建成功: {tenant.id}")
|
||||||
print(f" - 名称: {tenant.name}")
|
print(f" - 名称: {tenant.name}")
|
||||||
@@ -53,7 +56,9 @@ def test_tenant_management() -> None:
|
|||||||
# 4. 更新租户
|
# 4. 更新租户
|
||||||
print("\n1.4 更新租户信息...")
|
print("\n1.4 更新租户信息...")
|
||||||
updated = manager.update_tenant(
|
updated = manager.update_tenant(
|
||||||
tenant_id=tenant.id, name="Test Company Updated", tier="enterprise",
|
tenant_id=tenant.id,
|
||||||
|
name="Test Company Updated",
|
||||||
|
tier="enterprise",
|
||||||
)
|
)
|
||||||
assert updated is not None, "更新租户失败"
|
assert updated is not None, "更新租户失败"
|
||||||
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
|
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
|
||||||
@@ -163,7 +168,10 @@ def test_member_management(tenant_id: str) -> None:
|
|||||||
# 1. 邀请成员
|
# 1. 邀请成员
|
||||||
print("\n4.1 邀请成员...")
|
print("\n4.1 邀请成员...")
|
||||||
member1 = manager.invite_member(
|
member1 = manager.invite_member(
|
||||||
tenant_id=tenant_id, email="admin@test.com", role="admin", invited_by="user_001",
|
tenant_id=tenant_id,
|
||||||
|
email="admin@test.com",
|
||||||
|
role="admin",
|
||||||
|
invited_by="user_001",
|
||||||
)
|
)
|
||||||
print(f"✅ 成员邀请成功: {member1.email}")
|
print(f"✅ 成员邀请成功: {member1.email}")
|
||||||
print(f" - ID: {member1.id}")
|
print(f" - ID: {member1.id}")
|
||||||
@@ -171,7 +179,10 @@ def test_member_management(tenant_id: str) -> None:
|
|||||||
print(f" - 权限: {member1.permissions}")
|
print(f" - 权限: {member1.permissions}")
|
||||||
|
|
||||||
member2 = manager.invite_member(
|
member2 = manager.invite_member(
|
||||||
tenant_id=tenant_id, email="member@test.com", role="member", invited_by="user_001",
|
tenant_id=tenant_id,
|
||||||
|
email="member@test.com",
|
||||||
|
role="member",
|
||||||
|
invited_by="user_001",
|
||||||
)
|
)
|
||||||
print(f"✅ 成员邀请成功: {member2.email}")
|
print(f"✅ 成员邀请成功: {member2.email}")
|
||||||
|
|
||||||
|
|||||||
@@ -205,7 +205,8 @@ def test_subscription_manager() -> None:
|
|||||||
|
|
||||||
# 更改计划
|
# 更改计划
|
||||||
changed = manager.change_plan(
|
changed = manager.change_plan(
|
||||||
subscription_id=subscription.id, new_plan_id=enterprise_plan.id,
|
subscription_id=subscription.id,
|
||||||
|
new_plan_id=enterprise_plan.id,
|
||||||
)
|
)
|
||||||
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
|
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
|
||||||
|
|
||||||
|
|||||||
@@ -181,14 +181,16 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str) -> None:
|
|||||||
# 2. 趋势预测
|
# 2. 趋势预测
|
||||||
print("2. 趋势预测...")
|
print("2. 趋势预测...")
|
||||||
trend_result = await manager.predict(
|
trend_result = await manager.predict(
|
||||||
trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]},
|
trend_model_id,
|
||||||
|
{"historical_values": [10, 12, 15, 14, 18, 20, 22]},
|
||||||
)
|
)
|
||||||
print(f" 预测结果: {trend_result.prediction_data}")
|
print(f" 预测结果: {trend_result.prediction_data}")
|
||||||
|
|
||||||
# 3. 异常检测
|
# 3. 异常检测
|
||||||
print("3. 异常检测...")
|
print("3. 异常检测...")
|
||||||
anomaly_result = await manager.predict(
|
anomaly_result = await manager.predict(
|
||||||
anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]},
|
anomaly_model_id,
|
||||||
|
{"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]},
|
||||||
)
|
)
|
||||||
print(f" 检测结果: {anomaly_result.prediction_data}")
|
print(f" 检测结果: {anomaly_result.prediction_data}")
|
||||||
|
|
||||||
|
|||||||
@@ -525,7 +525,8 @@ class TestGrowthManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
referral = self.manager.generate_referral_code(
|
referral = self.manager.generate_referral_code(
|
||||||
program_id=program_id, referrer_id="referrer_user_001",
|
program_id=program_id,
|
||||||
|
referrer_id="referrer_user_001",
|
||||||
)
|
)
|
||||||
|
|
||||||
if referral:
|
if referral:
|
||||||
@@ -551,7 +552,8 @@ class TestGrowthManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
success = self.manager.apply_referral_code(
|
success = self.manager.apply_referral_code(
|
||||||
referral_code=referral_code, referee_id="new_user_001",
|
referral_code=referral_code,
|
||||||
|
referee_id="new_user_001",
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -618,7 +620,9 @@ class TestGrowthManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
incentives = self.manager.check_team_incentive_eligibility(
|
incentives = self.manager.check_team_incentive_eligibility(
|
||||||
tenant_id=self.test_tenant_id, current_tier="free", team_size=5,
|
tenant_id=self.test_tenant_id,
|
||||||
|
current_tier="free",
|
||||||
|
team_size=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"找到 {len(incentives)} 个符合条件的激励")
|
self.log(f"找到 {len(incentives)} 个符合条件的激励")
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
|
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create SDK: {str(e)}", success=False)
|
self.log(f"Failed to create SDK: {e!s}", success=False)
|
||||||
|
|
||||||
def test_sdk_list(self) -> None:
|
def test_sdk_list(self) -> None:
|
||||||
"""测试列出 SDK"""
|
"""测试列出 SDK"""
|
||||||
@@ -179,7 +179,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Search found {len(search_results)} SDKs")
|
self.log(f"Search found {len(search_results)} SDKs")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list SDKs: {str(e)}", success=False)
|
self.log(f"Failed to list SDKs: {e!s}", success=False)
|
||||||
|
|
||||||
def test_sdk_get(self) -> None:
|
def test_sdk_get(self) -> None:
|
||||||
"""测试获取 SDK 详情"""
|
"""测试获取 SDK 详情"""
|
||||||
@@ -191,19 +191,20 @@ class TestDeveloperEcosystem:
|
|||||||
else:
|
else:
|
||||||
self.log("SDK not found", success=False)
|
self.log("SDK not found", success=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get SDK: {str(e)}", success=False)
|
self.log(f"Failed to get SDK: {e!s}", success=False)
|
||||||
|
|
||||||
def test_sdk_update(self) -> None:
|
def test_sdk_update(self) -> None:
|
||||||
"""测试更新 SDK"""
|
"""测试更新 SDK"""
|
||||||
try:
|
try:
|
||||||
if self.created_ids["sdk"]:
|
if self.created_ids["sdk"]:
|
||||||
sdk = self.manager.update_sdk_release(
|
sdk = self.manager.update_sdk_release(
|
||||||
self.created_ids["sdk"][0], description="Updated description",
|
self.created_ids["sdk"][0],
|
||||||
|
description="Updated description",
|
||||||
)
|
)
|
||||||
if sdk:
|
if sdk:
|
||||||
self.log(f"Updated SDK: {sdk.name}")
|
self.log(f"Updated SDK: {sdk.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to update SDK: {str(e)}", success=False)
|
self.log(f"Failed to update SDK: {e!s}", success=False)
|
||||||
|
|
||||||
def test_sdk_publish(self) -> None:
|
def test_sdk_publish(self) -> None:
|
||||||
"""测试发布 SDK"""
|
"""测试发布 SDK"""
|
||||||
@@ -213,7 +214,7 @@ class TestDeveloperEcosystem:
|
|||||||
if sdk:
|
if sdk:
|
||||||
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
|
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to publish SDK: {str(e)}", success=False)
|
self.log(f"Failed to publish SDK: {e!s}", success=False)
|
||||||
|
|
||||||
def test_sdk_version_add(self) -> None:
|
def test_sdk_version_add(self) -> None:
|
||||||
"""测试添加 SDK 版本"""
|
"""测试添加 SDK 版本"""
|
||||||
@@ -230,7 +231,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.log(f"Added SDK version: {version.version}")
|
self.log(f"Added SDK version: {version.version}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to add SDK version: {str(e)}", success=False)
|
self.log(f"Failed to add SDK version: {e!s}", success=False)
|
||||||
|
|
||||||
def test_template_create(self) -> None:
|
def test_template_create(self) -> None:
|
||||||
"""测试创建模板"""
|
"""测试创建模板"""
|
||||||
@@ -273,7 +274,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Created free template: {template_free.name}")
|
self.log(f"Created free template: {template_free.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create template: {str(e)}", success=False)
|
self.log(f"Failed to create template: {e!s}", success=False)
|
||||||
|
|
||||||
def test_template_list(self) -> None:
|
def test_template_list(self) -> None:
|
||||||
"""测试列出模板"""
|
"""测试列出模板"""
|
||||||
@@ -290,7 +291,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Found {len(free_templates)} free templates")
|
self.log(f"Found {len(free_templates)} free templates")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list templates: {str(e)}", success=False)
|
self.log(f"Failed to list templates: {e!s}", success=False)
|
||||||
|
|
||||||
def test_template_get(self) -> None:
|
def test_template_get(self) -> None:
|
||||||
"""测试获取模板详情"""
|
"""测试获取模板详情"""
|
||||||
@@ -300,19 +301,20 @@ class TestDeveloperEcosystem:
|
|||||||
if template:
|
if template:
|
||||||
self.log(f"Retrieved template: {template.name}")
|
self.log(f"Retrieved template: {template.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get template: {str(e)}", success=False)
|
self.log(f"Failed to get template: {e!s}", success=False)
|
||||||
|
|
||||||
def test_template_approve(self) -> None:
|
def test_template_approve(self) -> None:
|
||||||
"""测试审核通过模板"""
|
"""测试审核通过模板"""
|
||||||
try:
|
try:
|
||||||
if self.created_ids["template"]:
|
if self.created_ids["template"]:
|
||||||
template = self.manager.approve_template(
|
template = self.manager.approve_template(
|
||||||
self.created_ids["template"][0], reviewed_by="admin_001",
|
self.created_ids["template"][0],
|
||||||
|
reviewed_by="admin_001",
|
||||||
)
|
)
|
||||||
if template:
|
if template:
|
||||||
self.log(f"Approved template: {template.name}")
|
self.log(f"Approved template: {template.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to approve template: {str(e)}", success=False)
|
self.log(f"Failed to approve template: {e!s}", success=False)
|
||||||
|
|
||||||
def test_template_publish(self) -> None:
|
def test_template_publish(self) -> None:
|
||||||
"""测试发布模板"""
|
"""测试发布模板"""
|
||||||
@@ -322,7 +324,7 @@ class TestDeveloperEcosystem:
|
|||||||
if template:
|
if template:
|
||||||
self.log(f"Published template: {template.name}")
|
self.log(f"Published template: {template.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to publish template: {str(e)}", success=False)
|
self.log(f"Failed to publish template: {e!s}", success=False)
|
||||||
|
|
||||||
def test_template_review(self) -> None:
|
def test_template_review(self) -> None:
|
||||||
"""测试添加模板评价"""
|
"""测试添加模板评价"""
|
||||||
@@ -338,7 +340,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.log(f"Added template review: {review.rating} stars")
|
self.log(f"Added template review: {review.rating} stars")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to add template review: {str(e)}", success=False)
|
self.log(f"Failed to add template review: {e!s}", success=False)
|
||||||
|
|
||||||
def test_plugin_create(self) -> None:
|
def test_plugin_create(self) -> None:
|
||||||
"""测试创建插件"""
|
"""测试创建插件"""
|
||||||
@@ -384,7 +386,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Created free plugin: {plugin_free.name}")
|
self.log(f"Created free plugin: {plugin_free.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create plugin: {str(e)}", success=False)
|
self.log(f"Failed to create plugin: {e!s}", success=False)
|
||||||
|
|
||||||
def test_plugin_list(self) -> None:
|
def test_plugin_list(self) -> None:
|
||||||
"""测试列出插件"""
|
"""测试列出插件"""
|
||||||
@@ -397,7 +399,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Found {len(integration_plugins)} integration plugins")
|
self.log(f"Found {len(integration_plugins)} integration plugins")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list plugins: {str(e)}", success=False)
|
self.log(f"Failed to list plugins: {e!s}", success=False)
|
||||||
|
|
||||||
def test_plugin_get(self) -> None:
|
def test_plugin_get(self) -> None:
|
||||||
"""测试获取插件详情"""
|
"""测试获取插件详情"""
|
||||||
@@ -407,7 +409,7 @@ class TestDeveloperEcosystem:
|
|||||||
if plugin:
|
if plugin:
|
||||||
self.log(f"Retrieved plugin: {plugin.name}")
|
self.log(f"Retrieved plugin: {plugin.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get plugin: {str(e)}", success=False)
|
self.log(f"Failed to get plugin: {e!s}", success=False)
|
||||||
|
|
||||||
def test_plugin_review(self) -> None:
|
def test_plugin_review(self) -> None:
|
||||||
"""测试审核插件"""
|
"""测试审核插件"""
|
||||||
@@ -422,7 +424,7 @@ class TestDeveloperEcosystem:
|
|||||||
if plugin:
|
if plugin:
|
||||||
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
|
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to review plugin: {str(e)}", success=False)
|
self.log(f"Failed to review plugin: {e!s}", success=False)
|
||||||
|
|
||||||
def test_plugin_publish(self) -> None:
|
def test_plugin_publish(self) -> None:
|
||||||
"""测试发布插件"""
|
"""测试发布插件"""
|
||||||
@@ -432,7 +434,7 @@ class TestDeveloperEcosystem:
|
|||||||
if plugin:
|
if plugin:
|
||||||
self.log(f"Published plugin: {plugin.name}")
|
self.log(f"Published plugin: {plugin.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to publish plugin: {str(e)}", success=False)
|
self.log(f"Failed to publish plugin: {e!s}", success=False)
|
||||||
|
|
||||||
def test_plugin_review_add(self) -> None:
|
def test_plugin_review_add(self) -> None:
|
||||||
"""测试添加插件评价"""
|
"""测试添加插件评价"""
|
||||||
@@ -448,7 +450,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.log(f"Added plugin review: {review.rating} stars")
|
self.log(f"Added plugin review: {review.rating} stars")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to add plugin review: {str(e)}", success=False)
|
self.log(f"Failed to add plugin review: {e!s}", success=False)
|
||||||
|
|
||||||
def test_developer_profile_create(self) -> None:
|
def test_developer_profile_create(self) -> None:
|
||||||
"""测试创建开发者档案"""
|
"""测试创建开发者档案"""
|
||||||
@@ -479,7 +481,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Created developer profile: {profile2.display_name}")
|
self.log(f"Created developer profile: {profile2.display_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create developer profile: {str(e)}", success=False)
|
self.log(f"Failed to create developer profile: {e!s}", success=False)
|
||||||
|
|
||||||
def test_developer_profile_get(self) -> None:
|
def test_developer_profile_get(self) -> None:
|
||||||
"""测试获取开发者档案"""
|
"""测试获取开发者档案"""
|
||||||
@@ -489,19 +491,20 @@ class TestDeveloperEcosystem:
|
|||||||
if profile:
|
if profile:
|
||||||
self.log(f"Retrieved developer profile: {profile.display_name}")
|
self.log(f"Retrieved developer profile: {profile.display_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get developer profile: {str(e)}", success=False)
|
self.log(f"Failed to get developer profile: {e!s}", success=False)
|
||||||
|
|
||||||
def test_developer_verify(self) -> None:
|
def test_developer_verify(self) -> None:
|
||||||
"""测试验证开发者"""
|
"""测试验证开发者"""
|
||||||
try:
|
try:
|
||||||
if self.created_ids["developer"]:
|
if self.created_ids["developer"]:
|
||||||
profile = self.manager.verify_developer(
|
profile = self.manager.verify_developer(
|
||||||
self.created_ids["developer"][0], DeveloperStatus.VERIFIED,
|
self.created_ids["developer"][0],
|
||||||
|
DeveloperStatus.VERIFIED,
|
||||||
)
|
)
|
||||||
if profile:
|
if profile:
|
||||||
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
|
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to verify developer: {str(e)}", success=False)
|
self.log(f"Failed to verify developer: {e!s}", success=False)
|
||||||
|
|
||||||
def test_developer_stats_update(self) -> None:
|
def test_developer_stats_update(self) -> None:
|
||||||
"""测试更新开发者统计"""
|
"""测试更新开发者统计"""
|
||||||
@@ -513,7 +516,7 @@ class TestDeveloperEcosystem:
|
|||||||
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates",
|
f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to update developer stats: {str(e)}", success=False)
|
self.log(f"Failed to update developer stats: {e!s}", success=False)
|
||||||
|
|
||||||
def test_code_example_create(self) -> None:
|
def test_code_example_create(self) -> None:
|
||||||
"""测试创建代码示例"""
|
"""测试创建代码示例"""
|
||||||
@@ -562,7 +565,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f"Created code example: {example_js.title}")
|
self.log(f"Created code example: {example_js.title}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create code example: {str(e)}", success=False)
|
self.log(f"Failed to create code example: {e!s}", success=False)
|
||||||
|
|
||||||
def test_code_example_list(self) -> None:
|
def test_code_example_list(self) -> None:
|
||||||
"""测试列出代码示例"""
|
"""测试列出代码示例"""
|
||||||
@@ -575,7 +578,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f"Found {len(python_examples)} Python examples")
|
self.log(f"Found {len(python_examples)} Python examples")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list code examples: {str(e)}", success=False)
|
self.log(f"Failed to list code examples: {e!s}", success=False)
|
||||||
|
|
||||||
def test_code_example_get(self) -> None:
|
def test_code_example_get(self) -> None:
|
||||||
"""测试获取代码示例详情"""
|
"""测试获取代码示例详情"""
|
||||||
@@ -587,7 +590,7 @@ console.log('Upload complete:', result.id);
|
|||||||
f"Retrieved code example: {example.title} (views: {example.view_count})",
|
f"Retrieved code example: {example.title} (views: {example.view_count})",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get code example: {str(e)}", success=False)
|
self.log(f"Failed to get code example: {e!s}", success=False)
|
||||||
|
|
||||||
def test_portal_config_create(self) -> None:
|
def test_portal_config_create(self) -> None:
|
||||||
"""测试创建开发者门户配置"""
|
"""测试创建开发者门户配置"""
|
||||||
@@ -608,7 +611,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f"Created portal config: {config.name}")
|
self.log(f"Created portal config: {config.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create portal config: {str(e)}", success=False)
|
self.log(f"Failed to create portal config: {e!s}", success=False)
|
||||||
|
|
||||||
def test_portal_config_get(self) -> None:
|
def test_portal_config_get(self) -> None:
|
||||||
"""测试获取开发者门户配置"""
|
"""测试获取开发者门户配置"""
|
||||||
@@ -624,7 +627,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f"Active portal config: {active_config.name}")
|
self.log(f"Active portal config: {active_config.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get portal config: {str(e)}", success=False)
|
self.log(f"Failed to get portal config: {e!s}", success=False)
|
||||||
|
|
||||||
def test_revenue_record(self) -> None:
|
def test_revenue_record(self) -> None:
|
||||||
"""测试记录开发者收益"""
|
"""测试记录开发者收益"""
|
||||||
@@ -644,7 +647,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f" - Platform fee: {revenue.platform_fee}")
|
self.log(f" - Platform fee: {revenue.platform_fee}")
|
||||||
self.log(f" - Developer earnings: {revenue.developer_earnings}")
|
self.log(f" - Developer earnings: {revenue.developer_earnings}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to record revenue: {str(e)}", success=False)
|
self.log(f"Failed to record revenue: {e!s}", success=False)
|
||||||
|
|
||||||
def test_revenue_summary(self) -> None:
|
def test_revenue_summary(self) -> None:
|
||||||
"""测试获取开发者收益汇总"""
|
"""测试获取开发者收益汇总"""
|
||||||
@@ -659,7 +662,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f" - Total earnings: {summary['total_earnings']}")
|
self.log(f" - Total earnings: {summary['total_earnings']}")
|
||||||
self.log(f" - Transaction count: {summary['transaction_count']}")
|
self.log(f" - Transaction count: {summary['transaction_count']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get revenue summary: {str(e)}", success=False)
|
self.log(f"Failed to get revenue summary: {e!s}", success=False)
|
||||||
|
|
||||||
def print_summary(self) -> None:
|
def print_summary(self) -> None:
|
||||||
"""打印测试摘要"""
|
"""打印测试摘要"""
|
||||||
|
|||||||
@@ -129,7 +129,9 @@ class TestOpsManager:
|
|||||||
|
|
||||||
# 更新告警规则
|
# 更新告警规则
|
||||||
updated_rule = self.manager.update_alert_rule(
|
updated_rule = self.manager.update_alert_rule(
|
||||||
rule1.id, threshold=85.0, description="更新后的描述",
|
rule1.id,
|
||||||
|
threshold=85.0,
|
||||||
|
description="更新后的描述",
|
||||||
)
|
)
|
||||||
assert updated_rule.threshold == 85.0
|
assert updated_rule.threshold == 85.0
|
||||||
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
|
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
|
||||||
@@ -421,7 +423,9 @@ class TestOpsManager:
|
|||||||
|
|
||||||
# 模拟扩缩容评估
|
# 模拟扩缩容评估
|
||||||
event = self.manager.evaluate_scaling_policy(
|
event = self.manager.evaluate_scaling_policy(
|
||||||
policy_id=policy.id, current_instances=3, current_utilization=0.85,
|
policy_id=policy.id,
|
||||||
|
current_instances=3,
|
||||||
|
current_utilization=0.85,
|
||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
@@ -439,7 +443,8 @@ class TestOpsManager:
|
|||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,),
|
"DELETE FROM auto_scaling_policies WHERE tenant_id = ?",
|
||||||
|
(self.tenant_id,),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up auto scaling test data")
|
self.log("Cleaned up auto scaling test data")
|
||||||
@@ -530,7 +535,8 @@ class TestOpsManager:
|
|||||||
|
|
||||||
# 发起故障转移
|
# 发起故障转移
|
||||||
event = self.manager.initiate_failover(
|
event = self.manager.initiate_failover(
|
||||||
config_id=config.id, reason="Primary region health check failed",
|
config_id=config.id,
|
||||||
|
reason="Primary region health check failed",
|
||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
@@ -638,7 +644,9 @@ class TestOpsManager:
|
|||||||
# 生成成本报告
|
# 生成成本报告
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
report = self.manager.generate_cost_report(
|
report = self.manager.generate_cost_report(
|
||||||
tenant_id=self.tenant_id, year=now.year, month=now.month,
|
tenant_id=self.tenant_id,
|
||||||
|
year=now.year,
|
||||||
|
month=now.month,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Generated cost report: {report.id}")
|
self.log(f"Generated cost report: {report.id}")
|
||||||
@@ -691,7 +699,8 @@ class TestOpsManager:
|
|||||||
)
|
)
|
||||||
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM idle_resources WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"DELETE FROM resource_utilizations WHERE tenant_id = ?", (self.tenant_id,),
|
"DELETE FROM resource_utilizations WHERE tenant_id = ?",
|
||||||
|
(self.tenant_id,),
|
||||||
)
|
)
|
||||||
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|||||||
@@ -19,7 +19,11 @@ class TingwuClient:
|
|||||||
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
|
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
|
||||||
|
|
||||||
def _sign_request(
|
def _sign_request(
|
||||||
self, method: str, uri: str, query: str = "", body: str = "",
|
self,
|
||||||
|
method: str,
|
||||||
|
uri: str,
|
||||||
|
query: str = "",
|
||||||
|
body: str = "",
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""阿里云签名 V3"""
|
"""阿里云签名 V3"""
|
||||||
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
@@ -43,7 +47,8 @@ class TingwuClient:
|
|||||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||||
|
|
||||||
config = open_api_models.Config(
|
config = open_api_models.Config(
|
||||||
access_key_id=self.access_key, access_key_secret=self.secret_key,
|
access_key_id=self.access_key,
|
||||||
|
access_key_secret=self.secret_key,
|
||||||
)
|
)
|
||||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||||
client = TingwuSDKClient(config)
|
client = TingwuSDKClient(config)
|
||||||
@@ -53,7 +58,8 @@ class TingwuClient:
|
|||||||
input=tingwu_models.Input(source="OSS", file_url=audio_url),
|
input=tingwu_models.Input(source="OSS", file_url=audio_url),
|
||||||
parameters=tingwu_models.Parameters(
|
parameters=tingwu_models.Parameters(
|
||||||
transcription=tingwu_models.Transcription(
|
transcription=tingwu_models.Transcription(
|
||||||
diarization_enabled=True, sentence_max_length=20,
|
diarization_enabled=True,
|
||||||
|
sentence_max_length=20,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -73,7 +79,10 @@ class TingwuClient:
|
|||||||
return f"mock_task_{int(time.time())}"
|
return f"mock_task_{int(time.time())}"
|
||||||
|
|
||||||
def get_task_result(
|
def get_task_result(
|
||||||
self, task_id: str, max_retries: int = 60, interval: int = 5,
|
self,
|
||||||
|
task_id: str,
|
||||||
|
max_retries: int = 60,
|
||||||
|
interval: int = 5,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""获取任务结果"""
|
"""获取任务结果"""
|
||||||
try:
|
try:
|
||||||
@@ -83,7 +92,8 @@ class TingwuClient:
|
|||||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||||
|
|
||||||
config = open_api_models.Config(
|
config = open_api_models.Config(
|
||||||
access_key_id=self.access_key, access_key_secret=self.secret_key,
|
access_key_id=self.access_key,
|
||||||
|
access_key_secret=self.secret_key,
|
||||||
)
|
)
|
||||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||||
client = TingwuSDKClient(config)
|
client = TingwuSDKClient(config)
|
||||||
|
|||||||
@@ -264,7 +264,9 @@ class WebhookNotifier:
|
|||||||
secret_enc = config.secret.encode("utf-8")
|
secret_enc = config.secret.encode("utf-8")
|
||||||
string_to_sign = f"{timestamp}\n{config.secret}"
|
string_to_sign = f"{timestamp}\n{config.secret}"
|
||||||
hmac_code = hmac.new(
|
hmac_code = hmac.new(
|
||||||
secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256,
|
secret_enc,
|
||||||
|
string_to_sign.encode("utf-8"),
|
||||||
|
digestmod=hashlib.sha256,
|
||||||
).digest()
|
).digest()
|
||||||
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
|
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
|
||||||
url = f"{config.url}×tamp = {timestamp}&sign = {sign}"
|
url = f"{config.url}×tamp = {timestamp}&sign = {sign}"
|
||||||
@@ -497,7 +499,10 @@ class WorkflowManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def list_workflows(
|
def list_workflows(
|
||||||
self, project_id: str = None, status: str = None, workflow_type: str = None,
|
self,
|
||||||
|
project_id: str | None = None,
|
||||||
|
status: str = None,
|
||||||
|
workflow_type: str = None,
|
||||||
) -> list[Workflow]:
|
) -> list[Workflow]:
|
||||||
"""列出工作流"""
|
"""列出工作流"""
|
||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
@@ -518,7 +523,8 @@ class WorkflowManager:
|
|||||||
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
|
where_clause = " AND ".join(conditions) if conditions else "1 = 1"
|
||||||
|
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params,
|
f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC",
|
||||||
|
params,
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
return [self._row_to_workflow(row) for row in rows]
|
return [self._row_to_workflow(row) for row in rows]
|
||||||
@@ -780,7 +786,8 @@ class WorkflowManager:
|
|||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
try:
|
try:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,),
|
"SELECT * FROM webhook_configs WHERE id = ?",
|
||||||
|
(webhook_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
@@ -962,9 +969,9 @@ class WorkflowManager:
|
|||||||
|
|
||||||
def list_logs(
|
def list_logs(
|
||||||
self,
|
self,
|
||||||
workflow_id: str = None,
|
workflow_id: str | None = None,
|
||||||
task_id: str = None,
|
task_id: str | None = None,
|
||||||
status: str = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[WorkflowLog]:
|
) -> list[WorkflowLog]:
|
||||||
@@ -1074,7 +1081,7 @@ class WorkflowManager:
|
|||||||
|
|
||||||
# ==================== Workflow Execution ====================
|
# ==================== Workflow Execution ====================
|
||||||
|
|
||||||
async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict:
|
async def execute_workflow(self, workflow_id: str, input_data: dict | None = None) -> dict:
|
||||||
"""执行工作流"""
|
"""执行工作流"""
|
||||||
workflow = self.get_workflow(workflow_id)
|
workflow = self.get_workflow(workflow_id)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
@@ -1159,7 +1166,10 @@ class WorkflowManager:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _execute_tasks_with_deps(
|
async def _execute_tasks_with_deps(
|
||||||
self, tasks: list[WorkflowTask], input_data: dict, log_id: str,
|
self,
|
||||||
|
tasks: list[WorkflowTask],
|
||||||
|
input_data: dict,
|
||||||
|
log_id: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""按依赖顺序执行任务"""
|
"""按依赖顺序执行任务"""
|
||||||
results = {}
|
results = {}
|
||||||
@@ -1413,7 +1423,10 @@ class WorkflowManager:
|
|||||||
# ==================== Notification ====================
|
# ==================== Notification ====================
|
||||||
|
|
||||||
async def _send_workflow_notification(
|
async def _send_workflow_notification(
|
||||||
self, workflow: Workflow, results: dict, success: bool = True,
|
self,
|
||||||
|
workflow: Workflow,
|
||||||
|
results: dict,
|
||||||
|
success: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""发送工作流执行通知"""
|
"""发送工作流执行通知"""
|
||||||
if not workflow.webhook_ids:
|
if not workflow.webhook_ids:
|
||||||
|
|||||||
Reference in New Issue
Block a user