diff --git a/backend/ai_manager.py b/backend/ai_manager.py
index 8aab05e..9468613 100644
--- a/backend/ai_manager.py
+++ b/backend/ai_manager.py
@@ -209,7 +209,7 @@ class AIManager:
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
- def _get_db(self):
+ def _get_db(self) -> None:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
@@ -921,7 +921,6 @@ class AIManager:
# 解析关键要点
key_points = []
- import re
# 尝试从 JSON 中提取
json_match = re.search(r"\{.*?\}", content, re.DOTALL)
@@ -1312,7 +1311,7 @@ class AIManager:
return [self._row_to_prediction_result(row) for row in rows]
- def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool):
+ def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None:
"""更新预测反馈(用于模型改进)"""
with self._get_db() as conn:
conn.execute(
diff --git a/backend/api_key_manager.py b/backend/api_key_manager.py
index 65cae67..fd7846e 100644
--- a/backend/api_key_manager.py
+++ b/backend/api_key_manager.py
@@ -51,7 +51,7 @@ class ApiKeyManager:
self.db_path = db_path
self._init_db()
- def _init_db(self):
+ def _init_db(self) -> None:
"""初始化数据库表"""
with sqlite3.connect(self.db_path) as conn:
conn.executescript("""
@@ -331,7 +331,7 @@ class ApiKeyManager:
conn.commit()
return cursor.rowcount > 0
- def update_last_used(self, key_id: str):
+ def update_last_used(self, key_id: str) -> None:
"""更新最后使用时间"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
diff --git a/backend/collaboration_manager.py b/backend/collaboration_manager.py
index d7e0f8a..ba8cdcb 100644
--- a/backend/collaboration_manager.py
+++ b/backend/collaboration_manager.py
@@ -195,7 +195,7 @@ class CollaborationManager:
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
return hashlib.sha256(data.encode()).hexdigest()[:32]
- def _save_share_to_db(self, share: ProjectShare):
+ def _save_share_to_db(self, share: ProjectShare) -> None:
"""保存分享记录到数据库"""
cursor = self.db.conn.cursor()
cursor.execute(
@@ -286,7 +286,7 @@ class CollaborationManager:
allow_export=bool(row[12]),
)
- def increment_share_usage(self, token: str):
+ def increment_share_usage(self, token: str) -> None:
"""增加分享链接使用次数"""
share = self._shares_cache.get(token)
if share:
@@ -403,7 +403,7 @@ class CollaborationManager:
return comment
- def _save_comment_to_db(self, comment: Comment):
+ def _save_comment_to_db(self, comment: Comment) -> None:
"""保存评论到数据库"""
cursor = self.db.conn.cursor()
cursor.execute(
@@ -616,7 +616,7 @@ class CollaborationManager:
return record
- def _save_change_to_db(self, record: ChangeRecord):
+ def _save_change_to_db(self, record: ChangeRecord) -> None:
"""保存变更记录到数据库"""
cursor = self.db.conn.cursor()
cursor.execute(
@@ -862,7 +862,7 @@ class CollaborationManager:
}
return permissions_map.get(role, ["read"])
- def _save_member_to_db(self, member: TeamMember):
+ def _save_member_to_db(self, member: TeamMember) -> None:
"""保存成员到数据库"""
cursor = self.db.conn.cursor()
cursor.execute(
@@ -970,7 +970,7 @@ class CollaborationManager:
permissions = json.loads(row[0]) if row[0] else []
return permission in permissions or "admin" in permissions
- def update_last_active(self, project_id: str, user_id: str):
+ def update_last_active(self, project_id: str, user_id: str) -> None:
"""更新用户最后活跃时间"""
if not self.db:
return
@@ -992,7 +992,7 @@ class CollaborationManager:
_collaboration_manager = None
-def get_collaboration_manager(db_manager=None):
+def get_collaboration_manager(db_manager=None) -> None:
"""获取协作管理器单例"""
global _collaboration_manager
if _collaboration_manager is None:
diff --git a/backend/db_manager.py b/backend/db_manager.py
index 667f5cc..03afbaf 100644
--- a/backend/db_manager.py
+++ b/backend/db_manager.py
@@ -123,7 +123,7 @@ class DatabaseManager:
conn.row_factory = sqlite3.Row
return conn
- def init_db(self):
+ def init_db(self) -> None:
"""初始化数据库表"""
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
schema = f.read()
@@ -299,7 +299,7 @@ class DatabaseManager:
conn.close()
return self.get_entity(entity_id)
- def delete_entity(self, entity_id: str):
+ def delete_entity(self, entity_id: str) -> None:
"""删除实体及其关联数据"""
conn = self.get_conn()
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
diff --git a/backend/developer_ecosystem_manager.py b/backend/developer_ecosystem_manager.py
index c0bbc89..a1b393e 100644
--- a/backend/developer_ecosystem_manager.py
+++ b/backend/developer_ecosystem_manager.py
@@ -352,7 +352,7 @@ class DeveloperEcosystemManager:
self.db_path = db_path
self.platform_fee_rate = 0.30 # 平台抽成比例 30%
- def _get_db(self):
+ def _get_db(self) -> None:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
@@ -515,7 +515,7 @@ class DeveloperEcosystemManager:
return self.get_sdk_release(sdk_id)
- def increment_sdk_download(self, sdk_id: str):
+ def increment_sdk_download(self, sdk_id: str) -> None:
"""增加 SDK 下载计数"""
with self._get_db() as conn:
conn.execute(
@@ -785,7 +785,7 @@ class DeveloperEcosystemManager:
return self.get_template(template_id)
- def increment_template_install(self, template_id: str):
+ def increment_template_install(self, template_id: str) -> None:
"""增加模板安装计数"""
with self._get_db() as conn:
conn.execute(
@@ -852,7 +852,7 @@ class DeveloperEcosystemManager:
return review
- def _update_template_rating(self, conn, template_id: str):
+ def _update_template_rating(self, conn, template_id: str) -> None:
"""更新模板评分"""
row = conn.execute(
"""
@@ -1084,7 +1084,7 @@ class DeveloperEcosystemManager:
return self.get_plugin(plugin_id)
- def increment_plugin_install(self, plugin_id: str, active: bool = True):
+ def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None:
"""增加插件安装计数"""
with self._get_db() as conn:
conn.execute(
@@ -1160,7 +1160,7 @@ class DeveloperEcosystemManager:
return review
- def _update_plugin_rating(self, conn, plugin_id: str):
+ def _update_plugin_rating(self, conn, plugin_id: str) -> None:
"""更新插件评分"""
row = conn.execute(
"""
@@ -1421,7 +1421,7 @@ class DeveloperEcosystemManager:
return self.get_developer_profile(developer_id)
- def update_developer_stats(self, developer_id: str):
+ def update_developer_stats(self, developer_id: str) -> None:
"""更新开发者统计信息"""
with self._get_db() as conn:
# 统计插件数量
@@ -1574,7 +1574,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_code_example(row) for row in rows]
- def increment_example_view(self, example_id: str):
+ def increment_example_view(self, example_id: str) -> None:
"""增加代码示例查看计数"""
with self._get_db() as conn:
conn.execute(
@@ -1587,7 +1587,7 @@ class DeveloperEcosystemManager:
)
conn.commit()
- def increment_example_copy(self, example_id: str):
+ def increment_example_copy(self, example_id: str) -> None:
"""增加代码示例复制计数"""
with self._get_db() as conn:
conn.execute(
diff --git a/backend/enterprise_manager.py b/backend/enterprise_manager.py
index 6d86c1b..aab1d23 100644
--- a/backend/enterprise_manager.py
+++ b/backend/enterprise_manager.py
@@ -329,7 +329,7 @@ class EnterpriseManager:
conn.row_factory = sqlite3.Row
return conn
- def _init_db(self):
+ def _init_db(self) -> None:
"""初始化数据库表"""
conn = self._get_connection()
try:
@@ -1190,7 +1190,7 @@ class EnterpriseManager:
# GET {scim_base_url}/Users
return []
- def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]):
+ def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]) -> None:
"""插入或更新 SCIM 用户"""
cursor = conn.cursor()
diff --git a/backend/export_manager.py b/backend/export_manager.py
index da174cf..f908927 100644
--- a/backend/export_manager.py
+++ b/backend/export_manager.py
@@ -22,14 +22,7 @@ try:
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch
- from reportlab.platypus import (
- PageBreak,
- Paragraph,
- SimpleDocTemplate,
- Spacer,
- Table,
- TableStyle,
- )
+ from reportlab.platypus import PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle
REPORTLAB_AVAILABLE = True
except ImportError:
@@ -194,20 +187,16 @@ class ExportManager:
f'fill="white" stroke="#bdc3c7" rx="5"/>'
)
svg_parts.append(
- f'实体类型'
+ f'实体类型'
)
for i, (etype, color) in enumerate(type_colors.items()):
if etype != "default":
y_pos = legend_y + 25 + i * 20
- svg_parts.append(
- f''
- )
+ svg_parts.append(f'')
text_y = y_pos + 4
svg_parts.append(
- f'{etype}'
+ f'{etype}'
)
svg_parts.append("")
@@ -320,7 +309,6 @@ class ExportManager:
Returns:
CSV 字符串
"""
- import csv
output = io.StringIO()
writer = csv.writer(output)
@@ -586,7 +574,7 @@ class ExportManager:
_export_manager = None
-def get_export_manager(db_manager=None):
+def get_export_manager(db_manager=None) -> None:
"""获取导出管理器实例"""
global _export_manager
if _export_manager is None:
diff --git a/backend/growth_manager.py b/backend/growth_manager.py
index a630e34..a988c88 100644
--- a/backend/growth_manager.py
+++ b/backend/growth_manager.py
@@ -369,7 +369,7 @@ class GrowthManager:
self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "")
self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "")
- def _get_db(self):
+ def _get_db(self) -> None:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
diff --git a/backend/image_processor.py b/backend/image_processor.py
index 58e514f..513410f 100644
--- a/backend/image_processor.py
+++ b/backend/image_processor.py
@@ -93,7 +93,7 @@ class ImageProcessor:
"other": "其他",
}
- def __init__(self, temp_dir: str = None):
+ def __init__(self, temp_dir: str = None) -> None:
"""
初始化图片处理器
@@ -103,7 +103,7 @@ class ImageProcessor:
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok=True)
- def preprocess_image(self, image, image_type: str = None):
+ def preprocess_image(self, image, image_type: str = None) -> None:
"""
预处理图片以提高OCR质量
@@ -145,7 +145,7 @@ class ImageProcessor:
print(f"Image preprocessing error: {e}")
return image
- def _enhance_whiteboard(self, image):
+ def _enhance_whiteboard(self, image) -> None:
"""增强白板图片"""
# 转换为灰度
gray = image.convert("L")
@@ -160,7 +160,7 @@ class ImageProcessor:
return binary.convert("L")
- def _enhance_handwritten(self, image):
+ def _enhance_handwritten(self, image) -> None:
"""增强手写笔记图片"""
# 转换为灰度
gray = image.convert("L")
diff --git a/backend/knowledge_reasoner.py b/backend/knowledge_reasoner.py
index 0cac3b6..1810ef6 100644
--- a/backend/knowledge_reasoner.py
+++ b/backend/knowledge_reasoner.py
@@ -163,8 +163,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3)
- import re
-
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
@@ -217,8 +215,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3)
- import re
-
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
@@ -271,8 +267,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3)
- import re
-
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
@@ -325,8 +319,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.4)
- import re
-
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
@@ -474,8 +466,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3)
- import re
-
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
diff --git a/backend/llm_client.py b/backend/llm_client.py
index 8e5da38..2068177 100644
--- a/backend/llm_client.py
+++ b/backend/llm_client.py
@@ -204,15 +204,13 @@ class LLMClient:
messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature=0.1)
- import re
-
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match:
return {"intent": "unknown", "explanation": "无法解析指令"}
try:
return json.loads(json_match.group())
- except BaseException:
+ except (json.JSONDecodeError, KeyError, TypeError):
return {"intent": "unknown", "explanation": "解析失败"}
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
diff --git a/backend/localization_manager.py b/backend/localization_manager.py
index 96c05c1..b4fe000 100644
--- a/backend/localization_manager.py
+++ b/backend/localization_manager.py
@@ -938,9 +938,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
- def get_translation(
- self, key: str, language: str, namespace: str = "common", fallback: bool = True
- ) -> str | None:
+ def get_translation(self, key: str, language: str, namespace: str = "common", fallback: bool = True) -> str | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
diff --git a/backend/main.py b/backend/main.py
index af752dc..33a5944 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -17,18 +17,7 @@ from datetime import datetime, timedelta
from typing import Any, Optional
import httpx
-from fastapi import (
- Body,
- Depends,
- FastAPI,
- File,
- Form,
- Header,
- HTTPException,
- Query,
- Request,
- UploadFile,
-)
+from fastapi import Body, Depends, FastAPI, File, Form, Header, HTTPException, Query, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -45,18 +34,21 @@ if backend_dir not in sys.path:
# Import clients
try:
from oss_uploader import get_oss_uploader
+
OSS_AVAILABLE = True
except ImportError:
OSS_AVAILABLE = False
try:
from tingwu_client import TingwuClient
+
TINGWU_AVAILABLE = True
except ImportError:
TINGWU_AVAILABLE = False
try:
from db_manager import Entity, EntityMention, get_db_manager
+
DB_AVAILABLE = True
except ImportError as e:
print(f"DB import error: {e}")
@@ -64,30 +56,35 @@ except ImportError as e:
try:
from document_processor import DocumentProcessor
+
DOC_PROCESSOR_AVAILABLE = True
except ImportError:
DOC_PROCESSOR_AVAILABLE = False
try:
from entity_aligner import EntityAligner
+
ALIGNER_AVAILABLE = True
except ImportError:
ALIGNER_AVAILABLE = False
try:
from llm_client import ChatMessage, get_llm_client
+
LLM_CLIENT_AVAILABLE = True
except ImportError:
LLM_CLIENT_AVAILABLE = False
try:
from knowledge_reasoner import get_knowledge_reasoner
+
REASONER_AVAILABLE = True
except ImportError:
REASONER_AVAILABLE = False
try:
from export_manager import ExportEntity, ExportRelation, ExportTranscript, get_export_manager
+
EXPORT_AVAILABLE = True
except ImportError:
EXPORT_AVAILABLE = False
@@ -100,6 +97,7 @@ except ImportError:
# Phase 6: API Key Manager
try:
from api_key_manager import get_api_key_manager
+
API_KEY_AVAILABLE = True
except ImportError as e:
print(f"API Key Manager import error: {e}")
@@ -108,6 +106,7 @@ except ImportError as e:
# Phase 6: Rate Limiter
try:
from rate_limiter import RateLimitConfig, get_rate_limiter
+
RATE_LIMITER_AVAILABLE = True
except ImportError as e:
print(f"Rate Limiter import error: {e}")
@@ -116,6 +115,7 @@ except ImportError as e:
# Phase 7: Workflow Manager
try:
from workflow_manager import WebhookConfig, Workflow
+
WORKFLOW_AVAILABLE = True
except ImportError as e:
print(f"Workflow Manager import error: {e}")
@@ -124,6 +124,7 @@ except ImportError as e:
# Phase 7: Multimodal Support
try:
from multimodal_processor import get_multimodal_processor
+
MULTIMODAL_AVAILABLE = True
except ImportError as e:
print(f"Multimodal Processor import error: {e}")
@@ -131,6 +132,7 @@ except ImportError as e:
try:
from image_processor import get_image_processor
+
IMAGE_PROCESSOR_AVAILABLE = True
except ImportError as e:
print(f"Image Processor import error: {e}")
@@ -138,6 +140,7 @@ except ImportError as e:
try:
from multimodal_entity_linker import EntityLink, get_multimodal_entity_linker
+
MULTIMODAL_LINKER_AVAILABLE = True
except ImportError as e:
print(f"Multimodal Entity Linker import error: {e}")
@@ -145,14 +148,8 @@ except ImportError as e:
# Phase 7 Task 7: Plugin Manager
try:
- from plugin_manager import (
- BotHandler,
- Plugin,
- PluginStatus,
- PluginType,
- WebhookIntegration,
- get_plugin_manager,
- )
+ from plugin_manager import BotHandler, Plugin, PluginStatus, PluginType, WebhookIntegration, get_plugin_manager
+
PLUGIN_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Plugin Manager import error: {e}")
@@ -161,6 +158,7 @@ except ImportError as e:
# Phase 7 Task 3: Security Manager
try:
from security_manager import MaskingRuleType, get_security_manager
+
SECURITY_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Security Manager import error: {e}")
@@ -169,6 +167,7 @@ except ImportError as e:
# Phase 7 Task 4: Collaboration Manager
try:
from collaboration_manager import get_collaboration_manager
+
COLLABORATION_AVAILABLE = True
except ImportError as e:
print(f"Collaboration Manager import error: {e}")
@@ -184,6 +183,7 @@ except ImportError as e:
# Phase 7 Task 6: Search Manager
try:
from search_manager import SearchOperator, get_search_manager
+
SEARCH_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Search Manager import error: {e}")
@@ -192,6 +192,7 @@ except ImportError as e:
# Phase 7 Task 8: Performance Manager
try:
from performance_manager import get_performance_manager
+
PERFORMANCE_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Performance Manager import error: {e}")
@@ -200,6 +201,7 @@ except ImportError as e:
# Phase 8: Tenant Manager (Multi-Tenant SaaS)
try:
from tenant_manager import TenantRole, TenantStatus, TenantTier, get_tenant_manager
+
TENANT_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Tenant Manager import error: {e}")
@@ -208,6 +210,7 @@ except ImportError as e:
# Phase 8: Subscription Manager
try:
from subscription_manager import get_subscription_manager
+
SUBSCRIPTION_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Subscription Manager import error: {e}")
@@ -216,6 +219,7 @@ except ImportError as e:
# Phase 8: Enterprise Manager
try:
from enterprise_manager import get_enterprise_manager
+
ENTERPRISE_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Enterprise Manager import error: {e}")
@@ -224,6 +228,7 @@ except ImportError as e:
# Phase 8: Localization Manager
try:
from localization_manager import get_localization_manager
+
LOCALIZATION_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Localization Manager import error: {e}")
@@ -231,13 +236,8 @@ except ImportError as e:
# Phase 8 Task 4: AI Manager
try:
- from ai_manager import (
- ModelStatus,
- ModelType,
- MultimodalProvider,
- PredictionType,
- get_ai_manager,
- )
+ from ai_manager import ModelStatus, ModelType, MultimodalProvider, PredictionType, get_ai_manager
+
AI_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"AI Manager import error: {e}")
@@ -253,6 +253,7 @@ try:
TrafficAllocationType,
WorkflowTriggerType,
)
+
GROWTH_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Growth Manager import error: {e}")
@@ -260,14 +261,8 @@ except ImportError as e:
# Phase 8 Task 8: Operations & Monitoring Manager
try:
- from ops_manager import (
- AlertChannelType,
- AlertRuleType,
- AlertSeverity,
- AlertStatus,
- ResourceType,
- get_ops_manager,
- )
+ from ops_manager import AlertChannelType, AlertRuleType, AlertSeverity, AlertStatus, ResourceType, get_ops_manager
+
OPS_MANAGER_AVAILABLE = True
except ImportError as e:
print(f"Ops Manager import error: {e}")
@@ -330,9 +325,12 @@ app = FastAPI(
{"name": "Localization", "description": "全球化与本地化(多语言、数据中心、支付方式、时区日历)"},
{"name": "AI Enhancement", "description": "AI 能力增强(自定义模型、多模态分析、智能摘要、预测分析)"},
{"name": "Growth & Analytics", "description": "运营与增长工具(用户行为分析、A/B 测试、邮件营销、推荐系统)"},
- {"name": "Operations & Monitoring", "description": "运维与监控(实时告警、容量规划、自动扩缩容、灾备故障转移、成本优化)"},
+ {
+ "name": "Operations & Monitoring",
+ "description": "运维与监控(实时告警、容量规划、自动扩缩容、灾备故障转移、成本优化)",
+ },
{"name": "System", "description": "系统信息"},
- ]
+ ],
)
app.add_middleware(
@@ -347,8 +345,12 @@ app.add_middleware(
# 公开访问的路径(不需要 API Key)
PUBLIC_PATHS = {
- "/", "/docs", "/openapi.json", "/redoc",
- "/api/v1/health", "/api/v1/status",
+ "/",
+ "/docs",
+ "/openapi.json",
+ "/redoc",
+ "/api/v1/health",
+ "/api/v1/status",
"/api/v1/api-keys", # POST 创建 API Key 不需要认证
}
@@ -384,8 +386,7 @@ async def verify_api_key(request: Request, x_api_key: str | None = Header(None,
if any(path.startswith(p) for p in ADMIN_PATHS):
if not x_api_key or x_api_key != MASTER_KEY:
raise HTTPException(
- status_code=403,
- detail="Admin access required. Provide valid master key in X-API-Key header."
+ status_code=403, detail="Admin access required. Provide valid master key in X-API-Key header."
)
return {"type": "admin", "key": x_api_key}
@@ -398,7 +399,7 @@ async def verify_api_key(request: Request, x_api_key: str | None = Header(None,
raise HTTPException(
status_code=401,
detail="API Key required. Provide your key in X-API-Key header.",
- headers={"WWW-Authenticate": "ApiKey"}
+ headers={"WWW-Authenticate": "ApiKey"},
)
# 验证 API Key
@@ -406,10 +407,7 @@ async def verify_api_key(request: Request, x_api_key: str | None = Header(None,
api_key = key_manager.validate_key(x_api_key)
if not api_key:
- raise HTTPException(
- status_code=401,
- detail="Invalid or expired API Key"
- )
+ raise HTTPException(status_code=401, detail="Invalid or expired API Key")
# 更新最后使用时间
key_manager.update_last_used(api_key.id)
@@ -445,7 +443,7 @@ async def rate_limit_middleware(request: Request, call_next):
# Master key 有更高的限流
config = RateLimitConfig(requests_per_minute=1000)
limit_key = f"master:{x_api_key[:16]}"
- elif hasattr(request.state, 'api_key') and request.state.api_key:
+ elif hasattr(request.state, "api_key") and request.state.api_key:
# 使用 API Key 的限流配置
api_key = request.state.api_key
config = RateLimitConfig(requests_per_minute=api_key.rate_limit)
@@ -466,14 +464,14 @@ async def rate_limit_middleware(request: Request, call_next):
"error": "Rate limit exceeded",
"retry_after": info.retry_after,
"limit": config.requests_per_minute,
- "window": "minute"
+ "window": "minute",
},
headers={
"X-RateLimit-Limit": str(config.requests_per_minute),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(info.reset_time),
- "Retry-After": str(info.retry_after)
- }
+ "Retry-After": str(info.retry_after),
+ },
)
# 继续处理请求
@@ -487,7 +485,7 @@ async def rate_limit_middleware(request: Request, call_next):
# 记录 API 调用日志
try:
- if hasattr(request.state, 'api_key') and request.state.api_key:
+ if hasattr(request.state, "api_key") and request.state.api_key:
api_key = request.state.api_key
response_time = int((time.time() - start_time) * 1000)
key_manager = get_api_key_manager()
@@ -498,7 +496,7 @@ async def rate_limit_middleware(request: Request, call_next):
status_code=response.status_code,
response_time_ms=response_time,
ip_address=request.client.host if request.client else "",
- user_agent=request.headers.get("User-Agent", "")
+ user_agent=request.headers.get("User-Agent", ""),
)
except Exception as e:
# 日志记录失败不应影响主流程
@@ -659,11 +657,13 @@ class GlossaryTermCreate(BaseModel):
# ==================== Phase 7: Workflow Pydantic Models ====================
+
class WorkflowCreate(BaseModel):
name: str = Field(..., description="工作流名称")
description: str = Field(default="", description="工作流描述")
- workflow_type: str = Field(...,
- description="工作流类型: auto_analyze, auto_align, auto_relation, scheduled_report, custom")
+ workflow_type: str = Field(
+ ..., description="工作流类型: auto_analyze, auto_align, auto_relation, scheduled_report, custom"
+ )
project_id: str = Field(..., description="所属项目ID")
schedule: str | None = Field(default=None, description="调度表达式(cron或分钟数)")
schedule_type: str = Field(default="manual", description="调度类型: manual, cron, interval")
@@ -861,6 +861,7 @@ def get_collab_manager():
_collaboration_manager = get_collaboration_manager(db)
return _collaboration_manager
+
# Phase 2: Entity Edit API
@@ -884,7 +885,7 @@ async def update_entity(entity_id: str, update: EntityUpdate, _=Depends(verify_a
"name": updated.name,
"type": updated.type,
"definition": updated.definition,
- "aliases": updated.aliases
+ "aliases": updated.aliases,
}
@@ -926,10 +927,11 @@ async def merge_entities_endpoint(entity_id: str, merge_req: EntityMergeRequest,
"name": result.name,
"type": result.type,
"definition": result.definition,
- "aliases": result.aliases
- }
+ "aliases": result.aliases,
+ },
}
+
# Phase 2: Relation Edit API
@@ -953,7 +955,7 @@ async def create_relation_endpoint(project_id: str, relation: RelationCreate, _=
source_entity_id=relation.source_entity_id,
target_entity_id=relation.target_entity_id,
relation_type=relation.relation_type,
- evidence=relation.evidence
+ evidence=relation.evidence,
)
return {
@@ -961,7 +963,7 @@ async def create_relation_endpoint(project_id: str, relation: RelationCreate, _=
"source_id": relation.source_entity_id,
"target_id": relation.target_entity_id,
"type": relation.relation_type,
- "success": True
+ "success": True,
}
@@ -984,17 +986,11 @@ async def update_relation(relation_id: str, relation: RelationCreate, _=Depends(
db = get_db_manager()
updated = db.update_relation(
- relation_id=relation_id,
- relation_type=relation.relation_type,
- evidence=relation.evidence
+ relation_id=relation_id, relation_type=relation.relation_type, evidence=relation.evidence
)
- return {
- "id": relation_id,
- "type": updated["relation_type"],
- "evidence": updated["evidence"],
- "success": True
- }
+ return {"id": relation_id, "type": updated["relation_type"], "evidence": updated["evidence"], "success": True}
+
# Phase 2: Transcript Edit API
@@ -1031,9 +1027,10 @@ async def update_transcript(transcript_id: str, update: TranscriptUpdate, _=Depe
"id": transcript_id,
"full_text": updated["full_text"],
"updated_at": updated["updated_at"],
- "success": True
+ "success": True,
}
+
# Phase 2: Manual Entity Creation
@@ -1057,21 +1054,12 @@ async def create_manual_entity(project_id: str, entity: ManualEntityCreate, _=De
# 检查是否已存在
existing = db.get_entity_by_name(project_id, entity.name)
if existing:
- return {
- "id": existing.id,
- "name": existing.name,
- "type": existing.type,
- "existed": True
- }
+ return {"id": existing.id, "name": existing.name, "type": existing.type, "existed": True}
entity_id = str(uuid.uuid4())[:8]
- new_entity = db.create_entity(Entity(
- id=entity_id,
- project_id=project_id,
- name=entity.name,
- type=entity.type,
- definition=entity.definition
- ))
+ new_entity = db.create_entity(
+ Entity(id=entity_id, project_id=project_id, name=entity.name, type=entity.type, definition=entity.definition)
+ )
# 如果有提及位置信息,保存提及
if entity.transcript_id and entity.start_pos is not None and entity.end_pos is not None:
@@ -1084,8 +1072,8 @@ async def create_manual_entity(project_id: str, entity: ManualEntityCreate, _=De
transcript_id=entity.transcript_id,
start_pos=entity.start_pos,
end_pos=entity.end_pos,
- text_snippet=text[max(0, entity.start_pos - 20):min(len(text), entity.end_pos + 20)],
- confidence=1.0
+ text_snippet=text[max(0, entity.start_pos - 20) : min(len(text), entity.end_pos + 20)],
+ confidence=1.0,
)
db.add_mention(mention)
@@ -1094,7 +1082,7 @@ async def create_manual_entity(project_id: str, entity: ManualEntityCreate, _=De
"name": new_entity.name,
"type": new_entity.type,
"definition": new_entity.definition,
- "success": True
+ "success": True,
}
@@ -1134,8 +1122,13 @@ def mock_transcribe() -> dict:
return {
"full_text": "我们今天讨论 Project Alpha 的进度,K8s 集群已经部署完成。",
"segments": [
- {"start": 0.0, "end": 5.0, "text": "我们今天讨论 Project Alpha 的进度,K8s 集群已经部署完成。", "speaker": "Speaker A"}
- ]
+ {
+ "start": 0.0,
+ "end": 5.0,
+ "text": "我们今天讨论 Project Alpha 的进度,K8s 集群已经部署完成。",
+ "speaker": "Speaker A",
+ }
+ ],
}
@@ -1174,14 +1167,15 @@ def extract_entities_with_llm(text: str) -> tuple[list[dict], list[dict]]:
f"{KIMI_BASE_URL}/v1/chat/completions",
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
json={"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1},
- timeout=60.0
+ timeout=60.0,
)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
import re
- json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
+
+ json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
return data.get("entities", []), data.get("relations", [])
@@ -1191,7 +1185,7 @@ def extract_entities_with_llm(text: str) -> tuple[list[dict], list[dict]]:
return [], []
-def align_entity(project_id: str, name: str, db, definition: str = "") -> Optional['Entity']:
+def align_entity(project_id: str, name: str, db, definition: str = "") -> Optional["Entity"]:
"""实体对齐 - Phase 3: 使用 embedding 对齐"""
# 1. 首先尝试精确匹配
existing = db.get_entity_by_name(project_id, name)
@@ -1212,6 +1206,7 @@ def align_entity(project_id: str, name: str, db, definition: str = "") -> Option
return None
+
# API Endpoints
@@ -1262,10 +1257,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends(
# 保存转录记录
transcript_id = str(uuid.uuid4())[:8]
db.save_transcript(
- transcript_id=transcript_id,
- project_id=project_id,
- filename=file.filename,
- full_text=tw_result["full_text"]
+ transcript_id=transcript_id, project_id=project_id, filename=file.filename, full_text=tw_result["full_text"]
)
# 实体对齐并保存 - Phase 3: 使用增强对齐
@@ -1281,23 +1273,20 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends(
name=existing.name,
type=existing.type,
definition=existing.definition,
- aliases=existing.aliases
+ aliases=existing.aliases,
)
entity_name_to_id[raw_ent["name"]] = existing.id
else:
- new_ent = db.create_entity(Entity(
- id=str(uuid.uuid4())[:8],
- project_id=project_id,
- name=raw_ent["name"],
- type=raw_ent.get("type", "OTHER"),
- definition=raw_ent.get("definition", "")
- ))
- ent_model = EntityModel(
- id=new_ent.id,
- name=new_ent.name,
- type=new_ent.type,
- definition=new_ent.definition
+ new_ent = db.create_entity(
+ Entity(
+ id=str(uuid.uuid4())[:8],
+ project_id=project_id,
+ name=raw_ent["name"],
+ type=raw_ent.get("type", "OTHER"),
+ definition=raw_ent.get("definition", ""),
+ )
)
+ ent_model = EntityModel(id=new_ent.id, name=new_ent.name, type=new_ent.type, definition=new_ent.definition)
entity_name_to_id[raw_ent["name"]] = new_ent.id
aligned_entities.append(ent_model)
@@ -1316,8 +1305,8 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends(
transcript_id=transcript_id,
start_pos=pos,
end_pos=pos + len(name),
- text_snippet=full_text[max(0, pos - 20):min(len(full_text), pos + len(name) + 20)],
- confidence=1.0
+ text_snippet=full_text[max(0, pos - 20) : min(len(full_text), pos + len(name) + 20)],
+ confidence=1.0,
)
db.add_mention(mention)
start_pos = pos + 1
@@ -1333,7 +1322,7 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends(
target_entity_id=target_id,
relation_type=rel.get("type", "related"),
evidence=tw_result["full_text"][:200],
- transcript_id=transcript_id
+ transcript_id=transcript_id,
)
# 构建片段
@@ -1345,9 +1334,10 @@ async def upload_audio(project_id: str, file: UploadFile = File(...), _=Depends(
segments=segments,
entities=aligned_entities,
full_text=tw_result["full_text"],
- created_at=datetime.now().isoformat()
+ created_at=datetime.now().isoformat(),
)
+
# Phase 3: Document Upload API
@@ -1381,7 +1371,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen
project_id=project_id,
filename=file.filename,
full_text=result["text"],
- transcript_type="document"
+ transcript_type="document",
)
# 提取实体和关系
@@ -1396,28 +1386,29 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen
if existing:
entity_name_to_id[raw_ent["name"]] = existing.id
- aligned_entities.append(EntityModel(
- id=existing.id,
- name=existing.name,
- type=existing.type,
- definition=existing.definition,
- aliases=existing.aliases
- ))
+ aligned_entities.append(
+ EntityModel(
+ id=existing.id,
+ name=existing.name,
+ type=existing.type,
+ definition=existing.definition,
+ aliases=existing.aliases,
+ )
+ )
else:
- new_ent = db.create_entity(Entity(
- id=str(uuid.uuid4())[:8],
- project_id=project_id,
- name=raw_ent["name"],
- type=raw_ent.get("type", "OTHER"),
- definition=raw_ent.get("definition", "")
- ))
+ new_ent = db.create_entity(
+ Entity(
+ id=str(uuid.uuid4())[:8],
+ project_id=project_id,
+ name=raw_ent["name"],
+ type=raw_ent.get("type", "OTHER"),
+ definition=raw_ent.get("definition", ""),
+ )
+ )
entity_name_to_id[raw_ent["name"]] = new_ent.id
- aligned_entities.append(EntityModel(
- id=new_ent.id,
- name=new_ent.name,
- type=new_ent.type,
- definition=new_ent.definition
- ))
+ aligned_entities.append(
+ EntityModel(id=new_ent.id, name=new_ent.name, type=new_ent.type, definition=new_ent.definition)
+ )
# 保存实体提及位置
full_text = result["text"]
@@ -1433,8 +1424,8 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen
transcript_id=transcript_id,
start_pos=pos,
end_pos=pos + len(name),
- text_snippet=full_text[max(0, pos - 20):min(len(full_text), pos + len(name) + 20)],
- confidence=1.0
+ text_snippet=full_text[max(0, pos - 20) : min(len(full_text), pos + len(name) + 20)],
+ confidence=1.0,
)
db.add_mention(mention)
start_pos = pos + 1
@@ -1450,7 +1441,7 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen
target_entity_id=target_id,
relation_type=rel.get("type", "related"),
evidence=result["text"][:200],
- transcript_id=transcript_id
+ transcript_id=transcript_id,
)
return {
@@ -1459,9 +1450,10 @@ async def upload_document(project_id: str, file: UploadFile = File(...), _=Depen
"filename": file.filename,
"text_length": len(result["text"]),
"entities": [e.dict() for e in aligned_entities],
- "created_at": datetime.now().isoformat()
+ "created_at": datetime.now().isoformat(),
}
+
# Phase 3: Knowledge Base API
@@ -1495,7 +1487,7 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)):
mentions = db.get_entity_mentions(ent.id)
entity_stats[ent.id] = {
"mention_count": len(mentions),
- "transcript_ids": list(set([m.transcript_id for m in mentions]))
+ "transcript_ids": list(set([m.transcript_id for m in mentions])),
}
# Phase 5: 获取实体属性
attrs = db.get_entity_attributes(ent.id)
@@ -1505,16 +1497,12 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)):
entity_map = {e.id: e.name for e in entities}
return {
- "project": {
- "id": project.id,
- "name": project.name,
- "description": project.description
- },
+ "project": {"id": project.id, "name": project.name, "description": project.description},
"stats": {
"entity_count": len(entities),
"relation_count": len(relations),
"transcript_count": len(transcripts),
- "glossary_count": len(glossary)
+ "glossary_count": len(glossary),
},
"entities": [
{
@@ -1525,7 +1513,7 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)):
"aliases": e.aliases,
"mention_count": entity_stats.get(e.id, {}).get("mention_count", 0),
"appears_in": entity_stats.get(e.id, {}).get("transcript_ids", []),
- "attributes": entity_attributes.get(e.id, []) # Phase 5: 包含属性
+ "attributes": entity_attributes.get(e.id, []), # Phase 5: 包含属性
}
for e in entities
],
@@ -1537,30 +1525,21 @@ async def get_knowledge_base(project_id: str, _=Depends(verify_api_key)):
"target_id": r["target_entity_id"],
"target_name": entity_map.get(r["target_entity_id"], "Unknown"),
"type": r["relation_type"],
- "evidence": r["evidence"]
+ "evidence": r["evidence"],
}
for r in relations
],
"glossary": [
- {
- "id": g["id"],
- "term": g["term"],
- "pronunciation": g["pronunciation"],
- "frequency": g["frequency"]
- }
+ {"id": g["id"], "term": g["term"], "pronunciation": g["pronunciation"], "frequency": g["frequency"]}
for g in glossary
],
"transcripts": [
- {
- "id": t["id"],
- "filename": t["filename"],
- "type": t.get("type", "audio"),
- "created_at": t["created_at"]
- }
+ {"id": t["id"], "filename": t["filename"], "type": t.get("type", "audio"), "created_at": t["created_at"]}
for t in transcripts
- ]
+ ],
}
+
# Phase 3: Glossary API
@@ -1575,18 +1554,9 @@ async def add_glossary_term(project_id: str, term: GlossaryTermCreate, _=Depends
if not project:
raise HTTPException(status_code=404, detail="Project not found")
- term_id = db.add_glossary_term(
- project_id=project_id,
- term=term.term,
- pronunciation=term.pronunciation
- )
+ term_id = db.add_glossary_term(project_id=project_id, term=term.term, pronunciation=term.pronunciation)
- return {
- "id": term_id,
- "term": term.term,
- "pronunciation": term.pronunciation,
- "success": True
- }
+ return {"id": term_id, "term": term.term, "pronunciation": term.pronunciation, "success": True}
@app.get("/api/v1/projects/{project_id}/glossary")
@@ -1610,6 +1580,7 @@ async def delete_glossary_term(term_id: str, _=Depends(verify_api_key)):
db.delete_glossary_term(term_id)
return {"success": True}
+
# Phase 3: Entity Alignment API
@@ -1637,27 +1608,16 @@ async def align_project_entities(project_id: str, threshold: float = 0.85, _=Dep
continue
similar = aligner.find_similar_entity(
- project_id,
- entity.name,
- entity.definition,
- exclude_id=entity.id,
- threshold=threshold
+ project_id, entity.name, entity.definition, exclude_id=entity.id, threshold=threshold
)
if similar:
# 合并实体
db.merge_entities(similar.id, entity.id)
merged_count += 1
- merged_pairs.append({
- "source": entity.name,
- "target": similar.name
- })
+ merged_pairs.append({"source": entity.name, "target": similar.name})
- return {
- "success": True,
- "merged_count": merged_count,
- "merged_pairs": merged_pairs
- }
+ return {"success": True, "merged_count": merged_count, "merged_pairs": merged_pairs}
@app.get("/api/v1/projects/{project_id}/entities")
@@ -1668,8 +1628,9 @@ async def get_project_entities(project_id: str, _=Depends(verify_api_key)):
db = get_db_manager()
entities = db.list_project_entities(project_id)
- return [{"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": e.aliases}
- for e in entities]
+ return [
+ {"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": e.aliases} for e in entities
+ ]
@app.get("/api/v1/projects/{project_id}/relations")
@@ -1685,15 +1646,18 @@ async def get_project_relations(project_id: str, _=Depends(verify_api_key)):
entities = db.list_project_entities(project_id)
entity_map = {e.id: e.name for e in entities}
- return [{
- "id": r["id"],
- "source_id": r["source_entity_id"],
- "source_name": entity_map.get(r["source_entity_id"], "Unknown"),
- "target_id": r["target_entity_id"],
- "target_name": entity_map.get(r["target_entity_id"], "Unknown"),
- "type": r["relation_type"],
- "evidence": r["evidence"]
- } for r in relations]
+ return [
+ {
+ "id": r["id"],
+ "source_id": r["source_entity_id"],
+ "source_name": entity_map.get(r["source_entity_id"], "Unknown"),
+ "target_id": r["target_entity_id"],
+ "target_name": entity_map.get(r["target_entity_id"], "Unknown"),
+ "type": r["relation_type"],
+ "evidence": r["evidence"],
+ }
+ for r in relations
+ ]
@app.get("/api/v1/projects/{project_id}/transcripts")
@@ -1704,13 +1668,16 @@ async def get_project_transcripts(project_id: str, _=Depends(verify_api_key)):
db = get_db_manager()
transcripts = db.list_project_transcripts(project_id)
- return [{
- "id": t["id"],
- "filename": t["filename"],
- "type": t.get("type", "audio"),
- "created_at": t["created_at"],
- "preview": t["full_text"][:100] + "..." if len(t["full_text"]) > 100 else t["full_text"]
- } for t in transcripts]
+ return [
+ {
+ "id": t["id"],
+ "filename": t["filename"],
+ "type": t.get("type", "audio"),
+ "created_at": t["created_at"],
+ "preview": t["full_text"][:100] + "..." if len(t["full_text"]) > 100 else t["full_text"],
+ }
+ for t in transcripts
+ ]
@app.get("/api/v1/entities/{entity_id}/mentions")
@@ -1721,14 +1688,18 @@ async def get_entity_mentions(entity_id: str, _=Depends(verify_api_key)):
db = get_db_manager()
mentions = db.get_entity_mentions(entity_id)
- return [{
- "id": m.id,
- "transcript_id": m.transcript_id,
- "start_pos": m.start_pos,
- "end_pos": m.end_pos,
- "text_snippet": m.text_snippet,
- "confidence": m.confidence
- } for m in mentions]
+ return [
+ {
+ "id": m.id,
+ "transcript_id": m.transcript_id,
+ "start_pos": m.start_pos,
+ "end_pos": m.end_pos,
+ "text_snippet": m.text_snippet,
+ "confidence": m.confidence,
+ }
+ for m in mentions
+ ]
+
# Health check - Legacy endpoint (deprecated, use /api/v1/health)
@@ -1749,12 +1720,13 @@ async def legacy_health_check():
"multimodal_available": MULTIMODAL_AVAILABLE,
"image_processor_available": IMAGE_PROCESSOR_AVAILABLE,
"multimodal_linker_available": MULTIMODAL_LINKER_AVAILABLE,
- "plugin_manager_available": PLUGIN_MANAGER_AVAILABLE
+ "plugin_manager_available": PLUGIN_MANAGER_AVAILABLE,
}
# ==================== Phase 4: Agent 助手 API ====================
+
@app.post("/api/v1/projects/{project_id}/agent/query")
async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_key)):
"""Agent RAG 问答"""
@@ -1773,20 +1745,21 @@ async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_k
# 构建上下文
context_parts = []
- for t in project_context.get('recent_transcripts', []):
+ for t in project_context.get("recent_transcripts", []):
context_parts.append(f"【{t['filename']}】\n{t['full_text'][:1000]}")
context = "\n\n".join(context_parts)
if query.stream:
- import json
from fastapi.responses import StreamingResponse
async def stream_response():
messages = [
ChatMessage(role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"),
- ChatMessage(role="user", content=f"""基于以下项目信息回答问题:
+ ChatMessage(
+ role="user",
+ content=f"""基于以下项目信息回答问题:
## 项目信息
{json.dumps(project_context, ensure_ascii=False, indent=2)}
@@ -1797,7 +1770,8 @@ async def agent_query(project_id: str, query: AgentQuery, _=Depends(verify_api_k
## 用户问题
{query.query}
-请用中文回答,保持简洁专业。如果信息不足,请明确说明。""")
+请用中文回答,保持简洁专业。如果信息不足,请明确说明。""",
+ ),
]
async for chunk in llm.chat_stream(messages):
@@ -1896,7 +1870,9 @@ async def agent_command(project_id: str, command: AgentCommand, _=Depends(verify
else:
result["action"] = "none"
- result["message"] = "无法理解的指令,请尝试:\n- 合并实体:把所有'客户端'合并到'App'\n- 提问:张总对项目的态度如何?\n- 编辑:修改'K8s'的定义为..."
+ result["message"] = (
+ "无法理解的指令,请尝试:\n- 合并实体:把所有'客户端'合并到'App'\n- 提问:张总对项目的态度如何?\n- 编辑:修改'K8s'的定义为..."
+ )
return result
@@ -1927,8 +1903,7 @@ async def agent_suggest(project_id: str, _=Depends(verify_api_key)):
messages = [ChatMessage(role="user", content=prompt)]
content = await llm.chat(messages, temperature=0.3)
- import re
- json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
+ json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
@@ -1941,6 +1916,7 @@ async def agent_suggest(project_id: str, _=Depends(verify_api_key)):
# ==================== Phase 4: 知识溯源 API ====================
+
@app.get("/api/v1/relations/{relation_id}/provenance")
async def get_relation_provenance(relation_id: str, _=Depends(verify_api_key)):
"""获取关系的知识溯源信息"""
@@ -1959,10 +1935,14 @@ async def get_relation_provenance(relation_id: str, _=Depends(verify_api_key)):
"target": relation.get("target_name"),
"type": relation.get("relation_type"),
"evidence": relation.get("evidence"),
- "transcript": {
- "id": relation.get("transcript_id"),
- "filename": relation.get("transcript_filename"),
- } if relation.get("transcript_id") else None
+ "transcript": (
+ {
+ "id": relation.get("transcript_id"),
+ "filename": relation.get("transcript_filename"),
+ }
+ if relation.get("transcript_id")
+ else None
+ ),
}
@@ -2007,15 +1987,16 @@ async def get_entity_evolution(entity_id: str, _=Depends(verify_api_key)):
"date": m.get("transcript_date"),
"snippet": m.get("text_snippet"),
"transcript_id": m.get("transcript_id"),
- "filename": m.get("filename")
+ "filename": m.get("filename"),
}
for m in entity.get("mentions", [])
- ]
+ ],
}
# ==================== Phase 4: 实体管理增强 API ====================
+
@app.get("/api/v1/projects/{project_id}/entities/search")
async def search_entities(project_id: str, q: str, _=Depends(verify_api_key)):
"""搜索实体"""
@@ -2029,13 +2010,10 @@ async def search_entities(project_id: str, q: str, _=Depends(verify_api_key)):
# ==================== Phase 5: 时间线视图 API ====================
+
@app.get("/api/v1/projects/{project_id}/timeline")
async def get_project_timeline(
- project_id: str,
- entity_id: str = None,
- start_date: str = None,
- end_date: str = None,
- _=Depends(verify_api_key)
+ project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None, _=Depends(verify_api_key)
):
"""获取项目时间线 - 按时间顺序的实体提及和关系事件"""
if not DB_AVAILABLE:
@@ -2048,11 +2026,7 @@ async def get_project_timeline(
timeline = db.get_project_timeline(project_id, entity_id, start_date, end_date)
- return {
- "project_id": project_id,
- "events": timeline,
- "total_count": len(timeline)
- }
+ return {"project_id": project_id, "events": timeline, "total_count": len(timeline)}
@app.get("/api/v1/projects/{project_id}/timeline/summary")
@@ -2068,11 +2042,7 @@ async def get_timeline_summary(project_id: str, _=Depends(verify_api_key)):
summary = db.get_entity_timeline_summary(project_id)
- return {
- "project_id": project_id,
- "project_name": project.name,
- **summary
- }
+ return {"project_id": project_id, "project_name": project.name, **summary}
@app.get("/api/v1/entities/{entity_id}/timeline")
@@ -2093,12 +2063,13 @@ async def get_entity_timeline(entity_id: str, _=Depends(verify_api_key)):
"entity_name": entity.name,
"entity_type": entity.type,
"events": timeline,
- "total_count": len(timeline)
+ "total_count": len(timeline),
}
# ==================== Phase 5: 知识推理与问答增强 API ====================
+
class ReasoningQuery(BaseModel):
query: str
reasoning_depth: str = "medium" # shallow/medium/deep
@@ -2135,15 +2106,12 @@ async def reasoning_query(project_id: str, query: ReasoningQuery, _=Depends(veri
graph_data = {
"entities": [{"id": e.id, "name": e.name, "type": e.type, "definition": e.definition} for e in entities],
- "relations": relations
+ "relations": relations,
}
# 执行增强问答
result = await reasoner.enhanced_qa(
- query=query.query,
- project_context=project_context,
- graph_data=graph_data,
- reasoning_depth=query.reasoning_depth
+ query=query.query, project_context=project_context, graph_data=graph_data, reasoning_depth=query.reasoning_depth
)
return {
@@ -2152,17 +2120,12 @@ async def reasoning_query(project_id: str, query: ReasoningQuery, _=Depends(veri
"confidence": result.confidence,
"evidence": result.evidence,
"knowledge_gaps": result.gaps,
- "project_id": project_id
+ "project_id": project_id,
}
@app.post("/api/v1/projects/{project_id}/reasoning/inference-path")
-async def find_inference_path(
- project_id: str,
- start_entity: str,
- end_entity: str,
- _=Depends(verify_api_key)
-):
+async def find_inference_path(project_id: str, start_entity: str, end_entity: str, _=Depends(verify_api_key)):
"""
发现两个实体之间的推理路径
@@ -2182,10 +2145,7 @@ async def find_inference_path(
entities = db.list_project_entities(project_id)
relations = db.list_project_relations(project_id)
- graph_data = {
- "entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities],
- "relations": relations
- }
+ graph_data = {"entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], "relations": relations}
# 查找推理路径
paths = reasoner.find_inference_paths(start_entity, end_entity, graph_data)
@@ -2197,11 +2157,11 @@ async def find_inference_path(
{
"path": path.path,
"strength": path.strength,
- "path_description": " -> ".join([p["entity"] for p in path.path])
+ "path_description": " -> ".join([p["entity"] for p in path.path]),
}
for path in paths[:5] # 最多返回5条路径
],
- "total_paths": len(paths)
+ "total_paths": len(paths),
}
@@ -2237,28 +2197,19 @@ async def project_summary(project_id: str, req: SummaryRequest, _=Depends(verify
entities = db.list_project_entities(project_id)
relations = db.list_project_relations(project_id)
- graph_data = {
- "entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities],
- "relations": relations
- }
+ graph_data = {"entities": [{"id": e.id, "name": e.name, "type": e.type} for e in entities], "relations": relations}
# 生成总结
summary = await reasoner.summarize_project(
- project_context=project_context,
- graph_data=graph_data,
- summary_type=req.summary_type
+ project_context=project_context, graph_data=graph_data, summary_type=req.summary_type
)
- return {
- "project_id": project_id,
- "summary_type": req.summary_type,
- **summary
- ** summary
- }
+ return {"project_id": project_id, "summary_type": req.summary_type, **summary**summary}
# ==================== Phase 5: 实体属性扩展 API ====================
+
class AttributeTemplateCreate(BaseModel):
name: str
type: str # text, number, date, select, multiselect, boolean
@@ -2296,9 +2247,8 @@ class EntityAttributeBatchSet(BaseModel):
# 属性模板管理 API
@app.post("/api/v1/projects/{project_id}/attribute-templates")
async def create_attribute_template_endpoint(
- project_id: str,
- template: AttributeTemplateCreate,
- _=Depends(verify_api_key)):
+ project_id: str, template: AttributeTemplateCreate, _=Depends(verify_api_key)
+):
"""创建属性模板"""
if not DB_AVAILABLE:
raise HTTPException(status_code=500, detail="Database not available")
@@ -2319,17 +2269,12 @@ async def create_attribute_template_endpoint(
default_value=template.default_value or "",
description=template.description or "",
is_required=template.is_required,
- sort_order=template.sort_order
+ sort_order=template.sort_order,
)
db.create_attribute_template(new_template)
- return {
- "id": new_template.id,
- "name": new_template.name,
- "type": new_template.type,
- "success": True
- }
+ return {"id": new_template.id, "name": new_template.name, "type": new_template.type, "success": True}
@app.get("/api/v1/projects/{project_id}/attribute-templates")
@@ -2350,7 +2295,7 @@ async def list_attribute_templates_endpoint(project_id: str, _=Depends(verify_ap
"default_value": t.default_value,
"description": t.description,
"is_required": t.is_required,
- "sort_order": t.sort_order
+ "sort_order": t.sort_order,
}
for t in templates
]
@@ -2376,15 +2321,14 @@ async def get_attribute_template_endpoint(template_id: str, _=Depends(verify_api
"default_value": template.default_value,
"description": template.description,
"is_required": template.is_required,
- "sort_order": template.sort_order
+ "sort_order": template.sort_order,
}
@app.put("/api/v1/attribute-templates/{template_id}")
async def update_attribute_template_endpoint(
- template_id: str,
- update: AttributeTemplateUpdate,
- _=Depends(verify_api_key)):
+ template_id: str, update: AttributeTemplateUpdate, _=Depends(verify_api_key)
+):
"""更新属性模板"""
if not DB_AVAILABLE:
raise HTTPException(status_code=500, detail="Database not available")
@@ -2397,12 +2341,7 @@ async def update_attribute_template_endpoint(
update_data = {k: v for k, v in update.dict().items() if v is not None}
updated = db.update_attribute_template(template_id, **update_data)
- return {
- "id": updated.id,
- "name": updated.name,
- "type": updated.type,
- "success": True
- }
+ return {"id": updated.id, "name": updated.name, "type": updated.type, "success": True}
@app.delete("/api/v1/attribute-templates/{template_id}")
@@ -2430,13 +2369,13 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
raise HTTPException(status_code=404, detail="Entity not found")
# 验证类型
- valid_types = ['text', 'number', 'date', 'select', 'multiselect']
+ valid_types = ["text", "number", "date", "select", "multiselect"]
if attr.type not in valid_types:
raise HTTPException(status_code=400, detail=f"Invalid type. Must be one of: {valid_types}")
# 处理 value
value = attr.value
- if attr.type == 'multiselect' and isinstance(value, list):
+ if attr.type == "multiselect" and isinstance(value, list):
value = json.dumps(value)
elif value is not None:
value = str(value)
@@ -2449,8 +2388,7 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
# 检查是否已存在
conn = db.get_conn()
existing = conn.execute(
- "SELECT * FROM entity_attributes WHERE entity_id = ? AND name = ?",
- (entity_id, attr.name)
+ "SELECT * FROM entity_attributes WHERE entity_id = ? AND name = ?", (entity_id, attr.name)
).fetchone()
now = datetime.now().isoformat()
@@ -2461,8 +2399,16 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
"""INSERT INTO attribute_history
(id, entity_id, attribute_name, old_value, new_value, changed_by, changed_at, change_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
- (str(uuid.uuid4())[:8], entity_id, attr.name, existing['value'], value,
- "user", now, attr.change_reason or "")
+ (
+ str(uuid.uuid4())[:8],
+ entity_id,
+ attr.name,
+ existing["value"],
+ value,
+ "user",
+ now,
+ attr.change_reason or "",
+ ),
)
# 更新
@@ -2470,9 +2416,9 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
"""UPDATE entity_attributes
SET value = ?, type = ?, options = ?, updated_at = ?
WHERE id = ?""",
- (value, attr.type, options, now, existing['id'])
+ (value, attr.type, options, now, existing["id"]),
)
- attr_id = existing['id']
+ attr_id = existing["id"]
else:
# 创建
attr_id = str(uuid.uuid4())[:8]
@@ -2480,7 +2426,7 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
"""INSERT INTO entity_attributes
(id, entity_id, template_id, name, type, value, options, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (attr_id, entity_id, attr.template_id, attr.name, attr.type, value, options, now, now)
+ (attr_id, entity_id, attr.template_id, attr.name, attr.type, value, options, now, now),
)
# 记录历史
@@ -2488,8 +2434,7 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
"""INSERT INTO attribute_history
(id, entity_id, attribute_name, old_value, new_value, changed_by, changed_at, change_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
- (str(uuid.uuid4())[:8], entity_id, attr.name, None, value,
- "user", now, attr.change_reason or "创建属性")
+ (str(uuid.uuid4())[:8], entity_id, attr.name, None, value, "user", now, attr.change_reason or "创建属性"),
)
conn.commit()
@@ -2501,15 +2446,14 @@ async def set_entity_attribute_endpoint(entity_id: str, attr: EntityAttributeSet
"name": attr.name,
"type": attr.type,
"value": attr.value,
- "success": True
+ "success": True,
}
@app.post("/api/v1/entities/{entity_id}/attributes/batch")
async def batch_set_entity_attributes_endpoint(
- entity_id: str,
- batch: EntityAttributeBatchSet,
- _=Depends(verify_api_key)):
+ entity_id: str, batch: EntityAttributeBatchSet, _=Depends(verify_api_key)
+):
"""批量设置实体属性值"""
if not DB_AVAILABLE:
raise HTTPException(status_code=500, detail="Database not available")
@@ -2526,25 +2470,14 @@ async def batch_set_entity_attributes_endpoint(
template = db.get_attribute_template(attr_data.template_id)
if template:
new_attr = EntityAttribute(
- id=str(uuid.uuid4())[:8],
- entity_id=entity_id,
- template_id=attr_data.template_id,
- value=attr_data.value
+ id=str(uuid.uuid4())[:8], entity_id=entity_id, template_id=attr_data.template_id, value=attr_data.value
+ )
+ db.set_entity_attribute(new_attr, changed_by="user", change_reason=batch.change_reason or "批量更新")
+ results.append(
+ {"template_id": attr_data.template_id, "template_name": template.name, "value": attr_data.value}
)
- db.set_entity_attribute(new_attr, changed_by="user",
- change_reason=batch.change_reason or "批量更新")
- results.append({
- "template_id": attr_data.template_id,
- "template_name": template.name,
- "value": attr_data.value
- })
- return {
- "entity_id": entity_id,
- "updated_count": len(results),
- "attributes": results,
- "success": True
- }
+ return {"entity_id": entity_id, "updated_count": len(results), "attributes": results, "success": True}
@app.get("/api/v1/entities/{entity_id}/attributes")
@@ -2566,22 +2499,22 @@ async def get_entity_attributes_endpoint(entity_id: str, _=Depends(verify_api_ke
"template_id": a.template_id,
"template_name": a.template_name,
"template_type": a.template_type,
- "value": a.value
+ "value": a.value,
}
for a in attrs
]
@app.delete("/api/v1/entities/{entity_id}/attributes/{template_id}")
-async def delete_entity_attribute_endpoint(entity_id: str, template_id: str,
- reason: str | None = "", _=Depends(verify_api_key)):
+async def delete_entity_attribute_endpoint(
+ entity_id: str, template_id: str, reason: str | None = "", _=Depends(verify_api_key)
+):
"""删除实体属性值"""
if not DB_AVAILABLE:
raise HTTPException(status_code=500, detail="Database not available")
db = get_db_manager()
- db.delete_entity_attribute(entity_id, template_id,
- changed_by="user", change_reason=reason)
+ db.delete_entity_attribute(entity_id, template_id, changed_by="user", change_reason=reason)
return {"success": True, "message": "Attribute deleted"}
@@ -2604,7 +2537,7 @@ async def get_entity_attribute_history_endpoint(entity_id: str, limit: int = 50,
"new_value": h.new_value,
"changed_by": h.changed_by,
"changed_at": h.changed_at,
- "change_reason": h.change_reason
+ "change_reason": h.change_reason,
}
for h in history
]
@@ -2628,7 +2561,7 @@ async def get_template_history_endpoint(template_id: str, limit: int = 50, _=Dep
"new_value": h.new_value,
"changed_by": h.changed_by,
"changed_at": h.changed_at,
- "change_reason": h.change_reason
+ "change_reason": h.change_reason,
}
for h in history
]
@@ -2639,7 +2572,7 @@ async def get_template_history_endpoint(template_id: str, limit: int = 50, _=Dep
async def search_entities_by_attributes_endpoint(
project_id: str,
attribute_filter: str | None = None, # JSON 格式: {"职位": "经理", "部门": "技术部"}
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""根据属性筛选搜索实体"""
if not DB_AVAILABLE:
@@ -2660,13 +2593,7 @@ async def search_entities_by_attributes_endpoint(
entities = db.search_entities_by_attributes(project_id, filters)
return [
- {
- "id": e.id,
- "name": e.name,
- "type": e.type,
- "definition": e.definition,
- "attributes": e.attributes
- }
+ {"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "attributes": e.attributes}
for e in entities
]
@@ -2693,34 +2620,38 @@ async def export_graph_svg_endpoint(project_id: str, _=Depends(verify_api_key)):
entities = []
for e in entities_data:
attrs = db.get_entity_attributes(e.id)
- entities.append(ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={a.template_name: a.value for a in attrs}
- ))
+ entities.append(
+ ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={a.template_name: a.value for a in attrs},
+ )
+ )
relations = []
for r in relations_data:
- relations.append(ExportRelation(
- id=r.id,
- source=r.source_name,
- target=r.target_name,
- relation_type=r.relation_type,
- confidence=r.confidence,
- evidence=r.evidence or ""
- ))
+ relations.append(
+ ExportRelation(
+ id=r.id,
+ source=r.source_name,
+ target=r.target_name,
+ relation_type=r.relation_type,
+ confidence=r.confidence,
+ evidence=r.evidence or "",
+ )
+ )
export_mgr = get_export_manager()
svg_content = export_mgr.export_knowledge_graph_svg(project_id, entities, relations)
return StreamingResponse(
- io.BytesIO(svg_content.encode('utf-8')),
+ io.BytesIO(svg_content.encode("utf-8")),
media_type="image/svg+xml",
- headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.svg"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.svg"},
)
@@ -2743,26 +2674,30 @@ async def export_graph_png_endpoint(project_id: str, _=Depends(verify_api_key)):
entities = []
for e in entities_data:
attrs = db.get_entity_attributes(e.id)
- entities.append(ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={a.template_name: a.value for a in attrs}
- ))
+ entities.append(
+ ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={a.template_name: a.value for a in attrs},
+ )
+ )
relations = []
for r in relations_data:
- relations.append(ExportRelation(
- id=r.id,
- source=r.source_name,
- target=r.target_name,
- relation_type=r.relation_type,
- confidence=r.confidence,
- evidence=r.evidence or ""
- ))
+ relations.append(
+ ExportRelation(
+ id=r.id,
+ source=r.source_name,
+ target=r.target_name,
+ relation_type=r.relation_type,
+ confidence=r.confidence,
+ evidence=r.evidence or "",
+ )
+ )
export_mgr = get_export_manager()
png_bytes = export_mgr.export_knowledge_graph_png(project_id, entities, relations)
@@ -2770,7 +2705,7 @@ async def export_graph_png_endpoint(project_id: str, _=Depends(verify_api_key)):
return StreamingResponse(
io.BytesIO(png_bytes),
media_type="image/png",
- headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.png"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-graph-{project_id}.png"},
)
@@ -2791,15 +2726,17 @@ async def export_entities_excel_endpoint(project_id: str, _=Depends(verify_api_k
entities = []
for e in entities_data:
attrs = db.get_entity_attributes(e.id)
- entities.append(ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={a.template_name: a.value for a in attrs}
- ))
+ entities.append(
+ ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={a.template_name: a.value for a in attrs},
+ )
+ )
export_mgr = get_export_manager()
excel_bytes = export_mgr.export_entities_excel(entities)
@@ -2807,7 +2744,7 @@ async def export_entities_excel_endpoint(project_id: str, _=Depends(verify_api_k
return StreamingResponse(
io.BytesIO(excel_bytes),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
- headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.xlsx"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.xlsx"},
)
@@ -2828,23 +2765,25 @@ async def export_entities_csv_endpoint(project_id: str, _=Depends(verify_api_key
entities = []
for e in entities_data:
attrs = db.get_entity_attributes(e.id)
- entities.append(ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={a.template_name: a.value for a in attrs}
- ))
+ entities.append(
+ ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={a.template_name: a.value for a in attrs},
+ )
+ )
export_mgr = get_export_manager()
csv_content = export_mgr.export_entities_csv(entities)
return StreamingResponse(
- io.BytesIO(csv_content.encode('utf-8')),
+ io.BytesIO(csv_content.encode("utf-8")),
media_type="text/csv",
- headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.csv"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-entities-{project_id}.csv"},
)
@@ -2864,22 +2803,24 @@ async def export_relations_csv_endpoint(project_id: str, _=Depends(verify_api_ke
relations = []
for r in relations_data:
- relations.append(ExportRelation(
- id=r.id,
- source=r.source_name,
- target=r.target_name,
- relation_type=r.relation_type,
- confidence=r.confidence,
- evidence=r.evidence or ""
- ))
+ relations.append(
+ ExportRelation(
+ id=r.id,
+ source=r.source_name,
+ target=r.target_name,
+ relation_type=r.relation_type,
+ confidence=r.confidence,
+ evidence=r.evidence or "",
+ )
+ )
export_mgr = get_export_manager()
csv_content = export_mgr.export_relations_csv(relations)
return StreamingResponse(
- io.BytesIO(csv_content.encode('utf-8')),
+ io.BytesIO(csv_content.encode("utf-8")),
media_type="text/csv",
- headers={"Content-Disposition": f"attachment; filename=insightflow-relations-{project_id}.csv"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-relations-{project_id}.csv"},
)
@@ -2903,38 +2844,39 @@ async def export_report_pdf_endpoint(project_id: str, _=Depends(verify_api_key))
entities = []
for e in entities_data:
attrs = db.get_entity_attributes(e.id)
- entities.append(ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={a.template_name: a.value for a in attrs}
- ))
+ entities.append(
+ ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={a.template_name: a.value for a in attrs},
+ )
+ )
relations = []
for r in relations_data:
- relations.append(ExportRelation(
- id=r.id,
- source=r.source_name,
- target=r.target_name,
- relation_type=r.relation_type,
- confidence=r.confidence,
- evidence=r.evidence or ""
- ))
+ relations.append(
+ ExportRelation(
+ id=r.id,
+ source=r.source_name,
+ target=r.target_name,
+ relation_type=r.relation_type,
+ confidence=r.confidence,
+ evidence=r.evidence or "",
+ )
+ )
transcripts = []
for t in transcripts_data:
segments = json.loads(t.segments) if t.segments else []
- transcripts.append(ExportTranscript(
- id=t.id,
- name=t.name,
- type=t.type,
- content=t.full_text or "",
- segments=segments,
- entity_mentions=[]
- ))
+ transcripts.append(
+ ExportTranscript(
+ id=t.id, name=t.name, type=t.type, content=t.full_text or "", segments=segments, entity_mentions=[]
+ )
+ )
# 获取项目总结
summary = ""
@@ -2954,7 +2896,7 @@ async def export_report_pdf_endpoint(project_id: str, _=Depends(verify_api_key))
return StreamingResponse(
io.BytesIO(pdf_bytes),
media_type="application/pdf",
- headers={"Content-Disposition": f"attachment; filename=insightflow-report-{project_id}.pdf"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-report-{project_id}.pdf"},
)
@@ -2978,48 +2920,47 @@ async def export_project_json_endpoint(project_id: str, _=Depends(verify_api_key
entities = []
for e in entities_data:
attrs = db.get_entity_attributes(e.id)
- entities.append(ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={a.template_name: a.value for a in attrs}
- ))
+ entities.append(
+ ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={a.template_name: a.value for a in attrs},
+ )
+ )
relations = []
for r in relations_data:
- relations.append(ExportRelation(
- id=r.id,
- source=r.source_name,
- target=r.target_name,
- relation_type=r.relation_type,
- confidence=r.confidence,
- evidence=r.evidence or ""
- ))
+ relations.append(
+ ExportRelation(
+ id=r.id,
+ source=r.source_name,
+ target=r.target_name,
+ relation_type=r.relation_type,
+ confidence=r.confidence,
+ evidence=r.evidence or "",
+ )
+ )
transcripts = []
for t in transcripts_data:
segments = json.loads(t.segments) if t.segments else []
- transcripts.append(ExportTranscript(
- id=t.id,
- name=t.name,
- type=t.type,
- content=t.full_text or "",
- segments=segments,
- entity_mentions=[]
- ))
+ transcripts.append(
+ ExportTranscript(
+ id=t.id, name=t.name, type=t.type, content=t.full_text or "", segments=segments, entity_mentions=[]
+ )
+ )
export_mgr = get_export_manager()
- json_content = export_mgr.export_project_json(
- project_id, project.name, entities, relations, transcripts
- )
+ json_content = export_mgr.export_project_json(project_id, project.name, entities, relations, transcripts)
return StreamingResponse(
- io.BytesIO(json_content.encode('utf-8')),
+ io.BytesIO(json_content.encode("utf-8")),
media_type="application/json",
- headers={"Content-Disposition": f"attachment; filename=insightflow-project-{project_id}.json"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-project-{project_id}.json"},
)
@@ -3039,15 +2980,18 @@ async def export_transcript_markdown_endpoint(transcript_id: str, _=Depends(veri
# 获取项目实体用于映射
entities_data = db.get_project_entities(transcript.project_id)
- entities_map = {e.id: ExportEntity(
- id=e.id,
- name=e.name,
- type=e.type,
- definition=e.definition or "",
- aliases=json.loads(e.aliases) if e.aliases else [],
- mention_count=e.mention_count,
- attributes={}
- ) for e in entities_data}
+ entities_map = {
+ e.id: ExportEntity(
+ id=e.id,
+ name=e.name,
+ type=e.type,
+ definition=e.definition or "",
+ aliases=json.loads(e.aliases) if e.aliases else [],
+ mention_count=e.mention_count,
+ attributes={},
+ )
+ for e in entities_data
+ }
segments = json.loads(transcript.segments) if transcript.segments else []
@@ -3057,26 +3001,25 @@ async def export_transcript_markdown_endpoint(transcript_id: str, _=Depends(veri
type=transcript.type,
content=transcript.full_text or "",
segments=segments,
- entity_mentions=[{
- "entity_id": m.entity_id,
- "entity_name": m.entity_name,
- "position": m.position,
- "context": m.context
- } for m in mentions]
+ entity_mentions=[
+ {"entity_id": m.entity_id, "entity_name": m.entity_name, "position": m.position, "context": m.context}
+ for m in mentions
+ ],
)
export_mgr = get_export_manager()
markdown_content = export_mgr.export_transcript_markdown(export_transcript, entities_map)
return StreamingResponse(
- io.BytesIO(markdown_content.encode('utf-8')),
+ io.BytesIO(markdown_content.encode("utf-8")),
media_type="text/markdown",
- headers={"Content-Disposition": f"attachment; filename=insightflow-transcript-{transcript_id}.md"}
+ headers={"Content-Disposition": f"attachment; filename=insightflow-transcript-{transcript_id}.md"},
)
# ==================== Neo4j Graph Database API ====================
+
class Neo4jSyncRequest(BaseModel):
project_id: str
@@ -3096,11 +3039,7 @@ class GraphQueryRequest(BaseModel):
async def neo4j_status(_=Depends(verify_api_key)):
"""获取 Neo4j 连接状态"""
if not NEO4J_AVAILABLE:
- return {
- "available": False,
- "connected": False,
- "message": "Neo4j driver not installed"
- }
+ return {"available": False, "connected": False, "message": "Neo4j driver not installed"}
try:
manager = get_neo4j_manager()
@@ -3109,14 +3048,10 @@ async def neo4j_status(_=Depends(verify_api_key)):
"available": True,
"connected": connected,
"uri": manager.uri if connected else None,
- "message": "Connected" if connected else "Not connected"
+ "message": "Connected" if connected else "Not connected",
}
except Exception as e:
- return {
- "available": True,
- "connected": False,
- "message": str(e)
- }
+ return {"available": True, "connected": False, "message": str(e)}
@app.post("/api/v1/neo4j/sync")
@@ -3141,34 +3076,35 @@ async def neo4j_sync_project(request: Neo4jSyncRequest, _=Depends(verify_api_key
entities = db.get_project_entities(request.project_id)
entities_data = []
for e in entities:
- entities_data.append({
- "id": e.id,
- "name": e.name,
- "type": e.type,
- "definition": e.definition,
- "aliases": json.loads(e.aliases) if e.aliases else [],
- "properties": e.attributes if hasattr(e, 'attributes') else {}
- })
+ entities_data.append(
+ {
+ "id": e.id,
+ "name": e.name,
+ "type": e.type,
+ "definition": e.definition,
+ "aliases": json.loads(e.aliases) if e.aliases else [],
+ "properties": e.attributes if hasattr(e, "attributes") else {},
+ }
+ )
# 获取项目所有关系
relations = db.get_project_relations(request.project_id)
relations_data = []
for r in relations:
- relations_data.append({
- "id": r.id,
- "source_entity_id": r.source_entity_id,
- "target_entity_id": r.target_entity_id,
- "relation_type": r.relation_type,
- "evidence": r.evidence,
- "properties": {}
- })
+ relations_data.append(
+ {
+ "id": r.id,
+ "source_entity_id": r.source_entity_id,
+ "target_entity_id": r.target_entity_id,
+ "relation_type": r.relation_type,
+ "evidence": r.evidence,
+ "properties": {},
+ }
+ )
# 同步到 Neo4j
sync_project_to_neo4j(
- project_id=request.project_id,
- project_name=project.name,
- entities=entities_data,
- relations=relations_data
+ project_id=request.project_id, project_name=project.name, entities=entities_data, relations=relations_data
)
return {
@@ -3176,7 +3112,7 @@ async def neo4j_sync_project(request: Neo4jSyncRequest, _=Depends(verify_api_key
"project_id": request.project_id,
"entities_synced": len(entities_data),
"relations_synced": len(relations_data),
- "message": f"Synced {len(entities_data)} entities and {len(relations_data)} relations to Neo4j"
+ "message": f"Synced {len(entities_data)} entities and {len(relations_data)} relations to Neo4j",
}
@@ -3204,26 +3140,12 @@ async def find_shortest_path(request: PathQueryRequest, _=Depends(verify_api_key
if not manager.is_connected():
raise HTTPException(status_code=503, detail="Neo4j not connected")
- path = manager.find_shortest_path(
- request.source_entity_id,
- request.target_entity_id,
- request.max_depth
- )
+ path = manager.find_shortest_path(request.source_entity_id, request.target_entity_id, request.max_depth)
if not path:
- return {
- "found": False,
- "message": "No path found between entities"
- }
+ return {"found": False, "message": "No path found between entities"}
- return {
- "found": True,
- "path": {
- "nodes": path.nodes,
- "relationships": path.relationships,
- "length": path.length
- }
- }
+ return {"found": True, "path": {"nodes": path.nodes, "relationships": path.relationships, "length": path.length}}
@app.post("/api/v1/graph/paths")
@@ -3236,32 +3158,16 @@ async def find_all_paths(request: PathQueryRequest, _=Depends(verify_api_key)):
if not manager.is_connected():
raise HTTPException(status_code=503, detail="Neo4j not connected")
- paths = manager.find_all_paths(
- request.source_entity_id,
- request.target_entity_id,
- request.max_depth
- )
+ paths = manager.find_all_paths(request.source_entity_id, request.target_entity_id, request.max_depth)
return {
"count": len(paths),
- "paths": [
- {
- "nodes": p.nodes,
- "relationships": p.relationships,
- "length": p.length
- }
- for p in paths
- ]
+ "paths": [{"nodes": p.nodes, "relationships": p.relationships, "length": p.length} for p in paths],
}
@app.get("/api/v1/entities/{entity_id}/neighbors")
-async def get_entity_neighbors(
- entity_id: str,
- relation_type: str = None,
- limit: int = 50,
- _=Depends(verify_api_key)
-):
+async def get_entity_neighbors(entity_id: str, relation_type: str = None, limit: int = 50, _=Depends(verify_api_key)):
"""获取实体的邻居节点"""
if not NEO4J_AVAILABLE:
raise HTTPException(status_code=503, detail="Neo4j not available")
@@ -3271,11 +3177,7 @@ async def get_entity_neighbors(
raise HTTPException(status_code=503, detail="Neo4j not connected")
neighbors = manager.find_neighbors(entity_id, relation_type, limit)
- return {
- "entity_id": entity_id,
- "count": len(neighbors),
- "neighbors": neighbors
- }
+ return {"entity_id": entity_id, "count": len(neighbors), "neighbors": neighbors}
@app.get("/api/v1/entities/{entity_id1}/common-neighbors/{entity_id2}")
@@ -3289,20 +3191,11 @@ async def get_common_neighbors(entity_id1: str, entity_id2: str, _=Depends(verif
raise HTTPException(status_code=503, detail="Neo4j not connected")
common = manager.find_common_neighbors(entity_id1, entity_id2)
- return {
- "entity_id1": entity_id1,
- "entity_id2": entity_id2,
- "count": len(common),
- "common_neighbors": common
- }
+ return {"entity_id1": entity_id1, "entity_id2": entity_id2, "count": len(common), "common_neighbors": common}
@app.get("/api/v1/projects/{project_id}/graph/centrality")
-async def get_centrality_analysis(
- project_id: str,
- metric: str = "degree",
- _=Depends(verify_api_key)
-):
+async def get_centrality_analysis(project_id: str, metric: str = "degree", _=Depends(verify_api_key)):
"""获取中心性分析结果"""
if not NEO4J_AVAILABLE:
raise HTTPException(status_code=503, detail="Neo4j not available")
@@ -3316,14 +3209,8 @@ async def get_centrality_analysis(
"metric": metric,
"count": len(rankings),
"rankings": [
- {
- "entity_id": r.entity_id,
- "entity_name": r.entity_name,
- "score": r.score,
- "rank": r.rank
- }
- for r in rankings
- ]
+ {"entity_id": r.entity_id, "entity_name": r.entity_name, "score": r.score, "rank": r.rank} for r in rankings
+ ],
}
@@ -3341,14 +3228,9 @@ async def get_communities(project_id: str, _=Depends(verify_api_key)):
return {
"count": len(communities),
"communities": [
- {
- "community_id": c.community_id,
- "size": c.size,
- "density": c.density,
- "nodes": c.nodes
- }
+ {"community_id": c.community_id, "size": c.size, "density": c.density, "nodes": c.nodes}
for c in communities
- ]
+ ],
}
@@ -3368,6 +3250,7 @@ async def get_subgraph(request: GraphQueryRequest, _=Depends(verify_api_key)):
# ==================== Phase 6: API Key Management Endpoints ====================
+
@app.post("/api/v1/api-keys", response_model=ApiKeyCreateResponse, tags=["API Keys"])
async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)):
"""
@@ -3386,7 +3269,7 @@ async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)):
name=request.name,
permissions=request.permissions,
rate_limit=request.rate_limit,
- expires_days=request.expires_days
+ expires_days=request.expires_days,
)
return ApiKeyCreateResponse(
@@ -3401,18 +3284,13 @@ async def create_api_key(request: ApiKeyCreate, _=Depends(verify_api_key)):
created_at=api_key.created_at,
expires_at=api_key.expires_at,
last_used_at=api_key.last_used_at,
- total_calls=api_key.total_calls
- )
+ total_calls=api_key.total_calls,
+ ),
)
@app.get("/api/v1/api-keys", response_model=ApiKeyListResponse, tags=["API Keys"])
-async def list_api_keys(
- status: str | None = None,
- limit: int = 100,
- offset: int = 0,
- _=Depends(verify_api_key)
-):
+async def list_api_keys(status: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key)):
"""
列出所有 API Keys
@@ -3438,11 +3316,11 @@ async def list_api_keys(
created_at=k.created_at,
expires_at=k.expires_at,
last_used_at=k.last_used_at,
- total_calls=k.total_calls
+ total_calls=k.total_calls,
)
for k in keys
],
- total=len(keys)
+ total=len(keys),
)
@@ -3468,7 +3346,7 @@ async def get_api_key(key_id: str, _=Depends(verify_api_key)):
created_at=key.created_at,
expires_at=key.expires_at,
last_used_at=key.last_used_at,
- total_calls=key.total_calls
+ total_calls=key.total_calls,
)
@@ -3513,7 +3391,7 @@ async def update_api_key(key_id: str, request: ApiKeyUpdate, _=Depends(verify_ap
created_at=key.created_at,
expires_at=key.expires_at,
last_used_at=key.last_used_at,
- total_calls=key.total_calls
+ total_calls=key.total_calls,
)
@@ -3556,19 +3434,12 @@ async def get_api_key_stats(key_id: str, days: int = 30, _=Depends(verify_api_ke
stats = key_manager.get_call_stats(key_id, days=days)
return ApiStatsResponse(
- summary=ApiCallStats(**stats["summary"]),
- endpoints=stats["endpoints"],
- daily=stats["daily"]
+ summary=ApiCallStats(**stats["summary"]), endpoints=stats["endpoints"], daily=stats["daily"]
)
@app.get("/api/v1/api-keys/{key_id}/logs", response_model=ApiLogsResponse, tags=["API Keys"])
-async def get_api_key_logs(
- key_id: str,
- limit: int = 100,
- offset: int = 0,
- _=Depends(verify_api_key)
-):
+async def get_api_key_logs(key_id: str, limit: int = 100, offset: int = 0, _=Depends(verify_api_key)):
"""
获取 API Key 的调用日志
@@ -3598,11 +3469,11 @@ async def get_api_key_logs(
ip_address=log["ip_address"],
user_agent=log["user_agent"],
error_message=log["error_message"],
- created_at=log["created_at"]
+ created_at=log["created_at"],
)
for log in logs
],
- total=len(logs)
+ total=len(logs),
)
@@ -3610,17 +3481,12 @@ async def get_api_key_logs(
async def get_rate_limit_status(request: Request, _=Depends(verify_api_key)):
"""获取当前请求的限流状态"""
if not RATE_LIMITER_AVAILABLE:
- return RateLimitStatus(
- limit=60,
- remaining=60,
- reset_time=int(time.time()) + 60,
- window="minute"
- )
+ return RateLimitStatus(limit=60, remaining=60, reset_time=int(time.time()) + 60, window="minute")
limiter = get_rate_limiter()
# 获取限流键
- if hasattr(request.state, 'api_key') and request.state.api_key:
+ if hasattr(request.state, "api_key") and request.state.api_key:
api_key = request.state.api_key
limit_key = f"api_key:{api_key.id}"
limit = api_key.rate_limit
@@ -3631,24 +3497,16 @@ async def get_rate_limit_status(request: Request, _=Depends(verify_api_key)):
info = await limiter.get_limit_info(limit_key)
- return RateLimitStatus(
- limit=limit,
- remaining=info.remaining,
- reset_time=info.reset_time,
- window="minute"
- )
+ return RateLimitStatus(limit=limit, remaining=info.remaining, reset_time=info.reset_time, window="minute")
# ==================== Phase 6: System Endpoints ====================
+
@app.get("/api/v1/health", tags=["System"])
async def api_health_check():
"""健康检查端点"""
- return {
- "status": "healthy",
- "version": "0.7.0",
- "timestamp": datetime.now().isoformat()
- }
+ return {"status": "healthy", "version": "0.7.0", "timestamp": datetime.now().isoformat()}
@app.get("/api/v1/status", tags=["System"])
@@ -3675,7 +3533,7 @@ async def system_status():
"documentation": "/docs",
"openapi": "/openapi.json",
},
- "timestamp": datetime.now().isoformat()
+ "timestamp": datetime.now().isoformat(),
}
return status
@@ -3691,6 +3549,7 @@ def get_workflow_manager_instance():
global _workflow_manager
if _workflow_manager is None and WORKFLOW_AVAILABLE and DB_AVAILABLE:
from workflow_manager import WorkflowManager
+
db = get_db_manager()
_workflow_manager = WorkflowManager(db)
_workflow_manager.start()
@@ -3733,7 +3592,7 @@ async def create_workflow_endpoint(request: WorkflowCreate, _=Depends(verify_api
schedule=request.schedule,
schedule_type=request.schedule_type,
config=request.config,
- webhook_ids=request.webhook_ids
+ webhook_ids=request.webhook_ids,
)
created = manager.create_workflow(workflow)
@@ -3756,7 +3615,7 @@ async def create_workflow_endpoint(request: WorkflowCreate, _=Depends(verify_api
next_run_at=created.next_run_at,
run_count=created.run_count,
success_count=created.success_count,
- fail_count=created.fail_count
+ fail_count=created.fail_count,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -3767,7 +3626,7 @@ async def list_workflows_endpoint(
project_id: str | None = None,
status: str | None = None,
workflow_type: str | None = None,
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取工作流列表"""
if not WORKFLOW_AVAILABLE:
@@ -3796,11 +3655,11 @@ async def list_workflows_endpoint(
next_run_at=w.next_run_at,
run_count=w.run_count,
success_count=w.success_count,
- fail_count=w.fail_count
+ fail_count=w.fail_count,
)
for w in workflows
],
- total=len(workflows)
+ total=len(workflows),
)
@@ -3834,7 +3693,7 @@ async def get_workflow_endpoint(workflow_id: str, _=Depends(verify_api_key)):
next_run_at=workflow.next_run_at,
run_count=workflow.run_count,
success_count=workflow.success_count,
- fail_count=workflow.fail_count
+ fail_count=workflow.fail_count,
)
@@ -3870,7 +3729,7 @@ async def update_workflow_endpoint(workflow_id: str, request: WorkflowUpdate, _=
next_run_at=updated.next_run_at,
run_count=updated.run_count,
success_count=updated.success_count,
- fail_count=updated.fail_count
+ fail_count=updated.fail_count,
)
@@ -3891,9 +3750,8 @@ async def delete_workflow_endpoint(workflow_id: str, _=Depends(verify_api_key)):
@app.post("/api/v1/workflows/{workflow_id}/trigger", response_model=WorkflowTriggerResponse, tags=["Workflows"])
async def trigger_workflow_endpoint(
- workflow_id: str,
- request: WorkflowTriggerRequest = None,
- _=Depends(verify_api_key)):
+ workflow_id: str, request: WorkflowTriggerRequest = None, _=Depends(verify_api_key)
+):
"""手动触发工作流"""
if not WORKFLOW_AVAILABLE:
raise HTTPException(status_code=503, detail="Workflow automation not available")
@@ -3901,17 +3759,14 @@ async def trigger_workflow_endpoint(
manager = get_workflow_manager_instance()
try:
- result = await manager.execute_workflow(
- workflow_id,
- input_data=request.input_data if request else {}
- )
+ result = await manager.execute_workflow(workflow_id, input_data=request.input_data if request else {})
return WorkflowTriggerResponse(
success=result["success"],
workflow_id=result["workflow_id"],
log_id=result["log_id"],
results=result["results"],
- duration_ms=result["duration_ms"]
+ duration_ms=result["duration_ms"],
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -3921,11 +3776,7 @@ async def trigger_workflow_endpoint(
@app.get("/api/v1/workflows/{workflow_id}/logs", response_model=WorkflowLogListResponse, tags=["Workflows"])
async def get_workflow_logs_endpoint(
- workflow_id: str,
- status: str | None = None,
- limit: int = 100,
- offset: int = 0,
- _=Depends(verify_api_key)
+ workflow_id: str, status: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key)
):
"""获取工作流执行日志"""
if not WORKFLOW_AVAILABLE:
@@ -3947,11 +3798,11 @@ async def get_workflow_logs_endpoint(
input_data=log.input_data,
output_data=log.output_data,
error_message=log.error_message,
- created_at=log.created_at
+ created_at=log.created_at,
)
for log in logs
],
- total=len(logs)
+ total=len(logs),
)
@@ -3969,6 +3820,7 @@ async def get_workflow_stats_endpoint(workflow_id: str, days: int = 30, _=Depend
# ==================== Phase 7: Webhook Endpoints ====================
+
@app.post("/api/v1/webhooks", response_model=WebhookResponse, tags=["Webhooks"])
async def create_webhook_endpoint(request: WebhookCreate, _=Depends(verify_api_key)):
"""
@@ -3993,7 +3845,7 @@ async def create_webhook_endpoint(request: WebhookCreate, _=Depends(verify_api_k
url=request.url,
secret=request.secret,
headers=request.headers,
- template=request.template
+ template=request.template,
)
created = manager.create_webhook(webhook)
@@ -4010,7 +3862,7 @@ async def create_webhook_endpoint(request: WebhookCreate, _=Depends(verify_api_k
updated_at=created.updated_at,
last_used_at=created.last_used_at,
success_count=created.success_count,
- fail_count=created.fail_count
+ fail_count=created.fail_count,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -4039,11 +3891,11 @@ async def list_webhooks_endpoint(_=Depends(verify_api_key)):
updated_at=w.updated_at,
last_used_at=w.last_used_at,
success_count=w.success_count,
- fail_count=w.fail_count
+ fail_count=w.fail_count,
)
for w in webhooks
],
- total=len(webhooks)
+ total=len(webhooks),
)
@@ -4071,7 +3923,7 @@ async def get_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)):
updated_at=webhook.updated_at,
last_used_at=webhook.last_used_at,
success_count=webhook.success_count,
- fail_count=webhook.fail_count
+ fail_count=webhook.fail_count,
)
@@ -4101,7 +3953,7 @@ async def update_webhook_endpoint(webhook_id: str, request: WebhookUpdate, _=Dep
updated_at=updated.updated_at,
last_used_at=updated.last_used_at,
success_count=updated.success_count,
- fail_count=updated.fail_count
+ fail_count=updated.fail_count,
)
@@ -4138,7 +3990,9 @@ async def test_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)):
}
if webhook.webhook_type == "slack":
- test_message = {"text": "🔔 这是来自 InsightFlow 的 Webhook 测试消息\n\n如果您收到这条消息,说明 Webhook 配置正确!"}
+ test_message = {
+ "text": "🔔 这是来自 InsightFlow 的 Webhook 测试消息\n\n如果您收到这条消息,说明 Webhook 配置正确!"
+ }
success = await manager.notifier.send(webhook, test_message)
manager.update_webhook_stats(webhook_id, success)
@@ -4151,6 +4005,7 @@ async def test_webhook_endpoint(webhook_id: str, _=Depends(verify_api_key)):
# ==================== Phase 7: Multimodal Support Endpoints ====================
+
# Pydantic Models for Multimodal API
class VideoUploadResponse(BaseModel):
video_id: str
@@ -4208,10 +4063,7 @@ class MultimodalStatsResponse(BaseModel):
@app.post("/api/v1/projects/{project_id}/upload-video", response_model=VideoUploadResponse, tags=["Multimodal"])
async def upload_video_endpoint(
- project_id: str,
- file: UploadFile = File(...),
- extract_interval: int = Form(5),
- _=Depends(verify_api_key)
+ project_id: str, file: UploadFile = File(...), extract_interval: int = Form(5), _=Depends(verify_api_key)
):
"""
上传视频文件进行处理
@@ -4261,10 +4113,21 @@ async def upload_video_endpoint(
audio_transcript_id, full_ocr_text, extracted_entities,
extracted_relations, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (video_id, project_id, file.filename, video_info.get('duration', 0),
- video_info.get('fps', 0),
- json.dumps({'width': video_info.get('width', 0), 'height': video_info.get('height', 0)}),
- None, result.full_text, '[]', '[]', 'completed', now, now)
+ (
+ video_id,
+ project_id,
+ file.filename,
+ video_info.get("duration", 0),
+ video_info.get("fps", 0),
+ json.dumps({"width": video_info.get("width", 0), "height": video_info.get("height", 0)}),
+ None,
+ result.full_text,
+ "[]",
+ "[]",
+ "completed",
+ now,
+ now,
+ ),
)
# 保存关键帧信息
@@ -4273,8 +4136,16 @@ async def upload_video_endpoint(
"""INSERT INTO video_frames
(id, video_id, frame_number, timestamp, image_url, ocr_text, extracted_entities, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
- (frame.id, frame.video_id, frame.frame_number, frame.timestamp,
- frame.frame_path, frame.ocr_text, json.dumps(frame.entities_detected), now)
+ (
+ frame.id,
+ frame.video_id,
+ frame.frame_number,
+ frame.timestamp,
+ frame.frame_path,
+ frame.ocr_text,
+ json.dumps(frame.entities_detected),
+ now,
+ ),
)
conn.commit()
@@ -4292,13 +4163,15 @@ async def upload_video_endpoint(
if existing:
entity_name_to_id[raw_ent["name"]] = existing.id
else:
- new_ent = db.create_entity(Entity(
- id=str(uuid.uuid4())[:8],
- project_id=project_id,
- name=raw_ent["name"],
- type=raw_ent.get("type", "OTHER"),
- definition=raw_ent.get("definition", "")
- ))
+ new_ent = db.create_entity(
+ Entity(
+ id=str(uuid.uuid4())[:8],
+ project_id=project_id,
+ name=raw_ent["name"],
+ type=raw_ent.get("type", "OTHER"),
+ definition=raw_ent.get("definition", ""),
+ )
+ )
entity_name_to_id[raw_ent["name"]] = new_ent.id
# 保存多模态实体提及
@@ -4307,8 +4180,17 @@ async def upload_video_endpoint(
"""INSERT OR REPLACE INTO multimodal_mentions
(id, project_id, entity_id, modality, source_id, source_type, text_snippet, confidence, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (str(uuid.uuid4())[:8], project_id, entity_name_to_id[raw_ent["name"]],
- 'video', video_id, 'video_frame', raw_ent.get("name", ""), 1.0, now)
+ (
+ str(uuid.uuid4())[:8],
+ project_id,
+ entity_name_to_id[raw_ent["name"]],
+ "video",
+ video_id,
+ "video_frame",
+ raw_ent.get("name", ""),
+ 1.0,
+ now,
+ ),
)
conn.commit()
conn.close()
@@ -4323,14 +4205,14 @@ async def upload_video_endpoint(
source_entity_id=source_id,
target_entity_id=target_id,
relation_type=rel.get("type", "related"),
- evidence=result.full_text[:200]
+ evidence=result.full_text[:200],
)
# 更新视频的实体和关系信息
conn = db.get_conn()
conn.execute(
"UPDATE videos SET extracted_entities = ?, extracted_relations = ? WHERE id = ?",
- (json.dumps(raw_entities), json.dumps(raw_relations), video_id)
+ (json.dumps(raw_entities), json.dumps(raw_relations), video_id),
)
conn.commit()
conn.close()
@@ -4343,16 +4225,13 @@ async def upload_video_endpoint(
audio_extracted=bool(result.audio_path),
frame_count=len(result.frames),
ocr_text_preview=result.full_text[:200] + "..." if len(result.full_text) > 200 else result.full_text,
- message="Video processed successfully"
+ message="Video processed successfully",
)
@app.post("/api/v1/projects/{project_id}/upload-image", response_model=ImageUploadResponse, tags=["Multimodal"])
async def upload_image_endpoint(
- project_id: str,
- file: UploadFile = File(...),
- detect_type: bool = Form(True),
- _=Depends(verify_api_key)
+ project_id: str, file: UploadFile = File(...), detect_type: bool = Form(True), _=Depends(verify_api_key)
):
"""
上传图片文件进行处理
@@ -4397,10 +4276,18 @@ async def upload_image_endpoint(
(id, project_id, filename, ocr_text, description,
extracted_entities, extracted_relations, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (image_id, project_id, file.filename, result.ocr_text, result.description,
- json.dumps([{"name": e.name, "type": e.type, "confidence": e.confidence} for e in result.entities]),
- json.dumps([{"source": r.source, "target": r.target, "type": r.relation_type} for r in result.relations]),
- 'completed', now, now)
+ (
+ image_id,
+ project_id,
+ file.filename,
+ result.ocr_text,
+ result.description,
+ json.dumps([{"name": e.name, "type": e.type, "confidence": e.confidence} for e in result.entities]),
+ json.dumps([{"source": r.source, "target": r.target, "type": r.relation_type} for r in result.relations]),
+ "completed",
+ now,
+ now,
+ ),
)
conn.commit()
conn.close()
@@ -4410,13 +4297,11 @@ async def upload_image_endpoint(
existing = align_entity(project_id, entity.name, db, "")
if not existing:
- new_ent = db.create_entity(Entity(
- id=str(uuid.uuid4())[:8],
- project_id=project_id,
- name=entity.name,
- type=entity.type,
- definition=""
- ))
+ new_ent = db.create_entity(
+ Entity(
+ id=str(uuid.uuid4())[:8], project_id=project_id, name=entity.name, type=entity.type, definition=""
+ )
+ )
entity_id = new_ent.id
else:
entity_id = existing.id
@@ -4427,8 +4312,17 @@ async def upload_image_endpoint(
"""INSERT OR REPLACE INTO multimodal_mentions
(id, project_id, entity_id, modality, source_id, source_type, text_snippet, confidence, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (str(uuid.uuid4())[:8], project_id, entity_id,
- 'image', image_id, result.image_type, entity.name, entity.confidence, now)
+ (
+ str(uuid.uuid4())[:8],
+ project_id,
+ entity_id,
+ "image",
+ image_id,
+ result.image_type,
+ entity.name,
+ entity.confidence,
+ now,
+ ),
)
conn.commit()
conn.close()
@@ -4444,7 +4338,7 @@ async def upload_image_endpoint(
source_entity_id=source_entity.id,
target_entity_id=target_entity.id,
relation_type=relation.relation_type,
- evidence=result.ocr_text[:200]
+ evidence=result.ocr_text[:200],
)
return ImageUploadResponse(
@@ -4455,16 +4349,12 @@ async def upload_image_endpoint(
ocr_text_preview=result.ocr_text[:200] + "..." if len(result.ocr_text) > 200 else result.ocr_text,
description=result.description,
entity_count=len(result.entities),
- status="completed"
+ status="completed",
)
@app.post("/api/v1/projects/{project_id}/upload-images-batch", tags=["Multimodal"])
-async def upload_images_batch_endpoint(
- project_id: str,
- files: list[UploadFile] = File(...),
- _=Depends(verify_api_key)
-):
+async def upload_images_batch_endpoint(project_id: str, files: list[UploadFile] = File(...), _=Depends(verify_api_key)):
"""
批量上传图片文件进行处理
@@ -4506,43 +4396,46 @@ async def upload_images_batch_endpoint(
(id, project_id, filename, ocr_text, description,
extracted_entities, extracted_relations, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (image_id, project_id, "batch_image", result.ocr_text, result.description,
- json.dumps([{"name": e.name, "type": e.type} for e in result.entities]),
- json.dumps([{"source": r.source, "target": r.target} for r in result.relations]),
- 'completed', now, now)
+ (
+ image_id,
+ project_id,
+ "batch_image",
+ result.ocr_text,
+ result.description,
+ json.dumps([{"name": e.name, "type": e.type} for e in result.entities]),
+ json.dumps([{"source": r.source, "target": r.target} for r in result.relations]),
+ "completed",
+ now,
+ now,
+ ),
)
conn.commit()
conn.close()
- results.append({
- "image_id": image_id,
- "status": "success",
- "image_type": result.image_type,
- "entity_count": len(result.entities)
- })
+ results.append(
+ {
+ "image_id": image_id,
+ "status": "success",
+ "image_type": result.image_type,
+ "entity_count": len(result.entities),
+ }
+ )
else:
- results.append({
- "image_id": result.image_id,
- "status": "failed",
- "error": result.error_message
- })
+ results.append({"image_id": result.image_id, "status": "failed", "error": result.error_message})
return {
"project_id": project_id,
"total_count": batch_result.total_count,
"success_count": batch_result.success_count,
"failed_count": batch_result.failed_count,
- "results": results
+ "results": results,
}
-@app.post("/api/v1/projects/{project_id}/multimodal/align",
- response_model=MultimodalAlignmentResponse, tags=["Multimodal"])
-async def align_multimodal_entities_endpoint(
- project_id: str,
- threshold: float = 0.85,
- _=Depends(verify_api_key)
-):
+@app.post(
+ "/api/v1/projects/{project_id}/multimodal/align", response_model=MultimodalAlignmentResponse, tags=["Multimodal"]
+)
+async def align_multimodal_entities_endpoint(project_id: str, threshold: float = 0.85, _=Depends(verify_api_key)):
"""
跨模态实体对齐
@@ -4567,35 +4460,34 @@ async def align_multimodal_entities_endpoint(
# 获取多模态提及
conn = db.get_conn()
- mentions = conn.execute(
- """SELECT * FROM multimodal_mentions WHERE project_id = ?""",
- (project_id,)
- ).fetchall()
+ mentions = conn.execute("""SELECT * FROM multimodal_mentions WHERE project_id = ?""", (project_id,)).fetchall()
conn.close()
# 按模态分组实体
modality_entities = {"audio": [], "video": [], "image": [], "document": []}
for mention in mentions:
- modality = mention['modality']
- entity = db.get_entity(mention['entity_id'])
- if entity and entity.id not in [e.get('id') for e in modality_entities[modality]]:
- modality_entities[modality].append({
- 'id': entity.id,
- 'name': entity.name,
- 'type': entity.type,
- 'definition': entity.definition,
- 'aliases': entity.aliases
- })
+ modality = mention["modality"]
+ entity = db.get_entity(mention["entity_id"])
+ if entity and entity.id not in [e.get("id") for e in modality_entities[modality]]:
+ modality_entities[modality].append(
+ {
+ "id": entity.id,
+ "name": entity.name,
+ "type": entity.type,
+ "definition": entity.definition,
+ "aliases": entity.aliases,
+ }
+ )
# 跨模态对齐
linker = get_multimodal_entity_linker(similarity_threshold=threshold)
links = linker.align_cross_modal_entities(
project_id=project_id,
- audio_entities=modality_entities['audio'],
- video_entities=modality_entities['video'],
- image_entities=modality_entities['image'],
- document_entities=modality_entities['document']
+ audio_entities=modality_entities["audio"],
+ video_entities=modality_entities["video"],
+ image_entities=modality_entities["image"],
+ document_entities=modality_entities["document"],
)
# 保存关联到数据库
@@ -4608,20 +4500,29 @@ async def align_multimodal_entities_endpoint(
"""INSERT OR REPLACE INTO multimodal_entity_links
(id, entity_id, linked_entity_id, link_type, confidence, evidence, modalities, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
- (link.id, link.source_entity_id, link.target_entity_id, link.link_type,
- link.confidence, link.evidence,
- json.dumps([link.source_modality, link.target_modality]), now)
+ (
+ link.id,
+ link.source_entity_id,
+ link.target_entity_id,
+ link.link_type,
+ link.confidence,
+ link.evidence,
+ json.dumps([link.source_modality, link.target_modality]),
+ now,
+ ),
+ )
+ saved_links.append(
+ MultimodalEntityLinkResponse(
+ link_id=link.id,
+ source_entity_id=link.source_entity_id,
+ target_entity_id=link.target_entity_id,
+ source_modality=link.source_modality,
+ target_modality=link.target_modality,
+ link_type=link.link_type,
+ confidence=link.confidence,
+ evidence=link.evidence,
+ )
)
- saved_links.append(MultimodalEntityLinkResponse(
- link_id=link.id,
- source_entity_id=link.source_entity_id,
- target_entity_id=link.target_entity_id,
- source_modality=link.source_modality,
- target_modality=link.target_modality,
- link_type=link.link_type,
- confidence=link.confidence,
- evidence=link.evidence
- ))
conn.commit()
conn.close()
@@ -4630,7 +4531,7 @@ async def align_multimodal_entities_endpoint(
project_id=project_id,
aligned_count=len(saved_links),
links=saved_links,
- message=f"Successfully aligned {len(saved_links)} cross-modal entity pairs"
+ message=f"Successfully aligned {len(saved_links)} cross-modal entity pairs",
)
@@ -4652,36 +4553,33 @@ async def get_multimodal_stats_endpoint(project_id: str, _=Depends(verify_api_ke
conn = db.get_conn()
# 统计视频数量
- video_count = conn.execute(
- "SELECT COUNT(*) as count FROM videos WHERE project_id = ?",
- (project_id,)
- ).fetchone()['count']
+ video_count = conn.execute("SELECT COUNT(*) as count FROM videos WHERE project_id = ?", (project_id,)).fetchone()[
+ "count"
+ ]
# 统计图片数量
- image_count = conn.execute(
- "SELECT COUNT(*) as count FROM images WHERE project_id = ?",
- (project_id,)
- ).fetchone()['count']
+ image_count = conn.execute("SELECT COUNT(*) as count FROM images WHERE project_id = ?", (project_id,)).fetchone()[
+ "count"
+ ]
# 统计多模态实体提及
multimodal_count = conn.execute(
- "SELECT COUNT(DISTINCT entity_id) as count FROM multimodal_mentions WHERE project_id = ?",
- (project_id,)
- ).fetchone()['count']
+ "SELECT COUNT(DISTINCT entity_id) as count FROM multimodal_mentions WHERE project_id = ?", (project_id,)
+ ).fetchone()["count"]
# 统计跨模态关联
cross_modal_count = conn.execute(
"SELECT COUNT(*) as count FROM multimodal_entity_links WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)",
- (project_id,)
- ).fetchone()['count']
+ (project_id,),
+ ).fetchone()["count"]
# 模态分布
modality_dist = {}
- for modality in ['audio', 'video', 'image', 'document']:
+ for modality in ["audio", "video", "image", "document"]:
count = conn.execute(
"SELECT COUNT(*) as count FROM multimodal_mentions WHERE project_id = ? AND modality = ?",
- (project_id, modality)
- ).fetchone()['count']
+ (project_id, modality),
+ ).fetchone()["count"]
modality_dist[modality] = count
conn.close()
@@ -4692,7 +4590,7 @@ async def get_multimodal_stats_endpoint(project_id: str, _=Depends(verify_api_ke
image_count=image_count,
multimodal_entity_count=multimodal_count,
cross_modal_links=cross_modal_count,
- modality_distribution=modality_dist
+ modality_distribution=modality_dist,
)
@@ -4709,21 +4607,28 @@ async def list_project_videos_endpoint(project_id: str, _=Depends(verify_api_key
"""SELECT id, filename, duration, fps, resolution,
full_ocr_text, status, created_at
FROM videos WHERE project_id = ? ORDER BY created_at DESC""",
- (project_id,)
+ (project_id,),
).fetchall()
conn.close()
- return [{
- "id": v['id'],
- "filename": v['filename'],
- "duration": v['duration'],
- "fps": v['fps'],
- "resolution": json.loads(v['resolution']) if v['resolution'] else None,
- "ocr_preview": v['full_ocr_text'][:200] + "..." if v['full_ocr_text'] and len(v['full_ocr_text']) > 200 else v['full_ocr_text'],
- "status": v['status'],
- "created_at": v['created_at']
- } for v in videos]
+ return [
+ {
+ "id": v["id"],
+ "filename": v["filename"],
+ "duration": v["duration"],
+ "fps": v["fps"],
+ "resolution": json.loads(v["resolution"]) if v["resolution"] else None,
+ "ocr_preview": (
+ v["full_ocr_text"][:200] + "..."
+ if v["full_ocr_text"] and len(v["full_ocr_text"]) > 200
+ else v["full_ocr_text"]
+ ),
+ "status": v["status"],
+ "created_at": v["created_at"],
+ }
+ for v in videos
+ ]
@app.get("/api/v1/projects/{project_id}/images", tags=["Multimodal"])
@@ -4739,20 +4644,25 @@ async def list_project_images_endpoint(project_id: str, _=Depends(verify_api_key
"""SELECT id, filename, ocr_text, description,
extracted_entities, status, created_at
FROM images WHERE project_id = ? ORDER BY created_at DESC""",
- (project_id,)
+ (project_id,),
).fetchall()
conn.close()
- return [{
- "id": img['id'],
- "filename": img['filename'],
- "ocr_preview": img['ocr_text'][:200] + "..." if img['ocr_text'] and len(img['ocr_text']) > 200 else img['ocr_text'],
- "description": img['description'],
- "entity_count": len(json.loads(img['extracted_entities'])) if img['extracted_entities'] else 0,
- "status": img['status'],
- "created_at": img['created_at']
- } for img in images]
+ return [
+ {
+ "id": img["id"],
+ "filename": img["filename"],
+ "ocr_preview": (
+ img["ocr_text"][:200] + "..." if img["ocr_text"] and len(img["ocr_text"]) > 200 else img["ocr_text"]
+ ),
+ "description": img["description"],
+ "entity_count": len(json.loads(img["extracted_entities"])) if img["extracted_entities"] else 0,
+ "status": img["status"],
+ "created_at": img["created_at"],
+ }
+ for img in images
+ ]
@app.get("/api/v1/videos/{video_id}/frames", tags=["Multimodal"])
@@ -4767,19 +4677,22 @@ async def get_video_frames_endpoint(video_id: str, _=Depends(verify_api_key)):
frames = conn.execute(
"""SELECT id, frame_number, timestamp, image_url, ocr_text, extracted_entities
FROM video_frames WHERE video_id = ? ORDER BY timestamp""",
- (video_id,)
+ (video_id,),
).fetchall()
conn.close()
- return [{
- "id": f['id'],
- "frame_number": f['frame_number'],
- "timestamp": f['timestamp'],
- "image_url": f['image_url'],
- "ocr_text": f['ocr_text'],
- "entities": json.loads(f['extracted_entities']) if f['extracted_entities'] else []
- } for f in frames]
+ return [
+ {
+ "id": f["id"],
+ "frame_number": f["frame_number"],
+ "timestamp": f["timestamp"],
+ "image_url": f["image_url"],
+ "ocr_text": f["ocr_text"],
+ "entities": json.loads(f["extracted_entities"]) if f["extracted_entities"] else [],
+ }
+ for f in frames
+ ]
@app.get("/api/v1/entities/{entity_id}/multimodal-mentions", tags=["Multimodal"])
@@ -4796,22 +4709,25 @@ async def get_entity_multimodal_mentions_endpoint(entity_id: str, _=Depends(veri
FROM multimodal_mentions m
JOIN entities e ON m.entity_id = e.id
WHERE m.entity_id = ? ORDER BY m.created_at DESC""",
- (entity_id,)
+ (entity_id,),
).fetchall()
conn.close()
- return [{
- "id": m['id'],
- "entity_id": m['entity_id'],
- "entity_name": m['entity_name'],
- "modality": m['modality'],
- "source_id": m['source_id'],
- "source_type": m['source_type'],
- "text_snippet": m['text_snippet'],
- "confidence": m['confidence'],
- "created_at": m['created_at']
- } for m in mentions]
+ return [
+ {
+ "id": m["id"],
+ "entity_id": m["entity_id"],
+ "entity_name": m["entity_name"],
+ "modality": m["modality"],
+ "source_id": m["source_id"],
+ "source_type": m["source_type"],
+ "text_snippet": m["text_snippet"],
+ "confidence": m["confidence"],
+ "created_at": m["created_at"],
+ }
+ for m in mentions
+ ]
@app.get("/api/v1/projects/{project_id}/multimodal/suggest-merges", tags=["Multimodal"])
@@ -4834,36 +4750,34 @@ async def suggest_multimodal_merges_endpoint(project_id: str, _=Depends(verify_a
# 获取所有实体
entities = db.list_project_entities(project_id)
- entity_dicts = [{
- 'id': e.id,
- 'name': e.name,
- 'type': e.type,
- 'definition': e.definition,
- 'aliases': e.aliases
- } for e in entities]
+ entity_dicts = [
+ {"id": e.id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": e.aliases} for e in entities
+ ]
# 获取现有链接
conn = db.get_conn()
existing_links = conn.execute(
"""SELECT * FROM multimodal_entity_links
WHERE entity_id IN (SELECT id FROM entities WHERE project_id = ?)""",
- (project_id,)
+ (project_id,),
).fetchall()
conn.close()
existing_link_objects = []
for row in existing_links:
- existing_link_objects.append(EntityLink(
- id=row['id'],
- project_id=project_id,
- source_entity_id=row['entity_id'],
- target_entity_id=row['linked_entity_id'],
- link_type=row['link_type'],
- source_modality='unknown',
- target_modality='unknown',
- confidence=row['confidence'],
- evidence=row['evidence'] or ""
- ))
+ existing_link_objects.append(
+ EntityLink(
+ id=row["id"],
+ project_id=project_id,
+ source_entity_id=row["entity_id"],
+ target_entity_id=row["linked_entity_id"],
+ link_type=row["link_type"],
+ source_modality="unknown",
+ target_modality="unknown",
+ confidence=row["confidence"],
+ evidence=row["evidence"] or "",
+ )
+ )
# 获取建议
linker = get_multimodal_entity_linker()
@@ -4875,26 +4789,27 @@ async def suggest_multimodal_merges_endpoint(project_id: str, _=Depends(verify_a
"suggestions": [
{
"entity1": {
- "id": s['entity1'].get('id'),
- "name": s['entity1'].get('name'),
- "type": s['entity1'].get('type')
+ "id": s["entity1"].get("id"),
+ "name": s["entity1"].get("name"),
+ "type": s["entity1"].get("type"),
},
"entity2": {
- "id": s['entity2'].get('id'),
- "name": s['entity2'].get('name'),
- "type": s['entity2'].get('type')
+ "id": s["entity2"].get("id"),
+ "name": s["entity2"].get("name"),
+ "type": s["entity2"].get("type"),
},
- "similarity": s['similarity'],
- "match_type": s['match_type'],
- "suggested_action": s['suggested_action']
+ "similarity": s["similarity"],
+ "match_type": s["match_type"],
+ "suggested_action": s["suggested_action"],
}
for s in suggestions[:20] # 最多返回20个建议
- ]
+ ],
}
# ==================== Phase 7: Multimodal Support API ====================
+
class VideoUploadResponse(BaseModel):
video_id: str
filename: str
@@ -4934,10 +4849,12 @@ class MultimodalProfileResponse(BaseModel):
# ==================== Phase 7 Task 7: Plugin Management Pydantic Models ====================
+
class PluginCreate(BaseModel):
name: str = Field(..., description="插件名称")
- plugin_type: str = Field(...,
- description="插件类型: chrome_extension, feishu_bot, dingtalk_bot, zapier, make, webdav, custom")
+ plugin_type: str = Field(
+ ..., description="插件类型: chrome_extension, feishu_bot, dingtalk_bot, zapier, make, webdav, custom"
+ )
project_id: str = Field(..., description="关联项目ID")
config: dict = Field(default_factory=dict, description="插件配置")
@@ -5109,6 +5026,7 @@ def get_plugin_manager_instance():
# ==================== Phase 7 Task 7: Plugin Management Endpoints ====================
+
@app.post("/api/v1/plugins", response_model=PluginResponse, tags=["Plugins"])
async def create_plugin_endpoint(request: PluginCreate, _=Depends(verify_api_key)):
"""
@@ -5133,7 +5051,7 @@ async def create_plugin_endpoint(request: PluginCreate, _=Depends(verify_api_key
name=request.name,
plugin_type=request.plugin_type,
project_id=request.project_id,
- config=request.config
+ config=request.config,
)
created = manager.create_plugin(plugin)
@@ -5148,16 +5066,13 @@ async def create_plugin_endpoint(request: PluginCreate, _=Depends(verify_api_key
created_at=created.created_at,
updated_at=created.updated_at,
last_used_at=created.last_used_at,
- use_count=created.use_count
+ use_count=created.use_count,
)
@app.get("/api/v1/plugins", response_model=PluginListResponse, tags=["Plugins"])
async def list_plugins_endpoint(
- project_id: str | None = None,
- plugin_type: str | None = None,
- status: str | None = None,
- _=Depends(verify_api_key)
+ project_id: str | None = None, plugin_type: str | None = None, status: str | None = None, _=Depends(verify_api_key)
):
"""获取插件列表"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -5178,11 +5093,11 @@ async def list_plugins_endpoint(
created_at=p.created_at,
updated_at=p.updated_at,
last_used_at=p.last_used_at,
- use_count=p.use_count
+ use_count=p.use_count,
)
for p in plugins
],
- total=len(plugins)
+ total=len(plugins),
)
@@ -5208,7 +5123,7 @@ async def get_plugin_endpoint(plugin_id: str, _=Depends(verify_api_key)):
created_at=plugin.created_at,
updated_at=plugin.updated_at,
last_used_at=plugin.last_used_at,
- use_count=plugin.use_count
+ use_count=plugin.use_count,
)
@@ -5236,7 +5151,7 @@ async def update_plugin_endpoint(plugin_id: str, request: PluginUpdate, _=Depend
created_at=updated.created_at,
updated_at=updated.updated_at,
last_used_at=updated.last_used_at,
- use_count=updated.use_count
+ use_count=updated.use_count,
)
@@ -5257,6 +5172,7 @@ async def delete_plugin_endpoint(plugin_id: str, _=Depends(verify_api_key)):
# ==================== Phase 7 Task 7: Chrome Extension Endpoints ====================
+
@app.post("/api/v1/plugins/chrome/tokens", response_model=ChromeExtensionTokenResponse, tags=["Chrome Extension"])
async def create_chrome_token_endpoint(request: ChromeExtensionTokenCreate, _=Depends(verify_api_key)):
"""
@@ -5277,7 +5193,7 @@ async def create_chrome_token_endpoint(request: ChromeExtensionTokenCreate, _=De
name=request.name,
project_id=request.project_id,
permissions=request.permissions,
- expires_days=request.expires_days
+ expires_days=request.expires_days,
)
return ChromeExtensionTokenResponse(
@@ -5287,15 +5203,12 @@ async def create_chrome_token_endpoint(request: ChromeExtensionTokenCreate, _=De
project_id=token.project_id,
permissions=token.permissions,
expires_at=token.expires_at,
- created_at=token.created_at
+ created_at=token.created_at,
)
@app.get("/api/v1/plugins/chrome/tokens", tags=["Chrome Extension"])
-async def list_chrome_tokens_endpoint(
- project_id: str | None = None,
- _=Depends(verify_api_key)
-):
+async def list_chrome_tokens_endpoint(project_id: str | None = None, _=Depends(verify_api_key)):
"""列出 Chrome 扩展令牌"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -5319,11 +5232,11 @@ async def list_chrome_tokens_endpoint(
"created_at": t.created_at,
"last_used_at": t.last_used_at,
"use_count": t.use_count,
- "is_revoked": t.is_revoked
+ "is_revoked": t.is_revoked,
}
for t in tokens
],
- "total": len(tokens)
+ "total": len(tokens),
}
@@ -5370,11 +5283,7 @@ async def chrome_import_webpage_endpoint(request: ChromeExtensionImportRequest):
# 导入网页
result = await handler.import_webpage(
- token=token,
- url=request.url,
- title=request.title,
- content=request.content,
- html_content=request.html_content
+ token=token, url=request.url, title=request.title, content=request.content, html_content=request.html_content
)
if not result["success"]:
@@ -5385,6 +5294,7 @@ async def chrome_import_webpage_endpoint(request: ChromeExtensionImportRequest):
# ==================== Phase 7 Task 7: Bot Endpoints ====================
+
@app.post("/api/v1/plugins/bot/feishu/sessions", response_model=BotSessionResponse, tags=["Bot"])
async def create_feishu_session_endpoint(request: BotSessionCreate, _=Depends(verify_api_key)):
"""创建飞书机器人会话"""
@@ -5402,7 +5312,7 @@ async def create_feishu_session_endpoint(request: BotSessionCreate, _=Depends(ve
session_name=request.session_name,
project_id=request.project_id,
webhook_url=request.webhook_url,
- secret=request.secret
+ secret=request.secret,
)
return BotSessionResponse(
@@ -5415,7 +5325,7 @@ async def create_feishu_session_endpoint(request: BotSessionCreate, _=Depends(ve
is_active=session.is_active,
created_at=session.created_at,
last_message_at=session.last_message_at,
- message_count=session.message_count
+ message_count=session.message_count,
)
@@ -5436,7 +5346,7 @@ async def create_dingtalk_session_endpoint(request: BotSessionCreate, _=Depends(
session_name=request.session_name,
project_id=request.project_id,
webhook_url=request.webhook_url,
- secret=request.secret
+ secret=request.secret,
)
return BotSessionResponse(
@@ -5449,16 +5359,12 @@ async def create_dingtalk_session_endpoint(request: BotSessionCreate, _=Depends(
is_active=session.is_active,
created_at=session.created_at,
last_message_at=session.last_message_at,
- message_count=session.message_count
+ message_count=session.message_count,
)
@app.get("/api/v1/plugins/bot/{bot_type}/sessions", tags=["Bot"])
-async def list_bot_sessions_endpoint(
- bot_type: str,
- project_id: str | None = None,
- _=Depends(verify_api_key)
-):
+async def list_bot_sessions_endpoint(bot_type: str, project_id: str | None = None, _=Depends(verify_api_key)):
"""列出机器人会话"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -5488,11 +5394,11 @@ async def list_bot_sessions_endpoint(
"is_active": s.is_active,
"created_at": s.created_at,
"last_message_at": s.last_message_at,
- "message_count": s.message_count
+ "message_count": s.message_count,
}
for s in sessions
],
- "total": len(sessions)
+ "total": len(sessions),
}
@@ -5523,9 +5429,9 @@ async def bot_webhook_endpoint(bot_type: str, request: Request):
# 获取会话ID(飞书和钉钉的格式不同)
if bot_type == "feishu":
- session_id = message.get('chat_id') or message.get('open_chat_id')
+ session_id = message.get("chat_id") or message.get("open_chat_id")
else: # dingtalk
- session_id = message.get('conversationId') or message.get('senderStaffId')
+ session_id = message.get("conversationId") or message.get("senderStaffId")
if not session_id:
raise HTTPException(status_code=400, detail="Cannot identify session")
@@ -5534,11 +5440,7 @@ async def bot_webhook_endpoint(bot_type: str, request: Request):
session = handler.get_session(session_id)
if not session:
# 自动创建会话
- session = handler.create_session(
- session_id=session_id,
- session_name=f"Auto-{session_id[:8]}",
- webhook_url=""
- )
+ session = handler.create_session(session_id=session_id, session_name=f"Auto-{session_id[:8]}", webhook_url="")
# 处理消息
result = await handler.handle_message(session, message)
@@ -5551,12 +5453,7 @@ async def bot_webhook_endpoint(bot_type: str, request: Request):
@app.post("/api/v1/plugins/bot/{bot_type}/sessions/{session_id}/send", tags=["Bot"])
-async def send_bot_message_endpoint(
- bot_type: str,
- session_id: str,
- message: str,
- _=Depends(verify_api_key)
-):
+async def send_bot_message_endpoint(bot_type: str, session_id: str, message: str, _=Depends(verify_api_key)):
"""发送消息到机器人会话"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -5584,6 +5481,7 @@ async def send_bot_message_endpoint(
# ==================== Phase 7 Task 7: Integration Endpoints ====================
+
@app.post("/api/v1/plugins/integrations/zapier", response_model=WebhookEndpointResponse, tags=["Integrations"])
async def create_zapier_endpoint(request: WebhookEndpointCreate, _=Depends(verify_api_key)):
"""创建 Zapier Webhook 端点"""
@@ -5602,7 +5500,7 @@ async def create_zapier_endpoint(request: WebhookEndpointCreate, _=Depends(verif
project_id=request.project_id,
auth_type=request.auth_type,
auth_config=request.auth_config,
- trigger_events=request.trigger_events
+ trigger_events=request.trigger_events,
)
return WebhookEndpointResponse(
@@ -5616,7 +5514,7 @@ async def create_zapier_endpoint(request: WebhookEndpointCreate, _=Depends(verif
is_active=endpoint.is_active,
created_at=endpoint.created_at,
last_triggered_at=endpoint.last_triggered_at,
- trigger_count=endpoint.trigger_count
+ trigger_count=endpoint.trigger_count,
)
@@ -5638,7 +5536,7 @@ async def create_make_endpoint(request: WebhookEndpointCreate, _=Depends(verify_
project_id=request.project_id,
auth_type=request.auth_type,
auth_config=request.auth_config,
- trigger_events=request.trigger_events
+ trigger_events=request.trigger_events,
)
return WebhookEndpointResponse(
@@ -5652,15 +5550,13 @@ async def create_make_endpoint(request: WebhookEndpointCreate, _=Depends(verify_
is_active=endpoint.is_active,
created_at=endpoint.created_at,
last_triggered_at=endpoint.last_triggered_at,
- trigger_count=endpoint.trigger_count
+ trigger_count=endpoint.trigger_count,
)
@app.get("/api/v1/plugins/integrations/{endpoint_type}", tags=["Integrations"])
async def list_integration_endpoints_endpoint(
- endpoint_type: str,
- project_id: str | None = None,
- _=Depends(verify_api_key)
+ endpoint_type: str, project_id: str | None = None, _=Depends(verify_api_key)
):
"""列出集成端点"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -5693,11 +5589,11 @@ async def list_integration_endpoints_endpoint(
"is_active": e.is_active,
"created_at": e.created_at,
"last_triggered_at": e.last_triggered_at,
- "trigger_count": e.trigger_count
+ "trigger_count": e.trigger_count,
}
for e in endpoints
],
- "total": len(endpoints)
+ "total": len(endpoints),
}
@@ -5722,20 +5618,11 @@ async def test_integration_endpoint(endpoint_id: str, _=Depends(verify_api_key))
result = await handler.test_endpoint(endpoint)
- return WebhookTestResponse(
- success=result["success"],
- endpoint_id=endpoint_id,
- message=result["message"]
- )
+ return WebhookTestResponse(success=result["success"], endpoint_id=endpoint_id, message=result["message"])
@app.post("/api/v1/plugins/integrations/{endpoint_id}/trigger", tags=["Integrations"])
-async def trigger_integration_endpoint(
- endpoint_id: str,
- event_type: str,
- data: dict,
- _=Depends(verify_api_key)
-):
+async def trigger_integration_endpoint(endpoint_id: str, event_type: str, data: dict, _=Depends(verify_api_key)):
"""手动触发集成端点"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -5760,6 +5647,7 @@ async def trigger_integration_endpoint(
# ==================== Phase 7 Task 7: WebDAV Endpoints ====================
+
@app.post("/api/v1/plugins/webdav", response_model=WebDAVSyncResponse, tags=["WebDAV"])
async def create_webdav_sync_endpoint(request: WebDAVSyncCreate, _=Depends(verify_api_key)):
"""
@@ -5784,7 +5672,7 @@ async def create_webdav_sync_endpoint(request: WebDAVSyncCreate, _=Depends(verif
password=request.password,
remote_path=request.remote_path,
sync_mode=request.sync_mode,
- sync_interval=request.sync_interval
+ sync_interval=request.sync_interval,
)
return WebDAVSyncResponse(
@@ -5800,15 +5688,12 @@ async def create_webdav_sync_endpoint(request: WebDAVSyncCreate, _=Depends(verif
last_sync_status=sync.last_sync_status,
is_active=sync.is_active,
created_at=sync.created_at,
- sync_count=sync.sync_count
+ sync_count=sync.sync_count,
)
@app.get("/api/v1/plugins/webdav", tags=["WebDAV"])
-async def list_webdav_syncs_endpoint(
- project_id: str | None = None,
- _=Depends(verify_api_key)
-):
+async def list_webdav_syncs_endpoint(project_id: str | None = None, _=Depends(verify_api_key)):
"""列出 WebDAV 同步配置"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -5836,11 +5721,11 @@ async def list_webdav_syncs_endpoint(
"last_sync_status": s.last_sync_status,
"is_active": s.is_active,
"created_at": s.created_at,
- "sync_count": s.sync_count
+ "sync_count": s.sync_count,
}
for s in syncs
],
- "total": len(syncs)
+ "total": len(syncs),
}
@@ -5863,8 +5748,7 @@ async def test_webdav_connection_endpoint(sync_id: str, _=Depends(verify_api_key
result = await handler.test_connection(sync)
return WebDAVTestResponse(
- success=result["success"],
- message=result.get("message") or result.get("error", "Unknown result")
+ success=result["success"], message=result.get("message") or result.get("error", "Unknown result")
)
@@ -5892,7 +5776,7 @@ async def sync_webdav_endpoint(sync_id: str, _=Depends(verify_api_key)):
entities_count=result.get("entities_count"),
relations_count=result.get("relations_count"),
remote_path=result.get("remote_path"),
- error=result.get("error")
+ error=result.get("error"),
)
@@ -5920,12 +5804,9 @@ async def delete_webdav_sync_endpoint(sync_id: str, _=Depends(verify_api_key)):
async def get_openapi():
"""获取 OpenAPI 规范"""
from fastapi.openapi.utils import get_openapi
+
return get_openapi(
- title=app.title,
- version=app.version,
- description=app.description,
- routes=app.routes,
- tags=app.openapi_tags
+ title=app.title, version=app.version, description=app.description, routes=app.routes, tags=app.openapi_tags
)
@@ -5934,6 +5815,7 @@ app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
if __name__ == "__main__":
import uvicorn
+
uvicorn.run(app, host="0.0.0.0", port=8000)
@@ -6036,20 +5918,14 @@ class WebhookPayload(BaseModel):
@app.post("/api/v1/plugins", response_model=PluginResponse, tags=["Plugins"])
-async def create_plugin(
- request: PluginCreateRequest,
- api_key: str = Depends(verify_api_key)
-):
+async def create_plugin(request: PluginCreateRequest, api_key: str = Depends(verify_api_key)):
"""创建插件"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
manager = get_plugin_manager()
plugin = manager.create_plugin(
- name=request.name,
- plugin_type=request.plugin_type,
- project_id=request.project_id,
- config=request.config
+ name=request.name, plugin_type=request.plugin_type, project_id=request.project_id, config=request.config
)
return PluginResponse(
@@ -6059,15 +5935,13 @@ async def create_plugin(
project_id=plugin.project_id,
status=plugin.status,
api_key=plugin.api_key,
- created_at=plugin.created_at
+ created_at=plugin.created_at,
)
@app.get("/api/v1/plugins", tags=["Plugins"])
async def list_plugins(
- project_id: str | None = None,
- plugin_type: str | None = None,
- api_key: str = Depends(verify_api_key)
+ project_id: str | None = None, plugin_type: str | None = None, api_key: str = Depends(verify_api_key)
):
"""列出插件"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -6085,7 +5959,7 @@ async def list_plugins(
"project_id": p.project_id,
"status": p.status,
"use_count": p.use_count,
- "created_at": p.created_at
+ "created_at": p.created_at,
}
for p in plugins
]
@@ -6093,10 +5967,7 @@ async def list_plugins(
@app.get("/api/v1/plugins/{plugin_id}", response_model=PluginResponse, tags=["Plugins"])
-async def get_plugin(
- plugin_id: str,
- api_key: str = Depends(verify_api_key)
-):
+async def get_plugin(plugin_id: str, api_key: str = Depends(verify_api_key)):
"""获取插件详情"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6114,15 +5985,12 @@ async def get_plugin(
project_id=plugin.project_id,
status=plugin.status,
api_key=plugin.api_key,
- created_at=plugin.created_at
+ created_at=plugin.created_at,
)
@app.delete("/api/v1/plugins/{plugin_id}", tags=["Plugins"])
-async def delete_plugin(
- plugin_id: str,
- api_key: str = Depends(verify_api_key)
-):
+async def delete_plugin(plugin_id: str, api_key: str = Depends(verify_api_key)):
"""删除插件"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6134,10 +6002,7 @@ async def delete_plugin(
@app.post("/api/v1/plugins/{plugin_id}/regenerate-key", tags=["Plugins"])
-async def regenerate_plugin_key(
- plugin_id: str,
- api_key: str = Depends(verify_api_key)
-):
+async def regenerate_plugin_key(plugin_id: str, api_key: str = Depends(verify_api_key)):
"""重新生成插件 API Key"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6150,11 +6015,9 @@ async def regenerate_plugin_key(
# ==================== Chrome Extension API ====================
+
@app.post("/api/v1/plugins/chrome/clip", response_model=ChromeClipResponse, tags=["Chrome Extension"])
-async def chrome_clip(
- request: ChromeClipRequest,
- x_api_key: str | None = Header(None, alias="X-API-Key")
-):
+async def chrome_clip(request: ChromeClipRequest, x_api_key: str | None = Header(None, alias="X-API-Key")):
"""Chrome 插件保存网页内容"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6195,7 +6058,7 @@ URL: {request.url}
project_id=project_id,
filename=f"clip_{request.title[:50]}.md",
full_text=doc_content,
- transcript_type="document"
+ transcript_type="document",
)
# 记录活动
@@ -6203,12 +6066,7 @@ URL: {request.url}
plugin_id=plugin.id,
activity_type="clip",
source="chrome_extension",
- details={
- "url": request.url,
- "title": request.title,
- "project_id": project_id,
- "transcript_id": transcript_id
- }
+ details={"url": request.url, "title": request.title, "project_id": project_id, "transcript_id": transcript_id},
)
return ChromeClipResponse(
@@ -6217,18 +6075,15 @@ URL: {request.url}
url=request.url,
title=request.title,
status="success",
- message="Content saved successfully"
+ message="Content saved successfully",
)
# ==================== Bot API ====================
+
@app.post("/api/v1/bots/webhook/{platform}", response_model=BotMessageResponse, tags=["Bot"])
-async def bot_webhook(
- platform: str,
- request: Request,
- x_signature: str | None = Header(None, alias="X-Signature")
-):
+async def bot_webhook(platform: str, request: Request, x_signature: str | None = Header(None, alias="X-Signature")):
"""接收机器人 Webhook 消息(飞书/钉钉/Slack)"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6257,15 +6112,13 @@ async def bot_webhook(
success=True,
reply="收到消息!请使用 InsightFlow 控制台查看更多功能。",
session_id=message.get("session_id", ""),
- action="reply"
+ action="reply",
)
@app.get("/api/v1/bots/sessions", response_model=list[BotSessionResponse], tags=["Bot"])
async def list_bot_sessions(
- plugin_id: str | None = None,
- project_id: str | None = None,
- api_key: str = Depends(verify_api_key)
+ plugin_id: str | None = None, project_id: str | None = None, api_key: str = Depends(verify_api_key)
):
"""列出机器人会话"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -6285,7 +6138,7 @@ async def list_bot_sessions(
project_id=s.project_id,
message_count=s.message_count,
created_at=s.created_at,
- last_message_at=s.last_message_at
+ last_message_at=s.last_message_at,
)
for s in sessions
]
@@ -6293,6 +6146,7 @@ async def list_bot_sessions(
# ==================== Webhook Integration API ====================
+
@app.post("/api/v1/webhook-endpoints", response_model=WebhookEndpointResponse, tags=["Integrations"])
async def create_integration_webhook_endpoint(
plugin_id: str,
@@ -6300,7 +6154,7 @@ async def create_integration_webhook_endpoint(
endpoint_type: str,
target_project_id: str | None = None,
allowed_events: list[str] | None = None,
- api_key: str = Depends(verify_api_key)
+ api_key: str = Depends(verify_api_key),
):
"""创建 Webhook 端点(用于 Zapier/Make 集成)"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -6312,7 +6166,7 @@ async def create_integration_webhook_endpoint(
name=name,
endpoint_type=endpoint_type,
target_project_id=target_project_id,
- allowed_events=allowed_events
+ allowed_events=allowed_events,
)
return WebhookEndpointResponse(
@@ -6324,15 +6178,12 @@ async def create_integration_webhook_endpoint(
target_project_id=endpoint.target_project_id,
is_active=endpoint.is_active,
trigger_count=endpoint.trigger_count,
- created_at=endpoint.created_at
+ created_at=endpoint.created_at,
)
@app.get("/api/v1/webhook-endpoints", response_model=list[WebhookEndpointResponse], tags=["Integrations"])
-async def list_webhook_endpoints(
- plugin_id: str | None = None,
- api_key: str = Depends(verify_api_key)
-):
+async def list_webhook_endpoints(plugin_id: str | None = None, api_key: str = Depends(verify_api_key)):
"""列出 Webhook 端点"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6350,7 +6201,7 @@ async def list_webhook_endpoints(
target_project_id=e.target_project_id,
is_active=e.is_active,
trigger_count=e.trigger_count,
- created_at=e.created_at
+ created_at=e.created_at,
)
for e in endpoints
]
@@ -6358,10 +6209,7 @@ async def list_webhook_endpoints(
@app.post("/webhook/{endpoint_type}/{token}", tags=["Integrations"])
async def receive_webhook(
- endpoint_type: str,
- token: str,
- request: Request,
- x_signature: str | None = Header(None, alias="X-Signature")
+ endpoint_type: str, token: str, request: Request, x_signature: str | None = Header(None, alias="X-Signature")
):
"""接收外部 Webhook 调用(Zapier/Make/Custom)"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -6397,22 +6245,19 @@ async def receive_webhook(
details={
"endpoint_id": endpoint.id,
"event": body.get("event"),
- "data_keys": list(body.get("data", {}).keys())
- }
+ "data_keys": list(body.get("data", {}).keys()),
+ },
)
# 处理数据(简化版本)
# 实际应该根据 endpoint.target_project_id 和 body 内容创建文档/实体等
- return {
- "success": True,
- "endpoint_id": endpoint.id,
- "received_at": datetime.now().isoformat()
- }
+ return {"success": True, "endpoint_id": endpoint.id, "received_at": datetime.now().isoformat()}
# ==================== WebDAV API ====================
+
@app.post("/api/v1/webdav-syncs", response_model=WebDAVSyncResponse, tags=["WebDAV"])
async def create_webdav_sync(
plugin_id: str,
@@ -6425,7 +6270,7 @@ async def create_webdav_sync(
sync_direction: str = "bidirectional",
sync_mode: str = "manual",
auto_analyze: bool = True,
- api_key: str = Depends(verify_api_key)
+ api_key: str = Depends(verify_api_key),
):
"""创建 WebDAV 同步配置"""
if not PLUGIN_MANAGER_AVAILABLE:
@@ -6442,7 +6287,7 @@ async def create_webdav_sync(
local_path=local_path,
sync_direction=sync_direction,
sync_mode=sync_mode,
- auto_analyze=auto_analyze
+ auto_analyze=auto_analyze,
)
return WebDAVSyncResponse(
@@ -6458,15 +6303,12 @@ async def create_webdav_sync(
auto_analyze=sync.auto_analyze,
is_active=sync.is_active,
last_sync_at=sync.last_sync_at,
- created_at=sync.created_at
+ created_at=sync.created_at,
)
@app.get("/api/v1/webdav-syncs", response_model=list[WebDAVSyncResponse], tags=["WebDAV"])
-async def list_webdav_syncs(
- plugin_id: str | None = None,
- api_key: str = Depends(verify_api_key)
-):
+async def list_webdav_syncs(plugin_id: str | None = None, api_key: str = Depends(verify_api_key)):
"""列出 WebDAV 同步配置"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6488,17 +6330,14 @@ async def list_webdav_syncs(
auto_analyze=s.auto_analyze,
is_active=s.is_active,
last_sync_at=s.last_sync_at,
- created_at=s.created_at
+ created_at=s.created_at,
)
for s in syncs
]
@app.post("/api/v1/webdav-syncs/{sync_id}/test", tags=["WebDAV"])
-async def test_webdav_connection(
- sync_id: str,
- api_key: str = Depends(verify_api_key)
-):
+async def test_webdav_connection(sync_id: str, api_key: str = Depends(verify_api_key)):
"""测试 WebDAV 连接"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6510,22 +6349,16 @@ async def test_webdav_connection(
raise HTTPException(status_code=404, detail="WebDAV sync not found")
from plugin_manager import WebDAVSync as WebDAVSyncHandler
+
handler = WebDAVSyncHandler(manager)
- success, message = await handler.test_connection(
- sync.server_url,
- sync.username,
- sync.password
- )
+ success, message = await handler.test_connection(sync.server_url, sync.username, sync.password)
return {"success": success, "message": message}
@app.post("/api/v1/webdav-syncs/{sync_id}/sync", tags=["WebDAV"])
-async def trigger_webdav_sync(
- sync_id: str,
- api_key: str = Depends(verify_api_key)
-):
+async def trigger_webdav_sync(sync_id: str, api_key: str = Depends(verify_api_key)):
"""手动触发 WebDAV 同步"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
@@ -6539,39 +6372,24 @@ async def trigger_webdav_sync(
# 这里应该启动异步同步任务
# 简化版本,仅返回成功
- manager.update_webdav_sync(
- sync_id,
- last_sync_at=datetime.now().isoformat(),
- last_sync_status="running"
- )
+ manager.update_webdav_sync(sync_id, last_sync_at=datetime.now().isoformat(), last_sync_status="running")
- return {
- "success": True,
- "sync_id": sync_id,
- "status": "running",
- "message": "Sync started"
- }
+ return {"success": True, "sync_id": sync_id, "status": "running", "message": "Sync started"}
# ==================== Plugin Activity Logs ====================
+
@app.get("/api/v1/plugins/{plugin_id}/logs", tags=["Plugins"])
async def get_plugin_logs(
- plugin_id: str,
- activity_type: str | None = None,
- limit: int = 100,
- api_key: str = Depends(verify_api_key)
+ plugin_id: str, activity_type: str | None = None, limit: int = 100, api_key: str = Depends(verify_api_key)
):
"""获取插件活动日志"""
if not PLUGIN_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Plugin manager not available")
manager = get_plugin_manager()
- logs = manager.get_activity_logs(
- plugin_id=plugin_id,
- activity_type=activity_type,
- limit=limit
- )
+ logs = manager.get_activity_logs(plugin_id=plugin_id, activity_type=activity_type, limit=limit)
return {
"logs": [
@@ -6580,7 +6398,7 @@ async def get_plugin_logs(
"activity_type": log.activity_type,
"source": log.source,
"details": log.details,
- "created_at": log.created_at
+ "created_at": log.created_at,
}
for log in logs
]
@@ -6589,6 +6407,7 @@ async def get_plugin_logs(
# ==================== Phase 7 Task 3: Security & Compliance API ====================
+
# Pydantic models for security API
class AuditLogResponse(BaseModel):
id: str
@@ -6704,6 +6523,7 @@ class AccessRequestResponse(BaseModel):
# ==================== Audit Logs API ====================
+
@app.get("/api/v1/audit-logs", response_model=list[AuditLogResponse], tags=["Security"])
async def get_audit_logs(
user_id: str | None = None,
@@ -6715,7 +6535,7 @@ async def get_audit_logs(
success: bool | None = None,
limit: int = 100,
offset: int = 0,
- api_key: str = Depends(verify_api_key)
+ api_key: str = Depends(verify_api_key),
):
"""查询审计日志"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6731,7 +6551,7 @@ async def get_audit_logs(
end_time=end_time,
success=success,
limit=limit,
- offset=offset
+ offset=offset,
)
return [
@@ -6745,7 +6565,7 @@ async def get_audit_logs(
action_details=log.action_details,
success=log.success,
error_message=log.error_message,
- created_at=log.created_at
+ created_at=log.created_at,
)
for log in logs
]
@@ -6753,9 +6573,7 @@ async def get_audit_logs(
@app.get("/api/v1/audit-logs/stats", response_model=AuditStatsResponse, tags=["Security"])
async def get_audit_stats(
- start_time: str | None = None,
- end_time: str | None = None,
- api_key: str = Depends(verify_api_key)
+ start_time: str | None = None, end_time: str | None = None, api_key: str = Depends(verify_api_key)
):
"""获取审计统计"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6769,11 +6587,10 @@ async def get_audit_stats(
# ==================== Encryption API ====================
+
@app.post("/api/v1/projects/{project_id}/encryption/enable", response_model=EncryptionConfigResponse, tags=["Security"])
async def enable_project_encryption(
- project_id: str,
- request: EncryptionEnableRequest,
- api_key: str = Depends(verify_api_key)
+ project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key)
):
"""启用项目端到端加密"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6789,7 +6606,7 @@ async def enable_project_encryption(
is_enabled=config.is_enabled,
encryption_type=config.encryption_type,
created_at=config.created_at,
- updated_at=config.updated_at
+ updated_at=config.updated_at,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -6797,9 +6614,7 @@ async def enable_project_encryption(
@app.post("/api/v1/projects/{project_id}/encryption/disable", tags=["Security"])
async def disable_project_encryption(
- project_id: str,
- request: EncryptionEnableRequest,
- api_key: str = Depends(verify_api_key)
+ project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key)
):
"""禁用项目加密"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6816,9 +6631,7 @@ async def disable_project_encryption(
@app.post("/api/v1/projects/{project_id}/encryption/verify", tags=["Security"])
async def verify_encryption_password(
- project_id: str,
- request: EncryptionEnableRequest,
- api_key: str = Depends(verify_api_key)
+ project_id: str, request: EncryptionEnableRequest, api_key: str = Depends(verify_api_key)
):
"""验证加密密码"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6830,12 +6643,10 @@ async def verify_encryption_password(
return {"valid": is_valid}
-@app.get("/api/v1/projects/{project_id}/encryption",
- response_model=Optional[EncryptionConfigResponse], tags=["Security"])
-async def get_encryption_config(
- project_id: str,
- api_key: str = Depends(verify_api_key)
-):
+@app.get(
+ "/api/v1/projects/{project_id}/encryption", response_model=Optional[EncryptionConfigResponse], tags=["Security"]
+)
+async def get_encryption_config(project_id: str, api_key: str = Depends(verify_api_key)):
"""获取项目加密配置"""
if not SECURITY_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Security manager not available")
@@ -6852,17 +6663,16 @@ async def get_encryption_config(
is_enabled=config.is_enabled,
encryption_type=config.encryption_type,
created_at=config.created_at,
- updated_at=config.updated_at
+ updated_at=config.updated_at,
)
# ==================== Data Masking API ====================
+
@app.post("/api/v1/projects/{project_id}/masking-rules", response_model=MaskingRuleResponse, tags=["Security"])
async def create_masking_rule(
- project_id: str,
- request: MaskingRuleCreateRequest,
- api_key: str = Depends(verify_api_key)
+ project_id: str, request: MaskingRuleCreateRequest, api_key: str = Depends(verify_api_key)
):
"""创建数据脱敏规则"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6882,7 +6692,7 @@ async def create_masking_rule(
pattern=request.pattern,
replacement=request.replacement,
description=request.description,
- priority=request.priority
+ priority=request.priority,
)
return MaskingRuleResponse(
@@ -6896,16 +6706,12 @@ async def create_masking_rule(
priority=rule.priority,
description=rule.description,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
@app.get("/api/v1/projects/{project_id}/masking-rules", response_model=list[MaskingRuleResponse], tags=["Security"])
-async def get_masking_rules(
- project_id: str,
- active_only: bool = True,
- api_key: str = Depends(verify_api_key)
-):
+async def get_masking_rules(project_id: str, active_only: bool = True, api_key: str = Depends(verify_api_key)):
"""获取项目脱敏规则"""
if not SECURITY_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Security manager not available")
@@ -6925,7 +6731,7 @@ async def get_masking_rules(
priority=rule.priority,
description=rule.description,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
for rule in rules
]
@@ -6940,7 +6746,7 @@ async def update_masking_rule(
is_active: bool | None = None,
priority: int | None = None,
description: str | None = None,
- api_key: str = Depends(verify_api_key)
+ api_key: str = Depends(verify_api_key),
):
"""更新脱敏规则"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -6978,15 +6784,12 @@ async def update_masking_rule(
priority=rule.priority,
description=rule.description,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
@app.delete("/api/v1/masking-rules/{rule_id}", tags=["Security"])
-async def delete_masking_rule(
- rule_id: str,
- api_key: str = Depends(verify_api_key)
-):
+async def delete_masking_rule(rule_id: str, api_key: str = Depends(verify_api_key)):
"""删除脱敏规则"""
if not SECURITY_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Security manager not available")
@@ -7001,11 +6804,7 @@ async def delete_masking_rule(
@app.post("/api/v1/projects/{project_id}/masking/apply", response_model=MaskingApplyResponse, tags=["Security"])
-async def apply_masking(
- project_id: str,
- request: MaskingApplyRequest,
- api_key: str = Depends(verify_api_key)
-):
+async def apply_masking(project_id: str, request: MaskingApplyRequest, api_key: str = Depends(verify_api_key)):
"""应用脱敏规则到文本"""
if not SECURITY_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Security manager not available")
@@ -7023,20 +6822,15 @@ async def apply_masking(
rules = manager.get_masking_rules(project_id)
applied_rules = [r.name for r in rules if r.is_active]
- return MaskingApplyResponse(
- original_text=request.text,
- masked_text=masked_text,
- applied_rules=applied_rules
- )
+ return MaskingApplyResponse(original_text=request.text, masked_text=masked_text, applied_rules=applied_rules)
# ==================== Data Access Policy API ====================
+
@app.post("/api/v1/projects/{project_id}/access-policies", response_model=AccessPolicyResponse, tags=["Security"])
async def create_access_policy(
- project_id: str,
- request: AccessPolicyCreateRequest,
- api_key: str = Depends(verify_api_key)
+ project_id: str, request: AccessPolicyCreateRequest, api_key: str = Depends(verify_api_key)
):
"""创建数据访问策略"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -7053,7 +6847,7 @@ async def create_access_policy(
allowed_ips=request.allowed_ips,
time_restrictions=request.time_restrictions,
max_access_count=request.max_access_count,
- require_approval=request.require_approval
+ require_approval=request.require_approval,
)
return AccessPolicyResponse(
@@ -7069,16 +6863,12 @@ async def create_access_policy(
require_approval=policy.require_approval,
is_active=policy.is_active,
created_at=policy.created_at,
- updated_at=policy.updated_at
+ updated_at=policy.updated_at,
)
@app.get("/api/v1/projects/{project_id}/access-policies", response_model=list[AccessPolicyResponse], tags=["Security"])
-async def get_access_policies(
- project_id: str,
- active_only: bool = True,
- api_key: str = Depends(verify_api_key)
-):
+async def get_access_policies(project_id: str, active_only: bool = True, api_key: str = Depends(verify_api_key)):
"""获取项目访问策略"""
if not SECURITY_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Security manager not available")
@@ -7100,7 +6890,7 @@ async def get_access_policies(
require_approval=policy.require_approval,
is_active=policy.is_active,
created_at=policy.created_at,
- updated_at=policy.updated_at
+ updated_at=policy.updated_at,
)
for policy in policies
]
@@ -7108,10 +6898,7 @@ async def get_access_policies(
@app.post("/api/v1/access-policies/{policy_id}/check", tags=["Security"])
async def check_access_permission(
- policy_id: str,
- user_id: str,
- user_ip: str | None = None,
- api_key: str = Depends(verify_api_key)
+ policy_id: str, user_id: str, user_ip: str | None = None, api_key: str = Depends(verify_api_key)
):
"""检查访问权限"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -7120,19 +6907,17 @@ async def check_access_permission(
manager = get_security_manager()
allowed, reason = manager.check_access_permission(policy_id, user_id, user_ip)
- return {
- "allowed": allowed,
- "reason": reason if not allowed else None
- }
+ return {"allowed": allowed, "reason": reason if not allowed else None}
# ==================== Access Request API ====================
+
@app.post("/api/v1/access-requests", response_model=AccessRequestResponse, tags=["Security"])
async def create_access_request(
request: AccessRequestCreateRequest,
user_id: str, # 实际应该从认证信息中获取
- api_key: str = Depends(verify_api_key)
+ api_key: str = Depends(verify_api_key),
):
"""创建访问请求"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -7144,7 +6929,7 @@ async def create_access_request(
policy_id=request.policy_id,
user_id=user_id,
request_reason=request.request_reason,
- expires_hours=request.expires_hours
+ expires_hours=request.expires_hours,
)
return AccessRequestResponse(
@@ -7156,16 +6941,13 @@ async def create_access_request(
approved_by=access_request.approved_by,
approved_at=access_request.approved_at,
expires_at=access_request.expires_at,
- created_at=access_request.created_at
+ created_at=access_request.created_at,
)
@app.post("/api/v1/access-requests/{request_id}/approve", response_model=AccessRequestResponse, tags=["Security"])
async def approve_access_request(
- request_id: str,
- approved_by: str,
- expires_hours: int = 24,
- api_key: str = Depends(verify_api_key)
+ request_id: str, approved_by: str, expires_hours: int = 24, api_key: str = Depends(verify_api_key)
):
"""批准访问请求"""
if not SECURITY_MANAGER_AVAILABLE:
@@ -7186,16 +6968,12 @@ async def approve_access_request(
approved_by=access_request.approved_by,
approved_at=access_request.approved_at,
expires_at=access_request.expires_at,
- created_at=access_request.created_at
+ created_at=access_request.created_at,
)
@app.post("/api/v1/access-requests/{request_id}/reject", response_model=AccessRequestResponse, tags=["Security"])
-async def reject_access_request(
- request_id: str,
- rejected_by: str,
- api_key: str = Depends(verify_api_key)
-):
+async def reject_access_request(request_id: str, rejected_by: str, api_key: str = Depends(verify_api_key)):
"""拒绝访问请求"""
if not SECURITY_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Security manager not available")
@@ -7215,7 +6993,7 @@ async def reject_access_request(
approved_by=access_request.approved_by,
approved_at=access_request.approved_at,
expires_at=access_request.expires_at,
- created_at=access_request.created_at
+ created_at=access_request.created_at,
)
@@ -7225,6 +7003,7 @@ async def reject_access_request(
# ----- 请求模型 -----
+
class ShareLinkCreate(BaseModel):
permission: str = "read_only" # read_only, comment, edit, admin
expires_in_days: int | None = None
@@ -7268,6 +7047,7 @@ class TeamMemberRoleUpdate(BaseModel):
# ----- 项目分享 -----
+
@app.post("/api/v1/projects/{project_id}/shares")
async def create_share_link(project_id: str, request: ShareLinkCreate, created_by: str = "current_user"):
"""创建项目分享链接"""
@@ -7283,7 +7063,7 @@ async def create_share_link(project_id: str, request: ShareLinkCreate, created_b
max_uses=request.max_uses,
password=request.password,
allow_download=request.allow_download,
- allow_export=request.allow_export
+ allow_export=request.allow_export,
)
return {
@@ -7293,7 +7073,7 @@ async def create_share_link(project_id: str, request: ShareLinkCreate, created_b
"created_at": share.created_at,
"expires_at": share.expires_at,
"max_uses": share.max_uses,
- "share_url": f"/share/{share.token}"
+ "share_url": f"/share/{share.token}",
}
@@ -7319,7 +7099,7 @@ async def list_project_shares(project_id: str):
"is_active": s.is_active,
"has_password": s.password_hash is not None,
"allow_download": s.allow_download,
- "allow_export": s.allow_export
+ "allow_export": s.allow_export,
}
for s in shares
]
@@ -7346,7 +7126,7 @@ async def verify_share_link(request: ShareLinkVerify):
"project_id": share.project_id,
"permission": share.permission,
"allow_download": share.allow_download,
- "allow_export": share.allow_export
+ "allow_export": share.allow_export,
}
@@ -7380,11 +7160,11 @@ async def access_shared_project(token: str, password: str | None = None):
"id": project.id,
"name": project.name,
"description": project.description,
- "created_at": project.created_at
+ "created_at": project.created_at,
},
"permission": share.permission,
"allow_download": share.allow_download,
- "allow_export": share.allow_export
+ "allow_export": share.allow_export,
}
@@ -7402,6 +7182,7 @@ async def revoke_share_link(share_id: str, revoked_by: str = "current_user"):
return {"success": True, "message": "Share link revoked"}
+
# ----- 评论和批注 -----
@@ -7420,7 +7201,7 @@ async def add_comment(project_id: str, request: CommentCreate, author: str = "cu
author_name=author_name,
content=request.content,
parent_id=request.parent_id,
- mentions=request.mentions
+ mentions=request.mentions,
)
return {
@@ -7432,7 +7213,7 @@ async def add_comment(project_id: str, request: CommentCreate, author: str = "cu
"author_name": comment.author_name,
"content": comment.content,
"created_at": comment.created_at,
- "resolved": comment.resolved
+ "resolved": comment.resolved,
}
@@ -7458,10 +7239,10 @@ async def get_comments(target_type: str, target_id: str, include_resolved: bool
"updated_at": c.updated_at,
"resolved": c.resolved,
"resolved_by": c.resolved_by,
- "resolved_at": c.resolved_at
+ "resolved_at": c.resolved_at,
}
for c in comments
- ]
+ ],
}
@@ -7486,10 +7267,10 @@ async def get_project_comments(project_id: str, limit: int = 50, offset: int = 0
"author_name": c.author_name,
"content": c.content,
"created_at": c.created_at,
- "resolved": c.resolved
+ "resolved": c.resolved,
}
for c in comments
- ]
+ ],
}
@@ -7505,11 +7286,7 @@ async def update_comment(comment_id: str, request: CommentUpdate, updated_by: st
if not comment:
raise HTTPException(status_code=404, detail="Comment not found or not authorized")
- return {
- "id": comment.id,
- "content": comment.content,
- "updated_at": comment.updated_at
- }
+ return {"id": comment.id, "content": comment.content, "updated_at": comment.updated_at}
@app.post("/api/v1/comments/{comment_id}/resolve")
@@ -7541,16 +7318,13 @@ async def delete_comment(comment_id: str, deleted_by: str = "current_user"):
return {"success": True, "message": "Comment deleted"}
+
# ----- 变更历史 -----
@app.get("/api/v1/projects/{project_id}/history")
async def get_change_history(
- project_id: str,
- entity_type: str | None = None,
- entity_id: str | None = None,
- limit: int = 50,
- offset: int = 0
+ project_id: str, entity_type: str | None = None, entity_id: str | None = None, limit: int = 50, offset: int = 0
):
"""获取变更历史"""
if not COLLABORATION_AVAILABLE:
@@ -7574,10 +7348,10 @@ async def get_change_history(
"old_value": r.old_value,
"new_value": r.new_value,
"description": r.description,
- "reverted": r.reverted
+ "reverted": r.reverted,
}
for r in records
- ]
+ ],
}
@@ -7613,10 +7387,10 @@ async def get_entity_versions(entity_type: str, entity_id: str):
"changed_at": r.changed_at,
"old_value": r.old_value,
"new_value": r.new_value,
- "description": r.description
+ "description": r.description,
}
for r in records
- ]
+ ],
}
@@ -7634,6 +7408,7 @@ async def revert_change(record_id: str, reverted_by: str = "current_user"):
return {"success": True, "message": "Change reverted"}
+
# ----- 团队成员 -----
@@ -7650,7 +7425,7 @@ async def invite_team_member(project_id: str, request: TeamMemberInvite, invited
user_name=request.user_name,
user_email=request.user_email,
role=request.role,
- invited_by=invited_by
+ invited_by=invited_by,
)
return {
@@ -7660,7 +7435,7 @@ async def invite_team_member(project_id: str, request: TeamMemberInvite, invited
"user_email": member.user_email,
"role": member.role,
"joined_at": member.joined_at,
- "permissions": member.permissions
+ "permissions": member.permissions,
}
@@ -7684,10 +7459,10 @@ async def list_team_members(project_id: str):
"role": m.role,
"joined_at": m.joined_at,
"last_active_at": m.last_active_at,
- "permissions": m.permissions
+ "permissions": m.permissions,
}
for m in members
- ]
+ ],
}
@@ -7737,23 +7512,17 @@ async def check_project_permissions(project_id: str, user_id: str = "current_use
break
if not user_member:
- return {
- "has_access": False,
- "role": None,
- "permissions": []
- }
+ return {"has_access": False, "role": None, "permissions": []}
- return {
- "has_access": True,
- "role": user_member.role,
- "permissions": user_member.permissions
- }
+ return {"has_access": True, "role": user_member.role, "permissions": user_member.permissions}
# ==================== Phase 7 Task 6: Advanced Search & Discovery ====================
+
class FullTextSearchRequest(BaseModel):
"""全文搜索请求"""
+
query: str
content_types: list[str] | None = None
operator: str = "AND" # AND, OR, NOT
@@ -7762,6 +7531,7 @@ class FullTextSearchRequest(BaseModel):
class SemanticSearchRequest(BaseModel):
"""语义搜索请求"""
+
query: str
content_types: list[str] | None = None
threshold: float = 0.7
@@ -7769,11 +7539,7 @@ class SemanticSearchRequest(BaseModel):
@app.post("/api/v1/search/fulltext", tags=["Search"])
-async def fulltext_search(
- project_id: str,
- request: FullTextSearchRequest,
- _=Depends(verify_api_key)
-):
+async def fulltext_search(project_id: str, request: FullTextSearchRequest, _=Depends(verify_api_key)):
"""全文搜索"""
if not SEARCH_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Search manager not available")
@@ -7790,30 +7556,29 @@ async def fulltext_search(
project_id=project_id,
content_types=request.content_types,
operator=operator,
- limit=request.limit
+ limit=request.limit,
)
return {
"query": request.query,
"operator": request.operator,
"total": len(results),
- "results": [{
- "id": r.id,
- "type": r.type,
- "title": r.title,
- "content": r.content,
- "highlights": r.highlights,
- "score": r.score
- } for r in results]
+ "results": [
+ {
+ "id": r.id,
+ "type": r.type,
+ "title": r.title,
+ "content": r.content,
+ "highlights": r.highlights,
+ "score": r.score,
+ }
+ for r in results
+ ],
}
@app.post("/api/v1/search/semantic", tags=["Search"])
-async def semantic_search(
- project_id: str,
- request: SemanticSearchRequest,
- _=Depends(verify_api_key)
-):
+async def semantic_search(project_id: str, request: SemanticSearchRequest, _=Depends(verify_api_key)):
"""语义搜索"""
if not SEARCH_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Search manager not available")
@@ -7825,29 +7590,20 @@ async def semantic_search(
project_id=project_id,
content_types=request.content_types,
threshold=request.threshold,
- limit=request.limit
+ limit=request.limit,
)
return {
"query": request.query,
"threshold": request.threshold,
"total": len(results),
- "results": [{
- "id": r.id,
- "type": r.type,
- "text": r.text,
- "similarity": r.similarity
- } for r in results]
+ "results": [{"id": r.id, "type": r.type, "text": r.text, "similarity": r.similarity} for r in results],
}
@app.get("/api/v1/entities/{entity_id}/paths/{target_entity_id}", tags=["Search"])
async def find_entity_paths(
- entity_id: str,
- target_entity_id: str,
- max_depth: int = 5,
- find_all: bool = False,
- _=Depends(verify_api_key)
+ entity_id: str, target_entity_id: str, max_depth: int = 5, find_all: bool = False, _=Depends(verify_api_key)
):
"""查找实体关系路径"""
if not SEARCH_MANAGER_AVAILABLE:
@@ -7857,15 +7613,11 @@ async def find_entity_paths(
if find_all:
paths = search_manager.path_discovery.find_all_paths(
- source_entity_id=entity_id,
- target_entity_id=target_entity_id,
- max_depth=max_depth
+ source_entity_id=entity_id, target_entity_id=target_entity_id, max_depth=max_depth
)
else:
path = search_manager.path_discovery.find_shortest_path(
- source_entity_id=entity_id,
- target_entity_id=target_entity_id,
- max_depth=max_depth
+ source_entity_id=entity_id, target_entity_id=target_entity_id, max_depth=max_depth
)
paths = [path] if path else []
@@ -7873,22 +7625,21 @@ async def find_entity_paths(
"source_entity_id": entity_id,
"target_entity_id": target_entity_id,
"path_count": len(paths),
- "paths": [{
- "path_id": p.path_id,
- "path_length": p.path_length,
- "nodes": p.nodes,
- "edges": p.edges,
- "confidence": p.confidence
- } for p in paths]
+ "paths": [
+ {
+ "path_id": p.path_id,
+ "path_length": p.path_length,
+ "nodes": p.nodes,
+ "edges": p.edges,
+ "confidence": p.confidence,
+ }
+ for p in paths
+ ],
}
@app.get("/api/v1/entities/{entity_id}/network", tags=["Search"])
-async def get_entity_network(
- entity_id: str,
- depth: int = 2,
- _=Depends(verify_api_key)
-):
+async def get_entity_network(entity_id: str, depth: int = 2, _=Depends(verify_api_key)):
"""获取实体关系网络"""
if not SEARCH_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Search manager not available")
@@ -7900,10 +7651,7 @@ async def get_entity_network(
@app.get("/api/v1/projects/{project_id}/knowledge-gaps", tags=["Search"])
-async def detect_knowledge_gaps(
- project_id: str,
- _=Depends(verify_api_key)
-):
+async def detect_knowledge_gaps(project_id: str, _=Depends(verify_api_key)):
"""检测知识缺口"""
if not SEARCH_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Search manager not available")
@@ -7917,23 +7665,23 @@ async def detect_knowledge_gaps(
"project_id": project_id,
"completeness": completeness,
"gap_count": len(gaps),
- "gaps": [{
- "gap_id": g.gap_id,
- "gap_type": g.gap_type,
- "entity_id": g.entity_id,
- "entity_name": g.entity_name,
- "description": g.description,
- "severity": g.severity,
- "suggestion": g.suggestion
- } for g in gaps]
+ "gaps": [
+ {
+ "gap_id": g.gap_id,
+ "gap_type": g.gap_type,
+ "entity_id": g.entity_id,
+ "entity_name": g.entity_name,
+ "description": g.description,
+ "severity": g.severity,
+ "suggestion": g.suggestion,
+ }
+ for g in gaps
+ ],
}
@app.post("/api/v1/projects/{project_id}/search/index", tags=["Search"])
-async def index_project_for_search(
- project_id: str,
- _=Depends(verify_api_key)
-):
+async def index_project_for_search(project_id: str, _=Depends(verify_api_key)):
"""为项目创建搜索索引"""
if not SEARCH_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Search manager not available")
@@ -7949,10 +7697,9 @@ async def index_project_for_search(
# ==================== Phase 7 Task 8: Performance & Scaling ====================
+
@app.get("/api/v1/cache/stats", tags=["Performance"])
-async def get_cache_stats(
- _=Depends(verify_api_key)
-):
+async def get_cache_stats(_=Depends(verify_api_key)):
"""获取缓存统计"""
if not PERFORMANCE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Performance manager not available")
@@ -7967,15 +7714,12 @@ async def get_cache_stats(
"miss_count": stats.miss_count,
"hit_rate": stats.hit_rate,
"evicted_count": stats.evicted_count,
- "expired_count": stats.expired_count
+ "expired_count": stats.expired_count,
}
@app.post("/api/v1/cache/clear", tags=["Performance"])
-async def clear_cache(
- pattern: str | None = None,
- _=Depends(verify_api_key)
-):
+async def clear_cache(pattern: str | None = None, _=Depends(verify_api_key)):
"""清除缓存"""
if not PERFORMANCE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Performance manager not available")
@@ -7995,7 +7739,7 @@ async def get_performance_metrics(
endpoint: str | None = None,
hours: int = 24,
limit: int = 1000,
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取性能指标"""
if not PERFORMANCE_MANAGER_AVAILABLE:
@@ -8006,31 +7750,28 @@ async def get_performance_metrics(
start_time = (datetime.now() - timedelta(hours=hours)).isoformat()
metrics = perf_manager.monitor.get_metrics(
- metric_type=metric_type,
- endpoint=endpoint,
- start_time=start_time,
- limit=limit
+ metric_type=metric_type, endpoint=endpoint, start_time=start_time, limit=limit
)
return {
"period_hours": hours,
"total": len(metrics),
- "metrics": [{
- "id": m.id,
- "metric_type": m.metric_type,
- "endpoint": m.endpoint,
- "duration_ms": m.duration_ms,
- "status_code": m.status_code,
- "timestamp": m.timestamp
- } for m in metrics]
+ "metrics": [
+ {
+ "id": m.id,
+ "metric_type": m.metric_type,
+ "endpoint": m.endpoint,
+ "duration_ms": m.duration_ms,
+ "status_code": m.status_code,
+ "timestamp": m.timestamp,
+ }
+ for m in metrics
+ ],
}
@app.get("/api/v1/performance/summary", tags=["Performance"])
-async def get_performance_summary(
- hours: int = 24,
- _=Depends(verify_api_key)
-):
+async def get_performance_summary(hours: int = 24, _=Depends(verify_api_key)):
"""获取性能汇总统计"""
if not PERFORMANCE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Performance manager not available")
@@ -8042,10 +7783,7 @@ async def get_performance_summary(
@app.get("/api/v1/tasks/{task_id}/status", tags=["Performance"])
-async def get_task_status(
- task_id: str,
- _=Depends(verify_api_key)
-):
+async def get_task_status(task_id: str, _=Depends(verify_api_key)):
"""获取任务状态"""
if not PERFORMANCE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Performance manager not available")
@@ -8068,16 +7806,13 @@ async def get_task_status(
"started_at": task.started_at,
"completed_at": task.completed_at,
"retry_count": task.retry_count,
- "priority": task.priority
+ "priority": task.priority,
}
@app.get("/api/v1/tasks", tags=["Performance"])
async def list_tasks(
- project_id: str | None = None,
- status: str | None = None,
- limit: int = 50,
- _=Depends(verify_api_key)
+ project_id: str | None = None, status: str | None = None, limit: int = 50, _=Depends(verify_api_key)
):
"""列出任务"""
if not PERFORMANCE_MANAGER_AVAILABLE:
@@ -8088,23 +7823,23 @@ async def list_tasks(
return {
"total": len(tasks),
- "tasks": [{
- "task_id": t.task_id,
- "task_type": t.task_type,
- "status": t.status,
- "project_id": t.project_id,
- "created_at": t.created_at,
- "retry_count": t.retry_count,
- "priority": t.priority
- } for t in tasks]
+ "tasks": [
+ {
+ "task_id": t.task_id,
+ "task_type": t.task_type,
+ "status": t.status,
+ "project_id": t.project_id,
+ "created_at": t.created_at,
+ "retry_count": t.retry_count,
+ "priority": t.priority,
+ }
+ for t in tasks
+ ],
}
@app.post("/api/v1/tasks/{task_id}/cancel", tags=["Performance"])
-async def cancel_task(
- task_id: str,
- _=Depends(verify_api_key)
-):
+async def cancel_task(task_id: str, _=Depends(verify_api_key)):
"""取消任务"""
if not PERFORMANCE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Performance manager not available")
@@ -8119,9 +7854,7 @@ async def cancel_task(
@app.get("/api/v1/shards", tags=["Performance"])
-async def list_shards(
- _=Depends(verify_api_key)
-):
+async def list_shards(_=Depends(verify_api_key)):
"""列出数据库分片"""
if not PERFORMANCE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Performance manager not available")
@@ -8131,12 +7864,10 @@ async def list_shards(
return {
"shard_count": len(shards),
- "shards": [{
- "shard_id": s.shard_id,
- "entity_count": s.entity_count,
- "db_path": s.db_path,
- "created_at": s.created_at
- } for s in shards]
+ "shards": [
+ {"shard_id": s.shard_id, "entity_count": s.entity_count, "db_path": s.db_path, "created_at": s.created_at}
+ for s in shards
+ ],
}
@@ -8144,6 +7875,7 @@ async def list_shards(
# Phase 8: Multi-Tenant SaaS APIs
# ============================================
+
class CreateTenantRequest(BaseModel):
name: str
description: str | None = None
@@ -8184,9 +7916,7 @@ class UpdateMemberRequest(BaseModel):
# Tenant Management APIs
@app.post("/api/v1/tenants", tags=["Tenants"])
async def create_tenant(
- request: CreateTenantRequest,
- user_id: str = Header(..., description="当前用户ID"),
- _=Depends(verify_api_key)
+ request: CreateTenantRequest, user_id: str = Header(..., description="当前用户ID"), _=Depends(verify_api_key)
):
"""创建新租户"""
if not TENANT_MANAGER_AVAILABLE:
@@ -8195,10 +7925,7 @@ async def create_tenant(
manager = get_tenant_manager()
try:
tenant = manager.create_tenant(
- name=request.name,
- owner_id=user_id,
- tier=request.tier,
- description=request.description
+ name=request.name, owner_id=user_id, tier=request.tier, description=request.description
)
return {
"id": tenant.id,
@@ -8206,17 +7933,14 @@ async def create_tenant(
"slug": tenant.slug,
"tier": tenant.tier,
"status": tenant.status,
- "created_at": tenant.created_at.isoformat()
+ "created_at": tenant.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/tenants", tags=["Tenants"])
-async def list_my_tenants(
- user_id: str = Header(..., description="当前用户ID"),
- _=Depends(verify_api_key)
-):
+async def list_my_tenants(user_id: str = Header(..., description="当前用户ID"), _=Depends(verify_api_key)):
"""获取当前用户的所有租户"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8227,10 +7951,7 @@ async def list_my_tenants(
@app.get("/api/v1/tenants/{tenant_id}", tags=["Tenants"])
-async def get_tenant(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_tenant(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户详情"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8251,16 +7972,12 @@ async def get_tenant(
"owner_id": tenant.owner_id,
"created_at": tenant.created_at.isoformat(),
"settings": tenant.settings,
- "resource_limits": tenant.resource_limits
+ "resource_limits": tenant.resource_limits,
}
@app.put("/api/v1/tenants/{tenant_id}", tags=["Tenants"])
-async def update_tenant(
- tenant_id: str,
- request: UpdateTenantRequest,
- _=Depends(verify_api_key)
-):
+async def update_tenant(tenant_id: str, request: UpdateTenantRequest, _=Depends(verify_api_key)):
"""更新租户信息"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8271,7 +7988,7 @@ async def update_tenant(
name=request.name,
description=request.description,
tier=request.tier,
- status=request.status
+ status=request.status,
)
if not tenant:
@@ -8283,15 +8000,12 @@ async def update_tenant(
"slug": tenant.slug,
"tier": tenant.tier,
"status": tenant.status,
- "updated_at": tenant.updated_at.isoformat()
+ "updated_at": tenant.updated_at.isoformat(),
}
@app.delete("/api/v1/tenants/{tenant_id}", tags=["Tenants"])
-async def delete_tenant(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def delete_tenant(tenant_id: str, _=Depends(verify_api_key)):
"""删除租户"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8307,22 +8021,14 @@ async def delete_tenant(
# Domain Management APIs
@app.post("/api/v1/tenants/{tenant_id}/domains", tags=["Tenants"])
-async def add_domain(
- tenant_id: str,
- request: AddDomainRequest,
- _=Depends(verify_api_key)
-):
+async def add_domain(tenant_id: str, request: AddDomainRequest, _=Depends(verify_api_key)):
"""为租户添加自定义域名"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
manager = get_tenant_manager()
try:
- domain = manager.add_domain(
- tenant_id=tenant_id,
- domain=request.domain,
- is_primary=request.is_primary
- )
+ domain = manager.add_domain(tenant_id=tenant_id, domain=request.domain, is_primary=request.is_primary)
# 获取验证指导
instructions = manager.get_domain_verification_instructions(domain.id)
@@ -8334,17 +8040,14 @@ async def add_domain(
"is_primary": domain.is_primary,
"verification_token": domain.verification_token,
"verification_instructions": instructions,
- "created_at": domain.created_at.isoformat()
+ "created_at": domain.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/tenants/{tenant_id}/domains", tags=["Tenants"])
-async def list_domains(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def list_domains(tenant_id: str, _=Depends(verify_api_key)):
"""列出租户的所有域名"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8353,24 +8056,23 @@ async def list_domains(
domains = manager.list_domains(tenant_id)
return {
- "domains": [{
- "id": d.id,
- "domain": d.domain,
- "status": d.status,
- "is_primary": d.is_primary,
- "ssl_enabled": d.ssl_enabled,
- "verified_at": d.verified_at.isoformat() if d.verified_at else None,
- "created_at": d.created_at.isoformat()
- } for d in domains]
+ "domains": [
+ {
+ "id": d.id,
+ "domain": d.domain,
+ "status": d.status,
+ "is_primary": d.is_primary,
+ "ssl_enabled": d.ssl_enabled,
+ "verified_at": d.verified_at.isoformat() if d.verified_at else None,
+ "created_at": d.created_at.isoformat(),
+ }
+ for d in domains
+ ]
}
@app.post("/api/v1/tenants/{tenant_id}/domains/{domain_id}/verify", tags=["Tenants"])
-async def verify_domain(
- tenant_id: str,
- domain_id: str,
- _=Depends(verify_api_key)
-):
+async def verify_domain(tenant_id: str, domain_id: str, _=Depends(verify_api_key)):
"""验证域名所有权"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8378,18 +8080,11 @@ async def verify_domain(
manager = get_tenant_manager()
success = manager.verify_domain(tenant_id, domain_id)
- return {
- "success": success,
- "message": "Domain verified successfully" if success else "Domain verification failed"
- }
+ return {"success": success, "message": "Domain verified successfully" if success else "Domain verification failed"}
@app.delete("/api/v1/tenants/{tenant_id}/domains/{domain_id}", tags=["Tenants"])
-async def remove_domain(
- tenant_id: str,
- domain_id: str,
- _=Depends(verify_api_key)
-):
+async def remove_domain(tenant_id: str, domain_id: str, _=Depends(verify_api_key)):
"""移除域名绑定"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8405,10 +8100,7 @@ async def remove_domain(
# Branding APIs
@app.get("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"])
-async def get_branding(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_branding(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户品牌配置"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8423,7 +8115,7 @@ async def get_branding(
"favicon_url": None,
"primary_color": None,
"secondary_color": None,
- "custom_css": None
+ "custom_css": None,
}
return {
@@ -8434,16 +8126,12 @@ async def get_branding(
"secondary_color": branding.secondary_color,
"custom_css": branding.custom_css,
"custom_js": branding.custom_js,
- "login_page_bg": branding.login_page_bg
+ "login_page_bg": branding.login_page_bg,
}
@app.put("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"])
-async def update_branding(
- tenant_id: str,
- request: UpdateBrandingRequest,
- _=Depends(verify_api_key)
-):
+async def update_branding(tenant_id: str, request: UpdateBrandingRequest, _=Depends(verify_api_key)):
"""更新租户品牌配置"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8457,7 +8145,7 @@ async def update_branding(
secondary_color=request.secondary_color,
custom_css=request.custom_css,
custom_js=request.custom_js,
- login_page_bg=request.login_page_bg
+ login_page_bg=request.login_page_bg,
)
return {
@@ -8466,7 +8154,7 @@ async def update_branding(
"favicon_url": branding.favicon_url,
"primary_color": branding.primary_color,
"secondary_color": branding.secondary_color,
- "updated_at": branding.updated_at.isoformat()
+ "updated_at": branding.updated_at.isoformat(),
}
@@ -8480,6 +8168,7 @@ async def get_branding_css(tenant_id: str):
css = manager.get_branding_css(tenant_id)
from fastapi.responses import PlainTextResponse
+
return PlainTextResponse(content=css, media_type="text/css")
@@ -8489,7 +8178,7 @@ async def invite_member(
tenant_id: str,
request: InviteMemberRequest,
user_id: str = Header(..., description="邀请者用户ID"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""邀请成员加入租户"""
if not TENANT_MANAGER_AVAILABLE:
@@ -8497,30 +8186,21 @@ async def invite_member(
manager = get_tenant_manager()
try:
- member = manager.invite_member(
- tenant_id=tenant_id,
- email=request.email,
- role=request.role,
- invited_by=user_id
- )
+ member = manager.invite_member(tenant_id=tenant_id, email=request.email, role=request.role, invited_by=user_id)
return {
"id": member.id,
"email": member.email,
"role": member.role,
"status": member.status,
- "invited_at": member.invited_at.isoformat()
+ "invited_at": member.invited_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/tenants/{tenant_id}/members", tags=["Tenants"])
-async def list_members(
- tenant_id: str,
- status: str | None = None,
- _=Depends(verify_api_key)
-):
+async def list_members(tenant_id: str, status: str | None = None, _=Depends(verify_api_key)):
"""列出租户成员"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8529,27 +8209,25 @@ async def list_members(
members = manager.list_members(tenant_id, status)
return {
- "members": [{
- "id": m.id,
- "user_id": m.user_id,
- "email": m.email,
- "role": m.role,
- "status": m.status,
- "permissions": m.permissions,
- "invited_at": m.invited_at.isoformat(),
- "joined_at": m.joined_at.isoformat() if m.joined_at else None,
- "last_active_at": m.last_active_at.isoformat() if m.last_active_at else None
- } for m in members]
+ "members": [
+ {
+ "id": m.id,
+ "user_id": m.user_id,
+ "email": m.email,
+ "role": m.role,
+ "status": m.status,
+ "permissions": m.permissions,
+ "invited_at": m.invited_at.isoformat(),
+ "joined_at": m.joined_at.isoformat() if m.joined_at else None,
+ "last_active_at": m.last_active_at.isoformat() if m.last_active_at else None,
+ }
+ for m in members
+ ]
}
@app.put("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"])
-async def update_member(
- tenant_id: str,
- member_id: str,
- request: UpdateMemberRequest,
- _=Depends(verify_api_key)
-):
+async def update_member(tenant_id: str, member_id: str, request: UpdateMemberRequest, _=Depends(verify_api_key)):
"""更新成员角色"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8564,11 +8242,7 @@ async def update_member(
@app.delete("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"])
-async def remove_member(
- tenant_id: str,
- member_id: str,
- _=Depends(verify_api_key)
-):
+async def remove_member(tenant_id: str, member_id: str, _=Depends(verify_api_key)):
"""移除成员"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8584,10 +8258,7 @@ async def remove_member(
# Usage & Limits APIs
@app.get("/api/v1/tenants/{tenant_id}/usage", tags=["Tenants"])
-async def get_tenant_usage(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_tenant_usage(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户资源使用统计"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8599,11 +8270,7 @@ async def get_tenant_usage(
@app.get("/api/v1/tenants/{tenant_id}/limits/{resource_type}", tags=["Tenants"])
-async def check_resource_limit(
- tenant_id: str,
- resource_type: str,
- _=Depends(verify_api_key)
-):
+async def check_resource_limit(tenant_id: str, resource_type: str, _=Depends(verify_api_key)):
"""检查特定资源是否超限"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -8616,7 +8283,7 @@ async def check_resource_limit(
"allowed": allowed,
"current": current,
"limit": limit,
- "usage_percentage": round(current / limit * 100, 2) if limit > 0 else 0
+ "usage_percentage": round(current / limit * 100, 2) if limit > 0 else 0,
}
@@ -8643,19 +8310,15 @@ async def resolve_tenant_by_domain(domain: str):
"branding": {
"logo_url": branding.logo_url if branding else None,
"primary_color": branding.primary_color if branding else None,
- "favicon_url": branding.favicon_url if branding else None
- }
+ "favicon_url": branding.favicon_url if branding else None,
+ },
}
@app.get("/api/v1/health", tags=["System"])
async def detailed_health_check():
"""健康检查"""
- health = {
- "status": "healthy",
- "timestamp": datetime.now().isoformat(),
- "components": {}
- }
+ health = {"status": "healthy", "timestamp": datetime.now().isoformat(), "components": {}}
# 数据库检查
if DB_AVAILABLE:
@@ -8700,6 +8363,7 @@ async def detailed_health_check():
# ==================== Phase 8: Multi-Tenant SaaS API ====================
+
# Pydantic Models for Tenant API
class TenantCreate(BaseModel):
name: str = Field(..., description="租户名称")
@@ -8829,7 +8493,7 @@ async def create_tenant_endpoint(tenant: TenantCreate, request: Request, _=Depen
# 获取当前用户ID(从请求状态或API Key)
user_id = ""
- if hasattr(request.state, 'api_key') and request.state.api_key:
+ if hasattr(request.state, "api_key") and request.state.api_key:
user_id = request.state.api_key.created_by or ""
try:
@@ -8839,7 +8503,7 @@ async def create_tenant_endpoint(tenant: TenantCreate, request: Request, _=Depen
created_by=user_id,
description=tenant.description,
plan=TenantTier(tenant.plan),
- billing_email=tenant.billing_email
+ billing_email=tenant.billing_email,
)
return new_tenant.to_dict()
except ValueError as e:
@@ -8848,11 +8512,7 @@ async def create_tenant_endpoint(tenant: TenantCreate, request: Request, _=Depen
@app.get("/api/v1/tenants", response_model=list[TenantResponse], tags=["Tenants"])
async def list_tenants_endpoint(
- status: str | None = None,
- plan: str | None = None,
- limit: int = 100,
- offset: int = 0,
- _=Depends(verify_api_key)
+ status: str | None = None, plan: str | None = None, limit: int = 100, offset: int = 0, _=Depends(verify_api_key)
):
"""列出租户"""
if not TENANT_MANAGER_AVAILABLE:
@@ -8863,12 +8523,7 @@ async def list_tenants_endpoint(
status_enum = TenantStatus(status) if status else None
plan_enum = TenantTier(plan) if plan else None
- tenants = tenant_manager.list_tenants(
- status=status_enum,
- plan=plan_enum,
- limit=limit,
- offset=offset
- )
+ tenants = tenant_manager.list_tenants(status=status_enum, plan=plan_enum, limit=limit, offset=offset)
return [t.to_dict() for t in tenants]
@@ -9031,11 +8686,7 @@ async def get_tenant_branding_endpoint(tenant_id: str, _=Depends(verify_api_key)
@app.put("/api/v1/tenants/{tenant_id}/branding", tags=["Tenants"])
-async def update_tenant_branding_endpoint(
- tenant_id: str,
- branding: TenantBrandingUpdate,
- _=Depends(verify_api_key)
-):
+async def update_tenant_branding_endpoint(tenant_id: str, branding: TenantBrandingUpdate, _=Depends(verify_api_key)):
"""更新租户品牌配置"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -9064,17 +8715,13 @@ async def get_tenant_theme_css_endpoint(tenant_id: str):
if not branding:
raise HTTPException(status_code=404, detail="Branding not found")
- from fastapi.responses import PlainTextResponse
return PlainTextResponse(content=branding.get_theme_css(), media_type="text/css")
# Tenant Member API
@app.post("/api/v1/tenants/{tenant_id}/members/invite", response_model=TenantMemberResponse, tags=["Tenants"])
async def invite_tenant_member_endpoint(
- tenant_id: str,
- invite: TenantMemberInvite,
- request: Request,
- _=Depends(verify_api_key)
+ tenant_id: str, invite: TenantMemberInvite, request: Request, _=Depends(verify_api_key)
):
"""邀请成员加入租户"""
if not TENANT_MANAGER_AVAILABLE:
@@ -9084,7 +8731,7 @@ async def invite_tenant_member_endpoint(
# 获取当前用户ID
invited_by = ""
- if hasattr(request.state, 'api_key') and request.state.api_key:
+ if hasattr(request.state, "api_key") and request.state.api_key:
invited_by = request.state.api_key.created_by or ""
try:
@@ -9093,7 +8740,7 @@ async def invite_tenant_member_endpoint(
email=invite.email,
role=TenantRole(invite.role),
invited_by=invited_by,
- name=invite.name
+ name=invite.name,
)
return member.to_dict()
except ValueError as e:
@@ -9117,10 +8764,7 @@ async def accept_invitation_endpoint(token: str, user_id: str):
@app.get("/api/v1/tenants/{tenant_id}/members", response_model=list[TenantMemberResponse], tags=["Tenants"])
async def list_tenant_members_endpoint(
- tenant_id: str,
- status: str | None = None,
- role: str | None = None,
- _=Depends(verify_api_key)
+ tenant_id: str, status: str | None = None, role: str | None = None, _=Depends(verify_api_key)
):
"""列出租户成员"""
if not TENANT_MANAGER_AVAILABLE:
@@ -9137,11 +8781,7 @@ async def list_tenant_members_endpoint(
@app.put("/api/v1/tenants/{tenant_id}/members/{member_id}/role", tags=["Tenants"])
async def update_member_role_endpoint(
- tenant_id: str,
- member_id: str,
- role: str,
- request: Request,
- _=Depends(verify_api_key)
+ tenant_id: str, member_id: str, role: str, request: Request, _=Depends(verify_api_key)
):
"""更新成员角色"""
if not TENANT_MANAGER_AVAILABLE:
@@ -9151,15 +8791,12 @@ async def update_member_role_endpoint(
# 获取当前用户ID
updated_by = ""
- if hasattr(request.state, 'api_key') and request.state.api_key:
+ if hasattr(request.state, "api_key") and request.state.api_key:
updated_by = request.state.api_key.created_by or ""
try:
updated = tenant_manager.update_member_role(
- tenant_id=tenant_id,
- member_id=member_id,
- new_role=TenantRole(role),
- updated_by=updated_by
+ tenant_id=tenant_id, member_id=member_id, new_role=TenantRole(role), updated_by=updated_by
)
if not updated:
raise HTTPException(status_code=404, detail="Member not found")
@@ -9169,12 +8806,7 @@ async def update_member_role_endpoint(
@app.delete("/api/v1/tenants/{tenant_id}/members/{member_id}", tags=["Tenants"])
-async def remove_tenant_member_endpoint(
- tenant_id: str,
- member_id: str,
- request: Request,
- _=Depends(verify_api_key)
-):
+async def remove_tenant_member_endpoint(tenant_id: str, member_id: str, request: Request, _=Depends(verify_api_key)):
"""移除租户成员"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -9183,7 +8815,7 @@ async def remove_tenant_member_endpoint(
# 获取当前用户ID
removed_by = ""
- if hasattr(request.state, 'api_key') and request.state.api_key:
+ if hasattr(request.state, "api_key") and request.state.api_key:
removed_by = request.state.api_key.created_by or ""
try:
@@ -9208,11 +8840,7 @@ async def list_tenant_roles_endpoint(tenant_id: str, _=Depends(verify_api_key)):
@app.post("/api/v1/tenants/{tenant_id}/roles", response_model=TenantRoleResponse, tags=["Tenants"])
-async def create_tenant_role_endpoint(
- tenant_id: str,
- role: TenantRoleCreate,
- _=Depends(verify_api_key)
-):
+async def create_tenant_role_endpoint(tenant_id: str, role: TenantRoleCreate, _=Depends(verify_api_key)):
"""创建自定义角色"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
@@ -9221,10 +8849,7 @@ async def create_tenant_role_endpoint(
try:
new_role = tenant_manager.create_custom_role(
- tenant_id=tenant_id,
- name=role.name,
- description=role.description,
- permissions=role.permissions
+ tenant_id=tenant_id, name=role.name, description=role.description, permissions=role.permissions
)
return new_role.to_dict()
except ValueError as e:
@@ -9233,10 +8858,7 @@ async def create_tenant_role_endpoint(
@app.put("/api/v1/tenants/{tenant_id}/roles/{role_id}/permissions", tags=["Tenants"])
async def update_role_permissions_endpoint(
- tenant_id: str,
- role_id: str,
- permissions: list[str],
- _=Depends(verify_api_key)
+ tenant_id: str, role_id: str, permissions: list[str], _=Depends(verify_api_key)
):
"""更新角色权限"""
if not TENANT_MANAGER_AVAILABLE:
@@ -9277,32 +8899,20 @@ async def list_tenant_permissions_endpoint(_=Depends(verify_api_key)):
raise HTTPException(status_code=500, detail="Tenant manager not available")
tenant_manager = get_tenant_manager()
- return {
- "permissions": [
- {"id": k, "name": v}
- for k, v in tenant_manager.PERMISSION_NAMES.items()
- ]
- }
+ return {"permissions": [{"id": k, "name": v} for k, v in tenant_manager.PERMISSION_NAMES.items()]}
# Tenant Resolution API
@app.get("/api/v1/tenants/resolve", tags=["Tenants"])
async def resolve_tenant_endpoint(
- host: str | None = None,
- slug: str | None = None,
- tenant_id: str | None = None,
- _=Depends(verify_api_key)
+ host: str | None = None, slug: str | None = None, tenant_id: str | None = None, _=Depends(verify_api_key)
):
"""从请求信息解析租户"""
if not TENANT_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Tenant manager not available")
tenant_manager = get_tenant_manager()
- tenant = tenant_manager.resolve_tenant_from_request(
- host=host,
- slug=slug,
- tenant_id=tenant_id
- )
+ tenant = tenant_manager.resolve_tenant_from_request(host=host, slug=slug, tenant_id=tenant_id)
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
@@ -9329,6 +8939,7 @@ async def get_tenant_context_endpoint(tenant_id: str, _=Depends(verify_api_key))
# Phase 8 Task 2: Subscription & Billing APIs
# ============================================
+
# Pydantic Models for Subscription API
class CreateSubscriptionRequest(BaseModel):
plan_id: str = Field(..., description="订阅计划ID")
@@ -9381,8 +8992,7 @@ class CreateCheckoutSessionRequest(BaseModel):
# Subscription Plan APIs
@app.get("/api/v1/subscription-plans", tags=["Subscriptions"])
async def list_subscription_plans(
- include_inactive: bool = Query(default=False, description="包含已停用计划"),
- _=Depends(verify_api_key)
+ include_inactive: bool = Query(default=False, description="包含已停用计划"), _=Depends(verify_api_key)
):
"""获取所有订阅计划"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9403,7 +9013,7 @@ async def list_subscription_plans(
"currency": p.currency,
"features": p.features,
"limits": p.limits,
- "is_active": p.is_active
+ "is_active": p.is_active,
}
for p in plans
]
@@ -9411,10 +9021,7 @@ async def list_subscription_plans(
@app.get("/api/v1/subscription-plans/{plan_id}", tags=["Subscriptions"])
-async def get_subscription_plan(
- plan_id: str,
- _=Depends(verify_api_key)
-):
+async def get_subscription_plan(plan_id: str, _=Depends(verify_api_key)):
"""获取订阅计划详情"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9436,7 +9043,7 @@ async def get_subscription_plan(
"features": plan.features,
"limits": plan.limits,
"is_active": plan.is_active,
- "created_at": plan.created_at.isoformat()
+ "created_at": plan.created_at.isoformat(),
}
@@ -9446,7 +9053,7 @@ async def create_subscription(
tenant_id: str,
request: CreateSubscriptionRequest,
user_id: str = Header(..., description="当前用户ID"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""创建新订阅"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9459,7 +9066,7 @@ async def create_subscription(
plan_id=request.plan_id,
payment_provider=request.payment_provider,
trial_days=request.trial_days,
- billing_cycle=request.billing_cycle
+ billing_cycle=request.billing_cycle,
)
return {
@@ -9471,17 +9078,14 @@ async def create_subscription(
"current_period_end": subscription.current_period_end.isoformat(),
"trial_start": subscription.trial_start.isoformat() if subscription.trial_start else None,
"trial_end": subscription.trial_end.isoformat() if subscription.trial_end else None,
- "created_at": subscription.created_at.isoformat()
+ "created_at": subscription.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/tenants/{tenant_id}/subscription", tags=["Subscriptions"])
-async def get_tenant_subscription(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_tenant_subscription(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户当前订阅"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9508,17 +9112,13 @@ async def get_tenant_subscription(
"canceled_at": subscription.canceled_at.isoformat() if subscription.canceled_at else None,
"trial_start": subscription.trial_start.isoformat() if subscription.trial_start else None,
"trial_end": subscription.trial_end.isoformat() if subscription.trial_end else None,
- "created_at": subscription.created_at.isoformat()
+ "created_at": subscription.created_at.isoformat(),
}
}
@app.put("/api/v1/tenants/{tenant_id}/subscription/change-plan", tags=["Subscriptions"])
-async def change_subscription_plan(
- tenant_id: str,
- request: ChangePlanRequest,
- _=Depends(verify_api_key)
-):
+async def change_subscription_plan(tenant_id: str, request: ChangePlanRequest, _=Depends(verify_api_key)):
"""更改订阅计划"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9531,27 +9131,21 @@ async def change_subscription_plan(
try:
updated = manager.change_plan(
- subscription_id=subscription.id,
- new_plan_id=request.new_plan_id,
- prorate=request.prorate
+ subscription_id=subscription.id, new_plan_id=request.new_plan_id, prorate=request.prorate
)
return {
"id": updated.id,
"plan_id": updated.plan_id,
"status": updated.status,
- "message": "Plan changed successfully"
+ "message": "Plan changed successfully",
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/api/v1/tenants/{tenant_id}/subscription/cancel", tags=["Subscriptions"])
-async def cancel_subscription(
- tenant_id: str,
- request: CancelSubscriptionRequest,
- _=Depends(verify_api_key)
-):
+async def cancel_subscription(tenant_id: str, request: CancelSubscriptionRequest, _=Depends(verify_api_key)):
"""取消订阅"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9563,17 +9157,14 @@ async def cancel_subscription(
raise HTTPException(status_code=404, detail="No active subscription found")
try:
- updated = manager.cancel_subscription(
- subscription_id=subscription.id,
- at_period_end=request.at_period_end
- )
+ updated = manager.cancel_subscription(subscription_id=subscription.id, at_period_end=request.at_period_end)
return {
"id": updated.id,
"status": updated.status,
"cancel_at_period_end": updated.cancel_at_period_end,
"canceled_at": updated.canceled_at.isoformat() if updated.canceled_at else None,
- "message": "Subscription cancelled"
+ "message": "Subscription cancelled",
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -9581,11 +9172,7 @@ async def cancel_subscription(
# Usage APIs
@app.post("/api/v1/tenants/{tenant_id}/usage", tags=["Subscriptions"])
-async def record_usage(
- tenant_id: str,
- request: RecordUsageRequest,
- _=Depends(verify_api_key)
-):
+async def record_usage(tenant_id: str, request: RecordUsageRequest, _=Depends(verify_api_key)):
"""记录用量"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9596,7 +9183,7 @@ async def record_usage(
resource_type=request.resource_type,
quantity=request.quantity,
unit=request.unit,
- description=request.description
+ description=request.description,
)
return {
@@ -9606,7 +9193,7 @@ async def record_usage(
"quantity": record.quantity,
"unit": record.unit,
"cost": record.cost,
- "recorded_at": record.recorded_at.isoformat()
+ "recorded_at": record.recorded_at.isoformat(),
}
@@ -9615,7 +9202,7 @@ async def get_usage_summary(
tenant_id: str,
start_date: str | None = Query(default=None, description="开始日期 (ISO格式)"),
end_date: str | None = Query(default=None, description="结束日期 (ISO格式)"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取用量汇总"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9638,7 +9225,7 @@ async def list_payments(
status: str | None = Query(default=None, description="支付状态过滤"),
limit: int = Query(default=100, description="返回数量限制"),
offset: int = Query(default=0, description="偏移量"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取支付记录列表"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9658,20 +9245,16 @@ async def list_payments(
"payment_method": p.payment_method,
"paid_at": p.paid_at.isoformat() if p.paid_at else None,
"failed_at": p.failed_at.isoformat() if p.failed_at else None,
- "created_at": p.created_at.isoformat()
+ "created_at": p.created_at.isoformat(),
}
for p in payments
],
- "total": len(payments)
+ "total": len(payments),
}
@app.get("/api/v1/tenants/{tenant_id}/payments/{payment_id}", tags=["Subscriptions"])
-async def get_payment(
- tenant_id: str,
- payment_id: str,
- _=Depends(verify_api_key)
-):
+async def get_payment(tenant_id: str, payment_id: str, _=Depends(verify_api_key)):
"""获取支付记录详情"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9696,7 +9279,7 @@ async def get_payment(
"paid_at": payment.paid_at.isoformat() if payment.paid_at else None,
"failed_at": payment.failed_at.isoformat() if payment.failed_at else None,
"failure_reason": payment.failure_reason,
- "created_at": payment.created_at.isoformat()
+ "created_at": payment.created_at.isoformat(),
}
@@ -9707,7 +9290,7 @@ async def list_invoices(
status: str | None = Query(default=None, description="发票状态过滤"),
limit: int = Query(default=100, description="返回数量限制"),
offset: int = Query(default=0, description="偏移量"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取发票列表"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9730,20 +9313,16 @@ async def list_invoices(
"description": inv.description,
"due_date": inv.due_date.isoformat() if inv.due_date else None,
"paid_at": inv.paid_at.isoformat() if inv.paid_at else None,
- "created_at": inv.created_at.isoformat()
+ "created_at": inv.created_at.isoformat(),
}
for inv in invoices
],
- "total": len(invoices)
+ "total": len(invoices),
}
@app.get("/api/v1/tenants/{tenant_id}/invoices/{invoice_id}", tags=["Subscriptions"])
-async def get_invoice(
- tenant_id: str,
- invoice_id: str,
- _=Depends(verify_api_key)
-):
+async def get_invoice(tenant_id: str, invoice_id: str, _=Depends(verify_api_key)):
"""获取发票详情"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9769,7 +9348,7 @@ async def get_invoice(
"paid_at": invoice.paid_at.isoformat() if invoice.paid_at else None,
"voided_at": invoice.voided_at.isoformat() if invoice.voided_at else None,
"void_reason": invoice.void_reason,
- "created_at": invoice.created_at.isoformat()
+ "created_at": invoice.created_at.isoformat(),
}
@@ -9779,7 +9358,7 @@ async def request_refund(
tenant_id: str,
request: RequestRefundRequest,
user_id: str = Header(..., description="当前用户ID"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""申请退款"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9792,7 +9371,7 @@ async def request_refund(
payment_id=request.payment_id,
amount=request.amount,
reason=request.reason,
- requested_by=user_id
+ requested_by=user_id,
)
return {
@@ -9802,7 +9381,7 @@ async def request_refund(
"currency": refund.currency,
"reason": refund.reason,
"status": refund.status,
- "requested_at": refund.requested_at.isoformat()
+ "requested_at": refund.requested_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -9814,7 +9393,7 @@ async def list_refunds(
status: str | None = Query(default=None, description="退款状态过滤"),
limit: int = Query(default=100, description="返回数量限制"),
offset: int = Query(default=0, description="偏移量"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取退款记录列表"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9836,11 +9415,11 @@ async def list_refunds(
"requested_at": r.requested_at.isoformat(),
"approved_by": r.approved_by,
"approved_at": r.approved_at.isoformat() if r.approved_at else None,
- "completed_at": r.completed_at.isoformat() if r.completed_at else None
+ "completed_at": r.completed_at.isoformat() if r.completed_at else None,
}
for r in refunds
],
- "total": len(refunds)
+ "total": len(refunds),
}
@@ -9850,7 +9429,7 @@ async def process_refund(
refund_id: str,
request: ProcessRefundRequest,
user_id: str = Header(..., description="当前用户ID"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""处理退款申请(管理员)"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9866,11 +9445,7 @@ async def process_refund(
# 自动完成退款(简化实现)
refund = manager.complete_refund(refund_id)
- return {
- "id": refund.id,
- "status": refund.status,
- "message": "Refund approved and processed"
- }
+ return {"id": refund.id, "status": refund.status, "message": "Refund approved and processed"}
elif request.action == "reject":
if not request.reason:
@@ -9880,11 +9455,7 @@ async def process_refund(
if not refund:
raise HTTPException(status_code=404, detail="Refund not found")
- return {
- "id": refund.id,
- "status": refund.status,
- "message": "Refund rejected"
- }
+ return {"id": refund.id, "status": refund.status, "message": "Refund rejected"}
else:
raise HTTPException(status_code=400, detail="Invalid action")
@@ -9898,7 +9469,7 @@ async def get_billing_history(
end_date: str | None = Query(default=None, description="结束日期 (ISO格式)"),
limit: int = Query(default=100, description="返回数量限制"),
offset: int = Query(default=0, description="偏移量"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""获取账单历史"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9921,21 +9492,17 @@ async def get_billing_history(
"description": h.description,
"reference_id": h.reference_id,
"balance_after": h.balance_after,
- "created_at": h.created_at.isoformat()
+ "created_at": h.created_at.isoformat(),
}
for h in history
],
- "total": len(history)
+ "total": len(history),
}
# Payment Provider Integration APIs
@app.post("/api/v1/tenants/{tenant_id}/checkout/stripe", tags=["Subscriptions"])
-async def create_stripe_checkout(
- tenant_id: str,
- request: CreateCheckoutSessionRequest,
- _=Depends(verify_api_key)
-):
+async def create_stripe_checkout(tenant_id: str, request: CreateCheckoutSessionRequest, _=Depends(verify_api_key)):
"""创建 Stripe Checkout 会话"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Subscription manager not available")
@@ -9948,7 +9515,7 @@ async def create_stripe_checkout(
plan_id=request.plan_id,
success_url=request.success_url,
cancel_url=request.cancel_url,
- billing_cycle=request.billing_cycle
+ billing_cycle=request.billing_cycle,
)
return session
@@ -9961,7 +9528,7 @@ async def create_alipay_order(
tenant_id: str,
plan_id: str,
billing_cycle: str = Query(default="monthly", description="计费周期"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""创建支付宝订单"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9970,11 +9537,7 @@ async def create_alipay_order(
manager = get_subscription_manager()
try:
- order = manager.create_alipay_order(
- tenant_id=tenant_id,
- plan_id=plan_id,
- billing_cycle=billing_cycle
- )
+ order = manager.create_alipay_order(tenant_id=tenant_id, plan_id=plan_id, billing_cycle=billing_cycle)
return order
except Exception as e:
@@ -9986,7 +9549,7 @@ async def create_wechat_order(
tenant_id: str,
plan_id: str,
billing_cycle: str = Query(default="monthly", description="计费周期"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""创建微信支付订单"""
if not SUBSCRIPTION_MANAGER_AVAILABLE:
@@ -9995,11 +9558,7 @@ async def create_wechat_order(
manager = get_subscription_manager()
try:
- order = manager.create_wechat_order(
- tenant_id=tenant_id,
- plan_id=plan_id,
- billing_cycle=billing_cycle
- )
+ order = manager.create_wechat_order(tenant_id=tenant_id, plan_id=plan_id, billing_cycle=billing_cycle)
return order
except Exception as e:
@@ -10062,6 +9621,7 @@ async def wechat_webhook(request: Request):
# Pydantic Models for Enterprise
+
class SSOConfigCreate(BaseModel):
provider: str = Field(..., description="SSO 提供商: wechat_work/dingtalk/feishu/okta/azure_ad/google/custom_saml")
entity_id: str | None = Field(default=None, description="SAML Entity ID")
@@ -10158,12 +9718,9 @@ class RetentionPolicyUpdate(BaseModel):
# SSO/SAML APIs
+
@app.post("/api/v1/tenants/{tenant_id}/sso-configs", tags=["Enterprise"])
-async def create_sso_config_endpoint(
- tenant_id: str,
- config: SSOConfigCreate,
- _=Depends(verify_api_key)
-):
+async def create_sso_config_endpoint(tenant_id: str, config: SSOConfigCreate, _=Depends(verify_api_key)):
"""创建 SSO 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10189,7 +9746,7 @@ async def create_sso_config_endpoint(
attribute_mapping=config.attribute_mapping,
auto_provision=config.auto_provision,
default_role=config.default_role,
- domain_restriction=config.domain_restriction
+ domain_restriction=config.domain_restriction,
)
return {
@@ -10203,17 +9760,14 @@ async def create_sso_config_endpoint(
"scopes": sso_config.scopes,
"auto_provision": sso_config.auto_provision,
"default_role": sso_config.default_role,
- "created_at": sso_config.created_at.isoformat()
+ "created_at": sso_config.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/tenants/{tenant_id}/sso-configs", tags=["Enterprise"])
-async def list_sso_configs_endpoint(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def list_sso_configs_endpoint(tenant_id: str, _=Depends(verify_api_key)):
"""列出租户的所有 SSO 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10232,20 +9786,16 @@ async def list_sso_configs_endpoint(
"authorization_url": c.authorization_url,
"auto_provision": c.auto_provision,
"default_role": c.default_role,
- "created_at": c.created_at.isoformat()
+ "created_at": c.created_at.isoformat(),
}
for c in configs
],
- "total": len(configs)
+ "total": len(configs),
}
@app.get("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"])
-async def get_sso_config_endpoint(
- tenant_id: str,
- config_id: str,
- _=Depends(verify_api_key)
-):
+async def get_sso_config_endpoint(tenant_id: str, config_id: str, _=Depends(verify_api_key)):
"""获取 SSO 配置详情"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10274,16 +9824,13 @@ async def get_sso_config_endpoint(
"default_role": config.default_role,
"domain_restriction": config.domain_restriction,
"created_at": config.created_at.isoformat(),
- "updated_at": config.updated_at.isoformat()
+ "updated_at": config.updated_at.isoformat(),
}
@app.put("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"])
async def update_sso_config_endpoint(
- tenant_id: str,
- config_id: str,
- update: SSOConfigUpdate,
- _=Depends(verify_api_key)
+ tenant_id: str, config_id: str, update: SSOConfigUpdate, _=Depends(verify_api_key)
):
"""更新 SSO 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10296,23 +9843,14 @@ async def update_sso_config_endpoint(
raise HTTPException(status_code=404, detail="SSO config not found")
updated = manager.update_sso_config(
- config_id=config_id,
- **{k: v for k, v in update.dict().items() if v is not None}
+ config_id=config_id, **{k: v for k, v in update.dict().items() if v is not None}
)
- return {
- "id": updated.id,
- "status": updated.status,
- "updated_at": updated.updated_at.isoformat()
- }
+ return {"id": updated.id, "status": updated.status, "updated_at": updated.updated_at.isoformat()}
@app.delete("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}", tags=["Enterprise"])
-async def delete_sso_config_endpoint(
- tenant_id: str,
- config_id: str,
- _=Depends(verify_api_key)
-):
+async def delete_sso_config_endpoint(tenant_id: str, config_id: str, _=Depends(verify_api_key)):
"""删除 SSO 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10329,10 +9867,7 @@ async def delete_sso_config_endpoint(
@app.get("/api/v1/tenants/{tenant_id}/sso-configs/{config_id}/metadata", tags=["Enterprise"])
async def get_sso_metadata_endpoint(
- tenant_id: str,
- config_id: str,
- base_url: str = Query(..., description="服务基础 URL"),
- _=Depends(verify_api_key)
+ tenant_id: str, config_id: str, base_url: str = Query(..., description="服务基础 URL"), _=Depends(verify_api_key)
):
"""获取 SAML Service Provider 元数据"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10350,18 +9885,15 @@ async def get_sso_metadata_endpoint(
"metadata_xml": metadata,
"entity_id": f"{base_url}/api/v1/sso/saml/{tenant_id}",
"acs_url": f"{base_url}/api/v1/sso/saml/{tenant_id}/acs",
- "slo_url": f"{base_url}/api/v1/sso/saml/{tenant_id}/slo"
+ "slo_url": f"{base_url}/api/v1/sso/saml/{tenant_id}/slo",
}
# SCIM APIs
+
@app.post("/api/v1/tenants/{tenant_id}/scim-configs", tags=["Enterprise"])
-async def create_scim_config_endpoint(
- tenant_id: str,
- config: SCIMConfigCreate,
- _=Depends(verify_api_key)
-):
+async def create_scim_config_endpoint(tenant_id: str, config: SCIMConfigCreate, _=Depends(verify_api_key)):
"""创建 SCIM 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10376,7 +9908,7 @@ async def create_scim_config_endpoint(
scim_token=config.scim_token,
sync_interval_minutes=config.sync_interval_minutes,
attribute_mapping=config.attribute_mapping,
- sync_rules=config.sync_rules
+ sync_rules=config.sync_rules,
)
return {
@@ -10386,17 +9918,14 @@ async def create_scim_config_endpoint(
"status": scim_config.status,
"scim_base_url": scim_config.scim_base_url,
"sync_interval_minutes": scim_config.sync_interval_minutes,
- "created_at": scim_config.created_at.isoformat()
+ "created_at": scim_config.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/tenants/{tenant_id}/scim-configs", tags=["Enterprise"])
-async def get_scim_config_endpoint(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_scim_config_endpoint(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户的 SCIM 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10417,16 +9946,13 @@ async def get_scim_config_endpoint(
"last_sync_at": config.last_sync_at.isoformat() if config.last_sync_at else None,
"last_sync_status": config.last_sync_status,
"last_sync_users_count": config.last_sync_users_count,
- "created_at": config.created_at.isoformat()
+ "created_at": config.created_at.isoformat(),
}
@app.put("/api/v1/tenants/{tenant_id}/scim-configs/{config_id}", tags=["Enterprise"])
async def update_scim_config_endpoint(
- tenant_id: str,
- config_id: str,
- update: SCIMConfigUpdate,
- _=Depends(verify_api_key)
+ tenant_id: str, config_id: str, update: SCIMConfigUpdate, _=Depends(verify_api_key)
):
"""更新 SCIM 配置"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10439,23 +9965,14 @@ async def update_scim_config_endpoint(
raise HTTPException(status_code=404, detail="SCIM config not found")
updated = manager.update_scim_config(
- config_id=config_id,
- **{k: v for k, v in update.dict().items() if v is not None}
+ config_id=config_id, **{k: v for k, v in update.dict().items() if v is not None}
)
- return {
- "id": updated.id,
- "status": updated.status,
- "updated_at": updated.updated_at.isoformat()
- }
+ return {"id": updated.id, "status": updated.status, "updated_at": updated.updated_at.isoformat()}
@app.post("/api/v1/tenants/{tenant_id}/scim-configs/{config_id}/sync", tags=["Enterprise"])
-async def sync_scim_users_endpoint(
- tenant_id: str,
- config_id: str,
- _=Depends(verify_api_key)
-):
+async def sync_scim_users_endpoint(tenant_id: str, config_id: str, _=Depends(verify_api_key)):
"""执行 SCIM 用户同步"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10473,9 +9990,7 @@ async def sync_scim_users_endpoint(
@app.get("/api/v1/tenants/{tenant_id}/scim-users", tags=["Enterprise"])
async def list_scim_users_endpoint(
- tenant_id: str,
- active_only: bool = Query(default=True, description="仅显示活跃用户"),
- _=Depends(verify_api_key)
+ tenant_id: str, active_only: bool = Query(default=True, description="仅显示活跃用户"), _=Depends(verify_api_key)
):
"""列出 SCIM 用户"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10494,22 +10009,23 @@ async def list_scim_users_endpoint(
"display_name": u.display_name,
"active": u.active,
"groups": u.groups,
- "synced_at": u.synced_at.isoformat()
+ "synced_at": u.synced_at.isoformat(),
}
for u in users
],
- "total": len(users)
+ "total": len(users),
}
# Audit Log Export APIs
+
@app.post("/api/v1/tenants/{tenant_id}/audit-exports", tags=["Enterprise"])
async def create_audit_export_endpoint(
tenant_id: str,
request: AuditExportCreate,
current_user: str = Header(default="user", description="当前用户ID"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""创建审计日志导出任务"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10528,7 +10044,7 @@ async def create_audit_export_endpoint(
end_date=end_date,
created_by=current_user,
filters=request.filters,
- compliance_standard=request.compliance_standard
+ compliance_standard=request.compliance_standard,
)
return {
@@ -10540,7 +10056,7 @@ async def create_audit_export_endpoint(
"compliance_standard": export.compliance_standard,
"status": export.status,
"expires_at": export.expires_at.isoformat() if export.expires_at else None,
- "created_at": export.created_at.isoformat()
+ "created_at": export.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -10548,9 +10064,7 @@ async def create_audit_export_endpoint(
@app.get("/api/v1/tenants/{tenant_id}/audit-exports", tags=["Enterprise"])
async def list_audit_exports_endpoint(
- tenant_id: str,
- limit: int = Query(default=100, description="返回数量限制"),
- _=Depends(verify_api_key)
+ tenant_id: str, limit: int = Query(default=100, description="返回数量限制"), _=Depends(verify_api_key)
):
"""列出审计日志导出记录"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10572,20 +10086,16 @@ async def list_audit_exports_endpoint(
"record_count": e.record_count,
"downloaded_by": e.downloaded_by,
"expires_at": e.expires_at.isoformat() if e.expires_at else None,
- "created_at": e.created_at.isoformat()
+ "created_at": e.created_at.isoformat(),
}
for e in exports
],
- "total": len(exports)
+ "total": len(exports),
}
@app.get("/api/v1/tenants/{tenant_id}/audit-exports/{export_id}", tags=["Enterprise"])
-async def get_audit_export_endpoint(
- tenant_id: str,
- export_id: str,
- _=Depends(verify_api_key)
-):
+async def get_audit_export_endpoint(tenant_id: str, export_id: str, _=Depends(verify_api_key)):
"""获取审计日志导出详情"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10612,7 +10122,7 @@ async def get_audit_export_endpoint(
"expires_at": export.expires_at.isoformat() if export.expires_at else None,
"created_at": export.created_at.isoformat(),
"completed_at": export.completed_at.isoformat() if export.completed_at else None,
- "error_message": export.error_message
+ "error_message": export.error_message,
}
@@ -10621,7 +10131,7 @@ async def download_audit_export_endpoint(
tenant_id: str,
export_id: str,
current_user: str = Header(default="user", description="当前用户ID"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""下载审计日志导出文件"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10642,18 +10152,15 @@ async def download_audit_export_endpoint(
# 返回文件下载信息
return {
"download_url": f"/api/v1/tenants/{tenant_id}/audit-exports/{export_id}/file",
- "expires_at": export.expires_at.isoformat() if export.expires_at else None
+ "expires_at": export.expires_at.isoformat() if export.expires_at else None,
}
# Data Retention Policy APIs
+
@app.post("/api/v1/tenants/{tenant_id}/retention-policies", tags=["Enterprise"])
-async def create_retention_policy_endpoint(
- tenant_id: str,
- policy: RetentionPolicyCreate,
- _=Depends(verify_api_key)
-):
+async def create_retention_policy_endpoint(tenant_id: str, policy: RetentionPolicyCreate, _=Depends(verify_api_key)):
"""创建数据保留策略"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10673,7 +10180,7 @@ async def create_retention_policy_endpoint(
execute_at=policy.execute_at,
notify_before_days=policy.notify_before_days,
archive_location=policy.archive_location,
- archive_encryption=policy.archive_encryption
+ archive_encryption=policy.archive_encryption,
)
return {
@@ -10685,7 +10192,7 @@ async def create_retention_policy_endpoint(
"action": new_policy.action,
"auto_execute": new_policy.auto_execute,
"is_active": new_policy.is_active,
- "created_at": new_policy.created_at.isoformat()
+ "created_at": new_policy.created_at.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -10695,7 +10202,7 @@ async def create_retention_policy_endpoint(
async def list_retention_policies_endpoint(
tenant_id: str,
resource_type: str | None = Query(default=None, description="资源类型过滤"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""列出数据保留策略"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10714,20 +10221,16 @@ async def list_retention_policies_endpoint(
"action": p.action,
"auto_execute": p.auto_execute,
"is_active": p.is_active,
- "last_executed_at": p.last_executed_at.isoformat() if p.last_executed_at else None
+ "last_executed_at": p.last_executed_at.isoformat() if p.last_executed_at else None,
}
for p in policies
],
- "total": len(policies)
+ "total": len(policies),
}
@app.get("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"])
-async def get_retention_policy_endpoint(
- tenant_id: str,
- policy_id: str,
- _=Depends(verify_api_key)
-):
+async def get_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depends(verify_api_key)):
"""获取数据保留策略详情"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10755,16 +10258,13 @@ async def get_retention_policy_endpoint(
"is_active": policy.is_active,
"last_executed_at": policy.last_executed_at.isoformat() if policy.last_executed_at else None,
"last_execution_result": policy.last_execution_result,
- "created_at": policy.created_at.isoformat()
+ "created_at": policy.created_at.isoformat(),
}
@app.put("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"])
async def update_retention_policy_endpoint(
- tenant_id: str,
- policy_id: str,
- update: RetentionPolicyUpdate,
- _=Depends(verify_api_key)
+ tenant_id: str, policy_id: str, update: RetentionPolicyUpdate, _=Depends(verify_api_key)
):
"""更新数据保留策略"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10777,22 +10277,14 @@ async def update_retention_policy_endpoint(
raise HTTPException(status_code=404, detail="Policy not found")
updated = manager.update_retention_policy(
- policy_id=policy_id,
- **{k: v for k, v in update.dict().items() if v is not None}
+ policy_id=policy_id, **{k: v for k, v in update.dict().items() if v is not None}
)
- return {
- "id": updated.id,
- "updated_at": updated.updated_at.isoformat()
- }
+ return {"id": updated.id, "updated_at": updated.updated_at.isoformat()}
@app.delete("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}", tags=["Enterprise"])
-async def delete_retention_policy_endpoint(
- tenant_id: str,
- policy_id: str,
- _=Depends(verify_api_key)
-):
+async def delete_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depends(verify_api_key)):
"""删除数据保留策略"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10808,11 +10300,7 @@ async def delete_retention_policy_endpoint(
@app.post("/api/v1/tenants/{tenant_id}/retention-policies/{policy_id}/execute", tags=["Enterprise"])
-async def execute_retention_policy_endpoint(
- tenant_id: str,
- policy_id: str,
- _=Depends(verify_api_key)
-):
+async def execute_retention_policy_endpoint(tenant_id: str, policy_id: str, _=Depends(verify_api_key)):
"""执行数据保留策略"""
if not ENTERPRISE_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Enterprise manager not available")
@@ -10830,7 +10318,7 @@ async def execute_retention_policy_endpoint(
"policy_id": job.policy_id,
"status": job.status,
"started_at": job.started_at.isoformat() if job.started_at else None,
- "created_at": job.created_at.isoformat()
+ "created_at": job.created_at.isoformat(),
}
@@ -10839,7 +10327,7 @@ async def list_retention_jobs_endpoint(
tenant_id: str,
policy_id: str,
limit: int = Query(default=100, description="返回数量限制"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""列出数据保留任务"""
if not ENTERPRISE_MANAGER_AVAILABLE:
@@ -10863,11 +10351,11 @@ async def list_retention_jobs_endpoint(
"affected_records": j.affected_records,
"archived_records": j.archived_records,
"deleted_records": j.deleted_records,
- "error_count": j.error_count
+ "error_count": j.error_count,
}
for j in jobs
],
- "total": len(jobs)
+ "total": len(jobs),
}
@@ -10875,6 +10363,7 @@ async def list_retention_jobs_endpoint(
# Phase 8 Task 7: Globalization & Localization API
# ============================================
+
# Pydantic Models for Localization API
class TranslationCreate(BaseModel):
key: str = Field(..., description="翻译键")
@@ -10938,10 +10427,7 @@ class ConvertTimezoneRequest(BaseModel):
# Translation APIs
@app.get("/api/v1/translations/{language}/{key}", tags=["Localization"])
async def get_translation(
- language: str,
- key: str,
- namespace: str = Query(default="common", description="命名空间"),
- _=Depends(verify_api_key)
+ language: str, key: str, namespace: str = Query(default="common", description="命名空间"), _=Depends(verify_api_key)
):
"""获取翻译"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -10953,31 +10439,18 @@ async def get_translation(
if value is None:
raise HTTPException(status_code=404, detail="Translation not found")
- return {
- "key": key,
- "language": language,
- "namespace": namespace,
- "value": value
- }
+ return {"key": key, "language": language, "namespace": namespace, "value": value}
@app.post("/api/v1/translations/{language}", tags=["Localization"])
-async def create_translation(
- language: str,
- request: TranslationCreate,
- _=Depends(verify_api_key)
-):
+async def create_translation(language: str, request: TranslationCreate, _=Depends(verify_api_key)):
"""创建/更新翻译"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
manager = get_localization_manager()
translation = manager.set_translation(
- key=request.key,
- language=language,
- value=request.value,
- namespace=request.namespace,
- context=request.context
+ key=request.key, language=language, value=request.value, namespace=request.namespace, context=request.context
)
return {
@@ -10986,7 +10459,7 @@ async def create_translation(
"language": translation.language,
"namespace": translation.namespace,
"value": translation.value,
- "created_at": translation.created_at.isoformat()
+ "created_at": translation.created_at.isoformat(),
}
@@ -10996,7 +10469,7 @@ async def update_translation(
key: str,
request: TranslationUpdate,
namespace: str = Query(default="common", description="命名空间"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""更新翻译"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11004,11 +10477,7 @@ async def update_translation(
manager = get_localization_manager()
translation = manager.set_translation(
- key=key,
- language=language,
- value=request.value,
- namespace=namespace,
- context=request.context
+ key=key, language=language, value=request.value, namespace=namespace, context=request.context
)
return {
@@ -11017,16 +10486,13 @@ async def update_translation(
"language": translation.language,
"namespace": translation.namespace,
"value": translation.value,
- "updated_at": translation.updated_at.isoformat()
+ "updated_at": translation.updated_at.isoformat(),
}
@app.delete("/api/v1/translations/{language}/{key}", tags=["Localization"])
async def delete_translation(
- language: str,
- key: str,
- namespace: str = Query(default="common", description="命名空间"),
- _=Depends(verify_api_key)
+ language: str, key: str, namespace: str = Query(default="common", description="命名空间"), _=Depends(verify_api_key)
):
"""删除翻译"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11047,7 +10513,7 @@ async def list_translations(
namespace: str | None = Query(default=None, description="命名空间"),
limit: int = Query(default=1000, description="返回数量限制"),
offset: int = Query(default=0, description="偏移量"),
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""列出翻译"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11065,19 +10531,17 @@ async def list_translations(
"namespace": t.namespace,
"value": t.value,
"is_reviewed": t.is_reviewed,
- "updated_at": t.updated_at.isoformat()
+ "updated_at": t.updated_at.isoformat(),
}
for t in translations
],
- "total": len(translations)
+ "total": len(translations),
}
# Language APIs
@app.get("/api/v1/languages", tags=["Localization"])
-async def list_languages(
- active_only: bool = Query(default=True, description="仅返回激活的语言")
-):
+async def list_languages(active_only: bool = Query(default=True, description="仅返回激活的语言")):
"""列出支持的语言"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
@@ -11096,11 +10560,11 @@ async def list_languages(
"is_default": lang.is_default,
"date_format": lang.date_format,
"time_format": lang.time_format,
- "calendar_type": lang.calendar_type
+ "calendar_type": lang.calendar_type,
}
for lang in languages
],
- "total": len(languages)
+ "total": len(languages),
}
@@ -11130,7 +10594,7 @@ async def get_language(code: str):
"number_format": lang.number_format,
"currency_format": lang.currency_format,
"first_day_of_week": lang.first_day_of_week,
- "calendar_type": lang.calendar_type
+ "calendar_type": lang.calendar_type,
}
@@ -11138,7 +10602,7 @@ async def get_language(code: str):
@app.get("/api/v1/data-centers", tags=["Localization"])
async def list_data_centers(
status: str | None = Query(default=None, description="状态过滤"),
- region: str | None = Query(default=None, description="区域过滤")
+ region: str | None = Query(default=None, description="区域过滤"),
):
"""列出数据中心"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11157,11 +10621,11 @@ async def list_data_centers(
"endpoint": dc.endpoint,
"status": dc.status,
"priority": dc.priority,
- "supported_regions": dc.supported_regions
+ "supported_regions": dc.supported_regions,
}
for dc in data_centers
],
- "total": len(data_centers)
+ "total": len(data_centers),
}
@@ -11186,15 +10650,12 @@ async def get_data_center(dc_id: str):
"status": dc.status,
"priority": dc.priority,
"supported_regions": dc.supported_regions,
- "capabilities": dc.capabilities
+ "capabilities": dc.capabilities,
}
@app.get("/api/v1/tenants/{tenant_id}/data-center", tags=["Localization"])
-async def get_tenant_data_center(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_tenant_data_center(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户数据中心配置"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
@@ -11214,37 +10675,39 @@ async def get_tenant_data_center(
"tenant_id": mapping.tenant_id,
"region_code": mapping.region_code,
"data_residency": mapping.data_residency,
- "primary_dc": {
- "id": primary_dc.id,
- "region_code": primary_dc.region_code,
- "name": primary_dc.name,
- "endpoint": primary_dc.endpoint
- } if primary_dc else None,
- "secondary_dc": {
- "id": secondary_dc.id,
- "region_code": secondary_dc.region_code,
- "name": secondary_dc.name,
- "endpoint": secondary_dc.endpoint
- } if secondary_dc else None,
- "created_at": mapping.created_at.isoformat()
+ "primary_dc": (
+ {
+ "id": primary_dc.id,
+ "region_code": primary_dc.region_code,
+ "name": primary_dc.name,
+ "endpoint": primary_dc.endpoint,
+ }
+ if primary_dc
+ else None
+ ),
+ "secondary_dc": (
+ {
+ "id": secondary_dc.id,
+ "region_code": secondary_dc.region_code,
+ "name": secondary_dc.name,
+ "endpoint": secondary_dc.endpoint,
+ }
+ if secondary_dc
+ else None
+ ),
+ "created_at": mapping.created_at.isoformat(),
}
@app.post("/api/v1/tenants/{tenant_id}/data-center", tags=["Localization"])
-async def set_tenant_data_center(
- tenant_id: str,
- request: DataCenterMappingRequest,
- _=Depends(verify_api_key)
-):
+async def set_tenant_data_center(tenant_id: str, request: DataCenterMappingRequest, _=Depends(verify_api_key)):
"""设置租户数据中心"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
manager = get_localization_manager()
mapping = manager.set_tenant_data_center(
- tenant_id=tenant_id,
- region_code=request.region_code,
- data_residency=request.data_residency
+ tenant_id=tenant_id, region_code=request.region_code, data_residency=request.data_residency
)
return {
@@ -11252,7 +10715,7 @@ async def set_tenant_data_center(
"tenant_id": mapping.tenant_id,
"region_code": mapping.region_code,
"data_residency": mapping.data_residency,
- "created_at": mapping.created_at.isoformat()
+ "created_at": mapping.created_at.isoformat(),
}
@@ -11261,7 +10724,7 @@ async def set_tenant_data_center(
async def list_payment_methods(
country_code: str | None = Query(default=None, description="国家代码"),
currency: str | None = Query(default=None, description="货币代码"),
- active_only: bool = Query(default=True, description="仅返回激活的支付方式")
+ active_only: bool = Query(default=True, description="仅返回激活的支付方式"),
):
"""列出支付方式"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11282,18 +10745,17 @@ async def list_payment_methods(
"is_active": m.is_active,
"display_order": m.display_order,
"min_amount": m.min_amount,
- "max_amount": m.max_amount
+ "max_amount": m.max_amount,
}
for m in methods
],
- "total": len(methods)
+ "total": len(methods),
}
@app.get("/api/v1/payment-methods/localized", tags=["Localization"])
async def get_localized_payment_methods(
- country_code: str = Query(..., description="国家代码"),
- language: str = Query(default="en", description="语言代码")
+ country_code: str = Query(..., description="国家代码"), language: str = Query(default="en", description="语言代码")
):
"""获取本地化的支付方式列表"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11302,18 +10764,14 @@ async def get_localized_payment_methods(
manager = get_localization_manager()
methods = manager.get_localized_payment_methods(country_code, language)
- return {
- "country_code": country_code,
- "language": language,
- "payment_methods": methods
- }
+ return {"country_code": country_code, "language": language, "payment_methods": methods}
# Country APIs
@app.get("/api/v1/countries", tags=["Localization"])
async def list_countries(
region: str | None = Query(default=None, description="区域过滤"),
- active_only: bool = Query(default=True, description="仅返回激活的国家")
+ active_only: bool = Query(default=True, description="仅返回激活的国家"),
):
"""列出国家/地区"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11333,11 +10791,11 @@ async def list_countries(
"default_currency": c.default_currency,
"timezone": c.timezone,
"calendar_type": c.calendar_type,
- "vat_rate": c.vat_rate
+ "vat_rate": c.vat_rate,
}
for c in countries
],
- "total": len(countries)
+ "total": len(countries),
}
@@ -11365,16 +10823,13 @@ async def get_country(code: str):
"supported_currencies": country.supported_currencies,
"timezone": country.timezone,
"calendar_type": country.calendar_type,
- "vat_rate": country.vat_rate
+ "vat_rate": country.vat_rate,
}
# Localization Settings APIs
@app.get("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"])
-async def get_localization_settings(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def get_localization_settings(tenant_id: str, _=Depends(verify_api_key)):
"""获取租户本地化设置"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
@@ -11399,16 +10854,12 @@ async def get_localization_settings(
"first_day_of_week": settings.first_day_of_week,
"region_code": settings.region_code,
"data_residency": settings.data_residency,
- "updated_at": settings.updated_at.isoformat()
+ "updated_at": settings.updated_at.isoformat(),
}
@app.post("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"])
-async def create_localization_settings(
- tenant_id: str,
- request: LocalizationSettingsCreate,
- _=Depends(verify_api_key)
-):
+async def create_localization_settings(tenant_id: str, request: LocalizationSettingsCreate, _=Depends(verify_api_key)):
"""创建租户本地化设置"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
@@ -11422,7 +10873,7 @@ async def create_localization_settings(
supported_currencies=request.supported_currencies,
default_timezone=request.default_timezone,
region_code=request.region_code,
- data_residency=request.data_residency
+ data_residency=request.data_residency,
)
return {
@@ -11435,16 +10886,12 @@ async def create_localization_settings(
"default_timezone": settings.default_timezone,
"region_code": settings.region_code,
"data_residency": settings.data_residency,
- "created_at": settings.created_at.isoformat()
+ "created_at": settings.created_at.isoformat(),
}
@app.put("/api/v1/tenants/{tenant_id}/localization", tags=["Localization"])
-async def update_localization_settings(
- tenant_id: str,
- request: LocalizationSettingsUpdate,
- _=Depends(verify_api_key)
-):
+async def update_localization_settings(tenant_id: str, request: LocalizationSettingsUpdate, _=Depends(verify_api_key)):
"""更新租户本地化设置"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
@@ -11467,15 +10914,14 @@ async def update_localization_settings(
"default_timezone": settings.default_timezone,
"region_code": settings.region_code,
"data_residency": settings.data_residency,
- "updated_at": settings.updated_at.isoformat()
+ "updated_at": settings.updated_at.isoformat(),
}
# Formatting APIs
@app.post("/api/v1/format/datetime", tags=["Localization"])
async def format_datetime_endpoint(
- request: FormatDateTimeRequest,
- language: str = Query(default="en", description="语言代码")
+ request: FormatDateTimeRequest, language: str = Query(default="en", description="语言代码")
):
"""格式化日期时间"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11484,15 +10930,12 @@ async def format_datetime_endpoint(
manager = get_localization_manager()
try:
- dt = datetime.fromisoformat(request.timestamp.replace('Z', '+00:00'))
+ dt = datetime.fromisoformat(request.timestamp.replace("Z", "+00:00"))
except ValueError:
raise HTTPException(status_code=400, detail="Invalid timestamp format")
formatted = manager.format_datetime(
- dt=dt,
- language=language,
- timezone=request.timezone,
- format_type=request.format_type
+ dt=dt, language=language, timezone=request.timezone, format_type=request.format_type
)
return {
@@ -11500,61 +10943,40 @@ async def format_datetime_endpoint(
"formatted": formatted,
"language": language,
"timezone": request.timezone,
- "format_type": request.format_type
+ "format_type": request.format_type,
}
@app.post("/api/v1/format/number", tags=["Localization"])
async def format_number_endpoint(
- request: FormatNumberRequest,
- language: str = Query(default="en", description="语言代码")
+ request: FormatNumberRequest, language: str = Query(default="en", description="语言代码")
):
"""格式化数字"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
manager = get_localization_manager()
- formatted = manager.format_number(
- number=request.number,
- language=language,
- decimal_places=request.decimal_places
- )
+ formatted = manager.format_number(number=request.number, language=language, decimal_places=request.decimal_places)
- return {
- "original": request.number,
- "formatted": formatted,
- "language": language
- }
+ return {"original": request.number, "formatted": formatted, "language": language}
@app.post("/api/v1/format/currency", tags=["Localization"])
async def format_currency_endpoint(
- request: FormatCurrencyRequest,
- language: str = Query(default="en", description="语言代码")
+ request: FormatCurrencyRequest, language: str = Query(default="en", description="语言代码")
):
"""格式化货币"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
manager = get_localization_manager()
- formatted = manager.format_currency(
- amount=request.amount,
- currency=request.currency,
- language=language
- )
+ formatted = manager.format_currency(amount=request.amount, currency=request.currency, language=language)
- return {
- "original": request.amount,
- "currency": request.currency,
- "formatted": formatted,
- "language": language
- }
+ return {"original": request.amount, "currency": request.currency, "formatted": formatted, "language": language}
@app.post("/api/v1/convert/timezone", tags=["Localization"])
-async def convert_timezone_endpoint(
- request: ConvertTimezoneRequest
-):
+async def convert_timezone_endpoint(request: ConvertTimezoneRequest):
"""转换时区"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
@@ -11562,47 +10984,38 @@ async def convert_timezone_endpoint(
manager = get_localization_manager()
try:
- dt = datetime.fromisoformat(request.timestamp.replace('Z', '+00:00'))
+ dt = datetime.fromisoformat(request.timestamp.replace("Z", "+00:00"))
except ValueError:
raise HTTPException(status_code=400, detail="Invalid timestamp format")
- converted = manager.convert_timezone(
- dt=dt,
- from_tz=request.from_tz,
- to_tz=request.to_tz
- )
+ converted = manager.convert_timezone(dt=dt, from_tz=request.from_tz, to_tz=request.to_tz)
return {
"original": request.timestamp,
"from_timezone": request.from_tz,
"to_timezone": request.to_tz,
- "converted": converted.isoformat()
+ "converted": converted.isoformat(),
}
@app.get("/api/v1/detect/locale", tags=["Localization"])
async def detect_locale(
accept_language: str | None = Header(default=None, description="Accept-Language 头"),
- ip_country: str | None = Query(default=None, description="IP国家代码")
+ ip_country: str | None = Query(default=None, description="IP国家代码"),
):
"""检测用户本地化偏好"""
if not LOCALIZATION_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="Localization manager not available")
manager = get_localization_manager()
- preferences = manager.detect_user_preferences(
- accept_language=accept_language,
- ip_country=ip_country
- )
+ preferences = manager.detect_user_preferences(accept_language=accept_language, ip_country=ip_country)
return preferences
@app.get("/api/v1/calendar/{calendar_type}", tags=["Localization"])
async def get_calendar_info(
- calendar_type: str,
- year: int = Query(..., description="年份"),
- month: int = Query(..., description="月份")
+ calendar_type: str, year: int = Query(..., description="年份"), month: int = Query(..., description="月份")
):
"""获取日历信息"""
if not LOCALIZATION_MANAGER_AVAILABLE:
@@ -11618,6 +11031,7 @@ async def get_calendar_info(
# Phase 8 Task 4: AI 能力增强 API
# ============================================
+
class CreateCustomModelRequest(BaseModel):
name: str
description: str
@@ -11690,9 +11104,7 @@ class PredictionFeedbackRequest(BaseModel):
# 自定义模型管理 API
@app.post("/api/v1/tenants/{tenant_id}/ai/custom-models", tags=["AI Enhancement"])
async def create_custom_model(
- tenant_id: str,
- request: CreateCustomModelRequest,
- created_by: str = Query(..., description="创建者ID")
+ tenant_id: str, request: CreateCustomModelRequest, created_by: str = Query(..., description="创建者ID")
):
"""创建自定义模型"""
if not AI_MANAGER_AVAILABLE:
@@ -11708,14 +11120,14 @@ async def create_custom_model(
model_type=ModelType(request.model_type),
training_data=request.training_data,
hyperparameters=request.hyperparameters,
- created_by=created_by
+ created_by=created_by,
)
return {
"id": model.id,
"name": model.name,
"model_type": model.model_type.value,
"status": model.status.value,
- "created_at": model.created_at
+ "created_at": model.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -11725,7 +11137,7 @@ async def create_custom_model(
async def list_custom_models(
tenant_id: str,
model_type: str | None = Query(default=None, description="模型类型过滤"),
- status: str | None = Query(default=None, description="状态过滤")
+ status: str | None = Query(default=None, description="状态过滤"),
):
"""列出自定义模型"""
if not AI_MANAGER_AVAILABLE:
@@ -11746,7 +11158,7 @@ async def list_custom_models(
"model_type": m.model_type.value,
"status": m.status.value,
"metrics": m.metrics,
- "created_at": m.created_at
+ "created_at": m.created_at,
}
for m in models
]
@@ -11778,15 +11190,12 @@ async def get_custom_model(model_id: str):
"model_path": model.model_path,
"created_at": model.created_at,
"trained_at": model.trained_at,
- "created_by": model.created_by
+ "created_by": model.created_by,
}
@app.post("/api/v1/ai/custom-models/{model_id}/samples", tags=["AI Enhancement"])
-async def add_training_sample(
- model_id: str,
- request: AddTrainingSampleRequest
-):
+async def add_training_sample(model_id: str, request: AddTrainingSampleRequest):
"""添加训练样本"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -11794,10 +11203,7 @@ async def add_training_sample(
manager = get_ai_manager()
sample = manager.add_training_sample(
- model_id=model_id,
- text=request.text,
- entities=request.entities,
- metadata=request.metadata
+ model_id=model_id, text=request.text, entities=request.entities, metadata=request.metadata
)
return {
@@ -11805,7 +11211,7 @@ async def add_training_sample(
"model_id": sample.model_id,
"text": sample.text,
"entities": sample.entities,
- "created_at": sample.created_at
+ "created_at": sample.created_at,
}
@@ -11820,13 +11226,7 @@ async def get_training_samples(model_id: str):
return {
"samples": [
- {
- "id": s.id,
- "text": s.text,
- "entities": s.entities,
- "metadata": s.metadata,
- "created_at": s.created_at
- }
+ {"id": s.id, "text": s.text, "entities": s.entities, "metadata": s.metadata, "created_at": s.created_at}
for s in samples
]
}
@@ -11842,12 +11242,7 @@ async def train_custom_model(model_id: str):
try:
model = await manager.train_custom_model(model_id)
- return {
- "id": model.id,
- "status": model.status.value,
- "metrics": model.metrics,
- "trained_at": model.trained_at
- }
+ return {"id": model.id, "status": model.status.value, "metrics": model.metrics, "trained_at": model.trained_at}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -11862,22 +11257,14 @@ async def predict_with_custom_model(request: PredictRequest):
try:
entities = await manager.predict_with_custom_model(request.model_id, request.text)
- return {
- "model_id": request.model_id,
- "text": request.text,
- "entities": entities
- }
+ return {"model_id": request.model_id, "text": request.text, "entities": entities}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# 多模态分析 API
@app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/multimodal", tags=["AI Enhancement"])
-async def analyze_multimodal(
- tenant_id: str,
- project_id: str,
- request: MultimodalAnalysisRequest
-):
+async def analyze_multimodal(tenant_id: str, project_id: str, request: MultimodalAnalysisRequest):
"""多模态分析"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -11891,7 +11278,7 @@ async def analyze_multimodal(
provider=MultimodalProvider(request.provider),
input_type=request.input_type,
input_urls=request.input_urls,
- prompt=request.prompt
+ prompt=request.prompt,
)
return {
@@ -11901,7 +11288,7 @@ async def analyze_multimodal(
"result": analysis.result,
"tokens_used": analysis.tokens_used,
"cost": analysis.cost,
- "created_at": analysis.created_at
+ "created_at": analysis.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -11909,8 +11296,7 @@ async def analyze_multimodal(
@app.get("/api/v1/tenants/{tenant_id}/ai/multimodal", tags=["AI Enhancement"])
async def list_multimodal_analyses(
- tenant_id: str,
- project_id: str | None = Query(default=None, description="项目ID过滤")
+ tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤")
):
"""获取多模态分析历史"""
if not AI_MANAGER_AVAILABLE:
@@ -11930,7 +11316,7 @@ async def list_multimodal_analyses(
"result": a.result,
"tokens_used": a.tokens_used,
"cost": a.cost,
- "created_at": a.created_at
+ "created_at": a.created_at,
}
for a in analyses
]
@@ -11939,11 +11325,7 @@ async def list_multimodal_analyses(
# 知识图谱 RAG API
@app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/kg-rag", tags=["AI Enhancement"])
-async def create_kg_rag(
- tenant_id: str,
- project_id: str,
- request: CreateKGRAGRequest
-):
+async def create_kg_rag(tenant_id: str, project_id: str, request: CreateKGRAGRequest):
"""创建知识图谱 RAG 配置"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -11957,7 +11339,7 @@ async def create_kg_rag(
description=request.description,
kg_config=request.kg_config,
retrieval_config=request.retrieval_config,
- generation_config=request.generation_config
+ generation_config=request.generation_config,
)
return {
@@ -11965,15 +11347,12 @@ async def create_kg_rag(
"name": rag.name,
"description": rag.description,
"is_active": rag.is_active,
- "created_at": rag.created_at
+ "created_at": rag.created_at,
}
@app.get("/api/v1/tenants/{tenant_id}/ai/kg-rag", tags=["AI Enhancement"])
-async def list_kg_rags(
- tenant_id: str,
- project_id: str | None = Query(default=None, description="项目ID过滤")
-):
+async def list_kg_rags(tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤")):
"""列出知识图谱 RAG 配置"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -11989,7 +11368,7 @@ async def list_kg_rags(
"name": r.name,
"description": r.description,
"is_active": r.is_active,
- "created_at": r.created_at
+ "created_at": r.created_at,
}
for r in rags
]
@@ -12000,7 +11379,7 @@ async def list_kg_rags(
async def query_kg_rag(
request: KGRAGQueryRequest,
project_entities: list[dict] = Body(default=[], description="项目实体列表"),
- project_relations: list[dict] = Body(default=[], description="项目关系列表")
+ project_relations: list[dict] = Body(default=[], description="项目关系列表"),
):
"""基于知识图谱的 RAG 查询"""
if not AI_MANAGER_AVAILABLE:
@@ -12013,7 +11392,7 @@ async def query_kg_rag(
rag_id=request.rag_id,
query=request.query,
project_entities=project_entities,
- project_relations=project_relations
+ project_relations=project_relations,
)
return {
@@ -12025,7 +11404,7 @@ async def query_kg_rag(
"confidence": result.confidence,
"tokens_used": result.tokens_used,
"latency_ms": result.latency_ms,
- "created_at": result.created_at
+ "created_at": result.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -12033,11 +11412,7 @@ async def query_kg_rag(
# 智能摘要 API
@app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/summarize", tags=["AI Enhancement"])
-async def generate_smart_summary(
- tenant_id: str,
- project_id: str,
- request: SmartSummaryRequest
-):
+async def generate_smart_summary(tenant_id: str, project_id: str, request: SmartSummaryRequest):
"""生成智能摘要"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -12050,7 +11425,7 @@ async def generate_smart_summary(
source_type=request.source_type,
source_id=request.source_id,
summary_type=request.summary_type,
- content_data=request.content_data
+ content_data=request.content_data,
)
return {
@@ -12063,7 +11438,7 @@ async def generate_smart_summary(
"entities_mentioned": summary.entities_mentioned,
"confidence": summary.confidence,
"tokens_used": summary.tokens_used,
- "created_at": summary.created_at
+ "created_at": summary.created_at,
}
@@ -12072,7 +11447,7 @@ async def list_smart_summaries(
tenant_id: str,
project_id: str,
source_type: str | None = Query(default=None, description="来源类型过滤"),
- source_id: str | None = Query(default=None, description="来源ID过滤")
+ source_id: str | None = Query(default=None, description="来源ID过滤"),
):
"""获取智能摘要列表"""
if not AI_MANAGER_AVAILABLE:
@@ -12086,11 +11461,7 @@ async def list_smart_summaries(
# 预测模型 API
@app.post("/api/v1/tenants/{tenant_id}/projects/{project_id}/ai/prediction-models", tags=["AI Enhancement"])
-async def create_prediction_model(
- tenant_id: str,
- project_id: str,
- request: CreatePredictionModelRequest
-):
+async def create_prediction_model(tenant_id: str, project_id: str, request: CreatePredictionModelRequest):
"""创建预测模型"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -12105,7 +11476,7 @@ async def create_prediction_model(
prediction_type=PredictionType(request.prediction_type),
target_entity_type=request.target_entity_type,
features=request.features,
- model_config=request.model_config
+ model_config=request.model_config,
)
return {
@@ -12115,7 +11486,7 @@ async def create_prediction_model(
"target_entity_type": model.target_entity_type,
"features": model.features,
"is_active": model.is_active,
- "created_at": model.created_at
+ "created_at": model.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -12123,8 +11494,7 @@ async def create_prediction_model(
@app.get("/api/v1/tenants/{tenant_id}/ai/prediction-models", tags=["AI Enhancement"])
async def list_prediction_models(
- tenant_id: str,
- project_id: str | None = Query(default=None, description="项目ID过滤")
+ tenant_id: str, project_id: str | None = Query(default=None, description="项目ID过滤")
):
"""列出预测模型"""
if not AI_MANAGER_AVAILABLE:
@@ -12145,7 +11515,7 @@ async def list_prediction_models(
"accuracy": m.accuracy,
"last_trained_at": m.last_trained_at,
"prediction_count": m.prediction_count,
- "is_active": m.is_active
+ "is_active": m.is_active,
}
for m in models
]
@@ -12177,15 +11547,12 @@ async def get_prediction_model(model_id: str):
"last_trained_at": model.last_trained_at,
"prediction_count": model.prediction_count,
"is_active": model.is_active,
- "created_at": model.created_at
+ "created_at": model.created_at,
}
@app.post("/api/v1/ai/prediction-models/{model_id}/train", tags=["AI Enhancement"])
-async def train_prediction_model(
- model_id: str,
- historical_data: list[dict] = Body(..., description="历史训练数据")
-):
+async def train_prediction_model(model_id: str, historical_data: list[dict] = Body(..., description="历史训练数据")):
"""训练预测模型"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -12194,11 +11561,7 @@ async def train_prediction_model(
try:
model = await manager.train_prediction_model(model_id, historical_data)
- return {
- "id": model.id,
- "accuracy": model.accuracy,
- "last_trained_at": model.last_trained_at
- }
+ return {"id": model.id, "accuracy": model.accuracy, "last_trained_at": model.last_trained_at}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -12222,17 +11585,14 @@ async def predict(request: PredictDataRequest):
"prediction_data": result.prediction_data,
"confidence": result.confidence,
"explanation": result.explanation,
- "created_at": result.created_at
+ "created_at": result.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/ai/prediction-models/{model_id}/results", tags=["AI Enhancement"])
-async def get_prediction_results(
- model_id: str,
- limit: int = Query(default=100, description="返回结果数量限制")
-):
+async def get_prediction_results(model_id: str, limit: int = Query(default=100, description="返回结果数量限制")):
"""获取预测结果历史"""
if not AI_MANAGER_AVAILABLE:
raise HTTPException(status_code=500, detail="AI manager not available")
@@ -12251,7 +11611,7 @@ async def get_prediction_results(
"explanation": r.explanation,
"actual_value": r.actual_value,
"is_correct": r.is_correct,
- "created_at": r.created_at
+ "created_at": r.created_at,
}
for r in results
]
@@ -12266,9 +11626,7 @@ async def update_prediction_feedback(request: PredictionFeedbackRequest):
manager = get_ai_manager()
manager.update_prediction_feedback(
- prediction_id=request.prediction_id,
- actual_value=request.actual_value,
- is_correct=request.is_correct
+ prediction_id=request.prediction_id, actual_value=request.actual_value, is_correct=request.is_correct
)
return {"status": "success", "message": "Feedback updated"}
@@ -12276,6 +11634,7 @@ async def update_prediction_feedback(request: PredictionFeedbackRequest):
# ==================== Phase 8 Task 5: Growth & Analytics Endpoints ====================
+
# Pydantic Models for Growth API
class TrackEventRequest(BaseModel):
tenant_id: str
@@ -12391,6 +11750,7 @@ def get_growth_manager_instance():
# ==================== 用户行为分析 API ====================
+
@app.post("/api/v1/analytics/track", tags=["Growth & Analytics"])
async def track_event_endpoint(request: TrackEventRequest):
"""
@@ -12413,18 +11773,14 @@ async def track_event_endpoint(request: TrackEventRequest):
session_id=request.session_id,
device_info=request.device_info,
referrer=request.referrer,
- utm_params={
- "source": request.utm_source,
- "medium": request.utm_medium,
- "campaign": request.utm_campaign
- } if any([request.utm_source, request.utm_medium, request.utm_campaign]) else None
+ utm_params=(
+ {"source": request.utm_source, "medium": request.utm_medium, "campaign": request.utm_campaign}
+ if any([request.utm_source, request.utm_medium, request.utm_campaign])
+ else None
+ ),
)
- return {
- "success": True,
- "event_id": event.id,
- "timestamp": event.timestamp.isoformat()
- }
+ return {"success": True, "event_id": event.id, "timestamp": event.timestamp.isoformat()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -12442,11 +11798,7 @@ async def get_analytics_dashboard(tenant_id: str):
@app.get("/api/v1/analytics/summary/{tenant_id}", tags=["Growth & Analytics"])
-async def get_analytics_summary(
- tenant_id: str,
- start_date: str | None = None,
- end_date: str | None = None
-):
+async def get_analytics_summary(tenant_id: str, start_date: str | None = None, end_date: str | None = None):
"""获取用户分析汇总"""
if not GROWTH_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Growth manager not available")
@@ -12483,12 +11835,13 @@ async def get_user_profile(tenant_id: str, user_id: str):
"feature_usage": profile.feature_usage,
"ltv": profile.ltv,
"churn_risk_score": profile.churn_risk_score,
- "engagement_score": profile.engagement_score
+ "engagement_score": profile.engagement_score,
}
# ==================== 转化漏斗 API ====================
+
@app.post("/api/v1/analytics/funnels", tags=["Growth & Analytics"])
async def create_funnel_endpoint(request: CreateFunnelRequest, created_by: str = "system"):
"""创建转化漏斗"""
@@ -12505,23 +11858,14 @@ async def create_funnel_endpoint(request: CreateFunnelRequest, created_by: str =
name=request.name,
description=request.description,
steps=request.steps,
- created_by=created_by
+ created_by=created_by,
)
- return {
- "id": funnel.id,
- "name": funnel.name,
- "steps": funnel.steps,
- "created_at": funnel.created_at
- }
+ return {"id": funnel.id, "name": funnel.name, "steps": funnel.steps, "created_at": funnel.created_at}
@app.get("/api/v1/analytics/funnels/{funnel_id}/analyze", tags=["Growth & Analytics"])
-async def analyze_funnel_endpoint(
- funnel_id: str,
- period_start: str | None = None,
- period_end: str | None = None
-):
+async def analyze_funnel_endpoint(funnel_id: str, period_start: str | None = None, period_end: str | None = None):
"""分析漏斗转化率"""
if not GROWTH_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Growth manager not available")
@@ -12543,15 +11887,13 @@ async def analyze_funnel_endpoint(
"total_users": analysis.total_users,
"step_conversions": analysis.step_conversions,
"overall_conversion": analysis.overall_conversion,
- "drop_off_points": analysis.drop_off_points
+ "drop_off_points": analysis.drop_off_points,
}
@app.get("/api/v1/analytics/retention/{tenant_id}", tags=["Growth & Analytics"])
async def calculate_retention(
- tenant_id: str,
- cohort_date: str,
- periods: str | None = None # JSON array: [1, 3, 7, 14, 30]
+ tenant_id: str, cohort_date: str, periods: str | None = None # JSON array: [1, 3, 7, 14, 30]
):
"""计算留存率"""
if not GROWTH_MANAGER_AVAILABLE:
@@ -12569,6 +11911,7 @@ async def calculate_retention(
# ==================== A/B 测试 API ====================
+
@app.post("/api/v1/experiments", tags=["Growth & Analytics"])
async def create_experiment_endpoint(request: CreateExperimentRequest, created_by: str = "system"):
"""创建 A/B 测试实验"""
@@ -12593,7 +11936,7 @@ async def create_experiment_endpoint(request: CreateExperimentRequest, created_b
secondary_metrics=request.secondary_metrics,
min_sample_size=request.min_sample_size,
confidence_level=request.confidence_level,
- created_by=created_by
+ created_by=created_by,
)
return {
@@ -12601,7 +11944,7 @@ async def create_experiment_endpoint(request: CreateExperimentRequest, created_b
"name": experiment.name,
"status": experiment.status.value,
"variants": experiment.variants,
- "created_at": experiment.created_at
+ "created_at": experiment.created_at,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -12628,7 +11971,7 @@ async def list_experiments(status: str | None = None):
"hypothesis": e.hypothesis,
"primary_metric": e.primary_metric,
"start_date": e.start_date.isoformat() if e.start_date else None,
- "end_date": e.end_date.isoformat() if e.end_date else None
+ "end_date": e.end_date.isoformat() if e.end_date else None,
}
for e in experiments
]
@@ -12658,7 +12001,7 @@ async def get_experiment_endpoint(experiment_id: str):
"primary_metric": experiment.primary_metric,
"secondary_metrics": experiment.secondary_metrics,
"start_date": experiment.start_date.isoformat() if experiment.start_date else None,
- "end_date": experiment.end_date.isoformat() if experiment.end_date else None
+ "end_date": experiment.end_date.isoformat() if experiment.end_date else None,
}
@@ -12671,19 +12014,13 @@ async def assign_variant_endpoint(experiment_id: str, request: AssignVariantRequ
manager = get_growth_manager_instance()
variant_id = manager.assign_variant(
- experiment_id=experiment_id,
- user_id=request.user_id,
- user_attributes=request.user_attributes
+ experiment_id=experiment_id, user_id=request.user_id, user_attributes=request.user_attributes
)
if not variant_id:
raise HTTPException(status_code=400, detail="Failed to assign variant")
- return {
- "experiment_id": experiment_id,
- "user_id": request.user_id,
- "variant_id": variant_id
- }
+ return {"experiment_id": experiment_id, "user_id": request.user_id, "variant_id": variant_id}
@app.post("/api/v1/experiments/{experiment_id}/metrics", tags=["Growth & Analytics"])
@@ -12699,7 +12036,7 @@ async def record_experiment_metric_endpoint(experiment_id: str, request: RecordM
variant_id=request.variant_id,
user_id=request.user_id,
metric_name=request.metric_name,
- metric_value=request.metric_value
+ metric_value=request.metric_value,
)
return {"success": True}
@@ -12737,7 +12074,7 @@ async def start_experiment_endpoint(experiment_id: str):
return {
"id": experiment.id,
"status": experiment.status.value,
- "start_date": experiment.start_date.isoformat() if experiment.start_date else None
+ "start_date": experiment.start_date.isoformat() if experiment.start_date else None,
}
@@ -12757,12 +12094,13 @@ async def stop_experiment_endpoint(experiment_id: str):
return {
"id": experiment.id,
"status": experiment.status.value,
- "end_date": experiment.end_date.isoformat() if experiment.end_date else None
+ "end_date": experiment.end_date.isoformat() if experiment.end_date else None,
}
# ==================== 邮件营销 API ====================
+
@app.post("/api/v1/email/templates", tags=["Growth & Analytics"])
async def create_email_template_endpoint(request: CreateEmailTemplateRequest):
"""创建邮件模板"""
@@ -12783,7 +12121,7 @@ async def create_email_template_endpoint(request: CreateEmailTemplateRequest):
variables=request.variables,
from_name=request.from_name,
from_email=request.from_email,
- reply_to=request.reply_to
+ reply_to=request.reply_to,
)
return {
@@ -12792,7 +12130,7 @@ async def create_email_template_endpoint(request: CreateEmailTemplateRequest):
"template_type": template.template_type.value,
"subject": template.subject,
"variables": template.variables,
- "created_at": template.created_at
+ "created_at": template.created_at,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -12818,7 +12156,7 @@ async def list_email_templates(template_type: str | None = None):
"template_type": t.template_type.value,
"subject": t.subject,
"variables": t.variables,
- "is_active": t.is_active
+ "is_active": t.is_active,
}
for t in templates
]
@@ -12846,7 +12184,7 @@ async def get_email_template_endpoint(template_id: str):
"text_content": template.text_content,
"variables": template.variables,
"from_name": template.from_name,
- "from_email": template.from_email
+ "from_email": template.from_email,
}
@@ -12882,7 +12220,7 @@ async def create_email_campaign_endpoint(request: CreateCampaignRequest):
name=request.name,
template_id=request.template_id,
recipient_list=request.recipients,
- scheduled_at=scheduled_at
+ scheduled_at=scheduled_at,
)
return {
@@ -12891,7 +12229,7 @@ async def create_email_campaign_endpoint(request: CreateCampaignRequest):
"template_id": campaign.template_id,
"status": campaign.status,
"recipient_count": campaign.recipient_count,
- "scheduled_at": campaign.scheduled_at
+ "scheduled_at": campaign.scheduled_at,
}
@@ -12926,7 +12264,7 @@ async def create_automation_workflow_endpoint(request: CreateAutomationWorkflowR
description=request.description,
trigger_type=WorkflowTriggerType(request.trigger_type),
trigger_conditions=request.trigger_conditions,
- actions=request.actions
+ actions=request.actions,
)
return {
@@ -12934,12 +12272,13 @@ async def create_automation_workflow_endpoint(request: CreateAutomationWorkflowR
"name": workflow.name,
"trigger_type": workflow.trigger_type.value,
"is_active": workflow.is_active,
- "created_at": workflow.created_at
+ "created_at": workflow.created_at,
}
# ==================== 推荐系统 API ====================
+
@app.post("/api/v1/referral/programs", tags=["Growth & Analytics"])
async def create_referral_program_endpoint(request: CreateReferralProgramRequest):
"""创建推荐计划"""
@@ -12959,7 +12298,7 @@ async def create_referral_program_endpoint(request: CreateReferralProgramRequest
referee_reward_value=request.referee_reward_value,
max_referrals_per_user=request.max_referrals_per_user,
referral_code_length=request.referral_code_length,
- expiry_days=request.expiry_days
+ expiry_days=request.expiry_days,
)
return {
@@ -12969,7 +12308,7 @@ async def create_referral_program_endpoint(request: CreateReferralProgramRequest
"referrer_reward_value": program.referrer_reward_value,
"referee_reward_type": program.referee_reward_type,
"referee_reward_value": program.referee_reward_value,
- "is_active": program.is_active
+ "is_active": program.is_active,
}
@@ -12991,7 +12330,7 @@ async def generate_referral_code_endpoint(program_id: str, referrer_id: str):
"referral_code": referral.referral_code,
"referrer_id": referral.referrer_id,
"status": referral.status.value,
- "expires_at": referral.expires_at.isoformat()
+ "expires_at": referral.expires_at.isoformat(),
}
@@ -13042,7 +12381,7 @@ async def create_team_incentive_endpoint(request: CreateTeamIncentiveRequest):
incentive_type=request.incentive_type,
incentive_value=request.incentive_value,
valid_from=datetime.fromisoformat(request.valid_from),
- valid_until=datetime.fromisoformat(request.valid_until)
+ valid_until=datetime.fromisoformat(request.valid_until),
)
return {
@@ -13053,16 +12392,12 @@ async def create_team_incentive_endpoint(request: CreateTeamIncentiveRequest):
"incentive_type": incentive.incentive_type,
"incentive_value": incentive.incentive_value,
"valid_from": incentive.valid_from.isoformat(),
- "valid_until": incentive.valid_until.isoformat()
+ "valid_until": incentive.valid_until.isoformat(),
}
@app.get("/api/v1/team-incentives/check", tags=["Growth & Analytics"])
-async def check_team_incentive_eligibility(
- tenant_id: str,
- current_tier: str,
- team_size: int
-):
+async def check_team_incentive_eligibility(tenant_id: str, current_tier: str, team_size: int):
"""检查团队激励资格"""
if not GROWTH_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Growth manager not available")
@@ -13073,12 +12408,7 @@ async def check_team_incentive_eligibility(
return {
"eligible_incentives": [
- {
- "id": i.id,
- "name": i.name,
- "incentive_type": i.incentive_type,
- "incentive_value": i.incentive_value
- }
+ {"id": i.id, "name": i.name, "incentive_type": i.incentive_type, "incentive_value": i.incentive_value}
for i in incentives
]
}
@@ -13101,6 +12431,7 @@ try:
TemplateCategory,
TemplateStatus,
)
+
DEVELOPER_ECOSYSTEM_AVAILABLE = True
except ImportError as e:
print(f"Developer Ecosystem Manager import error: {e}")
@@ -13253,10 +12584,10 @@ def get_developer_ecosystem_manager_instance():
# ==================== SDK Release & Management API ====================
+
@app.post("/api/v1/developer/sdks", tags=["Developer Ecosystem"])
async def create_sdk_release_endpoint(
- request: SDKReleaseCreate,
- created_by: str = Header(default="system", description="创建者ID")
+ request: SDKReleaseCreate, created_by: str = Header(default="system", description="创建者ID")
):
"""创建 SDK 发布"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13279,7 +12610,7 @@ async def create_sdk_release_endpoint(
dependencies=request.dependencies,
file_size=request.file_size,
checksum=request.checksum,
- created_by=created_by
+ created_by=created_by,
)
return {
@@ -13289,7 +12620,7 @@ async def create_sdk_release_endpoint(
"version": sdk.version,
"status": sdk.status.value,
"package_name": sdk.package_name,
- "created_at": sdk.created_at
+ "created_at": sdk.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -13299,7 +12630,7 @@ async def create_sdk_release_endpoint(
async def list_sdk_releases_endpoint(
language: str | None = Query(default=None, description="SDK语言过滤"),
status: str | None = Query(default=None, description="状态过滤"),
- search: str | None = Query(default=None, description="搜索关键词")
+ search: str | None = Query(default=None, description="搜索关键词"),
):
"""列出 SDK 发布"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13323,7 +12654,7 @@ async def list_sdk_releases_endpoint(
"package_name": s.package_name,
"status": s.status.value,
"download_count": s.download_count,
- "created_at": s.created_at
+ "created_at": s.created_at,
}
for s in sdks
]
@@ -13360,7 +12691,7 @@ async def get_sdk_release_endpoint(sdk_id: str):
"checksum": sdk.checksum,
"download_count": sdk.download_count,
"created_at": sdk.created_at,
- "published_at": sdk.published_at
+ "published_at": sdk.published_at,
}
@@ -13378,12 +12709,7 @@ async def update_sdk_release_endpoint(sdk_id: str, request: SDKReleaseUpdate):
if not sdk:
raise HTTPException(status_code=404, detail="SDK not found")
- return {
- "id": sdk.id,
- "name": sdk.name,
- "status": sdk.status.value,
- "updated_at": sdk.updated_at
- }
+ return {"id": sdk.id, "name": sdk.name, "status": sdk.status.value, "updated_at": sdk.updated_at}
@app.post("/api/v1/developer/sdks/{sdk_id}/publish", tags=["Developer Ecosystem"])
@@ -13398,11 +12724,7 @@ async def publish_sdk_release_endpoint(sdk_id: str):
if not sdk:
raise HTTPException(status_code=404, detail="SDK not found")
- return {
- "id": sdk.id,
- "status": sdk.status.value,
- "published_at": sdk.published_at
- }
+ return {"id": sdk.id, "status": sdk.status.value, "published_at": sdk.published_at}
@app.post("/api/v1/developer/sdks/{sdk_id}/download", tags=["Developer Ecosystem"])
@@ -13434,7 +12756,7 @@ async def get_sdk_versions_endpoint(sdk_id: str):
"is_latest": v.is_latest,
"is_lts": v.is_lts,
"download_count": v.download_count,
- "created_at": v.created_at
+ "created_at": v.created_at,
}
for v in versions
]
@@ -13456,7 +12778,7 @@ async def add_sdk_version_endpoint(sdk_id: str, request: SDKVersionCreate):
release_notes=request.release_notes,
download_url=request.download_url,
checksum=request.checksum,
- file_size=request.file_size
+ file_size=request.file_size,
)
return {
@@ -13464,17 +12786,18 @@ async def add_sdk_version_endpoint(sdk_id: str, request: SDKVersionCreate):
"version": version.version,
"is_latest": version.is_latest,
"is_lts": version.is_lts,
- "created_at": version.created_at
+ "created_at": version.created_at,
}
# ==================== Template Market API ====================
+
@app.post("/api/v1/developer/templates", tags=["Developer Ecosystem"])
async def create_template_endpoint(
request: TemplateCreate,
author_id: str = Header(default="system", description="作者ID"),
- author_name: str = Header(default="System", description="作者名称")
+ author_name: str = Header(default="System", description="作者名称"),
):
"""创建模板"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13500,7 +12823,7 @@ async def create_template_endpoint(
version=request.version,
min_platform_version=request.min_platform_version,
file_size=request.file_size,
- checksum=request.checksum
+ checksum=request.checksum,
)
return {
@@ -13509,7 +12832,7 @@ async def create_template_endpoint(
"category": template.category.value,
"status": template.status.value,
"price": template.price,
- "created_at": template.created_at
+ "created_at": template.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -13523,7 +12846,7 @@ async def list_templates_endpoint(
author_id: str | None = Query(default=None, description="作者ID过滤"),
min_price: float | None = Query(default=None, description="最低价格"),
max_price: float | None = Query(default=None, description="最高价格"),
- sort_by: str = Query(default="created_at", description="排序方式")
+ sort_by: str = Query(default="created_at", description="排序方式"),
):
"""列出模板"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13541,7 +12864,7 @@ async def list_templates_endpoint(
author_id=author_id,
min_price=min_price,
max_price=max_price,
- sort_by=sort_by
+ sort_by=sort_by,
)
return {
@@ -13558,7 +12881,7 @@ async def list_templates_endpoint(
"rating": t.rating,
"install_count": t.install_count,
"version": t.version,
- "created_at": t.created_at
+ "created_at": t.created_at,
}
for t in templates
]
@@ -13598,7 +12921,7 @@ async def get_template_endpoint(template_id: str):
"rating_count": template.rating_count,
"review_count": template.review_count,
"version": template.version,
- "created_at": template.created_at
+ "created_at": template.created_at,
}
@@ -13614,10 +12937,7 @@ async def approve_template_endpoint(template_id: str, reviewed_by: str = Header(
if not template:
raise HTTPException(status_code=404, detail="Template not found")
- return {
- "id": template.id,
- "status": template.status.value
- }
+ return {"id": template.id, "status": template.status.value}
@app.post("/api/v1/developer/templates/{template_id}/publish", tags=["Developer Ecosystem"])
@@ -13632,11 +12952,7 @@ async def publish_template_endpoint(template_id: str):
if not template:
raise HTTPException(status_code=404, detail="Template not found")
- return {
- "id": template.id,
- "status": template.status.value,
- "published_at": template.published_at
- }
+ return {"id": template.id, "status": template.status.value, "published_at": template.published_at}
@app.post("/api/v1/developer/templates/{template_id}/reject", tags=["Developer Ecosystem"])
@@ -13651,10 +12967,7 @@ async def reject_template_endpoint(template_id: str, reason: str = ""):
if not template:
raise HTTPException(status_code=404, detail="Template not found")
- return {
- "id": template.id,
- "status": template.status.value
- }
+ return {"id": template.id, "status": template.status.value}
@app.post("/api/v1/developer/templates/{template_id}/install", tags=["Developer Ecosystem"])
@@ -13674,7 +12987,7 @@ async def add_template_review_endpoint(
template_id: str,
request: TemplateReviewCreate,
user_id: str = Header(default="user", description="用户ID"),
- user_name: str = Header(default="User", description="用户名称")
+ user_name: str = Header(default="User", description="用户名称"),
):
"""添加模板评价"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13688,22 +13001,14 @@ async def add_template_review_endpoint(
user_name=user_name,
rating=request.rating,
comment=request.comment,
- is_verified_purchase=request.is_verified_purchase
+ is_verified_purchase=request.is_verified_purchase,
)
- return {
- "id": review.id,
- "rating": review.rating,
- "comment": review.comment,
- "created_at": review.created_at
- }
+ return {"id": review.id, "rating": review.rating, "comment": review.comment, "created_at": review.created_at}
@app.get("/api/v1/developer/templates/{template_id}/reviews", tags=["Developer Ecosystem"])
-async def get_template_reviews_endpoint(
- template_id: str,
- limit: int = Query(default=50, description="返回数量限制")
-):
+async def get_template_reviews_endpoint(template_id: str, limit: int = Query(default=50, description="返回数量限制")):
"""获取模板评价"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
raise HTTPException(status_code=503, detail="Developer ecosystem manager not available")
@@ -13720,7 +13025,7 @@ async def get_template_reviews_endpoint(
"comment": r.comment,
"is_verified_purchase": r.is_verified_purchase,
"helpful_count": r.helpful_count,
- "created_at": r.created_at
+ "created_at": r.created_at,
}
for r in reviews
]
@@ -13729,11 +13034,12 @@ async def get_template_reviews_endpoint(
# ==================== Plugin Market API ====================
+
@app.post("/api/v1/developer/plugins", tags=["Developer Ecosystem"])
async def create_developer_plugin_endpoint(
request: PluginCreate,
author_id: str = Header(default="system", description="作者ID"),
- author_name: str = Header(default="System", description="作者名称")
+ author_name: str = Header(default="System", description="作者名称"),
):
"""创建插件"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13762,7 +13068,7 @@ async def create_developer_plugin_endpoint(
version=request.version,
min_platform_version=request.min_platform_version,
file_size=request.file_size,
- checksum=request.checksum
+ checksum=request.checksum,
)
return {
@@ -13772,7 +13078,7 @@ async def create_developer_plugin_endpoint(
"status": plugin.status.value,
"price": plugin.price,
"pricing_model": plugin.pricing_model,
- "created_at": plugin.created_at
+ "created_at": plugin.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -13784,7 +13090,7 @@ async def list_developer_plugins_endpoint(
status: str | None = Query(default=None, description="状态过滤"),
search: str | None = Query(default=None, description="搜索关键词"),
author_id: str | None = Query(default=None, description="作者ID过滤"),
- sort_by: str = Query(default="created_at", description="排序方式")
+ sort_by: str = Query(default="created_at", description="排序方式"),
):
"""列出插件"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13796,11 +13102,7 @@ async def list_developer_plugins_endpoint(
status_enum = PluginStatus(status) if status else None
plugins = manager.list_plugins(
- category=category_enum,
- status=status_enum,
- search=search,
- author_id=author_id,
- sort_by=sort_by
+ category=category_enum, status=status_enum, search=search, author_id=author_id, sort_by=sort_by
)
return {
@@ -13818,7 +13120,7 @@ async def list_developer_plugins_endpoint(
"install_count": p.install_count,
"active_install_count": p.active_install_count,
"version": p.version,
- "created_at": p.created_at
+ "created_at": p.created_at,
}
for p in plugins
]
@@ -13861,7 +13163,7 @@ async def get_developer_plugin_endpoint(plugin_id: str):
"version": plugin.version,
"reviewed_by": plugin.reviewed_by,
"reviewed_at": plugin.reviewed_at,
- "created_at": plugin.created_at
+ "created_at": plugin.created_at,
}
@@ -13870,7 +13172,7 @@ async def review_plugin_endpoint(
plugin_id: str,
status: str = Query(..., description="审核状态: approved/rejected"),
reviewed_by: str = Header(default="system", description="审核人ID"),
- notes: str = ""
+ notes: str = "",
):
"""审核插件"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13889,7 +13191,7 @@ async def review_plugin_endpoint(
"id": plugin.id,
"status": plugin.status.value,
"reviewed_by": plugin.reviewed_by,
- "reviewed_at": plugin.reviewed_at
+ "reviewed_at": plugin.reviewed_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -13907,11 +13209,7 @@ async def publish_plugin_endpoint(plugin_id: str):
if not plugin:
raise HTTPException(status_code=404, detail="Plugin not found")
- return {
- "id": plugin.id,
- "status": plugin.status.value,
- "published_at": plugin.published_at
- }
+ return {"id": plugin.id, "status": plugin.status.value, "published_at": plugin.published_at}
@app.post("/api/v1/developer/plugins/{plugin_id}/install", tags=["Developer Ecosystem"])
@@ -13931,7 +13229,7 @@ async def add_plugin_review_endpoint(
plugin_id: str,
request: PluginReviewCreate,
user_id: str = Header(default="user", description="用户ID"),
- user_name: str = Header(default="User", description="用户名称")
+ user_name: str = Header(default="User", description="用户名称"),
):
"""添加插件评价"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -13945,22 +13243,14 @@ async def add_plugin_review_endpoint(
user_name=user_name,
rating=request.rating,
comment=request.comment,
- is_verified_purchase=request.is_verified_purchase
+ is_verified_purchase=request.is_verified_purchase,
)
- return {
- "id": review.id,
- "rating": review.rating,
- "comment": review.comment,
- "created_at": review.created_at
- }
+ return {"id": review.id, "rating": review.rating, "comment": review.comment, "created_at": review.created_at}
@app.get("/api/v1/developer/plugins/{plugin_id}/reviews", tags=["Developer Ecosystem"])
-async def get_plugin_reviews_endpoint(
- plugin_id: str,
- limit: int = Query(default=50, description="返回数量限制")
-):
+async def get_plugin_reviews_endpoint(plugin_id: str, limit: int = Query(default=50, description="返回数量限制")):
"""获取插件评价"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
raise HTTPException(status_code=503, detail="Developer ecosystem manager not available")
@@ -13977,7 +13267,7 @@ async def get_plugin_reviews_endpoint(
"comment": r.comment,
"is_verified_purchase": r.is_verified_purchase,
"helpful_count": r.helpful_count,
- "created_at": r.created_at
+ "created_at": r.created_at,
}
for r in reviews
]
@@ -13986,11 +13276,12 @@ async def get_plugin_reviews_endpoint(
# ==================== Developer Revenue Sharing API ====================
+
@app.get("/api/v1/developer/revenues/{developer_id}", tags=["Developer Ecosystem"])
async def get_developer_revenues_endpoint(
developer_id: str,
start_date: str | None = Query(default=None, description="开始日期 (ISO格式)"),
- end_date: str | None = Query(default=None, description="结束日期 (ISO格式)")
+ end_date: str | None = Query(default=None, description="结束日期 (ISO格式)"),
):
"""获取开发者收益记录"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -14013,7 +13304,7 @@ async def get_developer_revenues_endpoint(
"platform_fee": r.platform_fee,
"developer_earnings": r.developer_earnings,
"currency": r.currency,
- "created_at": r.created_at
+ "created_at": r.created_at,
}
for r in revenues
]
@@ -14034,6 +13325,7 @@ async def get_developer_revenue_summary_endpoint(developer_id: str):
# ==================== Developer Profile & Management API ====================
+
@app.post("/api/v1/developer/profiles", tags=["Developer Ecosystem"])
async def create_developer_profile_endpoint(request: DeveloperProfileCreate):
"""创建开发者档案"""
@@ -14051,7 +13343,7 @@ async def create_developer_profile_endpoint(request: DeveloperProfileCreate):
bio=request.bio,
website=request.website,
github_url=request.github_url,
- avatar_url=request.avatar_url
+ avatar_url=request.avatar_url,
)
return {
@@ -14060,7 +13352,7 @@ async def create_developer_profile_endpoint(request: DeveloperProfileCreate):
"display_name": profile.display_name,
"email": profile.email,
"status": profile.status.value,
- "created_at": profile.created_at
+ "created_at": profile.created_at,
}
@@ -14092,7 +13384,7 @@ async def get_developer_profile_endpoint(developer_id: str):
"template_count": profile.template_count,
"rating_average": profile.rating_average,
"created_at": profile.created_at,
- "verified_at": profile.verified_at
+ "verified_at": profile.verified_at,
}
@@ -14114,7 +13406,7 @@ async def get_developer_profile_by_user_endpoint(user_id: str):
"display_name": profile.display_name,
"status": profile.status.value,
"total_sales": profile.total_sales,
- "total_downloads": profile.total_downloads
+ "total_downloads": profile.total_downloads,
}
@@ -14129,8 +13421,7 @@ async def update_developer_profile_endpoint(developer_id: str, request: Develope
@app.post("/api/v1/developer/profiles/{developer_id}/verify", tags=["Developer Ecosystem"])
async def verify_developer_endpoint(
- developer_id: str,
- status: str = Query(..., description="认证状态: verified/certified/suspended")
+ developer_id: str, status: str = Query(..., description="认证状态: verified/certified/suspended")
):
"""验证开发者"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -14145,11 +13436,7 @@ async def verify_developer_endpoint(
if not profile:
raise HTTPException(status_code=404, detail="Developer profile not found")
- return {
- "id": profile.id,
- "status": profile.status.value,
- "verified_at": profile.verified_at
- }
+ return {"id": profile.id, "status": profile.status.value, "verified_at": profile.verified_at}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -14168,11 +13455,12 @@ async def update_developer_stats_endpoint(developer_id: str):
# ==================== Code Examples API ====================
+
@app.post("/api/v1/developer/code-examples", tags=["Developer Ecosystem"])
async def create_code_example_endpoint(
request: CodeExampleCreate,
author_id: str = Header(default="system", description="作者ID"),
- author_name: str = Header(default="System", description="作者名称")
+ author_name: str = Header(default="System", description="作者名称"),
):
"""创建代码示例"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -14191,7 +13479,7 @@ async def create_code_example_endpoint(
author_id=author_id,
author_name=author_name,
sdk_id=request.sdk_id,
- api_endpoints=request.api_endpoints
+ api_endpoints=request.api_endpoints,
)
return {
@@ -14200,7 +13488,7 @@ async def create_code_example_endpoint(
"language": example.language,
"category": example.category,
"tags": example.tags,
- "created_at": example.created_at
+ "created_at": example.created_at,
}
@@ -14209,7 +13497,7 @@ async def list_code_examples_endpoint(
language: str | None = Query(default=None, description="编程语言过滤"),
category: str | None = Query(default=None, description="分类过滤"),
sdk_id: str | None = Query(default=None, description="SDK ID过滤"),
- search: str | None = Query(default=None, description="搜索关键词")
+ search: str | None = Query(default=None, description="搜索关键词"),
):
"""列出代码示例"""
if not DEVELOPER_ECOSYSTEM_AVAILABLE:
@@ -14231,7 +13519,7 @@ async def list_code_examples_endpoint(
"view_count": e.view_count,
"copy_count": e.copy_count,
"rating": e.rating,
- "created_at": e.created_at
+ "created_at": e.created_at,
}
for e in examples
]
@@ -14267,7 +13555,7 @@ async def get_code_example_endpoint(example_id: str):
"view_count": example.view_count,
"copy_count": example.copy_count,
"rating": example.rating,
- "created_at": example.created_at
+ "created_at": example.created_at,
}
@@ -14285,6 +13573,7 @@ async def copy_code_example_endpoint(example_id: str):
# ==================== API Documentation API ====================
+
@app.get("/api/v1/developer/api-docs", tags=["Developer Ecosystem"])
async def get_latest_api_documentation_endpoint():
"""获取最新 API 文档"""
@@ -14302,7 +13591,7 @@ async def get_latest_api_documentation_endpoint():
"version": doc.version,
"changelog": doc.changelog,
"generated_at": doc.generated_at,
- "generated_by": doc.generated_by
+ "generated_by": doc.generated_by,
}
@@ -14326,12 +13615,13 @@ async def get_api_documentation_endpoint(doc_id: str):
"html_content": doc.html_content,
"changelog": doc.changelog,
"generated_at": doc.generated_at,
- "generated_by": doc.generated_by
+ "generated_by": doc.generated_by,
}
# ==================== Developer Portal API ====================
+
@app.post("/api/v1/developer/portal-configs", tags=["Developer Ecosystem"])
async def create_portal_config_endpoint(request: PortalConfigCreate):
"""创建开发者门户配置"""
@@ -14354,7 +13644,7 @@ async def create_portal_config_endpoint(request: PortalConfigCreate):
support_url=request.support_url,
github_url=request.github_url,
discord_url=request.discord_url,
- api_base_url=request.api_base_url
+ api_base_url=request.api_base_url,
)
return {
@@ -14362,7 +13652,7 @@ async def create_portal_config_endpoint(request: PortalConfigCreate):
"name": config.name,
"theme": config.theme,
"is_active": config.is_active,
- "created_at": config.created_at
+ "created_at": config.created_at,
}
@@ -14392,7 +13682,7 @@ async def get_active_portal_config_endpoint():
"github_url": config.github_url,
"discord_url": config.discord_url,
"api_base_url": config.api_base_url,
- "is_active": config.is_active
+ "is_active": config.is_active,
}
@@ -14417,7 +13707,7 @@ async def get_portal_config_endpoint(config_id: str):
"secondary_color": config.secondary_color,
"support_email": config.support_email,
"api_base_url": config.api_base_url,
- "is_active": config.is_active
+ "is_active": config.is_active,
}
@@ -14471,8 +13761,9 @@ class AlertRuleResponse(BaseModel):
class AlertChannelCreate(BaseModel):
name: str = Field(..., description="渠道名称")
- channel_type: str = Field(...,
- description="渠道类型: pagerduty, opsgenie, feishu, dingtalk, slack, email, sms, webhook")
+ channel_type: str = Field(
+ ..., description="渠道类型: pagerduty, opsgenie, feishu, dingtalk, slack, email, sms, webhook"
+ )
config: dict = Field(default_factory=dict, description="渠道特定配置")
severity_filter: list[str] = Field(default_factory=lambda: ["p0", "p1", "p2", "p3"], description="过滤的告警级别")
@@ -14558,10 +13849,7 @@ class BackupJobCreate(BaseModel):
# Alert Rules API
@app.post("/api/v1/ops/alert-rules", response_model=AlertRuleResponse, tags=["Operations & Monitoring"])
async def create_alert_rule_endpoint(
- tenant_id: str,
- request: AlertRuleCreate,
- user_id: str = "system",
- _=Depends(verify_api_key)
+ tenant_id: str, request: AlertRuleCreate, user_id: str = "system", _=Depends(verify_api_key)
):
"""创建告警规则"""
if not OPS_MANAGER_AVAILABLE:
@@ -14584,7 +13872,7 @@ async def create_alert_rule_endpoint(
channels=request.channels,
labels=request.labels,
annotations=request.annotations,
- created_by=user_id
+ created_by=user_id,
)
return AlertRuleResponse(
@@ -14603,18 +13891,14 @@ async def create_alert_rule_endpoint(
annotations=rule.annotations,
is_enabled=rule.is_enabled,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/ops/alert-rules", tags=["Operations & Monitoring"])
-async def list_alert_rules_endpoint(
- tenant_id: str,
- is_enabled: bool | None = None,
- _=Depends(verify_api_key)
-):
+async def list_alert_rules_endpoint(tenant_id: str, is_enabled: bool | None = None, _=Depends(verify_api_key)):
"""列出租户的告警规则"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -14639,7 +13923,7 @@ async def list_alert_rules_endpoint(
annotations=rule.annotations,
is_enabled=rule.is_enabled,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
for rule in rules
]
@@ -14673,16 +13957,12 @@ async def get_alert_rule_endpoint(rule_id: str, _=Depends(verify_api_key)):
annotations=rule.annotations,
is_enabled=rule.is_enabled,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
@app.patch("/api/v1/ops/alert-rules/{rule_id}", response_model=AlertRuleResponse, tags=["Operations & Monitoring"])
-async def update_alert_rule_endpoint(
- rule_id: str,
- updates: dict,
- _=Depends(verify_api_key)
-):
+async def update_alert_rule_endpoint(rule_id: str, updates: dict, _=Depends(verify_api_key)):
"""更新告警规则"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -14709,7 +13989,7 @@ async def update_alert_rule_endpoint(
annotations=rule.annotations,
is_enabled=rule.is_enabled,
created_at=rule.created_at,
- updated_at=rule.updated_at
+ updated_at=rule.updated_at,
)
@@ -14730,11 +14010,7 @@ async def delete_alert_rule_endpoint(rule_id: str, _=Depends(verify_api_key)):
# Alert Channels API
@app.post("/api/v1/ops/alert-channels", response_model=AlertChannelResponse, tags=["Operations & Monitoring"])
-async def create_alert_channel_endpoint(
- tenant_id: str,
- request: AlertChannelCreate,
- _=Depends(verify_api_key)
-):
+async def create_alert_channel_endpoint(tenant_id: str, request: AlertChannelCreate, _=Depends(verify_api_key)):
"""创建告警渠道"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -14747,7 +14023,7 @@ async def create_alert_channel_endpoint(
name=request.name,
channel_type=AlertChannelType(request.channel_type),
config=request.config,
- severity_filter=request.severity_filter
+ severity_filter=request.severity_filter,
)
return AlertChannelResponse(
@@ -14760,7 +14036,7 @@ async def create_alert_channel_endpoint(
success_count=channel.success_count,
fail_count=channel.fail_count,
last_used_at=channel.last_used_at,
- created_at=channel.created_at
+ created_at=channel.created_at,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -14786,7 +14062,7 @@ async def list_alert_channels_endpoint(tenant_id: str, _=Depends(verify_api_key)
success_count=channel.success_count,
fail_count=channel.fail_count,
last_used_at=channel.last_used_at,
- created_at=channel.created_at
+ created_at=channel.created_at,
)
for channel in channels
]
@@ -14810,11 +14086,7 @@ async def test_alert_channel_endpoint(channel_id: str, _=Depends(verify_api_key)
# Alerts API
@app.get("/api/v1/ops/alerts", tags=["Operations & Monitoring"])
async def list_alerts_endpoint(
- tenant_id: str,
- status: str | None = None,
- severity: str | None = None,
- limit: int = 100,
- _=Depends(verify_api_key)
+ tenant_id: str, status: str | None = None, severity: str | None = None, limit: int = 100, _=Depends(verify_api_key)
):
"""列出租户的告警"""
if not OPS_MANAGER_AVAILABLE:
@@ -14842,18 +14114,14 @@ async def list_alerts_endpoint(
started_at=alert.started_at,
resolved_at=alert.resolved_at,
acknowledged_by=alert.acknowledged_by,
- suppression_count=alert.suppression_count
+ suppression_count=alert.suppression_count,
)
for alert in alerts
]
@app.post("/api/v1/ops/alerts/{alert_id}/acknowledge", tags=["Operations & Monitoring"])
-async def acknowledge_alert_endpoint(
- alert_id: str,
- user_id: str = "system",
- _=Depends(verify_api_key)
-):
+async def acknowledge_alert_endpoint(alert_id: str, user_id: str = "system", _=Depends(verify_api_key)):
"""确认告警"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -14892,7 +14160,7 @@ async def record_resource_metric_endpoint(
metric_value: float,
unit: str,
metadata: dict = None,
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""记录资源指标"""
if not OPS_MANAGER_AVAILABLE:
@@ -14908,7 +14176,7 @@ async def record_resource_metric_endpoint(
metric_name=metric_name,
metric_value=metric_value,
unit=unit,
- metadata=metadata
+ metadata=metadata,
)
return {
@@ -14917,7 +14185,7 @@ async def record_resource_metric_endpoint(
"metric_name": metric.metric_name,
"metric_value": metric.metric_value,
"unit": metric.unit,
- "timestamp": metric.timestamp
+ "timestamp": metric.timestamp,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -14925,10 +14193,7 @@ async def record_resource_metric_endpoint(
@app.get("/api/v1/ops/resource-metrics", tags=["Operations & Monitoring"])
async def get_resource_metrics_endpoint(
- tenant_id: str,
- metric_name: str,
- seconds: int = 3600,
- _=Depends(verify_api_key)
+ tenant_id: str, metric_name: str, seconds: int = 3600, _=Depends(verify_api_key)
):
"""获取资源指标数据"""
if not OPS_MANAGER_AVAILABLE:
@@ -14945,7 +14210,7 @@ async def get_resource_metrics_endpoint(
"metric_name": m.metric_name,
"metric_value": m.metric_value,
"unit": m.unit,
- "timestamp": m.timestamp
+ "timestamp": m.timestamp,
}
for m in metrics
]
@@ -14959,7 +14224,7 @@ async def create_capacity_plan_endpoint(
current_capacity: float,
prediction_date: str,
confidence: float = 0.8,
- _=Depends(verify_api_key)
+ _=Depends(verify_api_key),
):
"""创建容量规划"""
if not OPS_MANAGER_AVAILABLE:
@@ -14973,7 +14238,7 @@ async def create_capacity_plan_endpoint(
resource_type=ResourceType(resource_type),
current_capacity=current_capacity,
prediction_date=prediction_date,
- confidence=confidence
+ confidence=confidence,
)
return {
@@ -14985,7 +14250,7 @@ async def create_capacity_plan_endpoint(
"confidence": plan.confidence,
"recommended_action": plan.recommended_action,
"estimated_cost": plan.estimated_cost,
- "created_at": plan.created_at
+ "created_at": plan.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -15010,7 +14275,7 @@ async def list_capacity_plans_endpoint(tenant_id: str, _=Depends(verify_api_key)
"confidence": plan.confidence,
"recommended_action": plan.recommended_action,
"estimated_cost": plan.estimated_cost,
- "created_at": plan.created_at
+ "created_at": plan.created_at,
}
for plan in plans
]
@@ -15019,9 +14284,7 @@ async def list_capacity_plans_endpoint(tenant_id: str, _=Depends(verify_api_key)
# Auto Scaling API
@app.post("/api/v1/ops/auto-scaling-policies", tags=["Operations & Monitoring"])
async def create_auto_scaling_policy_endpoint(
- tenant_id: str,
- request: AutoScalingPolicyCreate,
- _=Depends(verify_api_key)
+ tenant_id: str, request: AutoScalingPolicyCreate, _=Depends(verify_api_key)
):
"""创建自动扩缩容策略"""
if not OPS_MANAGER_AVAILABLE:
@@ -15041,7 +14304,7 @@ async def create_auto_scaling_policy_endpoint(
scale_down_threshold=request.scale_down_threshold,
scale_up_step=request.scale_up_step,
scale_down_step=request.scale_down_step,
- cooldown_period=request.cooldown_period
+ cooldown_period=request.cooldown_period,
)
return {
@@ -15054,7 +14317,7 @@ async def create_auto_scaling_policy_endpoint(
"scale_up_threshold": policy.scale_up_threshold,
"scale_down_threshold": policy.scale_down_threshold,
"is_enabled": policy.is_enabled,
- "created_at": policy.created_at
+ "created_at": policy.created_at,
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -15078,7 +14341,7 @@ async def list_auto_scaling_policies_endpoint(tenant_id: str, _=Depends(verify_a
"max_instances": policy.max_instances,
"target_utilization": policy.target_utilization,
"is_enabled": policy.is_enabled,
- "created_at": policy.created_at
+ "created_at": policy.created_at,
}
for policy in policies
]
@@ -15086,10 +14349,7 @@ async def list_auto_scaling_policies_endpoint(tenant_id: str, _=Depends(verify_a
@app.get("/api/v1/ops/scaling-events", tags=["Operations & Monitoring"])
async def list_scaling_events_endpoint(
- tenant_id: str,
- policy_id: str | None = None,
- limit: int = 100,
- _=Depends(verify_api_key)
+ tenant_id: str, policy_id: str | None = None, limit: int = 100, _=Depends(verify_api_key)
):
"""获取扩缩容事件列表"""
if not OPS_MANAGER_AVAILABLE:
@@ -15108,7 +14368,7 @@ async def list_scaling_events_endpoint(
"reason": event.reason,
"status": event.status,
"started_at": event.started_at,
- "completed_at": event.completed_at
+ "completed_at": event.completed_at,
}
for event in events
]
@@ -15116,11 +14376,7 @@ async def list_scaling_events_endpoint(
# Health Check API
@app.post("/api/v1/ops/health-checks", response_model=HealthCheckResponse, tags=["Operations & Monitoring"])
-async def create_health_check_endpoint(
- tenant_id: str,
- request: HealthCheckCreate,
- _=Depends(verify_api_key)
-):
+async def create_health_check_endpoint(tenant_id: str, request: HealthCheckCreate, _=Depends(verify_api_key)):
"""创建健康检查"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -15136,7 +14392,7 @@ async def create_health_check_endpoint(
check_config=request.check_config,
interval=request.interval,
timeout=request.timeout,
- retry_count=request.retry_count
+ retry_count=request.retry_count,
)
return HealthCheckResponse(
@@ -15148,7 +14404,7 @@ async def create_health_check_endpoint(
interval=check.interval,
timeout=check.timeout,
is_enabled=check.is_enabled,
- created_at=check.created_at
+ created_at=check.created_at,
)
@@ -15171,7 +14427,7 @@ async def list_health_checks_endpoint(tenant_id: str, _=Depends(verify_api_key))
"interval": check.interval,
"timeout": check.timeout,
"is_enabled": check.is_enabled,
- "created_at": check.created_at
+ "created_at": check.created_at,
}
for check in checks
]
@@ -15192,17 +14448,13 @@ async def execute_health_check_endpoint(check_id: str, _=Depends(verify_api_key)
"status": result.status.value,
"response_time": result.response_time,
"message": result.message,
- "checked_at": result.checked_at
+ "checked_at": result.checked_at,
}
# Backup API
@app.post("/api/v1/ops/backup-jobs", tags=["Operations & Monitoring"])
-async def create_backup_job_endpoint(
- tenant_id: str,
- request: BackupJobCreate,
- _=Depends(verify_api_key)
-):
+async def create_backup_job_endpoint(tenant_id: str, request: BackupJobCreate, _=Depends(verify_api_key)):
"""创建备份任务"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -15219,7 +14471,7 @@ async def create_backup_job_endpoint(
retention_days=request.retention_days,
encryption_enabled=request.encryption_enabled,
compression_enabled=request.compression_enabled,
- storage_location=request.storage_location
+ storage_location=request.storage_location,
)
return {
@@ -15229,7 +14481,7 @@ async def create_backup_job_endpoint(
"target_type": job.target_type,
"schedule": job.schedule,
"is_enabled": job.is_enabled,
- "created_at": job.created_at
+ "created_at": job.created_at,
}
@@ -15250,7 +14502,7 @@ async def list_backup_jobs_endpoint(tenant_id: str, _=Depends(verify_api_key)):
"target_type": job.target_type,
"schedule": job.schedule,
"is_enabled": job.is_enabled,
- "created_at": job.created_at
+ "created_at": job.created_at,
}
for job in jobs
]
@@ -15273,16 +14525,13 @@ async def execute_backup_endpoint(job_id: str, _=Depends(verify_api_key)):
"job_id": record.job_id,
"status": record.status.value,
"started_at": record.started_at,
- "storage_path": record.storage_path
+ "storage_path": record.storage_path,
}
@app.get("/api/v1/ops/backup-records", tags=["Operations & Monitoring"])
async def list_backup_records_endpoint(
- tenant_id: str,
- job_id: str | None = None,
- limit: int = 100,
- _=Depends(verify_api_key)
+ tenant_id: str, job_id: str | None = None, limit: int = 100, _=Depends(verify_api_key)
):
"""获取备份记录列表"""
if not OPS_MANAGER_AVAILABLE:
@@ -15300,7 +14549,7 @@ async def list_backup_records_endpoint(
"checksum": record.checksum,
"started_at": record.started_at,
"completed_at": record.completed_at,
- "storage_path": record.storage_path
+ "storage_path": record.storage_path,
}
for record in records
]
@@ -15308,12 +14557,7 @@ async def list_backup_records_endpoint(
# Cost Optimization API
@app.post("/api/v1/ops/cost-reports", tags=["Operations & Monitoring"])
-async def generate_cost_report_endpoint(
- tenant_id: str,
- year: int,
- month: int,
- _=Depends(verify_api_key)
-):
+async def generate_cost_report_endpoint(tenant_id: str, year: int, month: int, _=Depends(verify_api_key)):
"""生成成本报告"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -15329,7 +14573,7 @@ async def generate_cost_report_endpoint(
"breakdown": report.breakdown,
"trends": report.trends,
"anomalies": report.anomalies,
- "created_at": report.created_at
+ "created_at": report.created_at,
}
@@ -15352,17 +14596,14 @@ async def get_idle_resources_endpoint(tenant_id: str, _=Depends(verify_api_key))
"estimated_monthly_cost": resource.estimated_monthly_cost,
"currency": resource.currency,
"reason": resource.reason,
- "recommendation": resource.recommendation
+ "recommendation": resource.recommendation,
}
for resource in idle_resources
]
@app.post("/api/v1/ops/cost-optimization-suggestions", tags=["Operations & Monitoring"])
-async def generate_cost_optimization_suggestions_endpoint(
- tenant_id: str,
- _=Depends(verify_api_key)
-):
+async def generate_cost_optimization_suggestions_endpoint(tenant_id: str, _=Depends(verify_api_key)):
"""生成成本优化建议"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -15382,7 +14623,7 @@ async def generate_cost_optimization_suggestions_endpoint(
"difficulty": suggestion.difficulty,
"risk_level": suggestion.risk_level,
"is_applied": suggestion.is_applied,
- "created_at": suggestion.created_at
+ "created_at": suggestion.created_at,
}
for suggestion in suggestions
]
@@ -15390,9 +14631,7 @@ async def generate_cost_optimization_suggestions_endpoint(
@app.get("/api/v1/ops/cost-optimization-suggestions", tags=["Operations & Monitoring"])
async def list_cost_optimization_suggestions_endpoint(
- tenant_id: str,
- is_applied: bool | None = None,
- _=Depends(verify_api_key)
+ tenant_id: str, is_applied: bool | None = None, _=Depends(verify_api_key)
):
"""获取成本优化建议列表"""
if not OPS_MANAGER_AVAILABLE:
@@ -15412,17 +14651,14 @@ async def list_cost_optimization_suggestions_endpoint(
"difficulty": suggestion.difficulty,
"risk_level": suggestion.risk_level,
"is_applied": suggestion.is_applied,
- "created_at": suggestion.created_at
+ "created_at": suggestion.created_at,
}
for suggestion in suggestions
]
@app.post("/api/v1/ops/cost-optimization-suggestions/{suggestion_id}/apply", tags=["Operations & Monitoring"])
-async def apply_cost_optimization_suggestion_endpoint(
- suggestion_id: str,
- _=Depends(verify_api_key)
-):
+async def apply_cost_optimization_suggestion_endpoint(suggestion_id: str, _=Depends(verify_api_key)):
"""应用成本优化建议"""
if not OPS_MANAGER_AVAILABLE:
raise HTTPException(status_code=503, detail="Operations manager not available")
@@ -15440,11 +14676,10 @@ async def apply_cost_optimization_suggestion_endpoint(
"id": suggestion.id,
"title": suggestion.title,
"is_applied": suggestion.is_applied,
- "applied_at": suggestion.applied_at
- }
+ "applied_at": suggestion.applied_at,
+ },
}
if __name__ == "__main__":
- import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/backend/multimodal_entity_linker.py b/backend/multimodal_entity_linker.py
index 141f779..1ce7f43 100644
--- a/backend/multimodal_entity_linker.py
+++ b/backend/multimodal_entity_linker.py
@@ -80,7 +80,7 @@ class MultimodalEntityLinker:
# 模态类型
MODALITIES = ["audio", "video", "image", "document"]
- def __init__(self, similarity_threshold: float = 0.85):
+ def __init__(self, similarity_threshold: float = 0.85) -> None:
"""
初始化多模态实体关联器
diff --git a/backend/multimodal_processor.py b/backend/multimodal_processor.py
index 868198e..acd5a03 100644
--- a/backend/multimodal_processor.py
+++ b/backend/multimodal_processor.py
@@ -94,7 +94,7 @@ class VideoProcessingResult:
class MultimodalProcessor:
"""多模态处理器 - 处理视频文件"""
- def __init__(self, temp_dir: str = None, frame_interval: int = 5):
+ def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
"""
初始化多模态处理器
@@ -401,7 +401,7 @@ class MultimodalProcessor:
error_message=str(e),
)
- def cleanup(self, video_id: str = None):
+ def cleanup(self, video_id: str = None) -> None:
"""
清理临时文件
diff --git a/backend/neo4j_manager.py b/backend/neo4j_manager.py
index c39fa99..d0888be 100644
--- a/backend/neo4j_manager.py
+++ b/backend/neo4j_manager.py
@@ -107,7 +107,7 @@ class Neo4jManager:
self._connect()
- def _connect(self):
+ def _connect(self) -> None:
"""建立 Neo4j 连接"""
if not NEO4J_AVAILABLE:
return
@@ -121,7 +121,7 @@ class Neo4jManager:
logger.error(f"Failed to connect to Neo4j: {e}")
self._driver = None
- def close(self):
+ def close(self) -> None:
"""关闭连接"""
if self._driver:
self._driver.close()
@@ -137,7 +137,7 @@ class Neo4jManager:
except BaseException:
return False
- def init_schema(self):
+ def init_schema(self) -> None:
"""初始化图数据库 Schema(约束和索引)"""
if not self._driver:
logger.error("Neo4j not connected")
@@ -178,7 +178,7 @@ class Neo4jManager:
# ==================== 数据同步 ====================
- def sync_project(self, project_id: str, project_name: str, project_description: str = ""):
+ def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None:
"""同步项目节点到 Neo4j"""
if not self._driver:
return
@@ -196,7 +196,7 @@ class Neo4jManager:
description=project_description,
)
- def sync_entity(self, entity: GraphEntity):
+ def sync_entity(self, entity: GraphEntity) -> None:
"""同步单个实体到 Neo4j"""
if not self._driver:
return
@@ -225,7 +225,7 @@ class Neo4jManager:
properties=json.dumps(entity.properties),
)
- def sync_entities_batch(self, entities: list[GraphEntity]):
+ def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
"""批量同步实体到 Neo4j"""
if not self._driver or not entities:
return
@@ -262,7 +262,7 @@ class Neo4jManager:
entities=entities_data,
)
- def sync_relation(self, relation: GraphRelation):
+ def sync_relation(self, relation: GraphRelation) -> None:
"""同步单个关系到 Neo4j"""
if not self._driver:
return
@@ -286,7 +286,7 @@ class Neo4jManager:
properties=json.dumps(relation.properties),
)
- def sync_relations_batch(self, relations: list[GraphRelation]):
+ def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
"""批量同步关系到 Neo4j"""
if not self._driver or not relations:
return
@@ -318,7 +318,7 @@ class Neo4jManager:
relations=relations_data,
)
- def delete_entity(self, entity_id: str):
+ def delete_entity(self, entity_id: str) -> None:
"""从 Neo4j 删除实体及其关系"""
if not self._driver:
return
@@ -332,7 +332,7 @@ class Neo4jManager:
id=entity_id,
)
- def delete_project(self, project_id: str):
+ def delete_project(self, project_id: str) -> None:
"""从 Neo4j 删除项目及其所有实体和关系"""
if not self._driver:
return
@@ -949,7 +949,7 @@ def get_neo4j_manager() -> Neo4jManager:
return _neo4j_manager
-def close_neo4j_manager():
+def close_neo4j_manager() -> None:
"""关闭 Neo4j 连接"""
global _neo4j_manager
if _neo4j_manager:
@@ -958,7 +958,7 @@ def close_neo4j_manager():
# 便捷函数
-def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]):
+def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None:
"""
同步整个项目到 Neo4j
diff --git a/backend/ops_manager.py b/backend/ops_manager.py
index b444b45..064e7c5 100644
--- a/backend/ops_manager.py
+++ b/backend/ops_manager.py
@@ -456,13 +456,13 @@ class OpsManager:
self._evaluator_thread = None
self._register_default_evaluators()
- def _get_db(self):
+ def _get_db(self) -> None:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
- def _register_default_evaluators(self):
+ def _register_default_evaluators(self) -> None:
"""注册默认的告警评估器"""
self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule
self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule
@@ -1249,7 +1249,7 @@ class OpsManager:
return self.get_alert(alert_id)
- def _increment_suppression_count(self, alert_id: str):
+ def _increment_suppression_count(self, alert_id: str) -> None:
"""增加告警抑制计数"""
with self._get_db() as conn:
conn.execute(
@@ -1262,7 +1262,7 @@ class OpsManager:
)
conn.commit()
- def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool):
+ def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool) -> None:
"""更新告警通知状态"""
with self._get_db() as conn:
row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone()
@@ -1276,7 +1276,7 @@ class OpsManager:
)
conn.commit()
- def _update_channel_stats(self, channel_id: str, success: bool):
+ def _update_channel_stats(self, channel_id: str, success: bool) -> None:
"""更新渠道统计"""
now = datetime.now().isoformat()
@@ -1769,9 +1769,7 @@ class OpsManager:
return self._row_to_scaling_event(row)
return None
- def update_scaling_event_status(
- self, event_id: str, status: str, error_message: str = None
- ) -> ScalingEvent | None:
+ def update_scaling_event_status(self, event_id: str, status: str, error_message: str = None) -> ScalingEvent | None:
"""更新扩缩容事件状态"""
now = datetime.now().isoformat()
@@ -2339,7 +2337,7 @@ class OpsManager:
return record
- def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None):
+ def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None:
"""完成备份"""
now = datetime.now().isoformat()
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]
diff --git a/backend/oss_uploader.py b/backend/oss_uploader.py
index af72403..83de463 100644
--- a/backend/oss_uploader.py
+++ b/backend/oss_uploader.py
@@ -37,7 +37,7 @@ class OSSUploader:
url = self.bucket.sign_url("GET", object_name, 3600)
return url, object_name
- def delete_object(self, object_name: str):
+ def delete_object(self, object_name: str) -> None:
"""删除 OSS 对象"""
self.bucket.delete_object(object_name)
diff --git a/backend/performance_manager.py b/backend/performance_manager.py
index 3a177ac..ed617d5 100644
--- a/backend/performance_manager.py
+++ b/backend/performance_manager.py
@@ -55,7 +55,7 @@ class CacheStats:
expired: int = 0
hit_rate: float = 0.0
- def update_hit_rate(self):
+ def update_hit_rate(self) -> None:
"""更新命中率"""
if self.total_requests > 0:
self.hit_rate = round(self.hits / self.total_requests, 4)
@@ -194,7 +194,7 @@ class CacheManager:
# 初始化缓存统计表
self._init_cache_tables()
- def _init_cache_tables(self):
+ def _init_cache_tables(self) -> None:
"""初始化缓存统计表"""
conn = sqlite3.connect(self.db_path)
@@ -234,7 +234,7 @@ class CacheManager:
except BaseException:
return 1024 # 默认估算
- def _evict_lru(self, required_space: int = 0):
+ def _evict_lru(self, required_space: int = 0) -> None:
"""LRU 淘汰策略"""
with self.cache_lock:
while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache:
@@ -444,7 +444,7 @@ class CacheManager:
return stats
- def save_stats(self):
+ def save_stats(self) -> None:
"""保存缓存统计到数据库"""
conn = sqlite3.connect(self.db_path)
@@ -618,7 +618,7 @@ class DatabaseSharding:
# 初始化分片
self._init_shards()
- def _init_shards(self):
+ def _init_shards(self) -> None:
"""初始化分片"""
# 计算每个分片的 key 范围
chars = "0123456789abcdef"
@@ -645,7 +645,7 @@ class DatabaseSharding:
if not os.path.exists(db_path):
self._create_shard_db(db_path)
- def _create_shard_db(self, db_path: str):
+ def _create_shard_db(self, db_path: str) -> None:
"""创建分片数据库"""
conn = sqlite3.connect(db_path)
@@ -792,7 +792,7 @@ class DatabaseSharding:
print(f"迁移失败: {e}")
return False
- def _update_shard_stats(self, shard_id: str):
+ def _update_shard_stats(self, shard_id: str) -> None:
"""更新分片统计"""
shard_info = self.shard_map.get(shard_id)
if not shard_info:
@@ -923,7 +923,7 @@ class TaskQueue:
except Exception as e:
print(f"Celery 初始化失败,使用内存任务队列: {e}")
- def _init_task_tables(self):
+ def _init_task_tables(self) -> None:
"""初始化任务队列表"""
conn = sqlite3.connect(self.db_path)
@@ -953,7 +953,7 @@ class TaskQueue:
"""检查任务队列是否可用"""
return self.use_celery or True # 内存模式也可用
- def register_handler(self, task_type: str, handler: Callable):
+ def register_handler(self, task_type: str, handler: Callable) -> None:
"""注册任务处理器"""
self.task_handlers[task_type] = handler
@@ -1014,7 +1014,7 @@ class TaskQueue:
return task_id
- def _execute_task(self, task_id: str):
+ def _execute_task(self, task_id: str) -> None:
"""执行任务(内存模式)"""
with self.task_lock:
task = self.tasks.get(task_id)
@@ -1055,7 +1055,7 @@ class TaskQueue:
self._update_task_status(task)
- def _save_task(self, task: TaskInfo):
+ def _save_task(self, task: TaskInfo) -> None:
"""保存任务到数据库"""
conn = sqlite3.connect(self.db_path)
@@ -1084,7 +1084,7 @@ class TaskQueue:
conn.commit()
conn.close()
- def _update_task_status(self, task: TaskInfo):
+ def _update_task_status(self, task: TaskInfo) -> None:
"""更新任务状态"""
conn = sqlite3.connect(self.db_path)
@@ -1143,9 +1143,7 @@ class TaskQueue:
with self.task_lock:
return self.tasks.get(task_id)
- def list_tasks(
- self, status: str | None = None, task_type: str | None = None, limit: int = 100
- ) -> list[TaskInfo]:
+ def list_tasks(self, status: str | None = None, task_type: str | None = None, limit: int = 100) -> list[TaskInfo]:
"""列出任务"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
@@ -1333,7 +1331,7 @@ class PerformanceMonitor:
if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
self._record_slow_query(metric)
- def _flush_metrics(self):
+ def _flush_metrics(self) -> None:
"""将缓冲区指标写入数据库"""
if not self.metrics_buffer:
return
@@ -1362,12 +1360,12 @@ class PerformanceMonitor:
self.metrics_buffer = []
- def _record_slow_query(self, metric: PerformanceMetric):
+ def _record_slow_query(self, metric: PerformanceMetric) -> None:
"""记录慢查询"""
# 可以发送到专门的慢查询日志或监控系统
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
- def _trigger_alert(self, metric: PerformanceMetric):
+ def _trigger_alert(self, metric: PerformanceMetric) -> None:
"""触发告警"""
alert_data = {
"type": "performance_alert",
@@ -1382,7 +1380,7 @@ class PerformanceMonitor:
except Exception as e:
print(f"告警处理失败: {e}")
- def register_alert_handler(self, handler: Callable):
+ def register_alert_handler(self, handler: Callable) -> None:
"""注册告警处理器"""
self.alert_handlers.append(handler)
@@ -1585,7 +1583,9 @@ class PerformanceMonitor:
# ==================== 性能装饰器 ====================
-def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None):
+def cached(
+ cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None
+) -> None:
"""
缓存装饰器
@@ -1598,7 +1598,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
def decorator(func: Callable) -> Callable:
@wraps(func)
- def wrapper(*args, **kwargs):
+ def wrapper(*args, **kwargs) -> None:
# 生成缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
@@ -1625,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
return decorator
-def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None):
+def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
"""
性能监控装饰器
diff --git a/backend/plugin_manager.py b/backend/plugin_manager.py
index d330716..a58637f 100644
--- a/backend/plugin_manager.py
+++ b/backend/plugin_manager.py
@@ -163,7 +163,7 @@ class PluginManager:
self._handlers = {}
self._register_default_handlers()
- def _register_default_handlers(self):
+ def _register_default_handlers(self) -> None:
"""注册默认处理器"""
self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self)
self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu")
@@ -371,7 +371,7 @@ class PluginManager:
return cursor.rowcount > 0
- def record_plugin_usage(self, plugin_id: str):
+ def record_plugin_usage(self, plugin_id: str) -> None:
"""记录插件使用"""
conn = self.db.get_conn()
now = datetime.now().isoformat()
@@ -826,9 +826,6 @@ class BotHandler:
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送飞书消息"""
- import base64
- import hashlib
-
timestamp = str(int(time.time()))
# 生成签名
@@ -851,9 +848,6 @@ class BotHandler:
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送钉钉消息"""
- import base64
- import hashlib
-
timestamp = str(round(time.time() * 1000))
# 生成签名
@@ -1358,7 +1352,7 @@ class WebDAVSyncManager:
_plugin_manager = None
-def get_plugin_manager(db_manager=None):
+def get_plugin_manager(db_manager=None) -> None:
"""获取 PluginManager 单例"""
global _plugin_manager
if _plugin_manager is None:
diff --git a/backend/rate_limiter.py b/backend/rate_limiter.py
index 5e01b0b..db9142d 100644
--- a/backend/rate_limiter.py
+++ b/backend/rate_limiter.py
@@ -56,7 +56,7 @@ class SlidingWindowCounter:
self._cleanup_old(now)
return sum(self.requests.values())
- def _cleanup_old(self, now: int):
+ def _cleanup_old(self, now: int) -> None:
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
cutoff = now - self.window_size
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
@@ -67,7 +67,7 @@ class SlidingWindowCounter:
class RateLimiter:
"""API 限流器"""
- def __init__(self):
+ def __init__(self) -> None:
# key -> SlidingWindowCounter
self.counters: dict[str, SlidingWindowCounter] = {}
# key -> RateLimitConfig
@@ -143,7 +143,7 @@ class RateLimiter:
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
)
- def reset(self, key: str | None = None):
+ def reset(self, key: str | None = None) -> None:
"""重置限流计数器"""
if key:
self.counters.pop(key, None)
@@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter:
# 限流装饰器(用于函数级别限流)
-def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None):
+def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
"""
限流装饰器
diff --git a/backend/search_manager.py b/backend/search_manager.py
index ce25840..7c0a8e0 100644
--- a/backend/search_manager.py
+++ b/backend/search_manager.py
@@ -22,6 +22,7 @@ from enum import Enum
class SearchOperator(Enum):
"""搜索操作符"""
+
AND = "AND"
OR = "OR"
NOT = "NOT"
@@ -31,15 +32,18 @@ class SearchOperator(Enum):
try:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
+
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
# ==================== 数据模型 ====================
+
@dataclass
class SearchResult:
"""搜索结果数据模型"""
+
id: str
content: str
content_type: str # transcript, entity, relation
@@ -56,13 +60,14 @@ class SearchResult:
"project_id": self.project_id,
"score": self.score,
"highlights": self.highlights,
- "metadata": self.metadata
+ "metadata": self.metadata,
}
@dataclass
class SemanticSearchResult:
"""语义搜索结果数据模型"""
+
id: str
content: str
content_type: str
@@ -78,7 +83,7 @@ class SemanticSearchResult:
"content_type": self.content_type,
"project_id": self.project_id,
"similarity": round(self.similarity, 4),
- "metadata": self.metadata
+ "metadata": self.metadata,
}
if self.embedding:
result["embedding_dim"] = len(self.embedding)
@@ -88,6 +93,7 @@ class SemanticSearchResult:
@dataclass
class EntityPath:
"""实体关系路径数据模型"""
+
path_id: str
source_entity_id: str
source_entity_name: str
@@ -110,13 +116,14 @@ class EntityPath:
"nodes": self.nodes,
"edges": self.edges,
"confidence": self.confidence,
- "path_description": self.path_description
+ "path_description": self.path_description,
}
@dataclass
class KnowledgeGap:
"""知识缺口数据模型"""
+
gap_id: str
gap_type: str # missing_attribute, sparse_relation, isolated_entity, incomplete_entity
entity_id: str | None
@@ -137,13 +144,14 @@ class KnowledgeGap:
"severity": self.severity,
"suggestions": self.suggestions,
"related_entities": self.related_entities,
- "metadata": self.metadata
+ "metadata": self.metadata,
}
@dataclass
class SearchIndex:
"""搜索索引数据模型"""
+
id: str
content_id: str
content_type: str
@@ -157,6 +165,7 @@ class SearchIndex:
@dataclass
class TextEmbedding:
"""文本 Embedding 数据模型"""
+
id: str
content_id: str
content_type: str
@@ -168,6 +177,7 @@ class TextEmbedding:
# ==================== 全文搜索 ====================
+
class FullTextSearch:
"""
全文搜索模块
@@ -189,7 +199,7 @@ class FullTextSearch:
conn.row_factory = sqlite3.Row
return conn
- def _init_search_tables(self):
+ def _init_search_tables(self) -> None:
"""初始化搜索相关表"""
conn = self._get_conn()
@@ -239,7 +249,7 @@ class FullTextSearch:
# 清理文本
text = text.lower()
# 提取中文字符、英文单词和数字
- tokens = re.findall(r'[\u4e00-\u9fa5]+|[a-z]+|\d+', text)
+ tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text)
return tokens
def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]:
@@ -259,8 +269,7 @@ class FullTextSearch:
return dict(positions)
- def index_content(self, content_id: str, content_type: str,
- project_id: str, text: str) -> bool:
+ def index_content(self, content_id: str, content_type: str, project_id: str, text: str) -> bool:
"""
为内容创建搜索索引
@@ -294,28 +303,35 @@ class FullTextSearch:
now = datetime.now().isoformat()
# 保存索引
- conn.execute("""
+ conn.execute(
+ """
INSERT OR REPLACE INTO search_indexes
(id, content_id, content_type, project_id, tokens, token_positions, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- index_id, content_id, content_type, project_id,
- json.dumps(tokens, ensure_ascii=False),
- json.dumps(token_positions, ensure_ascii=False),
- now, now
- ))
+ """,
+ (
+ index_id,
+ content_id,
+ content_type,
+ project_id,
+ json.dumps(tokens, ensure_ascii=False),
+ json.dumps(token_positions, ensure_ascii=False),
+ now,
+ now,
+ ),
+ )
# 保存词频统计
for token, freq in token_freq.items():
positions = token_positions.get(token, [])
- conn.execute("""
+ conn.execute(
+ """
INSERT OR REPLACE INTO search_term_freq
(term, content_id, content_type, project_id, frequency, positions)
VALUES (?, ?, ?, ?, ?, ?)
- """, (
- token, content_id, content_type, project_id, freq,
- json.dumps(positions, ensure_ascii=False)
- ))
+ """,
+ (token, content_id, content_type, project_id, freq, json.dumps(positions, ensure_ascii=False)),
+ )
conn.commit()
conn.close()
@@ -325,9 +341,14 @@ class FullTextSearch:
print(f"索引创建失败: {e}")
return False
- def search(self, query: str, project_id: str | None = None,
- content_types: list[str] | None = None,
- limit: int = 20, offset: int = 0) -> list[SearchResult]:
+ def search(
+ self,
+ query: str,
+ project_id: str | None = None,
+ content_types: list[str] | None = None,
+ limit: int = 20,
+ offset: int = 0,
+ ) -> list[SearchResult]:
"""
全文搜索
@@ -345,9 +366,7 @@ class FullTextSearch:
parsed_query = self._parse_boolean_query(query)
# 执行搜索
- results = self._execute_boolean_search(
- parsed_query, project_id, content_types
- )
+ results = self._execute_boolean_search(parsed_query, project_id, content_types)
# 计算相关性分数
scored_results = self._score_results(results, parsed_query)
@@ -355,7 +374,7 @@ class FullTextSearch:
# 排序和分页
scored_results.sort(key=lambda x: x.score, reverse=True)
- return scored_results[offset:offset + limit]
+ return scored_results[offset : offset + limit]
def _parse_boolean_query(self, query: str) -> dict:
"""
@@ -371,7 +390,7 @@ class FullTextSearch:
# 提取短语(引号内的内容)
phrases = re.findall(r'"([^"]+)"', query)
- query_without_phrases = re.sub(r'"[^"]+"', '', query)
+ query_without_phrases = re.sub(r'"[^"]+"', "", query)
# 解析布尔操作
and_terms = []
@@ -379,13 +398,13 @@ class FullTextSearch:
not_terms = []
# 处理 NOT
- not_pattern = r'(?:NOT\s+|\-)(\w+)'
+ not_pattern = r"(?:NOT\s+|\-)(\w+)"
not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE)
not_terms.extend(not_matches)
- query_without_phrases = re.sub(not_pattern, '', query_without_phrases, flags=re.IGNORECASE)
+ query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE)
# 处理 OR
- or_parts = re.split(r'\s+OR\s+', query_without_phrases, flags=re.IGNORECASE)
+ or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE)
if len(or_parts) > 1:
or_terms = [p.strip() for p in or_parts[1:] if p.strip()]
query_without_phrases = or_parts[0]
@@ -393,16 +412,11 @@ class FullTextSearch:
# 剩余的作为 AND 条件
and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()]
- return {
- "and": and_terms + phrases,
- "or": or_terms,
- "not": not_terms,
- "phrases": phrases
- }
+ return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases}
- def _execute_boolean_search(self, parsed_query: dict,
- project_id: str | None = None,
- content_types: list[str] | None = None) -> list[dict]:
+ def _execute_boolean_search(
+ self, parsed_query: dict, project_id: str | None = None, content_types: list[str] | None = None
+ ) -> list[dict]:
"""执行布尔搜索"""
conn = self._get_conn()
@@ -415,7 +429,7 @@ class FullTextSearch:
params.append(project_id)
if content_types:
- placeholders = ','.join(['?' for _ in content_types])
+ placeholders = ",".join(["?" for _ in content_types])
base_where.append(f"content_type IN ({placeholders})")
params.extend(content_types)
@@ -427,13 +441,16 @@ class FullTextSearch:
# 处理 AND 条件
if parsed_query["and"]:
for term in parsed_query["and"]:
- term_results = conn.execute(f"""
+ term_results = conn.execute(
+ f"""
SELECT content_id, content_type, project_id, frequency, positions
FROM search_term_freq
WHERE term = ? AND {base_where_str}
- """, [term] + params).fetchall()
+ """,
+ [term] + params,
+ ).fetchall()
- term_contents = {(r['content_id'], r['content_type']) for r in term_results}
+ term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
if not candidates:
candidates = term_contents
@@ -443,13 +460,16 @@ class FullTextSearch:
# 处理 OR 条件
if parsed_query["or"]:
for term in parsed_query["or"]:
- term_results = conn.execute(f"""
+ term_results = conn.execute(
+ f"""
SELECT content_id, content_type, project_id, frequency, positions
FROM search_term_freq
WHERE term = ? AND {base_where_str}
- """, [term] + params).fetchall()
+ """,
+ [term] + params,
+ ).fetchall()
- term_contents = {(r['content_id'], r['content_type']) for r in term_results}
+ term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
candidates |= term_contents # 并集
# 如果没有 AND 和 OR,但有 phrases,使用 phrases
@@ -459,13 +479,16 @@ class FullTextSearch:
if phrase_tokens:
# 查找包含所有短语的文档
for token in phrase_tokens:
- term_results = conn.execute(f"""
+ term_results = conn.execute(
+ f"""
SELECT content_id, content_type, project_id, frequency, positions
FROM search_term_freq
WHERE term = ? AND {base_where_str}
- """, [token] + params).fetchall()
+ """,
+ [token] + params,
+ ).fetchall()
- term_contents = {(r['content_id'], r['content_type']) for r in term_results}
+ term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
if not candidates:
candidates = term_contents
@@ -475,13 +498,16 @@ class FullTextSearch:
# 处理 NOT 条件(排除)
if parsed_query["not"]:
for term in parsed_query["not"]:
- term_results = conn.execute(f"""
+ term_results = conn.execute(
+ f"""
SELECT content_id, content_type
FROM search_term_freq
WHERE term = ? AND {base_where_str}
- """, [term] + params).fetchall()
+ """,
+ [term] + params,
+ ).fetchall()
- term_contents = {(r['content_id'], r['content_type']) for r in term_results}
+ term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
candidates -= term_contents # 差集
# 获取完整内容
@@ -490,33 +516,28 @@ class FullTextSearch:
# 获取原始内容
content = self._get_content_by_id(conn, content_id, content_type)
if content:
- results.append({
- "id": content_id,
- "content_type": content_type,
- "project_id": project_id or self._get_project_id(conn, content_id, content_type),
- "content": content,
- "terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"]
- })
+ results.append(
+ {
+ "id": content_id,
+ "content_type": content_type,
+ "project_id": project_id or self._get_project_id(conn, content_id, content_type),
+ "content": content,
+ "terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"],
+ }
+ )
conn.close()
return results
- def _get_content_by_id(self, conn: sqlite3.Connection,
- content_id: str, content_type: str) -> str | None:
+ def _get_content_by_id(self, conn: sqlite3.Connection, content_id: str, content_type: str) -> str | None:
"""根据ID获取内容"""
try:
if content_type == "transcript":
- row = conn.execute(
- "SELECT full_text FROM transcripts WHERE id = ?",
- (content_id,)
- ).fetchone()
- return row['full_text'] if row else None
+ row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (content_id,)).fetchone()
+ return row["full_text"] if row else None
elif content_type == "entity":
- row = conn.execute(
- "SELECT name, definition FROM entities WHERE id = ?",
- (content_id,)
- ).fetchone()
+ row = conn.execute("SELECT name, definition FROM entities WHERE id = ?", (content_id,)).fetchone()
if row:
return f"{row['name']} {row['definition'] or ''}"
return None
@@ -529,7 +550,7 @@ class FullTextSearch:
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""",
- (content_id,)
+ (content_id,),
).fetchone()
if row:
return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}"
@@ -540,29 +561,19 @@ class FullTextSearch:
print(f"获取内容失败: {e}")
return None
- def _get_project_id(self, conn: sqlite3.Connection,
- content_id: str, content_type: str) -> str | None:
+ def _get_project_id(self, conn: sqlite3.Connection, content_id: str, content_type: str) -> str | None:
"""获取内容所属的项目ID"""
try:
if content_type == "transcript":
- row = conn.execute(
- "SELECT project_id FROM transcripts WHERE id = ?",
- (content_id,)
- ).fetchone()
+ row = conn.execute("SELECT project_id FROM transcripts WHERE id = ?", (content_id,)).fetchone()
elif content_type == "entity":
- row = conn.execute(
- "SELECT project_id FROM entities WHERE id = ?",
- (content_id,)
- ).fetchone()
+ row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (content_id,)).fetchone()
elif content_type == "relation":
- row = conn.execute(
- "SELECT project_id FROM entity_relations WHERE id = ?",
- (content_id,)
- ).fetchone()
+ row = conn.execute("SELECT project_id FROM entity_relations WHERE id = ?", (content_id,)).fetchone()
else:
return None
- return row['project_id'] if row else None
+ return row["project_id"] if row else None
except Exception:
return None
@@ -612,20 +623,21 @@ class FullTextSearch:
# 归一化分数
score = min(score / max(len(all_terms), 1), 10.0)
- scored.append(SearchResult(
- id=result["id"],
- content=result["content"],
- content_type=result["content_type"],
- project_id=result["project_id"],
- score=round(score, 4),
- highlights=highlights[:10], # 限制高亮数量
- metadata={}
- ))
+ scored.append(
+ SearchResult(
+ id=result["id"],
+ content=result["content"],
+ content_type=result["content_type"],
+ project_id=result["project_id"],
+ score=round(score, 4),
+ highlights=highlights[:10], # 限制高亮数量
+ metadata={},
+ )
+ )
return scored
- def highlight_text(self, text: str, query: str,
- max_length: int = 300) -> str:
+ def highlight_text(self, text: str, query: str, max_length: int = 300) -> str:
"""
高亮文本中的关键词
@@ -671,14 +683,12 @@ class FullTextSearch:
# 删除索引
conn.execute(
- "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?",
- (content_id, content_type)
+ "DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?", (content_id, content_type)
)
# 删除词频统计
conn.execute(
- "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?",
- (content_id, content_type)
+ "DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?", (content_id, content_type)
)
conn.commit()
@@ -696,26 +706,24 @@ class FullTextSearch:
try:
# 索引转录文本
transcripts = conn.execute(
- "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
- (project_id,)
+ "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id,)
).fetchall()
for t in transcripts:
- if t['full_text']:
- if self.index_content(t['id'], 'transcript', t['project_id'], t['full_text']):
+ if t["full_text"]:
+ if self.index_content(t["id"], "transcript", t["project_id"], t["full_text"]):
stats["transcripts"] += 1
else:
stats["errors"] += 1
# 索引实体
entities = conn.execute(
- "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
- (project_id,)
+ "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
for e in entities:
text = f"{e['name']} {e['definition'] or ''}"
- if self.index_content(e['id'], 'entity', e['project_id'], text):
+ if self.index_content(e["id"], "entity", e["project_id"], text):
stats["entities"] += 1
else:
stats["errors"] += 1
@@ -728,12 +736,12 @@ class FullTextSearch:
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ?""",
- (project_id,)
+ (project_id,),
).fetchall()
for r in relations:
text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}"
- if self.index_content(r['id'], 'relation', r['project_id'], text):
+ if self.index_content(r["id"], "relation", r["project_id"], text):
stats["relations"] += 1
else:
stats["errors"] += 1
@@ -748,6 +756,7 @@ class FullTextSearch:
# ==================== 语义搜索 ====================
+
class SemanticSearch:
"""
语义搜索模块
@@ -759,8 +768,7 @@ class SemanticSearch:
- 语义相似内容推荐
"""
- def __init__(self, db_path: str = "insightflow.db",
- model_name: str = "paraphrase-multilingual-MiniLM-L12-v2"):
+ def __init__(self, db_path: str = "insightflow.db", model_name: str = "paraphrase-multilingual-MiniLM-L12-v2"):
self.db_path = db_path
self.model_name = model_name
self.model = None
@@ -780,7 +788,7 @@ class SemanticSearch:
conn.row_factory = sqlite3.Row
return conn
- def _init_embedding_tables(self):
+ def _init_embedding_tables(self) -> None:
"""初始化 embedding 相关表"""
conn = self._get_conn()
@@ -832,8 +840,7 @@ class SemanticSearch:
print(f"生成 embedding 失败: {e}")
return None
- def index_embedding(self, content_id: str, content_type: str,
- project_id: str, text: str) -> bool:
+ def index_embedding(self, content_id: str, content_type: str, project_id: str, text: str) -> bool:
"""
为内容生成并保存 embedding
@@ -858,16 +865,22 @@ class SemanticSearch:
embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16]
- conn.execute("""
+ conn.execute(
+ """
INSERT OR REPLACE INTO embeddings
(id, content_id, content_type, project_id, embedding, model_name, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
- """, (
- embedding_id, content_id, content_type, project_id,
- json.dumps(embedding),
- self.model_name,
- datetime.now().isoformat()
- ))
+ """,
+ (
+ embedding_id,
+ content_id,
+ content_type,
+ project_id,
+ json.dumps(embedding),
+ self.model_name,
+ datetime.now().isoformat(),
+ ),
+ )
conn.commit()
conn.close()
@@ -877,9 +890,14 @@ class SemanticSearch:
print(f"索引 embedding 失败: {e}")
return False
- def search(self, query: str, project_id: str | None = None,
- content_types: list[str] | None = None,
- top_k: int = 10, threshold: float = 0.5) -> list[SemanticSearchResult]:
+ def search(
+ self,
+ query: str,
+ project_id: str | None = None,
+ content_types: list[str] | None = None,
+ top_k: int = 10,
+ threshold: float = 0.5,
+ ) -> list[SemanticSearchResult]:
"""
语义搜索
@@ -912,17 +930,20 @@ class SemanticSearch:
params.append(project_id)
if content_types:
- placeholders = ','.join(['?' for _ in content_types])
+ placeholders = ",".join(["?" for _ in content_types])
where_clauses.append(f"content_type IN ({placeholders})")
params.extend(content_types)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
- rows = conn.execute(f"""
+ rows = conn.execute(
+ f"""
SELECT content_id, content_type, project_id, embedding
FROM embeddings
WHERE {where_str}
- """, params).fetchall()
+ """,
+ params,
+ ).fetchall()
conn.close()
@@ -932,24 +953,26 @@ class SemanticSearch:
for row in rows:
try:
- content_embedding = json.loads(row['embedding'])
+ content_embedding = json.loads(row["embedding"])
# 计算余弦相似度
similarity = cosine_similarity(query_vec, [content_embedding])[0][0]
if similarity >= threshold:
# 获取原始内容
- content = self._get_content_text(row['content_id'], row['content_type'])
+ content = self._get_content_text(row["content_id"], row["content_type"])
- results.append(SemanticSearchResult(
- id=row['content_id'],
- content=content or "",
- content_type=row['content_type'],
- project_id=row['project_id'],
- similarity=float(similarity),
- embedding=None, # 不返回 embedding 以节省带宽
- metadata={}
- ))
+ results.append(
+ SemanticSearchResult(
+ id=row["content_id"],
+ content=content or "",
+ content_type=row["content_type"],
+ project_id=row["project_id"],
+ similarity=float(similarity),
+ embedding=None, # 不返回 embedding 以节省带宽
+ metadata={},
+ )
+ )
except Exception as e:
print(f"计算相似度失败: {e}")
continue
@@ -964,17 +987,11 @@ class SemanticSearch:
try:
if content_type == "transcript":
- row = conn.execute(
- "SELECT full_text FROM transcripts WHERE id = ?",
- (content_id,)
- ).fetchone()
- result = row['full_text'] if row else None
+ row = conn.execute("SELECT full_text FROM transcripts WHERE id = ?", (content_id,)).fetchone()
+ result = row["full_text"] if row else None
elif content_type == "entity":
- row = conn.execute(
- "SELECT name, definition FROM entities WHERE id = ?",
- (content_id,)
- ).fetchone()
+ row = conn.execute("SELECT name, definition FROM entities WHERE id = ?", (content_id,)).fetchone()
result = f"{row['name']}: {row['definition']}" if row else None
elif content_type == "relation":
@@ -985,7 +1002,7 @@ class SemanticSearch:
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""",
- (content_id,)
+ (content_id,),
).fetchone()
result = f"{row['source_name']} {row['relation_type']} {row['target_name']}" if row else None
@@ -1000,8 +1017,7 @@ class SemanticSearch:
print(f"获取内容失败: {e}")
return None
- def find_similar_content(self, content_id: str, content_type: str,
- top_k: int = 5) -> list[SemanticSearchResult]:
+ def find_similar_content(self, content_id: str, content_type: str, top_k: int = 5) -> list[SemanticSearchResult]:
"""
查找与指定内容相似的内容
@@ -1021,22 +1037,22 @@ class SemanticSearch:
row = conn.execute(
"SELECT embedding, project_id FROM embeddings WHERE content_id = ? AND content_type = ?",
- (content_id, content_type)
+ (content_id, content_type),
).fetchone()
if not row:
conn.close()
return []
- source_embedding = json.loads(row['embedding'])
- project_id = row['project_id']
+ source_embedding = json.loads(row["embedding"])
+ project_id = row["project_id"]
# 获取其他内容的 embedding
rows = conn.execute(
"""SELECT content_id, content_type, project_id, embedding
FROM embeddings
WHERE project_id = ? AND (content_id != ? OR content_type != ?)""",
- (project_id, content_id, content_type)
+ (project_id, content_id, content_type),
).fetchall()
conn.close()
@@ -1047,19 +1063,21 @@ class SemanticSearch:
for row in rows:
try:
- content_embedding = json.loads(row['embedding'])
+ content_embedding = json.loads(row["embedding"])
similarity = cosine_similarity(source_vec, [content_embedding])[0][0]
- content = self._get_content_text(row['content_id'], row['content_type'])
+ content = self._get_content_text(row["content_id"], row["content_type"])
- results.append(SemanticSearchResult(
- id=row['content_id'],
- content=content or "",
- content_type=row['content_type'],
- project_id=row['project_id'],
- similarity=float(similarity),
- metadata={}
- ))
+ results.append(
+ SemanticSearchResult(
+ id=row["content_id"],
+ content=content or "",
+ content_type=row["content_type"],
+ project_id=row["project_id"],
+ similarity=float(similarity),
+ metadata={},
+ )
+ )
except Exception:
continue
@@ -1070,10 +1088,7 @@ class SemanticSearch:
"""删除内容的 embedding"""
try:
conn = self._get_conn()
- conn.execute(
- "DELETE FROM embeddings WHERE content_id = ? AND content_type = ?",
- (content_id, content_type)
- )
+ conn.execute("DELETE FROM embeddings WHERE content_id = ? AND content_type = ?", (content_id, content_type))
conn.commit()
conn.close()
return True
@@ -1084,6 +1099,7 @@ class SemanticSearch:
# ==================== 实体关系路径发现 ====================
+
class EntityPathDiscovery:
"""
实体关系路径发现模块
@@ -1104,9 +1120,7 @@ class EntityPathDiscovery:
conn.row_factory = sqlite3.Row
return conn
- def find_shortest_path(self, source_entity_id: str,
- target_entity_id: str,
- max_depth: int = 5) -> EntityPath | None:
+ def find_shortest_path(self, source_entity_id: str, target_entity_id: str, max_depth: int = 5) -> EntityPath | None:
"""
查找两个实体之间的最短路径(BFS算法)
@@ -1121,21 +1135,17 @@ class EntityPathDiscovery:
conn = self._get_conn()
# 获取项目ID
- row = conn.execute(
- "SELECT project_id FROM entities WHERE id = ?",
- (source_entity_id,)
- ).fetchone()
+ row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)).fetchone()
if not row:
conn.close()
return None
- project_id = row['project_id']
+ project_id = row["project_id"]
# 验证目标实体也在同一项目
row = conn.execute(
- "SELECT 1 FROM entities WHERE id = ? AND project_id = ?",
- (target_entity_id, project_id)
+ "SELECT 1 FROM entities WHERE id = ? AND project_id = ?", (target_entity_id, project_id)
).fetchone()
if not row:
@@ -1158,7 +1168,8 @@ class EntityPathDiscovery:
return self._build_path_object(path, project_id)
# 获取邻居
- neighbors = conn.execute("""
+ neighbors = conn.execute(
+ """
SELECT target_entity_id as neighbor_id, relation_type, evidence
FROM entity_relations
WHERE source_entity_id = ? AND project_id = ?
@@ -1166,10 +1177,12 @@ class EntityPathDiscovery:
SELECT source_entity_id as neighbor_id, relation_type, evidence
FROM entity_relations
WHERE target_entity_id = ? AND project_id = ?
- """, (current_id, project_id, current_id, project_id)).fetchall()
+ """,
+ (current_id, project_id, current_id, project_id),
+ ).fetchall()
for neighbor in neighbors:
- neighbor_id = neighbor['neighbor_id']
+ neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited:
visited.add(neighbor_id)
queue.append((neighbor_id, path + [neighbor_id]))
@@ -1177,10 +1190,9 @@ class EntityPathDiscovery:
conn.close()
return None
- def find_all_paths(self, source_entity_id: str,
- target_entity_id: str,
- max_depth: int = 4,
- max_paths: int = 10) -> list[EntityPath]:
+ def find_all_paths(
+ self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10
+ ) -> list[EntityPath]:
"""
查找两个实体之间的所有路径(限制数量和深度)
@@ -1196,21 +1208,17 @@ class EntityPathDiscovery:
conn = self._get_conn()
# 获取项目ID
- row = conn.execute(
- "SELECT project_id FROM entities WHERE id = ?",
- (source_entity_id,)
- ).fetchone()
+ row = conn.execute("SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)).fetchone()
if not row:
conn.close()
return []
- project_id = row['project_id']
+ project_id = row["project_id"]
paths = []
- def dfs(current_id: str, target_id: str,
- path: list[str], visited: set[str], depth: int):
+ def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int):
if depth > max_depth:
return
@@ -1219,7 +1227,8 @@ class EntityPathDiscovery:
return
# 获取邻居
- neighbors = conn.execute("""
+ neighbors = conn.execute(
+ """
SELECT target_entity_id as neighbor_id
FROM entity_relations
WHERE source_entity_id = ? AND project_id = ?
@@ -1227,10 +1236,12 @@ class EntityPathDiscovery:
SELECT source_entity_id as neighbor_id
FROM entity_relations
WHERE target_entity_id = ? AND project_id = ?
- """, (current_id, project_id, current_id, project_id)).fetchall()
+ """,
+ (current_id, project_id, current_id, project_id),
+ ).fetchall()
for neighbor in neighbors:
- neighbor_id = neighbor['neighbor_id']
+ neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited and len(paths) < max_paths:
visited.add(neighbor_id)
path.append(neighbor_id)
@@ -1246,24 +1257,16 @@ class EntityPathDiscovery:
# 构建路径对象
return [self._build_path_object(path, project_id) for path in paths]
- def _build_path_object(self, entity_ids: list[str],
- project_id: str) -> EntityPath:
+ def _build_path_object(self, entity_ids: list[str], project_id: str) -> EntityPath:
"""构建路径对象"""
conn = self._get_conn()
# 获取实体信息
nodes = []
for entity_id in entity_ids:
- row = conn.execute(
- "SELECT id, name, type FROM entities WHERE id = ?",
- (entity_id,)
- ).fetchone()
+ row = conn.execute("SELECT id, name, type FROM entities WHERE id = ?", (entity_id,)).fetchone()
if row:
- nodes.append({
- "id": row['id'],
- "name": row['name'],
- "type": row['type']
- })
+ nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
# 获取边信息
edges = []
@@ -1271,27 +1274,32 @@ class EntityPathDiscovery:
source_id = entity_ids[i]
target_id = entity_ids[i + 1]
- row = conn.execute("""
+ row = conn.execute(
+ """
SELECT id, relation_type, evidence
FROM entity_relations
WHERE ((source_entity_id = ? AND target_entity_id = ?)
OR (source_entity_id = ? AND target_entity_id = ?))
AND project_id = ?
- """, (source_id, target_id, target_id, source_id, project_id)).fetchone()
+ """,
+ (source_id, target_id, target_id, source_id, project_id),
+ ).fetchone()
if row:
- edges.append({
- "id": row['id'],
- "source": source_id,
- "target": target_id,
- "relation_type": row['relation_type'],
- "evidence": row['evidence']
- })
+ edges.append(
+ {
+ "id": row["id"],
+ "source": source_id,
+ "target": target_id,
+ "relation_type": row["relation_type"],
+ "evidence": row["evidence"],
+ }
+ )
conn.close()
# 生成路径描述
- node_names = [n['name'] for n in nodes]
+ node_names = [n["name"] for n in nodes]
path_desc = " → ".join(node_names)
# 计算置信度(基于路径长度和关系数量)
@@ -1300,18 +1308,17 @@ class EntityPathDiscovery:
return EntityPath(
path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}",
source_entity_id=entity_ids[0],
- source_entity_name=nodes[0]['name'] if nodes else "",
+ source_entity_name=nodes[0]["name"] if nodes else "",
target_entity_id=entity_ids[-1],
- target_entity_name=nodes[-1]['name'] if nodes else "",
+ target_entity_name=nodes[-1]["name"] if nodes else "",
path_length=len(entity_ids) - 1,
nodes=nodes,
edges=edges,
confidence=round(confidence, 4),
- path_description=path_desc
+ path_description=path_desc,
)
- def find_multi_hop_relations(self, entity_id: str,
- max_hops: int = 3) -> list[dict]:
+ def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]:
"""
查找实体的多跳关系
@@ -1325,17 +1332,14 @@ class EntityPathDiscovery:
conn = self._get_conn()
# 获取项目ID
- row = conn.execute(
- "SELECT project_id, name FROM entities WHERE id = ?",
- (entity_id,)
- ).fetchone()
+ row = conn.execute("SELECT project_id, name FROM entities WHERE id = ?", (entity_id,)).fetchone()
if not row:
conn.close()
return []
- project_id = row['project_id']
- row['name']
+ project_id = row["project_id"]
+ row["name"]
# BFS 收集多跳关系
visited = {entity_id: 0}
@@ -1349,7 +1353,8 @@ class EntityPathDiscovery:
continue
# 获取邻居
- neighbors = conn.execute("""
+ neighbors = conn.execute(
+ """
SELECT
CASE
WHEN source_entity_id = ? THEN target_entity_id
@@ -1360,10 +1365,12 @@ class EntityPathDiscovery:
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
- """, (current_id, current_id, current_id, project_id)).fetchall()
+ """,
+ (current_id, current_id, current_id, project_id),
+ ).fetchall()
for neighbor in neighbors:
- neighbor_id = neighbor['neighbor_id']
+ neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited:
visited[neighbor_id] = depth + 1
@@ -1371,29 +1378,31 @@ class EntityPathDiscovery:
# 获取邻居信息
neighbor_info = conn.execute(
- "SELECT name, type FROM entities WHERE id = ?",
- (neighbor_id,)
+ "SELECT name, type FROM entities WHERE id = ?", (neighbor_id,)
).fetchone()
if neighbor_info:
- relations.append({
- "entity_id": neighbor_id,
- "entity_name": neighbor_info['name'],
- "entity_type": neighbor_info['type'],
- "hops": depth + 1,
- "relation_type": neighbor['relation_type'],
- "evidence": neighbor['evidence'],
- "path": self._get_path_to_entity(entity_id, neighbor_id, project_id, conn)
- })
+ relations.append(
+ {
+ "entity_id": neighbor_id,
+ "entity_name": neighbor_info["name"],
+ "entity_type": neighbor_info["type"],
+ "hops": depth + 1,
+ "relation_type": neighbor["relation_type"],
+ "evidence": neighbor["evidence"],
+ "path": self._get_path_to_entity(entity_id, neighbor_id, project_id, conn),
+ }
+ )
conn.close()
# 按跳数排序
- relations.sort(key=lambda x: x['hops'])
+ relations.sort(key=lambda x: x["hops"])
return relations
- def _get_path_to_entity(self, source_id: str, target_id: str,
- project_id: str, conn: sqlite3.Connection) -> list[str]:
+ def _get_path_to_entity(
+ self, source_id: str, target_id: str, project_id: str, conn: sqlite3.Connection
+ ) -> list[str]:
"""获取从源实体到目标实体的路径(简化版)"""
# BFS 找路径
visited = {source_id}
@@ -1408,7 +1417,8 @@ class EntityPathDiscovery:
if len(path) > 5: # 限制路径长度
continue
- neighbors = conn.execute("""
+ neighbors = conn.execute(
+ """
SELECT
CASE
WHEN source_entity_id = ? THEN target_entity_id
@@ -1417,10 +1427,12 @@ class EntityPathDiscovery:
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
- """, (current, current, current, project_id)).fetchall()
+ """,
+ (current, current, current, project_id),
+ ).fetchall()
for neighbor in neighbors:
- neighbor_id = neighbor['neighbor_id']
+ neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited:
visited.add(neighbor_id)
queue.append((neighbor_id, path + [neighbor_id]))
@@ -1440,30 +1452,34 @@ class EntityPathDiscovery:
# 节点数据
nodes = []
for node in path.nodes:
- nodes.append({
- "id": node["id"],
- "name": node["name"],
- "type": node["type"],
- "is_source": node["id"] == path.source_entity_id,
- "is_target": node["id"] == path.target_entity_id
- })
+ nodes.append(
+ {
+ "id": node["id"],
+ "name": node["name"],
+ "type": node["type"],
+ "is_source": node["id"] == path.source_entity_id,
+ "is_target": node["id"] == path.target_entity_id,
+ }
+ )
# 边数据
links = []
for edge in path.edges:
- links.append({
- "source": edge["source"],
- "target": edge["target"],
- "relation_type": edge["relation_type"],
- "evidence": edge["evidence"]
- })
+ links.append(
+ {
+ "source": edge["source"],
+ "target": edge["target"],
+ "relation_type": edge["relation_type"],
+ "evidence": edge["evidence"],
+ }
+ )
return {
"nodes": nodes,
"links": links,
"path_description": path.path_description,
"path_length": path.path_length,
- "confidence": path.confidence
+ "confidence": path.confidence,
}
def analyze_path_centrality(self, project_id: str) -> list[dict]:
@@ -1479,19 +1495,17 @@ class EntityPathDiscovery:
conn = self._get_conn()
# 获取所有实体
- entities = conn.execute(
- "SELECT id, name FROM entities WHERE project_id = ?",
- (project_id,)
- ).fetchall()
+ entities = conn.execute("SELECT id, name FROM entities WHERE project_id = ?", (project_id,)).fetchall()
# 计算每个实体作为桥梁的次数
bridge_scores = []
for entity in entities:
- entity_id = entity['id']
+ entity_id = entity["id"]
# 计算该实体连接的不同群组数量
- neighbors = conn.execute("""
+ neighbors = conn.execute(
+ """
SELECT
CASE
WHEN source_entity_id = ? THEN target_entity_id
@@ -1500,13 +1514,16 @@ class EntityPathDiscovery:
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
- """, (entity_id, entity_id, entity_id, project_id)).fetchall()
+ """,
+ (entity_id, entity_id, entity_id, project_id),
+ ).fetchall()
- neighbor_ids = {n['neighbor_id'] for n in neighbors}
+ neighbor_ids = {n["neighbor_id"] for n in neighbors}
# 计算邻居之间的连接数(用于评估桥接程度)
if len(neighbor_ids) > 1:
- connections = conn.execute(f"""
+ connections = conn.execute(
+ f"""
SELECT COUNT(*) as count
FROM entity_relations
WHERE ((source_entity_id IN ({','.join(['?' for _ in neighbor_ids])})
@@ -1514,29 +1531,34 @@ class EntityPathDiscovery:
OR (target_entity_id IN ({','.join(['?' for _ in neighbor_ids])})
AND source_entity_id IN ({','.join(['?' for _ in neighbor_ids])})))
AND project_id = ?
- """, list(neighbor_ids) * 4 + [project_id]).fetchone()
+ """,
+ list(neighbor_ids) * 4 + [project_id],
+ ).fetchone()
# 桥接分数 = 邻居数量 / (邻居间连接数 + 1)
- bridge_score = len(neighbor_ids) / (connections['count'] + 1)
+ bridge_score = len(neighbor_ids) / (connections["count"] + 1)
else:
bridge_score = 0
- bridge_scores.append({
- "entity_id": entity_id,
- "entity_name": entity['name'],
- "neighbor_count": len(neighbor_ids),
- "bridge_score": round(bridge_score, 4)
- })
+ bridge_scores.append(
+ {
+ "entity_id": entity_id,
+ "entity_name": entity["name"],
+ "neighbor_count": len(neighbor_ids),
+ "bridge_score": round(bridge_score, 4),
+ }
+ )
conn.close()
# 按桥接分数排序
- bridge_scores.sort(key=lambda x: x['bridge_score'], reverse=True)
+ bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True)
return bridge_scores[:20] # 返回前20
# ==================== 知识缺口识别 ====================
+
class KnowledgeGapDetection:
"""
知识缺口识别模块
@@ -1597,36 +1619,31 @@ class KnowledgeGapDetection:
# 获取项目的属性模板
templates = conn.execute(
- "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?",
- (project_id,)
+ "SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?", (project_id,)
).fetchall()
if not templates:
conn.close()
return []
- required_template_ids = {t['id'] for t in templates if t['is_required']}
+ required_template_ids = {t["id"] for t in templates if t["is_required"]}
if not required_template_ids:
conn.close()
return []
# 检查每个实体的属性完整性
- entities = conn.execute(
- "SELECT id, name FROM entities WHERE project_id = ?",
- (project_id,)
- ).fetchall()
+ entities = conn.execute("SELECT id, name FROM entities WHERE project_id = ?", (project_id,)).fetchall()
for entity in entities:
- entity_id = entity['id']
+ entity_id = entity["id"]
# 获取实体已有的属性
existing_attrs = conn.execute(
- "SELECT template_id FROM entity_attributes WHERE entity_id = ?",
- (entity_id,)
+ "SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,)
).fetchall()
- existing_template_ids = {a['template_id'] for a in existing_attrs}
+ existing_template_ids = {a["template_id"] for a in existing_attrs}
# 找出缺失的必需属性
missing_templates = required_template_ids - existing_template_ids
@@ -1635,27 +1652,28 @@ class KnowledgeGapDetection:
missing_names = []
for template_id in missing_templates:
template = conn.execute(
- "SELECT name FROM attribute_templates WHERE id = ?",
- (template_id,)
+ "SELECT name FROM attribute_templates WHERE id = ?", (template_id,)
).fetchone()
if template:
- missing_names.append(template['name'])
+ missing_names.append(template["name"])
if missing_names:
- gaps.append(KnowledgeGap(
- gap_id=f"gap_attr_{entity_id}",
- gap_type="missing_attribute",
- entity_id=entity_id,
- entity_name=entity['name'],
- description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
- severity="medium",
- suggestions=[
- f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}",
- "检查属性模板定义是否合理"
- ],
- related_entities=[],
- metadata={"missing_attributes": missing_names}
- ))
+ gaps.append(
+ KnowledgeGap(
+ gap_id=f"gap_attr_{entity_id}",
+ gap_type="missing_attribute",
+ entity_id=entity_id,
+ entity_name=entity["name"],
+ description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
+ severity="medium",
+ suggestions=[
+ f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}",
+ "检查属性模板定义是否合理",
+ ],
+ related_entities=[],
+ metadata={"missing_attributes": missing_names},
+ )
+ )
conn.close()
return gaps
@@ -1666,28 +1684,29 @@ class KnowledgeGapDetection:
gaps = []
# 获取所有实体及其关系数量
- entities = conn.execute(
- "SELECT id, name, type FROM entities WHERE project_id = ?",
- (project_id,)
- ).fetchall()
+ entities = conn.execute("SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,)).fetchall()
for entity in entities:
- entity_id = entity['id']
+ entity_id = entity["id"]
# 计算关系数量
- relation_count = conn.execute("""
+ relation_count = conn.execute(
+ """
SELECT COUNT(*) as count
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
- """, (entity_id, entity_id, project_id)).fetchone()['count']
+ """,
+ (entity_id, entity_id, project_id),
+ ).fetchone()["count"]
# 根据实体类型判断阈值
- threshold = 1 if entity['type'] in ['PERSON', 'ORG'] else 0
+ threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0
if relation_count <= threshold:
# 查找潜在的相关实体
- potential_related = conn.execute("""
+ potential_related = conn.execute(
+ """
SELECT e.id, e.name
FROM entities e
JOIN transcripts t ON t.project_id = e.project_id
@@ -1695,26 +1714,30 @@ class KnowledgeGapDetection:
AND e.id != ?
AND t.full_text LIKE ?
LIMIT 5
- """, (project_id, entity_id, f"%{entity['name']}%")).fetchall()
+ """,
+ (project_id, entity_id, f"%{entity['name']}%"),
+ ).fetchall()
- gaps.append(KnowledgeGap(
- gap_id=f"gap_sparse_{entity_id}",
- gap_type="sparse_relation",
- entity_id=entity_id,
- entity_name=entity['name'],
- description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
- severity="medium" if relation_count == 0 else "low",
- suggestions=[
- f"检查转录文本中提及 '{entity['name']}' 的其他实体",
- f"手动添加 '{entity['name']}' 与其他实体的关系",
- "使用实体对齐功能合并相似实体"
- ],
- related_entities=[r['id'] for r in potential_related],
- metadata={
- "relation_count": relation_count,
- "potential_related": [r['name'] for r in potential_related]
- }
- ))
+ gaps.append(
+ KnowledgeGap(
+ gap_id=f"gap_sparse_{entity_id}",
+ gap_type="sparse_relation",
+ entity_id=entity_id,
+ entity_name=entity["name"],
+ description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
+ severity="medium" if relation_count == 0 else "low",
+ suggestions=[
+ f"检查转录文本中提及 '{entity['name']}' 的其他实体",
+ f"手动添加 '{entity['name']}' 与其他实体的关系",
+ "使用实体对齐功能合并相似实体",
+ ],
+ related_entities=[r["id"] for r in potential_related],
+ metadata={
+ "relation_count": relation_count,
+ "potential_related": [r["name"] for r in potential_related],
+ },
+ )
+ )
conn.close()
return gaps
@@ -1725,7 +1748,8 @@ class KnowledgeGapDetection:
gaps = []
# 查找没有关系的实体
- isolated = conn.execute("""
+ isolated = conn.execute(
+ """
SELECT e.id, e.name, e.type
FROM entities e
LEFT JOIN entity_relations r1 ON e.id = r1.source_entity_id
@@ -1733,24 +1757,28 @@ class KnowledgeGapDetection:
WHERE e.project_id = ?
AND r1.id IS NULL
AND r2.id IS NULL
- """, (project_id,)).fetchall()
+ """,
+ (project_id,),
+ ).fetchall()
for entity in isolated:
- gaps.append(KnowledgeGap(
- gap_id=f"gap_iso_{entity['id']}",
- gap_type="isolated_entity",
- entity_id=entity['id'],
- entity_name=entity['name'],
- description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
- severity="high",
- suggestions=[
- f"检查 '{entity['name']}' 是否应该与其他实体建立关系",
- f"考虑删除不相关的实体 '{entity['name']}'",
- "运行关系发现算法自动识别潜在关系"
- ],
- related_entities=[],
- metadata={"entity_type": entity['type']}
- ))
+ gaps.append(
+ KnowledgeGap(
+ gap_id=f"gap_iso_{entity['id']}",
+ gap_type="isolated_entity",
+ entity_id=entity["id"],
+ entity_name=entity["name"],
+ description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
+ severity="high",
+ suggestions=[
+ f"检查 '{entity['name']}' 是否应该与其他实体建立关系",
+ f"考虑删除不相关的实体 '{entity['name']}'",
+ "运行关系发现算法自动识别潜在关系",
+ ],
+ related_entities=[],
+ metadata={"entity_type": entity["type"]},
+ )
+ )
conn.close()
return gaps
@@ -1761,28 +1789,30 @@ class KnowledgeGapDetection:
gaps = []
# 查找缺少定义的实体
- incomplete = conn.execute("""
+ incomplete = conn.execute(
+ """
SELECT id, name, type, definition
FROM entities
WHERE project_id = ?
AND (definition IS NULL OR definition = '')
- """, (project_id,)).fetchall()
+ """,
+ (project_id,),
+ ).fetchall()
for entity in incomplete:
- gaps.append(KnowledgeGap(
- gap_id=f"gap_inc_{entity['id']}",
- gap_type="incomplete_entity",
- entity_id=entity['id'],
- entity_name=entity['name'],
- description=f"实体 '{entity['name']}' 缺少定义",
- severity="low",
- suggestions=[
- f"为 '{entity['name']}' 添加定义",
- "从转录文本中提取定义信息"
- ],
- related_entities=[],
- metadata={"entity_type": entity['type']}
- ))
+ gaps.append(
+ KnowledgeGap(
+ gap_id=f"gap_inc_{entity['id']}",
+ gap_type="incomplete_entity",
+ entity_id=entity["id"],
+ entity_name=entity["name"],
+ description=f"实体 '{entity['name']}' 缺少定义",
+ severity="low",
+ suggestions=[f"为 '{entity['name']}' 添加定义", "从转录文本中提取定义信息"],
+ related_entities=[],
+ metadata={"entity_type": entity["type"]},
+ )
+ )
conn.close()
return gaps
@@ -1793,25 +1823,19 @@ class KnowledgeGapDetection:
gaps = []
# 分析转录文本中频繁提及但未提取为实体的词
- transcripts = conn.execute(
- "SELECT full_text FROM transcripts WHERE project_id = ?",
- (project_id,)
- ).fetchall()
+ transcripts = conn.execute("SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,)).fetchall()
# 合并所有文本
- all_text = " ".join([t['full_text'] or "" for t in transcripts])
+ all_text = " ".join([t["full_text"] or "" for t in transcripts])
# 获取现有实体名称
- existing_entities = conn.execute(
- "SELECT name FROM entities WHERE project_id = ?",
- (project_id,)
- ).fetchall()
+ existing_entities = conn.execute("SELECT name FROM entities WHERE project_id = ?", (project_id,)).fetchall()
- existing_names = {e['name'].lower() for e in existing_entities}
+ existing_names = {e["name"].lower() for e in existing_entities}
# 简单的关键词提取(实际可以使用更复杂的 NLP 方法)
# 查找大写的词组(可能是专有名词)
- potential_entities = re.findall(r'[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*', all_text)
+ potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text)
# 统计频率
freq = defaultdict(int)
@@ -1822,20 +1846,19 @@ class KnowledgeGapDetection:
# 找出高频但未提取的词
for entity, count in freq.items():
if count >= 3: # 出现3次以上
- gaps.append(KnowledgeGap(
- gap_id=f"gap_missing_{hash(entity) % 10000}",
- gap_type="missing_key_entity",
- entity_id=None,
- entity_name=None,
- description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
- severity="low",
- suggestions=[
- f"考虑将 '{entity}' 添加为实体",
- "检查实体提取算法是否需要优化"
- ],
- related_entities=[],
- metadata={"mention_count": count}
- ))
+ gaps.append(
+ KnowledgeGap(
+ gap_id=f"gap_missing_{hash(entity) % 10000}",
+ gap_type="missing_key_entity",
+ entity_id=None,
+ entity_name=None,
+ description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
+ severity="low",
+ suggestions=[f"考虑将 '{entity}' 添加为实体", "检查实体提取算法是否需要优化"],
+ related_entities=[],
+ metadata={"mention_count": count},
+ )
+ )
conn.close()
return gaps[:10] # 限制数量
@@ -1853,12 +1876,15 @@ class KnowledgeGapDetection:
conn = self._get_conn()
# 基础统计
- stats = conn.execute("""
+ stats = conn.execute(
+ """
SELECT
(SELECT COUNT(*) FROM entities WHERE project_id = ?) as entity_count,
(SELECT COUNT(*) FROM entity_relations WHERE project_id = ?) as relation_count,
(SELECT COUNT(*) FROM transcripts WHERE project_id = ?) as transcript_count
- """, (project_id, project_id, project_id)).fetchone()
+ """,
+ (project_id, project_id, project_id),
+ ).fetchone()
# 计算完整性分数
gaps = self.analyze_project(project_id)
@@ -1884,17 +1910,13 @@ class KnowledgeGapDetection:
"project_id": project_id,
"completeness_score": score,
"statistics": {
- "entity_count": stats['entity_count'],
- "relation_count": stats['relation_count'],
- "transcript_count": stats['transcript_count']
- },
- "gap_summary": {
- "total": len(gaps),
- "by_type": dict(gap_by_type),
- "by_severity": severity_count
+ "entity_count": stats["entity_count"],
+ "relation_count": stats["relation_count"],
+ "transcript_count": stats["transcript_count"],
},
+ "gap_summary": {"total": len(gaps), "by_type": dict(gap_by_type), "by_severity": severity_count},
"top_gaps": [g.to_dict() for g in gaps[:10]],
- "recommendations": self._generate_recommendations(gaps)
+ "recommendations": self._generate_recommendations(gaps),
}
def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]:
@@ -1926,6 +1948,7 @@ class KnowledgeGapDetection:
# ==================== 搜索管理器 ====================
+
class SearchManager:
"""
搜索管理器 - 统一入口
@@ -1940,8 +1963,7 @@ class SearchManager:
self.path_discovery = EntityPathDiscovery(db_path)
self.gap_detection = KnowledgeGapDetection(db_path)
- def hybrid_search(self, query: str, project_id: str | None = None,
- limit: int = 20) -> dict:
+ def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict:
"""
混合搜索(全文 + 语义)
@@ -1954,16 +1976,12 @@ class SearchManager:
Dict: 混合搜索结果
"""
# 全文搜索
- fulltext_results = self.fulltext_search.search(
- query, project_id, limit=limit
- )
+ fulltext_results = self.fulltext_search.search(query, project_id, limit=limit)
# 语义搜索
semantic_results = []
if self.semantic_search.is_available():
- semantic_results = self.semantic_search.search(
- query, project_id, top_k=limit
- )
+ semantic_results = self.semantic_search.search(query, project_id, top_k=limit)
# 合并结果(去重并加权)
combined = {}
@@ -1979,7 +1997,7 @@ class SearchManager:
"fulltext_score": r.score,
"semantic_score": 0,
"combined_score": r.score * 0.6, # 全文权重 60%
- "highlights": r.highlights
+ "highlights": r.highlights,
}
# 添加语义搜索结果
@@ -1997,7 +2015,7 @@ class SearchManager:
"fulltext_score": 0,
"semantic_score": r.similarity,
"combined_score": r.similarity * 0.4,
- "highlights": []
+ "highlights": [],
}
# 排序
@@ -2010,7 +2028,7 @@ class SearchManager:
"total": len(results),
"fulltext_count": len(fulltext_results),
"semantic_count": len(semantic_results),
- "results": results[:limit]
+ "results": results[:limit],
}
def index_project(self, project_id: str) -> dict:
@@ -2035,13 +2053,12 @@ class SearchManager:
# 索引转录文本
transcripts = conn.execute(
- "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
- (project_id,)
+ "SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?", (project_id,)
).fetchall()
for t in transcripts:
- if t['full_text'] and self.semantic_search.index_embedding(
- t['id'], 'transcript', t['project_id'], t['full_text']
+ if t["full_text"] and self.semantic_search.index_embedding(
+ t["id"], "transcript", t["project_id"], t["full_text"]
):
semantic_stats["indexed"] += 1
else:
@@ -2049,26 +2066,19 @@ class SearchManager:
# 索引实体
entities = conn.execute(
- "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
- (project_id,)
+ "SELECT id, project_id, name, definition FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
for e in entities:
text = f"{e['name']} {e['definition'] or ''}"
- if self.semantic_search.index_embedding(
- e['id'], 'entity', e['project_id'], text
- ):
+ if self.semantic_search.index_embedding(e["id"], "entity", e["project_id"], text):
semantic_stats["indexed"] += 1
else:
semantic_stats["errors"] += 1
conn.close()
- return {
- "project_id": project_id,
- "fulltext": fulltext_stats,
- "semantic": semantic_stats
- }
+ return {"project_id": project_id, "fulltext": fulltext_stats, "semantic": semantic_stats}
def get_search_stats(self, project_id: str | None = None) -> dict:
"""获取搜索统计信息"""
@@ -2080,15 +2090,13 @@ class SearchManager:
# 全文索引统计
fulltext_count = conn.execute(
- f"SELECT COUNT(*) as count FROM search_indexes {where_clause}",
- params
- ).fetchone()['count']
+ f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params
+ ).fetchone()["count"]
# 语义索引统计
- semantic_count = conn.execute(
- f"SELECT COUNT(*) as count FROM embeddings {where_clause}",
- params
- ).fetchone()['count']
+ semantic_count = conn.execute(f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params).fetchone()[
+ "count"
+ ]
# 按类型统计
type_stats = {}
@@ -2097,9 +2105,9 @@ class SearchManager:
"""SELECT content_type, COUNT(*) as count
FROM search_indexes WHERE project_id = ?
GROUP BY content_type""",
- (project_id,)
+ (project_id,),
).fetchall()
- type_stats = {r['content_type']: r['count'] for r in rows}
+ type_stats = {r["content_type"]: r["count"] for r in rows}
conn.close()
@@ -2108,7 +2116,7 @@ class SearchManager:
"fulltext_indexed": fulltext_count,
"semantic_indexed": semantic_count,
"by_content_type": type_stats,
- "semantic_search_available": self.semantic_search.is_available()
+ "semantic_search_available": self.semantic_search.is_available(),
}
@@ -2125,22 +2133,19 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
# 便捷函数
-def fulltext_search(query: str, project_id: str | None = None,
- limit: int = 20) -> list[SearchResult]:
+def fulltext_search(query: str, project_id: str | None = None, limit: int = 20) -> list[SearchResult]:
"""全文搜索便捷函数"""
manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit)
-def semantic_search(query: str, project_id: str | None = None,
- top_k: int = 10) -> list[SemanticSearchResult]:
+def semantic_search(query: str, project_id: str | None = None, top_k: int = 10) -> list[SemanticSearchResult]:
"""语义搜索便捷函数"""
manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k)
-def find_entity_path(source_id: str, target_id: str,
- max_depth: int = 5) -> EntityPath | None:
+def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None:
"""查找实体路径便捷函数"""
manager = get_search_manager()
return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth)
diff --git a/backend/security_manager.py b/backend/security_manager.py
index 777bb5b..940a60b 100644
--- a/backend/security_manager.py
+++ b/backend/security_manager.py
@@ -19,6 +19,7 @@ try:
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+
CRYPTO_AVAILABLE = True
except ImportError:
CRYPTO_AVAILABLE = False
@@ -27,6 +28,7 @@ except ImportError:
class AuditActionType(Enum):
"""审计动作类型"""
+
CREATE = "create"
READ = "read"
UPDATE = "update"
@@ -49,26 +51,29 @@ class AuditActionType(Enum):
class DataSensitivityLevel(Enum):
"""数据敏感度级别"""
- PUBLIC = "public" # 公开
- INTERNAL = "internal" # 内部
+
+ PUBLIC = "public" # 公开
+ INTERNAL = "internal" # 内部
CONFIDENTIAL = "confidential" # 机密
- SECRET = "secret" # 绝密
+ SECRET = "secret" # 绝密
class MaskingRuleType(Enum):
"""脱敏规则类型"""
- PHONE = "phone" # 手机号
- EMAIL = "email" # 邮箱
- ID_CARD = "id_card" # 身份证号
- BANK_CARD = "bank_card" # 银行卡号
- NAME = "name" # 姓名
- ADDRESS = "address" # 地址
- CUSTOM = "custom" # 自定义
+
+ PHONE = "phone" # 手机号
+ EMAIL = "email" # 邮箱
+ ID_CARD = "id_card" # 身份证号
+ BANK_CARD = "bank_card" # 银行卡号
+ NAME = "name" # 姓名
+ ADDRESS = "address" # 地址
+ CUSTOM = "custom" # 自定义
@dataclass
class AuditLog:
"""审计日志条目"""
+
id: str
action_type: str
user_id: str | None = None
@@ -90,6 +95,7 @@ class AuditLog:
@dataclass
class EncryptionConfig:
"""加密配置"""
+
id: str
project_id: str
is_enabled: bool = False
@@ -107,6 +113,7 @@ class EncryptionConfig:
@dataclass
class MaskingRule:
"""脱敏规则"""
+
id: str
project_id: str
name: str
@@ -126,6 +133,7 @@ class MaskingRule:
@dataclass
class DataAccessPolicy:
"""数据访问策略"""
+
id: str
project_id: str
name: str
@@ -147,6 +155,7 @@ class DataAccessPolicy:
@dataclass
class AccessRequest:
"""访问请求(用于需要审批的访问)"""
+
id: str
policy_id: str
user_id: str
@@ -166,30 +175,15 @@ class SecurityManager:
# 预定义脱敏规则
DEFAULT_MASKING_RULES = {
- MaskingRuleType.PHONE: {
- "pattern": r"(\d{3})\d{4}(\d{4})",
- "replacement": r"\1****\2"
- },
- MaskingRuleType.EMAIL: {
- "pattern": r"(\w{1,3})\w+(@\w+\.\w+)",
- "replacement": r"\1***\2"
- },
- MaskingRuleType.ID_CARD: {
- "pattern": r"(\d{6})\d{8}(\d{4})",
- "replacement": r"\1********\2"
- },
- MaskingRuleType.BANK_CARD: {
- "pattern": r"(\d{4})\d+(\d{4})",
- "replacement": r"\1 **** **** \2"
- },
- MaskingRuleType.NAME: {
- "pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+",
- "replacement": r"\1**"
- },
+ MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
+ MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
+ MaskingRuleType.ID_CARD: {"pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2"},
+ MaskingRuleType.BANK_CARD: {"pattern": r"(\d{4})\d+(\d{4})", "replacement": r"\1 **** **** \2"},
+ MaskingRuleType.NAME: {"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", "replacement": r"\1**"},
MaskingRuleType.ADDRESS: {
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
- "replacement": r"\1\2***"
- }
+ "replacement": r"\1\2***",
+ },
}
def __init__(self, db_path: str = "insightflow.db"):
@@ -200,7 +194,7 @@ class SecurityManager:
self._local = {}
self._init_db()
- def _init_db(self):
+ def _init_db(self) -> None:
"""初始化数据库表"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -308,9 +302,7 @@ class SecurityManager:
def _generate_id(self) -> str:
"""生成唯一ID"""
- return hashlib.sha256(
- f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
- ).hexdigest()[:32]
+ return hashlib.sha256(f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()).hexdigest()[:32]
# ==================== 审计日志 ====================
@@ -326,7 +318,7 @@ class SecurityManager:
before_value: str | None = None,
after_value: str | None = None,
success: bool = True,
- error_message: str | None = None
+ error_message: str | None = None,
) -> AuditLog:
"""记录审计日志"""
log = AuditLog(
@@ -341,22 +333,34 @@ class SecurityManager:
before_value=before_value,
after_value=after_value,
success=success,
- error_message=error_message
+ error_message=error_message,
)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO audit_logs
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
action_details, before_value, after_value, success, error_message, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- log.id, log.action_type, log.user_id, log.user_ip, log.user_agent,
- log.resource_type, log.resource_id, log.action_details,
- log.before_value, log.after_value, int(log.success),
- log.error_message, log.created_at
- ))
+ """,
+ (
+ log.id,
+ log.action_type,
+ log.user_id,
+ log.user_ip,
+ log.user_agent,
+ log.resource_type,
+ log.resource_id,
+ log.action_details,
+ log.before_value,
+ log.after_value,
+ int(log.success),
+ log.error_message,
+ log.created_at,
+ ),
+ )
conn.commit()
conn.close()
@@ -372,7 +376,7 @@ class SecurityManager:
end_time: str | None = None,
success: bool | None = None,
limit: int = 100,
- offset: int = 0
+ offset: int = 0,
) -> list[AuditLog]:
"""查询审计日志"""
conn = sqlite3.connect(self.db_path)
@@ -429,18 +433,14 @@ class SecurityManager:
after_value=row[9],
success=bool(row[10]),
error_message=row[11],
- created_at=row[12]
+ created_at=row[12],
)
logs.append(log)
conn.close()
return logs
- def get_audit_stats(
- self,
- start_time: str | None = None,
- end_time: str | None = None
- ) -> dict[str, Any]:
+ def get_audit_stats(self, start_time: str | None = None, end_time: str | None = None) -> dict[str, Any]:
"""获取审计统计"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -460,12 +460,7 @@ class SecurityManager:
cursor.execute(query, params)
rows = cursor.fetchall()
- stats = {
- "total_actions": 0,
- "success_count": 0,
- "failure_count": 0,
- "action_breakdown": {}
- }
+ stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}}
for action_type, success, count in rows:
stats["total_actions"] += count
@@ -500,11 +495,7 @@ class SecurityManager:
)
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
- def enable_encryption(
- self,
- project_id: str,
- master_password: str
- ) -> EncryptionConfig:
+ def enable_encryption(self, project_id: str, master_password: str) -> EncryptionConfig:
"""启用项目加密"""
if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available")
@@ -523,43 +514,54 @@ class SecurityManager:
encryption_type="aes-256-gcm",
key_derivation="pbkdf2",
master_key_hash=key_hash,
- salt=salt
+ salt=salt,
)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 检查是否已存在配置
- cursor.execute(
- "SELECT id FROM encryption_configs WHERE project_id = ?",
- (project_id,)
- )
+ cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,))
existing = cursor.fetchone()
if existing:
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE encryption_configs
SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
master_key_hash = ?, salt = ?, updated_at = ?
WHERE project_id = ?
- """, (
- config.encryption_type, config.key_derivation,
- config.master_key_hash, config.salt,
- config.updated_at, project_id
- ))
+ """,
+ (
+ config.encryption_type,
+ config.key_derivation,
+ config.master_key_hash,
+ config.salt,
+ config.updated_at,
+ project_id,
+ ),
+ )
config.id = existing[0]
else:
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO encryption_configs
(id, project_id, is_enabled, encryption_type, key_derivation,
master_key_hash, salt, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- config.id, config.project_id, int(config.is_enabled),
- config.encryption_type, config.key_derivation,
- config.master_key_hash, config.salt,
- config.created_at, config.updated_at
- ))
+ """,
+ (
+ config.id,
+ config.project_id,
+ int(config.is_enabled),
+ config.encryption_type,
+ config.key_derivation,
+ config.master_key_hash,
+ config.salt,
+ config.created_at,
+ config.updated_at,
+ ),
+ )
conn.commit()
conn.close()
@@ -569,16 +571,12 @@ class SecurityManager:
action_type=AuditActionType.ENCRYPTION_ENABLE,
resource_type="project",
resource_id=project_id,
- action_details={"encryption_type": config.encryption_type}
+ action_details={"encryption_type": config.encryption_type},
)
return config
- def disable_encryption(
- self,
- project_id: str,
- master_password: str
- ) -> bool:
+ def disable_encryption(self, project_id: str, master_password: str) -> bool:
"""禁用项目加密"""
# 验证密码
if not self.verify_encryption_password(project_id, master_password):
@@ -587,29 +585,24 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE encryption_configs
SET is_enabled = 0, updated_at = ?
WHERE project_id = ?
- """, (datetime.now().isoformat(), project_id))
+ """,
+ (datetime.now().isoformat(), project_id),
+ )
conn.commit()
conn.close()
# 记录审计日志
- self.log_audit(
- action_type=AuditActionType.ENCRYPTION_DISABLE,
- resource_type="project",
- resource_id=project_id
- )
+ self.log_audit(action_type=AuditActionType.ENCRYPTION_DISABLE, resource_type="project", resource_id=project_id)
return True
- def verify_encryption_password(
- self,
- project_id: str,
- password: str
- ) -> bool:
+ def verify_encryption_password(self, project_id: str, password: str) -> bool:
"""验证加密密码"""
if not CRYPTO_AVAILABLE:
return False
@@ -617,10 +610,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute(
- "SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
- (project_id,)
- )
+ cursor.execute("SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id,))
row = cursor.fetchone()
conn.close()
@@ -638,10 +628,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute(
- "SELECT * FROM encryption_configs WHERE project_id = ?",
- (project_id,)
- )
+ cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,))
row = cursor.fetchone()
conn.close()
@@ -657,15 +644,10 @@ class SecurityManager:
master_key_hash=row[5],
salt=row[6],
created_at=row[7],
- updated_at=row[8]
+ updated_at=row[8],
)
- def encrypt_data(
- self,
- data: str,
- password: str,
- salt: str | None = None
- ) -> tuple[str, str]:
+ def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
"""加密数据"""
if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available")
@@ -679,12 +661,7 @@ class SecurityManager:
return base64.b64encode(encrypted).decode(), salt
- def decrypt_data(
- self,
- encrypted_data: str,
- password: str,
- salt: str
- ) -> str:
+ def decrypt_data(self, encrypted_data: str, password: str, salt: str) -> str:
"""解密数据"""
if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available")
@@ -705,7 +682,7 @@ class SecurityManager:
pattern: str | None = None,
replacement: str | None = None,
description: str | None = None,
- priority: int = 0
+ priority: int = 0,
) -> MaskingRule:
"""创建脱敏规则"""
# 使用预定义规则或自定义规则
@@ -722,22 +699,33 @@ class SecurityManager:
pattern=pattern or "",
replacement=replacement or "****",
description=description,
- priority=priority
+ priority=priority,
)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO masking_rules
(id, project_id, name, rule_type, pattern, replacement,
is_active, priority, description, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- rule.id, rule.project_id, rule.name, rule.rule_type,
- rule.pattern, rule.replacement, int(rule.is_active),
- rule.priority, rule.description, rule.created_at, rule.updated_at
- ))
+ """,
+ (
+ rule.id,
+ rule.project_id,
+ rule.name,
+ rule.rule_type,
+ rule.pattern,
+ rule.replacement,
+ int(rule.is_active),
+ rule.priority,
+ rule.description,
+ rule.created_at,
+ rule.updated_at,
+ ),
+ )
conn.commit()
conn.close()
@@ -747,16 +735,12 @@ class SecurityManager:
action_type=AuditActionType.DATA_MASKING,
resource_type="project",
resource_id=project_id,
- action_details={"action": "create_rule", "rule_name": name}
+ action_details={"action": "create_rule", "rule_name": name},
)
return rule
- def get_masking_rules(
- self,
- project_id: str,
- active_only: bool = True
- ) -> list[MaskingRule]:
+ def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]:
"""获取脱敏规则"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -775,27 +759,25 @@ class SecurityManager:
rules = []
for row in rows:
- rules.append(MaskingRule(
- id=row[0],
- project_id=row[1],
- name=row[2],
- rule_type=row[3],
- pattern=row[4],
- replacement=row[5],
- is_active=bool(row[6]),
- priority=row[7],
- description=row[8],
- created_at=row[9],
- updated_at=row[10]
- ))
+ rules.append(
+ MaskingRule(
+ id=row[0],
+ project_id=row[1],
+ name=row[2],
+ rule_type=row[3],
+ pattern=row[4],
+ replacement=row[5],
+ is_active=bool(row[6]),
+ priority=row[7],
+ description=row[8],
+ created_at=row[9],
+ updated_at=row[10],
+ )
+ )
return rules
- def update_masking_rule(
- self,
- rule_id: str,
- **kwargs
- ) -> MaskingRule | None:
+ def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None:
"""更新脱敏规则"""
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
@@ -818,11 +800,14 @@ class SecurityManager:
params.append(datetime.now().isoformat())
params.append(rule_id)
- cursor.execute(f"""
+ cursor.execute(
+ f"""
UPDATE masking_rules
SET {', '.join(set_clauses)}
WHERE id = ?
- """, params)
+ """,
+ params,
+ )
conn.commit()
conn.close()
@@ -848,7 +833,7 @@ class SecurityManager:
priority=row[7],
description=row[8],
created_at=row[9],
- updated_at=row[10]
+ updated_at=row[10],
)
def delete_masking_rule(self, rule_id: str) -> bool:
@@ -864,12 +849,7 @@ class SecurityManager:
return success
- def apply_masking(
- self,
- text: str,
- project_id: str,
- rule_types: list[MaskingRuleType] | None = None
- ) -> str:
+ def apply_masking(self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None) -> str:
"""应用脱敏规则到文本"""
rules = self.get_masking_rules(project_id)
@@ -884,22 +864,14 @@ class SecurityManager:
continue
try:
- masked_text = re.sub(
- rule.pattern,
- rule.replacement,
- masked_text
- )
+ masked_text = re.sub(rule.pattern, rule.replacement, masked_text)
except re.error:
# 忽略无效的正则表达式
continue
return masked_text
- def apply_masking_to_entity(
- self,
- entity_data: dict[str, Any],
- project_id: str
- ) -> dict[str, Any]:
+ def apply_masking_to_entity(self, entity_data: dict[str, Any], project_id: str) -> dict[str, Any]:
"""对实体数据应用脱敏"""
masked_data = entity_data.copy()
@@ -924,7 +896,7 @@ class SecurityManager:
allowed_ips: list[str] | None = None,
time_restrictions: dict | None = None,
max_access_count: int | None = None,
- require_approval: bool = False
+ require_approval: bool = False,
) -> DataAccessPolicy:
"""创建数据访问策略"""
policy = DataAccessPolicy(
@@ -937,36 +909,43 @@ class SecurityManager:
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
max_access_count=max_access_count,
- require_approval=require_approval
+ require_approval=require_approval,
)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO data_access_policies
(id, project_id, name, description, allowed_users, allowed_roles,
allowed_ips, time_restrictions, max_access_count, require_approval,
is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- policy.id, policy.project_id, policy.name, policy.description,
- policy.allowed_users, policy.allowed_roles, policy.allowed_ips,
- policy.time_restrictions, policy.max_access_count,
- int(policy.require_approval), int(policy.is_active),
- policy.created_at, policy.updated_at
- ))
+ """,
+ (
+ policy.id,
+ policy.project_id,
+ policy.name,
+ policy.description,
+ policy.allowed_users,
+ policy.allowed_roles,
+ policy.allowed_ips,
+ policy.time_restrictions,
+ policy.max_access_count,
+ int(policy.require_approval),
+ int(policy.is_active),
+ policy.created_at,
+ policy.updated_at,
+ ),
+ )
conn.commit()
conn.close()
return policy
- def get_access_policies(
- self,
- project_id: str,
- active_only: bool = True
- ) -> list[DataAccessPolicy]:
+ def get_access_policies(self, project_id: str, active_only: bool = True) -> list[DataAccessPolicy]:
"""获取数据访问策略"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -983,38 +962,34 @@ class SecurityManager:
policies = []
for row in rows:
- policies.append(DataAccessPolicy(
- id=row[0],
- project_id=row[1],
- name=row[2],
- description=row[3],
- allowed_users=row[4],
- allowed_roles=row[5],
- allowed_ips=row[6],
- time_restrictions=row[7],
- max_access_count=row[8],
- require_approval=bool(row[9]),
- is_active=bool(row[10]),
- created_at=row[11],
- updated_at=row[12]
- ))
+ policies.append(
+ DataAccessPolicy(
+ id=row[0],
+ project_id=row[1],
+ name=row[2],
+ description=row[3],
+ allowed_users=row[4],
+ allowed_roles=row[5],
+ allowed_ips=row[6],
+ time_restrictions=row[7],
+ max_access_count=row[8],
+ require_approval=bool(row[9]),
+ is_active=bool(row[10]),
+ created_at=row[11],
+ updated_at=row[12],
+ )
+ )
return policies
def check_access_permission(
- self,
- policy_id: str,
- user_id: str,
- user_ip: str | None = None
+ self, policy_id: str, user_id: str, user_ip: str | None = None
) -> tuple[bool, str | None]:
"""检查访问权限"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute(
- "SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
- (policy_id,)
- )
+ cursor.execute("SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,))
row = cursor.fetchone()
conn.close()
@@ -1034,7 +1009,7 @@ class SecurityManager:
require_approval=bool(row[9]),
is_active=bool(row[10]),
created_at=row[11],
- updated_at=row[12]
+ updated_at=row[12],
)
# 检查用户白名单
@@ -1074,11 +1049,14 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT * FROM access_requests
WHERE policy_id = ? AND user_id = ? AND status = 'approved'
AND (expires_at IS NULL OR expires_at > ?)
- """, (policy_id, user_id, datetime.now().isoformat()))
+ """,
+ (policy_id, user_id, datetime.now().isoformat()),
+ )
request = cursor.fetchone()
conn.close()
@@ -1104,11 +1082,7 @@ class SecurityManager:
return ip == pattern
def create_access_request(
- self,
- policy_id: str,
- user_id: str,
- request_reason: str | None = None,
- expires_hours: int = 24
+ self, policy_id: str, user_id: str, request_reason: str | None = None, expires_hours: int = 24
) -> AccessRequest:
"""创建访问请求"""
request = AccessRequest(
@@ -1116,21 +1090,28 @@ class SecurityManager:
policy_id=policy_id,
user_id=user_id,
request_reason=request_reason,
- expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat()
+ expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(),
)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO access_requests
(id, policy_id, user_id, request_reason, status, expires_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
- """, (
- request.id, request.policy_id, request.user_id,
- request.request_reason, request.status, request.expires_at,
- request.created_at
- ))
+ """,
+ (
+ request.id,
+ request.policy_id,
+ request.user_id,
+ request.request_reason,
+ request.status,
+ request.expires_at,
+ request.created_at,
+ ),
+ )
conn.commit()
conn.close()
@@ -1138,10 +1119,7 @@ class SecurityManager:
return request
def approve_access_request(
- self,
- request_id: str,
- approved_by: str,
- expires_hours: int = 24
+ self, request_id: str, approved_by: str, expires_hours: int = 24
) -> AccessRequest | None:
"""批准访问请求"""
conn = sqlite3.connect(self.db_path)
@@ -1150,11 +1128,14 @@ class SecurityManager:
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
approved_at = datetime.now().isoformat()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE access_requests
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
WHERE id = ?
- """, (approved_by, approved_at, expires_at, request_id))
+ """,
+ (approved_by, approved_at, expires_at, request_id),
+ )
conn.commit()
@@ -1175,23 +1156,22 @@ class SecurityManager:
approved_by=row[5],
approved_at=row[6],
expires_at=row[7],
- created_at=row[8]
+ created_at=row[8],
)
- def reject_access_request(
- self,
- request_id: str,
- rejected_by: str
- ) -> AccessRequest | None:
+ def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
"""拒绝访问请求"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE access_requests
SET status = 'rejected', approved_by = ?
WHERE id = ?
- """, (rejected_by, request_id))
+ """,
+ (rejected_by, request_id),
+ )
conn.commit()
@@ -1211,7 +1191,7 @@ class SecurityManager:
approved_by=row[5],
approved_at=row[6],
expires_at=row[7],
- created_at=row[8]
+ created_at=row[8],
)
diff --git a/backend/subscription_manager.py b/backend/subscription_manager.py
index 1f0a80f..2aef34d 100644
--- a/backend/subscription_manager.py
+++ b/backend/subscription_manager.py
@@ -24,63 +24,69 @@ logger = logging.getLogger(__name__)
class SubscriptionStatus(StrEnum):
"""订阅状态"""
- ACTIVE = "active" # 活跃
- CANCELLED = "cancelled" # 已取消
- EXPIRED = "expired" # 已过期
- PAST_DUE = "past_due" # 逾期
- TRIAL = "trial" # 试用中
- PENDING = "pending" # 待支付
+
+ ACTIVE = "active" # 活跃
+ CANCELLED = "cancelled" # 已取消
+ EXPIRED = "expired" # 已过期
+ PAST_DUE = "past_due" # 逾期
+ TRIAL = "trial" # 试用中
+ PENDING = "pending" # 待支付
class PaymentProvider(StrEnum):
"""支付提供商"""
- STRIPE = "stripe" # Stripe
- ALIPAY = "alipay" # 支付宝
- WECHAT = "wechat" # 微信支付
+
+ STRIPE = "stripe" # Stripe
+ ALIPAY = "alipay" # 支付宝
+ WECHAT = "wechat" # 微信支付
BANK_TRANSFER = "bank_transfer" # 银行转账
class PaymentStatus(StrEnum):
"""支付状态"""
- PENDING = "pending" # 待支付
- PROCESSING = "processing" # 处理中
- COMPLETED = "completed" # 已完成
- FAILED = "failed" # 失败
- REFUNDED = "refunded" # 已退款
+
+ PENDING = "pending" # 待支付
+ PROCESSING = "processing" # 处理中
+ COMPLETED = "completed" # 已完成
+ FAILED = "failed" # 失败
+ REFUNDED = "refunded" # 已退款
PARTIAL_REFUNDED = "partial_refunded" # 部分退款
class InvoiceStatus(StrEnum):
"""发票状态"""
- DRAFT = "draft" # 草稿
- ISSUED = "issued" # 已开具
- PAID = "paid" # 已支付
- OVERDUE = "overdue" # 逾期
- VOID = "void" # 作废
+
+ DRAFT = "draft" # 草稿
+ ISSUED = "issued" # 已开具
+ PAID = "paid" # 已支付
+ OVERDUE = "overdue" # 逾期
+ VOID = "void" # 作废
CREDIT_NOTE = "credit_note" # 贷项通知单
class RefundStatus(StrEnum):
"""退款状态"""
- PENDING = "pending" # 待处理
- APPROVED = "approved" # 已批准
- REJECTED = "rejected" # 已拒绝
- COMPLETED = "completed" # 已完成
- FAILED = "failed" # 失败
+
+ PENDING = "pending" # 待处理
+ APPROVED = "approved" # 已批准
+ REJECTED = "rejected" # 已拒绝
+ COMPLETED = "completed" # 已完成
+ FAILED = "failed" # 失败
@dataclass
class SubscriptionPlan:
"""订阅计划数据类"""
+
id: str
name: str
- tier: str # free/pro/enterprise
+ tier: str # free/pro/enterprise
description: str
- price_monthly: float # 月付价格
- price_yearly: float # 年付价格
- currency: str # CNY/USD
- features: list[str] # 功能列表
- limits: dict[str, Any] # 资源限制
+ price_monthly: float # 月付价格
+ price_yearly: float # 年付价格
+ currency: str # CNY/USD
+ features: list[str] # 功能列表
+ limits: dict[str, Any] # 资源限制
is_active: bool
created_at: datetime
updated_at: datetime
@@ -90,6 +96,7 @@ class SubscriptionPlan:
@dataclass
class Subscription:
"""订阅数据类"""
+
id: str
tenant_id: str
plan_id: str
@@ -110,13 +117,14 @@ class Subscription:
@dataclass
class UsageRecord:
"""用量记录数据类"""
+
id: str
tenant_id: str
- resource_type: str # transcription/storage/api_call
- quantity: float # 使用量
- unit: str # minutes/mb/count
+ resource_type: str # transcription/storage/api_call
+ quantity: float # 使用量
+ unit: str # minutes/mb/count
recorded_at: datetime
- cost: float # 费用
+ cost: float # 费用
description: str | None
metadata: dict[str, Any]
@@ -124,6 +132,7 @@ class UsageRecord:
@dataclass
class Payment:
"""支付记录数据类"""
+
id: str
tenant_id: str
subscription_id: str | None
@@ -145,6 +154,7 @@ class Payment:
@dataclass
class Invoice:
"""发票数据类"""
+
id: str
tenant_id: str
subscription_id: str | None
@@ -168,6 +178,7 @@ class Invoice:
@dataclass
class Refund:
"""退款数据类"""
+
id: str
tenant_id: str
payment_id: str
@@ -190,14 +201,15 @@ class Refund:
@dataclass
class BillingHistory:
"""账单历史数据类"""
+
id: str
tenant_id: str
- type: str # subscription/usage/payment/refund
+ type: str # subscription/usage/payment/refund
amount: float
currency: str
description: str
- reference_id: str # 关联的订阅/支付/退款ID
- balance_after: float # 操作后余额
+ reference_id: str # 关联的订阅/支付/退款ID
+ balance_after: float # 操作后余额
created_at: datetime
metadata: dict[str, Any]
@@ -214,21 +226,15 @@ class SubscriptionManager:
"price_monthly": 0.0,
"price_yearly": 0.0,
"currency": "CNY",
- "features": [
- "basic_analysis",
- "export_png",
- "3_projects",
- "100_mb_storage",
- "60_min_transcription"
- ],
+ "features": ["basic_analysis", "export_png", "3_projects", "100_mb_storage", "60_min_transcription"],
"limits": {
"max_projects": 3,
"max_storage_mb": 100,
"max_transcription_minutes": 60,
"max_api_calls_per_day": 100,
"max_team_members": 2,
- "max_entities": 100
- }
+ "max_entities": 100,
+ },
},
"pro": {
"name": "Pro",
@@ -246,7 +252,7 @@ class SubscriptionManager:
"collaboration",
"20_projects",
"10_gb_storage",
- "600_min_transcription"
+ "600_min_transcription",
],
"limits": {
"max_projects": 20,
@@ -254,8 +260,8 @@ class SubscriptionManager:
"max_transcription_minutes": 600,
"max_api_calls_per_day": 10000,
"max_team_members": 10,
- "max_entities": 1000
- }
+ "max_entities": 1000,
+ },
},
"enterprise": {
"name": "Enterprise",
@@ -272,7 +278,7 @@ class SubscriptionManager:
"priority_support",
"custom_integration",
"sla_guarantee",
- "dedicated_manager"
+ "dedicated_manager",
],
"limits": {
"max_projects": -1,
@@ -280,33 +286,17 @@ class SubscriptionManager:
"max_transcription_minutes": -1,
"max_api_calls_per_day": -1,
"max_team_members": -1,
- "max_entities": -1
- }
- }
+ "max_entities": -1,
+ },
+ },
}
# 按量计费单价(CNY)
USAGE_PRICING = {
- "transcription": {
- "unit": "minute",
- "price": 0.5, # 0.5元/分钟
- "free_quota": 60 # 每月免费额度
- },
- "storage": {
- "unit": "gb",
- "price": 10.0, # 10元/GB/月
- "free_quota": 0.1 # 100MB免费
- },
- "api_call": {
- "unit": "1000_calls",
- "price": 5.0, # 5元/1000次
- "free_quota": 1000 # 每月免费1000次
- },
- "export": {
- "unit": "page",
- "price": 0.1, # 0.1元/页(PDF导出)
- "free_quota": 100
- }
+ "transcription": {"unit": "minute", "price": 0.5, "free_quota": 60}, # 0.5元/分钟 # 每月免费额度
+ "storage": {"unit": "gb", "price": 10.0, "free_quota": 0.1}, # 10元/GB/月 # 100MB免费
+ "api_call": {"unit": "1000_calls", "price": 5.0, "free_quota": 1000}, # 5元/1000次 # 每月免费1000次
+ "export": {"unit": "page", "price": 0.1, "free_quota": 100}, # 0.1元/页(PDF导出)
}
def __init__(self, db_path: str = "insightflow.db"):
@@ -320,7 +310,7 @@ class SubscriptionManager:
conn.row_factory = sqlite3.Row
return conn
- def _init_db(self):
+ def _init_db(self) -> None:
"""初始化数据库表"""
conn = self._get_connection()
try:
@@ -504,33 +494,36 @@ class SubscriptionManager:
finally:
conn.close()
- def _init_default_plans(self):
+ def _init_default_plans(self) -> None:
"""初始化默认订阅计划"""
conn = self._get_connection()
try:
cursor = conn.cursor()
for tier, plan_data in self.DEFAULT_PLANS.items():
- cursor.execute("""
+ cursor.execute(
+ """
INSERT OR IGNORE INTO subscription_plans
(id, name, tier, description, price_monthly, price_yearly, currency,
features, limits, is_active, created_at, updated_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- str(uuid.uuid4()),
- plan_data["name"],
- plan_data["tier"],
- plan_data["description"],
- plan_data["price_monthly"],
- plan_data["price_yearly"],
- plan_data["currency"],
- json.dumps(plan_data["features"]),
- json.dumps(plan_data["limits"]),
- 1,
- datetime.now(),
- datetime.now(),
- json.dumps({})
- ))
+ """,
+ (
+ str(uuid.uuid4()),
+ plan_data["name"],
+ plan_data["tier"],
+ plan_data["description"],
+ plan_data["price_monthly"],
+ plan_data["price_yearly"],
+ plan_data["currency"],
+ json.dumps(plan_data["features"]),
+ json.dumps(plan_data["limits"]),
+ 1,
+ datetime.now(),
+ datetime.now(),
+ json.dumps({}),
+ ),
+ )
conn.commit()
logger.info("Default subscription plans initialized")
@@ -589,10 +582,17 @@ class SubscriptionManager:
finally:
conn.close()
- def create_plan(self, name: str, tier: str, description: str,
- price_monthly: float, price_yearly: float,
- currency: str = "CNY", features: list[str] = None,
- limits: dict[str, Any] = None) -> SubscriptionPlan:
+ def create_plan(
+ self,
+ name: str,
+ tier: str,
+ description: str,
+ price_monthly: float,
+ price_yearly: float,
+ currency: str = "CNY",
+ features: list[str] = None,
+ limits: dict[str, Any] = None,
+ ) -> SubscriptionPlan:
"""创建新订阅计划"""
conn = self._get_connection()
try:
@@ -611,22 +611,33 @@ class SubscriptionManager:
is_active=True,
created_at=datetime.now(),
updated_at=datetime.now(),
- metadata={}
+ metadata={},
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO subscription_plans
(id, name, tier, description, price_monthly, price_yearly, currency,
features, limits, is_active, created_at, updated_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- plan.id, plan.name, plan.tier, plan.description,
- plan.price_monthly, plan.price_yearly, plan.currency,
- json.dumps(plan.features), json.dumps(plan.limits),
- int(plan.is_active), plan.created_at, plan.updated_at,
- json.dumps(plan.metadata)
- ))
+ """,
+ (
+ plan.id,
+ plan.name,
+ plan.tier,
+ plan.description,
+ plan.price_monthly,
+ plan.price_yearly,
+ plan.currency,
+ json.dumps(plan.features),
+ json.dumps(plan.limits),
+ int(plan.is_active),
+ plan.created_at,
+ plan.updated_at,
+ json.dumps(plan.metadata),
+ ),
+ )
conn.commit()
logger.info(f"Subscription plan created: {plan_id} ({name})")
@@ -650,15 +661,23 @@ class SubscriptionManager:
updates = []
params = []
- allowed_fields = ['name', 'description', 'price_monthly', 'price_yearly',
- 'currency', 'features', 'limits', 'is_active']
+ allowed_fields = [
+ "name",
+ "description",
+ "price_monthly",
+ "price_yearly",
+ "currency",
+ "features",
+ "limits",
+ "is_active",
+ ]
for key, value in kwargs.items():
if key in allowed_fields:
updates.append(f"{key} = ?")
- if key in ['features', 'limits']:
- params.append(json.dumps(value) if value else '{}')
- elif key == 'is_active':
+ if key in ["features", "limits"]:
+ params.append(json.dumps(value) if value else "{}")
+ elif key == "is_active":
params.append(int(value))
else:
params.append(value)
@@ -671,10 +690,13 @@ class SubscriptionManager:
params.append(plan_id)
cursor = conn.cursor()
- cursor.execute(f"""
+ cursor.execute(
+ f"""
UPDATE subscription_plans SET {', '.join(updates)}
WHERE id = ?
- """, params)
+ """,
+ params,
+ )
conn.commit()
return self.get_plan(plan_id)
@@ -684,19 +706,26 @@ class SubscriptionManager:
# ==================== 订阅管理 ====================
- def create_subscription(self, tenant_id: str, plan_id: str,
- payment_provider: str | None = None,
- trial_days: int = 0,
- billing_cycle: str = "monthly") -> Subscription:
+ def create_subscription(
+ self,
+ tenant_id: str,
+ plan_id: str,
+ payment_provider: str | None = None,
+ trial_days: int = 0,
+ billing_cycle: str = "monthly",
+ ) -> Subscription:
"""创建新订阅"""
conn = self._get_connection()
try:
# 检查是否已有活跃订阅
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT * FROM subscriptions
WHERE tenant_id = ? AND status IN ('active', 'trial', 'pending')
- """, (tenant_id,))
+ """,
+ (tenant_id,),
+ )
existing = cursor.fetchone()
if existing:
@@ -741,37 +770,60 @@ class SubscriptionManager:
provider_subscription_id=None,
created_at=now,
updated_at=now,
- metadata={"billing_cycle": billing_cycle}
+ metadata={"billing_cycle": billing_cycle},
)
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO subscriptions
(id, tenant_id, plan_id, status, current_period_start, current_period_end,
cancel_at_period_end, canceled_at, trial_start, trial_end,
payment_provider, provider_subscription_id, created_at, updated_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- subscription.id, subscription.tenant_id, subscription.plan_id,
- subscription.status, subscription.current_period_start,
- subscription.current_period_end, int(subscription.cancel_at_period_end),
- subscription.canceled_at, subscription.trial_start, subscription.trial_end,
- subscription.payment_provider, subscription.provider_subscription_id,
- subscription.created_at, subscription.updated_at,
- json.dumps(subscription.metadata)
- ))
+ """,
+ (
+ subscription.id,
+ subscription.tenant_id,
+ subscription.plan_id,
+ subscription.status,
+ subscription.current_period_start,
+ subscription.current_period_end,
+ int(subscription.cancel_at_period_end),
+ subscription.canceled_at,
+ subscription.trial_start,
+ subscription.trial_end,
+ subscription.payment_provider,
+ subscription.provider_subscription_id,
+ subscription.created_at,
+ subscription.updated_at,
+ json.dumps(subscription.metadata),
+ ),
+ )
# 创建发票
amount = plan.price_yearly if billing_cycle == "yearly" else plan.price_monthly
if amount > 0 and trial_days == 0:
self._create_invoice_internal(
- conn, tenant_id, subscription_id, amount, plan.currency,
- now, period_end, f"{plan.name} Subscription ({billing_cycle})"
+ conn,
+ tenant_id,
+ subscription_id,
+ amount,
+ plan.currency,
+ now,
+ period_end,
+ f"{plan.name} Subscription ({billing_cycle})",
)
# 记录账单历史
self._add_billing_history_internal(
- conn, tenant_id, "subscription", 0, plan.currency,
- f"Subscription created: {plan.name}", subscription_id, 0
+ conn,
+ tenant_id,
+ "subscription",
+ 0,
+ plan.currency,
+ f"Subscription created: {plan.name}",
+ subscription_id,
+ 0,
)
conn.commit()
@@ -805,11 +857,14 @@ class SubscriptionManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT * FROM subscriptions
WHERE tenant_id = ? AND status IN ('active', 'trial', 'past_due', 'pending')
ORDER BY created_at DESC LIMIT 1
- """, (tenant_id,))
+ """,
+ (tenant_id,),
+ )
row = cursor.fetchone()
if row:
@@ -830,14 +885,21 @@ class SubscriptionManager:
updates = []
params = []
- allowed_fields = ['status', 'current_period_start', 'current_period_end',
- 'cancel_at_period_end', 'canceled_at', 'trial_end',
- 'payment_provider', 'provider_subscription_id']
+ allowed_fields = [
+ "status",
+ "current_period_start",
+ "current_period_end",
+ "cancel_at_period_end",
+ "canceled_at",
+ "trial_end",
+ "payment_provider",
+ "provider_subscription_id",
+ ]
for key, value in kwargs.items():
if key in allowed_fields:
updates.append(f"{key} = ?")
- if key == 'cancel_at_period_end':
+ if key == "cancel_at_period_end":
params.append(int(value))
else:
params.append(value)
@@ -850,10 +912,13 @@ class SubscriptionManager:
params.append(subscription_id)
cursor = conn.cursor()
- cursor.execute(f"""
+ cursor.execute(
+ f"""
UPDATE subscriptions SET {', '.join(updates)}
WHERE id = ?
- """, params)
+ """,
+ params,
+ )
conn.commit()
return self.get_subscription(subscription_id)
@@ -861,8 +926,7 @@ class SubscriptionManager:
finally:
conn.close()
- def cancel_subscription(self, subscription_id: str,
- at_period_end: bool = True) -> Subscription | None:
+ def cancel_subscription(self, subscription_id: str, at_period_end: bool = True) -> Subscription | None:
"""取消订阅"""
conn = self._get_connection()
try:
@@ -875,25 +939,36 @@ class SubscriptionManager:
if at_period_end:
# 在周期结束时取消
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE subscriptions
SET cancel_at_period_end = 1, canceled_at = ?, updated_at = ?
WHERE id = ?
- """, (now, now, subscription_id))
+ """,
+ (now, now, subscription_id),
+ )
else:
# 立即取消
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE subscriptions
SET status = 'cancelled', canceled_at = ?, updated_at = ?
WHERE id = ?
- """, (now, now, subscription_id))
+ """,
+ (now, now, subscription_id),
+ )
# 记录账单历史
self._add_billing_history_internal(
- conn, subscription.tenant_id, "subscription", 0, "CNY",
+ conn,
+ subscription.tenant_id,
+ "subscription",
+ 0,
+ "CNY",
f"Subscription cancelled{' (at period end)' if at_period_end else ''}",
- subscription_id, 0
+ subscription_id,
+ 0,
)
conn.commit()
@@ -903,8 +978,7 @@ class SubscriptionManager:
finally:
conn.close()
- def change_plan(self, subscription_id: str, new_plan_id: str,
- prorate: bool = True) -> Subscription | None:
+ def change_plan(self, subscription_id: str, new_plan_id: str, prorate: bool = True) -> Subscription | None:
"""更改订阅计划"""
conn = self._get_connection()
try:
@@ -926,17 +1000,25 @@ class SubscriptionManager:
pass
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE subscriptions
SET plan_id = ?, updated_at = ?
WHERE id = ?
- """, (new_plan_id, now, subscription_id))
+ """,
+ (new_plan_id, now, subscription_id),
+ )
# 记录账单历史
self._add_billing_history_internal(
- conn, subscription.tenant_id, "subscription", 0, new_plan.currency,
+ conn,
+ subscription.tenant_id,
+ "subscription",
+ 0,
+ new_plan.currency,
f"Plan changed from {old_plan.name if old_plan else 'unknown'} to {new_plan.name}",
- subscription_id, 0
+ subscription_id,
+ 0,
)
conn.commit()
@@ -948,10 +1030,15 @@ class SubscriptionManager:
# ==================== 用量计费 ====================
- def record_usage(self, tenant_id: str, resource_type: str,
- quantity: float, unit: str,
- description: str | None = None,
- metadata: dict | None = None) -> UsageRecord:
+ def record_usage(
+ self,
+ tenant_id: str,
+ resource_type: str,
+ quantity: float,
+ unit: str,
+ description: str | None = None,
+ metadata: dict | None = None,
+ ) -> UsageRecord:
"""记录用量"""
conn = self._get_connection()
try:
@@ -968,19 +1055,28 @@ class SubscriptionManager:
recorded_at=datetime.now(),
cost=cost,
description=description,
- metadata=metadata or {}
+ metadata=metadata or {},
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO usage_records
(id, tenant_id, resource_type, quantity, unit, recorded_at, cost, description, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- record.id, record.tenant_id, record.resource_type,
- record.quantity, record.unit, record.recorded_at,
- record.cost, record.description, json.dumps(record.metadata)
- ))
+ """,
+ (
+ record.id,
+ record.tenant_id,
+ record.resource_type,
+ record.quantity,
+ record.unit,
+ record.recorded_at,
+ record.cost,
+ record.description,
+ json.dumps(record.metadata),
+ ),
+ )
conn.commit()
return record
@@ -988,9 +1084,9 @@ class SubscriptionManager:
finally:
conn.close()
- def get_usage_summary(self, tenant_id: str,
- start_date: datetime | None = None,
- end_date: datetime | None = None) -> dict[str, Any]:
+ def get_usage_summary(
+ self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None
+ ) -> dict[str, Any]:
"""获取用量汇总"""
conn = self._get_connection()
try:
@@ -1023,21 +1119,21 @@ class SubscriptionManager:
total_cost = 0
for row in rows:
- summary[row['resource_type']] = {
- "quantity": row['total_quantity'],
- "cost": row['total_cost'],
- "records": row['record_count']
+ summary[row["resource_type"]] = {
+ "quantity": row["total_quantity"],
+ "cost": row["total_cost"],
+ "records": row["record_count"],
}
- total_cost += row['total_cost']
+ total_cost += row["total_cost"]
return {
"tenant_id": tenant_id,
"period": {
"start": start_date.isoformat() if start_date else None,
- "end": end_date.isoformat() if end_date else None
+ "end": end_date.isoformat() if end_date else None,
},
"breakdown": summary,
- "total_cost": total_cost
+ "total_cost": total_cost,
}
finally:
@@ -1060,11 +1156,17 @@ class SubscriptionManager:
# ==================== 支付管理 ====================
- def create_payment(self, tenant_id: str, amount: float, currency: str,
- provider: str, subscription_id: str | None = None,
- invoice_id: str | None = None,
- payment_method: str | None = None,
- payment_details: dict | None = None) -> Payment:
+ def create_payment(
+ self,
+ tenant_id: str,
+ amount: float,
+ currency: str,
+ provider: str,
+ subscription_id: str | None = None,
+ invoice_id: str | None = None,
+ payment_method: str | None = None,
+ payment_details: dict | None = None,
+ ) -> Payment:
"""创建支付记录"""
conn = self._get_connection()
try:
@@ -1087,24 +1189,37 @@ class SubscriptionManager:
failed_at=None,
failure_reason=None,
created_at=now,
- updated_at=now
+ updated_at=now,
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO payments
(id, tenant_id, subscription_id, invoice_id, amount, currency,
provider, provider_payment_id, status, payment_method, payment_details,
paid_at, failed_at, failure_reason, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- payment.id, payment.tenant_id, payment.subscription_id,
- payment.invoice_id, payment.amount, payment.currency,
- payment.provider, payment.provider_payment_id, payment.status,
- payment.payment_method, json.dumps(payment.payment_details),
- payment.paid_at, payment.failed_at, payment.failure_reason,
- payment.created_at, payment.updated_at
- ))
+ """,
+ (
+ payment.id,
+ payment.tenant_id,
+ payment.subscription_id,
+ payment.invoice_id,
+ payment.amount,
+ payment.currency,
+ payment.provider,
+ payment.provider_payment_id,
+ payment.status,
+ payment.payment_method,
+ json.dumps(payment.payment_details),
+ payment.paid_at,
+ payment.failed_at,
+ payment.failure_reason,
+ payment.created_at,
+ payment.updated_at,
+ ),
+ )
conn.commit()
return payment
@@ -1112,8 +1227,7 @@ class SubscriptionManager:
finally:
conn.close()
- def confirm_payment(self, payment_id: str,
- provider_payment_id: str | None = None) -> Payment | None:
+ def confirm_payment(self, payment_id: str, provider_payment_id: str | None = None) -> Payment | None:
"""确认支付完成"""
conn = self._get_connection()
try:
@@ -1124,33 +1238,47 @@ class SubscriptionManager:
now = datetime.now()
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE payments
SET status = 'completed', provider_payment_id = ?, paid_at = ?, updated_at = ?
WHERE id = ?
- """, (provider_payment_id, now, now, payment_id))
+ """,
+ (provider_payment_id, now, now, payment_id),
+ )
# 如果有关联发票,更新发票状态
if payment.invoice_id:
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE invoices
SET status = 'paid', amount_paid = amount_due, paid_at = ?
WHERE id = ?
- """, (now, payment.invoice_id))
+ """,
+ (now, payment.invoice_id),
+ )
# 如果有关联订阅,激活订阅
if payment.subscription_id:
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE subscriptions
SET status = 'active', updated_at = ?
WHERE id = ? AND status = 'pending'
- """, (now, payment.subscription_id))
+ """,
+ (now, payment.subscription_id),
+ )
# 记录账单历史
self._add_billing_history_internal(
- conn, payment.tenant_id, "payment", payment.amount,
- payment.currency, f"Payment completed via {payment.provider}",
- payment_id, 0 # 余额更新应该在账户管理中处理
+ conn,
+ payment.tenant_id,
+ "payment",
+ payment.amount,
+ payment.currency,
+ f"Payment completed via {payment.provider}",
+ payment_id,
+ 0, # 余额更新应该在账户管理中处理
)
conn.commit()
@@ -1167,11 +1295,14 @@ class SubscriptionManager:
now = datetime.now()
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE payments
SET status = 'failed', failure_reason = ?, failed_at = ?, updated_at = ?
WHERE id = ?
- """, (reason, now, now, payment_id))
+ """,
+ (reason, now, now, payment_id),
+ )
conn.commit()
return self._get_payment_internal(conn, payment_id)
@@ -1187,8 +1318,9 @@ class SubscriptionManager:
finally:
conn.close()
- def list_payments(self, tenant_id: str, status: str | None = None,
- limit: int = 100, offset: int = 0) -> list[Payment]:
+ def list_payments(
+ self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0
+ ) -> list[Payment]:
"""列出支付记录"""
conn = self._get_connection()
try:
@@ -1224,11 +1356,18 @@ class SubscriptionManager:
# ==================== 发票管理 ====================
- def _create_invoice_internal(self, conn: sqlite3.Connection, tenant_id: str,
- subscription_id: str | None, amount: float,
- currency: str, period_start: datetime,
- period_end: datetime, description: str,
- line_items: list[dict] | None = None) -> Invoice:
+ def _create_invoice_internal(
+ self,
+ conn: sqlite3.Connection,
+ tenant_id: str,
+ subscription_id: str | None,
+ amount: float,
+ currency: str,
+ period_start: datetime,
+ period_end: datetime,
+ description: str,
+ line_items: list[dict] | None = None,
+ ) -> Invoice:
"""内部方法:创建发票"""
invoice_id = str(uuid.uuid4())
invoice_number = self._generate_invoice_number()
@@ -1253,25 +1392,39 @@ class SubscriptionManager:
voided_at=None,
void_reason=None,
created_at=now,
- updated_at=now
+ updated_at=now,
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO invoices
(id, tenant_id, subscription_id, invoice_number, status, amount_due, amount_paid,
currency, period_start, period_end, description, line_items, due_date,
paid_at, voided_at, void_reason, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- invoice.id, invoice.tenant_id, invoice.subscription_id,
- invoice.invoice_number, invoice.status, invoice.amount_due,
- invoice.amount_paid, invoice.currency, invoice.period_start,
- invoice.period_end, invoice.description,
- json.dumps(invoice.line_items), invoice.due_date,
- invoice.paid_at, invoice.voided_at, invoice.void_reason,
- invoice.created_at, invoice.updated_at
- ))
+ """,
+ (
+ invoice.id,
+ invoice.tenant_id,
+ invoice.subscription_id,
+ invoice.invoice_number,
+ invoice.status,
+ invoice.amount_due,
+ invoice.amount_paid,
+ invoice.currency,
+ invoice.period_start,
+ invoice.period_end,
+ invoice.description,
+ json.dumps(invoice.line_items),
+ invoice.due_date,
+ invoice.paid_at,
+ invoice.voided_at,
+ invoice.void_reason,
+ invoice.created_at,
+ invoice.updated_at,
+ ),
+ )
return invoice
@@ -1305,8 +1458,9 @@ class SubscriptionManager:
finally:
conn.close()
- def list_invoices(self, tenant_id: str, status: str | None = None,
- limit: int = 100, offset: int = 0) -> list[Invoice]:
+ def list_invoices(
+ self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0
+ ) -> list[Invoice]:
"""列出发票"""
conn = self._get_connection()
try:
@@ -1344,11 +1498,14 @@ class SubscriptionManager:
now = datetime.now()
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE invoices
SET status = 'void', voided_at = ?, void_reason = ?, updated_at = ?
WHERE id = ?
- """, (now, reason, now, invoice_id))
+ """,
+ (now, reason, now, invoice_id),
+ )
conn.commit()
return self.get_invoice(invoice_id)
@@ -1364,12 +1521,15 @@ class SubscriptionManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT COUNT(*) as count FROM invoices
WHERE invoice_number LIKE ?
- """, (f"{prefix}%",))
+ """,
+ (f"{prefix}%",),
+ )
row = cursor.fetchone()
- count = row['count'] + 1
+ count = row["count"] + 1
return f"{prefix}-{count:06d}"
@@ -1378,8 +1538,7 @@ class SubscriptionManager:
# ==================== 退款管理 ====================
- def request_refund(self, tenant_id: str, payment_id: str, amount: float,
- reason: str, requested_by: str) -> Refund:
+ def request_refund(self, tenant_id: str, payment_id: str, amount: float, reason: str, requested_by: str) -> Refund:
"""申请退款"""
conn = self._get_connection()
try:
@@ -1417,23 +1576,38 @@ class SubscriptionManager:
provider_refund_id=None,
metadata={},
created_at=now,
- updated_at=now
+ updated_at=now,
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO refunds
(id, tenant_id, payment_id, invoice_id, amount, currency, reason, status,
requested_by, requested_at, approved_by, approved_at, completed_at,
provider_refund_id, metadata, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- refund.id, refund.tenant_id, refund.payment_id, refund.invoice_id,
- refund.amount, refund.currency, refund.reason, refund.status,
- refund.requested_by, refund.requested_at, refund.approved_by,
- refund.approved_at, refund.completed_at, refund.provider_refund_id,
- json.dumps(refund.metadata), refund.created_at, refund.updated_at
- ))
+ """,
+ (
+ refund.id,
+ refund.tenant_id,
+ refund.payment_id,
+ refund.invoice_id,
+ refund.amount,
+ refund.currency,
+ refund.reason,
+ refund.status,
+ refund.requested_by,
+ refund.requested_at,
+ refund.approved_by,
+ refund.approved_at,
+ refund.completed_at,
+ refund.provider_refund_id,
+ json.dumps(refund.metadata),
+ refund.created_at,
+ refund.updated_at,
+ ),
+ )
conn.commit()
logger.info(f"Refund requested: {refund_id} for payment {payment_id}")
@@ -1456,11 +1630,14 @@ class SubscriptionManager:
now = datetime.now()
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE refunds
SET status = 'approved', approved_by = ?, approved_at = ?, updated_at = ?
WHERE id = ?
- """, (approved_by, now, now, refund_id))
+ """,
+ (approved_by, now, now, refund_id),
+ )
conn.commit()
return self._get_refund_internal(conn, refund_id)
@@ -1468,8 +1645,7 @@ class SubscriptionManager:
finally:
conn.close()
- def complete_refund(self, refund_id: str,
- provider_refund_id: str | None = None) -> Refund | None:
+ def complete_refund(self, refund_id: str, provider_refund_id: str | None = None) -> Refund | None:
"""完成退款"""
conn = self._get_connection()
try:
@@ -1480,24 +1656,35 @@ class SubscriptionManager:
now = datetime.now()
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE refunds
SET status = 'completed', provider_refund_id = ?, completed_at = ?, updated_at = ?
WHERE id = ?
- """, (provider_refund_id, now, now, refund_id))
+ """,
+ (provider_refund_id, now, now, refund_id),
+ )
# 更新原支付记录状态
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE payments
SET status = 'refunded', updated_at = ?
WHERE id = ?
- """, (now, refund.payment_id))
+ """,
+ (now, refund.payment_id),
+ )
# 记录账单历史
self._add_billing_history_internal(
- conn, refund.tenant_id, "refund", -refund.amount,
- refund.currency, f"Refund processed: {refund.reason}",
- refund_id, 0
+ conn,
+ refund.tenant_id,
+ "refund",
+ -refund.amount,
+ refund.currency,
+ f"Refund processed: {refund.reason}",
+ refund_id,
+ 0,
)
conn.commit()
@@ -1518,11 +1705,14 @@ class SubscriptionManager:
now = datetime.now()
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE refunds
SET status = 'rejected', metadata = json_set(metadata, '$.rejection_reason', ?), updated_at = ?
WHERE id = ?
- """, (reason, now, refund_id))
+ """,
+ (reason, now, refund_id),
+ )
conn.commit()
return self._get_refund_internal(conn, refund_id)
@@ -1538,8 +1728,9 @@ class SubscriptionManager:
finally:
conn.close()
- def list_refunds(self, tenant_id: str, status: str | None = None,
- limit: int = 100, offset: int = 0) -> list[Refund]:
+ def list_refunds(
+ self, tenant_id: str, status: str | None = None, limit: int = 100, offset: int = 0
+ ) -> list[Refund]:
"""列出退款记录"""
conn = self._get_connection()
try:
@@ -1575,27 +1766,49 @@ class SubscriptionManager:
# ==================== 账单历史 ====================
- def _add_billing_history_internal(self, conn: sqlite3.Connection,
- tenant_id: str, type: str, amount: float,
- currency: str, description: str,
- reference_id: str, balance_after: float):
+ def _add_billing_history_internal(
+ self,
+ conn: sqlite3.Connection,
+ tenant_id: str,
+ type: str,
+ amount: float,
+ currency: str,
+ description: str,
+ reference_id: str,
+ balance_after: float,
+ ):
"""内部方法:添加账单历史"""
history_id = str(uuid.uuid4())
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO billing_history
(id, tenant_id, type, amount, currency, description, reference_id, balance_after, created_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- history_id, tenant_id, type, amount, currency,
- description, reference_id, balance_after, datetime.now(), json.dumps({})
- ))
+ """,
+ (
+ history_id,
+ tenant_id,
+ type,
+ amount,
+ currency,
+ description,
+ reference_id,
+ balance_after,
+ datetime.now(),
+ json.dumps({}),
+ ),
+ )
- def get_billing_history(self, tenant_id: str,
- start_date: datetime | None = None,
- end_date: datetime | None = None,
- limit: int = 100, offset: int = 0) -> list[BillingHistory]:
+ def get_billing_history(
+ self,
+ tenant_id: str,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ limit: int = 100,
+ offset: int = 0,
+ ) -> list[BillingHistory]:
"""获取账单历史"""
conn = self._get_connection()
try:
@@ -1624,9 +1837,9 @@ class SubscriptionManager:
# ==================== 支付提供商集成 ====================
- def create_stripe_checkout_session(self, tenant_id: str, plan_id: str,
- success_url: str, cancel_url: str,
- billing_cycle: str = "monthly") -> dict[str, Any]:
+ def create_stripe_checkout_session(
+ self, tenant_id: str, plan_id: str, success_url: str, cancel_url: str, billing_cycle: str = "monthly"
+ ) -> dict[str, Any]:
"""创建 Stripe Checkout 会话(占位实现)"""
# 这里应该集成 Stripe SDK
# 简化实现,返回模拟数据
@@ -1634,11 +1847,10 @@ class SubscriptionManager:
"session_id": f"cs_{uuid.uuid4().hex[:24]}",
"url": f"https://checkout.stripe.com/mock/{uuid.uuid4().hex[:24]}",
"status": "created",
- "provider": "stripe"
+ "provider": "stripe",
}
- def create_alipay_order(self, tenant_id: str, plan_id: str,
- billing_cycle: str = "monthly") -> dict[str, Any]:
+ def create_alipay_order(self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建支付宝订单(占位实现)"""
# 这里应该集成支付宝 SDK
plan = self.get_plan(plan_id)
@@ -1650,11 +1862,10 @@ class SubscriptionManager:
"currency": plan.currency,
"qr_code_url": f"https://qr.alipay.com/mock/{uuid.uuid4().hex[:16]}",
"status": "pending",
- "provider": "alipay"
+ "provider": "alipay",
}
- def create_wechat_order(self, tenant_id: str, plan_id: str,
- billing_cycle: str = "monthly") -> dict[str, Any]:
+ def create_wechat_order(self, tenant_id: str, plan_id: str, billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建微信支付订单(占位实现)"""
# 这里应该集成微信支付 SDK
plan = self.get_plan(plan_id)
@@ -1666,7 +1877,7 @@ class SubscriptionManager:
"currency": plan.currency,
"prepay_id": f"wx{uuid.uuid4().hex[:32]}",
"status": "pending",
- "provider": "wechat"
+ "provider": "wechat",
}
def handle_webhook(self, provider: str, payload: dict[str, Any]) -> bool:
@@ -1696,220 +1907,221 @@ class SubscriptionManager:
def _row_to_plan(self, row: sqlite3.Row) -> SubscriptionPlan:
"""数据库行转换为 SubscriptionPlan 对象"""
return SubscriptionPlan(
- id=row['id'],
- name=row['name'],
- tier=row['tier'],
- description=row['description'] or "",
- price_monthly=row['price_monthly'],
- price_yearly=row['price_yearly'],
- currency=row['currency'],
- features=json.loads(
- row['features'] or '[]'),
- limits=json.loads(
- row['limits'] or '{}'),
- is_active=bool(
- row['is_active']),
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'],
- metadata=json.loads(
- row['metadata'] or '{}'))
+ id=row["id"],
+ name=row["name"],
+ tier=row["tier"],
+ description=row["description"] or "",
+ price_monthly=row["price_monthly"],
+ price_yearly=row["price_yearly"],
+ currency=row["currency"],
+ features=json.loads(row["features"] or "[]"),
+ limits=json.loads(row["limits"] or "{}"),
+ is_active=bool(row["is_active"]),
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ metadata=json.loads(row["metadata"] or "{}"),
+ )
def _row_to_subscription(self, row: sqlite3.Row) -> Subscription:
"""数据库行转换为 Subscription 对象"""
return Subscription(
- id=row['id'],
- tenant_id=row['tenant_id'],
- plan_id=row['plan_id'],
- status=row['status'],
- current_period_start=datetime.fromisoformat(
- row['current_period_start']) if row['current_period_start'] and isinstance(
- row['current_period_start'],
- str) else row['current_period_start'],
- current_period_end=datetime.fromisoformat(
- row['current_period_end']) if row['current_period_end'] and isinstance(
- row['current_period_end'],
- str) else row['current_period_end'],
- cancel_at_period_end=bool(
- row['cancel_at_period_end']),
- canceled_at=datetime.fromisoformat(
- row['canceled_at']) if row['canceled_at'] and isinstance(
- row['canceled_at'],
- str) else row['canceled_at'],
- trial_start=datetime.fromisoformat(
- row['trial_start']) if row['trial_start'] and isinstance(
- row['trial_start'],
- str) else row['trial_start'],
- trial_end=datetime.fromisoformat(
- row['trial_end']) if row['trial_end'] and isinstance(
- row['trial_end'],
- str) else row['trial_end'],
- payment_provider=row['payment_provider'],
- provider_subscription_id=row['provider_subscription_id'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'],
- metadata=json.loads(
- row['metadata'] or '{}'))
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ plan_id=row["plan_id"],
+ status=row["status"],
+ current_period_start=(
+ datetime.fromisoformat(row["current_period_start"])
+ if row["current_period_start"] and isinstance(row["current_period_start"], str)
+ else row["current_period_start"]
+ ),
+ current_period_end=(
+ datetime.fromisoformat(row["current_period_end"])
+ if row["current_period_end"] and isinstance(row["current_period_end"], str)
+ else row["current_period_end"]
+ ),
+ cancel_at_period_end=bool(row["cancel_at_period_end"]),
+ canceled_at=(
+ datetime.fromisoformat(row["canceled_at"])
+ if row["canceled_at"] and isinstance(row["canceled_at"], str)
+ else row["canceled_at"]
+ ),
+ trial_start=(
+ datetime.fromisoformat(row["trial_start"])
+ if row["trial_start"] and isinstance(row["trial_start"], str)
+ else row["trial_start"]
+ ),
+ trial_end=(
+ datetime.fromisoformat(row["trial_end"])
+ if row["trial_end"] and isinstance(row["trial_end"], str)
+ else row["trial_end"]
+ ),
+ payment_provider=row["payment_provider"],
+ provider_subscription_id=row["provider_subscription_id"],
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ metadata=json.loads(row["metadata"] or "{}"),
+ )
def _row_to_usage(self, row: sqlite3.Row) -> UsageRecord:
"""数据库行转换为 UsageRecord 对象"""
return UsageRecord(
- id=row['id'],
- tenant_id=row['tenant_id'],
- resource_type=row['resource_type'],
- quantity=row['quantity'],
- unit=row['unit'],
- recorded_at=datetime.fromisoformat(
- row['recorded_at']) if isinstance(
- row['recorded_at'],
- str) else row['recorded_at'],
- cost=row['cost'],
- description=row['description'],
- metadata=json.loads(
- row['metadata'] or '{}'))
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ resource_type=row["resource_type"],
+ quantity=row["quantity"],
+ unit=row["unit"],
+ recorded_at=(
+ datetime.fromisoformat(row["recorded_at"])
+ if isinstance(row["recorded_at"], str)
+ else row["recorded_at"]
+ ),
+ cost=row["cost"],
+ description=row["description"],
+ metadata=json.loads(row["metadata"] or "{}"),
+ )
def _row_to_payment(self, row: sqlite3.Row) -> Payment:
"""数据库行转换为 Payment 对象"""
return Payment(
- id=row['id'],
- tenant_id=row['tenant_id'],
- subscription_id=row['subscription_id'],
- invoice_id=row['invoice_id'],
- amount=row['amount'],
- currency=row['currency'],
- provider=row['provider'],
- provider_payment_id=row['provider_payment_id'],
- status=row['status'],
- payment_method=row['payment_method'],
- payment_details=json.loads(
- row['payment_details'] or '{}'),
- paid_at=datetime.fromisoformat(
- row['paid_at']) if row['paid_at'] and isinstance(
- row['paid_at'],
- str) else row['paid_at'],
- failed_at=datetime.fromisoformat(
- row['failed_at']) if row['failed_at'] and isinstance(
- row['failed_at'],
- str) else row['failed_at'],
- failure_reason=row['failure_reason'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'])
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ subscription_id=row["subscription_id"],
+ invoice_id=row["invoice_id"],
+ amount=row["amount"],
+ currency=row["currency"],
+ provider=row["provider"],
+ provider_payment_id=row["provider_payment_id"],
+ status=row["status"],
+ payment_method=row["payment_method"],
+ payment_details=json.loads(row["payment_details"] or "{}"),
+ paid_at=(
+ datetime.fromisoformat(row["paid_at"])
+ if row["paid_at"] and isinstance(row["paid_at"], str)
+ else row["paid_at"]
+ ),
+ failed_at=(
+ datetime.fromisoformat(row["failed_at"])
+ if row["failed_at"] and isinstance(row["failed_at"], str)
+ else row["failed_at"]
+ ),
+ failure_reason=row["failure_reason"],
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ )
def _row_to_invoice(self, row: sqlite3.Row) -> Invoice:
"""数据库行转换为 Invoice 对象"""
return Invoice(
- id=row['id'],
- tenant_id=row['tenant_id'],
- subscription_id=row['subscription_id'],
- invoice_number=row['invoice_number'],
- status=row['status'],
- amount_due=row['amount_due'],
- amount_paid=row['amount_paid'],
- currency=row['currency'],
- period_start=datetime.fromisoformat(
- row['period_start']) if row['period_start'] and isinstance(
- row['period_start'],
- str) else row['period_start'],
- period_end=datetime.fromisoformat(
- row['period_end']) if row['period_end'] and isinstance(
- row['period_end'],
- str) else row['period_end'],
- description=row['description'],
- line_items=json.loads(
- row['line_items'] or '[]'),
- due_date=datetime.fromisoformat(
- row['due_date']) if row['due_date'] and isinstance(
- row['due_date'],
- str) else row['due_date'],
- paid_at=datetime.fromisoformat(
- row['paid_at']) if row['paid_at'] and isinstance(
- row['paid_at'],
- str) else row['paid_at'],
- voided_at=datetime.fromisoformat(
- row['voided_at']) if row['voided_at'] and isinstance(
- row['voided_at'],
- str) else row['voided_at'],
- void_reason=row['void_reason'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'])
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ subscription_id=row["subscription_id"],
+ invoice_number=row["invoice_number"],
+ status=row["status"],
+ amount_due=row["amount_due"],
+ amount_paid=row["amount_paid"],
+ currency=row["currency"],
+ period_start=(
+ datetime.fromisoformat(row["period_start"])
+ if row["period_start"] and isinstance(row["period_start"], str)
+ else row["period_start"]
+ ),
+ period_end=(
+ datetime.fromisoformat(row["period_end"])
+ if row["period_end"] and isinstance(row["period_end"], str)
+ else row["period_end"]
+ ),
+ description=row["description"],
+ line_items=json.loads(row["line_items"] or "[]"),
+ due_date=(
+ datetime.fromisoformat(row["due_date"])
+ if row["due_date"] and isinstance(row["due_date"], str)
+ else row["due_date"]
+ ),
+ paid_at=(
+ datetime.fromisoformat(row["paid_at"])
+ if row["paid_at"] and isinstance(row["paid_at"], str)
+ else row["paid_at"]
+ ),
+ voided_at=(
+ datetime.fromisoformat(row["voided_at"])
+ if row["voided_at"] and isinstance(row["voided_at"], str)
+ else row["voided_at"]
+ ),
+ void_reason=row["void_reason"],
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ )
def _row_to_refund(self, row: sqlite3.Row) -> Refund:
"""数据库行转换为 Refund 对象"""
return Refund(
- id=row['id'],
- tenant_id=row['tenant_id'],
- payment_id=row['payment_id'],
- invoice_id=row['invoice_id'],
- amount=row['amount'],
- currency=row['currency'],
- reason=row['reason'],
- status=row['status'],
- requested_by=row['requested_by'],
- requested_at=datetime.fromisoformat(
- row['requested_at']) if isinstance(
- row['requested_at'],
- str) else row['requested_at'],
- approved_by=row['approved_by'],
- approved_at=datetime.fromisoformat(
- row['approved_at']) if row['approved_at'] and isinstance(
- row['approved_at'],
- str) else row['approved_at'],
- completed_at=datetime.fromisoformat(
- row['completed_at']) if row['completed_at'] and isinstance(
- row['completed_at'],
- str) else row['completed_at'],
- provider_refund_id=row['provider_refund_id'],
- metadata=json.loads(
- row['metadata'] or '{}'),
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'])
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ payment_id=row["payment_id"],
+ invoice_id=row["invoice_id"],
+ amount=row["amount"],
+ currency=row["currency"],
+ reason=row["reason"],
+ status=row["status"],
+ requested_by=row["requested_by"],
+ requested_at=(
+ datetime.fromisoformat(row["requested_at"])
+ if isinstance(row["requested_at"], str)
+ else row["requested_at"]
+ ),
+ approved_by=row["approved_by"],
+ approved_at=(
+ datetime.fromisoformat(row["approved_at"])
+ if row["approved_at"] and isinstance(row["approved_at"], str)
+ else row["approved_at"]
+ ),
+ completed_at=(
+ datetime.fromisoformat(row["completed_at"])
+ if row["completed_at"] and isinstance(row["completed_at"], str)
+ else row["completed_at"]
+ ),
+ provider_refund_id=row["provider_refund_id"],
+ metadata=json.loads(row["metadata"] or "{}"),
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ )
def _row_to_billing_history(self, row: sqlite3.Row) -> BillingHistory:
"""数据库行转换为 BillingHistory 对象"""
return BillingHistory(
- id=row['id'],
- tenant_id=row['tenant_id'],
- type=row['type'],
- amount=row['amount'],
- currency=row['currency'],
- description=row['description'],
- reference_id=row['reference_id'],
- balance_after=row['balance_after'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- metadata=json.loads(
- row['metadata'] or '{}'))
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ type=row["type"],
+ amount=row["amount"],
+ currency=row["currency"],
+ description=row["description"],
+ reference_id=row["reference_id"],
+ balance_after=row["balance_after"],
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ metadata=json.loads(row["metadata"] or "{}"),
+ )
# 全局订阅管理器实例
diff --git a/backend/tenant_manager.py b/backend/tenant_manager.py
index 5b71a25..5a38250 100644
--- a/backend/tenant_manager.py
+++ b/backend/tenant_manager.py
@@ -1,4 +1,3 @@
-
"""
InsightFlow Phase 8 - 多租户 SaaS 架构管理模块
@@ -27,6 +26,7 @@ logger = logging.getLogger(__name__)
class TenantLimits:
"""租户资源限制常量"""
+
FREE_MAX_PROJECTS = 3
FREE_MAX_STORAGE_MB = 100
FREE_MAX_TRANSCRIPTION_MINUTES = 60
@@ -46,83 +46,90 @@ class TenantLimits:
class TenantStatus(StrEnum):
"""租户状态"""
- ACTIVE = "active" # 活跃
- SUSPENDED = "suspended" # 暂停
- TRIAL = "trial" # 试用
- EXPIRED = "expired" # 过期
- PENDING = "pending" # 待激活
+
+ ACTIVE = "active" # 活跃
+ SUSPENDED = "suspended" # 暂停
+ TRIAL = "trial" # 试用
+ EXPIRED = "expired" # 过期
+ PENDING = "pending" # 待激活
class TenantTier(StrEnum):
"""租户订阅层级"""
- FREE = "free" # 免费版
- PRO = "pro" # 专业版
- ENTERPRISE = "enterprise" # 企业版
+
+ FREE = "free" # 免费版
+ PRO = "pro" # 专业版
+ ENTERPRISE = "enterprise" # 企业版
class TenantRole(StrEnum):
"""租户角色"""
- OWNER = "owner" # 所有者
- ADMIN = "admin" # 管理员
- MEMBER = "member" # 成员
- VIEWER = "viewer" # 查看者
+
+ OWNER = "owner" # 所有者
+ ADMIN = "admin" # 管理员
+ MEMBER = "member" # 成员
+ VIEWER = "viewer" # 查看者
class DomainStatus(StrEnum):
"""域名状态"""
- PENDING = "pending" # 待验证
- VERIFIED = "verified" # 已验证
- FAILED = "failed" # 验证失败
- EXPIRED = "expired" # 已过期
+
+ PENDING = "pending" # 待验证
+ VERIFIED = "verified" # 已验证
+ FAILED = "failed" # 验证失败
+ EXPIRED = "expired" # 已过期
@dataclass
class Tenant:
"""租户数据类"""
+
id: str
name: str
- slug: str # URL 友好的唯一标识
+ slug: str # URL 友好的唯一标识
description: str | None
- tier: str # free/pro/enterprise
- status: str # active/suspended/trial/expired/pending
- owner_id: str # 所有者用户ID
+ tier: str # free/pro/enterprise
+ status: str # active/suspended/trial/expired/pending
+ owner_id: str # 所有者用户ID
created_at: datetime
updated_at: datetime
expires_at: datetime | None # 订阅过期时间
- settings: dict[str, Any] # 租户级设置
+ settings: dict[str, Any] # 租户级设置
resource_limits: dict[str, Any] # 资源限制
- metadata: dict[str, Any] # 元数据
+ metadata: dict[str, Any] # 元数据
@dataclass
class TenantDomain:
"""租户域名数据类"""
+
id: str
tenant_id: str
- domain: str # 自定义域名
- status: str # pending/verified/failed/expired
- verification_token: str # 验证令牌
- verification_method: str # dns/file
+ domain: str # 自定义域名
+ status: str # pending/verified/failed/expired
+ verification_token: str # 验证令牌
+ verification_method: str # dns/file
verified_at: datetime | None
created_at: datetime
updated_at: datetime
- is_primary: bool # 是否主域名
- ssl_enabled: bool # SSL 是否启用
+ is_primary: bool # 是否主域名
+ ssl_enabled: bool # SSL 是否启用
ssl_expires_at: datetime | None
@dataclass
class TenantBranding:
"""租户品牌配置数据类"""
+
id: str
tenant_id: str
- logo_url: str | None # Logo URL
- favicon_url: str | None # Favicon URL
- primary_color: str | None # 主题主色
+ logo_url: str | None # Logo URL
+ favicon_url: str | None # Favicon URL
+ primary_color: str | None # 主题主色
secondary_color: str | None # 主题次色
- custom_css: str | None # 自定义 CSS
- custom_js: str | None # 自定义 JS
- login_page_bg: str | None # 登录页背景
+ custom_css: str | None # 自定义 CSS
+ custom_js: str | None # 自定义 JS
+ login_page_bg: str | None # 登录页背景
email_template: str | None # 邮件模板
created_at: datetime
updated_at: datetime
@@ -131,30 +138,32 @@ class TenantBranding:
@dataclass
class TenantMember:
"""租户成员数据类"""
+
id: str
tenant_id: str
user_id: str
email: str
- role: str # owner/admin/member/viewer
- permissions: list[str] # 具体权限列表
- invited_by: str | None # 邀请者
+ role: str # owner/admin/member/viewer
+ permissions: list[str] # 具体权限列表
+ invited_by: str | None # 邀请者
invited_at: datetime
joined_at: datetime | None
last_active_at: datetime | None
- status: str # active/pending/suspended
+ status: str # active/pending/suspended
@dataclass
class TenantPermission:
"""租户权限定义数据类"""
+
id: str
tenant_id: str
- name: str # 权限名称
- code: str # 权限代码
+ name: str # 权限名称
+ code: str # 权限代码
description: str | None
- resource_type: str # project/entity/api/etc
- actions: list[str] # create/read/update/delete/etc
- conditions: dict | None # 条件限制
+ resource_type: str # project/entity/api/etc
+ actions: list[str] # create/read/update/delete/etc
+ conditions: dict | None # 条件限制
created_at: datetime
@@ -170,7 +179,7 @@ class TenantManager:
"max_api_calls_per_day": TenantLimits.FREE_MAX_API_CALLS_PER_DAY,
"max_team_members": TenantLimits.FREE_MAX_TEAM_MEMBERS,
"max_entities": TenantLimits.FREE_MAX_ENTITIES,
- "features": ["basic_analysis", "export_png"]
+ "features": ["basic_analysis", "export_png"],
},
TenantTier.PRO: {
"max_projects": TenantLimits.PRO_MAX_PROJECTS,
@@ -179,8 +188,14 @@ class TenantManager:
"max_api_calls_per_day": TenantLimits.PRO_MAX_API_CALLS_PER_DAY,
"max_team_members": TenantLimits.PRO_MAX_TEAM_MEMBERS,
"max_entities": TenantLimits.PRO_MAX_ENTITIES,
- "features": ["basic_analysis", "advanced_analysis", "export_all",
- "api_access", "webhooks", "collaboration"]
+ "features": [
+ "basic_analysis",
+ "advanced_analysis",
+ "export_all",
+ "api_access",
+ "webhooks",
+ "collaboration",
+ ],
},
TenantTier.ENTERPRISE: {
"max_projects": TenantLimits.UNLIMITED, # 无限制
@@ -189,27 +204,23 @@ class TenantManager:
"max_api_calls_per_day": TenantLimits.UNLIMITED,
"max_team_members": TenantLimits.UNLIMITED,
"max_entities": TenantLimits.UNLIMITED,
- "features": ["all"] # 所有功能
- }
+ "features": ["all"], # 所有功能
+ },
}
# 角色权限映射
ROLE_PERMISSIONS = {
- TenantRole.OWNER: [
- "tenant:*", "project:*", "member:*", "billing:*",
- "settings:*", "api:*", "export:*"
- ],
- TenantRole.ADMIN: [
- "tenant:read", "project:*", "member:*", "billing:read",
- "settings:*", "api:*", "export:*"
- ],
+ TenantRole.OWNER: ["tenant:*", "project:*", "member:*", "billing:*", "settings:*", "api:*", "export:*"],
+ TenantRole.ADMIN: ["tenant:read", "project:*", "member:*", "billing:read", "settings:*", "api:*", "export:*"],
TenantRole.MEMBER: [
- "tenant:read", "project:create", "project:read", "project:update",
- "member:read", "export:basic"
+ "tenant:read",
+ "project:create",
+ "project:read",
+ "project:update",
+ "member:read",
+ "export:basic",
],
- TenantRole.VIEWER: [
- "tenant:read", "project:read", "member:read"
- ]
+ TenantRole.VIEWER: ["tenant:read", "project:read", "member:read"],
}
# 权限名称映射
@@ -227,7 +238,7 @@ class TenantManager:
"settings:*": "设置完全控制",
"api:*": "API完全控制",
"export:*": "导出完全控制",
- "export:basic": "基础导出"
+ "export:basic": "基础导出",
}
def __init__(self, db_path: str = "insightflow.db"):
@@ -240,7 +251,7 @@ class TenantManager:
conn.row_factory = sqlite3.Row
return conn
- def _init_db(self):
+ def _init_db(self) -> None:
"""初始化数据库表"""
conn = self._get_connection()
try:
@@ -379,10 +390,9 @@ class TenantManager:
# ==================== 租户管理 ====================
- def create_tenant(self, name: str, owner_id: str,
- tier: str = "free",
- description: str | None = None,
- settings: dict | None = None) -> Tenant:
+ def create_tenant(
+ self, name: str, owner_id: str, tier: str = "free", description: str | None = None, settings: dict | None = None
+ ) -> Tenant:
"""创建新租户"""
conn = self._get_connection()
try:
@@ -406,21 +416,32 @@ class TenantManager:
expires_at=None,
settings=settings or {},
resource_limits=resource_limits,
- metadata={}
+ metadata={},
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO tenants (id, name, slug, description, tier, status, owner_id,
created_at, updated_at, expires_at, settings, resource_limits, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- tenant.id, tenant.name, tenant.slug, tenant.description,
- tenant.tier, tenant.status, tenant.owner_id,
- tenant.created_at, tenant.updated_at, tenant.expires_at,
- json.dumps(tenant.settings), json.dumps(tenant.resource_limits),
- json.dumps(tenant.metadata)
- ))
+ """,
+ (
+ tenant.id,
+ tenant.name,
+ tenant.slug,
+ tenant.description,
+ tenant.tier,
+ tenant.status,
+ tenant.owner_id,
+ tenant.created_at,
+ tenant.updated_at,
+ tenant.expires_at,
+ json.dumps(tenant.settings),
+ json.dumps(tenant.resource_limits),
+ json.dumps(tenant.metadata),
+ ),
+ )
# 自动将所有者添加为成员
self._add_member_internal(conn, tenant_id, owner_id, "", TenantRole.OWNER, None)
@@ -471,11 +492,14 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT t.* FROM tenants t
JOIN tenant_domains d ON t.id = d.tenant_id
WHERE d.domain = ? AND d.status = 'verified'
- """, (domain,))
+ """,
+ (domain,),
+ )
row = cursor.fetchone()
if row:
@@ -485,12 +509,15 @@ class TenantManager:
finally:
conn.close()
- def update_tenant(self, tenant_id: str,
- name: str | None = None,
- description: str | None = None,
- tier: str | None = None,
- status: str | None = None,
- settings: dict | None = None) -> Tenant | None:
+ def update_tenant(
+ self,
+ tenant_id: str,
+ name: str | None = None,
+ description: str | None = None,
+ tier: str | None = None,
+ status: str | None = None,
+ settings: dict | None = None,
+ ) -> Tenant | None:
"""更新租户信息"""
conn = self._get_connection()
try:
@@ -526,10 +553,13 @@ class TenantManager:
params.append(tenant_id)
cursor = conn.cursor()
- cursor.execute(f"""
+ cursor.execute(
+ f"""
UPDATE tenants SET {', '.join(updates)}
WHERE id = ?
- """, params)
+ """,
+ params,
+ )
conn.commit()
return self.get_tenant(tenant_id)
@@ -548,9 +578,9 @@ class TenantManager:
finally:
conn.close()
- def list_tenants(self, status: str | None = None,
- tier: str | None = None,
- limit: int = 100, offset: int = 0) -> list[Tenant]:
+ def list_tenants(
+ self, status: str | None = None, tier: str | None = None, limit: int = 100, offset: int = 0
+ ) -> list[Tenant]:
"""列出租户"""
conn = self._get_connection()
try:
@@ -579,9 +609,9 @@ class TenantManager:
# ==================== 域名管理 ====================
- def add_domain(self, tenant_id: str, domain: str,
- is_primary: bool = False,
- verification_method: str = "dns") -> TenantDomain:
+ def add_domain(
+ self, tenant_id: str, domain: str, is_primary: bool = False, verification_method: str = "dns"
+ ) -> TenantDomain:
"""为租户添加自定义域名"""
conn = self._get_connection()
try:
@@ -605,31 +635,43 @@ class TenantManager:
updated_at=datetime.now(),
is_primary=is_primary,
ssl_enabled=False,
- ssl_expires_at=None
+ ssl_expires_at=None,
)
cursor = conn.cursor()
# 如果设为主域名,取消其他主域名
if is_primary:
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE tenant_domains SET is_primary = 0
WHERE tenant_id = ?
- """, (tenant_id,))
+ """,
+ (tenant_id,),
+ )
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO tenant_domains (id, tenant_id, domain, status,
verification_token, verification_method, verified_at,
created_at, updated_at, is_primary, ssl_enabled, ssl_expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- tenant_domain.id, tenant_domain.tenant_id, tenant_domain.domain,
- tenant_domain.status, tenant_domain.verification_token,
- tenant_domain.verification_method, tenant_domain.verified_at,
- tenant_domain.created_at, tenant_domain.updated_at,
- int(tenant_domain.is_primary), int(tenant_domain.ssl_enabled),
- tenant_domain.ssl_expires_at
- ))
+ """,
+ (
+ tenant_domain.id,
+ tenant_domain.tenant_id,
+ tenant_domain.domain,
+ tenant_domain.status,
+ tenant_domain.verification_token,
+ tenant_domain.verification_method,
+ tenant_domain.verified_at,
+ tenant_domain.created_at,
+ tenant_domain.updated_at,
+ int(tenant_domain.is_primary),
+ int(tenant_domain.ssl_enabled),
+ tenant_domain.ssl_expires_at,
+ ),
+ )
conn.commit()
logger.info(f"Domain added: {domain} for tenant {tenant_id}")
@@ -649,36 +691,45 @@ class TenantManager:
cursor = conn.cursor()
# 获取域名信息
- cursor.execute("""
+ cursor.execute(
+ """
SELECT * FROM tenant_domains
WHERE id = ? AND tenant_id = ?
- """, (domain_id, tenant_id))
+ """,
+ (domain_id, tenant_id),
+ )
row = cursor.fetchone()
if not row:
return False
- domain = row['domain']
- token = row['verification_token']
- method = row['verification_method']
+ domain = row["domain"]
+ token = row["verification_token"]
+ method = row["verification_method"]
# 执行验证
is_verified = self._check_domain_verification(domain, token, method)
if is_verified:
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE tenant_domains
SET status = 'verified', verified_at = ?, updated_at = ?
WHERE id = ?
- """, (datetime.now(), datetime.now(), domain_id))
+ """,
+ (datetime.now(), datetime.now(), domain_id),
+ )
conn.commit()
logger.info(f"Domain verified: {domain}")
else:
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE tenant_domains
SET status = 'failed', updated_at = ?
WHERE id = ?
- """, (datetime.now(), domain_id))
+ """,
+ (datetime.now(), domain_id),
+ )
conn.commit()
return is_verified
@@ -700,26 +751,23 @@ class TenantManager:
if not row:
return None
- domain = row['domain']
- token = row['verification_token']
+ domain = row["domain"]
+ token = row["verification_token"]
return {
"domain": domain,
- "verification_method": row['verification_method'],
+ "verification_method": row["verification_method"],
"dns_record": {
"type": "TXT",
"name": "_insightflow",
"value": f"insightflow-verify={token}",
- "ttl": 3600
- },
- "file_verification": {
- "url": f"http://{domain}/.well-known/insightflow-verify.txt",
- "content": token
+ "ttl": 3600,
},
+ "file_verification": {"url": f"http://{domain}/.well-known/insightflow-verify.txt", "content": token},
"instructions": [
f"DNS 验证: 添加 TXT 记录 _insightflow.{domain},值为 insightflow-verify={token}",
- f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt,内容为 {token}"
- ]
+ f"文件验证: 在网站根目录创建 .well-known/insightflow-verify.txt,内容为 {token}",
+ ],
}
finally:
@@ -730,10 +778,13 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
DELETE FROM tenant_domains
WHERE id = ? AND tenant_id = ?
- """, (domain_id, tenant_id))
+ """,
+ (domain_id, tenant_id),
+ )
conn.commit()
return cursor.rowcount > 0
finally:
@@ -744,11 +795,14 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT * FROM tenant_domains
WHERE tenant_id = ?
ORDER BY is_primary DESC, created_at DESC
- """, (tenant_id,))
+ """,
+ (tenant_id,),
+ )
rows = cursor.fetchall()
return [self._row_to_domain(row) for row in rows]
@@ -773,15 +827,18 @@ class TenantManager:
finally:
conn.close()
- def update_branding(self, tenant_id: str,
- logo_url: str | None = None,
- favicon_url: str | None = None,
- primary_color: str | None = None,
- secondary_color: str | None = None,
- custom_css: str | None = None,
- custom_js: str | None = None,
- login_page_bg: str | None = None,
- email_template: str | None = None) -> TenantBranding:
+ def update_branding(
+ self,
+ tenant_id: str,
+ logo_url: str | None = None,
+ favicon_url: str | None = None,
+ primary_color: str | None = None,
+ secondary_color: str | None = None,
+ custom_css: str | None = None,
+ custom_js: str | None = None,
+ login_page_bg: str | None = None,
+ email_template: str | None = None,
+ ) -> TenantBranding:
"""更新租户品牌配置"""
conn = self._get_connection()
try:
@@ -825,23 +882,38 @@ class TenantManager:
params.append(datetime.now())
params.append(tenant_id)
- cursor.execute(f"""
+ cursor.execute(
+ f"""
UPDATE tenant_branding SET {', '.join(updates)}
WHERE tenant_id = ?
- """, params)
+ """,
+ params,
+ )
else:
# 创建
branding_id = str(uuid.uuid4())
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO tenant_branding
(id, tenant_id, logo_url, favicon_url, primary_color, secondary_color,
custom_css, custom_js, login_page_bg, email_template, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- branding_id, tenant_id, logo_url, favicon_url, primary_color,
- secondary_color, custom_css, custom_js, login_page_bg, email_template,
- datetime.now(), datetime.now()
- ))
+ """,
+ (
+ branding_id,
+ tenant_id,
+ logo_url,
+ favicon_url,
+ primary_color,
+ secondary_color,
+ custom_css,
+ custom_js,
+ login_page_bg,
+ email_template,
+ datetime.now(),
+ datetime.now(),
+ ),
+ )
conn.commit()
return self.get_branding(tenant_id)
@@ -889,8 +961,9 @@ class TenantManager:
# ==================== 成员与权限管理 ====================
- def invite_member(self, tenant_id: str, email: str, role: str,
- invited_by: str, permissions: list[str] | None = None) -> TenantMember:
+ def invite_member(
+ self, tenant_id: str, email: str, role: str, invited_by: str, permissions: list[str] | None = None
+ ) -> TenantMember:
"""邀请成员加入租户"""
conn = self._get_connection()
try:
@@ -912,21 +985,31 @@ class TenantManager:
invited_at=datetime.now(),
joined_at=None,
last_active_at=None,
- status="pending"
+ status="pending",
)
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO tenant_members
(id, tenant_id, user_id, email, role, permissions, invited_by,
invited_at, joined_at, last_active_at, status)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- member.id, member.tenant_id, member.user_id, member.email,
- member.role, json.dumps(member.permissions), member.invited_by,
- member.invited_at, member.joined_at, member.last_active_at,
- member.status
- ))
+ """,
+ (
+ member.id,
+ member.tenant_id,
+ member.user_id,
+ member.email,
+ member.role,
+ json.dumps(member.permissions),
+ member.invited_by,
+ member.invited_at,
+ member.joined_at,
+ member.last_active_at,
+ member.status,
+ ),
+ )
conn.commit()
logger.info(f"Member invited: {email} to tenant {tenant_id}")
@@ -940,11 +1023,14 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE tenant_members
SET user_id = ?, status = 'active', joined_at = ?
WHERE id = ? AND status = 'pending'
- """, (user_id, datetime.now(), invitation_id))
+ """,
+ (user_id, datetime.now(), invitation_id),
+ )
conn.commit()
return cursor.rowcount > 0
@@ -957,17 +1043,21 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
DELETE FROM tenant_members
WHERE id = ? AND tenant_id = ?
- """, (member_id, tenant_id))
+ """,
+ (member_id, tenant_id),
+ )
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
- def update_member_role(self, tenant_id: str, member_id: str,
- role: str, permissions: list[str] | None = None) -> bool:
+ def update_member_role(
+ self, tenant_id: str, member_id: str, role: str, permissions: list[str] | None = None
+ ) -> bool:
"""更新成员角色"""
conn = self._get_connection()
try:
@@ -976,11 +1066,14 @@ class TenantManager:
final_permissions = permissions or default_permissions
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
UPDATE tenant_members
SET role = ?, permissions = ?, updated_at = ?
WHERE id = ? AND tenant_id = ?
- """, (role, json.dumps(final_permissions), datetime.now(), member_id, tenant_id))
+ """,
+ (role, json.dumps(final_permissions), datetime.now(), member_id, tenant_id),
+ )
conn.commit()
return cursor.rowcount > 0
@@ -1011,23 +1104,25 @@ class TenantManager:
finally:
conn.close()
- def check_permission(self, tenant_id: str, user_id: str,
- resource: str, action: str) -> bool:
+ def check_permission(self, tenant_id: str, user_id: str, resource: str, action: str) -> bool:
"""检查用户是否有特定权限"""
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT role, permissions FROM tenant_members
WHERE tenant_id = ? AND user_id = ? AND status = 'active'
- """, (tenant_id, user_id))
+ """,
+ (tenant_id, user_id),
+ )
row = cursor.fetchone()
if not row:
return False
- role = row['role']
- permissions = json.loads(row['permissions'] or '[]')
+ role = row["role"]
+ permissions = json.loads(row["permissions"] or "[]")
# 所有者拥有所有权限
if role == TenantRole.OWNER.value:
@@ -1047,23 +1142,22 @@ class TenantManager:
conn = self._get_connection()
try:
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT t.*, m.role, m.status as member_status
FROM tenants t
JOIN tenant_members m ON t.id = m.tenant_id
WHERE m.user_id = ? AND m.status = 'active'
ORDER BY t.created_at DESC
- """, (user_id,))
+ """,
+ (user_id,),
+ )
rows = cursor.fetchall()
result = []
for row in rows:
tenant = self._row_to_tenant(row)
- result.append({
- **asdict(tenant),
- "member_role": row['role'],
- "member_status": row['member_status']
- })
+ result.append({**asdict(tenant), "member_role": row["role"], "member_status": row["member_status"]})
return result
finally:
@@ -1071,13 +1165,16 @@ class TenantManager:
# ==================== 资源使用统计 ====================
- def record_usage(self, tenant_id: str,
- storage_bytes: int = 0,
- transcription_seconds: int = 0,
- api_calls: int = 0,
- projects_count: int = 0,
- entities_count: int = 0,
- members_count: int = 0):
+ def record_usage(
+ self,
+ tenant_id: str,
+ storage_bytes: int = 0,
+ transcription_seconds: int = 0,
+ api_calls: int = 0,
+ projects_count: int = 0,
+ entities_count: int = 0,
+ members_count: int = 0,
+ ):
"""记录资源使用"""
conn = self._get_connection()
try:
@@ -1085,7 +1182,8 @@ class TenantManager:
usage_id = str(uuid.uuid4())
cursor = conn.cursor()
- cursor.execute("""
+ cursor.execute(
+ """
INSERT INTO tenant_usage
(id, tenant_id, date, storage_bytes, transcription_seconds, api_calls,
projects_count, entities_count, members_count)
@@ -1097,19 +1195,28 @@ class TenantManager:
projects_count = MAX(projects_count, excluded.projects_count),
entities_count = MAX(entities_count, excluded.entities_count),
members_count = MAX(members_count, excluded.members_count)
- """, (
- usage_id, tenant_id, today, storage_bytes, transcription_seconds,
- api_calls, projects_count, entities_count, members_count
- ))
+ """,
+ (
+ usage_id,
+ tenant_id,
+ today,
+ storage_bytes,
+ transcription_seconds,
+ api_calls,
+ projects_count,
+ entities_count,
+ members_count,
+ ),
+ )
conn.commit()
finally:
conn.close()
- def get_usage_stats(self, tenant_id: str,
- start_date: datetime | None = None,
- end_date: datetime | None = None) -> dict[str, Any]:
+ def get_usage_stats(
+ self, tenant_id: str, start_date: datetime | None = None, end_date: datetime | None = None
+ ) -> dict[str, Any]:
"""获取使用统计"""
conn = self._get_connection()
try:
@@ -1143,23 +1250,29 @@ class TenantManager:
limits = tenant.resource_limits if tenant else {}
return {
- "storage_bytes": row['total_storage'] or 0,
- "storage_mb": (row['total_storage'] or 0) / (1024 * 1024),
- "transcription_seconds": row['total_transcription'] or 0,
- "transcription_minutes": (row['total_transcription'] or 0) / 60,
- "api_calls": row['total_api_calls'] or 0,
- "projects_count": row['max_projects'] or 0,
- "entities_count": row['max_entities'] or 0,
- "members_count": row['max_members'] or 0,
+ "storage_bytes": row["total_storage"] or 0,
+ "storage_mb": (row["total_storage"] or 0) / (1024 * 1024),
+ "transcription_seconds": row["total_transcription"] or 0,
+ "transcription_minutes": (row["total_transcription"] or 0) / 60,
+ "api_calls": row["total_api_calls"] or 0,
+ "projects_count": row["max_projects"] or 0,
+ "entities_count": row["max_entities"] or 0,
+ "members_count": row["max_members"] or 0,
"limits": limits,
"usage_percentages": {
- "storage": self._calc_percentage(row['total_storage'] or 0, limits.get('max_storage_mb', 0) * 1024 * 1024),
- "transcription": self._calc_percentage(row['total_transcription'] or 0, limits.get('max_transcription_minutes', 0) * 60),
- "api_calls": self._calc_percentage(row['total_api_calls'] or 0, limits.get('max_api_calls_per_day', 0)),
- "projects": self._calc_percentage(row['max_projects'] or 0, limits.get('max_projects', 0)),
- "entities": self._calc_percentage(row['max_entities'] or 0, limits.get('max_entities', 0)),
- "members": self._calc_percentage(row['max_members'] or 0, limits.get('max_team_members', 0))
- }
+ "storage": self._calc_percentage(
+ row["total_storage"] or 0, limits.get("max_storage_mb", 0) * 1024 * 1024
+ ),
+ "transcription": self._calc_percentage(
+ row["total_transcription"] or 0, limits.get("max_transcription_minutes", 0) * 60
+ ),
+ "api_calls": self._calc_percentage(
+ row["total_api_calls"] or 0, limits.get("max_api_calls_per_day", 0)
+ ),
+ "projects": self._calc_percentage(row["max_projects"] or 0, limits.get("max_projects", 0)),
+ "entities": self._calc_percentage(row["max_entities"] or 0, limits.get("max_entities", 0)),
+ "members": self._calc_percentage(row["max_members"] or 0, limits.get("max_team_members", 0)),
+ },
}
finally:
@@ -1179,12 +1292,12 @@ class TenantManager:
stats = self.get_usage_stats(tenant_id)
resource_map = {
- "storage": ("storage_mb", stats['storage_mb']),
- "transcription": ("max_transcription_minutes", stats['transcription_minutes']),
- "api_calls": ("max_api_calls_per_day", stats['api_calls']),
- "projects": ("max_projects", stats['projects_count']),
- "entities": ("max_entities", stats['entities_count']),
- "members": ("max_team_members", stats['members_count'])
+ "storage": ("storage_mb", stats["storage_mb"]),
+ "transcription": ("max_transcription_minutes", stats["transcription_minutes"]),
+ "api_calls": ("max_api_calls_per_day", stats["api_calls"]),
+ "projects": ("max_projects", stats["projects_count"]),
+ "entities": ("max_entities", stats["entities_count"]),
+ "members": ("max_team_members", stats["members_count"]),
}
if resource_type not in resource_map:
@@ -1204,8 +1317,8 @@ class TenantManager:
def _generate_slug(self, name: str) -> str:
"""生成 URL 友好的 slug"""
# 转换为小写,替换空格为连字符
- slug = re.sub(r'[^\w\s-]', '', name.lower())
- slug = re.sub(r'[-\s]+', '-', slug)
+ slug = re.sub(r"[^\w\s-]", "", name.lower())
+ slug = re.sub(r"[-\s]+", "-", slug)
# 检查是否已存在
conn = self._get_connection()
@@ -1233,7 +1346,7 @@ class TenantManager:
def _validate_domain(self, domain: str) -> bool:
"""验证域名格式"""
- pattern = r'^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$'
+ pattern = r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$"
return bool(re.match(pattern, domain))
def _check_domain_verification(self, domain: str, token: str, method: str) -> bool:
@@ -1269,7 +1382,7 @@ class TenantManager:
def _darken_color(self, hex_color: str, percent: int) -> str:
"""加深颜色"""
- hex_color = hex_color.lstrip('#')
+ hex_color = hex_color.lstrip("#")
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
@@ -1286,133 +1399,147 @@ class TenantManager:
return 0.0 if limit == 0 else 100.0
return min(100.0, round(current / limit * 100, 2))
- def _add_member_internal(self, conn: sqlite3.Connection, tenant_id: str,
- user_id: str, email: str, role: TenantRole,
- invited_by: str | None):
+ def _add_member_internal(
+ self,
+ conn: sqlite3.Connection,
+ tenant_id: str,
+ user_id: str,
+ email: str,
+ role: TenantRole,
+ invited_by: str | None,
+ ):
"""内部方法:添加成员"""
cursor = conn.cursor()
member_id = str(uuid.uuid4())
- cursor.execute("""
+ cursor.execute(
+ """
INSERT OR IGNORE INTO tenant_members
(id, tenant_id, user_id, email, role, permissions, invited_by,
invited_at, joined_at, last_active_at, status)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- member_id, tenant_id, user_id, email, role.value,
- json.dumps(self.ROLE_PERMISSIONS.get(role, [])),
- invited_by, datetime.now(), datetime.now(), datetime.now(), "active"
- ))
+ """,
+ (
+ member_id,
+ tenant_id,
+ user_id,
+ email,
+ role.value,
+ json.dumps(self.ROLE_PERMISSIONS.get(role, [])),
+ invited_by,
+ datetime.now(),
+ datetime.now(),
+ datetime.now(),
+ "active",
+ ),
+ )
def _row_to_tenant(self, row: sqlite3.Row) -> Tenant:
"""数据库行转换为 Tenant 对象"""
return Tenant(
- id=row['id'],
- name=row['name'],
- slug=row['slug'],
- description=row['description'],
- tier=row['tier'],
- status=row['status'],
- owner_id=row['owner_id'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'],
- expires_at=datetime.fromisoformat(
- row['expires_at']) if row['expires_at'] and isinstance(
- row['expires_at'],
- str) else row['expires_at'],
- settings=json.loads(
- row['settings'] or '{}'),
- resource_limits=json.loads(
- row['resource_limits'] or '{}'),
- metadata=json.loads(
- row['metadata'] or '{}'))
+ id=row["id"],
+ name=row["name"],
+ slug=row["slug"],
+ description=row["description"],
+ tier=row["tier"],
+ status=row["status"],
+ owner_id=row["owner_id"],
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ expires_at=(
+ datetime.fromisoformat(row["expires_at"])
+ if row["expires_at"] and isinstance(row["expires_at"], str)
+ else row["expires_at"]
+ ),
+ settings=json.loads(row["settings"] or "{}"),
+ resource_limits=json.loads(row["resource_limits"] or "{}"),
+ metadata=json.loads(row["metadata"] or "{}"),
+ )
def _row_to_domain(self, row: sqlite3.Row) -> TenantDomain:
"""数据库行转换为 TenantDomain 对象"""
return TenantDomain(
- id=row['id'],
- tenant_id=row['tenant_id'],
- domain=row['domain'],
- status=row['status'],
- verification_token=row['verification_token'],
- verification_method=row['verification_method'],
- verified_at=datetime.fromisoformat(
- row['verified_at']) if row['verified_at'] and isinstance(
- row['verified_at'],
- str) else row['verified_at'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'],
- is_primary=bool(
- row['is_primary']),
- ssl_enabled=bool(
- row['ssl_enabled']),
- ssl_expires_at=datetime.fromisoformat(
- row['ssl_expires_at']) if row['ssl_expires_at'] and isinstance(
- row['ssl_expires_at'],
- str) else row['ssl_expires_at'])
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ domain=row["domain"],
+ status=row["status"],
+ verification_token=row["verification_token"],
+ verification_method=row["verification_method"],
+ verified_at=(
+ datetime.fromisoformat(row["verified_at"])
+ if row["verified_at"] and isinstance(row["verified_at"], str)
+ else row["verified_at"]
+ ),
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ is_primary=bool(row["is_primary"]),
+ ssl_enabled=bool(row["ssl_enabled"]),
+ ssl_expires_at=(
+ datetime.fromisoformat(row["ssl_expires_at"])
+ if row["ssl_expires_at"] and isinstance(row["ssl_expires_at"], str)
+ else row["ssl_expires_at"]
+ ),
+ )
def _row_to_branding(self, row: sqlite3.Row) -> TenantBranding:
"""数据库行转换为 TenantBranding 对象"""
return TenantBranding(
- id=row['id'],
- tenant_id=row['tenant_id'],
- logo_url=row['logo_url'],
- favicon_url=row['favicon_url'],
- primary_color=row['primary_color'],
- secondary_color=row['secondary_color'],
- custom_css=row['custom_css'],
- custom_js=row['custom_js'],
- login_page_bg=row['login_page_bg'],
- email_template=row['email_template'],
- created_at=datetime.fromisoformat(
- row['created_at']) if isinstance(
- row['created_at'],
- str) else row['created_at'],
- updated_at=datetime.fromisoformat(
- row['updated_at']) if isinstance(
- row['updated_at'],
- str) else row['updated_at'])
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ logo_url=row["logo_url"],
+ favicon_url=row["favicon_url"],
+ primary_color=row["primary_color"],
+ secondary_color=row["secondary_color"],
+ custom_css=row["custom_css"],
+ custom_js=row["custom_js"],
+ login_page_bg=row["login_page_bg"],
+ email_template=row["email_template"],
+ created_at=(
+ datetime.fromisoformat(row["created_at"]) if isinstance(row["created_at"], str) else row["created_at"]
+ ),
+ updated_at=(
+ datetime.fromisoformat(row["updated_at"]) if isinstance(row["updated_at"], str) else row["updated_at"]
+ ),
+ )
def _row_to_member(self, row: sqlite3.Row) -> TenantMember:
"""数据库行转换为 TenantMember 对象"""
return TenantMember(
- id=row['id'],
- tenant_id=row['tenant_id'],
- user_id=row['user_id'],
- email=row['email'],
- role=row['role'],
- permissions=json.loads(
- row['permissions'] or '[]'),
- invited_by=row['invited_by'],
- invited_at=datetime.fromisoformat(
- row['invited_at']) if isinstance(
- row['invited_at'],
- str) else row['invited_at'],
- joined_at=datetime.fromisoformat(
- row['joined_at']) if row['joined_at'] and isinstance(
- row['joined_at'],
- str) else row['joined_at'],
- last_active_at=datetime.fromisoformat(
- row['last_active_at']) if row['last_active_at'] and isinstance(
- row['last_active_at'],
- str) else row['last_active_at'],
- status=row['status'])
+ id=row["id"],
+ tenant_id=row["tenant_id"],
+ user_id=row["user_id"],
+ email=row["email"],
+ role=row["role"],
+ permissions=json.loads(row["permissions"] or "[]"),
+ invited_by=row["invited_by"],
+ invited_at=(
+ datetime.fromisoformat(row["invited_at"]) if isinstance(row["invited_at"], str) else row["invited_at"]
+ ),
+ joined_at=(
+ datetime.fromisoformat(row["joined_at"])
+ if row["joined_at"] and isinstance(row["joined_at"], str)
+ else row["joined_at"]
+ ),
+ last_active_at=(
+ datetime.fromisoformat(row["last_active_at"])
+ if row["last_active_at"] and isinstance(row["last_active_at"], str)
+ else row["last_active_at"]
+ ),
+ status=row["status"],
+ )
# ==================== 租户上下文管理 ====================
+
class TenantContext:
"""租户上下文管理器 - 用于请求级别的租户隔离"""
@@ -1420,7 +1547,7 @@ class TenantContext:
_current_user_id: str | None = None
@classmethod
- def set_current_tenant(cls, tenant_id: str):
+ def set_current_tenant(cls, tenant_id: str) -> None:
"""设置当前租户上下文"""
cls._current_tenant_id = tenant_id
@@ -1430,7 +1557,7 @@ class TenantContext:
return cls._current_tenant_id
@classmethod
- def set_current_user(cls, user_id: str):
+ def set_current_user(cls, user_id: str) -> None:
"""设置当前用户"""
cls._current_user_id = user_id
@@ -1440,7 +1567,7 @@ class TenantContext:
return cls._current_user_id
@classmethod
- def clear(cls):
+ def clear(cls) -> None:
"""清除上下文"""
cls._current_tenant_id = None
cls._current_user_id = None
diff --git a/backend/tingwu_client.py b/backend/tingwu_client.py
index fab65c7..9ff3862 100644
--- a/backend/tingwu_client.py
+++ b/backend/tingwu_client.py
@@ -20,7 +20,7 @@ class TingwuClient:
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
"""阿里云签名 V3"""
- timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
+ timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
# 简化签名,实际生产需要完整实现
# 这里使用基础认证头
@@ -39,25 +39,16 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
- config = open_api_models.Config(
- access_key_id=self.access_key,
- access_key_secret=self.secret_key
- )
+ config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
request = tingwu_models.CreateTaskRequest(
type="offline",
- input=tingwu_models.Input(
- source="OSS",
- file_url=audio_url
- ),
+ input=tingwu_models.Input(source="OSS", file_url=audio_url),
parameters=tingwu_models.Parameters(
- transcription=tingwu_models.Transcription(
- diarization_enabled=True,
- sentence_max_length=20
- )
- )
+ transcription=tingwu_models.Transcription(diarization_enabled=True, sentence_max_length=20)
+ ),
)
response = client.create_task(request)
@@ -81,10 +72,7 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
- config = open_api_models.Config(
- access_key_id=self.access_key,
- access_key_secret=self.secret_key
- )
+ config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config)
@@ -128,25 +116,29 @@ class TingwuClient:
if transcription.sentences:
for sent in transcription.sentences:
- segments.append({
- "start": sent.begin_time / 1000,
- "end": sent.end_time / 1000,
- "text": sent.text,
- "speaker": f"Speaker {sent.speaker_id}"
- })
+ segments.append(
+ {
+ "start": sent.begin_time / 1000,
+ "end": sent.end_time / 1000,
+ "text": sent.text,
+ "speaker": f"Speaker {sent.speaker_id}",
+ }
+ )
- return {
- "full_text": full_text.strip(),
- "segments": segments
- }
+ return {"full_text": full_text.strip(), "segments": segments}
def _mock_result(self) -> dict[str, Any]:
"""Mock 结果"""
return {
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
"segments": [
- {"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"}
- ]
+ {
+ "start": 0.0,
+ "end": 5.0,
+ "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
+ "speaker": "Speaker A",
+ }
+ ],
}
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
diff --git a/backend/workflow_manager.py b/backend/workflow_manager.py
index 2404140..cb87aa4 100644
--- a/backend/workflow_manager.py
+++ b/backend/workflow_manager.py
@@ -10,8 +10,12 @@ InsightFlow Workflow Manager - Phase 7
"""
import asyncio
+import base64
+import hashlib
+import hmac
import json
import logging
+import urllib.parse
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
@@ -32,6 +36,7 @@ logger = logging.getLogger(__name__)
class WorkflowStatus(Enum):
"""工作流状态"""
+
ACTIVE = "active"
PAUSED = "paused"
ERROR = "error"
@@ -40,15 +45,17 @@ class WorkflowStatus(Enum):
class WorkflowType(Enum):
"""工作流类型"""
+
AUTO_ANALYZE = "auto_analyze" # 自动分析新文件
- AUTO_ALIGN = "auto_align" # 自动实体对齐
+ AUTO_ALIGN = "auto_align" # 自动实体对齐
AUTO_RELATION = "auto_relation" # 自动关系发现
SCHEDULED_REPORT = "scheduled_report" # 定时报告
- CUSTOM = "custom" # 自定义工作流
+ CUSTOM = "custom" # 自定义工作流
class WebhookType(Enum):
"""Webhook 类型"""
+
FEISHU = "feishu"
DINGTALK = "dingtalk"
SLACK = "slack"
@@ -57,6 +64,7 @@ class WebhookType(Enum):
class TaskStatus(Enum):
"""任务执行状态"""
+
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
@@ -67,6 +75,7 @@ class TaskStatus(Enum):
@dataclass
class WorkflowTask:
"""工作流任务定义"""
+
id: str
workflow_id: str
name: str
@@ -90,6 +99,7 @@ class WorkflowTask:
@dataclass
class WebhookConfig:
"""Webhook 配置"""
+
id: str
name: str
webhook_type: str # feishu, dingtalk, slack, custom
@@ -114,6 +124,7 @@ class WebhookConfig:
@dataclass
class Workflow:
"""工作流定义"""
+
id: str
name: str
description: str
@@ -143,6 +154,7 @@ class Workflow:
@dataclass
class WorkflowLog:
"""工作流执行日志"""
+
id: str
workflow_id: str
task_id: str | None = None
@@ -186,20 +198,13 @@ class WebhookNotifier:
async def _send_feishu(self, config: WebhookConfig, message: dict) -> bool:
"""发送飞书通知"""
- import base64
- import hashlib
- import hmac
-
timestamp = str(int(datetime.now().timestamp()))
# 签名计算
if config.secret:
string_to_sign = f"{timestamp}\n{config.secret}"
- hmac_code = hmac.new(
- string_to_sign.encode('utf-8'),
- digestmod=hashlib.sha256
- ).digest()
- sign = base64.b64encode(hmac_code).decode('utf-8')
+ hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
+ sign = base64.b64encode(hmac_code).decode("utf-8")
else:
sign = ""
@@ -210,9 +215,7 @@ class WebhookNotifier:
"timestamp": timestamp,
"sign": sign,
"msg_type": "text",
- "content": {
- "text": message["content"]
- }
+ "content": {"text": message["content"]},
}
elif "title" in message:
# 富文本消息
@@ -220,34 +223,15 @@ class WebhookNotifier:
"timestamp": timestamp,
"sign": sign,
"msg_type": "post",
- "content": {
- "post": {
- "zh_cn": {
- "title": message.get("title", ""),
- "content": message.get("body", [])
- }
- }
- }
+ "content": {"post": {"zh_cn": {"title": message.get("title", ""), "content": message.get("body", [])}}},
}
else:
# 卡片消息
- payload = {
- "timestamp": timestamp,
- "sign": sign,
- "msg_type": "interactive",
- "card": message.get("card", {})
- }
+ payload = {"timestamp": timestamp, "sign": sign, "msg_type": "interactive", "card": message.get("card", {})}
- headers = {
- "Content-Type": "application/json",
- **config.headers
- }
+ headers = {"Content-Type": "application/json", **config.headers}
- response = await self.http_client.post(
- config.url,
- json=payload,
- headers=headers
- )
+ response = await self.http_client.post(config.url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
@@ -255,18 +239,13 @@ class WebhookNotifier:
async def _send_dingtalk(self, config: WebhookConfig, message: dict) -> bool:
"""发送钉钉通知"""
- import base64
- import hashlib
- import hmac
- import urllib.parse
-
timestamp = str(round(datetime.now().timestamp() * 1000))
# 签名计算
if config.secret:
- secret_enc = config.secret.encode('utf-8')
+ secret_enc = config.secret.encode("utf-8")
string_to_sign = f"{timestamp}\n{config.secret}"
- hmac_code = hmac.new(secret_enc, string_to_sign.encode('utf-8'), digestmod=hashlib.sha256).digest()
+ hmac_code = hmac.new(secret_enc, string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
url = f"{config.url}×tamp={timestamp}&sign={sign}"
else:
@@ -274,19 +253,11 @@ class WebhookNotifier:
# 构建消息体
if "content" in message:
- payload = {
- "msgtype": "text",
- "text": {
- "content": message["content"]
- }
- }
+ payload = {"msgtype": "text", "text": {"content": message["content"]}}
elif "title" in message:
payload = {
"msgtype": "markdown",
- "markdown": {
- "title": message["title"],
- "text": message.get("markdown", "")
- }
+ "markdown": {"title": message["title"], "text": message.get("markdown", "")},
}
elif "link" in message:
payload = {
@@ -295,19 +266,13 @@ class WebhookNotifier:
"text": message.get("text", ""),
"title": message["title"],
"picUrl": message.get("pic_url", ""),
- "messageUrl": message["link"]
- }
+ "messageUrl": message["link"],
+ },
}
else:
- payload = {
- "msgtype": "action_card",
- "action_card": message.get("action_card", {})
- }
+ payload = {"msgtype": "action_card", "action_card": message.get("action_card", {})}
- headers = {
- "Content-Type": "application/json",
- **config.headers
- }
+ headers = {"Content-Type": "application/json", **config.headers}
response = await self.http_client.post(url, json=payload, headers=headers)
response.raise_for_status()
@@ -328,32 +293,18 @@ class WebhookNotifier:
if "attachments" in message:
payload["attachments"] = message["attachments"]
- headers = {
- "Content-Type": "application/json",
- **config.headers
- }
+ headers = {"Content-Type": "application/json", **config.headers}
- response = await self.http_client.post(
- config.url,
- json=payload,
- headers=headers
- )
+ response = await self.http_client.post(config.url, json=payload, headers=headers)
response.raise_for_status()
return response.text == "ok"
async def _send_custom(self, config: WebhookConfig, message: dict) -> bool:
"""发送自定义 Webhook 通知"""
- headers = {
- "Content-Type": "application/json",
- **config.headers
- }
+ headers = {"Content-Type": "application/json", **config.headers}
- response = await self.http_client.post(
- config.url,
- json=message,
- headers=headers
- )
+ response = await self.http_client.post(config.url, json=message, headers=headers)
response.raise_for_status()
return True
@@ -380,12 +331,9 @@ class WorkflowManager:
self._setup_default_handlers()
# 添加调度器事件监听
- self.scheduler.add_listener(
- self._on_job_executed,
- EVENT_JOB_EXECUTED | EVENT_JOB_ERROR
- )
+ self.scheduler.add_listener(self._on_job_executed, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
- def _setup_default_handlers(self):
+ def _setup_default_handlers(self) -> None:
"""设置默认的任务处理器"""
self._task_handlers = {
"analyze": self._handle_analyze_task,
@@ -395,11 +343,11 @@ class WorkflowManager:
"custom": self._handle_custom_task,
}
- def register_task_handler(self, task_type: str, handler: Callable):
+ def register_task_handler(self, task_type: str, handler: Callable) -> None:
"""注册自定义任务处理器"""
self._task_handlers[task_type] = handler
- def start(self):
+ def start(self) -> None:
"""启动工作流管理器"""
if not self.scheduler.running:
self.scheduler.start()
@@ -409,7 +357,7 @@ class WorkflowManager:
if self.db:
asyncio.create_task(self._load_and_schedule_workflows())
- def stop(self):
+ def stop(self) -> None:
"""停止工作流管理器"""
if self.scheduler.running:
self.scheduler.shutdown(wait=True)
@@ -425,7 +373,7 @@ class WorkflowManager:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Failed to load workflows: {e}")
- def _schedule_workflow(self, workflow: Workflow):
+ def _schedule_workflow(self, workflow: Workflow) -> None:
"""调度工作流"""
job_id = f"workflow_{workflow.id}"
@@ -450,7 +398,7 @@ class WorkflowManager:
args=[workflow.id],
replace_existing=True,
max_instances=1,
- coalesce=True
+ coalesce=True,
)
logger.info(f"Scheduled workflow {workflow.id} ({workflow.name}) with {workflow.schedule_type}")
@@ -462,7 +410,7 @@ class WorkflowManager:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Scheduled workflow execution failed: {e}")
- def _on_job_executed(self, event):
+ def _on_job_executed(self, event) -> None:
"""调度器事件处理"""
if event.exception:
logger.error(f"Job {event.job_id} failed: {event.exception}")
@@ -482,11 +430,26 @@ class WorkflowManager:
created_at, updated_at, last_run_at, next_run_at,
run_count, success_count, fail_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (workflow.id, workflow.name, workflow.description, workflow.workflow_type,
- workflow.project_id, workflow.status, workflow.schedule, workflow.schedule_type,
- json.dumps(workflow.config), json.dumps(workflow.webhook_ids), workflow.is_active,
- workflow.created_at, workflow.updated_at, workflow.last_run_at, workflow.next_run_at,
- workflow.run_count, workflow.success_count, workflow.fail_count)
+ (
+ workflow.id,
+ workflow.name,
+ workflow.description,
+ workflow.workflow_type,
+ workflow.project_id,
+ workflow.status,
+ workflow.schedule,
+ workflow.schedule_type,
+ json.dumps(workflow.config),
+ json.dumps(workflow.webhook_ids),
+ workflow.is_active,
+ workflow.created_at,
+ workflow.updated_at,
+ workflow.last_run_at,
+ workflow.next_run_at,
+ workflow.run_count,
+ workflow.success_count,
+ workflow.fail_count,
+ ),
)
conn.commit()
@@ -502,10 +465,7 @@ class WorkflowManager:
"""获取工作流"""
conn = self.db.get_conn()
try:
- row = conn.execute(
- "SELECT * FROM workflows WHERE id = ?",
- (workflow_id,)
- ).fetchone()
+ row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id,)).fetchone()
if not row:
return None
@@ -514,8 +474,7 @@ class WorkflowManager:
finally:
conn.close()
- def list_workflows(self, project_id: str = None, status: str = None,
- workflow_type: str = None) -> list[Workflow]:
+ def list_workflows(self, project_id: str = None, status: str = None, workflow_type: str = None) -> list[Workflow]:
"""列出工作流"""
conn = self.db.get_conn()
try:
@@ -535,8 +494,7 @@ class WorkflowManager:
where_clause = " AND ".join(conditions) if conditions else "1=1"
rows = conn.execute(
- f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC",
- params
+ f"SELECT * FROM workflows WHERE {where_clause} ORDER BY created_at DESC", params
).fetchall()
return [self._row_to_workflow(row) for row in rows]
@@ -547,15 +505,23 @@ class WorkflowManager:
"""更新工作流"""
conn = self.db.get_conn()
try:
- allowed_fields = ['name', 'description', 'status', 'schedule',
- 'schedule_type', 'is_active', 'config', 'webhook_ids']
+ allowed_fields = [
+ "name",
+ "description",
+ "status",
+ "schedule",
+ "schedule_type",
+ "is_active",
+ "config",
+ "webhook_ids",
+ ]
updates = []
values = []
for f in allowed_fields:
if f in kwargs:
updates.append(f"{f} = ?")
- if f in ['config', 'webhook_ids']:
+ if f in ["config", "webhook_ids"]:
values.append(json.dumps(kwargs[f]))
else:
values.append(kwargs[f])
@@ -607,24 +573,24 @@ class WorkflowManager:
def _row_to_workflow(self, row) -> Workflow:
"""将数据库行转换为 Workflow 对象"""
return Workflow(
- id=row['id'],
- name=row['name'],
- description=row['description'] or "",
- workflow_type=row['workflow_type'],
- project_id=row['project_id'],
- status=row['status'],
- schedule=row['schedule'],
- schedule_type=row['schedule_type'],
- config=json.loads(row['config']) if row['config'] else {},
- webhook_ids=json.loads(row['webhook_ids']) if row['webhook_ids'] else [],
- is_active=bool(row['is_active']),
- created_at=row['created_at'],
- updated_at=row['updated_at'],
- last_run_at=row['last_run_at'],
- next_run_at=row['next_run_at'],
- run_count=row['run_count'] or 0,
- success_count=row['success_count'] or 0,
- fail_count=row['fail_count'] or 0
+ id=row["id"],
+ name=row["name"],
+ description=row["description"] or "",
+ workflow_type=row["workflow_type"],
+ project_id=row["project_id"],
+ status=row["status"],
+ schedule=row["schedule"],
+ schedule_type=row["schedule_type"],
+ config=json.loads(row["config"]) if row["config"] else {},
+ webhook_ids=json.loads(row["webhook_ids"]) if row["webhook_ids"] else [],
+ is_active=bool(row["is_active"]),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ last_run_at=row["last_run_at"],
+ next_run_at=row["next_run_at"],
+ run_count=row["run_count"] or 0,
+ success_count=row["success_count"] or 0,
+ fail_count=row["fail_count"] or 0,
)
# ==================== Workflow Task CRUD ====================
@@ -639,10 +605,20 @@ class WorkflowManager:
depends_on, timeout_seconds, retry_count, retry_delay,
created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (task.id, task.workflow_id, task.name, task.task_type,
- json.dumps(task.config), task.order, json.dumps(task.depends_on),
- task.timeout_seconds, task.retry_count, task.retry_delay,
- task.created_at, task.updated_at)
+ (
+ task.id,
+ task.workflow_id,
+ task.name,
+ task.task_type,
+ json.dumps(task.config),
+ task.order,
+ json.dumps(task.depends_on),
+ task.timeout_seconds,
+ task.retry_count,
+ task.retry_delay,
+ task.created_at,
+ task.updated_at,
+ ),
)
conn.commit()
return task
@@ -653,10 +629,7 @@ class WorkflowManager:
"""获取任务"""
conn = self.db.get_conn()
try:
- row = conn.execute(
- "SELECT * FROM workflow_tasks WHERE id = ?",
- (task_id,)
- ).fetchone()
+ row = conn.execute("SELECT * FROM workflow_tasks WHERE id = ?", (task_id,)).fetchone()
if not row:
return None
@@ -670,8 +643,7 @@ class WorkflowManager:
conn = self.db.get_conn()
try:
rows = conn.execute(
- "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order",
- (workflow_id,)
+ "SELECT * FROM workflow_tasks WHERE workflow_id = ? ORDER BY task_order", (workflow_id,)
).fetchall()
return [self._row_to_task(row) for row in rows]
@@ -682,15 +654,23 @@ class WorkflowManager:
"""更新任务"""
conn = self.db.get_conn()
try:
- allowed_fields = ['name', 'task_type', 'config', 'task_order',
- 'depends_on', 'timeout_seconds', 'retry_count', 'retry_delay']
+ allowed_fields = [
+ "name",
+ "task_type",
+ "config",
+ "task_order",
+ "depends_on",
+ "timeout_seconds",
+ "retry_count",
+ "retry_delay",
+ ]
updates = []
values = []
for f in allowed_fields:
if f in kwargs:
updates.append(f"{f} = ?")
- if f in ['config', 'depends_on']:
+ if f in ["config", "depends_on"]:
values.append(json.dumps(kwargs[f]))
else:
values.append(kwargs[f])
@@ -723,18 +703,18 @@ class WorkflowManager:
def _row_to_task(self, row) -> WorkflowTask:
"""将数据库行转换为 WorkflowTask 对象"""
return WorkflowTask(
- id=row['id'],
- workflow_id=row['workflow_id'],
- name=row['name'],
- task_type=row['task_type'],
- config=json.loads(row['config']) if row['config'] else {},
- order=row['task_order'] or 0,
- depends_on=json.loads(row['depends_on']) if row['depends_on'] else [],
- timeout_seconds=row['timeout_seconds'] or 300,
- retry_count=row['retry_count'] or 3,
- retry_delay=row['retry_delay'] or 5,
- created_at=row['created_at'],
- updated_at=row['updated_at']
+ id=row["id"],
+ workflow_id=row["workflow_id"],
+ name=row["name"],
+ task_type=row["task_type"],
+ config=json.loads(row["config"]) if row["config"] else {},
+ order=row["task_order"] or 0,
+ depends_on=json.loads(row["depends_on"]) if row["depends_on"] else [],
+ timeout_seconds=row["timeout_seconds"] or 300,
+ retry_count=row["retry_count"] or 3,
+ retry_delay=row["retry_delay"] or 5,
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
)
# ==================== Webhook Config CRUD ====================
@@ -749,10 +729,21 @@ class WorkflowManager:
is_active, created_at, updated_at, last_used_at,
success_count, fail_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (webhook.id, webhook.name, webhook.webhook_type, webhook.url,
- webhook.secret, json.dumps(webhook.headers), webhook.template,
- webhook.is_active, webhook.created_at, webhook.updated_at,
- webhook.last_used_at, webhook.success_count, webhook.fail_count)
+ (
+ webhook.id,
+ webhook.name,
+ webhook.webhook_type,
+ webhook.url,
+ webhook.secret,
+ json.dumps(webhook.headers),
+ webhook.template,
+ webhook.is_active,
+ webhook.created_at,
+ webhook.updated_at,
+ webhook.last_used_at,
+ webhook.success_count,
+ webhook.fail_count,
+ ),
)
conn.commit()
return webhook
@@ -763,10 +754,7 @@ class WorkflowManager:
"""获取 Webhook 配置"""
conn = self.db.get_conn()
try:
- row = conn.execute(
- "SELECT * FROM webhook_configs WHERE id = ?",
- (webhook_id,)
- ).fetchone()
+ row = conn.execute("SELECT * FROM webhook_configs WHERE id = ?", (webhook_id,)).fetchone()
if not row:
return None
@@ -779,9 +767,7 @@ class WorkflowManager:
"""列出所有 Webhook 配置"""
conn = self.db.get_conn()
try:
- rows = conn.execute(
- "SELECT * FROM webhook_configs ORDER BY created_at DESC"
- ).fetchall()
+ rows = conn.execute("SELECT * FROM webhook_configs ORDER BY created_at DESC").fetchall()
return [self._row_to_webhook(row) for row in rows]
finally:
@@ -791,15 +777,14 @@ class WorkflowManager:
"""更新 Webhook 配置"""
conn = self.db.get_conn()
try:
- allowed_fields = ['name', 'webhook_type', 'url', 'secret',
- 'headers', 'template', 'is_active']
+ allowed_fields = ["name", "webhook_type", "url", "secret", "headers", "template", "is_active"]
updates = []
values = []
for f in allowed_fields:
if f in kwargs:
updates.append(f"{f} = ?")
- if f == 'headers':
+ if f == "headers":
values.append(json.dumps(kwargs[f]))
else:
values.append(kwargs[f])
@@ -829,7 +814,7 @@ class WorkflowManager:
finally:
conn.close()
- def update_webhook_stats(self, webhook_id: str, success: bool):
+ def update_webhook_stats(self, webhook_id: str, success: bool) -> None:
"""更新 Webhook 统计"""
conn = self.db.get_conn()
try:
@@ -838,14 +823,14 @@ class WorkflowManager:
"""UPDATE webhook_configs
SET success_count = success_count + 1, last_used_at = ?
WHERE id = ?""",
- (datetime.now().isoformat(), webhook_id)
+ (datetime.now().isoformat(), webhook_id),
)
else:
conn.execute(
"""UPDATE webhook_configs
SET fail_count = fail_count + 1, last_used_at = ?
WHERE id = ?""",
- (datetime.now().isoformat(), webhook_id)
+ (datetime.now().isoformat(), webhook_id),
)
conn.commit()
finally:
@@ -854,19 +839,19 @@ class WorkflowManager:
def _row_to_webhook(self, row) -> WebhookConfig:
"""将数据库行转换为 WebhookConfig 对象"""
return WebhookConfig(
- id=row['id'],
- name=row['name'],
- webhook_type=row['webhook_type'],
- url=row['url'],
- secret=row['secret'] or "",
- headers=json.loads(row['headers']) if row['headers'] else {},
- template=row['template'] or "",
- is_active=bool(row['is_active']),
- created_at=row['created_at'],
- updated_at=row['updated_at'],
- last_used_at=row['last_used_at'],
- success_count=row['success_count'] or 0,
- fail_count=row['fail_count'] or 0
+ id=row["id"],
+ name=row["name"],
+ webhook_type=row["webhook_type"],
+ url=row["url"],
+ secret=row["secret"] or "",
+ headers=json.loads(row["headers"]) if row["headers"] else {},
+ template=row["template"] or "",
+ is_active=bool(row["is_active"]),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ last_used_at=row["last_used_at"],
+ success_count=row["success_count"] or 0,
+ fail_count=row["fail_count"] or 0,
)
# ==================== Workflow Log ====================
@@ -880,10 +865,19 @@ class WorkflowManager:
(id, workflow_id, task_id, status, start_time, end_time,
duration_ms, input_data, output_data, error_message, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (log.id, log.workflow_id, log.task_id, log.status,
- log.start_time, log.end_time, log.duration_ms,
- json.dumps(log.input_data), json.dumps(log.output_data),
- log.error_message, log.created_at)
+ (
+ log.id,
+ log.workflow_id,
+ log.task_id,
+ log.status,
+ log.start_time,
+ log.end_time,
+ log.duration_ms,
+ json.dumps(log.input_data),
+ json.dumps(log.output_data),
+ log.error_message,
+ log.created_at,
+ ),
)
conn.commit()
return log
@@ -894,15 +888,14 @@ class WorkflowManager:
"""更新工作流日志"""
conn = self.db.get_conn()
try:
- allowed_fields = ['status', 'end_time', 'duration_ms',
- 'output_data', 'error_message']
+ allowed_fields = ["status", "end_time", "duration_ms", "output_data", "error_message"]
updates = []
values = []
for f in allowed_fields:
if f in kwargs:
updates.append(f"{f} = ?")
- if f == 'output_data':
+ if f == "output_data":
values.append(json.dumps(kwargs[f]))
else:
values.append(kwargs[f])
@@ -923,10 +916,7 @@ class WorkflowManager:
"""获取日志"""
conn = self.db.get_conn()
try:
- row = conn.execute(
- "SELECT * FROM workflow_logs WHERE id = ?",
- (log_id,)
- ).fetchone()
+ row = conn.execute("SELECT * FROM workflow_logs WHERE id = ?", (log_id,)).fetchone()
if not row:
return None
@@ -935,8 +925,9 @@ class WorkflowManager:
finally:
conn.close()
- def list_logs(self, workflow_id: str = None, task_id: str = None,
- status: str = None, limit: int = 100, offset: int = 0) -> list[WorkflowLog]:
+ def list_logs(
+ self, workflow_id: str = None, task_id: str = None, status: str = None, limit: int = 100, offset: int = 0
+ ) -> list[WorkflowLog]:
"""列出工作流日志"""
conn = self.db.get_conn()
try:
@@ -960,7 +951,7 @@ class WorkflowManager:
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT ? OFFSET ?""",
- params + [limit, offset]
+ params + [limit, offset],
).fetchall()
return [self._row_to_log(row) for row in rows]
@@ -975,27 +966,29 @@ class WorkflowManager:
# 总执行次数
total = conn.execute(
- "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?",
- (workflow_id, since)
+ "SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?", (workflow_id, since)
).fetchone()[0]
# 成功次数
success = conn.execute(
"SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'success' AND created_at > ?",
- (workflow_id, since)
+ (workflow_id, since),
).fetchone()[0]
# 失败次数
failed = conn.execute(
"SELECT COUNT(*) FROM workflow_logs WHERE workflow_id = ? AND status = 'failed' AND created_at > ?",
- (workflow_id, since)
+ (workflow_id, since),
).fetchone()[0]
# 平均执行时间
- avg_duration = conn.execute(
- "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?",
- (workflow_id, since)
- ).fetchone()[0] or 0
+ avg_duration = (
+ conn.execute(
+ "SELECT AVG(duration_ms) FROM workflow_logs WHERE workflow_id = ? AND created_at > ?",
+ (workflow_id, since),
+ ).fetchone()[0]
+ or 0
+ )
# 每日统计
daily = conn.execute(
@@ -1006,7 +999,7 @@ class WorkflowManager:
WHERE workflow_id = ? AND created_at > ?
GROUP BY DATE(created_at)
ORDER BY date""",
- (workflow_id, since)
+ (workflow_id, since),
).fetchall()
return {
@@ -1015,7 +1008,7 @@ class WorkflowManager:
"failed": failed,
"success_rate": round(success / total * 100, 2) if total > 0 else 0,
"avg_duration_ms": round(avg_duration, 2),
- "daily": [{"date": r["date"], "count": r["count"], "success": r["success"]} for r in daily]
+ "daily": [{"date": r["date"], "count": r["count"], "success": r["success"]} for r in daily],
}
finally:
conn.close()
@@ -1023,17 +1016,17 @@ class WorkflowManager:
def _row_to_log(self, row) -> WorkflowLog:
"""将数据库行转换为 WorkflowLog 对象"""
return WorkflowLog(
- id=row['id'],
- workflow_id=row['workflow_id'],
- task_id=row['task_id'],
- status=row['status'],
- start_time=row['start_time'],
- end_time=row['end_time'],
- duration_ms=row['duration_ms'] or 0,
- input_data=json.loads(row['input_data']) if row['input_data'] else {},
- output_data=json.loads(row['output_data']) if row['output_data'] else {},
- error_message=row['error_message'] or "",
- created_at=row['created_at']
+ id=row["id"],
+ workflow_id=row["workflow_id"],
+ task_id=row["task_id"],
+ status=row["status"],
+ start_time=row["start_time"],
+ end_time=row["end_time"],
+ duration_ms=row["duration_ms"] or 0,
+ input_data=json.loads(row["input_data"]) if row["input_data"] else {},
+ output_data=json.loads(row["output_data"]) if row["output_data"] else {},
+ error_message=row["error_message"] or "",
+ created_at=row["created_at"],
)
# ==================== Workflow Execution ====================
@@ -1049,8 +1042,7 @@ class WorkflowManager:
# 更新最后运行时间
now = datetime.now().isoformat()
- self.update_workflow(workflow_id, last_run_at=now,
- run_count=workflow.run_count + 1)
+ self.update_workflow(workflow_id, last_run_at=now, run_count=workflow.run_count + 1)
# 创建工作流执行日志
log = WorkflowLog(
@@ -1058,7 +1050,7 @@ class WorkflowManager:
workflow_id=workflow_id,
status=TaskStatus.RUNNING.value,
start_time=now,
- input_data=input_data or {}
+ input_data=input_data or {},
)
self.create_log(log)
@@ -1087,7 +1079,7 @@ class WorkflowManager:
status=TaskStatus.SUCCESS.value,
end_time=end_time.isoformat(),
duration_ms=duration,
- output_data=results
+ output_data=results,
)
# 更新成功计数
@@ -1098,7 +1090,7 @@ class WorkflowManager:
"workflow_id": workflow_id,
"log_id": log.id,
"results": results,
- "duration_ms": duration
+ "duration_ms": duration,
}
except (TimeoutError, httpx.HTTPError) as e:
@@ -1112,7 +1104,7 @@ class WorkflowManager:
status=TaskStatus.FAILED.value,
end_time=end_time.isoformat(),
duration_ms=duration,
- error_message=str(e)
+ error_message=str(e),
)
# 更新失败计数
@@ -1123,8 +1115,7 @@ class WorkflowManager:
raise
- async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask],
- input_data: dict, log_id: str) -> dict:
+ async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask], input_data: dict, log_id: str) -> dict:
"""按依赖顺序执行任务"""
results = {}
completed_tasks = set()
@@ -1132,9 +1123,7 @@ class WorkflowManager:
while len(completed_tasks) < len(tasks):
# 找到可以执行的任务(依赖已完成)
ready_tasks = [
- t for t in tasks
- if t.id not in completed_tasks
- and all(dep in completed_tasks for dep in t.depends_on)
+ t for t in tasks if t.id not in completed_tasks and all(dep in completed_tasks for dep in t.depends_on)
]
if not ready_tasks:
@@ -1171,8 +1160,7 @@ class WorkflowManager:
return results
- async def _execute_single_task(self, task: WorkflowTask,
- input_data: dict, log_id: str) -> Any:
+ async def _execute_single_task(self, task: WorkflowTask, input_data: dict, log_id: str) -> Any:
"""执行单个任务"""
handler = self._task_handlers.get(task.task_type)
if not handler:
@@ -1185,23 +1173,20 @@ class WorkflowManager:
task_id=task.id,
status=TaskStatus.RUNNING.value,
start_time=datetime.now().isoformat(),
- input_data=input_data
+ input_data=input_data,
)
self.create_log(task_log)
try:
# 设置超时
- result = await asyncio.wait_for(
- handler(task, input_data),
- timeout=task.timeout_seconds
- )
+ result = await asyncio.wait_for(handler(task, input_data), timeout=task.timeout_seconds)
# 更新任务日志为成功
self.update_log(
task_log.id,
status=TaskStatus.SUCCESS.value,
end_time=datetime.now().isoformat(),
- output_data={"result": result} if not isinstance(result, dict) else result
+ output_data={"result": result} if not isinstance(result, dict) else result,
)
return result
@@ -1211,21 +1196,17 @@ class WorkflowManager:
task_log.id,
status=TaskStatus.FAILED.value,
end_time=datetime.now().isoformat(),
- error_message="Task timeout"
+ error_message="Task timeout",
)
raise TimeoutError(f"Task {task.id} timed out after {task.timeout_seconds}s")
except Exception as e:
self.update_log(
- task_log.id,
- status=TaskStatus.FAILED.value,
- end_time=datetime.now().isoformat(),
- error_message=str(e)
+ task_log.id, status=TaskStatus.FAILED.value, end_time=datetime.now().isoformat(), error_message=str(e)
)
raise
- async def _execute_default_workflow(self, workflow: Workflow,
- input_data: dict) -> dict:
+ async def _execute_default_workflow(self, workflow: Workflow, input_data: dict) -> dict:
"""执行默认工作流(根据类型)"""
workflow_type = WorkflowType(workflow.workflow_type)
@@ -1252,12 +1233,7 @@ class WorkflowManager:
# 这里调用现有的文件分析逻辑
# 实际实现需要与 main.py 中的 upload_audio 逻辑集成
- return {
- "task": "analyze",
- "project_id": project_id,
- "files_processed": len(file_ids),
- "status": "completed"
- }
+ return {"task": "analyze", "project_id": project_id, "files_processed": len(file_ids), "status": "completed"}
async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理实体对齐任务"""
@@ -1273,11 +1249,10 @@ class WorkflowManager:
"project_id": project_id,
"threshold": threshold,
"entities_merged": 0, # 实际实现需要调用对齐逻辑
- "status": "completed"
+ "status": "completed",
}
- async def _handle_discover_relations_task(self, task: WorkflowTask,
- input_data: dict) -> dict:
+ async def _handle_discover_relations_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理关系发现任务"""
project_id = input_data.get("project_id")
@@ -1289,7 +1264,7 @@ class WorkflowManager:
"task": "discover_relations",
"project_id": project_id,
"relations_found": 0, # 实际实现需要调用关系发现逻辑
- "status": "completed"
+ "status": "completed",
}
async def _handle_notify_task(self, task: WorkflowTask, input_data: dict) -> dict:
@@ -1314,21 +1289,12 @@ class WorkflowManager:
success = await self.notifier.send(webhook, message)
self.update_webhook_stats(webhook_id, success)
- return {
- "task": "notify",
- "webhook_id": webhook_id,
- "success": success
- }
+ return {"task": "notify", "webhook_id": webhook_id, "success": success}
async def _handle_custom_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理自定义任务"""
# 自定义任务的具体逻辑由外部处理器实现
- return {
- "task": "custom",
- "task_name": task.name,
- "config": task.config,
- "status": "completed"
- }
+ return {"task": "custom", "task_name": task.name, "config": task.config, "status": "completed"}
# ==================== Default Workflow Implementations ====================
@@ -1344,7 +1310,7 @@ class WorkflowManager:
"files_analyzed": 0,
"entities_extracted": 0,
"relations_extracted": 0,
- "status": "completed"
+ "status": "completed",
}
async def _auto_align_entities(self, workflow: Workflow, input_data: dict) -> dict:
@@ -1357,7 +1323,7 @@ class WorkflowManager:
"project_id": project_id,
"threshold": threshold,
"entities_merged": 0,
- "status": "completed"
+ "status": "completed",
}
async def _auto_discover_relations(self, workflow: Workflow, input_data: dict) -> dict:
@@ -1368,7 +1334,7 @@ class WorkflowManager:
"workflow_type": "auto_relation",
"project_id": project_id,
"relations_discovered": 0,
- "status": "completed"
+ "status": "completed",
}
async def _generate_scheduled_report(self, workflow: Workflow, input_data: dict) -> dict:
@@ -1380,13 +1346,12 @@ class WorkflowManager:
"workflow_type": "scheduled_report",
"project_id": project_id,
"report_type": report_type,
- "status": "completed"
+ "status": "completed",
}
# ==================== Notification ====================
- async def _send_workflow_notification(self, workflow: Workflow,
- results: dict, success: bool = True):
+ async def _send_workflow_notification(self, workflow: Workflow, results: dict, success: bool = True):
"""发送工作流执行通知"""
if not workflow.webhook_ids:
return
@@ -1409,7 +1374,7 @@ class WorkflowManager:
"workflow_name": workflow.name,
"status": "success" if success else "failed",
"results": results,
- "timestamp": datetime.now().isoformat()
+ "timestamp": datetime.now().isoformat(),
}
try:
@@ -1418,8 +1383,7 @@ class WorkflowManager:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Failed to send notification to {webhook_id}: {e}")
- def _build_feishu_message(self, workflow: Workflow, results: dict,
- success: bool) -> dict:
+ def _build_feishu_message(self, workflow: Workflow, results: dict, success: bool) -> dict:
"""构建飞书消息"""
status_text = "✅ 成功" if success else "❌ 失败"
@@ -1429,11 +1393,10 @@ class WorkflowManager:
[{"tag": "text", "text": f"工作流: {workflow.name}"}],
[{"tag": "text", "text": f"状态: {status_text}"}],
[{"tag": "text", "text": f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"}],
- ]
+ ],
}
- def _build_dingtalk_message(self, workflow: Workflow, results: dict,
- success: bool) -> dict:
+ def _build_dingtalk_message(self, workflow: Workflow, results: dict, success: bool) -> dict:
"""构建钉钉消息"""
status_text = "✅ 成功" if success else "❌ 失败"
@@ -1451,11 +1414,10 @@ class WorkflowManager:
```json
{json.dumps(results, ensure_ascii=False, indent=2)}
```
-"""
+""",
}
- def _build_slack_message(self, workflow: Workflow, results: dict,
- success: bool) -> dict:
+ def _build_slack_message(self, workflow: Workflow, results: dict, success: bool) -> dict:
"""构建 Slack 消息"""
color = "#36a64f" if success else "#ff0000"
status_text = "Success" if success else "Failed"
@@ -1467,10 +1429,10 @@ class WorkflowManager:
"title": f"Workflow Execution: {workflow.name}",
"fields": [
{"title": "Status", "value": status_text, "short": True},
- {"title": "Time", "value": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "short": True}
+ {"title": "Time", "value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "short": True},
],
"footer": "InsightFlow",
- "ts": int(datetime.now().timestamp())
+ "ts": int(datetime.now().timestamp()),
}
]
}