fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
This commit is contained in:
OpenClaw Bot
2026-02-27 18:09:24 +08:00
parent 646b64daf7
commit 17bda3dbce
38 changed files with 1993 additions and 1972 deletions

View File

@@ -8,25 +8,25 @@ AI 能力增强模块
- 预测性分析(趋势预测、异常检测) - 预测性分析(趋势预测、异常检测)
""" """
import os
import json
import sqlite3
import httpx
import asyncio import asyncio
import json
import os
import random import random
import sqlite3
import statistics import statistics
from typing import List, Dict, Optional import uuid
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from collections import defaultdict
import uuid import httpx
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class ModelType(str, Enum): class ModelType(StrEnum):
"""模型类型""" """模型类型"""
CUSTOM_NER = "custom_ner" # 自定义实体识别 CUSTOM_NER = "custom_ner" # 自定义实体识别
@@ -35,7 +35,7 @@ class ModelType(str, Enum):
PREDICTION = "prediction" # 预测 PREDICTION = "prediction" # 预测
class ModelStatus(str, Enum): class ModelStatus(StrEnum):
"""模型状态""" """模型状态"""
PENDING = "pending" PENDING = "pending"
@@ -45,7 +45,7 @@ class ModelStatus(str, Enum):
ARCHIVED = "archived" ARCHIVED = "archived"
class MultimodalProvider(str, Enum): class MultimodalProvider(StrEnum):
"""多模态模型提供商""" """多模态模型提供商"""
GPT4V = "gpt-4-vision" GPT4V = "gpt-4-vision"
@@ -54,7 +54,7 @@ class MultimodalProvider(str, Enum):
KIMI_VL = "kimi-vl" KIMI_VL = "kimi-vl"
class PredictionType(str, Enum): class PredictionType(StrEnum):
"""预测类型""" """预测类型"""
TREND = "trend" # 趋势预测 TREND = "trend" # 趋势预测
@@ -73,13 +73,13 @@ class CustomModel:
description: str description: str
model_type: ModelType model_type: ModelType
status: ModelStatus status: ModelStatus
training_data: Dict # 训练数据配置 training_data: dict # 训练数据配置
hyperparameters: Dict # 超参数 hyperparameters: dict # 超参数
metrics: Dict # 训练指标 metrics: dict # 训练指标
model_path: Optional[str] # 模型文件路径 model_path: str | None # 模型文件路径
created_at: str created_at: str
updated_at: str updated_at: str
trained_at: Optional[str] trained_at: str | None
created_by: str created_by: str
@@ -90,8 +90,8 @@ class TrainingSample:
id: str id: str
model_id: str model_id: str
text: str text: str
entities: List[Dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}] entities: list[dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}]
metadata: Dict metadata: dict
created_at: str created_at: str
@@ -104,9 +104,9 @@ class MultimodalAnalysis:
project_id: str project_id: str
provider: MultimodalProvider provider: MultimodalProvider
input_type: str # image, video, audio, mixed input_type: str # image, video, audio, mixed
input_urls: List[str] input_urls: list[str]
prompt: str prompt: str
result: Dict # 分析结果 result: dict # 分析结果
tokens_used: int tokens_used: int
cost: float cost: float
created_at: str created_at: str
@@ -121,9 +121,9 @@ class KnowledgeGraphRAG:
project_id: str project_id: str
name: str name: str
description: str description: str
kg_config: Dict # 知识图谱配置 kg_config: dict # 知识图谱配置
retrieval_config: Dict # 检索配置 retrieval_config: dict # 检索配置
generation_config: Dict # 生成配置 generation_config: dict # 生成配置
is_active: bool is_active: bool
created_at: str created_at: str
updated_at: str updated_at: str
@@ -136,9 +136,9 @@ class RAGQuery:
id: str id: str
rag_id: str rag_id: str
query: str query: str
context: Dict # 检索到的上下文 context: dict # 检索到的上下文
answer: str answer: str
sources: List[Dict] # 来源信息 sources: list[dict] # 来源信息
confidence: float confidence: float
tokens_used: int tokens_used: int
latency_ms: int latency_ms: int
@@ -154,11 +154,11 @@ class PredictionModel:
project_id: str project_id: str
name: str name: str
prediction_type: PredictionType prediction_type: PredictionType
target_entity_type: Optional[str] # 目标实体类型 target_entity_type: str | None # 目标实体类型
features: List[str] # 特征列表 features: list[str] # 特征列表
model_config: Dict # 模型配置 model_config: dict # 模型配置
accuracy: Optional[float] accuracy: float | None
last_trained_at: Optional[str] last_trained_at: str | None
prediction_count: int prediction_count: int
is_active: bool is_active: bool
created_at: str created_at: str
@@ -172,12 +172,12 @@ class PredictionResult:
id: str id: str
model_id: str model_id: str
prediction_type: PredictionType prediction_type: PredictionType
target_id: Optional[str] # 预测目标ID target_id: str | None # 预测目标ID
prediction_data: Dict # 预测数据 prediction_data: dict # 预测数据
confidence: float confidence: float
explanation: str # 预测解释 explanation: str # 预测解释
actual_value: Optional[str] # 实际值(用于验证) actual_value: str | None # 实际值(用于验证)
is_correct: Optional[bool] is_correct: bool | None
created_at: str created_at: str
@@ -192,8 +192,8 @@ class SmartSummary:
source_id: str source_id: str
summary_type: str # extractive, abstractive, key_points, timeline summary_type: str # extractive, abstractive, key_points, timeline
content: str content: str
key_points: List[str] key_points: list[str]
entities_mentioned: List[str] entities_mentioned: list[str]
confidence: float confidence: float
tokens_used: int tokens_used: int
created_at: str created_at: str
@@ -223,8 +223,8 @@ class AIManager:
name: str, name: str,
description: str, description: str,
model_type: ModelType, model_type: ModelType,
training_data: Dict, training_data: dict,
hyperparameters: Dict, hyperparameters: dict,
created_by: str, created_by: str,
) -> CustomModel: ) -> CustomModel:
"""创建自定义模型""" """创建自定义模型"""
@@ -277,7 +277,7 @@ class AIManager:
return model return model
def get_custom_model(self, model_id: str) -> Optional[CustomModel]: def get_custom_model(self, model_id: str) -> CustomModel | None:
"""获取自定义模型""" """获取自定义模型"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone() row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone()
@@ -288,8 +288,8 @@ class AIManager:
return self._row_to_custom_model(row) return self._row_to_custom_model(row)
def list_custom_models( def list_custom_models(
self, tenant_id: str, model_type: Optional[ModelType] = None, status: Optional[ModelStatus] = None self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None
) -> List[CustomModel]: ) -> list[CustomModel]:
"""列出自定义模型""" """列出自定义模型"""
query = "SELECT * FROM custom_models WHERE tenant_id = ?" query = "SELECT * FROM custom_models WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -308,7 +308,7 @@ class AIManager:
return [self._row_to_custom_model(row) for row in rows] return [self._row_to_custom_model(row) for row in rows]
def add_training_sample( def add_training_sample(
self, model_id: str, text: str, entities: List[Dict], metadata: Dict = None self, model_id: str, text: str, entities: list[dict], metadata: dict = None
) -> TrainingSample: ) -> TrainingSample:
"""添加训练样本""" """添加训练样本"""
sample_id = f"ts_{uuid.uuid4().hex[:16]}" sample_id = f"ts_{uuid.uuid4().hex[:16]}"
@@ -338,7 +338,7 @@ class AIManager:
return sample return sample
def get_training_samples(self, model_id: str) -> List[TrainingSample]: def get_training_samples(self, model_id: str) -> list[TrainingSample]:
"""获取训练样本""" """获取训练样本"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -410,7 +410,7 @@ class AIManager:
conn.commit() conn.commit()
raise e raise e
async def predict_with_custom_model(self, model_id: str, text: str) -> List[Dict]: async def predict_with_custom_model(self, model_id: str, text: str) -> list[dict]:
"""使用自定义模型进行预测""" """使用自定义模型进行预测"""
model = self.get_custom_model(model_id) model = self.get_custom_model(model_id)
if not model or model.status != ModelStatus.READY: if not model or model.status != ModelStatus.READY:
@@ -461,7 +461,7 @@ class AIManager:
project_id: str, project_id: str,
provider: MultimodalProvider, provider: MultimodalProvider,
input_type: str, input_type: str,
input_urls: List[str], input_urls: list[str],
prompt: str, prompt: str,
) -> MultimodalAnalysis: ) -> MultimodalAnalysis:
"""多模态分析""" """多模态分析"""
@@ -517,7 +517,7 @@ class AIManager:
return analysis return analysis
async def _call_gpt4v(self, image_urls: List[str], prompt: str) -> Dict: async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict:
"""调用 GPT-4V""" """调用 GPT-4V"""
headers = {"Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json"}
@@ -544,7 +544,7 @@ class AIManager:
"cost": result["usage"]["total_tokens"] * 0.00001, # 估算成本 "cost": result["usage"]["total_tokens"] * 0.00001, # 估算成本
} }
async def _call_claude3(self, image_urls: List[str], prompt: str) -> Dict: async def _call_claude3(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Claude 3""" """调用 Claude 3"""
headers = { headers = {
"x-api-key": self.anthropic_api_key, "x-api-key": self.anthropic_api_key,
@@ -576,7 +576,7 @@ class AIManager:
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"]) * 0.000015, "cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"]) * 0.000015,
} }
async def _call_kimi_multimodal(self, image_urls: List[str], prompt: str) -> Dict: async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Kimi 多模态模型""" """调用 Kimi 多模态模型"""
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"}
@@ -600,7 +600,7 @@ class AIManager:
"cost": result["usage"]["total_tokens"] * 0.000005, "cost": result["usage"]["total_tokens"] * 0.000005,
} }
def get_multimodal_analyses(self, tenant_id: str, project_id: Optional[str] = None) -> List[MultimodalAnalysis]: def get_multimodal_analyses(self, tenant_id: str, project_id: str | None = None) -> list[MultimodalAnalysis]:
"""获取多模态分析历史""" """获取多模态分析历史"""
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -623,9 +623,9 @@ class AIManager:
project_id: str, project_id: str,
name: str, name: str,
description: str, description: str,
kg_config: Dict, kg_config: dict,
retrieval_config: Dict, retrieval_config: dict,
generation_config: Dict, generation_config: dict,
) -> KnowledgeGraphRAG: ) -> KnowledgeGraphRAG:
"""创建知识图谱 RAG 配置""" """创建知识图谱 RAG 配置"""
rag_id = f"kgr_{uuid.uuid4().hex[:16]}" rag_id = f"kgr_{uuid.uuid4().hex[:16]}"
@@ -671,7 +671,7 @@ class AIManager:
return rag return rag
def get_kg_rag(self, rag_id: str) -> Optional[KnowledgeGraphRAG]: def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None:
"""获取知识图谱 RAG 配置""" """获取知识图谱 RAG 配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone() row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone()
@@ -681,7 +681,7 @@ class AIManager:
return self._row_to_kg_rag(row) return self._row_to_kg_rag(row)
def list_kg_rags(self, tenant_id: str, project_id: Optional[str] = None) -> List[KnowledgeGraphRAG]: def list_kg_rags(self, tenant_id: str, project_id: str | None = None) -> list[KnowledgeGraphRAG]:
"""列出知识图谱 RAG 配置""" """列出知识图谱 RAG 配置"""
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -697,7 +697,7 @@ class AIManager:
return [self._row_to_kg_rag(row) for row in rows] return [self._row_to_kg_rag(row) for row in rows]
async def query_kg_rag( async def query_kg_rag(
self, rag_id: str, query: str, project_entities: List[Dict], project_relations: List[Dict] self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict]
) -> RAGQuery: ) -> RAGQuery:
"""基于知识图谱的 RAG 查询""" """基于知识图谱的 RAG 查询"""
import time import time
@@ -832,7 +832,7 @@ class AIManager:
return rag_query return rag_query
def _build_kg_context(self, entities: List[Dict], relations: List[Dict]) -> str: def _build_kg_context(self, entities: list[dict], relations: list[dict]) -> str:
"""构建知识图谱上下文文本""" """构建知识图谱上下文文本"""
context = [] context = []
@@ -858,7 +858,7 @@ class AIManager:
return "\n".join(context) return "\n".join(context)
async def generate_smart_summary( async def generate_smart_summary(
self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: Dict self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: dict
) -> SmartSummary: ) -> SmartSummary:
"""生成智能摘要""" """生成智能摘要"""
summary_id = f"ss_{uuid.uuid4().hex[:16]}" summary_id = f"ss_{uuid.uuid4().hex[:16]}"
@@ -999,9 +999,9 @@ class AIManager:
project_id: str, project_id: str,
name: str, name: str,
prediction_type: PredictionType, prediction_type: PredictionType,
target_entity_type: Optional[str], target_entity_type: str | None,
features: List[str], features: list[str],
model_config: Dict, model_config: dict,
) -> PredictionModel: ) -> PredictionModel:
"""创建预测模型""" """创建预测模型"""
model_id = f"pm_{uuid.uuid4().hex[:16]}" model_id = f"pm_{uuid.uuid4().hex[:16]}"
@@ -1053,7 +1053,7 @@ class AIManager:
return model return model
def get_prediction_model(self, model_id: str) -> Optional[PredictionModel]: def get_prediction_model(self, model_id: str) -> PredictionModel | None:
"""获取预测模型""" """获取预测模型"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM prediction_models WHERE id = ?", (model_id,)).fetchone() row = conn.execute("SELECT * FROM prediction_models WHERE id = ?", (model_id,)).fetchone()
@@ -1063,7 +1063,7 @@ class AIManager:
return self._row_to_prediction_model(row) return self._row_to_prediction_model(row)
def list_prediction_models(self, tenant_id: str, project_id: Optional[str] = None) -> List[PredictionModel]: def list_prediction_models(self, tenant_id: str, project_id: str | None = None) -> list[PredictionModel]:
"""列出预测模型""" """列出预测模型"""
query = "SELECT * FROM prediction_models WHERE tenant_id = ?" query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -1078,7 +1078,7 @@ class AIManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_prediction_model(row) for row in rows] return [self._row_to_prediction_model(row) for row in rows]
async def train_prediction_model(self, model_id: str, historical_data: List[Dict]) -> PredictionModel: async def train_prediction_model(self, model_id: str, historical_data: list[dict]) -> PredictionModel:
"""训练预测模型""" """训练预测模型"""
model = self.get_prediction_model(model_id) model = self.get_prediction_model(model_id)
if not model: if not model:
@@ -1105,7 +1105,7 @@ class AIManager:
return self.get_prediction_model(model_id) return self.get_prediction_model(model_id)
async def predict(self, model_id: str, input_data: Dict) -> PredictionResult: async def predict(self, model_id: str, input_data: dict) -> PredictionResult:
"""进行预测""" """进行预测"""
model = self.get_prediction_model(model_id) model = self.get_prediction_model(model_id)
if not model or not model.is_active: if not model or not model.is_active:
@@ -1172,7 +1172,7 @@ class AIManager:
return result return result
def _predict_trend(self, input_data: Dict, model: PredictionModel) -> Dict: def _predict_trend(self, input_data: dict, model: PredictionModel) -> dict:
"""趋势预测""" """趋势预测"""
historical_values = input_data.get("historical_values", []) historical_values = input_data.get("historical_values", [])
@@ -1211,7 +1211,7 @@ class AIManager:
"explanation": f"基于{len(historical_values)}个历史数据点,预测趋势为{trend}", "explanation": f"基于{len(historical_values)}个历史数据点,预测趋势为{trend}",
} }
def _detect_anomaly(self, input_data: Dict, model: PredictionModel) -> Dict: def _detect_anomaly(self, input_data: dict, model: PredictionModel) -> dict:
"""异常检测""" """异常检测"""
value = input_data.get("value") value = input_data.get("value")
historical_values = input_data.get("historical_values", []) historical_values = input_data.get("historical_values", [])
@@ -1245,7 +1245,7 @@ class AIManager:
"explanation": f"当前值偏离均值{z_score:.2f}个标准差,{'检测到异常' if is_anomaly else '处于正常范围'}", "explanation": f"当前值偏离均值{z_score:.2f}个标准差,{'检测到异常' if is_anomaly else '处于正常范围'}",
} }
def _predict_entity_growth(self, input_data: Dict, model: PredictionModel) -> Dict: def _predict_entity_growth(self, input_data: dict, model: PredictionModel) -> dict:
"""实体增长预测""" """实体增长预测"""
entity_history = input_data.get("entity_history", []) entity_history = input_data.get("entity_history", [])
@@ -1273,7 +1273,7 @@ class AIManager:
"explanation": f"基于过去{len(entity_history)}个周期的数据,预测增长率{avg_growth_rate * 100:.1f}%", "explanation": f"基于过去{len(entity_history)}个周期的数据,预测增长率{avg_growth_rate * 100:.1f}%",
} }
def _predict_relation_evolution(self, input_data: Dict, model: PredictionModel) -> Dict: def _predict_relation_evolution(self, input_data: dict, model: PredictionModel) -> dict:
"""关系演变预测""" """关系演变预测"""
relation_history = input_data.get("relation_history", []) relation_history = input_data.get("relation_history", [])
@@ -1299,7 +1299,7 @@ class AIManager:
"explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势", "explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势",
} }
def get_prediction_results(self, model_id: str, limit: int = 100) -> List[PredictionResult]: def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]:
"""获取预测结果历史""" """获取预测结果历史"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(

View File

@@ -4,14 +4,13 @@ InsightFlow API Key Manager - Phase 6
API Key 管理模块:生成、验证、撤销 API Key 管理模块:生成、验证、撤销
""" """
import os
import json
import hashlib import hashlib
import json
import os
import secrets import secrets
import sqlite3 import sqlite3
from datetime import datetime, timedelta
from typing import Optional, List, Dict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum from enum import Enum
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
@@ -29,15 +28,15 @@ class ApiKey:
key_hash: str # 存储哈希值,不存储原始 key key_hash: str # 存储哈希值,不存储原始 key
key_preview: str # 前8位预览如 "ak_live_abc..." key_preview: str # 前8位预览如 "ak_live_abc..."
name: str # 密钥名称/描述 name: str # 密钥名称/描述
owner_id: Optional[str] # 所有者ID预留多用户支持 owner_id: str | None # 所有者ID预留多用户支持
permissions: List[str] # 权限列表,如 ["read", "write"] permissions: list[str] # 权限列表,如 ["read", "write"]
rate_limit: int # 每分钟请求限制 rate_limit: int # 每分钟请求限制
status: str # active, revoked, expired status: str # active, revoked, expired
created_at: str created_at: str
expires_at: Optional[str] expires_at: str | None
last_used_at: Optional[str] last_used_at: str | None
revoked_at: Optional[str] revoked_at: str | None
revoked_reason: Optional[str] revoked_reason: str | None
total_calls: int = 0 total_calls: int = 0
@@ -131,10 +130,10 @@ class ApiKeyManager:
def create_key( def create_key(
self, self,
name: str, name: str,
owner_id: Optional[str] = None, owner_id: str | None = None,
permissions: List[str] = None, permissions: list[str] = None,
rate_limit: int = 60, rate_limit: int = 60,
expires_days: Optional[int] = None, expires_days: int | None = None,
) -> tuple[str, ApiKey]: ) -> tuple[str, ApiKey]:
""" """
创建新的 API Key 创建新的 API Key
@@ -196,7 +195,7 @@ class ApiKeyManager:
return raw_key, api_key return raw_key, api_key
def validate_key(self, key: str) -> Optional[ApiKey]: def validate_key(self, key: str) -> ApiKey | None:
""" """
验证 API Key 验证 API Key
@@ -231,7 +230,7 @@ class ApiKeyManager:
return api_key return api_key
def revoke_key(self, key_id: str, reason: str = "", owner_id: Optional[str] = None) -> bool: def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
"""撤销 API Key""" """撤销 API Key"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
# 验证所有权(如果提供了 owner_id # 验证所有权(如果提供了 owner_id
@@ -251,7 +250,7 @@ class ApiKeyManager:
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def get_key_by_id(self, key_id: str, owner_id: Optional[str] = None) -> Optional[ApiKey]: def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
"""通过 ID 获取 API Key不包含敏感信息""" """通过 ID 获取 API Key不包含敏感信息"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -268,8 +267,8 @@ class ApiKeyManager:
return None return None
def list_keys( def list_keys(
self, owner_id: Optional[str] = None, status: Optional[str] = None, limit: int = 100, offset: int = 0 self, owner_id: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0
) -> List[ApiKey]: ) -> list[ApiKey]:
"""列出 API Keys""" """列出 API Keys"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -294,10 +293,10 @@ class ApiKeyManager:
def update_key( def update_key(
self, self,
key_id: str, key_id: str,
name: Optional[str] = None, name: str | None = None,
permissions: Optional[List[str]] = None, permissions: list[str] | None = None,
rate_limit: Optional[int] = None, rate_limit: int | None = None,
owner_id: Optional[str] = None, owner_id: str | None = None,
) -> bool: ) -> bool:
"""更新 API Key 信息""" """更新 API Key 信息"""
updates = [] updates = []
@@ -371,12 +370,12 @@ class ApiKeyManager:
def get_call_logs( def get_call_logs(
self, self,
api_key_id: Optional[str] = None, api_key_id: str | None = None,
start_date: Optional[str] = None, start_date: str | None = None,
end_date: Optional[str] = None, end_date: str | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
) -> List[Dict]: ) -> list[dict]:
"""获取 API 调用日志""" """获取 API 调用日志"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -402,13 +401,13 @@ class ApiKeyManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [dict(row) for row in rows] return [dict(row) for row in rows]
def get_call_stats(self, api_key_id: Optional[str] = None, days: int = 30) -> Dict: def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
"""获取 API 调用统计""" """获取 API 调用统计"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
# 总体统计 # 总体统计
query = """ query = f"""
SELECT SELECT
COUNT(*) as total_calls, COUNT(*) as total_calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls, COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
@@ -417,8 +416,8 @@ class ApiKeyManager:
MAX(response_time_ms) as max_response_time, MAX(response_time_ms) as max_response_time,
MIN(response_time_ms) as min_response_time MIN(response_time_ms) as min_response_time
FROM api_call_logs FROM api_call_logs
WHERE created_at >= date('now', '-{} days') WHERE created_at >= date('now', '-{days} days')
""".format(days) """
params = [] params = []
if api_key_id: if api_key_id:
@@ -428,15 +427,15 @@ class ApiKeyManager:
row = conn.execute(query, params).fetchone() row = conn.execute(query, params).fetchone()
# 按端点统计 # 按端点统计
endpoint_query = """ endpoint_query = f"""
SELECT SELECT
endpoint, endpoint,
method, method,
COUNT(*) as calls, COUNT(*) as calls,
AVG(response_time_ms) as avg_time AVG(response_time_ms) as avg_time
FROM api_call_logs FROM api_call_logs
WHERE created_at >= date('now', '-{} days') WHERE created_at >= date('now', '-{days} days')
""".format(days) """
endpoint_params = [] endpoint_params = []
if api_key_id: if api_key_id:
@@ -448,14 +447,14 @@ class ApiKeyManager:
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall() endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
# 按天统计 # 按天统计
daily_query = """ daily_query = f"""
SELECT SELECT
date(created_at) as date, date(created_at) as date,
COUNT(*) as calls, COUNT(*) as calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success COUNT(CASE WHEN status_code < 400 THEN 1 END) as success
FROM api_call_logs FROM api_call_logs
WHERE created_at >= date('now', '-{} days') WHERE created_at >= date('now', '-{days} days')
""".format(days) """
daily_params = [] daily_params = []
if api_key_id: if api_key_id:
@@ -500,7 +499,7 @@ class ApiKeyManager:
# 全局实例 # 全局实例
_api_key_manager: Optional[ApiKeyManager] = None _api_key_manager: ApiKeyManager | None = None
def get_api_key_manager() -> ApiKeyManager: def get_api_key_manager() -> ApiKeyManager:

View File

@@ -3,13 +3,13 @@ InsightFlow - 协作与共享模块 (Phase 7 Task 4)
支持项目分享、评论批注、变更历史、团队空间 支持项目分享、评论批注、变更历史、团队空间
""" """
import hashlib
import json import json
import uuid import uuid
import hashlib
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any
class SharePermission(Enum): class SharePermission(Enum):
@@ -50,10 +50,10 @@ class ProjectShare:
permission: str # 权限级别 permission: str # 权限级别
created_by: str # 创建者 created_by: str # 创建者
created_at: str created_at: str
expires_at: Optional[str] # 过期时间 expires_at: str | None # 过期时间
max_uses: Optional[int] # 最大使用次数 max_uses: int | None # 最大使用次数
use_count: int # 已使用次数 use_count: int # 已使用次数
password_hash: Optional[str] # 密码保护 password_hash: str | None # 密码保护
is_active: bool # 是否激活 is_active: bool # 是否激活
allow_download: bool # 允许下载 allow_download: bool # 允许下载
allow_export: bool # 允许导出 allow_export: bool # 允许导出
@@ -67,17 +67,17 @@ class Comment:
project_id: str project_id: str
target_type: str # 评论目标类型 target_type: str # 评论目标类型
target_id: str # 目标ID target_id: str # 目标ID
parent_id: Optional[str] # 父评论ID支持回复 parent_id: str | None # 父评论ID支持回复
author: str # 作者 author: str # 作者
author_name: str # 作者显示名 author_name: str # 作者显示名
content: str # 评论内容 content: str # 评论内容
created_at: str created_at: str
updated_at: str updated_at: str
resolved: bool # 是否已解决 resolved: bool # 是否已解决
resolved_by: Optional[str] # 解决者 resolved_by: str | None # 解决者
resolved_at: Optional[str] # 解决时间 resolved_at: str | None # 解决时间
mentions: List[str] # 提及的用户 mentions: list[str] # 提及的用户
attachments: List[Dict] # 附件 attachments: list[dict] # 附件
@dataclass @dataclass
@@ -93,13 +93,13 @@ class ChangeRecord:
changed_by: str # 变更者 changed_by: str # 变更者
changed_by_name: str # 变更者显示名 changed_by_name: str # 变更者显示名
changed_at: str changed_at: str
old_value: Optional[Dict] # 旧值 old_value: dict | None # 旧值
new_value: Optional[Dict] # 新值 new_value: dict | None # 新值
description: str # 变更描述 description: str # 变更描述
session_id: Optional[str] # 会话ID批量变更关联 session_id: str | None # 会话ID批量变更关联
reverted: bool # 是否已回滚 reverted: bool # 是否已回滚
reverted_at: Optional[str] # 回滚时间 reverted_at: str | None # 回滚时间
reverted_by: Optional[str] # 回滚者 reverted_by: str | None # 回滚者
@dataclass @dataclass
@@ -114,8 +114,8 @@ class TeamMember:
role: str # 角色 (owner/admin/editor/viewer) role: str # 角色 (owner/admin/editor/viewer)
joined_at: str joined_at: str
invited_by: str # 邀请者 invited_by: str # 邀请者
last_active_at: Optional[str] # 最后活跃时间 last_active_at: str | None # 最后活跃时间
permissions: List[str] # 具体权限列表 permissions: list[str] # 具体权限列表
@dataclass @dataclass
@@ -130,7 +130,7 @@ class TeamSpace:
updated_at: str updated_at: str
member_count: int member_count: int
project_count: int project_count: int
settings: Dict[str, Any] # 团队设置 settings: dict[str, Any] # 团队设置
class CollaborationManager: class CollaborationManager:
@@ -138,8 +138,8 @@ class CollaborationManager:
def __init__(self, db_manager=None): def __init__(self, db_manager=None):
self.db = db_manager self.db = db_manager
self._shares_cache: Dict[str, ProjectShare] = {} self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: Dict[str, List[Comment]] = {} self._comments_cache: dict[str, list[Comment]] = {}
# ============ 项目分享 ============ # ============ 项目分享 ============
@@ -148,9 +148,9 @@ class CollaborationManager:
project_id: str, project_id: str,
created_by: str, created_by: str,
permission: str = "read_only", permission: str = "read_only",
expires_in_days: Optional[int] = None, expires_in_days: int | None = None,
max_uses: Optional[int] = None, max_uses: int | None = None,
password: Optional[str] = None, password: str | None = None,
allow_download: bool = False, allow_download: bool = False,
allow_export: bool = False, allow_export: bool = False,
) -> ProjectShare: ) -> ProjectShare:
@@ -224,7 +224,7 @@ class CollaborationManager:
) )
self.db.conn.commit() self.db.conn.commit()
def validate_share_token(self, token: str, password: Optional[str] = None) -> Optional[ProjectShare]: def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None:
"""验证分享令牌""" """验证分享令牌"""
# 从缓存或数据库获取 # 从缓存或数据库获取
share = self._shares_cache.get(token) share = self._shares_cache.get(token)
@@ -256,7 +256,7 @@ class CollaborationManager:
return share return share
def _get_share_from_db(self, token: str) -> Optional[ProjectShare]: def _get_share_from_db(self, token: str) -> ProjectShare | None:
"""从数据库获取分享记录""" """从数据库获取分享记录"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute( cursor.execute(
@@ -320,7 +320,7 @@ class CollaborationManager:
return cursor.rowcount > 0 return cursor.rowcount > 0
return False return False
def list_project_shares(self, project_id: str) -> List[ProjectShare]: def list_project_shares(self, project_id: str) -> list[ProjectShare]:
"""列出项目的所有分享链接""" """列出项目的所有分享链接"""
if not self.db: if not self.db:
return [] return []
@@ -366,9 +366,9 @@ class CollaborationManager:
author: str, author: str,
author_name: str, author_name: str,
content: str, content: str,
parent_id: Optional[str] = None, parent_id: str | None = None,
mentions: Optional[List[str]] = None, mentions: list[str] | None = None,
attachments: Optional[List[Dict]] = None, attachments: list[dict] | None = None,
) -> Comment: ) -> Comment:
"""添加评论""" """添加评论"""
comment_id = str(uuid.uuid4()) comment_id = str(uuid.uuid4())
@@ -434,7 +434,7 @@ class CollaborationManager:
) )
self.db.conn.commit() self.db.conn.commit()
def get_comments(self, target_type: str, target_id: str, include_resolved: bool = True) -> List[Comment]: def get_comments(self, target_type: str, target_id: str, include_resolved: bool = True) -> list[Comment]:
"""获取评论列表""" """获取评论列表"""
if not self.db: if not self.db:
return [] return []
@@ -484,7 +484,7 @@ class CollaborationManager:
attachments=json.loads(row[14]) if row[14] else [], attachments=json.loads(row[14]) if row[14] else [],
) )
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Optional[Comment]: def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
"""更新评论""" """更新评论"""
if not self.db: if not self.db:
return None return None
@@ -505,7 +505,7 @@ class CollaborationManager:
return self._get_comment_by_id(comment_id) return self._get_comment_by_id(comment_id)
return None return None
def _get_comment_by_id(self, comment_id: str) -> Optional[Comment]: def _get_comment_by_id(self, comment_id: str) -> Comment | None:
"""根据ID获取评论""" """根据ID获取评论"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,)) cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,))
@@ -551,7 +551,7 @@ class CollaborationManager:
self.db.conn.commit() self.db.conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def get_project_comments(self, project_id: str, limit: int = 50, offset: int = 0) -> List[Comment]: def get_project_comments(self, project_id: str, limit: int = 50, offset: int = 0) -> list[Comment]:
"""获取项目下的所有评论""" """获取项目下的所有评论"""
if not self.db: if not self.db:
return [] return []
@@ -583,10 +583,10 @@ class CollaborationManager:
entity_name: str, entity_name: str,
changed_by: str, changed_by: str,
changed_by_name: str, changed_by_name: str,
old_value: Optional[Dict] = None, old_value: dict | None = None,
new_value: Optional[Dict] = None, new_value: dict | None = None,
description: str = "", description: str = "",
session_id: Optional[str] = None, session_id: str | None = None,
) -> ChangeRecord: ) -> ChangeRecord:
"""记录变更""" """记录变更"""
record_id = str(uuid.uuid4()) record_id = str(uuid.uuid4())
@@ -651,11 +651,11 @@ class CollaborationManager:
def get_change_history( def get_change_history(
self, self,
project_id: str, project_id: str,
entity_type: Optional[str] = None, entity_type: str | None = None,
entity_id: Optional[str] = None, entity_id: str | None = None,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> List[ChangeRecord]: ) -> list[ChangeRecord]:
"""获取变更历史""" """获取变更历史"""
if not self.db: if not self.db:
return [] return []
@@ -719,7 +719,7 @@ class CollaborationManager:
reverted_by=row[15], reverted_by=row[15],
) )
def get_entity_version_history(self, entity_type: str, entity_id: str) -> List[ChangeRecord]: def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
"""获取实体的版本历史(用于版本对比)""" """获取实体的版本历史(用于版本对比)"""
if not self.db: if not self.db:
return [] return []
@@ -757,7 +757,7 @@ class CollaborationManager:
self.db.conn.commit() self.db.conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def get_change_stats(self, project_id: str) -> Dict[str, Any]: def get_change_stats(self, project_id: str) -> dict[str, Any]:
"""获取变更统计""" """获取变更统计"""
if not self.db: if not self.db:
return {} return {}
@@ -823,7 +823,7 @@ class CollaborationManager:
user_email: str, user_email: str,
role: str, role: str,
invited_by: str, invited_by: str,
permissions: Optional[List[str]] = None, permissions: list[str] | None = None,
) -> TeamMember: ) -> TeamMember:
"""添加团队成员""" """添加团队成员"""
member_id = str(uuid.uuid4()) member_id = str(uuid.uuid4())
@@ -851,7 +851,7 @@ class CollaborationManager:
return member return member
def _get_default_permissions(self, role: str) -> List[str]: def _get_default_permissions(self, role: str) -> list[str]:
"""获取角色的默认权限""" """获取角色的默认权限"""
permissions_map = { permissions_map = {
"owner": ["read", "write", "delete", "share", "admin", "export"], "owner": ["read", "write", "delete", "share", "admin", "export"],
@@ -887,7 +887,7 @@ class CollaborationManager:
) )
self.db.conn.commit() self.db.conn.commit()
def get_team_members(self, project_id: str) -> List[TeamMember]: def get_team_members(self, project_id: str) -> list[TeamMember]:
"""获取团队成员列表""" """获取团队成员列表"""
if not self.db: if not self.db:
return [] return []

View File

@@ -5,13 +5,12 @@ InsightFlow Database Manager - Phase 5
支持实体属性扩展 支持实体属性扩展
""" """
import os
import json import json
import os
import sqlite3 import sqlite3
import uuid import uuid
from datetime import datetime
from typing import List, Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
@@ -33,9 +32,9 @@ class Entity:
type: str type: str
definition: str = "" definition: str = ""
canonical_name: str = "" canonical_name: str = ""
aliases: List[str] = None aliases: list[str] = None
embedding: str = "" # Phase 3: 实体嵌入向量 embedding: str = "" # Phase 3: 实体嵌入向量
attributes: Dict = None # Phase 5: 实体属性 attributes: dict = None # Phase 5: 实体属性
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
@@ -54,7 +53,7 @@ class AttributeTemplate:
project_id: str project_id: str
name: str name: str
type: str # text, number, date, select, multiselect, boolean type: str # text, number, date, select, multiselect, boolean
options: List[str] = None # 用于 select/multiselect options: list[str] = None # 用于 select/multiselect
default_value: str = "" default_value: str = ""
description: str = "" description: str = ""
is_required: bool = False is_required: bool = False
@@ -73,11 +72,11 @@ class EntityAttribute:
id: str id: str
entity_id: str entity_id: str
template_id: Optional[str] = None template_id: str | None = None
name: str = "" # 属性名称 name: str = "" # 属性名称
type: str = "text" # 属性类型 type: str = "text" # 属性类型
value: str = "" value: str = ""
options: List[str] = None # 选项列表 options: list[str] = None # 选项列表
template_name: str = "" # 关联查询时填充 template_name: str = "" # 关联查询时填充
template_type: str = "" # 关联查询时填充 template_type: str = "" # 关联查询时填充
created_at: str = "" created_at: str = ""
@@ -126,7 +125,7 @@ class DatabaseManager:
def init_db(self): def init_db(self):
"""初始化数据库表""" """初始化数据库表"""
with open(os.path.join(os.path.dirname(__file__), "schema.sql"), "r") as f: with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
schema = f.read() schema = f.read()
conn = self.get_conn() conn = self.get_conn()
@@ -147,7 +146,7 @@ class DatabaseManager:
conn.close() conn.close()
return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now) return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now)
def get_project(self, project_id: str) -> Optional[Project]: def get_project(self, project_id: str) -> Project | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
conn.close() conn.close()
@@ -155,7 +154,7 @@ class DatabaseManager:
return Project(**dict(row)) return Project(**dict(row))
return None return None
def list_projects(self) -> List[Project]: def list_projects(self) -> list[Project]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall()
conn.close() conn.close()
@@ -184,7 +183,7 @@ class DatabaseManager:
conn.close() conn.close()
return entity return entity
def get_entity_by_name(self, project_id: str, name: str) -> Optional[Entity]: def get_entity_by_name(self, project_id: str, name: str) -> Entity | None:
"""通过名称查找实体(用于对齐)""" """通过名称查找实体(用于对齐)"""
conn = self.get_conn() conn = self.get_conn()
row = conn.execute( row = conn.execute(
@@ -198,7 +197,7 @@ class DatabaseManager:
return Entity(**data) return Entity(**data)
return None return None
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> List[Entity]: def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> list[Entity]:
"""查找相似实体""" """查找相似实体"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -245,7 +244,7 @@ class DatabaseManager:
conn.close() conn.close()
return self.get_entity(target_id) return self.get_entity(target_id)
def get_entity(self, entity_id: str) -> Optional[Entity]: def get_entity(self, entity_id: str) -> Entity | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
conn.close() conn.close()
@@ -255,7 +254,7 @@ class DatabaseManager:
return Entity(**data) return Entity(**data)
return None return None
def list_project_entities(self, project_id: str) -> List[Entity]: def list_project_entities(self, project_id: str) -> list[Entity]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,) "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,)
@@ -333,7 +332,7 @@ class DatabaseManager:
conn.close() conn.close()
return mention return mention
def get_entity_mentions(self, entity_id: str) -> List[EntityMention]: def get_entity_mentions(self, entity_id: str) -> list[EntityMention]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id,) "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id,)
@@ -355,13 +354,13 @@ class DatabaseManager:
conn.commit() conn.commit()
conn.close() conn.close()
def get_transcript(self, transcript_id: str) -> Optional[dict]: def get_transcript(self, transcript_id: str) -> dict | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
conn.close() conn.close()
return dict(row) if row else None return dict(row) if row else None
def list_project_transcripts(self, project_id: str) -> List[dict]: def list_project_transcripts(self, project_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
@@ -404,7 +403,7 @@ class DatabaseManager:
conn.close() conn.close()
return relation_id return relation_id
def get_entity_relations(self, entity_id: str) -> List[dict]: def get_entity_relations(self, entity_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM entity_relations """SELECT * FROM entity_relations
@@ -415,7 +414,7 @@ class DatabaseManager:
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
def list_project_relations(self, project_id: str) -> List[dict]: def list_project_relations(self, project_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id,) "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
@@ -473,7 +472,7 @@ class DatabaseManager:
conn.close() conn.close()
return term_id return term_id
def list_glossary(self, project_id: str) -> List[dict]: def list_glossary(self, project_id: str) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,) "SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,)
@@ -489,7 +488,7 @@ class DatabaseManager:
# ==================== Phase 4: Agent & Provenance ==================== # ==================== Phase 4: Agent & Provenance ====================
def get_relation_with_details(self, relation_id: str) -> Optional[dict]: def get_relation_with_details(self, relation_id: str) -> dict | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute( row = conn.execute(
"""SELECT r.*, """SELECT r.*,
@@ -505,7 +504,7 @@ class DatabaseManager:
conn.close() conn.close()
return dict(row) if row else None return dict(row) if row else None
def get_entity_with_mentions(self, entity_id: str) -> Optional[dict]: def get_entity_with_mentions(self, entity_id: str) -> dict | None:
conn = self.get_conn() conn = self.get_conn()
entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
if not entity_row: if not entity_row:
@@ -539,7 +538,7 @@ class DatabaseManager:
conn.close() conn.close()
return entity return entity
def search_entities(self, project_id: str, query: str) -> List[Entity]: def search_entities(self, project_id: str, query: str) -> list[Entity]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM entities """SELECT * FROM entities
@@ -616,7 +615,7 @@ class DatabaseManager:
def get_project_timeline( def get_project_timeline(
self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None
) -> List[dict]: ) -> list[dict]:
conn = self.get_conn() conn = self.get_conn()
conditions = ["t.project_id = ?"] conditions = ["t.project_id = ?"]
@@ -722,7 +721,7 @@ class DatabaseManager:
conn.close() conn.close()
return template return template
def get_attribute_template(self, template_id: str) -> Optional[AttributeTemplate]: def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM attribute_templates WHERE id = ?", (template_id,)).fetchone() row = conn.execute("SELECT * FROM attribute_templates WHERE id = ?", (template_id,)).fetchone()
conn.close() conn.close()
@@ -732,7 +731,7 @@ class DatabaseManager:
return AttributeTemplate(**data) return AttributeTemplate(**data)
return None return None
def list_attribute_templates(self, project_id: str) -> List[AttributeTemplate]: def list_attribute_templates(self, project_id: str) -> list[AttributeTemplate]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"""SELECT * FROM attribute_templates WHERE project_id = ? """SELECT * FROM attribute_templates WHERE project_id = ?
@@ -748,7 +747,7 @@ class DatabaseManager:
templates.append(AttributeTemplate(**data)) templates.append(AttributeTemplate(**data))
return templates return templates
def update_attribute_template(self, template_id: str, **kwargs) -> Optional[AttributeTemplate]: def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None:
conn = self.get_conn() conn = self.get_conn()
allowed_fields = ["name", "type", "options", "default_value", "description", "is_required", "sort_order"] allowed_fields = ["name", "type", "options", "default_value", "description", "is_required", "sort_order"]
updates = [] updates = []
@@ -834,7 +833,7 @@ class DatabaseManager:
conn.close() conn.close()
return attr return attr
def get_entity_attributes(self, entity_id: str) -> List[EntityAttribute]: def get_entity_attributes(self, entity_id: str) -> list[EntityAttribute]:
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
"""SELECT ea.*, at.name as template_name, at.type as template_type """SELECT ea.*, at.name as template_name, at.type as template_type
@@ -846,7 +845,7 @@ class DatabaseManager:
conn.close() conn.close()
return [EntityAttribute(**dict(r)) for r in rows] return [EntityAttribute(**dict(r)) for r in rows]
def get_entity_with_attributes(self, entity_id: str) -> Optional[Entity]: def get_entity_with_attributes(self, entity_id: str) -> Entity | None:
entity = self.get_entity(entity_id) entity = self.get_entity(entity_id)
if not entity: if not entity:
return None return None
@@ -889,7 +888,7 @@ class DatabaseManager:
def get_attribute_history( def get_attribute_history(
self, entity_id: str = None, template_id: str = None, limit: int = 50 self, entity_id: str = None, template_id: str = None, limit: int = 50
) -> List[AttributeHistory]: ) -> list[AttributeHistory]:
conn = self.get_conn() conn = self.get_conn()
conditions = [] conditions = []
params = [] params = []
@@ -913,7 +912,7 @@ class DatabaseManager:
conn.close() conn.close()
return [AttributeHistory(**dict(r)) for r in rows] return [AttributeHistory(**dict(r)) for r in rows]
def search_entities_by_attributes(self, project_id: str, attribute_filters: Dict[str, str]) -> List[Entity]: def search_entities_by_attributes(self, project_id: str, attribute_filters: dict[str, str]) -> list[Entity]:
entities = self.list_project_entities(project_id) entities = self.list_project_entities(project_id)
if not attribute_filters: if not attribute_filters:
return entities return entities
@@ -962,11 +961,11 @@ class DatabaseManager:
filename: str, filename: str,
duration: float = 0, duration: float = 0,
fps: float = 0, fps: float = 0,
resolution: Dict = None, resolution: dict = None,
audio_transcript_id: str = None, audio_transcript_id: str = None,
full_ocr_text: str = "", full_ocr_text: str = "",
extracted_entities: List[Dict] = None, extracted_entities: list[dict] = None,
extracted_relations: List[Dict] = None, extracted_relations: list[dict] = None,
) -> str: ) -> str:
"""创建视频记录""" """创建视频记录"""
conn = self.get_conn() conn = self.get_conn()
@@ -998,7 +997,7 @@ class DatabaseManager:
conn.close() conn.close()
return video_id return video_id
def get_video(self, video_id: str) -> Optional[Dict]: def get_video(self, video_id: str) -> dict | None:
"""获取视频信息""" """获取视频信息"""
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone() row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone()
@@ -1012,7 +1011,7 @@ class DatabaseManager:
return data return data
return None return None
def list_project_videos(self, project_id: str) -> List[Dict]: def list_project_videos(self, project_id: str) -> list[dict]:
"""获取项目的所有视频""" """获取项目的所有视频"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -1037,7 +1036,7 @@ class DatabaseManager:
timestamp: float, timestamp: float,
image_url: str = None, image_url: str = None,
ocr_text: str = None, ocr_text: str = None,
extracted_entities: List[Dict] = None, extracted_entities: list[dict] = None,
) -> str: ) -> str:
"""创建视频帧记录""" """创建视频帧记录"""
conn = self.get_conn() conn = self.get_conn()
@@ -1062,7 +1061,7 @@ class DatabaseManager:
conn.close() conn.close()
return frame_id return frame_id
def get_video_frames(self, video_id: str) -> List[Dict]: def get_video_frames(self, video_id: str) -> list[dict]:
"""获取视频的所有帧""" """获取视频的所有帧"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -1084,8 +1083,8 @@ class DatabaseManager:
filename: str, filename: str,
ocr_text: str = "", ocr_text: str = "",
description: str = "", description: str = "",
extracted_entities: List[Dict] = None, extracted_entities: list[dict] = None,
extracted_relations: List[Dict] = None, extracted_relations: list[dict] = None,
) -> str: ) -> str:
"""创建图片记录""" """创建图片记录"""
conn = self.get_conn() conn = self.get_conn()
@@ -1113,7 +1112,7 @@ class DatabaseManager:
conn.close() conn.close()
return image_id return image_id
def get_image(self, image_id: str) -> Optional[Dict]: def get_image(self, image_id: str) -> dict | None:
"""获取图片信息""" """获取图片信息"""
conn = self.get_conn() conn = self.get_conn()
row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone() row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone()
@@ -1126,7 +1125,7 @@ class DatabaseManager:
return data return data
return None return None
def list_project_images(self, project_id: str) -> List[Dict]: def list_project_images(self, project_id: str) -> list[dict]:
"""获取项目的所有图片""" """获取项目的所有图片"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -1168,7 +1167,7 @@ class DatabaseManager:
conn.close() conn.close()
return mention_id return mention_id
def get_entity_multimodal_mentions(self, entity_id: str) -> List[Dict]: def get_entity_multimodal_mentions(self, entity_id: str) -> list[dict]:
"""获取实体的多模态提及""" """获取实体的多模态提及"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -1181,7 +1180,7 @@ class DatabaseManager:
conn.close() conn.close()
return [dict(r) for r in rows] return [dict(r) for r in rows]
def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> List[Dict]: def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]:
"""获取项目的多模态提及""" """获取项目的多模态提及"""
conn = self.get_conn() conn = self.get_conn()
@@ -1214,7 +1213,7 @@ class DatabaseManager:
link_type: str, link_type: str,
confidence: float = 1.0, confidence: float = 1.0,
evidence: str = "", evidence: str = "",
modalities: List[str] = None, modalities: list[str] = None,
) -> str: ) -> str:
"""创建多模态实体关联""" """创建多模态实体关联"""
conn = self.get_conn() conn = self.get_conn()
@@ -1231,7 +1230,7 @@ class DatabaseManager:
conn.close() conn.close()
return link_id return link_id
def get_entity_multimodal_links(self, entity_id: str) -> List[Dict]: def get_entity_multimodal_links(self, entity_id: str) -> list[dict]:
"""获取实体的多模态关联""" """获取实体的多模态关联"""
conn = self.get_conn() conn = self.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -1251,7 +1250,7 @@ class DatabaseManager:
links.append(data) links.append(data)
return links return links
def get_project_multimodal_stats(self, project_id: str) -> Dict: def get_project_multimodal_stats(self, project_id: str) -> dict:
"""获取项目多模态统计信息""" """获取项目多模态统计信息"""
conn = self.get_conn() conn = self.get_conn()

View File

@@ -10,20 +10,19 @@ InsightFlow Developer Ecosystem Manager - Phase 8 Task 6
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import os
import json import json
import os
import sqlite3 import sqlite3
import uuid import uuid
from typing import List, Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class SDKLanguage(str, Enum): class SDKLanguage(StrEnum):
"""SDK 语言类型""" """SDK 语言类型"""
PYTHON = "python" PYTHON = "python"
@@ -34,7 +33,7 @@ class SDKLanguage(str, Enum):
RUST = "rust" RUST = "rust"
class SDKStatus(str, Enum): class SDKStatus(StrEnum):
"""SDK 状态""" """SDK 状态"""
DRAFT = "draft" # 草稿 DRAFT = "draft" # 草稿
@@ -44,7 +43,7 @@ class SDKStatus(str, Enum):
ARCHIVED = "archived" # 已归档 ARCHIVED = "archived" # 已归档
class TemplateCategory(str, Enum): class TemplateCategory(StrEnum):
"""模板分类""" """模板分类"""
MEDICAL = "medical" # 医疗 MEDICAL = "medical" # 医疗
@@ -55,7 +54,7 @@ class TemplateCategory(str, Enum):
GENERAL = "general" # 通用 GENERAL = "general" # 通用
class TemplateStatus(str, Enum): class TemplateStatus(StrEnum):
"""模板状态""" """模板状态"""
PENDING = "pending" # 待审核 PENDING = "pending" # 待审核
@@ -65,7 +64,7 @@ class TemplateStatus(str, Enum):
UNLISTED = "unlisted" # 未列出 UNLISTED = "unlisted" # 未列出
class PluginStatus(str, Enum): class PluginStatus(StrEnum):
"""插件状态""" """插件状态"""
PENDING = "pending" # 待审核 PENDING = "pending" # 待审核
@@ -76,7 +75,7 @@ class PluginStatus(str, Enum):
SUSPENDED = "suspended" # 已暂停 SUSPENDED = "suspended" # 已暂停
class PluginCategory(str, Enum): class PluginCategory(StrEnum):
"""插件分类""" """插件分类"""
INTEGRATION = "integration" # 集成 INTEGRATION = "integration" # 集成
@@ -87,7 +86,7 @@ class PluginCategory(str, Enum):
CUSTOM = "custom" # 自定义 CUSTOM = "custom" # 自定义
class DeveloperStatus(str, Enum): class DeveloperStatus(StrEnum):
"""开发者认证状态""" """开发者认证状态"""
UNVERIFIED = "unverified" # 未认证 UNVERIFIED = "unverified" # 未认证
@@ -113,13 +112,13 @@ class SDKRelease:
package_name: str # pip/npm/go module name package_name: str # pip/npm/go module name
status: SDKStatus status: SDKStatus
min_platform_version: str min_platform_version: str
dependencies: List[Dict] # [{"name": "requests", "version": ">=2.0"}] dependencies: list[dict] # [{"name": "requests", "version": ">=2.0"}]
file_size: int file_size: int
checksum: str checksum: str
download_count: int download_count: int
created_at: str created_at: str
updated_at: str updated_at: str
published_at: Optional[str] published_at: str | None
created_by: str created_by: str
@@ -148,17 +147,17 @@ class TemplateMarketItem:
name: str name: str
description: str description: str
category: TemplateCategory category: TemplateCategory
subcategory: Optional[str] subcategory: str | None
tags: List[str] tags: list[str]
author_id: str author_id: str
author_name: str author_name: str
status: TemplateStatus status: TemplateStatus
price: float # 0 = 免费 price: float # 0 = 免费
currency: str currency: str
preview_image_url: Optional[str] preview_image_url: str | None
demo_url: Optional[str] demo_url: str | None
documentation_url: Optional[str] documentation_url: str | None
download_url: Optional[str] download_url: str | None
install_count: int install_count: int
rating: float rating: float
rating_count: int rating_count: int
@@ -169,7 +168,7 @@ class TemplateMarketItem:
checksum: str checksum: str
created_at: str created_at: str
updated_at: str updated_at: str
published_at: Optional[str] published_at: str | None
@dataclass @dataclass
@@ -196,20 +195,20 @@ class PluginMarketItem:
name: str name: str
description: str description: str
category: PluginCategory category: PluginCategory
tags: List[str] tags: list[str]
author_id: str author_id: str
author_name: str author_name: str
status: PluginStatus status: PluginStatus
price: float price: float
currency: str currency: str
pricing_model: str # free, paid, freemium, subscription pricing_model: str # free, paid, freemium, subscription
preview_image_url: Optional[str] preview_image_url: str | None
demo_url: Optional[str] demo_url: str | None
documentation_url: Optional[str] documentation_url: str | None
repository_url: Optional[str] repository_url: str | None
download_url: Optional[str] download_url: str | None
webhook_url: Optional[str] # 用于插件回调 webhook_url: str | None # 用于插件回调
permissions: List[str] # 需要的权限列表 permissions: list[str] # 需要的权限列表
install_count: int install_count: int
active_install_count: int active_install_count: int
rating: float rating: float
@@ -221,10 +220,10 @@ class PluginMarketItem:
checksum: str checksum: str
created_at: str created_at: str
updated_at: str updated_at: str
published_at: Optional[str] published_at: str | None
reviewed_by: Optional[str] reviewed_by: str | None
reviewed_at: Optional[str] reviewed_at: str | None
review_notes: Optional[str] review_notes: str | None
@dataclass @dataclass
@@ -251,12 +250,12 @@ class DeveloperProfile:
user_id: str user_id: str
display_name: str display_name: str
email: str email: str
bio: Optional[str] bio: str | None
website: Optional[str] website: str | None
github_url: Optional[str] github_url: str | None
avatar_url: Optional[str] avatar_url: str | None
status: DeveloperStatus status: DeveloperStatus
verification_documents: Dict # 认证文档 verification_documents: dict # 认证文档
total_sales: float total_sales: float
total_downloads: int total_downloads: int
plugin_count: int plugin_count: int
@@ -264,7 +263,7 @@ class DeveloperProfile:
rating_average: float rating_average: float
created_at: str created_at: str
updated_at: str updated_at: str
verified_at: Optional[str] verified_at: str | None
@dataclass @dataclass
@@ -296,11 +295,11 @@ class CodeExample:
category: str category: str
code: str code: str
explanation: str explanation: str
tags: List[str] tags: list[str]
author_id: str author_id: str
author_name: str author_name: str
sdk_id: Optional[str] # 关联的 SDK sdk_id: str | None # 关联的 SDK
api_endpoints: List[str] # 涉及的 API 端点 api_endpoints: list[str] # 涉及的 API 端点
view_count: int view_count: int
copy_count: int copy_count: int
rating: float rating: float
@@ -330,16 +329,16 @@ class DeveloperPortalConfig:
name: str name: str
description: str description: str
theme: str theme: str
custom_css: Optional[str] custom_css: str | None
custom_js: Optional[str] custom_js: str | None
logo_url: Optional[str] logo_url: str | None
favicon_url: Optional[str] favicon_url: str | None
primary_color: str primary_color: str
secondary_color: str secondary_color: str
support_email: str support_email: str
support_url: Optional[str] support_url: str | None
github_url: Optional[str] github_url: str | None
discord_url: Optional[str] discord_url: str | None
api_base_url: str api_base_url: str
is_active: bool is_active: bool
created_at: str created_at: str
@@ -373,7 +372,7 @@ class DeveloperEcosystemManager:
repository_url: str, repository_url: str,
package_name: str, package_name: str,
min_platform_version: str, min_platform_version: str,
dependencies: List[Dict], dependencies: list[dict],
file_size: int, file_size: int,
checksum: str, checksum: str,
created_by: str, created_by: str,
@@ -442,7 +441,7 @@ class DeveloperEcosystemManager:
return sdk return sdk
def get_sdk_release(self, sdk_id: str) -> Optional[SDKRelease]: def get_sdk_release(self, sdk_id: str) -> SDKRelease | None:
"""获取 SDK 发布详情""" """获取 SDK 发布详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id,)).fetchone() row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id,)).fetchone()
@@ -452,8 +451,8 @@ class DeveloperEcosystemManager:
return None return None
def list_sdk_releases( def list_sdk_releases(
self, language: Optional[SDKLanguage] = None, status: Optional[SDKStatus] = None, search: Optional[str] = None self, language: SDKLanguage | None = None, status: SDKStatus | None = None, search: str | None = None
) -> List[SDKRelease]: ) -> list[SDKRelease]:
"""列出 SDK 发布""" """列出 SDK 发布"""
query = "SELECT * FROM sdk_releases WHERE 1=1" query = "SELECT * FROM sdk_releases WHERE 1=1"
params = [] params = []
@@ -474,7 +473,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_sdk_release(row) for row in rows] return [self._row_to_sdk_release(row) for row in rows]
def update_sdk_release(self, sdk_id: str, **kwargs) -> Optional[SDKRelease]: def update_sdk_release(self, sdk_id: str, **kwargs) -> SDKRelease | None:
"""更新 SDK 发布""" """更新 SDK 发布"""
allowed_fields = [ allowed_fields = [
"name", "name",
@@ -499,7 +498,7 @@ class DeveloperEcosystemManager:
return self.get_sdk_release(sdk_id) return self.get_sdk_release(sdk_id)
def publish_sdk_release(self, sdk_id: str) -> Optional[SDKRelease]: def publish_sdk_release(self, sdk_id: str) -> SDKRelease | None:
"""发布 SDK""" """发布 SDK"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -529,7 +528,7 @@ class DeveloperEcosystemManager:
) )
conn.commit() conn.commit()
def get_sdk_versions(self, sdk_id: str) -> List[SDKVersion]: def get_sdk_versions(self, sdk_id: str) -> list[SDKVersion]:
"""获取 SDK 版本历史""" """获取 SDK 版本历史"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -588,16 +587,16 @@ class DeveloperEcosystemManager:
name: str, name: str,
description: str, description: str,
category: TemplateCategory, category: TemplateCategory,
subcategory: Optional[str], subcategory: str | None,
tags: List[str], tags: list[str],
author_id: str, author_id: str,
author_name: str, author_name: str,
price: float = 0.0, price: float = 0.0,
currency: str = "CNY", currency: str = "CNY",
preview_image_url: Optional[str] = None, preview_image_url: str | None = None,
demo_url: Optional[str] = None, demo_url: str | None = None,
documentation_url: Optional[str] = None, documentation_url: str | None = None,
download_url: Optional[str] = None, download_url: str | None = None,
version: str = "1.0.0", version: str = "1.0.0",
min_platform_version: str = "1.0.0", min_platform_version: str = "1.0.0",
file_size: int = 0, file_size: int = 0,
@@ -679,7 +678,7 @@ class DeveloperEcosystemManager:
return template return template
def get_template(self, template_id: str) -> Optional[TemplateMarketItem]: def get_template(self, template_id: str) -> TemplateMarketItem | None:
"""获取模板详情""" """获取模板详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM template_market WHERE id = ?", (template_id,)).fetchone() row = conn.execute("SELECT * FROM template_market WHERE id = ?", (template_id,)).fetchone()
@@ -690,14 +689,14 @@ class DeveloperEcosystemManager:
def list_templates( def list_templates(
self, self,
category: Optional[TemplateCategory] = None, category: TemplateCategory | None = None,
status: Optional[TemplateStatus] = None, status: TemplateStatus | None = None,
search: Optional[str] = None, search: str | None = None,
author_id: Optional[str] = None, author_id: str | None = None,
min_price: Optional[float] = None, min_price: float | None = None,
max_price: Optional[float] = None, max_price: float | None = None,
sort_by: str = "created_at", sort_by: str = "created_at",
) -> List[TemplateMarketItem]: ) -> list[TemplateMarketItem]:
"""列出模板""" """列出模板"""
query = "SELECT * FROM template_market WHERE 1=1" query = "SELECT * FROM template_market WHERE 1=1"
params = [] params = []
@@ -735,7 +734,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_template(row) for row in rows] return [self._row_to_template(row) for row in rows]
def approve_template(self, template_id: str, reviewed_by: str) -> Optional[TemplateMarketItem]: def approve_template(self, template_id: str, reviewed_by: str) -> TemplateMarketItem | None:
"""审核通过模板""" """审核通过模板"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -752,7 +751,7 @@ class DeveloperEcosystemManager:
return self.get_template(template_id) return self.get_template(template_id)
def publish_template(self, template_id: str) -> Optional[TemplateMarketItem]: def publish_template(self, template_id: str) -> TemplateMarketItem | None:
"""发布模板""" """发布模板"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -769,7 +768,7 @@ class DeveloperEcosystemManager:
return self.get_template(template_id) return self.get_template(template_id)
def reject_template(self, template_id: str, reason: str) -> Optional[TemplateMarketItem]: def reject_template(self, template_id: str, reason: str) -> TemplateMarketItem | None:
"""拒绝模板""" """拒绝模板"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -874,7 +873,7 @@ class DeveloperEcosystemManager:
(round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], template_id), (round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], template_id),
) )
def get_template_reviews(self, template_id: str, limit: int = 50) -> List[TemplateReview]: def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]:
"""获取模板评价""" """获取模板评价"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -893,19 +892,19 @@ class DeveloperEcosystemManager:
name: str, name: str,
description: str, description: str,
category: PluginCategory, category: PluginCategory,
tags: List[str], tags: list[str],
author_id: str, author_id: str,
author_name: str, author_name: str,
price: float = 0.0, price: float = 0.0,
currency: str = "CNY", currency: str = "CNY",
pricing_model: str = "free", pricing_model: str = "free",
preview_image_url: Optional[str] = None, preview_image_url: str | None = None,
demo_url: Optional[str] = None, demo_url: str | None = None,
documentation_url: Optional[str] = None, documentation_url: str | None = None,
repository_url: Optional[str] = None, repository_url: str | None = None,
download_url: Optional[str] = None, download_url: str | None = None,
webhook_url: Optional[str] = None, webhook_url: str | None = None,
permissions: List[str] = None, permissions: list[str] = None,
version: str = "1.0.0", version: str = "1.0.0",
min_platform_version: str = "1.0.0", min_platform_version: str = "1.0.0",
file_size: int = 0, file_size: int = 0,
@@ -1003,7 +1002,7 @@ class DeveloperEcosystemManager:
return plugin return plugin
def get_plugin(self, plugin_id: str) -> Optional[PluginMarketItem]: def get_plugin(self, plugin_id: str) -> PluginMarketItem | None:
"""获取插件详情""" """获取插件详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id,)).fetchone() row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id,)).fetchone()
@@ -1014,12 +1013,12 @@ class DeveloperEcosystemManager:
def list_plugins( def list_plugins(
self, self,
category: Optional[PluginCategory] = None, category: PluginCategory | None = None,
status: Optional[PluginStatus] = None, status: PluginStatus | None = None,
search: Optional[str] = None, search: str | None = None,
author_id: Optional[str] = None, author_id: str | None = None,
sort_by: str = "created_at", sort_by: str = "created_at",
) -> List[PluginMarketItem]: ) -> list[PluginMarketItem]:
"""列出插件""" """列出插件"""
query = "SELECT * FROM plugin_market WHERE 1=1" query = "SELECT * FROM plugin_market WHERE 1=1"
params = [] params = []
@@ -1051,7 +1050,7 @@ class DeveloperEcosystemManager:
def review_plugin( def review_plugin(
self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = "" self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = ""
) -> Optional[PluginMarketItem]: ) -> PluginMarketItem | None:
"""审核插件""" """审核插件"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1068,7 +1067,7 @@ class DeveloperEcosystemManager:
return self.get_plugin(plugin_id) return self.get_plugin(plugin_id)
def publish_plugin(self, plugin_id: str) -> Optional[PluginMarketItem]: def publish_plugin(self, plugin_id: str) -> PluginMarketItem | None:
"""发布插件""" """发布插件"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1182,7 +1181,7 @@ class DeveloperEcosystemManager:
(round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], plugin_id), (round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], plugin_id),
) )
def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> List[PluginReview]: def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]:
"""获取插件评价""" """获取插件评价"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -1268,8 +1267,8 @@ class DeveloperEcosystemManager:
return revenue return revenue
def get_developer_revenues( def get_developer_revenues(
self, developer_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None self, developer_id: str, start_date: datetime | None = None, end_date: datetime | None = None
) -> List[DeveloperRevenue]: ) -> list[DeveloperRevenue]:
"""获取开发者收益记录""" """获取开发者收益记录"""
query = "SELECT * FROM developer_revenues WHERE developer_id = ?" query = "SELECT * FROM developer_revenues WHERE developer_id = ?"
params = [developer_id] params = [developer_id]
@@ -1287,7 +1286,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_developer_revenue(row) for row in rows] return [self._row_to_developer_revenue(row) for row in rows]
def get_developer_revenue_summary(self, developer_id: str) -> Dict: def get_developer_revenue_summary(self, developer_id: str) -> dict:
"""获取开发者收益汇总""" """获取开发者收益汇总"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
@@ -1318,10 +1317,10 @@ class DeveloperEcosystemManager:
user_id: str, user_id: str,
display_name: str, display_name: str,
email: str, email: str,
bio: Optional[str] = None, bio: str | None = None,
website: Optional[str] = None, website: str | None = None,
github_url: Optional[str] = None, github_url: str | None = None,
avatar_url: Optional[str] = None, avatar_url: str | None = None,
) -> DeveloperProfile: ) -> DeveloperProfile:
"""创建开发者档案""" """创建开发者档案"""
profile_id = f"dev_{uuid.uuid4().hex[:16]}" profile_id = f"dev_{uuid.uuid4().hex[:16]}"
@@ -1382,7 +1381,7 @@ class DeveloperEcosystemManager:
return profile return profile
def get_developer_profile(self, developer_id: str) -> Optional[DeveloperProfile]: def get_developer_profile(self, developer_id: str) -> DeveloperProfile | None:
"""获取开发者档案""" """获取开发者档案"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_profiles WHERE id = ?", (developer_id,)).fetchone() row = conn.execute("SELECT * FROM developer_profiles WHERE id = ?", (developer_id,)).fetchone()
@@ -1391,7 +1390,7 @@ class DeveloperEcosystemManager:
return self._row_to_developer_profile(row) return self._row_to_developer_profile(row)
return None return None
def get_developer_profile_by_user(self, user_id: str) -> Optional[DeveloperProfile]: def get_developer_profile_by_user(self, user_id: str) -> DeveloperProfile | None:
"""通过用户 ID 获取开发者档案""" """通过用户 ID 获取开发者档案"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,)).fetchone() row = conn.execute("SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,)).fetchone()
@@ -1400,7 +1399,7 @@ class DeveloperEcosystemManager:
return self._row_to_developer_profile(row) return self._row_to_developer_profile(row)
return None return None
def verify_developer(self, developer_id: str, status: DeveloperStatus) -> Optional[DeveloperProfile]: def verify_developer(self, developer_id: str, status: DeveloperStatus) -> DeveloperProfile | None:
"""验证开发者""" """验证开发者"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1473,11 +1472,11 @@ class DeveloperEcosystemManager:
category: str, category: str,
code: str, code: str,
explanation: str, explanation: str,
tags: List[str], tags: list[str],
author_id: str, author_id: str,
author_name: str, author_name: str,
sdk_id: Optional[str] = None, sdk_id: str | None = None,
api_endpoints: List[str] = None, api_endpoints: list[str] = None,
) -> CodeExample: ) -> CodeExample:
"""创建代码示例""" """创建代码示例"""
example_id = f"ex_{uuid.uuid4().hex[:16]}" example_id = f"ex_{uuid.uuid4().hex[:16]}"
@@ -1536,7 +1535,7 @@ class DeveloperEcosystemManager:
return example return example
def get_code_example(self, example_id: str) -> Optional[CodeExample]: def get_code_example(self, example_id: str) -> CodeExample | None:
"""获取代码示例""" """获取代码示例"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id,)).fetchone() row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id,)).fetchone()
@@ -1547,11 +1546,11 @@ class DeveloperEcosystemManager:
def list_code_examples( def list_code_examples(
self, self,
language: Optional[str] = None, language: str | None = None,
category: Optional[str] = None, category: str | None = None,
sdk_id: Optional[str] = None, sdk_id: str | None = None,
search: Optional[str] = None, search: str | None = None,
) -> List[CodeExample]: ) -> list[CodeExample]:
"""列出代码示例""" """列出代码示例"""
query = "SELECT * FROM code_examples WHERE 1=1" query = "SELECT * FROM code_examples WHERE 1=1"
params = [] params = []
@@ -1650,7 +1649,7 @@ class DeveloperEcosystemManager:
return doc return doc
def get_api_documentation(self, doc_id: str) -> Optional[APIDocumentation]: def get_api_documentation(self, doc_id: str) -> APIDocumentation | None:
"""获取 API 文档""" """获取 API 文档"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id,)).fetchone() row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id,)).fetchone()
@@ -1659,7 +1658,7 @@ class DeveloperEcosystemManager:
return self._row_to_api_documentation(row) return self._row_to_api_documentation(row)
return None return None
def get_latest_api_documentation(self) -> Optional[APIDocumentation]: def get_latest_api_documentation(self) -> APIDocumentation | None:
"""获取最新 API 文档""" """获取最新 API 文档"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1").fetchone() row = conn.execute("SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1").fetchone()
@@ -1675,16 +1674,16 @@ class DeveloperEcosystemManager:
name: str, name: str,
description: str, description: str,
theme: str = "default", theme: str = "default",
custom_css: Optional[str] = None, custom_css: str | None = None,
custom_js: Optional[str] = None, custom_js: str | None = None,
logo_url: Optional[str] = None, logo_url: str | None = None,
favicon_url: Optional[str] = None, favicon_url: str | None = None,
primary_color: str = "#1890ff", primary_color: str = "#1890ff",
secondary_color: str = "#52c41a", secondary_color: str = "#52c41a",
support_email: str = "support@insightflow.io", support_email: str = "support@insightflow.io",
support_url: Optional[str] = None, support_url: str | None = None,
github_url: Optional[str] = None, github_url: str | None = None,
discord_url: Optional[str] = None, discord_url: str | None = None,
api_base_url: str = "https://api.insightflow.io", api_base_url: str = "https://api.insightflow.io",
) -> DeveloperPortalConfig: ) -> DeveloperPortalConfig:
"""创建开发者门户配置""" """创建开发者门户配置"""
@@ -1746,7 +1745,7 @@ class DeveloperEcosystemManager:
return config return config
def get_portal_config(self, config_id: str) -> Optional[DeveloperPortalConfig]: def get_portal_config(self, config_id: str) -> DeveloperPortalConfig | None:
"""获取开发者门户配置""" """获取开发者门户配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,)).fetchone() row = conn.execute("SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,)).fetchone()
@@ -1755,7 +1754,7 @@ class DeveloperEcosystemManager:
return self._row_to_portal_config(row) return self._row_to_portal_config(row)
return None return None
def get_active_portal_config(self) -> Optional[DeveloperPortalConfig]: def get_active_portal_config(self) -> DeveloperPortalConfig | None:
"""获取活跃的开发者门户配置""" """获取活跃的开发者门户配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1").fetchone() row = conn.execute("SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1").fetchone()

View File

@@ -4,9 +4,8 @@ Document Processor - Phase 3
支持 PDF 和 DOCX 文档导入 支持 PDF 和 DOCX 文档导入
""" """
import os
import io import io
from typing import Dict import os
class DocumentProcessor: class DocumentProcessor:
@@ -21,7 +20,7 @@ class DocumentProcessor:
".md": self._extract_txt, ".md": self._extract_txt,
} }
def process(self, content: bytes, filename: str) -> Dict[str, str]: def process(self, content: bytes, filename: str) -> dict[str, str]:
""" """
处理文档并提取文本 处理文档并提取文本

View File

@@ -10,19 +10,19 @@ InsightFlow Phase 8 - 企业级功能管理模块
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import sqlite3
import json import json
import uuid
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import logging import logging
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SSOProvider(str, Enum): class SSOProvider(StrEnum):
"""SSO 提供商类型""" """SSO 提供商类型"""
WECHAT_WORK = "wechat_work" # 企业微信 WECHAT_WORK = "wechat_work" # 企业微信
@@ -34,7 +34,7 @@ class SSOProvider(str, Enum):
CUSTOM_SAML = "custom_saml" # 自定义 SAML CUSTOM_SAML = "custom_saml" # 自定义 SAML
class SSOStatus(str, Enum): class SSOStatus(StrEnum):
"""SSO 配置状态""" """SSO 配置状态"""
DISABLED = "disabled" # 未启用 DISABLED = "disabled" # 未启用
@@ -43,7 +43,7 @@ class SSOStatus(str, Enum):
ERROR = "error" # 配置错误 ERROR = "error" # 配置错误
class SCIMSyncStatus(str, Enum): class SCIMSyncStatus(StrEnum):
"""SCIM 同步状态""" """SCIM 同步状态"""
IDLE = "idle" # 空闲 IDLE = "idle" # 空闲
@@ -52,7 +52,7 @@ class SCIMSyncStatus(str, Enum):
FAILED = "failed" # 同步失败 FAILED = "failed" # 同步失败
class AuditLogExportFormat(str, Enum): class AuditLogExportFormat(StrEnum):
"""审计日志导出格式""" """审计日志导出格式"""
JSON = "json" JSON = "json"
@@ -61,7 +61,7 @@ class AuditLogExportFormat(str, Enum):
XLSX = "xlsx" XLSX = "xlsx"
class DataRetentionAction(str, Enum): class DataRetentionAction(StrEnum):
"""数据保留策略动作""" """数据保留策略动作"""
ARCHIVE = "archive" # 归档 ARCHIVE = "archive" # 归档
@@ -69,7 +69,7 @@ class DataRetentionAction(str, Enum):
ANONYMIZE = "anonymize" # 匿名化 ANONYMIZE = "anonymize" # 匿名化
class ComplianceStandard(str, Enum): class ComplianceStandard(StrEnum):
"""合规标准""" """合规标准"""
SOC2 = "soc2" SOC2 = "soc2"
@@ -87,29 +87,29 @@ class SSOConfig:
tenant_id: str tenant_id: str
provider: str # SSO 提供商 provider: str # SSO 提供商
status: str # 状态 status: str # 状态
entity_id: Optional[str] # SAML Entity ID entity_id: str | None # SAML Entity ID
sso_url: Optional[str] # SAML SSO URL sso_url: str | None # SAML SSO URL
slo_url: Optional[str] # SAML SLO URL slo_url: str | None # SAML SLO URL
certificate: Optional[str] # SAML 证书 (X.509) certificate: str | None # SAML 证书 (X.509)
metadata_url: Optional[str] # SAML 元数据 URL metadata_url: str | None # SAML 元数据 URL
metadata_xml: Optional[str] # SAML 元数据 XML metadata_xml: str | None # SAML 元数据 XML
# OAuth/OIDC 配置 # OAuth/OIDC 配置
client_id: Optional[str] client_id: str | None
client_secret: Optional[str] client_secret: str | None
authorization_url: Optional[str] authorization_url: str | None
token_url: Optional[str] token_url: str | None
userinfo_url: Optional[str] userinfo_url: str | None
scopes: List[str] scopes: list[str]
# 属性映射 # 属性映射
attribute_mapping: Dict[str, str] # 如 {"email": "user.mail", "name": "user.name"} attribute_mapping: dict[str, str] # 如 {"email": "user.mail", "name": "user.name"}
# 其他配置 # 其他配置
auto_provision: bool # 自动创建用户 auto_provision: bool # 自动创建用户
default_role: str # 默认角色 default_role: str # 默认角色
domain_restriction: List[str] # 允许的邮箱域名 domain_restriction: list[str] # 允许的邮箱域名
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
last_tested_at: Optional[datetime] last_tested_at: datetime | None
last_error: Optional[str] last_error: str | None
@dataclass @dataclass
@@ -125,14 +125,14 @@ class SCIMConfig:
scim_token: str # SCIM 访问令牌 scim_token: str # SCIM 访问令牌
# 同步配置 # 同步配置
sync_interval_minutes: int # 同步间隔(分钟) sync_interval_minutes: int # 同步间隔(分钟)
last_sync_at: Optional[datetime] last_sync_at: datetime | None
last_sync_status: Optional[str] last_sync_status: str | None
last_sync_error: Optional[str] last_sync_error: str | None
last_sync_users_count: int last_sync_users_count: int
# 属性映射 # 属性映射
attribute_mapping: Dict[str, str] attribute_mapping: dict[str, str]
# 同步规则 # 同步规则
sync_rules: Dict[str, Any] # 过滤规则、转换规则等 sync_rules: dict[str, Any] # 过滤规则、转换规则等
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -146,12 +146,12 @@ class SCIMUser:
external_id: str # 外部系统 ID external_id: str # 外部系统 ID
user_name: str user_name: str
email: str email: str
display_name: Optional[str] display_name: str | None
given_name: Optional[str] given_name: str | None
family_name: Optional[str] family_name: str | None
active: bool active: bool
groups: List[str] groups: list[str]
raw_data: Dict[str, Any] # 原始 SCIM 数据 raw_data: dict[str, Any] # 原始 SCIM 数据
synced_at: datetime synced_at: datetime
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -166,20 +166,20 @@ class AuditLogExport:
export_format: str export_format: str
start_date: datetime start_date: datetime
end_date: datetime end_date: datetime
filters: Dict[str, Any] # 过滤条件 filters: dict[str, Any] # 过滤条件
compliance_standard: Optional[str] compliance_standard: str | None
status: str # pending/processing/completed/failed status: str # pending/processing/completed/failed
file_path: Optional[str] file_path: str | None
file_size: Optional[int] file_size: int | None
record_count: Optional[int] record_count: int | None
checksum: Optional[str] # 文件校验和 checksum: str | None # 文件校验和
downloaded_by: Optional[str] downloaded_by: str | None
downloaded_at: Optional[datetime] downloaded_at: datetime | None
expires_at: Optional[datetime] # 文件过期时间 expires_at: datetime | None # 文件过期时间
created_by: str created_by: str
created_at: datetime created_at: datetime
completed_at: Optional[datetime] completed_at: datetime | None
error_message: Optional[str] error_message: str | None
@dataclass @dataclass
@@ -189,23 +189,23 @@ class DataRetentionPolicy:
id: str id: str
tenant_id: str tenant_id: str
name: str name: str
description: Optional[str] description: str | None
resource_type: str # project/transcript/entity/audit_log/user_data resource_type: str # project/transcript/entity/audit_log/user_data
retention_days: int # 保留天数 retention_days: int # 保留天数
action: str # archive/delete/anonymize action: str # archive/delete/anonymize
# 条件 # 条件
conditions: Dict[str, Any] # 触发条件 conditions: dict[str, Any] # 触发条件
# 执行配置 # 执行配置
auto_execute: bool # 自动执行 auto_execute: bool # 自动执行
execute_at: Optional[str] # 执行时间 (cron 表达式) execute_at: str | None # 执行时间 (cron 表达式)
notify_before_days: int # 提前通知天数 notify_before_days: int # 提前通知天数
# 归档配置 # 归档配置
archive_location: Optional[str] # 归档位置 archive_location: str | None # 归档位置
archive_encryption: bool # 归档加密 archive_encryption: bool # 归档加密
# 状态 # 状态
is_active: bool is_active: bool
last_executed_at: Optional[datetime] last_executed_at: datetime | None
last_execution_result: Optional[str] last_execution_result: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -218,13 +218,13 @@ class DataRetentionJob:
policy_id: str policy_id: str
tenant_id: str tenant_id: str
status: str # pending/running/completed/failed status: str # pending/running/completed/failed
started_at: Optional[datetime] started_at: datetime | None
completed_at: Optional[datetime] completed_at: datetime | None
affected_records: int affected_records: int
archived_records: int archived_records: int
deleted_records: int deleted_records: int
error_count: int error_count: int
details: Dict[str, Any] details: dict[str, Any]
created_at: datetime created_at: datetime
@@ -236,11 +236,11 @@ class SAMLAuthRequest:
tenant_id: str tenant_id: str
sso_config_id: str sso_config_id: str
request_id: str # SAML Request ID request_id: str # SAML Request ID
relay_state: Optional[str] relay_state: str | None
created_at: datetime created_at: datetime
expires_at: datetime expires_at: datetime
used: bool used: bool
used_at: Optional[datetime] used_at: datetime | None
@dataclass @dataclass
@@ -250,13 +250,13 @@ class SAMLAuthResponse:
id: str id: str
request_id: str request_id: str
tenant_id: str tenant_id: str
user_id: Optional[str] user_id: str | None
email: Optional[str] email: str | None
name: Optional[str] name: str | None
attributes: Dict[str, Any] attributes: dict[str, Any]
session_index: Optional[str] session_index: str | None
processed: bool processed: bool
processed_at: Optional[datetime] processed_at: datetime | None
created_at: datetime created_at: datetime
@@ -548,22 +548,22 @@ class EnterpriseManager:
self, self,
tenant_id: str, tenant_id: str,
provider: str, provider: str,
entity_id: Optional[str] = None, entity_id: str | None = None,
sso_url: Optional[str] = None, sso_url: str | None = None,
slo_url: Optional[str] = None, slo_url: str | None = None,
certificate: Optional[str] = None, certificate: str | None = None,
metadata_url: Optional[str] = None, metadata_url: str | None = None,
metadata_xml: Optional[str] = None, metadata_xml: str | None = None,
client_id: Optional[str] = None, client_id: str | None = None,
client_secret: Optional[str] = None, client_secret: str | None = None,
authorization_url: Optional[str] = None, authorization_url: str | None = None,
token_url: Optional[str] = None, token_url: str | None = None,
userinfo_url: Optional[str] = None, userinfo_url: str | None = None,
scopes: Optional[List[str]] = None, scopes: list[str] | None = None,
attribute_mapping: Optional[Dict[str, str]] = None, attribute_mapping: dict[str, str] | None = None,
auto_provision: bool = True, auto_provision: bool = True,
default_role: str = "member", default_role: str = "member",
domain_restriction: Optional[List[str]] = None, domain_restriction: list[str] | None = None,
) -> SSOConfig: ) -> SSOConfig:
"""创建 SSO 配置""" """创建 SSO 配置"""
conn = self._get_connection() conn = self._get_connection()
@@ -649,7 +649,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_sso_config(self, config_id: str) -> Optional[SSOConfig]: def get_sso_config(self, config_id: str) -> SSOConfig | None:
"""获取 SSO 配置""" """获取 SSO 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -664,7 +664,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_tenant_sso_config(self, tenant_id: str, provider: Optional[str] = None) -> Optional[SSOConfig]: def get_tenant_sso_config(self, tenant_id: str, provider: str | None = None) -> SSOConfig | None:
"""获取租户的 SSO 配置""" """获取租户的 SSO 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -698,7 +698,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def update_sso_config(self, config_id: str, **kwargs) -> Optional[SSOConfig]: def update_sso_config(self, config_id: str, **kwargs) -> SSOConfig | None:
"""更新 SSO 配置""" """更新 SSO 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -772,7 +772,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def list_sso_configs(self, tenant_id: str) -> List[SSOConfig]: def list_sso_configs(self, tenant_id: str) -> list[SSOConfig]:
"""列出租户的所有 SSO 配置""" """列出租户的所有 SSO 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -835,7 +835,7 @@ class EnterpriseManager:
return metadata return metadata
def create_saml_auth_request( def create_saml_auth_request(
self, tenant_id: str, config_id: str, relay_state: Optional[str] = None self, tenant_id: str, config_id: str, relay_state: str | None = None
) -> SAMLAuthRequest: ) -> SAMLAuthRequest:
"""创建 SAML 认证请求""" """创建 SAML 认证请求"""
conn = self._get_connection() conn = self._get_connection()
@@ -881,7 +881,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_saml_auth_request(self, request_id: str) -> Optional[SAMLAuthRequest]: def get_saml_auth_request(self, request_id: str) -> SAMLAuthRequest | None:
"""获取 SAML 认证请求""" """获取 SAML 认证请求"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -901,7 +901,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def process_saml_response(self, request_id: str, saml_response: str) -> Optional[SAMLAuthResponse]: def process_saml_response(self, request_id: str, saml_response: str) -> SAMLAuthResponse | None:
"""处理 SAML 响应""" """处理 SAML 响应"""
# 这里应该实现实际的 SAML 响应解析 # 这里应该实现实际的 SAML 响应解析
# 简化实现:假设响应已经验证并解析 # 简化实现:假设响应已经验证并解析
@@ -954,7 +954,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def _parse_saml_response(self, saml_response: str) -> Dict[str, Any]: def _parse_saml_response(self, saml_response: str) -> dict[str, Any]:
"""解析 SAML 响应(简化实现)""" """解析 SAML 响应(简化实现)"""
# 实际应该使用 python-saml 库解析 # 实际应该使用 python-saml 库解析
# 这里返回模拟数据 # 这里返回模拟数据
@@ -974,8 +974,8 @@ class EnterpriseManager:
scim_base_url: str, scim_base_url: str,
scim_token: str, scim_token: str,
sync_interval_minutes: int = 60, sync_interval_minutes: int = 60,
attribute_mapping: Optional[Dict[str, str]] = None, attribute_mapping: dict[str, str] | None = None,
sync_rules: Optional[Dict[str, Any]] = None, sync_rules: dict[str, Any] | None = None,
) -> SCIMConfig: ) -> SCIMConfig:
"""创建 SCIM 配置""" """创建 SCIM 配置"""
conn = self._get_connection() conn = self._get_connection()
@@ -1035,7 +1035,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_scim_config(self, config_id: str) -> Optional[SCIMConfig]: def get_scim_config(self, config_id: str) -> SCIMConfig | None:
"""获取 SCIM 配置""" """获取 SCIM 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1050,7 +1050,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_tenant_scim_config(self, tenant_id: str) -> Optional[SCIMConfig]: def get_tenant_scim_config(self, tenant_id: str) -> SCIMConfig | None:
"""获取租户的 SCIM 配置""" """获取租户的 SCIM 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1071,7 +1071,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def update_scim_config(self, config_id: str, **kwargs) -> Optional[SCIMConfig]: def update_scim_config(self, config_id: str, **kwargs) -> SCIMConfig | None:
"""更新 SCIM 配置""" """更新 SCIM 配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1121,7 +1121,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def sync_scim_users(self, config_id: str) -> Dict[str, Any]: def sync_scim_users(self, config_id: str) -> dict[str, Any]:
"""执行 SCIM 用户同步""" """执行 SCIM 用户同步"""
config = self.get_scim_config(config_id) config = self.get_scim_config(config_id)
if not config: if not config:
@@ -1184,13 +1184,13 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def _fetch_scim_users(self, config: SCIMConfig) -> List[Dict[str, Any]]: def _fetch_scim_users(self, config: SCIMConfig) -> list[dict[str, Any]]:
"""从 SCIM 服务端获取用户(模拟实现)""" """从 SCIM 服务端获取用户(模拟实现)"""
# 实际应该使用 HTTP 请求获取 # 实际应该使用 HTTP 请求获取
# GET {scim_base_url}/Users # GET {scim_base_url}/Users
return [] return []
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: Dict[str, Any]): def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]):
"""插入或更新 SCIM 用户""" """插入或更新 SCIM 用户"""
cursor = conn.cursor() cursor = conn.cursor()
@@ -1238,7 +1238,7 @@ class EnterpriseManager:
), ),
) )
def list_scim_users(self, tenant_id: str, active_only: bool = True) -> List[SCIMUser]: def list_scim_users(self, tenant_id: str, active_only: bool = True) -> list[SCIMUser]:
"""列出 SCIM 用户""" """列出 SCIM 用户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1269,8 +1269,8 @@ class EnterpriseManager:
start_date: datetime, start_date: datetime,
end_date: datetime, end_date: datetime,
created_by: str, created_by: str,
filters: Optional[Dict[str, Any]] = None, filters: dict[str, Any] | None = None,
compliance_standard: Optional[str] = None, compliance_standard: str | None = None,
) -> AuditLogExport: ) -> AuditLogExport:
"""创建审计日志导出任务""" """创建审计日志导出任务"""
conn = self._get_connection() conn = self._get_connection()
@@ -1337,7 +1337,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def process_audit_export(self, export_id: str, db_manager=None) -> Optional[AuditLogExport]: def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None:
"""处理审计日志导出任务""" """处理审计日志导出任务"""
export = self.get_audit_export(export_id) export = self.get_audit_export(export_id)
if not export: if not export:
@@ -1401,8 +1401,8 @@ class EnterpriseManager:
conn.close() conn.close()
def _fetch_audit_logs( def _fetch_audit_logs(
self, tenant_id: str, start_date: datetime, end_date: datetime, filters: Dict[str, Any], db_manager=None self, tenant_id: str, start_date: datetime, end_date: datetime, filters: dict[str, Any], db_manager=None
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
"""获取审计日志数据""" """获取审计日志数据"""
if db_manager is None: if db_manager is None:
return [] return []
@@ -1411,7 +1411,7 @@ class EnterpriseManager:
# 这里简化实现 # 这里简化实现
return [] return []
def _apply_compliance_filter(self, logs: List[Dict[str, Any]], standard: str) -> List[Dict[str, Any]]: def _apply_compliance_filter(self, logs: list[dict[str, Any]], standard: str) -> list[dict[str, Any]]:
"""应用合规标准字段过滤""" """应用合规标准字段过滤"""
fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), []) fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), [])
@@ -1425,10 +1425,10 @@ class EnterpriseManager:
return filtered_logs return filtered_logs
def _generate_export_file(self, export_id: str, logs: List[Dict[str, Any]], format: str) -> Tuple[str, int, str]: def _generate_export_file(self, export_id: str, logs: list[dict[str, Any]], format: str) -> tuple[str, int, str]:
"""生成导出文件""" """生成导出文件"""
import os
import hashlib import hashlib
import os
export_dir = "/tmp/insightflow/exports" export_dir = "/tmp/insightflow/exports"
os.makedirs(export_dir, exist_ok=True) os.makedirs(export_dir, exist_ok=True)
@@ -1461,7 +1461,7 @@ class EnterpriseManager:
return file_path, file_size, checksum return file_path, file_size, checksum
def get_audit_export(self, export_id: str) -> Optional[AuditLogExport]: def get_audit_export(self, export_id: str) -> AuditLogExport | None:
"""获取审计日志导出记录""" """获取审计日志导出记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1476,7 +1476,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def list_audit_exports(self, tenant_id: str, limit: int = 100) -> List[AuditLogExport]: def list_audit_exports(self, tenant_id: str, limit: int = 100) -> list[AuditLogExport]:
"""列出审计日志导出记录""" """列出审计日志导出记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1524,12 +1524,12 @@ class EnterpriseManager:
resource_type: str, resource_type: str,
retention_days: int, retention_days: int,
action: str, action: str,
description: Optional[str] = None, description: str | None = None,
conditions: Optional[Dict[str, Any]] = None, conditions: dict[str, Any] | None = None,
auto_execute: bool = False, auto_execute: bool = False,
execute_at: Optional[str] = None, execute_at: str | None = None,
notify_before_days: int = 7, notify_before_days: int = 7,
archive_location: Optional[str] = None, archive_location: str | None = None,
archive_encryption: bool = True, archive_encryption: bool = True,
) -> DataRetentionPolicy: ) -> DataRetentionPolicy:
"""创建数据保留策略""" """创建数据保留策略"""
@@ -1599,7 +1599,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def get_retention_policy(self, policy_id: str) -> Optional[DataRetentionPolicy]: def get_retention_policy(self, policy_id: str) -> DataRetentionPolicy | None:
"""获取数据保留策略""" """获取数据保留策略"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1614,7 +1614,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def list_retention_policies(self, tenant_id: str, resource_type: Optional[str] = None) -> List[DataRetentionPolicy]: def list_retention_policies(self, tenant_id: str, resource_type: str | None = None) -> list[DataRetentionPolicy]:
"""列出数据保留策略""" """列出数据保留策略"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1637,7 +1637,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def update_retention_policy(self, policy_id: str, **kwargs) -> Optional[DataRetentionPolicy]: def update_retention_policy(self, policy_id: str, **kwargs) -> DataRetentionPolicy | None:
"""更新数据保留策略""" """更新数据保留策略"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1818,7 +1818,7 @@ class EnterpriseManager:
def _retain_audit_logs( def _retain_audit_logs(
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime
) -> Dict[str, int]: ) -> dict[str, int]:
"""保留审计日志""" """保留审计日志"""
cursor = conn.cursor() cursor = conn.cursor()
@@ -1851,19 +1851,19 @@ class EnterpriseManager:
def _retain_projects( def _retain_projects(
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime
) -> Dict[str, int]: ) -> dict[str, int]:
"""保留项目数据""" """保留项目数据"""
# 简化实现 # 简化实现
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
def _retain_transcripts( def _retain_transcripts(
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime
) -> Dict[str, int]: ) -> dict[str, int]:
"""保留转录数据""" """保留转录数据"""
# 简化实现 # 简化实现
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0} return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
def get_retention_job(self, job_id: str) -> Optional[DataRetentionJob]: def get_retention_job(self, job_id: str) -> DataRetentionJob | None:
"""获取数据保留任务""" """获取数据保留任务"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1878,7 +1878,7 @@ class EnterpriseManager:
finally: finally:
conn.close() conn.close()
def list_retention_jobs(self, policy_id: str, limit: int = 100) -> List[DataRetentionJob]: def list_retention_jobs(self, policy_id: str, limit: int = 100) -> list[DataRetentionJob]:
"""列出数据保留任务""" """列出数据保留任务"""
conn = self._get_connection() conn = self._get_connection()
try: try:

View File

@@ -4,12 +4,12 @@ Entity Aligner - Phase 3
使用 embedding 进行实体对齐 使用 embedding 进行实体对齐
""" """
import os
import json import json
import os
from dataclasses import dataclass
import httpx import httpx
import numpy as np import numpy as np
from typing import List, Optional, Dict
from dataclasses import dataclass
# API Keys # API Keys
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
@@ -21,7 +21,7 @@ class EntityEmbedding:
entity_id: str entity_id: str
name: str name: str
definition: str definition: str
embedding: List[float] embedding: list[float]
class EntityAligner: class EntityAligner:
@@ -29,9 +29,9 @@ class EntityAligner:
def __init__(self, similarity_threshold: float = 0.85): def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold self.similarity_threshold = similarity_threshold
self.embedding_cache: Dict[str, List[float]] = {} self.embedding_cache: dict[str, list[float]] = {}
def get_embedding(self, text: str) -> Optional[List[float]]: def get_embedding(self, text: str) -> list[float] | None:
""" """
使用 Kimi API 获取文本的 embedding 使用 Kimi API 获取文本的 embedding
@@ -67,7 +67,7 @@ class EntityAligner:
print(f"Embedding API failed: {e}") print(f"Embedding API failed: {e}")
return None return None
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: def compute_similarity(self, embedding1: list[float], embedding2: list[float]) -> float:
""" """
计算两个 embedding 的余弦相似度 计算两个 embedding 的余弦相似度
@@ -111,9 +111,9 @@ class EntityAligner:
project_id: str, project_id: str,
name: str, name: str,
definition: str = "", definition: str = "",
exclude_id: Optional[str] = None, exclude_id: str | None = None,
threshold: Optional[float] = None, threshold: float | None = None,
) -> Optional[object]: ) -> object | None:
""" """
查找相似的实体 查找相似的实体
@@ -175,8 +175,8 @@ class EntityAligner:
return best_match return best_match
def _fallback_similarity_match( def _fallback_similarity_match(
self, entities: List[object], name: str, exclude_id: Optional[str] = None self, entities: list[object], name: str, exclude_id: str | None = None
) -> Optional[object]: ) -> object | None:
""" """
回退到简单的相似度匹配(不使用 embedding 回退到简单的相似度匹配(不使用 embedding
@@ -209,8 +209,8 @@ class EntityAligner:
return None return None
def batch_align_entities( def batch_align_entities(
self, project_id: str, new_entities: List[Dict], threshold: Optional[float] = None self, project_id: str, new_entities: list[dict], threshold: float | None = None
) -> List[Dict]: ) -> list[dict]:
""" """
批量对齐实体 批量对齐实体
@@ -257,7 +257,7 @@ class EntityAligner:
return results return results
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> List[str]: def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
""" """
使用 LLM 建议实体的别名 使用 LLM 建议实体的别名

View File

@@ -3,12 +3,12 @@ InsightFlow Export Module - Phase 5
支持导出知识图谱、项目报告、实体数据和转录文本 支持导出知识图谱、项目报告、实体数据和转录文本
""" """
import base64
import io import io
import json import json
import base64
from datetime import datetime
from typing import List, Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from typing import Any
try: try:
import pandas as pd import pandas as pd
@@ -20,9 +20,16 @@ except ImportError:
try: try:
from reportlab.lib import colors from reportlab.lib import colors
from reportlab.lib.pagesizes import A4 from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch from reportlab.lib.units import inch
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak from reportlab.platypus import (
PageBreak,
Paragraph,
SimpleDocTemplate,
Spacer,
Table,
TableStyle,
)
REPORTLAB_AVAILABLE = True REPORTLAB_AVAILABLE = True
except ImportError: except ImportError:
@@ -35,9 +42,9 @@ class ExportEntity:
name: str name: str
type: str type: str
definition: str definition: str
aliases: List[str] aliases: list[str]
mention_count: int mention_count: int
attributes: Dict[str, Any] attributes: dict[str, Any]
@dataclass @dataclass
@@ -56,8 +63,8 @@ class ExportTranscript:
name: str name: str
type: str # audio/document type: str # audio/document
content: str content: str
segments: List[Dict] segments: list[dict]
entity_mentions: List[Dict] entity_mentions: list[dict]
class ExportManager: class ExportManager:
@@ -67,7 +74,7 @@ class ExportManager:
self.db = db_manager self.db = db_manager
def export_knowledge_graph_svg( def export_knowledge_graph_svg(
self, project_id: str, entities: List[ExportEntity], relations: List[ExportRelation] self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
) -> str: ) -> str:
""" """
导出知识图谱为 SVG 格式 导出知识图谱为 SVG 格式
@@ -207,7 +214,7 @@ class ExportManager:
return "\n".join(svg_parts) return "\n".join(svg_parts)
def export_knowledge_graph_png( def export_knowledge_graph_png(
self, project_id: str, entities: List[ExportEntity], relations: List[ExportRelation] self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
) -> bytes: ) -> bytes:
""" """
导出知识图谱为 PNG 格式 导出知识图谱为 PNG 格式
@@ -226,7 +233,7 @@ class ExportManager:
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations) svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
return base64.b64encode(svg_content.encode("utf-8")) return base64.b64encode(svg_content.encode("utf-8"))
def export_entities_excel(self, entities: List[ExportEntity]) -> bytes: def export_entities_excel(self, entities: list[ExportEntity]) -> bytes:
""" """
导出实体数据为 Excel 格式 导出实体数据为 Excel 格式
@@ -275,7 +282,7 @@ class ExportManager:
return output.getvalue() return output.getvalue()
def export_entities_csv(self, entities: List[ExportEntity]) -> str: def export_entities_csv(self, entities: list[ExportEntity]) -> str:
""" """
导出实体数据为 CSV 格式 导出实体数据为 CSV 格式
@@ -306,7 +313,7 @@ class ExportManager:
return output.getvalue() return output.getvalue()
def export_relations_csv(self, relations: List[ExportRelation]) -> str: def export_relations_csv(self, relations: list[ExportRelation]) -> str:
""" """
导出关系数据为 CSV 格式 导出关系数据为 CSV 格式
@@ -324,7 +331,7 @@ class ExportManager:
return output.getvalue() return output.getvalue()
def export_transcript_markdown(self, transcript: ExportTranscript, entities_map: Dict[str, ExportEntity]) -> str: def export_transcript_markdown(self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity]) -> str:
""" """
导出转录文本为 Markdown 格式 导出转录文本为 Markdown 格式
@@ -387,9 +394,9 @@ class ExportManager:
self, self,
project_id: str, project_id: str,
project_name: str, project_name: str,
entities: List[ExportEntity], entities: list[ExportEntity],
relations: List[ExportRelation], relations: list[ExportRelation],
transcripts: List[ExportTranscript], transcripts: list[ExportTranscript],
summary: str = "", summary: str = "",
) -> bytes: ) -> bytes:
""" """
@@ -529,9 +536,9 @@ class ExportManager:
self, self,
project_id: str, project_id: str,
project_name: str, project_name: str,
entities: List[ExportEntity], entities: list[ExportEntity],
relations: List[ExportRelation], relations: list[ExportRelation],
transcripts: List[ExportTranscript], transcripts: list[ExportTranscript],
) -> str: ) -> str:
""" """
导出完整项目数据为 JSON 格式 导出完整项目数据为 JSON 格式

View File

@@ -10,25 +10,26 @@ InsightFlow Growth Manager - Phase 8 Task 5
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import os
import json
import sqlite3
import httpx
import asyncio import asyncio
import hashlib
import json
import os
import random import random
from typing import List, Dict, Optional, Any, Tuple import re
import sqlite3
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import StrEnum
import hashlib from typing import Any
import uuid
import re import httpx
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class EventType(str, Enum): class EventType(StrEnum):
"""事件类型""" """事件类型"""
PAGE_VIEW = "page_view" # 页面浏览 PAGE_VIEW = "page_view" # 页面浏览
@@ -44,7 +45,7 @@ class EventType(str, Enum):
REFERRAL_REWARD = "referral_reward" # 推荐奖励 REFERRAL_REWARD = "referral_reward" # 推荐奖励
class ExperimentStatus(str, Enum): class ExperimentStatus(StrEnum):
"""实验状态""" """实验状态"""
DRAFT = "draft" # 草稿 DRAFT = "draft" # 草稿
@@ -54,7 +55,7 @@ class ExperimentStatus(str, Enum):
ARCHIVED = "archived" # 已归档 ARCHIVED = "archived" # 已归档
class TrafficAllocationType(str, Enum): class TrafficAllocationType(StrEnum):
"""流量分配类型""" """流量分配类型"""
RANDOM = "random" # 随机分配 RANDOM = "random" # 随机分配
@@ -62,7 +63,7 @@ class TrafficAllocationType(str, Enum):
TARGETED = "targeted" # 定向分配 TARGETED = "targeted" # 定向分配
class EmailTemplateType(str, Enum): class EmailTemplateType(StrEnum):
"""邮件模板类型""" """邮件模板类型"""
WELCOME = "welcome" # 欢迎邮件 WELCOME = "welcome" # 欢迎邮件
@@ -74,7 +75,7 @@ class EmailTemplateType(str, Enum):
NEWSLETTER = "newsletter" # 新闻通讯 NEWSLETTER = "newsletter" # 新闻通讯
class EmailStatus(str, Enum): class EmailStatus(StrEnum):
"""邮件状态""" """邮件状态"""
DRAFT = "draft" # 草稿 DRAFT = "draft" # 草稿
@@ -88,7 +89,7 @@ class EmailStatus(str, Enum):
FAILED = "failed" # 失败 FAILED = "failed" # 失败
class WorkflowTriggerType(str, Enum): class WorkflowTriggerType(StrEnum):
"""工作流触发类型""" """工作流触发类型"""
USER_SIGNUP = "user_signup" # 用户注册 USER_SIGNUP = "user_signup" # 用户注册
@@ -100,7 +101,7 @@ class WorkflowTriggerType(str, Enum):
CUSTOM_EVENT = "custom_event" # 自定义事件 CUSTOM_EVENT = "custom_event" # 自定义事件
class ReferralStatus(str, Enum): class ReferralStatus(StrEnum):
"""推荐状态""" """推荐状态"""
PENDING = "pending" # 待处理 PENDING = "pending" # 待处理
@@ -118,14 +119,14 @@ class AnalyticsEvent:
user_id: str user_id: str
event_type: EventType event_type: EventType
event_name: str event_name: str
properties: Dict[str, Any] # 事件属性 properties: dict[str, Any] # 事件属性
timestamp: datetime timestamp: datetime
session_id: Optional[str] session_id: str | None
device_info: Dict[str, str] # 设备信息 device_info: dict[str, str] # 设备信息
referrer: Optional[str] referrer: str | None
utm_source: Optional[str] utm_source: str | None
utm_medium: Optional[str] utm_medium: str | None
utm_campaign: Optional[str] utm_campaign: str | None
@dataclass @dataclass
@@ -139,8 +140,8 @@ class UserProfile:
last_seen: datetime last_seen: datetime
total_sessions: int total_sessions: int
total_events: int total_events: int
feature_usage: Dict[str, int] # 功能使用次数 feature_usage: dict[str, int] # 功能使用次数
subscription_history: List[Dict] subscription_history: list[dict]
ltv: float # 生命周期价值 ltv: float # 生命周期价值
churn_risk_score: float # 流失风险分数 churn_risk_score: float # 流失风险分数
engagement_score: float # 参与度分数 engagement_score: float # 参与度分数
@@ -156,7 +157,7 @@ class Funnel:
tenant_id: str tenant_id: str
name: str name: str
description: str description: str
steps: List[Dict] # 漏斗步骤 steps: list[dict] # 漏斗步骤
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -169,9 +170,9 @@ class FunnelAnalysis:
period_start: datetime period_start: datetime
period_end: datetime period_end: datetime
total_users: int total_users: int
step_conversions: List[Dict] # 每步转化数据 step_conversions: list[dict] # 每步转化数据
overall_conversion: float # 总体转化率 overall_conversion: float # 总体转化率
drop_off_points: List[Dict] # 流失点 drop_off_points: list[dict] # 流失点
@dataclass @dataclass
@@ -184,14 +185,14 @@ class Experiment:
description: str description: str
hypothesis: str hypothesis: str
status: ExperimentStatus status: ExperimentStatus
variants: List[Dict] # 实验变体 variants: list[dict] # 实验变体
traffic_allocation: TrafficAllocationType traffic_allocation: TrafficAllocationType
traffic_split: Dict[str, float] # 流量分配比例 traffic_split: dict[str, float] # 流量分配比例
target_audience: Dict # 目标受众 target_audience: dict # 目标受众
primary_metric: str # 主要指标 primary_metric: str # 主要指标
secondary_metrics: List[str] # 次要指标 secondary_metrics: list[str] # 次要指标
start_date: Optional[datetime] start_date: datetime | None
end_date: Optional[datetime] end_date: datetime | None
min_sample_size: int # 最小样本量 min_sample_size: int # 最小样本量
confidence_level: float # 置信水平 confidence_level: float # 置信水平
created_at: datetime created_at: datetime
@@ -210,7 +211,7 @@ class ExperimentResult:
sample_size: int sample_size: int
mean_value: float mean_value: float
std_dev: float std_dev: float
confidence_interval: Tuple[float, float] confidence_interval: tuple[float, float]
p_value: float p_value: float
is_significant: bool is_significant: bool
uplift: float # 提升幅度 uplift: float # 提升幅度
@@ -228,11 +229,11 @@ class EmailTemplate:
subject: str subject: str
html_content: str html_content: str
text_content: str text_content: str
variables: List[str] # 模板变量 variables: list[str] # 模板变量
preview_text: Optional[str] preview_text: str | None
from_name: str from_name: str
from_email: str from_email: str
reply_to: Optional[str] reply_to: str | None
is_active: bool is_active: bool
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -254,9 +255,9 @@ class EmailCampaign:
clicked_count: int clicked_count: int
bounced_count: int bounced_count: int
failed_count: int failed_count: int
scheduled_at: Optional[datetime] scheduled_at: datetime | None
started_at: Optional[datetime] started_at: datetime | None
completed_at: Optional[datetime] completed_at: datetime | None
created_at: datetime created_at: datetime
@@ -272,13 +273,13 @@ class EmailLog:
template_id: str template_id: str
status: EmailStatus status: EmailStatus
subject: str subject: str
sent_at: Optional[datetime] sent_at: datetime | None
delivered_at: Optional[datetime] delivered_at: datetime | None
opened_at: Optional[datetime] opened_at: datetime | None
clicked_at: Optional[datetime] clicked_at: datetime | None
ip_address: Optional[str] ip_address: str | None
user_agent: Optional[str] user_agent: str | None
error_message: Optional[str] error_message: str | None
created_at: datetime created_at: datetime
@@ -291,8 +292,8 @@ class AutomationWorkflow:
name: str name: str
description: str description: str
trigger_type: WorkflowTriggerType trigger_type: WorkflowTriggerType
trigger_conditions: Dict # 触发条件 trigger_conditions: dict # 触发条件
actions: List[Dict] # 执行动作 actions: list[dict] # 执行动作
is_active: bool is_active: bool
execution_count: int execution_count: int
created_at: datetime created_at: datetime
@@ -327,15 +328,15 @@ class Referral:
program_id: str program_id: str
tenant_id: str tenant_id: str
referrer_id: str # 推荐人 referrer_id: str # 推荐人
referee_id: Optional[str] # 被推荐人 referee_id: str | None # 被推荐人
referral_code: str referral_code: str
status: ReferralStatus status: ReferralStatus
referrer_rewarded: bool referrer_rewarded: bool
referee_rewarded: bool referee_rewarded: bool
referrer_reward_value: float referrer_reward_value: float
referee_reward_value: float referee_reward_value: float
converted_at: Optional[datetime] converted_at: datetime | None
rewarded_at: Optional[datetime] rewarded_at: datetime | None
expires_at: datetime expires_at: datetime
created_at: datetime created_at: datetime
@@ -382,11 +383,11 @@ class GrowthManager:
user_id: str, user_id: str,
event_type: EventType, event_type: EventType,
event_name: str, event_name: str,
properties: Dict = None, properties: dict = None,
session_id: str = None, session_id: str = None,
device_info: Dict = None, device_info: dict = None,
referrer: str = None, referrer: str = None,
utm_params: Dict = None, utm_params: dict = None,
) -> AnalyticsEvent: ) -> AnalyticsEvent:
"""追踪事件""" """追踪事件"""
event_id = f"evt_{uuid.uuid4().hex[:16]}" event_id = f"evt_{uuid.uuid4().hex[:16]}"
@@ -554,7 +555,7 @@ class GrowthManager:
conn.commit() conn.commit()
def get_user_profile(self, tenant_id: str, user_id: str) -> Optional[UserProfile]: def get_user_profile(self, tenant_id: str, user_id: str) -> UserProfile | None:
"""获取用户画像""" """获取用户画像"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
@@ -567,7 +568,7 @@ class GrowthManager:
def get_user_analytics_summary( def get_user_analytics_summary(
self, tenant_id: str, start_date: datetime = None, end_date: datetime = None self, tenant_id: str, start_date: datetime = None, end_date: datetime = None
) -> Dict: ) -> dict:
"""获取用户分析汇总""" """获取用户分析汇总"""
with self._get_db() as conn: with self._get_db() as conn:
query = """ query = """
@@ -619,7 +620,7 @@ class GrowthManager:
"event_type_distribution": {r["event_type"]: r["count"] for r in type_rows}, "event_type_distribution": {r["event_type"]: r["count"] for r in type_rows},
} }
def create_funnel(self, tenant_id: str, name: str, description: str, steps: List[Dict], created_by: str) -> Funnel: def create_funnel(self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str) -> Funnel:
"""创建转化漏斗""" """创建转化漏斗"""
funnel_id = f"fnl_{uuid.uuid4().hex[:16]}" funnel_id = f"fnl_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -657,7 +658,7 @@ class GrowthManager:
def analyze_funnel( def analyze_funnel(
self, funnel_id: str, period_start: datetime = None, period_end: datetime = None self, funnel_id: str, period_start: datetime = None, period_end: datetime = None
) -> Optional[FunnelAnalysis]: ) -> FunnelAnalysis | None:
"""分析漏斗转化率""" """分析漏斗转化率"""
with self._get_db() as conn: with self._get_db() as conn:
funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id,)).fetchone() funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id,)).fetchone()
@@ -728,7 +729,7 @@ class GrowthManager:
drop_off_points=drop_off_points, drop_off_points=drop_off_points,
) )
def calculate_retention(self, tenant_id: str, cohort_date: datetime, periods: List[int] = None) -> Dict: def calculate_retention(self, tenant_id: str, cohort_date: datetime, periods: list[int] = None) -> dict:
"""计算留存率""" """计算留存率"""
if periods is None: if periods is None:
periods = [1, 3, 7, 14, 30] periods = [1, 3, 7, 14, 30]
@@ -787,12 +788,12 @@ class GrowthManager:
name: str, name: str,
description: str, description: str,
hypothesis: str, hypothesis: str,
variants: List[Dict], variants: list[dict],
traffic_allocation: TrafficAllocationType, traffic_allocation: TrafficAllocationType,
traffic_split: Dict[str, float], traffic_split: dict[str, float],
target_audience: Dict, target_audience: dict,
primary_metric: str, primary_metric: str,
secondary_metrics: List[str], secondary_metrics: list[str],
min_sample_size: int = 100, min_sample_size: int = 100,
confidence_level: float = 0.95, confidence_level: float = 0.95,
created_by: str = None, created_by: str = None,
@@ -859,7 +860,7 @@ class GrowthManager:
return experiment return experiment
def get_experiment(self, experiment_id: str) -> Optional[Experiment]: def get_experiment(self, experiment_id: str) -> Experiment | None:
"""获取实验详情""" """获取实验详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM experiments WHERE id = ?", (experiment_id,)).fetchone() row = conn.execute("SELECT * FROM experiments WHERE id = ?", (experiment_id,)).fetchone()
@@ -868,7 +869,7 @@ class GrowthManager:
return self._row_to_experiment(row) return self._row_to_experiment(row)
return None return None
def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> List[Experiment]: def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]:
"""列出实验""" """列出实验"""
query = "SELECT * FROM experiments WHERE tenant_id = ?" query = "SELECT * FROM experiments WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -883,7 +884,7 @@ class GrowthManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_experiment(row) for row in rows] return [self._row_to_experiment(row) for row in rows]
def assign_variant(self, experiment_id: str, user_id: str, user_attributes: Dict = None) -> Optional[str]: def assign_variant(self, experiment_id: str, user_id: str, user_attributes: dict = None) -> str | None:
"""为用户分配实验变体""" """为用户分配实验变体"""
experiment = self.get_experiment(experiment_id) experiment = self.get_experiment(experiment_id)
if not experiment or experiment.status != ExperimentStatus.RUNNING: if not experiment or experiment.status != ExperimentStatus.RUNNING:
@@ -929,7 +930,7 @@ class GrowthManager:
return variant_id return variant_id
def _random_allocation(self, variants: List[Dict], traffic_split: Dict[str, float]) -> str: def _random_allocation(self, variants: list[dict], traffic_split: dict[str, float]) -> str:
"""随机分配""" """随机分配"""
variant_ids = [v["id"] for v in variants] variant_ids = [v["id"] for v in variants]
weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids] weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids]
@@ -940,7 +941,7 @@ class GrowthManager:
return random.choices(variant_ids, weights=normalized_weights, k=1)[0] return random.choices(variant_ids, weights=normalized_weights, k=1)[0]
def _stratified_allocation( def _stratified_allocation(
self, variants: List[Dict], traffic_split: Dict[str, float], user_attributes: Dict self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict
) -> str: ) -> str:
"""分层分配(基于用户属性)""" """分层分配(基于用户属性)"""
# 简化的分层分配:根据用户 ID 哈希值分配 # 简化的分层分配:根据用户 ID 哈希值分配
@@ -952,7 +953,7 @@ class GrowthManager:
return self._random_allocation(variants, traffic_split) return self._random_allocation(variants, traffic_split)
def _targeted_allocation(self, variants: List[Dict], target_audience: Dict, user_attributes: Dict) -> Optional[str]: def _targeted_allocation(self, variants: list[dict], target_audience: dict, user_attributes: dict) -> str | None:
"""定向分配(基于目标受众条件)""" """定向分配(基于目标受众条件)"""
# 检查用户是否符合目标受众条件 # 检查用户是否符合目标受众条件
conditions = target_audience.get("conditions", []) conditions = target_audience.get("conditions", [])
@@ -1005,7 +1006,7 @@ class GrowthManager:
) )
conn.commit() conn.commit()
def analyze_experiment(self, experiment_id: str) -> Dict: def analyze_experiment(self, experiment_id: str) -> dict:
"""分析实验结果""" """分析实验结果"""
experiment = self.get_experiment(experiment_id) experiment = self.get_experiment(experiment_id)
if not experiment: if not experiment:
@@ -1083,7 +1084,7 @@ class GrowthManager:
"variant_results": results, "variant_results": results,
} }
def start_experiment(self, experiment_id: str) -> Optional[Experiment]: def start_experiment(self, experiment_id: str) -> Experiment | None:
"""启动实验""" """启动实验"""
with self._get_db() as conn: with self._get_db() as conn:
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1099,7 +1100,7 @@ class GrowthManager:
return self.get_experiment(experiment_id) return self.get_experiment(experiment_id)
def stop_experiment(self, experiment_id: str) -> Optional[Experiment]: def stop_experiment(self, experiment_id: str) -> Experiment | None:
"""停止实验""" """停止实验"""
with self._get_db() as conn: with self._get_db() as conn:
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1125,7 +1126,7 @@ class GrowthManager:
subject: str, subject: str,
html_content: str, html_content: str,
text_content: str = None, text_content: str = None,
variables: List[str] = None, variables: list[str] = None,
from_name: str = None, from_name: str = None,
from_email: str = None, from_email: str = None,
reply_to: str = None, reply_to: str = None,
@@ -1185,7 +1186,7 @@ class GrowthManager:
return template return template
def get_email_template(self, template_id: str) -> Optional[EmailTemplate]: def get_email_template(self, template_id: str) -> EmailTemplate | None:
"""获取邮件模板""" """获取邮件模板"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM email_templates WHERE id = ?", (template_id,)).fetchone() row = conn.execute("SELECT * FROM email_templates WHERE id = ?", (template_id,)).fetchone()
@@ -1194,7 +1195,7 @@ class GrowthManager:
return self._row_to_email_template(row) return self._row_to_email_template(row)
return None return None
def list_email_templates(self, tenant_id: str, template_type: EmailTemplateType = None) -> List[EmailTemplate]: def list_email_templates(self, tenant_id: str, template_type: EmailTemplateType = None) -> list[EmailTemplate]:
"""列出邮件模板""" """列出邮件模板"""
query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1" query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1"
params = [tenant_id] params = [tenant_id]
@@ -1209,7 +1210,7 @@ class GrowthManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_email_template(row) for row in rows] return [self._row_to_email_template(row) for row in rows]
def render_template(self, template_id: str, variables: Dict) -> Dict[str, str]: def render_template(self, template_id: str, variables: dict) -> dict[str, str]:
"""渲染邮件模板""" """渲染邮件模板"""
template = self.get_email_template(template_id) template = self.get_email_template(template_id)
if not template: if not template:
@@ -1235,7 +1236,7 @@ class GrowthManager:
} }
def create_email_campaign( def create_email_campaign(
self, tenant_id: str, name: str, template_id: str, recipient_list: List[Dict], scheduled_at: datetime = None self, tenant_id: str, name: str, template_id: str, recipient_list: list[dict], scheduled_at: datetime = None
) -> EmailCampaign: ) -> EmailCampaign:
"""创建邮件营销活动""" """创建邮件营销活动"""
campaign_id = f"ec_{uuid.uuid4().hex[:16]}" campaign_id = f"ec_{uuid.uuid4().hex[:16]}"
@@ -1314,7 +1315,7 @@ class GrowthManager:
return campaign return campaign
async def send_email(self, campaign_id: str, user_id: str, email: str, template_id: str, variables: Dict) -> bool: async def send_email(self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict) -> bool:
"""发送单封邮件""" """发送单封邮件"""
template = self.get_email_template(template_id) template = self.get_email_template(template_id)
if not template: if not template:
@@ -1380,7 +1381,7 @@ class GrowthManager:
conn.commit() conn.commit()
return False return False
async def send_campaign(self, campaign_id: str) -> Dict: async def send_campaign(self, campaign_id: str) -> dict:
"""发送整个营销活动""" """发送整个营销活动"""
with self._get_db() as conn: with self._get_db() as conn:
campaign_row = conn.execute("SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,)).fetchone() campaign_row = conn.execute("SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,)).fetchone()
@@ -1432,7 +1433,7 @@ class GrowthManager:
return {"campaign_id": campaign_id, "total": len(logs), "success": success_count, "failed": failed_count} return {"campaign_id": campaign_id, "total": len(logs), "success": success_count, "failed": failed_count}
def _get_user_variables(self, tenant_id: str, user_id: str) -> Dict: def _get_user_variables(self, tenant_id: str, user_id: str) -> dict:
"""获取用户变量用于邮件模板""" """获取用户变量用于邮件模板"""
# 这里应该从用户服务获取用户信息 # 这里应该从用户服务获取用户信息
# 简化实现 # 简化实现
@@ -1444,8 +1445,8 @@ class GrowthManager:
name: str, name: str,
description: str, description: str,
trigger_type: WorkflowTriggerType, trigger_type: WorkflowTriggerType,
trigger_conditions: Dict, trigger_conditions: dict,
actions: List[Dict], actions: list[dict],
) -> AutomationWorkflow: ) -> AutomationWorkflow:
"""创建自动化工作流""" """创建自动化工作流"""
workflow_id = f"aw_{uuid.uuid4().hex[:16]}" workflow_id = f"aw_{uuid.uuid4().hex[:16]}"
@@ -1491,7 +1492,7 @@ class GrowthManager:
return workflow return workflow
async def trigger_workflow(self, workflow_id: str, event_data: Dict): async def trigger_workflow(self, workflow_id: str, event_data: dict):
"""触发自动化工作流""" """触发自动化工作流"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
@@ -1519,7 +1520,7 @@ class GrowthManager:
return True return True
def _check_trigger_conditions(self, conditions: Dict, event_data: Dict) -> bool: def _check_trigger_conditions(self, conditions: dict, event_data: dict) -> bool:
"""检查触发条件""" """检查触发条件"""
# 简化的条件检查 # 简化的条件检查
for key, value in conditions.items(): for key, value in conditions.items():
@@ -1527,7 +1528,7 @@ class GrowthManager:
return False return False
return True return True
async def _execute_action(self, action: Dict, event_data: Dict): async def _execute_action(self, action: dict, event_data: dict):
"""执行工作流动作""" """执行工作流动作"""
action_type = action.get("type") action_type = action.get("type")
@@ -1691,7 +1692,7 @@ class GrowthManager:
if not row: if not row:
return code return code
def _get_referral_program(self, program_id: str) -> Optional[ReferralProgram]: def _get_referral_program(self, program_id: str) -> ReferralProgram | None:
"""获取推荐计划""" """获取推荐计划"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM referral_programs WHERE id = ?", (program_id,)).fetchone() row = conn.execute("SELECT * FROM referral_programs WHERE id = ?", (program_id,)).fetchone()
@@ -1746,7 +1747,7 @@ class GrowthManager:
return True return True
def get_referral_stats(self, program_id: str) -> Dict: def get_referral_stats(self, program_id: str) -> dict:
"""获取推荐统计""" """获取推荐统计"""
with self._get_db() as conn: with self._get_db() as conn:
stats = conn.execute( stats = conn.execute(
@@ -1841,7 +1842,7 @@ class GrowthManager:
def check_team_incentive_eligibility( def check_team_incentive_eligibility(
self, tenant_id: str, current_tier: str, team_size: int self, tenant_id: str, current_tier: str, team_size: int
) -> List[TeamIncentive]: ) -> list[TeamIncentive]:
"""检查团队激励资格""" """检查团队激励资格"""
with self._get_db() as conn: with self._get_db() as conn:
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1859,7 +1860,7 @@ class GrowthManager:
# ==================== 实时分析仪表板 ==================== # ==================== 实时分析仪表板 ====================
def get_realtime_dashboard(self, tenant_id: str) -> Dict: def get_realtime_dashboard(self, tenant_id: str) -> dict:
"""获取实时分析仪表板数据""" """获取实时分析仪表板数据"""
now = datetime.now() now = datetime.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -4,11 +4,10 @@ InsightFlow Image Processor - Phase 7
图片处理模块识别白板、PPT、手写笔记等内容 图片处理模块识别白板、PPT、手写笔记等内容
""" """
import os
import io
import uuid
import base64 import base64
from typing import List, Optional, Tuple import io
import os
import uuid
from dataclasses import dataclass from dataclasses import dataclass
# 尝试导入图像处理库 # 尝试导入图像处理库
@@ -42,7 +41,7 @@ class ImageEntity:
name: str name: str
type: str type: str
confidence: float confidence: float
bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, width, height) bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
@dataclass @dataclass
@@ -63,8 +62,8 @@ class ImageProcessingResult:
image_type: str # whiteboard, ppt, handwritten, screenshot, other image_type: str # whiteboard, ppt, handwritten, screenshot, other
ocr_text: str ocr_text: str
description: str description: str
entities: List[ImageEntity] entities: list[ImageEntity]
relations: List[ImageRelation] relations: list[ImageRelation]
width: int width: int
height: int height: int
success: bool success: bool
@@ -75,7 +74,7 @@ class ImageProcessingResult:
class BatchProcessingResult: class BatchProcessingResult:
"""批量图片处理结果""" """批量图片处理结果"""
results: List[ImageProcessingResult] results: list[ImageProcessingResult]
total_count: int total_count: int
success_count: int success_count: int
failed_count: int failed_count: int
@@ -231,7 +230,7 @@ class ImageProcessor:
print(f"Image type detection error: {e}") print(f"Image type detection error: {e}")
return "other" return "other"
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> Tuple[str, float]: def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
""" """
对图片进行OCR识别 对图片进行OCR识别
@@ -262,7 +261,7 @@ class ImageProcessor:
print(f"OCR error: {e}") print(f"OCR error: {e}")
return "", 0.0 return "", 0.0
def extract_entities_from_text(self, text: str) -> List[ImageEntity]: def extract_entities_from_text(self, text: str) -> list[ImageEntity]:
""" """
从OCR文本中提取实体 从OCR文本中提取实体
@@ -322,7 +321,7 @@ class ImageProcessor:
return unique_entities return unique_entities
def generate_description(self, image_type: str, ocr_text: str, entities: List[ImageEntity]) -> str: def generate_description(self, image_type: str, ocr_text: str, entities: list[ImageEntity]) -> str:
""" """
生成图片描述 生成图片描述
@@ -435,7 +434,7 @@ class ImageProcessor:
error_message=str(e), error_message=str(e),
) )
def _extract_relations(self, entities: List[ImageEntity], text: str) -> List[ImageRelation]: def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
""" """
从文本中提取实体关系 从文本中提取实体关系
@@ -475,7 +474,7 @@ class ImageProcessor:
return relations return relations
def process_batch(self, images_data: List[Tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult: def process_batch(self, images_data: list[tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult:
""" """
批量处理图片 批量处理图片
@@ -515,7 +514,7 @@ class ImageProcessor:
""" """
return base64.b64encode(image_data).decode("utf-8") return base64.b64encode(image_data).decode("utf-8")
def get_image_thumbnail(self, image_data: bytes, size: Tuple[int, int] = (200, 200)) -> bytes: def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
""" """
生成图片缩略图 生成图片缩略图

View File

@@ -1,8 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Initialize database with schema""" """Initialize database with schema"""
import sqlite3
import os import os
import sqlite3
db_path = os.path.join(os.path.dirname(__file__), "insightflow.db") db_path = os.path.join(os.path.dirname(__file__), "insightflow.db")
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
@@ -11,7 +11,7 @@ print(f"Database path: {db_path}")
print(f"Schema path: {schema_path}") print(f"Schema path: {schema_path}")
# Read schema # Read schema
with open(schema_path, "r") as f: with open(schema_path) as f:
schema = f.read() schema = f.read()
# Execute schema # Execute schema

View File

@@ -4,13 +4,13 @@ InsightFlow Knowledge Reasoning - Phase 5
知识推理与问答增强模块 知识推理与问答增强模块
""" """
import os
import json import json
import httpx import os
from typing import List, Dict
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@@ -32,9 +32,9 @@ class ReasoningResult:
answer: str answer: str
reasoning_type: ReasoningType reasoning_type: ReasoningType
confidence: float confidence: float
evidence: List[Dict] # 支撑证据 evidence: list[dict] # 支撑证据
related_entities: List[str] # 相关实体 related_entities: list[str] # 相关实体
gaps: List[str] # 知识缺口 gaps: list[str] # 知识缺口
@dataclass @dataclass
@@ -43,7 +43,7 @@ class InferencePath:
start_entity: str start_entity: str
end_entity: str end_entity: str
path: List[Dict] # 路径上的节点和关系 path: list[dict] # 路径上的节点和关系
strength: float # 路径强度 strength: float # 路径强度
@@ -71,7 +71,7 @@ class KnowledgeReasoner:
return result["choices"][0]["message"]["content"] return result["choices"][0]["message"]["content"]
async def enhanced_qa( async def enhanced_qa(
self, query: str, project_context: Dict, graph_data: Dict, reasoning_depth: str = "medium" self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium"
) -> ReasoningResult: ) -> ReasoningResult:
""" """
增强问答 - 结合图谱推理的问答 增强问答 - 结合图谱推理的问答
@@ -95,7 +95,7 @@ class KnowledgeReasoner:
else: else:
return await self._associative_reasoning(query, project_context, graph_data) return await self._associative_reasoning(query, project_context, graph_data)
async def _analyze_question(self, query: str) -> Dict: async def _analyze_question(self, query: str) -> dict:
"""分析问题类型和意图""" """分析问题类型和意图"""
prompt = f"""分析以下问题的类型和意图: prompt = f"""分析以下问题的类型和意图:
@@ -129,7 +129,7 @@ class KnowledgeReasoner:
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"} return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
async def _causal_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult: async def _causal_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""因果推理 - 分析原因和影响""" """因果推理 - 分析原因和影响"""
# 构建因果分析提示 # 构建因果分析提示
@@ -190,7 +190,7 @@ class KnowledgeReasoner:
gaps=["无法完成因果推理"], gaps=["无法完成因果推理"],
) )
async def _comparative_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult: async def _comparative_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""对比推理 - 比较实体间的异同""" """对比推理 - 比较实体间的异同"""
prompt = f"""基于以下知识图谱进行对比分析: prompt = f"""基于以下知识图谱进行对比分析:
@@ -244,7 +244,7 @@ class KnowledgeReasoner:
gaps=[], gaps=[],
) )
async def _temporal_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult: async def _temporal_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""时序推理 - 分析时间线和演变""" """时序推理 - 分析时间线和演变"""
prompt = f"""基于以下知识图谱进行时序分析: prompt = f"""基于以下知识图谱进行时序分析:
@@ -298,7 +298,7 @@ class KnowledgeReasoner:
gaps=[], gaps=[],
) )
async def _associative_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult: async def _associative_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""关联推理 - 发现实体间的隐含关联""" """关联推理 - 发现实体间的隐含关联"""
prompt = f"""基于以下知识图谱进行关联分析: prompt = f"""基于以下知识图谱进行关联分析:
@@ -353,8 +353,8 @@ class KnowledgeReasoner:
) )
def find_inference_paths( def find_inference_paths(
self, start_entity: str, end_entity: str, graph_data: Dict, max_depth: int = 3 self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3
) -> List[InferencePath]: ) -> list[InferencePath]:
""" """
发现两个实体之间的推理路径 发现两个实体之间的推理路径
@@ -416,7 +416,7 @@ class KnowledgeReasoner:
paths.sort(key=lambda p: p.strength, reverse=True) paths.sort(key=lambda p: p.strength, reverse=True)
return paths return paths
def _calculate_path_strength(self, path: List[Dict]) -> float: def _calculate_path_strength(self, path: list[dict]) -> float:
"""计算路径强度""" """计算路径强度"""
if len(path) < 2: if len(path) < 2:
return 0.0 return 0.0
@@ -438,8 +438,8 @@ class KnowledgeReasoner:
return length_factor * confidence_factor return length_factor * confidence_factor
async def summarize_project( async def summarize_project(
self, project_context: Dict, graph_data: Dict, summary_type: str = "comprehensive" self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive"
) -> Dict: ) -> dict:
""" """
项目智能总结 项目智能总结

View File

@@ -4,12 +4,13 @@ InsightFlow LLM Client - Phase 4
用于与 Kimi API 交互,支持 RAG 问答和 Agent 功能 用于与 Kimi API 交互,支持 RAG 问答和 Agent 功能
""" """
import os
import json import json
import httpx import os
from typing import List, Dict, AsyncGenerator from collections.abc import AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "") KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@@ -44,7 +45,7 @@ class LLMClient:
self.base_url = base_url or KIMI_BASE_URL self.base_url = base_url or KIMI_BASE_URL
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
async def chat(self, messages: List[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str: async def chat(self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str:
"""发送聊天请求""" """发送聊天请求"""
if not self.api_key: if not self.api_key:
raise ValueError("KIMI_API_KEY not set") raise ValueError("KIMI_API_KEY not set")
@@ -64,7 +65,7 @@ class LLMClient:
result = response.json() result = response.json()
return result["choices"][0]["message"]["content"] return result["choices"][0]["message"]["content"]
async def chat_stream(self, messages: List[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]: async def chat_stream(self, messages: list[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]:
"""流式聊天请求""" """流式聊天请求"""
if not self.api_key: if not self.api_key:
raise ValueError("KIMI_API_KEY not set") raise ValueError("KIMI_API_KEY not set")
@@ -96,7 +97,7 @@ class LLMClient:
async def extract_entities_with_confidence( async def extract_entities_with_confidence(
self, text: str self, text: str
) -> tuple[List[EntityExtractionResult], List[RelationExtractionResult]]: ) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]:
"""提取实体和关系,带置信度分数""" """提取实体和关系,带置信度分数"""
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回: prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
@@ -152,7 +153,7 @@ class LLMClient:
print(f"Parse extraction result failed: {e}") print(f"Parse extraction result failed: {e}")
return [], [] return [], []
async def rag_query(self, query: str, context: str, project_context: Dict) -> str: async def rag_query(self, query: str, context: str, project_context: dict) -> str:
"""RAG 问答 - 基于项目上下文回答问题""" """RAG 问答 - 基于项目上下文回答问题"""
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题: prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
@@ -174,7 +175,7 @@ class LLMClient:
return await self.chat(messages, temperature=0.3) return await self.chat(messages, temperature=0.3)
async def agent_command(self, command: str, project_context: Dict) -> Dict: async def agent_command(self, command: str, project_context: dict) -> dict:
"""Agent 指令解析 - 将自然语言指令转换为结构化操作""" """Agent 指令解析 - 将自然语言指令转换为结构化操作"""
prompt = f"""解析以下用户指令,转换为结构化操作: prompt = f"""解析以下用户指令,转换为结构化操作:
@@ -214,7 +215,7 @@ class LLMClient:
except BaseException: except BaseException:
return {"intent": "unknown", "explanation": "解析失败"} return {"intent": "unknown", "explanation": "解析失败"}
async def analyze_entity_evolution(self, entity_name: str, mentions: List[Dict]) -> str: async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
"""分析实体在项目中的演变/态度变化""" """分析实体在项目中的演变/态度变化"""
mentions_text = "\n".join( mentions_text = "\n".join(
[f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20]] # 限制数量 [f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20]] # 限制数量

View File

@@ -10,14 +10,14 @@ InsightFlow Phase 8 - 全球化与本地化管理模块
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import sqlite3
import json import json
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import logging import logging
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import StrEnum
from typing import Any
try: try:
import pytz import pytz
@@ -36,7 +36,7 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LanguageCode(str, Enum): class LanguageCode(StrEnum):
"""支持的语言代码""" """支持的语言代码"""
EN = "en" EN = "en"
@@ -53,7 +53,7 @@ class LanguageCode(str, Enum):
HI = "hi" HI = "hi"
class RegionCode(str, Enum): class RegionCode(StrEnum):
"""区域代码""" """区域代码"""
GLOBAL = "global" GLOBAL = "global"
@@ -65,7 +65,7 @@ class RegionCode(str, Enum):
MIDDLE_EAST = "me" MIDDLE_EAST = "me"
class DataCenterRegion(str, Enum): class DataCenterRegion(StrEnum):
"""数据中心区域""" """数据中心区域"""
US_EAST = "us-east" US_EAST = "us-east"
@@ -79,7 +79,7 @@ class DataCenterRegion(str, Enum):
CN_EAST = "cn-east" CN_EAST = "cn-east"
class PaymentProvider(str, Enum): class PaymentProvider(StrEnum):
"""支付提供商""" """支付提供商"""
STRIPE = "stripe" STRIPE = "stripe"
@@ -96,7 +96,7 @@ class PaymentProvider(str, Enum):
UNIONPAY = "unionpay" UNIONPAY = "unionpay"
class CalendarType(str, Enum): class CalendarType(StrEnum):
"""日历类型""" """日历类型"""
GREGORIAN = "gregorian" GREGORIAN = "gregorian"
@@ -115,12 +115,12 @@ class Translation:
language: str language: str
value: str value: str
namespace: str namespace: str
context: Optional[str] context: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
is_reviewed: bool is_reviewed: bool
reviewed_by: Optional[str] reviewed_by: str | None
reviewed_at: Optional[datetime] reviewed_at: datetime | None
@dataclass @dataclass
@@ -131,7 +131,7 @@ class LanguageConfig:
is_rtl: bool is_rtl: bool
is_active: bool is_active: bool
is_default: bool is_default: bool
fallback_language: Optional[str] fallback_language: str | None
date_format: str date_format: str
time_format: str time_format: str
datetime_format: str datetime_format: str
@@ -150,8 +150,8 @@ class DataCenter:
endpoint: str endpoint: str
status: str status: str
priority: int priority: int
supported_regions: List[str] supported_regions: list[str]
capabilities: Dict[str, Any] capabilities: dict[str, Any]
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -161,7 +161,7 @@ class TenantDataCenterMapping:
id: str id: str
tenant_id: str tenant_id: str
primary_dc_id: str primary_dc_id: str
secondary_dc_id: Optional[str] secondary_dc_id: str | None
region_code: str region_code: str
data_residency: str data_residency: str
created_at: datetime created_at: datetime
@@ -173,15 +173,15 @@ class LocalizedPaymentMethod:
id: str id: str
provider: str provider: str
name: str name: str
name_local: Dict[str, str] name_local: dict[str, str]
supported_countries: List[str] supported_countries: list[str]
supported_currencies: List[str] supported_currencies: list[str]
is_active: bool is_active: bool
config: Dict[str, Any] config: dict[str, Any]
icon_url: Optional[str] icon_url: str | None
display_order: int display_order: int
min_amount: Optional[float] min_amount: float | None
max_amount: Optional[float] max_amount: float | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -191,20 +191,20 @@ class CountryConfig:
code: str code: str
code3: str code3: str
name: str name: str
name_local: Dict[str, str] name_local: dict[str, str]
region: str region: str
default_language: str default_language: str
supported_languages: List[str] supported_languages: list[str]
default_currency: str default_currency: str
supported_currencies: List[str] supported_currencies: list[str]
timezone: str timezone: str
calendar_type: str calendar_type: str
date_format: Optional[str] date_format: str | None
time_format: Optional[str] time_format: str | None
number_format: Optional[str] number_format: str | None
address_format: Optional[str] address_format: str | None
phone_format: Optional[str] phone_format: str | None
vat_rate: Optional[float] vat_rate: float | None
is_active: bool is_active: bool
@@ -213,7 +213,7 @@ class TimezoneConfig:
id: str id: str
timezone: str timezone: str
utc_offset: str utc_offset: str
dst_offset: Optional[str] dst_offset: str | None
country_code: str country_code: str
region: str region: str
is_active: bool is_active: bool
@@ -223,7 +223,7 @@ class TimezoneConfig:
class CurrencyConfig: class CurrencyConfig:
code: str code: str
name: str name: str
name_local: Dict[str, str] name_local: dict[str, str]
symbol: str symbol: str
decimal_places: int decimal_places: int
decimal_separator: str decimal_separator: str
@@ -236,13 +236,13 @@ class LocalizationSettings:
id: str id: str
tenant_id: str tenant_id: str
default_language: str default_language: str
supported_languages: List[str] supported_languages: list[str]
default_currency: str default_currency: str
supported_currencies: List[str] supported_currencies: list[str]
default_timezone: str default_timezone: str
default_date_format: Optional[str] default_date_format: str | None
default_time_format: Optional[str] default_time_format: str | None
default_number_format: Optional[str] default_number_format: str | None
calendar_type: str calendar_type: str
first_day_of_week: int first_day_of_week: int
region_code: str region_code: str
@@ -940,7 +940,7 @@ class LocalizationManager:
def get_translation( def get_translation(
self, key: str, language: str, namespace: str = "common", fallback: bool = True self, key: str, language: str, namespace: str = "common", fallback: bool = True
) -> Optional[str]: ) -> str | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -962,7 +962,7 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def set_translation( def set_translation(
self, key: str, language: str, value: str, namespace: str = "common", context: Optional[str] = None self, key: str, language: str, value: str, namespace: str = "common", context: str | None = None
) -> Translation: ) -> Translation:
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -985,7 +985,7 @@ class LocalizationManager:
def _get_translation_internal( def _get_translation_internal(
self, conn: sqlite3.Connection, key: str, language: str, namespace: str self, conn: sqlite3.Connection, key: str, language: str, namespace: str
) -> Optional[Translation]: ) -> Translation | None:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace) "SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace)
@@ -1008,8 +1008,8 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_translations( def list_translations(
self, language: Optional[str] = None, namespace: Optional[str] = None, limit: int = 1000, offset: int = 0 self, language: str | None = None, namespace: str | None = None, limit: int = 1000, offset: int = 0
) -> List[Translation]: ) -> list[Translation]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1029,7 +1029,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_language_config(self, code: str) -> Optional[LanguageConfig]: def get_language_config(self, code: str) -> LanguageConfig | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1041,7 +1041,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_language_configs(self, active_only: bool = True) -> List[LanguageConfig]: def list_language_configs(self, active_only: bool = True) -> list[LanguageConfig]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1055,7 +1055,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_data_center(self, dc_id: str) -> Optional[DataCenter]: def get_data_center(self, dc_id: str) -> DataCenter | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1067,7 +1067,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_data_center_by_region(self, region_code: str) -> Optional[DataCenter]: def get_data_center_by_region(self, region_code: str) -> DataCenter | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1079,7 +1079,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_data_centers(self, status: Optional[str] = None, region: Optional[str] = None) -> List[DataCenter]: def list_data_centers(self, status: str | None = None, region: str | None = None) -> list[DataCenter]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1098,7 +1098,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_tenant_data_center(self, tenant_id: str) -> Optional[TenantDataCenterMapping]: def get_tenant_data_center(self, tenant_id: str) -> TenantDataCenterMapping | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1159,7 +1159,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_payment_method(self, provider: str) -> Optional[LocalizedPaymentMethod]: def get_payment_method(self, provider: str) -> LocalizedPaymentMethod | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1172,8 +1172,8 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_payment_methods( def list_payment_methods(
self, country_code: Optional[str] = None, currency: Optional[str] = None, active_only: bool = True self, country_code: str | None = None, currency: str | None = None, active_only: bool = True
) -> List[LocalizedPaymentMethod]: ) -> list[LocalizedPaymentMethod]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1194,7 +1194,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_localized_payment_methods(self, country_code: str, language: str = "en") -> List[Dict[str, Any]]: def get_localized_payment_methods(self, country_code: str, language: str = "en") -> list[dict[str, Any]]:
methods = self.list_payment_methods(country_code=country_code) methods = self.list_payment_methods(country_code=country_code)
result = [] result = []
for method in methods: for method in methods:
@@ -1212,7 +1212,7 @@ class LocalizationManager:
) )
return result return result
def get_country_config(self, code: str) -> Optional[CountryConfig]: def get_country_config(self, code: str) -> CountryConfig | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1224,7 +1224,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def list_country_configs(self, region: Optional[str] = None, active_only: bool = True) -> List[CountryConfig]: def list_country_configs(self, region: str | None = None, active_only: bool = True) -> list[CountryConfig]:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1243,7 +1243,7 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def format_datetime( def format_datetime(
self, dt: datetime, language: str = "en", timezone: Optional[str] = None, format_type: str = "datetime" self, dt: datetime, language: str = "en", timezone: str | None = None, format_type: str = "datetime"
) -> str: ) -> str:
try: try:
if timezone and PYTZ_AVAILABLE: if timezone and PYTZ_AVAILABLE:
@@ -1276,7 +1276,7 @@ class LocalizationManager:
logger.error(f"Error formatting datetime: {e}") logger.error(f"Error formatting datetime: {e}")
return dt.strftime("%Y-%m-%d %H:%M") return dt.strftime("%Y-%m-%d %H:%M")
def format_number(self, number: float, language: str = "en", decimal_places: Optional[int] = None) -> str: def format_number(self, number: float, language: str = "en", decimal_places: int | None = None) -> str:
try: try:
if BABEL_AVAILABLE: if BABEL_AVAILABLE:
try: try:
@@ -1319,7 +1319,7 @@ class LocalizationManager:
logger.error(f"Error converting timezone: {e}") logger.error(f"Error converting timezone: {e}")
return dt return dt
def get_calendar_info(self, calendar_type: str, year: int, month: int) -> Dict[str, Any]: def get_calendar_info(self, calendar_type: str, year: int, month: int) -> dict[str, Any]:
import calendar import calendar
cal = calendar.Calendar() cal = calendar.Calendar()
@@ -1334,7 +1334,7 @@ class LocalizationManager:
"weeks": month_days, "weeks": month_days,
} }
def get_localization_settings(self, tenant_id: str) -> Optional[LocalizationSettings]: def get_localization_settings(self, tenant_id: str) -> LocalizationSettings | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
@@ -1350,9 +1350,9 @@ class LocalizationManager:
self, self,
tenant_id: str, tenant_id: str,
default_language: str = "en", default_language: str = "en",
supported_languages: Optional[List[str]] = None, supported_languages: list[str] | None = None,
default_currency: str = "USD", default_currency: str = "USD",
supported_currencies: Optional[List[str]] = None, supported_currencies: list[str] | None = None,
default_timezone: str = "UTC", default_timezone: str = "UTC",
region_code: str = "global", region_code: str = "global",
data_residency: str = "regional", data_residency: str = "regional",
@@ -1397,7 +1397,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def update_localization_settings(self, tenant_id: str, **kwargs) -> Optional[LocalizationSettings]: def update_localization_settings(self, tenant_id: str, **kwargs) -> LocalizationSettings | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
settings = self.get_localization_settings(tenant_id) settings = self.get_localization_settings(tenant_id)
@@ -1441,8 +1441,8 @@ class LocalizationManager:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def detect_user_preferences( def detect_user_preferences(
self, accept_language: Optional[str] = None, ip_country: Optional[str] = None self, accept_language: str | None = None, ip_country: str | None = None
) -> Dict[str, str]: ) -> dict[str, str]:
preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"} preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"}
if accept_language: if accept_language:
langs = accept_language.split(",") langs = accept_language.split(",")

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,6 @@ InsightFlow Multimodal Entity Linker - Phase 7
""" """
import uuid import uuid
from typing import List, Dict, Optional, Tuple, Set
from dataclasses import dataclass from dataclasses import dataclass
from difflib import SequenceMatcher from difflib import SequenceMatcher
@@ -28,7 +27,7 @@ class MultimodalEntity:
source_id: str source_id: str
mention_context: str mention_context: str
confidence: float confidence: float
modality_features: Dict = None # 模态特定特征 modality_features: dict = None # 模态特定特征
def __post_init__(self): def __post_init__(self):
if self.modality_features is None: if self.modality_features is None:
@@ -55,7 +54,7 @@ class AlignmentResult:
"""对齐结果""" """对齐结果"""
entity_id: str entity_id: str
matched_entity_id: Optional[str] matched_entity_id: str | None
similarity: float similarity: float
match_type: str # exact, fuzzy, embedding match_type: str # exact, fuzzy, embedding
confidence: float confidence: float
@@ -66,9 +65,9 @@ class FusionResult:
"""知识融合结果""" """知识融合结果"""
canonical_entity_id: str canonical_entity_id: str
merged_entity_ids: List[str] merged_entity_ids: list[str]
fused_properties: Dict fused_properties: dict
source_modalities: List[str] source_modalities: list[str]
confidence: float confidence: float
@@ -117,7 +116,7 @@ class MultimodalEntityLinker:
# 编辑距离相似度 # 编辑距离相似度
return SequenceMatcher(None, s1, s2).ratio() return SequenceMatcher(None, s1, s2).ratio()
def calculate_entity_similarity(self, entity1: Dict, entity2: Dict) -> Tuple[float, str]: def calculate_entity_similarity(self, entity1: dict, entity2: dict) -> tuple[float, str]:
""" """
计算两个实体的综合相似度 计算两个实体的综合相似度
@@ -159,8 +158,8 @@ class MultimodalEntityLinker:
return combined_sim, "none" return combined_sim, "none"
def find_matching_entity( def find_matching_entity(
self, query_entity: Dict, candidate_entities: List[Dict], exclude_ids: Set[str] = None self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None
) -> Optional[AlignmentResult]: ) -> AlignmentResult | None:
""" """
在候选实体中查找匹配的实体 在候选实体中查找匹配的实体
@@ -201,11 +200,11 @@ class MultimodalEntityLinker:
def align_cross_modal_entities( def align_cross_modal_entities(
self, self,
project_id: str, project_id: str,
audio_entities: List[Dict], audio_entities: list[dict],
video_entities: List[Dict], video_entities: list[dict],
image_entities: List[Dict], image_entities: list[dict],
document_entities: List[Dict], document_entities: list[dict],
) -> List[EntityLink]: ) -> list[EntityLink]:
""" """
跨模态实体对齐 跨模态实体对齐
@@ -259,7 +258,7 @@ class MultimodalEntityLinker:
return links return links
def fuse_entity_knowledge( def fuse_entity_knowledge(
self, entity_id: str, linked_entities: List[Dict], multimodal_mentions: List[Dict] self, entity_id: str, linked_entities: list[dict], multimodal_mentions: list[dict]
) -> FusionResult: ) -> FusionResult:
""" """
融合多模态实体知识 融合多模态实体知识
@@ -331,7 +330,7 @@ class MultimodalEntityLinker:
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5), confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
) )
def detect_entity_conflicts(self, entities: List[Dict]) -> List[Dict]: def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
""" """
检测实体冲突(同名但不同义) 检测实体冲突(同名但不同义)
@@ -380,7 +379,7 @@ class MultimodalEntityLinker:
return conflicts return conflicts
def suggest_entity_merges(self, entities: List[Dict], existing_links: List[EntityLink] = None) -> List[Dict]: def suggest_entity_merges(self, entities: list[dict], existing_links: list[EntityLink] = None) -> list[dict]:
""" """
建议实体合并 建议实体合并
@@ -464,7 +463,7 @@ class MultimodalEntityLinker:
confidence=confidence, confidence=confidence,
) )
def analyze_modality_distribution(self, multimodal_entities: List[MultimodalEntity]) -> Dict: def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:
""" """
分析模态分布 分析模态分布

View File

@@ -4,12 +4,11 @@ InsightFlow Multimodal Processor - Phase 7
视频处理模块提取音频、关键帧、OCR识别 视频处理模块提取音频、关键帧、OCR识别
""" """
import os
import json import json
import uuid import os
import tempfile
import subprocess import subprocess
from typing import List, Dict, Tuple import tempfile
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -48,7 +47,7 @@ class VideoFrame:
frame_path: str frame_path: str
ocr_text: str = "" ocr_text: str = ""
ocr_confidence: float = 0.0 ocr_confidence: float = 0.0
entities_detected: List[Dict] = None entities_detected: list[dict] = None
def __post_init__(self): def __post_init__(self):
if self.entities_detected is None: if self.entities_detected is None:
@@ -72,7 +71,7 @@ class VideoInfo:
transcript_id: str = "" transcript_id: str = ""
status: str = "pending" status: str = "pending"
error_message: str = "" error_message: str = ""
metadata: Dict = None metadata: dict = None
def __post_init__(self): def __post_init__(self):
if self.metadata is None: if self.metadata is None:
@@ -85,8 +84,8 @@ class VideoProcessingResult:
video_id: str video_id: str
audio_path: str audio_path: str
frames: List[VideoFrame] frames: list[VideoFrame]
ocr_results: List[Dict] ocr_results: list[dict]
full_text: str # 整合的文本(音频转录 + OCR文本 full_text: str # 整合的文本(音频转录 + OCR文本
success: bool success: bool
error_message: str = "" error_message: str = ""
@@ -114,7 +113,7 @@ class MultimodalProcessor:
os.makedirs(self.frames_dir, exist_ok=True) os.makedirs(self.frames_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok=True) os.makedirs(self.audio_dir, exist_ok=True)
def extract_video_info(self, video_path: str) -> Dict: def extract_video_info(self, video_path: str) -> dict:
""" """
提取视频基本信息 提取视频基本信息
@@ -215,7 +214,7 @@ class MultimodalProcessor:
print(f"Error extracting audio: {e}") print(f"Error extracting audio: {e}")
raise raise
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> List[str]: def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
""" """
从视频中提取关键帧 从视频中提取关键帧
@@ -275,7 +274,7 @@ class MultimodalProcessor:
return frame_paths return frame_paths
def perform_ocr(self, image_path: str) -> Tuple[str, float]: def perform_ocr(self, image_path: str) -> tuple[str, float]:
""" """
对图片进行OCR识别 对图片进行OCR识别

View File

@@ -5,10 +5,9 @@ Phase 5: Neo4j 图数据库集成
支持数据同步、复杂图查询和图算法分析 支持数据同步、复杂图查询和图算法分析
""" """
import os
import json import json
import logging import logging
from typing import List, Dict, Optional import os
from dataclasses import dataclass from dataclasses import dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,7 +19,7 @@ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
# 延迟导入,避免未安装时出错 # 延迟导入,避免未安装时出错
try: try:
from neo4j import GraphDatabase, Driver from neo4j import Driver, GraphDatabase
NEO4J_AVAILABLE = True NEO4J_AVAILABLE = True
except ImportError: except ImportError:
@@ -37,8 +36,8 @@ class GraphEntity:
name: str name: str
type: str type: str
definition: str = "" definition: str = ""
aliases: List[str] = None aliases: list[str] = None
properties: Dict = None properties: dict = None
def __post_init__(self): def __post_init__(self):
if self.aliases is None: if self.aliases is None:
@@ -56,7 +55,7 @@ class GraphRelation:
target_id: str target_id: str
relation_type: str relation_type: str
evidence: str = "" evidence: str = ""
properties: Dict = None properties: dict = None
def __post_init__(self): def __post_init__(self):
if self.properties is None: if self.properties is None:
@@ -67,8 +66,8 @@ class GraphRelation:
class PathResult: class PathResult:
"""路径查询结果""" """路径查询结果"""
nodes: List[Dict] nodes: list[dict]
relationships: List[Dict] relationships: list[dict]
length: int length: int
total_weight: float = 0.0 total_weight: float = 0.0
@@ -78,7 +77,7 @@ class CommunityResult:
"""社区发现结果""" """社区发现结果"""
community_id: int community_id: int
nodes: List[Dict] nodes: list[dict]
size: int size: int
density: float = 0.0 density: float = 0.0
@@ -100,7 +99,7 @@ class Neo4jManager:
self.uri = uri or NEO4J_URI self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD self.password = password or NEO4J_PASSWORD
self._driver: Optional["Driver"] = None self._driver: Driver | None = None
if not NEO4J_AVAILABLE: if not NEO4J_AVAILABLE:
logger.error("Neo4j driver not available. Please install: pip install neo4j") logger.error("Neo4j driver not available. Please install: pip install neo4j")
@@ -226,7 +225,7 @@ class Neo4jManager:
properties=json.dumps(entity.properties), properties=json.dumps(entity.properties),
) )
def sync_entities_batch(self, entities: List[GraphEntity]): def sync_entities_batch(self, entities: list[GraphEntity]):
"""批量同步实体到 Neo4j""" """批量同步实体到 Neo4j"""
if not self._driver or not entities: if not self._driver or not entities:
return return
@@ -287,7 +286,7 @@ class Neo4jManager:
properties=json.dumps(relation.properties), properties=json.dumps(relation.properties),
) )
def sync_relations_batch(self, relations: List[GraphRelation]): def sync_relations_batch(self, relations: list[GraphRelation]):
"""批量同步关系到 Neo4j""" """批量同步关系到 Neo4j"""
if not self._driver or not relations: if not self._driver or not relations:
return return
@@ -350,7 +349,7 @@ class Neo4jManager:
# ==================== 复杂图查询 ==================== # ==================== 复杂图查询 ====================
def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> Optional[PathResult]: def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> PathResult | None:
""" """
查找两个实体之间的最短路径 查找两个实体之间的最短路径
@@ -399,7 +398,7 @@ class Neo4jManager:
return PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships)) return PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships))
def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> List[PathResult]: def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> list[PathResult]:
""" """
查找两个实体之间的所有路径 查找两个实体之间的所有路径
@@ -449,7 +448,7 @@ class Neo4jManager:
return paths return paths
def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> List[Dict]: def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> list[dict]:
""" """
查找实体的邻居节点 查找实体的邻居节点
@@ -502,7 +501,7 @@ class Neo4jManager:
return neighbors return neighbors
def find_common_neighbors(self, entity_id1: str, entity_id2: str) -> List[Dict]: def find_common_neighbors(self, entity_id1: str, entity_id2: str) -> list[dict]:
""" """
查找两个实体的共同邻居(潜在关联) 查找两个实体的共同邻居(潜在关联)
@@ -533,7 +532,7 @@ class Neo4jManager:
# ==================== 图算法分析 ==================== # ==================== 图算法分析 ====================
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> List[CentralityResult]: def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
""" """
计算 PageRank 中心性 计算 PageRank 中心性
@@ -619,7 +618,7 @@ class Neo4jManager:
return rankings return rankings
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> List[CentralityResult]: def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
""" """
计算 Betweenness 中心性(桥梁作用) 计算 Betweenness 中心性(桥梁作用)
@@ -663,7 +662,7 @@ class Neo4jManager:
return rankings return rankings
def detect_communities(self, project_id: str) -> List[CommunityResult]: def detect_communities(self, project_id: str) -> list[CommunityResult]:
""" """
社区发现(使用 Louvain 算法) 社区发现(使用 Louvain 算法)
@@ -733,7 +732,7 @@ class Neo4jManager:
results.sort(key=lambda x: x.size, reverse=True) results.sort(key=lambda x: x.size, reverse=True)
return results return results
def find_central_entities(self, project_id: str, metric: str = "degree") -> List[CentralityResult]: def find_central_entities(self, project_id: str, metric: str = "degree") -> list[CentralityResult]:
""" """
查找中心实体 查找中心实体
@@ -791,7 +790,7 @@ class Neo4jManager:
# ==================== 图统计 ==================== # ==================== 图统计 ====================
def get_graph_stats(self, project_id: str) -> Dict: def get_graph_stats(self, project_id: str) -> dict:
""" """
获取项目的图统计信息 获取项目的图统计信息
@@ -870,7 +869,7 @@ class Neo4jManager:
"density": round(relation_count / (entity_count * (entity_count - 1)), 4) if entity_count > 1 else 0, "density": round(relation_count / (entity_count * (entity_count - 1)), 4) if entity_count > 1 else 0,
} }
def get_subgraph(self, entity_ids: List[str], depth: int = 1) -> Dict: def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
""" """
获取指定实体的子图 获取指定实体的子图
@@ -959,7 +958,7 @@ def close_neo4j_manager():
# 便捷函数 # 便捷函数
def sync_project_to_neo4j(project_id: str, project_name: str, entities: List[Dict], relations: List[Dict]): def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]):
""" """
同步整个项目到 Neo4j 同步整个项目到 Neo4j

View File

@@ -10,26 +10,27 @@ InsightFlow Operations & Monitoring Manager - Phase 8 Task 8
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import os
import json
import sqlite3
import httpx
import asyncio import asyncio
import hashlib import hashlib
import uuid import json
import os
import re import re
import time import sqlite3
import statistics import statistics
from typing import List, Dict, Optional, Tuple, Callable import time
import uuid
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import StrEnum
import httpx
# Database path # Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class AlertSeverity(str, Enum): class AlertSeverity(StrEnum):
"""告警严重级别 P0-P3""" """告警严重级别 P0-P3"""
P0 = "p0" # 紧急 - 系统不可用,需要立即处理 P0 = "p0" # 紧急 - 系统不可用,需要立即处理
@@ -38,7 +39,7 @@ class AlertSeverity(str, Enum):
P3 = "p3" # 轻微 - 非核心功能问题24小时内处理 P3 = "p3" # 轻微 - 非核心功能问题24小时内处理
class AlertStatus(str, Enum): class AlertStatus(StrEnum):
"""告警状态""" """告警状态"""
FIRING = "firing" # 正在告警 FIRING = "firing" # 正在告警
@@ -47,7 +48,7 @@ class AlertStatus(str, Enum):
SUPPRESSED = "suppressed" # 已抑制 SUPPRESSED = "suppressed" # 已抑制
class AlertChannelType(str, Enum): class AlertChannelType(StrEnum):
"""告警渠道类型""" """告警渠道类型"""
PAGERDUTY = "pagerduty" PAGERDUTY = "pagerduty"
@@ -60,7 +61,7 @@ class AlertChannelType(str, Enum):
WEBHOOK = "webhook" WEBHOOK = "webhook"
class AlertRuleType(str, Enum): class AlertRuleType(StrEnum):
"""告警规则类型""" """告警规则类型"""
THRESHOLD = "threshold" # 阈值告警 THRESHOLD = "threshold" # 阈值告警
@@ -69,7 +70,7 @@ class AlertRuleType(str, Enum):
COMPOSITE = "composite" # 复合告警 COMPOSITE = "composite" # 复合告警
class ResourceType(str, Enum): class ResourceType(StrEnum):
"""资源类型""" """资源类型"""
CPU = "cpu" CPU = "cpu"
@@ -82,7 +83,7 @@ class ResourceType(str, Enum):
QUEUE = "queue" QUEUE = "queue"
class ScalingAction(str, Enum): class ScalingAction(StrEnum):
"""扩缩容动作""" """扩缩容动作"""
SCALE_UP = "scale_up" # 扩容 SCALE_UP = "scale_up" # 扩容
@@ -90,7 +91,7 @@ class ScalingAction(str, Enum):
MAINTAIN = "maintain" # 保持 MAINTAIN = "maintain" # 保持
class HealthStatus(str, Enum): class HealthStatus(StrEnum):
"""健康状态""" """健康状态"""
HEALTHY = "healthy" HEALTHY = "healthy"
@@ -99,7 +100,7 @@ class HealthStatus(str, Enum):
UNKNOWN = "unknown" UNKNOWN = "unknown"
class BackupStatus(str, Enum): class BackupStatus(StrEnum):
"""备份状态""" """备份状态"""
PENDING = "pending" PENDING = "pending"
@@ -124,9 +125,9 @@ class AlertRule:
threshold: float threshold: float
duration: int # 持续时间(秒) duration: int # 持续时间(秒)
evaluation_interval: int # 评估间隔(秒) evaluation_interval: int # 评估间隔(秒)
channels: List[str] # 告警渠道ID列表 channels: list[str] # 告警渠道ID列表
labels: Dict[str, str] # 标签 labels: dict[str, str] # 标签
annotations: Dict[str, str] # 注释 annotations: dict[str, str] # 注释
is_enabled: bool is_enabled: bool
created_at: str created_at: str
updated_at: str updated_at: str
@@ -141,12 +142,12 @@ class AlertChannel:
tenant_id: str tenant_id: str
name: str name: str
channel_type: AlertChannelType channel_type: AlertChannelType
config: Dict # 渠道特定配置 config: dict # 渠道特定配置
severity_filter: List[str] # 过滤的告警级别 severity_filter: list[str] # 过滤的告警级别
is_enabled: bool is_enabled: bool
success_count: int success_count: int
fail_count: int fail_count: int
last_used_at: Optional[str] last_used_at: str | None
created_at: str created_at: str
updated_at: str updated_at: str
@@ -165,13 +166,13 @@ class Alert:
metric: str metric: str
value: float value: float
threshold: float threshold: float
labels: Dict[str, str] labels: dict[str, str]
annotations: Dict[str, str] annotations: dict[str, str]
started_at: str started_at: str
resolved_at: Optional[str] resolved_at: str | None
acknowledged_by: Optional[str] acknowledged_by: str | None
acknowledged_at: Optional[str] acknowledged_at: str | None
notification_sent: Dict[str, bool] # 渠道发送状态 notification_sent: dict[str, bool] # 渠道发送状态
suppression_count: int # 抑制计数 suppression_count: int # 抑制计数
@@ -182,11 +183,11 @@ class AlertSuppressionRule:
id: str id: str
tenant_id: str tenant_id: str
name: str name: str
matchers: Dict[str, str] # 匹配条件 matchers: dict[str, str] # 匹配条件
duration: int # 抑制持续时间(秒) duration: int # 抑制持续时间(秒)
is_regex: bool # 是否使用正则匹配 is_regex: bool # 是否使用正则匹配
created_at: str created_at: str
expires_at: Optional[str] expires_at: str | None
@dataclass @dataclass
@@ -196,7 +197,7 @@ class AlertGroup:
id: str id: str
tenant_id: str tenant_id: str
group_key: str # 聚合键 group_key: str # 聚合键
alerts: List[str] # 告警ID列表 alerts: list[str] # 告警ID列表
created_at: str created_at: str
updated_at: str updated_at: str
@@ -213,7 +214,7 @@ class ResourceMetric:
metric_value: float metric_value: float
unit: str unit: str
timestamp: str timestamp: str
metadata: Dict metadata: dict
@dataclass @dataclass
@@ -267,8 +268,8 @@ class ScalingEvent:
triggered_by: str # 触发来源: manual, auto, scheduled triggered_by: str # 触发来源: manual, auto, scheduled
status: str # pending, in_progress, completed, failed status: str # pending, in_progress, completed, failed
started_at: str started_at: str
completed_at: Optional[str] completed_at: str | None
error_message: Optional[str] error_message: str | None
@dataclass @dataclass
@@ -281,7 +282,7 @@ class HealthCheck:
target_type: str # service, database, api, etc. target_type: str # service, database, api, etc.
target_id: str target_id: str
check_type: str # http, tcp, ping, custom check_type: str # http, tcp, ping, custom
check_config: Dict # 检查配置 check_config: dict # 检查配置
interval: int # 检查间隔(秒) interval: int # 检查间隔(秒)
timeout: int # 超时时间(秒) timeout: int # 超时时间(秒)
retry_count: int retry_count: int
@@ -302,7 +303,7 @@ class HealthCheckResult:
status: HealthStatus status: HealthStatus
response_time: float # 响应时间(毫秒) response_time: float # 响应时间(毫秒)
message: str message: str
details: Dict details: dict
checked_at: str checked_at: str
@@ -314,7 +315,7 @@ class FailoverConfig:
tenant_id: str tenant_id: str
name: str name: str
primary_region: str primary_region: str
secondary_regions: List[str] # 备用区域列表 secondary_regions: list[str] # 备用区域列表
failover_trigger: str # 触发条件 failover_trigger: str # 触发条件
auto_failover: bool auto_failover: bool
failover_timeout: int # 故障转移超时(秒) failover_timeout: int # 故障转移超时(秒)
@@ -336,8 +337,8 @@ class FailoverEvent:
reason: str reason: str
status: str # initiated, in_progress, completed, failed, rolled_back status: str # initiated, in_progress, completed, failed, rolled_back
started_at: str started_at: str
completed_at: Optional[str] completed_at: str | None
rolled_back_at: Optional[str] rolled_back_at: str | None
@dataclass @dataclass
@@ -371,9 +372,9 @@ class BackupRecord:
size_bytes: int size_bytes: int
checksum: str checksum: str
started_at: str started_at: str
completed_at: Optional[str] completed_at: str | None
verified_at: Optional[str] verified_at: str | None
error_message: Optional[str] error_message: str | None
storage_path: str storage_path: str
@@ -386,9 +387,9 @@ class CostReport:
report_period: str # YYYY-MM report_period: str # YYYY-MM
total_cost: float total_cost: float
currency: str currency: str
breakdown: Dict[str, float] # 按资源类型分解 breakdown: dict[str, float] # 按资源类型分解
trends: Dict # 趋势数据 trends: dict # 趋势数据
anomalies: List[Dict] # 异常检测 anomalies: list[dict] # 异常检测
created_at: str created_at: str
@@ -405,7 +406,7 @@ class ResourceUtilization:
avg_utilization: float avg_utilization: float
idle_time_percent: float idle_time_percent: float
report_date: str report_date: str
recommendations: List[str] recommendations: list[str]
@dataclass @dataclass
@@ -438,11 +439,11 @@ class CostOptimizationSuggestion:
currency: str currency: str
confidence: float confidence: float
difficulty: str # easy, medium, hard difficulty: str # easy, medium, hard
implementation_steps: List[str] implementation_steps: list[str]
risk_level: str # low, medium, high risk_level: str # low, medium, high
is_applied: bool is_applied: bool
created_at: str created_at: str
applied_at: Optional[str] applied_at: str | None
class OpsManager: class OpsManager:
@@ -450,7 +451,7 @@ class OpsManager:
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path self.db_path = db_path
self._alert_evaluators: Dict[str, Callable] = {} self._alert_evaluators: dict[str, Callable] = {}
self._running = False self._running = False
self._evaluator_thread = None self._evaluator_thread = None
self._register_default_evaluators() self._register_default_evaluators()
@@ -481,9 +482,9 @@ class OpsManager:
threshold: float, threshold: float,
duration: int, duration: int,
evaluation_interval: int, evaluation_interval: int,
channels: List[str], channels: list[str],
labels: Dict, labels: dict,
annotations: Dict, annotations: dict,
created_by: str, created_by: str,
) -> AlertRule: ) -> AlertRule:
"""创建告警规则""" """创建告警规则"""
@@ -545,7 +546,7 @@ class OpsManager:
return rule return rule
def get_alert_rule(self, rule_id: str) -> Optional[AlertRule]: def get_alert_rule(self, rule_id: str) -> AlertRule | None:
"""获取告警规则""" """获取告警规则"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id,)).fetchone() row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id,)).fetchone()
@@ -554,7 +555,7 @@ class OpsManager:
return self._row_to_alert_rule(row) return self._row_to_alert_rule(row)
return None return None
def list_alert_rules(self, tenant_id: str, is_enabled: Optional[bool] = None) -> List[AlertRule]: def list_alert_rules(self, tenant_id: str, is_enabled: bool | None = None) -> list[AlertRule]:
"""列出租户的所有告警规则""" """列出租户的所有告警规则"""
query = "SELECT * FROM alert_rules WHERE tenant_id = ?" query = "SELECT * FROM alert_rules WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -569,7 +570,7 @@ class OpsManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_alert_rule(row) for row in rows] return [self._row_to_alert_rule(row) for row in rows]
def update_alert_rule(self, rule_id: str, **kwargs) -> Optional[AlertRule]: def update_alert_rule(self, rule_id: str, **kwargs) -> AlertRule | None:
"""更新告警规则""" """更新告警规则"""
allowed_fields = [ allowed_fields = [
"name", "name",
@@ -619,7 +620,7 @@ class OpsManager:
# ==================== 告警渠道管理 ==================== # ==================== 告警渠道管理 ====================
def create_alert_channel( def create_alert_channel(
self, tenant_id: str, name: str, channel_type: AlertChannelType, config: Dict, severity_filter: List[str] = None self, tenant_id: str, name: str, channel_type: AlertChannelType, config: dict, severity_filter: list[str] = None
) -> AlertChannel: ) -> AlertChannel:
"""创建告警渠道""" """创建告警渠道"""
channel_id = f"ac_{uuid.uuid4().hex[:16]}" channel_id = f"ac_{uuid.uuid4().hex[:16]}"
@@ -667,7 +668,7 @@ class OpsManager:
return channel return channel
def get_alert_channel(self, channel_id: str) -> Optional[AlertChannel]: def get_alert_channel(self, channel_id: str) -> AlertChannel | None:
"""获取告警渠道""" """获取告警渠道"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM alert_channels WHERE id = ?", (channel_id,)).fetchone() row = conn.execute("SELECT * FROM alert_channels WHERE id = ?", (channel_id,)).fetchone()
@@ -676,7 +677,7 @@ class OpsManager:
return self._row_to_alert_channel(row) return self._row_to_alert_channel(row)
return None return None
def list_alert_channels(self, tenant_id: str) -> List[AlertChannel]: def list_alert_channels(self, tenant_id: str) -> list[AlertChannel]:
"""列出租户的所有告警渠道""" """列出租户的所有告警渠道"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -715,7 +716,7 @@ class OpsManager:
# ==================== 告警评估与触发 ==================== # ==================== 告警评估与触发 ====================
def _evaluate_threshold_rule(self, rule: AlertRule, metrics: List[ResourceMetric]) -> bool: def _evaluate_threshold_rule(self, rule: AlertRule, metrics: list[ResourceMetric]) -> bool:
"""评估阈值告警规则""" """评估阈值告警规则"""
if not metrics: if not metrics:
return False return False
@@ -746,7 +747,7 @@ class OpsManager:
return False return False
def _evaluate_anomaly_rule(self, rule: AlertRule, metrics: List[ResourceMetric]) -> bool: def _evaluate_anomaly_rule(self, rule: AlertRule, metrics: list[ResourceMetric]) -> bool:
"""评估异常检测规则(基于标准差)""" """评估异常检测规则(基于标准差)"""
if len(metrics) < 10: if len(metrics) < 10:
return False return False
@@ -764,7 +765,7 @@ class OpsManager:
return z_score > 3.0 return z_score > 3.0
def _evaluate_predictive_rule(self, rule: AlertRule, metrics: List[ResourceMetric]) -> bool: def _evaluate_predictive_rule(self, rule: AlertRule, metrics: list[ResourceMetric]) -> bool:
"""评估预测性告警规则(基于线性趋势)""" """评估预测性告警规则(基于线性趋势)"""
if len(metrics) < 5: if len(metrics) < 5:
return False return False
@@ -814,7 +815,7 @@ class OpsManager:
# 触发告警 # 触发告警
await self._trigger_alert(rule, metrics[-1] if metrics else None) await self._trigger_alert(rule, metrics[-1] if metrics else None)
async def _trigger_alert(self, rule: AlertRule, metric: Optional[ResourceMetric]): async def _trigger_alert(self, rule: AlertRule, metric: ResourceMetric | None):
"""触发告警""" """触发告警"""
# 检查是否已有相同告警在触发中 # 检查是否已有相同告警在触发中
existing = self.get_active_alert_by_rule(rule.id) existing = self.get_active_alert_by_rule(rule.id)
@@ -1166,7 +1167,7 @@ class OpsManager:
# ==================== 告警查询与管理 ==================== # ==================== 告警查询与管理 ====================
def get_active_alert_by_rule(self, rule_id: str) -> Optional[Alert]: def get_active_alert_by_rule(self, rule_id: str) -> Alert | None:
"""获取规则对应的活跃告警""" """获取规则对应的活跃告警"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
@@ -1180,7 +1181,7 @@ class OpsManager:
return self._row_to_alert(row) return self._row_to_alert(row)
return None return None
def get_alert(self, alert_id: str) -> Optional[Alert]: def get_alert(self, alert_id: str) -> Alert | None:
"""获取告警详情""" """获取告警详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id,)).fetchone() row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id,)).fetchone()
@@ -1192,10 +1193,10 @@ class OpsManager:
def list_alerts( def list_alerts(
self, self,
tenant_id: str, tenant_id: str,
status: Optional[AlertStatus] = None, status: AlertStatus | None = None,
severity: Optional[AlertSeverity] = None, severity: AlertSeverity | None = None,
limit: int = 100, limit: int = 100,
) -> List[Alert]: ) -> list[Alert]:
"""列出租户的告警""" """列出租户的告警"""
query = "SELECT * FROM alerts WHERE tenant_id = ?" query = "SELECT * FROM alerts WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -1214,7 +1215,7 @@ class OpsManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_alert(row) for row in rows] return [self._row_to_alert(row) for row in rows]
def acknowledge_alert(self, alert_id: str, user_id: str) -> Optional[Alert]: def acknowledge_alert(self, alert_id: str, user_id: str) -> Alert | None:
"""确认告警""" """确认告警"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1231,7 +1232,7 @@ class OpsManager:
return self.get_alert(alert_id) return self.get_alert(alert_id)
def resolve_alert(self, alert_id: str) -> Optional[Alert]: def resolve_alert(self, alert_id: str) -> Alert | None:
"""解决告警""" """解决告警"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1306,10 +1307,10 @@ class OpsManager:
self, self,
tenant_id: str, tenant_id: str,
name: str, name: str,
matchers: Dict[str, str], matchers: dict[str, str],
duration: int, duration: int,
is_regex: bool = False, is_regex: bool = False,
expires_at: Optional[str] = None, expires_at: str | None = None,
) -> AlertSuppressionRule: ) -> AlertSuppressionRule:
"""创建告警抑制规则""" """创建告警抑制规则"""
rule_id = f"sr_{uuid.uuid4().hex[:16]}" rule_id = f"sr_{uuid.uuid4().hex[:16]}"
@@ -1394,7 +1395,7 @@ class OpsManager:
metric_name: str, metric_name: str,
metric_value: float, metric_value: float,
unit: str, unit: str,
metadata: Dict = None, metadata: dict = None,
) -> ResourceMetric: ) -> ResourceMetric:
"""记录资源指标""" """记录资源指标"""
metric_id = f"rm_{uuid.uuid4().hex[:16]}" metric_id = f"rm_{uuid.uuid4().hex[:16]}"
@@ -1436,7 +1437,7 @@ class OpsManager:
return metric return metric
def get_recent_metrics(self, tenant_id: str, metric_name: str, seconds: int = 3600) -> List[ResourceMetric]: def get_recent_metrics(self, tenant_id: str, metric_name: str, seconds: int = 3600) -> list[ResourceMetric]:
"""获取最近的指标数据""" """获取最近的指标数据"""
cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat() cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat()
@@ -1458,7 +1459,7 @@ class OpsManager:
metric_name: str, metric_name: str,
start_time: str, start_time: str,
end_time: str, end_time: str,
) -> List[ResourceMetric]: ) -> list[ResourceMetric]:
"""获取指定资源的指标数据""" """获取指定资源的指标数据"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -1549,7 +1550,7 @@ class OpsManager:
return plan return plan
def _calculate_trend(self, values: List[float]) -> float: def _calculate_trend(self, values: list[float]) -> float:
"""计算趋势(增长率)""" """计算趋势(增长率)"""
if len(values) < 2: if len(values) < 2:
return 0.0 return 0.0
@@ -1576,7 +1577,7 @@ class OpsManager:
return slope / mean_y return slope / mean_y
return 0.0 return 0.0
def get_capacity_plans(self, tenant_id: str) -> List[CapacityPlan]: def get_capacity_plans(self, tenant_id: str) -> list[CapacityPlan]:
"""获取容量规划列表""" """获取容量规划列表"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -1653,7 +1654,7 @@ class OpsManager:
return policy return policy
def get_auto_scaling_policy(self, policy_id: str) -> Optional[AutoScalingPolicy]: def get_auto_scaling_policy(self, policy_id: str) -> AutoScalingPolicy | None:
"""获取自动扩缩容策略""" """获取自动扩缩容策略"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,)).fetchone() row = conn.execute("SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,)).fetchone()
@@ -1662,7 +1663,7 @@ class OpsManager:
return self._row_to_auto_scaling_policy(row) return self._row_to_auto_scaling_policy(row)
return None return None
def list_auto_scaling_policies(self, tenant_id: str) -> List[AutoScalingPolicy]: def list_auto_scaling_policies(self, tenant_id: str) -> list[AutoScalingPolicy]:
"""列出租户的自动扩缩容策略""" """列出租户的自动扩缩容策略"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -1672,7 +1673,7 @@ class OpsManager:
def evaluate_scaling_policy( def evaluate_scaling_policy(
self, policy_id: str, current_instances: int, current_utilization: float self, policy_id: str, current_instances: int, current_utilization: float
) -> Optional[ScalingEvent]: ) -> ScalingEvent | None:
"""评估扩缩容策略""" """评估扩缩容策略"""
policy = self.get_auto_scaling_policy(policy_id) policy = self.get_auto_scaling_policy(policy_id)
if not policy or not policy.is_enabled: if not policy or not policy.is_enabled:
@@ -1754,7 +1755,7 @@ class OpsManager:
return event return event
def get_last_scaling_event(self, policy_id: str) -> Optional[ScalingEvent]: def get_last_scaling_event(self, policy_id: str) -> ScalingEvent | None:
"""获取最近的扩缩容事件""" """获取最近的扩缩容事件"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute( row = conn.execute(
@@ -1770,7 +1771,7 @@ class OpsManager:
def update_scaling_event_status( def update_scaling_event_status(
self, event_id: str, status: str, error_message: str = None self, event_id: str, status: str, error_message: str = None
) -> Optional[ScalingEvent]: ) -> ScalingEvent | None:
"""更新扩缩容事件状态""" """更新扩缩容事件状态"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1797,7 +1798,7 @@ class OpsManager:
return self.get_scaling_event(event_id) return self.get_scaling_event(event_id)
def get_scaling_event(self, event_id: str) -> Optional[ScalingEvent]: def get_scaling_event(self, event_id: str) -> ScalingEvent | None:
"""获取扩缩容事件""" """获取扩缩容事件"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id,)).fetchone() row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id,)).fetchone()
@@ -1806,7 +1807,7 @@ class OpsManager:
return self._row_to_scaling_event(row) return self._row_to_scaling_event(row)
return None return None
def list_scaling_events(self, tenant_id: str, policy_id: str = None, limit: int = 100) -> List[ScalingEvent]: def list_scaling_events(self, tenant_id: str, policy_id: str = None, limit: int = 100) -> list[ScalingEvent]:
"""列出租户的扩缩容事件""" """列出租户的扩缩容事件"""
query = "SELECT * FROM scaling_events WHERE tenant_id = ?" query = "SELECT * FROM scaling_events WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -1831,7 +1832,7 @@ class OpsManager:
target_type: str, target_type: str,
target_id: str, target_id: str,
check_type: str, check_type: str,
check_config: Dict, check_config: dict,
interval: int = 60, interval: int = 60,
timeout: int = 10, timeout: int = 10,
retry_count: int = 3, retry_count: int = 3,
@@ -1889,7 +1890,7 @@ class OpsManager:
return check return check
def get_health_check(self, check_id: str) -> Optional[HealthCheck]: def get_health_check(self, check_id: str) -> HealthCheck | None:
"""获取健康检查配置""" """获取健康检查配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id,)).fetchone() row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id,)).fetchone()
@@ -1898,7 +1899,7 @@ class OpsManager:
return self._row_to_health_check(row) return self._row_to_health_check(row)
return None return None
def list_health_checks(self, tenant_id: str) -> List[HealthCheck]: def list_health_checks(self, tenant_id: str) -> list[HealthCheck]:
"""列出租户的健康检查""" """列出租户的健康检查"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -1958,7 +1959,7 @@ class OpsManager:
return result return result
async def _check_http_health(self, check: HealthCheck) -> Tuple[HealthStatus, float, str]: async def _check_http_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]:
"""HTTP 健康检查""" """HTTP 健康检查"""
config = check.check_config config = check.check_config
url = config.get("url") url = config.get("url")
@@ -1980,7 +1981,7 @@ class OpsManager:
except Exception as e: except Exception as e:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e) return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e)
async def _check_tcp_health(self, check: HealthCheck) -> Tuple[HealthStatus, float, str]: async def _check_tcp_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]:
"""TCP 健康检查""" """TCP 健康检查"""
config = check.check_config config = check.check_config
host = config.get("host") host = config.get("host")
@@ -1996,12 +1997,12 @@ class OpsManager:
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
return HealthStatus.HEALTHY, response_time, "TCP connection successful" return HealthStatus.HEALTHY, response_time, "TCP connection successful"
except asyncio.TimeoutError: except TimeoutError:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, "Connection timeout" return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, "Connection timeout"
except Exception as e: except Exception as e:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e) return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e)
async def _check_ping_health(self, check: HealthCheck) -> Tuple[HealthStatus, float, str]: async def _check_ping_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]:
"""Ping 健康检查(模拟)""" """Ping 健康检查(模拟)"""
config = check.check_config config = check.check_config
host = config.get("host") host = config.get("host")
@@ -2013,7 +2014,7 @@ class OpsManager:
# 这里模拟成功 # 这里模拟成功
return HealthStatus.HEALTHY, 10.0, "Ping successful" return HealthStatus.HEALTHY, 10.0, "Ping successful"
def get_health_check_results(self, check_id: str, limit: int = 100) -> List[HealthCheckResult]: def get_health_check_results(self, check_id: str, limit: int = 100) -> list[HealthCheckResult]:
"""获取健康检查历史结果""" """获取健康检查历史结果"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2031,7 +2032,7 @@ class OpsManager:
tenant_id: str, tenant_id: str,
name: str, name: str,
primary_region: str, primary_region: str,
secondary_regions: List[str], secondary_regions: list[str],
failover_trigger: str, failover_trigger: str,
auto_failover: bool = False, auto_failover: bool = False,
failover_timeout: int = 300, failover_timeout: int = 300,
@@ -2083,7 +2084,7 @@ class OpsManager:
return config return config
def get_failover_config(self, config_id: str) -> Optional[FailoverConfig]: def get_failover_config(self, config_id: str) -> FailoverConfig | None:
"""获取故障转移配置""" """获取故障转移配置"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM failover_configs WHERE id = ?", (config_id,)).fetchone() row = conn.execute("SELECT * FROM failover_configs WHERE id = ?", (config_id,)).fetchone()
@@ -2092,7 +2093,7 @@ class OpsManager:
return self._row_to_failover_config(row) return self._row_to_failover_config(row)
return None return None
def list_failover_configs(self, tenant_id: str) -> List[FailoverConfig]: def list_failover_configs(self, tenant_id: str) -> list[FailoverConfig]:
"""列出租户的故障转移配置""" """列出租户的故障转移配置"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2100,7 +2101,7 @@ class OpsManager:
).fetchall() ).fetchall()
return [self._row_to_failover_config(row) for row in rows] return [self._row_to_failover_config(row) for row in rows]
def initiate_failover(self, config_id: str, reason: str) -> Optional[FailoverEvent]: def initiate_failover(self, config_id: str, reason: str) -> FailoverEvent | None:
"""发起故障转移""" """发起故障转移"""
config = self.get_failover_config(config_id) config = self.get_failover_config(config_id)
if not config or not config.is_enabled: if not config or not config.is_enabled:
@@ -2150,7 +2151,7 @@ class OpsManager:
return event return event
def update_failover_status(self, event_id: str, status: str) -> Optional[FailoverEvent]: def update_failover_status(self, event_id: str, status: str) -> FailoverEvent | None:
"""更新故障转移状态""" """更新故障转移状态"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -2186,7 +2187,7 @@ class OpsManager:
return self.get_failover_event(event_id) return self.get_failover_event(event_id)
def get_failover_event(self, event_id: str) -> Optional[FailoverEvent]: def get_failover_event(self, event_id: str) -> FailoverEvent | None:
"""获取故障转移事件""" """获取故障转移事件"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id,)).fetchone() row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id,)).fetchone()
@@ -2195,7 +2196,7 @@ class OpsManager:
return self._row_to_failover_event(row) return self._row_to_failover_event(row)
return None return None
def list_failover_events(self, tenant_id: str, limit: int = 100) -> List[FailoverEvent]: def list_failover_events(self, tenant_id: str, limit: int = 100) -> list[FailoverEvent]:
"""列出租户的故障转移事件""" """列出租户的故障转移事件"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2272,7 +2273,7 @@ class OpsManager:
return job return job
def get_backup_job(self, job_id: str) -> Optional[BackupJob]: def get_backup_job(self, job_id: str) -> BackupJob | None:
"""获取备份任务""" """获取备份任务"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id,)).fetchone() row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id,)).fetchone()
@@ -2281,7 +2282,7 @@ class OpsManager:
return self._row_to_backup_job(row) return self._row_to_backup_job(row)
return None return None
def list_backup_jobs(self, tenant_id: str) -> List[BackupJob]: def list_backup_jobs(self, tenant_id: str) -> list[BackupJob]:
"""列出租户的备份任务""" """列出租户的备份任务"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2289,7 +2290,7 @@ class OpsManager:
).fetchall() ).fetchall()
return [self._row_to_backup_job(row) for row in rows] return [self._row_to_backup_job(row) for row in rows]
def execute_backup(self, job_id: str) -> Optional[BackupRecord]: def execute_backup(self, job_id: str) -> BackupRecord | None:
"""执行备份""" """执行备份"""
job = self.get_backup_job(job_id) job = self.get_backup_job(job_id)
if not job or not job.is_enabled: if not job or not job.is_enabled:
@@ -2354,7 +2355,7 @@ class OpsManager:
) )
conn.commit() conn.commit()
def get_backup_record(self, record_id: str) -> Optional[BackupRecord]: def get_backup_record(self, record_id: str) -> BackupRecord | None:
"""获取备份记录""" """获取备份记录"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id,)).fetchone() row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id,)).fetchone()
@@ -2363,7 +2364,7 @@ class OpsManager:
return self._row_to_backup_record(row) return self._row_to_backup_record(row)
return None return None
def list_backup_records(self, tenant_id: str, job_id: str = None, limit: int = 100) -> List[BackupRecord]: def list_backup_records(self, tenant_id: str, job_id: str = None, limit: int = 100) -> list[BackupRecord]:
"""列出租户的备份记录""" """列出租户的备份记录"""
query = "SELECT * FROM backup_records WHERE tenant_id = ?" query = "SELECT * FROM backup_records WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -2452,7 +2453,7 @@ class OpsManager:
return report return report
def _detect_cost_anomalies(self, utilizations: List[ResourceUtilization]) -> List[Dict]: def _detect_cost_anomalies(self, utilizations: list[ResourceUtilization]) -> list[dict]:
"""检测成本异常""" """检测成本异常"""
anomalies = [] anomalies = []
@@ -2483,7 +2484,7 @@ class OpsManager:
return anomalies return anomalies
def _calculate_cost_trends(self, tenant_id: str, year: int, month: int) -> Dict: def _calculate_cost_trends(self, tenant_id: str, year: int, month: int) -> dict:
"""计算成本趋势""" """计算成本趋势"""
# 简化实现:返回模拟趋势 # 简化实现:返回模拟趋势
return {"month_over_month": 0.05, "year_over_year": 0.15, "forecast_next_month": 1.05} # 5% 增长 # 15% 增长 return {"month_over_month": 0.05, "year_over_year": 0.15, "forecast_next_month": 1.05} # 5% 增长 # 15% 增长
@@ -2498,7 +2499,7 @@ class OpsManager:
avg_utilization: float, avg_utilization: float,
idle_time_percent: float, idle_time_percent: float,
report_date: str, report_date: str,
recommendations: List[str] = None, recommendations: list[str] = None,
) -> ResourceUtilization: ) -> ResourceUtilization:
"""记录资源利用率""" """记录资源利用率"""
util_id = f"ru_{uuid.uuid4().hex[:16]}" util_id = f"ru_{uuid.uuid4().hex[:16]}"
@@ -2541,7 +2542,7 @@ class OpsManager:
return util return util
def get_resource_utilizations(self, tenant_id: str, report_period: str) -> List[ResourceUtilization]: def get_resource_utilizations(self, tenant_id: str, report_period: str) -> list[ResourceUtilization]:
"""获取资源利用率列表""" """获取资源利用率列表"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2552,7 +2553,7 @@ class OpsManager:
).fetchall() ).fetchall()
return [self._row_to_resource_utilization(row) for row in rows] return [self._row_to_resource_utilization(row) for row in rows]
def detect_idle_resources(self, tenant_id: str) -> List[IdleResource]: def detect_idle_resources(self, tenant_id: str) -> list[IdleResource]:
"""检测闲置资源""" """检测闲置资源"""
idle_resources = [] idle_resources = []
@@ -2615,7 +2616,7 @@ class OpsManager:
return idle_resources return idle_resources
def get_idle_resources(self, tenant_id: str) -> List[IdleResource]: def get_idle_resources(self, tenant_id: str) -> list[IdleResource]:
"""获取闲置资源列表""" """获取闲置资源列表"""
with self._get_db() as conn: with self._get_db() as conn:
rows = conn.execute( rows = conn.execute(
@@ -2623,7 +2624,7 @@ class OpsManager:
).fetchall() ).fetchall()
return [self._row_to_idle_resource(row) for row in rows] return [self._row_to_idle_resource(row) for row in rows]
def generate_cost_optimization_suggestions(self, tenant_id: str) -> List[CostOptimizationSuggestion]: def generate_cost_optimization_suggestions(self, tenant_id: str) -> list[CostOptimizationSuggestion]:
"""生成成本优化建议""" """生成成本优化建议"""
suggestions = [] suggestions = []
@@ -2691,7 +2692,7 @@ class OpsManager:
def get_cost_optimization_suggestions( def get_cost_optimization_suggestions(
self, tenant_id: str, is_applied: bool = None self, tenant_id: str, is_applied: bool = None
) -> List[CostOptimizationSuggestion]: ) -> list[CostOptimizationSuggestion]:
"""获取成本优化建议""" """获取成本优化建议"""
query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?" query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?"
params = [tenant_id] params = [tenant_id]
@@ -2706,7 +2707,7 @@ class OpsManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_cost_optimization_suggestion(row) for row in rows] return [self._row_to_cost_optimization_suggestion(row) for row in rows]
def apply_cost_optimization_suggestion(self, suggestion_id: str) -> Optional[CostOptimizationSuggestion]: def apply_cost_optimization_suggestion(self, suggestion_id: str) -> CostOptimizationSuggestion | None:
"""应用成本优化建议""" """应用成本优化建议"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -2723,7 +2724,7 @@ class OpsManager:
return self.get_cost_optimization_suggestion(suggestion_id) return self.get_cost_optimization_suggestion(suggestion_id)
def get_cost_optimization_suggestion(self, suggestion_id: str) -> Optional[CostOptimizationSuggestion]: def get_cost_optimization_suggestion(self, suggestion_id: str) -> CostOptimizationSuggestion | None:
"""获取成本优化建议详情""" """获取成本优化建议详情"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,)).fetchone() row = conn.execute("SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,)).fetchone()

View File

@@ -6,6 +6,7 @@ OSS 上传工具 - 用于阿里听悟音频上传
import os import os
import uuid import uuid
from datetime import datetime from datetime import datetime
import oss2 import oss2

View File

@@ -9,18 +9,19 @@ Phase 7 Task 8: Performance Optimization & Scaling
4. PerformanceMonitor - 性能监控API响应、查询性能、缓存命中率 4. PerformanceMonitor - 性能监控API响应、查询性能、缓存命中率
""" """
import os
import json
import time
import hashlib import hashlib
import json
import os
import sqlite3 import sqlite3
import threading import threading
from dataclasses import dataclass, field import time
from typing import Dict, List, Optional, Any, Callable, Tuple
from datetime import datetime
from collections import OrderedDict
from functools import wraps
import uuid import uuid
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from typing import Any
# 尝试导入 Redis # 尝试导入 Redis
try: try:
@@ -67,7 +68,7 @@ class CacheEntry:
key: str key: str
value: Any value: Any
created_at: float created_at: float
expires_at: Optional[float] expires_at: float | None
access_count: int = 0 access_count: int = 0
last_accessed: float = 0 last_accessed: float = 0
size_bytes: int = 0 size_bytes: int = 0
@@ -79,12 +80,12 @@ class PerformanceMetric:
id: str id: str
metric_type: str # api_response, db_query, cache_operation metric_type: str # api_response, db_query, cache_operation
endpoint: Optional[str] endpoint: str | None
duration_ms: float duration_ms: float
timestamp: str timestamp: str
metadata: Dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
"id": self.id, "id": self.id,
"metric_type": self.metric_type, "metric_type": self.metric_type,
@@ -102,16 +103,16 @@ class TaskInfo:
id: str id: str
task_type: str task_type: str
status: str # pending, running, success, failed, retrying status: str # pending, running, success, failed, retrying
payload: Dict payload: dict
created_at: str created_at: str
started_at: Optional[str] = None started_at: str | None = None
completed_at: Optional[str] = None completed_at: str | None = None
result: Optional[Any] = None result: Any | None = None
error_message: Optional[str] = None error_message: str | None = None
retry_count: int = 0 retry_count: int = 0
max_retries: int = 3 max_retries: int = 3
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
"id": self.id, "id": self.id,
"task_type": self.task_type, "task_type": self.task_type,
@@ -132,7 +133,7 @@ class ShardInfo:
"""分片信息数据模型""" """分片信息数据模型"""
shard_id: str shard_id: str
shard_key_range: Tuple[str, str] # (start, end) shard_key_range: tuple[str, str] # (start, end)
db_path: str db_path: str
entity_count: int = 0 entity_count: int = 0
is_active: bool = True is_active: bool = True
@@ -160,7 +161,7 @@ class CacheManager:
def __init__( def __init__(
self, self,
redis_url: Optional[str] = None, redis_url: str | None = None,
max_memory_size: int = 100 * 1024 * 1024, # 100MB max_memory_size: int = 100 * 1024 * 1024, # 100MB
default_ttl: int = 3600, # 1小时 default_ttl: int = 3600, # 1小时
db_path: str = "insightflow.db", db_path: str = "insightflow.db",
@@ -242,7 +243,7 @@ class CacheManager:
self.current_memory_size -= oldest_entry.size_bytes self.current_memory_size -= oldest_entry.size_bytes
self.stats.evictions += 1 self.stats.evictions += 1
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Any | None:
""" """
获取缓存值 获取缓存值
@@ -291,7 +292,7 @@ class CacheManager:
self.stats.misses += 1 self.stats.misses += 1
return None return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
""" """
设置缓存值 设置缓存值
@@ -373,7 +374,7 @@ class CacheManager:
self.current_memory_size = 0 self.current_memory_size = 0
return True return True
def get_many(self, keys: List[str]) -> Dict[str, Any]: def get_many(self, keys: list[str]) -> dict[str, Any]:
"""批量获取缓存""" """批量获取缓存"""
results = {} results = {}
@@ -397,7 +398,7 @@ class CacheManager:
return results return results
def set_many(self, mapping: Dict[str, Any], ttl: Optional[int] = None) -> bool: def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool:
"""批量设置缓存""" """批量设置缓存"""
ttl = ttl or self.default_ttl ttl = ttl or self.default_ttl
@@ -417,7 +418,7 @@ class CacheManager:
self.set(key, value, ttl) self.set(key, value, ttl)
return True return True
def get_stats(self) -> Dict: def get_stats(self) -> dict:
"""获取缓存统计""" """获取缓存统计"""
self.stats.update_hit_rate() self.stats.update_hit_rate()
@@ -468,7 +469,7 @@ class CacheManager:
conn.commit() conn.commit()
conn.close() conn.close()
def warm_up(self, project_id: str) -> Dict: def warm_up(self, project_id: str) -> dict:
""" """
缓存预热 - 加载项目的热点数据 缓存预热 - 加载项目的热点数据
@@ -612,7 +613,7 @@ class DatabaseSharding:
os.makedirs(shard_db_dir, exist_ok=True) os.makedirs(shard_db_dir, exist_ok=True)
# 分片映射 # 分片映射
self.shard_map: Dict[str, ShardInfo] = {} self.shard_map: dict[str, ShardInfo] = {}
# 初始化分片 # 初始化分片
self._init_shards() self._init_shards()
@@ -708,7 +709,7 @@ class DatabaseSharding:
return conn return conn
def get_all_shards(self) -> List[ShardInfo]: def get_all_shards(self) -> list[ShardInfo]:
"""获取所有分片信息""" """获取所有分片信息"""
return list(self.shard_map.values()) return list(self.shard_map.values())
@@ -805,7 +806,7 @@ class DatabaseSharding:
conn.close() conn.close()
def cross_shard_query(self, query_func: Callable) -> List[Dict]: def cross_shard_query(self, query_func: Callable) -> list[dict]:
""" """
跨分片查询 跨分片查询
@@ -831,7 +832,7 @@ class DatabaseSharding:
return results return results
def get_shard_stats(self) -> List[Dict]: def get_shard_stats(self) -> list[dict]:
"""获取所有分片的统计信息""" """获取所有分片的统计信息"""
stats = [] stats = []
@@ -852,7 +853,7 @@ class DatabaseSharding:
return stats return stats
def rebalance_shards(self) -> Dict: def rebalance_shards(self) -> dict:
""" """
重新平衡分片 重新平衡分片
@@ -899,15 +900,15 @@ class TaskQueue:
- 任务状态追踪和重试机制 - 任务状态追踪和重试机制
""" """
def __init__(self, redis_url: Optional[str] = None, db_path: str = "insightflow.db"): def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db"):
self.db_path = db_path self.db_path = db_path
self.redis_url = redis_url self.redis_url = redis_url
self.celery_app = None self.celery_app = None
self.use_celery = False self.use_celery = False
# 内存任务存储(非 Celery 模式) # 内存任务存储(非 Celery 模式)
self.tasks: Dict[str, TaskInfo] = {} self.tasks: dict[str, TaskInfo] = {}
self.task_handlers: Dict[str, Callable] = {} self.task_handlers: dict[str, Callable] = {}
self.task_lock = threading.RLock() self.task_lock = threading.RLock()
# 初始化任务队列表 # 初始化任务队列表
@@ -956,7 +957,7 @@ class TaskQueue:
"""注册任务处理器""" """注册任务处理器"""
self.task_handlers[task_type] = handler self.task_handlers[task_type] = handler
def submit(self, task_type: str, payload: Dict, max_retries: int = 3) -> str: def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str:
""" """
提交任务 提交任务
@@ -1112,7 +1113,7 @@ class TaskQueue:
conn.commit() conn.commit()
conn.close() conn.close()
def get_status(self, task_id: str) -> Optional[TaskInfo]: def get_status(self, task_id: str) -> TaskInfo | None:
"""获取任务状态""" """获取任务状态"""
if self.use_celery: if self.use_celery:
try: try:
@@ -1143,8 +1144,8 @@ class TaskQueue:
return self.tasks.get(task_id) return self.tasks.get(task_id)
def list_tasks( def list_tasks(
self, status: Optional[str] = None, task_type: Optional[str] = None, limit: int = 100 self, status: str | None = None, task_type: str | None = None, limit: int = 100
) -> List[TaskInfo]: ) -> list[TaskInfo]:
"""列出任务""" """列出任务"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -1233,7 +1234,7 @@ class TaskQueue:
self._update_task_status(task) self._update_task_status(task)
return True return True
def get_stats(self) -> Dict: def get_stats(self) -> dict:
"""获取任务队列统计""" """获取任务队列统计"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1290,15 +1291,15 @@ class PerformanceMonitor:
self.alert_threshold = alert_threshold self.alert_threshold = alert_threshold
# 内存中的指标缓存 # 内存中的指标缓存
self.metrics_buffer: List[PerformanceMetric] = [] self.metrics_buffer: list[PerformanceMetric] = []
self.buffer_lock = threading.RLock() self.buffer_lock = threading.RLock()
self.buffer_size = 100 self.buffer_size = 100
# 告警回调 # 告警回调
self.alert_handlers: List[Callable] = [] self.alert_handlers: list[Callable] = []
def record_metric( def record_metric(
self, metric_type: str, duration_ms: float, endpoint: Optional[str] = None, metadata: Optional[Dict] = None self, metric_type: str, duration_ms: float, endpoint: str | None = None, metadata: dict | None = None
): ):
""" """
记录性能指标 记录性能指标
@@ -1385,7 +1386,7 @@ class PerformanceMonitor:
"""注册告警处理器""" """注册告警处理器"""
self.alert_handlers.append(handler) self.alert_handlers.append(handler)
def get_stats(self, hours: int = 24) -> Dict: def get_stats(self, hours: int = 24) -> dict:
""" """
获取性能统计 获取性能统计
@@ -1504,7 +1505,7 @@ class PerformanceMonitor:
], ],
} }
def get_api_performance(self, endpoint: Optional[str] = None, hours: int = 24) -> Dict: def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict:
"""获取 API 性能详情""" """获取 API 性能详情"""
self._flush_metrics() self._flush_metrics()
@@ -1584,7 +1585,7 @@ class PerformanceMonitor:
# ==================== 性能装饰器 ==================== # ==================== 性能装饰器 ====================
def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Optional[Callable] = None): def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None):
""" """
缓存装饰器 缓存装饰器
@@ -1624,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
return decorator return decorator
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: Optional[str] = None): def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None):
""" """
性能监控装饰器 性能监控装饰器
@@ -1662,7 +1663,7 @@ class PerformanceManager:
整合缓存管理、数据库分片、任务队列和性能监控功能 整合缓存管理、数据库分片、任务队列和性能监控功能
""" """
def __init__(self, db_path: str = "insightflow.db", redis_url: Optional[str] = None, enable_sharding: bool = False): def __init__(self, db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False):
self.db_path = db_path self.db_path = db_path
# 初始化各模块 # 初始化各模块
@@ -1674,7 +1675,7 @@ class PerformanceManager:
self.monitor = PerformanceMonitor(db_path=db_path) self.monitor = PerformanceMonitor(db_path=db_path)
def get_health_status(self) -> Dict: def get_health_status(self) -> dict:
"""获取系统健康状态""" """获取系统健康状态"""
return { return {
"cache": { "cache": {
@@ -1698,7 +1699,7 @@ class PerformanceManager:
}, },
} }
def get_full_stats(self) -> Dict: def get_full_stats(self) -> dict:
"""获取完整统计信息""" """获取完整统计信息"""
stats = { stats = {
"cache": self.cache.get_stats(), "cache": self.cache.get_stats(),
@@ -1717,7 +1718,7 @@ _performance_manager = None
def get_performance_manager( def get_performance_manager(
db_path: str = "insightflow.db", redis_url: Optional[str] = None, enable_sharding: bool = False db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False
) -> PerformanceManager: ) -> PerformanceManager:
"""获取性能管理器单例""" """获取性能管理器单例"""
global _performance_manager global _performance_manager

View File

@@ -4,20 +4,21 @@ InsightFlow Plugin Manager - Phase 7 Task 7
插件与集成系统Chrome插件、飞书/钉钉机器人、Zapier/Make集成、WebDAV同步 插件与集成系统Chrome插件、飞书/钉钉机器人、Zapier/Make集成、WebDAV同步
""" """
import os import base64
import json
import hashlib import hashlib
import hmac import hmac
import base64 import json
import time import os
import uuid
import httpx
import urllib.parse
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
import sqlite3 import sqlite3
import time
import urllib.parse
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
import httpx
# WebDAV 支持 # WebDAV 支持
try: try:
@@ -58,10 +59,10 @@ class Plugin:
plugin_type: str plugin_type: str
project_id: str project_id: str
status: str = "active" status: str = "active"
config: Dict = field(default_factory=dict) config: dict = field(default_factory=dict)
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_used_at: Optional[str] = None last_used_at: str | None = None
use_count: int = 0 use_count: int = 0
@@ -86,13 +87,13 @@ class BotSession:
bot_type: str # feishu, dingtalk bot_type: str # feishu, dingtalk
session_id: str # 群ID或会话ID session_id: str # 群ID或会话ID
session_name: str session_name: str
project_id: Optional[str] = None project_id: str | None = None
webhook_url: str = "" webhook_url: str = ""
secret: str = "" secret: str = ""
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_message_at: Optional[str] = None last_message_at: str | None = None
message_count: int = 0 message_count: int = 0
@@ -104,14 +105,14 @@ class WebhookEndpoint:
name: str name: str
endpoint_type: str # zapier, make, custom endpoint_type: str # zapier, make, custom
endpoint_url: str endpoint_url: str
project_id: Optional[str] = None project_id: str | None = None
auth_type: str = "none" # none, api_key, oauth, custom auth_type: str = "none" # none, api_key, oauth, custom
auth_config: Dict = field(default_factory=dict) auth_config: dict = field(default_factory=dict)
trigger_events: List[str] = field(default_factory=list) trigger_events: list[str] = field(default_factory=list)
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_triggered_at: Optional[str] = None last_triggered_at: str | None = None
trigger_count: int = 0 trigger_count: int = 0
@@ -128,7 +129,7 @@ class WebDAVSync:
remote_path: str = "/insightflow" remote_path: str = "/insightflow"
sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only
sync_interval: int = 3600 # 秒 sync_interval: int = 3600 # 秒
last_sync_at: Optional[str] = None last_sync_at: str | None = None
last_sync_status: str = "pending" # pending, success, failed last_sync_status: str = "pending" # pending, success, failed
last_sync_error: str = "" last_sync_error: str = ""
is_active: bool = True is_active: bool = True
@@ -143,13 +144,13 @@ class ChromeExtensionToken:
id: str id: str
token: str token: str
user_id: Optional[str] = None user_id: str | None = None
project_id: Optional[str] = None project_id: str | None = None
name: str = "" name: str = ""
permissions: List[str] = field(default_factory=lambda: ["read", "write"]) permissions: list[str] = field(default_factory=lambda: ["read", "write"])
expires_at: Optional[str] = None expires_at: str | None = None
created_at: str = "" created_at: str = ""
last_used_at: Optional[str] = None last_used_at: str | None = None
use_count: int = 0 use_count: int = 0
is_revoked: bool = False is_revoked: bool = False
@@ -171,7 +172,7 @@ class PluginManager:
self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make") self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make")
self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self) self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self)
def get_handler(self, plugin_type: PluginType) -> Optional[Any]: def get_handler(self, plugin_type: PluginType) -> Any | None:
"""获取插件处理器""" """获取插件处理器"""
return self._handlers.get(plugin_type) return self._handlers.get(plugin_type)
@@ -205,7 +206,7 @@ class PluginManager:
plugin.updated_at = now plugin.updated_at = now
return plugin return plugin
def get_plugin(self, plugin_id: str) -> Optional[Plugin]: def get_plugin(self, plugin_id: str) -> Plugin | None:
"""获取插件""" """获取插件"""
conn = self.db.get_conn() conn = self.db.get_conn()
row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id,)).fetchone() row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id,)).fetchone()
@@ -215,7 +216,7 @@ class PluginManager:
return self._row_to_plugin(row) return self._row_to_plugin(row)
return None return None
def list_plugins(self, project_id: str = None, plugin_type: str = None, status: str = None) -> List[Plugin]: def list_plugins(self, project_id: str = None, plugin_type: str = None, status: str = None) -> list[Plugin]:
"""列出插件""" """列出插件"""
conn = self.db.get_conn() conn = self.db.get_conn()
@@ -239,7 +240,7 @@ class PluginManager:
return [self._row_to_plugin(row) for row in rows] return [self._row_to_plugin(row) for row in rows]
def update_plugin(self, plugin_id: str, **kwargs) -> Optional[Plugin]: def update_plugin(self, plugin_id: str, **kwargs) -> Plugin | None:
"""更新插件""" """更新插件"""
conn = self.db.get_conn() conn = self.db.get_conn()
@@ -341,7 +342,7 @@ class PluginManager:
updated_at=now, updated_at=now,
) )
def get_plugin_config(self, plugin_id: str, key: str) -> Optional[str]: def get_plugin_config(self, plugin_id: str, key: str) -> str | None:
"""获取插件配置""" """获取插件配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
row = conn.execute( row = conn.execute(
@@ -351,7 +352,7 @@ class PluginManager:
return row["config_value"] if row else None return row["config_value"] if row else None
def get_all_plugin_configs(self, plugin_id: str) -> Dict[str, str]: def get_all_plugin_configs(self, plugin_id: str) -> dict[str, str]:
"""获取插件所有配置""" """获取插件所有配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
rows = conn.execute( rows = conn.execute(
@@ -396,7 +397,7 @@ class ChromeExtensionHandler:
name: str, name: str,
user_id: str = None, user_id: str = None,
project_id: str = None, project_id: str = None,
permissions: List[str] = None, permissions: list[str] = None,
expires_days: int = None, expires_days: int = None,
) -> ChromeExtensionToken: ) -> ChromeExtensionToken:
"""创建 Chrome 扩展令牌""" """创建 Chrome 扩展令牌"""
@@ -448,7 +449,7 @@ class ChromeExtensionHandler:
created_at=now, created_at=now,
) )
def validate_token(self, token: str) -> Optional[ChromeExtensionToken]: def validate_token(self, token: str) -> ChromeExtensionToken | None:
"""验证 Chrome 扩展令牌""" """验证 Chrome 扩展令牌"""
token_hash = hashlib.sha256(token.encode()).hexdigest() token_hash = hashlib.sha256(token.encode()).hexdigest()
@@ -501,7 +502,7 @@ class ChromeExtensionHandler:
return cursor.rowcount > 0 return cursor.rowcount > 0
def list_tokens(self, user_id: str = None, project_id: str = None) -> List[ChromeExtensionToken]: def list_tokens(self, user_id: str = None, project_id: str = None) -> list[ChromeExtensionToken]:
"""列出令牌""" """列出令牌"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -544,7 +545,7 @@ class ChromeExtensionHandler:
async def import_webpage( async def import_webpage(
self, token: ChromeExtensionToken, url: str, title: str, content: str, html_content: str = None self, token: ChromeExtensionToken, url: str, title: str, content: str, html_content: str = None
) -> Dict: ) -> dict:
"""导入网页内容""" """导入网页内容"""
if not token.project_id: if not token.project_id:
return {"success": False, "error": "Token not associated with any project"} return {"success": False, "error": "Token not associated with any project"}
@@ -617,7 +618,7 @@ class BotHandler:
updated_at=now, updated_at=now,
) )
def get_session(self, session_id: str) -> Optional[BotSession]: def get_session(self, session_id: str) -> BotSession | None:
"""获取会话""" """获取会话"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
row = conn.execute( row = conn.execute(
@@ -631,7 +632,7 @@ class BotHandler:
return self._row_to_session(row) return self._row_to_session(row)
return None return None
def list_sessions(self, project_id: str = None) -> List[BotSession]: def list_sessions(self, project_id: str = None) -> list[BotSession]:
"""列出会话""" """列出会话"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -652,7 +653,7 @@ class BotHandler:
return [self._row_to_session(row) for row in rows] return [self._row_to_session(row) for row in rows]
def update_session(self, session_id: str, **kwargs) -> Optional[BotSession]: def update_session(self, session_id: str, **kwargs) -> BotSession | None:
"""更新会话""" """更新会话"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -709,7 +710,7 @@ class BotHandler:
message_count=row["message_count"], message_count=row["message_count"],
) )
async def handle_message(self, session: BotSession, message: Dict) -> Dict: async def handle_message(self, session: BotSession, message: dict) -> dict:
"""处理收到的消息""" """处理收到的消息"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -740,7 +741,7 @@ class BotHandler:
return {"success": False, "error": "Unsupported message type"} return {"success": False, "error": "Unsupported message type"}
async def _handle_text_message(self, session: BotSession, text: str, raw_message: Dict) -> Dict: async def _handle_text_message(self, session: BotSession, text: str, raw_message: dict) -> dict:
"""处理文本消息""" """处理文本消息"""
# 简单命令处理 # 简单命令处理
if text.startswith("/help"): if text.startswith("/help"):
@@ -772,7 +773,7 @@ class BotHandler:
# 默认回复 # 默认回复
return {"success": True, "response": f"收到消息:{text[:100]}...\n\n使用 /help 查看可用命令"} return {"success": True, "response": f"收到消息:{text[:100]}...\n\n使用 /help 查看可用命令"}
async def _handle_audio_message(self, session: BotSession, message: Dict) -> Dict: async def _handle_audio_message(self, session: BotSession, message: dict) -> dict:
"""处理音频消息""" """处理音频消息"""
if not session.project_id: if not session.project_id:
return {"success": False, "error": "Session not bound to any project"} return {"success": False, "error": "Session not bound to any project"}
@@ -802,7 +803,7 @@ class BotHandler:
except Exception as e: except Exception as e:
return {"success": False, "error": f"Failed to process audio: {str(e)}"} return {"success": False, "error": f"Failed to process audio: {str(e)}"}
async def _handle_file_message(self, session: BotSession, message: Dict) -> Dict: async def _handle_file_message(self, session: BotSession, message: dict) -> dict:
"""处理文件消息""" """处理文件消息"""
return {"success": True, "response": "📎 收到文件,正在处理中..."} return {"success": True, "response": "📎 收到文件,正在处理中..."}
@@ -825,8 +826,8 @@ class BotHandler:
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool: async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送飞书消息""" """发送飞书消息"""
import hashlib
import base64 import base64
import hashlib
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
@@ -850,8 +851,8 @@ class BotHandler:
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool: async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送钉钉消息""" """发送钉钉消息"""
import hashlib
import base64 import base64
import hashlib
timestamp = str(round(time.time() * 1000)) timestamp = str(round(time.time() * 1000))
@@ -890,8 +891,8 @@ class WebhookIntegration:
endpoint_url: str, endpoint_url: str,
project_id: str = None, project_id: str = None,
auth_type: str = "none", auth_type: str = "none",
auth_config: Dict = None, auth_config: dict = None,
trigger_events: List[str] = None, trigger_events: list[str] = None,
) -> WebhookEndpoint: ) -> WebhookEndpoint:
"""创建 Webhook 端点""" """创建 Webhook 端点"""
endpoint_id = str(uuid.uuid4())[:8] endpoint_id = str(uuid.uuid4())[:8]
@@ -935,7 +936,7 @@ class WebhookIntegration:
updated_at=now, updated_at=now,
) )
def get_endpoint(self, endpoint_id: str) -> Optional[WebhookEndpoint]: def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None:
"""获取端点""" """获取端点"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
row = conn.execute( row = conn.execute(
@@ -947,7 +948,7 @@ class WebhookIntegration:
return self._row_to_endpoint(row) return self._row_to_endpoint(row)
return None return None
def list_endpoints(self, project_id: str = None) -> List[WebhookEndpoint]: def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]:
"""列出端点""" """列出端点"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -968,7 +969,7 @@ class WebhookIntegration:
return [self._row_to_endpoint(row) for row in rows] return [self._row_to_endpoint(row) for row in rows]
def update_endpoint(self, endpoint_id: str, **kwargs) -> Optional[WebhookEndpoint]: def update_endpoint(self, endpoint_id: str, **kwargs) -> WebhookEndpoint | None:
"""更新端点""" """更新端点"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -1034,7 +1035,7 @@ class WebhookIntegration:
trigger_count=row["trigger_count"], trigger_count=row["trigger_count"],
) )
async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: Dict) -> bool: async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool:
"""触发 Webhook""" """触发 Webhook"""
if not endpoint.is_active: if not endpoint.is_active:
return False return False
@@ -1079,7 +1080,7 @@ class WebhookIntegration:
print(f"Failed to trigger webhook: {e}") print(f"Failed to trigger webhook: {e}")
return False return False
async def test_endpoint(self, endpoint: WebhookEndpoint) -> Dict: async def test_endpoint(self, endpoint: WebhookEndpoint) -> dict:
"""测试端点""" """测试端点"""
test_data = { test_data = {
"message": "This is a test event from InsightFlow", "message": "This is a test event from InsightFlow",
@@ -1160,7 +1161,7 @@ class WebDAVSyncManager:
updated_at=now, updated_at=now,
) )
def get_sync(self, sync_id: str) -> Optional[WebDAVSync]: def get_sync(self, sync_id: str) -> WebDAVSync | None:
"""获取同步配置""" """获取同步配置"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id,)).fetchone() row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id,)).fetchone()
@@ -1170,7 +1171,7 @@ class WebDAVSyncManager:
return self._row_to_sync(row) return self._row_to_sync(row)
return None return None
def list_syncs(self, project_id: str = None) -> List[WebDAVSync]: def list_syncs(self, project_id: str = None) -> list[WebDAVSync]:
"""列出同步配置""" """列出同步配置"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -1185,7 +1186,7 @@ class WebDAVSyncManager:
return [self._row_to_sync(row) for row in rows] return [self._row_to_sync(row) for row in rows]
def update_sync(self, sync_id: str, **kwargs) -> Optional[WebDAVSync]: def update_sync(self, sync_id: str, **kwargs) -> WebDAVSync | None:
"""更新同步配置""" """更新同步配置"""
conn = self.pm.db.get_conn() conn = self.pm.db.get_conn()
@@ -1252,7 +1253,7 @@ class WebDAVSyncManager:
sync_count=row["sync_count"], sync_count=row["sync_count"],
) )
async def test_connection(self, sync: WebDAVSync) -> Dict: async def test_connection(self, sync: WebDAVSync) -> dict:
"""测试 WebDAV 连接""" """测试 WebDAV 连接"""
if not WEBDAV_AVAILABLE: if not WEBDAV_AVAILABLE:
return {"success": False, "error": "WebDAV library not available"} return {"success": False, "error": "WebDAV library not available"}
@@ -1268,7 +1269,7 @@ class WebDAVSyncManager:
except Exception as e: except Exception as e:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
async def sync_project(self, sync: WebDAVSync) -> Dict: async def sync_project(self, sync: WebDAVSync) -> dict:
"""同步项目到 WebDAV""" """同步项目到 WebDAV"""
if not WEBDAV_AVAILABLE: if not WEBDAV_AVAILABLE:
return {"success": False, "error": "WebDAV library not available"} return {"success": False, "error": "WebDAV library not available"}

View File

@@ -5,11 +5,11 @@ API 限流中间件
支持基于内存的滑动窗口限流 支持基于内存的滑动窗口限流
""" """
import time
import asyncio import asyncio
from typing import Dict, Optional, Callable import time
from dataclasses import dataclass
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps from functools import wraps
@@ -37,7 +37,7 @@ class SlidingWindowCounter:
def __init__(self, window_size: int = 60): def __init__(self, window_size: int = 60):
self.window_size = window_size self.window_size = window_size
self.requests: Dict[int, int] = defaultdict(int) # 秒级计数 self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock() self._cleanup_lock = asyncio.Lock()
@@ -69,13 +69,13 @@ class RateLimiter:
def __init__(self): def __init__(self):
# key -> SlidingWindowCounter # key -> SlidingWindowCounter
self.counters: Dict[str, SlidingWindowCounter] = {} self.counters: dict[str, SlidingWindowCounter] = {}
# key -> RateLimitConfig # key -> RateLimitConfig
self.configs: Dict[str, RateLimitConfig] = {} self.configs: dict[str, RateLimitConfig] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock() self._cleanup_lock = asyncio.Lock()
async def is_allowed(self, key: str, config: Optional[RateLimitConfig] = None) -> RateLimitInfo: async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
""" """
检查是否允许请求 检查是否允许请求
@@ -143,7 +143,7 @@ class RateLimiter:
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0, retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
) )
def reset(self, key: Optional[str] = None): def reset(self, key: str | None = None):
"""重置限流计数器""" """重置限流计数器"""
if key: if key:
self.counters.pop(key, None) self.counters.pop(key, None)
@@ -154,7 +154,7 @@ class RateLimiter:
# 全局限流器实例 # 全局限流器实例
_rate_limiter: Optional[RateLimiter] = None _rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter: def get_rate_limiter() -> RateLimiter:
@@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter:
# 限流装饰器(用于函数级别限流) # 限流装饰器(用于函数级别限流)
def rate_limit(requests_per_minute: int = 60, key_func: Optional[Callable] = None): def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None):
""" """
限流装饰器 限流装饰器

View File

@@ -9,15 +9,14 @@ Phase 7 Task 6: Advanced Search & Discovery
4. KnowledgeGapDetection - 知识缺口识别 4. KnowledgeGapDetection - 知识缺口识别
""" """
import re import hashlib
import json import json
import math import math
import re
import sqlite3 import sqlite3
import hashlib
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple, Set
from datetime import datetime
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum from enum import Enum
@@ -46,10 +45,10 @@ class SearchResult:
content_type: str # transcript, entity, relation content_type: str # transcript, entity, relation
project_id: str project_id: str
score: float score: float
highlights: List[Tuple[int, int]] = field(default_factory=list) # 高亮位置 highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置
metadata: Dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
"id": self.id, "id": self.id,
"content": self.content, "content": self.content,
@@ -69,10 +68,10 @@ class SemanticSearchResult:
content_type: str content_type: str
project_id: str project_id: str
similarity: float similarity: float
embedding: Optional[List[float]] = None embedding: list[float] | None = None
metadata: Dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict: def to_dict(self) -> dict:
result = { result = {
"id": self.id, "id": self.id,
"content": self.content[:500] + "..." if len(self.content) > 500 else self.content, "content": self.content[:500] + "..." if len(self.content) > 500 else self.content,
@@ -95,12 +94,12 @@ class EntityPath:
target_entity_id: str target_entity_id: str
target_entity_name: str target_entity_name: str
path_length: int path_length: int
nodes: List[Dict] # 路径上的节点 nodes: list[dict] # 路径上的节点
edges: List[Dict] # 路径上的边 edges: list[dict] # 路径上的边
confidence: float confidence: float
path_description: str path_description: str
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
"path_id": self.path_id, "path_id": self.path_id,
"source_entity_id": self.source_entity_id, "source_entity_id": self.source_entity_id,
@@ -120,15 +119,15 @@ class KnowledgeGap:
"""知识缺口数据模型""" """知识缺口数据模型"""
gap_id: str gap_id: str
gap_type: str # missing_attribute, sparse_relation, isolated_entity, incomplete_entity gap_type: str # missing_attribute, sparse_relation, isolated_entity, incomplete_entity
entity_id: Optional[str] entity_id: str | None
entity_name: Optional[str] entity_name: str | None
description: str description: str
severity: str # high, medium, low severity: str # high, medium, low
suggestions: List[str] suggestions: list[str]
related_entities: List[str] related_entities: list[str]
metadata: Dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
"gap_id": self.gap_id, "gap_id": self.gap_id,
"gap_type": self.gap_type, "gap_type": self.gap_type,
@@ -149,8 +148,8 @@ class SearchIndex:
content_id: str content_id: str
content_type: str content_type: str
project_id: str project_id: str
tokens: List[str] tokens: list[str]
token_positions: Dict[str, List[int]] # 词 -> 位置列表 token_positions: dict[str, list[int]] # 词 -> 位置列表
created_at: str created_at: str
updated_at: str updated_at: str
@@ -162,7 +161,7 @@ class TextEmbedding:
content_id: str content_id: str
content_type: str content_type: str
project_id: str project_id: str
embedding: List[float] embedding: list[float]
model_name: str model_name: str
created_at: str created_at: str
@@ -231,7 +230,7 @@ class FullTextSearch:
conn.commit() conn.commit()
conn.close() conn.close()
def _tokenize(self, text: str) -> List[str]: def _tokenize(self, text: str) -> list[str]:
""" """
中文分词(简化版) 中文分词(简化版)
@@ -243,7 +242,7 @@ class FullTextSearch:
tokens = re.findall(r'[\u4e00-\u9fa5]+|[a-z]+|\d+', text) tokens = re.findall(r'[\u4e00-\u9fa5]+|[a-z]+|\d+', text)
return tokens return tokens
def _extract_positions(self, text: str, tokens: List[str]) -> Dict[str, List[int]]: def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]:
"""提取每个词在文本中的位置""" """提取每个词在文本中的位置"""
positions = defaultdict(list) positions = defaultdict(list)
text_lower = text.lower() text_lower = text.lower()
@@ -326,9 +325,9 @@ class FullTextSearch:
print(f"索引创建失败: {e}") print(f"索引创建失败: {e}")
return False return False
def search(self, query: str, project_id: Optional[str] = None, def search(self, query: str, project_id: str | None = None,
content_types: Optional[List[str]] = None, content_types: list[str] | None = None,
limit: int = 20, offset: int = 0) -> List[SearchResult]: limit: int = 20, offset: int = 0) -> list[SearchResult]:
""" """
全文搜索 全文搜索
@@ -358,7 +357,7 @@ class FullTextSearch:
return scored_results[offset:offset + limit] return scored_results[offset:offset + limit]
def _parse_boolean_query(self, query: str) -> Dict: def _parse_boolean_query(self, query: str) -> dict:
""" """
解析布尔查询 解析布尔查询
@@ -401,9 +400,9 @@ class FullTextSearch:
"phrases": phrases "phrases": phrases
} }
def _execute_boolean_search(self, parsed_query: Dict, def _execute_boolean_search(self, parsed_query: dict,
project_id: Optional[str] = None, project_id: str | None = None,
content_types: Optional[List[str]] = None) -> List[Dict]: content_types: list[str] | None = None) -> list[dict]:
"""执行布尔搜索""" """执行布尔搜索"""
conn = self._get_conn() conn = self._get_conn()
@@ -503,7 +502,7 @@ class FullTextSearch:
return results return results
def _get_content_by_id(self, conn: sqlite3.Connection, def _get_content_by_id(self, conn: sqlite3.Connection,
content_id: str, content_type: str) -> Optional[str]: content_id: str, content_type: str) -> str | None:
"""根据ID获取内容""" """根据ID获取内容"""
try: try:
if content_type == "transcript": if content_type == "transcript":
@@ -542,7 +541,7 @@ class FullTextSearch:
return None return None
def _get_project_id(self, conn: sqlite3.Connection, def _get_project_id(self, conn: sqlite3.Connection,
content_id: str, content_type: str) -> Optional[str]: content_id: str, content_type: str) -> str | None:
"""获取内容所属的项目ID""" """获取内容所属的项目ID"""
try: try:
if content_type == "transcript": if content_type == "transcript":
@@ -567,7 +566,7 @@ class FullTextSearch:
except Exception: except Exception:
return None return None
def _score_results(self, results: List[Dict], parsed_query: Dict) -> List[SearchResult]: def _score_results(self, results: list[dict], parsed_query: dict) -> list[SearchResult]:
"""计算搜索结果的相关性分数""" """计算搜索结果的相关性分数"""
scored = [] scored = []
all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"] all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"]
@@ -689,7 +688,7 @@ class FullTextSearch:
print(f"删除索引失败: {e}") print(f"删除索引失败: {e}")
return False return False
def reindex_project(self, project_id: str) -> Dict: def reindex_project(self, project_id: str) -> dict:
"""重新索引整个项目""" """重新索引整个项目"""
conn = self._get_conn() conn = self._get_conn()
stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0} stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0}
@@ -808,7 +807,7 @@ class SemanticSearch:
"""检查语义搜索是否可用""" """检查语义搜索是否可用"""
return self.model is not None and SENTENCE_TRANSFORMERS_AVAILABLE return self.model is not None and SENTENCE_TRANSFORMERS_AVAILABLE
def generate_embedding(self, text: str) -> Optional[List[float]]: def generate_embedding(self, text: str) -> list[float] | None:
""" """
生成文本的 embedding 向量 生成文本的 embedding 向量
@@ -878,9 +877,9 @@ class SemanticSearch:
print(f"索引 embedding 失败: {e}") print(f"索引 embedding 失败: {e}")
return False return False
def search(self, query: str, project_id: Optional[str] = None, def search(self, query: str, project_id: str | None = None,
content_types: Optional[List[str]] = None, content_types: list[str] | None = None,
top_k: int = 10, threshold: float = 0.5) -> List[SemanticSearchResult]: top_k: int = 10, threshold: float = 0.5) -> list[SemanticSearchResult]:
""" """
语义搜索 语义搜索
@@ -959,7 +958,7 @@ class SemanticSearch:
results.sort(key=lambda x: x.similarity, reverse=True) results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k] return results[:top_k]
def _get_content_text(self, content_id: str, content_type: str) -> Optional[str]: def _get_content_text(self, content_id: str, content_type: str) -> str | None:
"""获取内容文本""" """获取内容文本"""
conn = self._get_conn() conn = self._get_conn()
@@ -1002,7 +1001,7 @@ class SemanticSearch:
return None return None
def find_similar_content(self, content_id: str, content_type: str, def find_similar_content(self, content_id: str, content_type: str,
top_k: int = 5) -> List[SemanticSearchResult]: top_k: int = 5) -> list[SemanticSearchResult]:
""" """
查找与指定内容相似的内容 查找与指定内容相似的内容
@@ -1107,7 +1106,7 @@ class EntityPathDiscovery:
def find_shortest_path(self, source_entity_id: str, def find_shortest_path(self, source_entity_id: str,
target_entity_id: str, target_entity_id: str,
max_depth: int = 5) -> Optional[EntityPath]: max_depth: int = 5) -> EntityPath | None:
""" """
查找两个实体之间的最短路径BFS算法 查找两个实体之间的最短路径BFS算法
@@ -1181,7 +1180,7 @@ class EntityPathDiscovery:
def find_all_paths(self, source_entity_id: str, def find_all_paths(self, source_entity_id: str,
target_entity_id: str, target_entity_id: str,
max_depth: int = 4, max_depth: int = 4,
max_paths: int = 10) -> List[EntityPath]: max_paths: int = 10) -> list[EntityPath]:
""" """
查找两个实体之间的所有路径(限制数量和深度) 查找两个实体之间的所有路径(限制数量和深度)
@@ -1211,7 +1210,7 @@ class EntityPathDiscovery:
paths = [] paths = []
def dfs(current_id: str, target_id: str, def dfs(current_id: str, target_id: str,
path: List[str], visited: Set[str], depth: int): path: list[str], visited: set[str], depth: int):
if depth > max_depth: if depth > max_depth:
return return
@@ -1247,7 +1246,7 @@ class EntityPathDiscovery:
# 构建路径对象 # 构建路径对象
return [self._build_path_object(path, project_id) for path in paths] return [self._build_path_object(path, project_id) for path in paths]
def _build_path_object(self, entity_ids: List[str], def _build_path_object(self, entity_ids: list[str],
project_id: str) -> EntityPath: project_id: str) -> EntityPath:
"""构建路径对象""" """构建路径对象"""
conn = self._get_conn() conn = self._get_conn()
@@ -1312,7 +1311,7 @@ class EntityPathDiscovery:
) )
def find_multi_hop_relations(self, entity_id: str, def find_multi_hop_relations(self, entity_id: str,
max_hops: int = 3) -> List[Dict]: max_hops: int = 3) -> list[dict]:
""" """
查找实体的多跳关系 查找实体的多跳关系
@@ -1394,7 +1393,7 @@ class EntityPathDiscovery:
return relations return relations
def _get_path_to_entity(self, source_id: str, target_id: str, def _get_path_to_entity(self, source_id: str, target_id: str,
project_id: str, conn: sqlite3.Connection) -> List[str]: project_id: str, conn: sqlite3.Connection) -> list[str]:
"""获取从源实体到目标实体的路径(简化版)""" """获取从源实体到目标实体的路径(简化版)"""
# BFS 找路径 # BFS 找路径
visited = {source_id} visited = {source_id}
@@ -1428,7 +1427,7 @@ class EntityPathDiscovery:
return [] return []
def generate_path_visualization(self, path: EntityPath) -> Dict: def generate_path_visualization(self, path: EntityPath) -> dict:
""" """
生成路径可视化数据 生成路径可视化数据
@@ -1467,7 +1466,7 @@ class EntityPathDiscovery:
"confidence": path.confidence "confidence": path.confidence
} }
def analyze_path_centrality(self, project_id: str) -> List[Dict]: def analyze_path_centrality(self, project_id: str) -> list[dict]:
""" """
分析项目中实体的路径中心性(桥接程度) 分析项目中实体的路径中心性(桥接程度)
@@ -1558,7 +1557,7 @@ class KnowledgeGapDetection:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def analyze_project(self, project_id: str) -> List[KnowledgeGap]: def analyze_project(self, project_id: str) -> list[KnowledgeGap]:
""" """
分析项目中的知识缺口 分析项目中的知识缺口
@@ -1591,7 +1590,7 @@ class KnowledgeGapDetection:
return gaps return gaps
def _check_entity_attribute_completeness(self, project_id: str) -> List[KnowledgeGap]: def _check_entity_attribute_completeness(self, project_id: str) -> list[KnowledgeGap]:
"""检查实体属性完整性""" """检查实体属性完整性"""
conn = self._get_conn() conn = self._get_conn()
gaps = [] gaps = []
@@ -1661,7 +1660,7 @@ class KnowledgeGapDetection:
conn.close() conn.close()
return gaps return gaps
def _check_relation_sparsity(self, project_id: str) -> List[KnowledgeGap]: def _check_relation_sparsity(self, project_id: str) -> list[KnowledgeGap]:
"""检查关系稀疏度""" """检查关系稀疏度"""
conn = self._get_conn() conn = self._get_conn()
gaps = [] gaps = []
@@ -1720,7 +1719,7 @@ class KnowledgeGapDetection:
conn.close() conn.close()
return gaps return gaps
def _check_isolated_entities(self, project_id: str) -> List[KnowledgeGap]: def _check_isolated_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查孤立实体(没有任何关系)""" """检查孤立实体(没有任何关系)"""
conn = self._get_conn() conn = self._get_conn()
gaps = [] gaps = []
@@ -1756,7 +1755,7 @@ class KnowledgeGapDetection:
conn.close() conn.close()
return gaps return gaps
def _check_incomplete_entities(self, project_id: str) -> List[KnowledgeGap]: def _check_incomplete_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查不完整实体(缺少名称、类型或定义)""" """检查不完整实体(缺少名称、类型或定义)"""
conn = self._get_conn() conn = self._get_conn()
gaps = [] gaps = []
@@ -1788,7 +1787,7 @@ class KnowledgeGapDetection:
conn.close() conn.close()
return gaps return gaps
def _check_missing_key_entities(self, project_id: str) -> List[KnowledgeGap]: def _check_missing_key_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查可能缺失的关键实体""" """检查可能缺失的关键实体"""
conn = self._get_conn() conn = self._get_conn()
gaps = [] gaps = []
@@ -1841,7 +1840,7 @@ class KnowledgeGapDetection:
conn.close() conn.close()
return gaps[:10] # 限制数量 return gaps[:10] # 限制数量
def generate_completeness_report(self, project_id: str) -> Dict: def generate_completeness_report(self, project_id: str) -> dict:
""" """
生成知识完整性报告 生成知识完整性报告
@@ -1898,7 +1897,7 @@ class KnowledgeGapDetection:
"recommendations": self._generate_recommendations(gaps) "recommendations": self._generate_recommendations(gaps)
} }
def _generate_recommendations(self, gaps: List[KnowledgeGap]) -> List[str]: def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]:
"""生成改进建议""" """生成改进建议"""
recommendations = [] recommendations = []
@@ -1941,8 +1940,8 @@ class SearchManager:
self.path_discovery = EntityPathDiscovery(db_path) self.path_discovery = EntityPathDiscovery(db_path)
self.gap_detection = KnowledgeGapDetection(db_path) self.gap_detection = KnowledgeGapDetection(db_path)
def hybrid_search(self, query: str, project_id: Optional[str] = None, def hybrid_search(self, query: str, project_id: str | None = None,
limit: int = 20) -> Dict: limit: int = 20) -> dict:
""" """
混合搜索(全文 + 语义) 混合搜索(全文 + 语义)
@@ -2014,7 +2013,7 @@ class SearchManager:
"results": results[:limit] "results": results[:limit]
} }
def index_project(self, project_id: str) -> Dict: def index_project(self, project_id: str) -> dict:
""" """
为项目建立所有索引 为项目建立所有索引
@@ -2071,7 +2070,7 @@ class SearchManager:
"semantic": semantic_stats "semantic": semantic_stats
} }
def get_search_stats(self, project_id: Optional[str] = None) -> Dict: def get_search_stats(self, project_id: str | None = None) -> dict:
"""获取搜索统计信息""" """获取搜索统计信息"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -2126,28 +2125,28 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
# 便捷函数 # 便捷函数
def fulltext_search(query: str, project_id: Optional[str] = None, def fulltext_search(query: str, project_id: str | None = None,
limit: int = 20) -> List[SearchResult]: limit: int = 20) -> list[SearchResult]:
"""全文搜索便捷函数""" """全文搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit) return manager.fulltext_search.search(query, project_id, limit=limit)
def semantic_search(query: str, project_id: Optional[str] = None, def semantic_search(query: str, project_id: str | None = None,
top_k: int = 10) -> List[SemanticSearchResult]: top_k: int = 10) -> list[SemanticSearchResult]:
"""语义搜索便捷函数""" """语义搜索便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k) return manager.semantic_search.search(query, project_id, top_k=top_k)
def find_entity_path(source_id: str, target_id: str, def find_entity_path(source_id: str, target_id: str,
max_depth: int = 5) -> Optional[EntityPath]: max_depth: int = 5) -> EntityPath | None:
"""查找实体路径便捷函数""" """查找实体路径便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth) return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth)
def detect_knowledge_gaps(project_id: str) -> List[KnowledgeGap]: def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]:
"""知识缺口检测便捷函数""" """知识缺口检测便捷函数"""
manager = get_search_manager() manager = get_search_manager()
return manager.gap_detection.analyze_project(project_id) return manager.gap_detection.analyze_project(project_id)

View File

@@ -3,16 +3,16 @@ InsightFlow Phase 7 Task 3: 数据安全与合规模块
Security Manager - 端到端加密、数据脱敏、审计日志 Security Manager - 端到端加密、数据脱敏、审计日志
""" """
import json
import hashlib
import secrets
import base64 import base64
import hashlib
import json
import re import re
from datetime import datetime, timedelta import secrets
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field, asdict
from enum import Enum
import sqlite3 import sqlite3
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
# 加密相关 # 加密相关
try: try:
@@ -71,19 +71,19 @@ class AuditLog:
"""审计日志条目""" """审计日志条目"""
id: str id: str
action_type: str action_type: str
user_id: Optional[str] = None user_id: str | None = None
user_ip: Optional[str] = None user_ip: str | None = None
user_agent: Optional[str] = None user_agent: str | None = None
resource_type: Optional[str] = None # project, entity, transcript, etc. resource_type: str | None = None # project, entity, transcript, etc.
resource_id: Optional[str] = None resource_id: str | None = None
action_details: Optional[str] = None # JSON string action_details: str | None = None # JSON string
before_value: Optional[str] = None before_value: str | None = None
after_value: Optional[str] = None after_value: str | None = None
success: bool = True success: bool = True
error_message: Optional[str] = None error_message: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -95,12 +95,12 @@ class EncryptionConfig:
is_enabled: bool = False is_enabled: bool = False
encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305 encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305
key_derivation: str = "pbkdf2" # pbkdf2, argon2 key_derivation: str = "pbkdf2" # pbkdf2, argon2
master_key_hash: Optional[str] = None # 主密钥哈希(用于验证) master_key_hash: str | None = None # 主密钥哈希(用于验证)
salt: Optional[str] = None salt: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -115,11 +115,11 @@ class MaskingRule:
replacement: str # 替换模板,如 "****" replacement: str # 替换模板,如 "****"
is_active: bool = True is_active: bool = True
priority: int = 0 priority: int = 0
description: Optional[str] = None description: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -129,18 +129,18 @@ class DataAccessPolicy:
id: str id: str
project_id: str project_id: str
name: str name: str
description: Optional[str] = None description: str | None = None
allowed_users: Optional[str] = None # JSON array of user IDs allowed_users: str | None = None # JSON array of user IDs
allowed_roles: Optional[str] = None # JSON array of roles allowed_roles: str | None = None # JSON array of roles
allowed_ips: Optional[str] = None # JSON array of IP patterns allowed_ips: str | None = None # JSON array of IP patterns
time_restrictions: Optional[str] = None # JSON: {"start_time": "09:00", "end_time": "18:00"} time_restrictions: str | None = None # JSON: {"start_time": "09:00", "end_time": "18:00"}
max_access_count: Optional[int] = None # 最大访问次数 max_access_count: int | None = None # 最大访问次数
require_approval: bool = False require_approval: bool = False
is_active: bool = True is_active: bool = True
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -150,14 +150,14 @@ class AccessRequest:
id: str id: str
policy_id: str policy_id: str
user_id: str user_id: str
request_reason: Optional[str] = None request_reason: str | None = None
status: str = "pending" # pending, approved, rejected, expired status: str = "pending" # pending, approved, rejected, expired
approved_by: Optional[str] = None approved_by: str | None = None
approved_at: Optional[str] = None approved_at: str | None = None
expires_at: Optional[str] = None expires_at: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -196,7 +196,7 @@ class SecurityManager:
self.db_path = db_path self.db_path = db_path
self.db_path = db_path self.db_path = db_path
# 预编译正则缓存 # 预编译正则缓存
self._compiled_patterns: Dict[str, re.Pattern] = {} self._compiled_patterns: dict[str, re.Pattern] = {}
self._local = {} self._local = {}
self._init_db() self._init_db()
@@ -317,16 +317,16 @@ class SecurityManager:
def log_audit( def log_audit(
self, self,
action_type: AuditActionType, action_type: AuditActionType,
user_id: Optional[str] = None, user_id: str | None = None,
user_ip: Optional[str] = None, user_ip: str | None = None,
user_agent: Optional[str] = None, user_agent: str | None = None,
resource_type: Optional[str] = None, resource_type: str | None = None,
resource_id: Optional[str] = None, resource_id: str | None = None,
action_details: Optional[Dict] = None, action_details: dict | None = None,
before_value: Optional[str] = None, before_value: str | None = None,
after_value: Optional[str] = None, after_value: str | None = None,
success: bool = True, success: bool = True,
error_message: Optional[str] = None error_message: str | None = None
) -> AuditLog: ) -> AuditLog:
"""记录审计日志""" """记录审计日志"""
log = AuditLog( log = AuditLog(
@@ -364,16 +364,16 @@ class SecurityManager:
def get_audit_logs( def get_audit_logs(
self, self,
user_id: Optional[str] = None, user_id: str | None = None,
resource_type: Optional[str] = None, resource_type: str | None = None,
resource_id: Optional[str] = None, resource_id: str | None = None,
action_type: Optional[str] = None, action_type: str | None = None,
start_time: Optional[str] = None, start_time: str | None = None,
end_time: Optional[str] = None, end_time: str | None = None,
success: Optional[bool] = None, success: bool | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0
) -> List[AuditLog]: ) -> list[AuditLog]:
"""查询审计日志""" """查询审计日志"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -438,9 +438,9 @@ class SecurityManager:
def get_audit_stats( def get_audit_stats(
self, self,
start_time: Optional[str] = None, start_time: str | None = None,
end_time: Optional[str] = None end_time: str | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""获取审计统计""" """获取审计统计"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -633,7 +633,7 @@ class SecurityManager:
return key_hash == stored_hash return key_hash == stored_hash
def get_encryption_config(self, project_id: str) -> Optional[EncryptionConfig]: def get_encryption_config(self, project_id: str) -> EncryptionConfig | None:
"""获取加密配置""" """获取加密配置"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -664,8 +664,8 @@ class SecurityManager:
self, self,
data: str, data: str,
password: str, password: str,
salt: Optional[str] = None salt: str | None = None
) -> Tuple[str, str]: ) -> tuple[str, str]:
"""加密数据""" """加密数据"""
if not CRYPTO_AVAILABLE: if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available") raise RuntimeError("cryptography library not available")
@@ -702,9 +702,9 @@ class SecurityManager:
project_id: str, project_id: str,
name: str, name: str,
rule_type: MaskingRuleType, rule_type: MaskingRuleType,
pattern: Optional[str] = None, pattern: str | None = None,
replacement: Optional[str] = None, replacement: str | None = None,
description: Optional[str] = None, description: str | None = None,
priority: int = 0 priority: int = 0
) -> MaskingRule: ) -> MaskingRule:
"""创建脱敏规则""" """创建脱敏规则"""
@@ -756,7 +756,7 @@ class SecurityManager:
self, self,
project_id: str, project_id: str,
active_only: bool = True active_only: bool = True
) -> List[MaskingRule]: ) -> list[MaskingRule]:
"""获取脱敏规则""" """获取脱敏规则"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -795,7 +795,7 @@ class SecurityManager:
self, self,
rule_id: str, rule_id: str,
**kwargs **kwargs
) -> Optional[MaskingRule]: ) -> MaskingRule | None:
"""更新脱敏规则""" """更新脱敏规则"""
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
@@ -868,7 +868,7 @@ class SecurityManager:
self, self,
text: str, text: str,
project_id: str, project_id: str,
rule_types: Optional[List[MaskingRuleType]] = None rule_types: list[MaskingRuleType] | None = None
) -> str: ) -> str:
"""应用脱敏规则到文本""" """应用脱敏规则到文本"""
rules = self.get_masking_rules(project_id) rules = self.get_masking_rules(project_id)
@@ -897,9 +897,9 @@ class SecurityManager:
def apply_masking_to_entity( def apply_masking_to_entity(
self, self,
entity_data: Dict[str, Any], entity_data: dict[str, Any],
project_id: str project_id: str
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""对实体数据应用脱敏""" """对实体数据应用脱敏"""
masked_data = entity_data.copy() masked_data = entity_data.copy()
@@ -918,12 +918,12 @@ class SecurityManager:
self, self,
project_id: str, project_id: str,
name: str, name: str,
description: Optional[str] = None, description: str | None = None,
allowed_users: Optional[List[str]] = None, allowed_users: list[str] | None = None,
allowed_roles: Optional[List[str]] = None, allowed_roles: list[str] | None = None,
allowed_ips: Optional[List[str]] = None, allowed_ips: list[str] | None = None,
time_restrictions: Optional[Dict] = None, time_restrictions: dict | None = None,
max_access_count: Optional[int] = None, max_access_count: int | None = None,
require_approval: bool = False require_approval: bool = False
) -> DataAccessPolicy: ) -> DataAccessPolicy:
"""创建数据访问策略""" """创建数据访问策略"""
@@ -966,7 +966,7 @@ class SecurityManager:
self, self,
project_id: str, project_id: str,
active_only: bool = True active_only: bool = True
) -> List[DataAccessPolicy]: ) -> list[DataAccessPolicy]:
"""获取数据访问策略""" """获取数据访问策略"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -1005,8 +1005,8 @@ class SecurityManager:
self, self,
policy_id: str, policy_id: str,
user_id: str, user_id: str,
user_ip: Optional[str] = None user_ip: str | None = None
) -> Tuple[bool, Optional[str]]: ) -> tuple[bool, str | None]:
"""检查访问权限""" """检查访问权限"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -1107,7 +1107,7 @@ class SecurityManager:
self, self,
policy_id: str, policy_id: str,
user_id: str, user_id: str,
request_reason: Optional[str] = None, request_reason: str | None = None,
expires_hours: int = 24 expires_hours: int = 24
) -> AccessRequest: ) -> AccessRequest:
"""创建访问请求""" """创建访问请求"""
@@ -1142,7 +1142,7 @@ class SecurityManager:
request_id: str, request_id: str,
approved_by: str, approved_by: str,
expires_hours: int = 24 expires_hours: int = 24
) -> Optional[AccessRequest]: ) -> AccessRequest | None:
"""批准访问请求""" """批准访问请求"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -1182,7 +1182,7 @@ class SecurityManager:
self, self,
request_id: str, request_id: str,
rejected_by: str rejected_by: str
) -> Optional[AccessRequest]: ) -> AccessRequest | None:
"""拒绝访问请求""" """拒绝访问请求"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()

View File

@@ -10,19 +10,19 @@ InsightFlow Phase 8 - 订阅与计费系统模块
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import sqlite3
import json import json
import uuid
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import logging import logging
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SubscriptionStatus(str, Enum): class SubscriptionStatus(StrEnum):
"""订阅状态""" """订阅状态"""
ACTIVE = "active" # 活跃 ACTIVE = "active" # 活跃
CANCELLED = "cancelled" # 已取消 CANCELLED = "cancelled" # 已取消
@@ -32,7 +32,7 @@ class SubscriptionStatus(str, Enum):
PENDING = "pending" # 待支付 PENDING = "pending" # 待支付
class PaymentProvider(str, Enum): class PaymentProvider(StrEnum):
"""支付提供商""" """支付提供商"""
STRIPE = "stripe" # Stripe STRIPE = "stripe" # Stripe
ALIPAY = "alipay" # 支付宝 ALIPAY = "alipay" # 支付宝
@@ -40,7 +40,7 @@ class PaymentProvider(str, Enum):
BANK_TRANSFER = "bank_transfer" # 银行转账 BANK_TRANSFER = "bank_transfer" # 银行转账
class PaymentStatus(str, Enum): class PaymentStatus(StrEnum):
"""支付状态""" """支付状态"""
PENDING = "pending" # 待支付 PENDING = "pending" # 待支付
PROCESSING = "processing" # 处理中 PROCESSING = "processing" # 处理中
@@ -50,7 +50,7 @@ class PaymentStatus(str, Enum):
PARTIAL_REFUNDED = "partial_refunded" # 部分退款 PARTIAL_REFUNDED = "partial_refunded" # 部分退款
class InvoiceStatus(str, Enum): class InvoiceStatus(StrEnum):
"""发票状态""" """发票状态"""
DRAFT = "draft" # 草稿 DRAFT = "draft" # 草稿
ISSUED = "issued" # 已开具 ISSUED = "issued" # 已开具
@@ -60,7 +60,7 @@ class InvoiceStatus(str, Enum):
CREDIT_NOTE = "credit_note" # 贷项通知单 CREDIT_NOTE = "credit_note" # 贷项通知单
class RefundStatus(str, Enum): class RefundStatus(StrEnum):
"""退款状态""" """退款状态"""
PENDING = "pending" # 待处理 PENDING = "pending" # 待处理
APPROVED = "approved" # 已批准 APPROVED = "approved" # 已批准
@@ -79,12 +79,12 @@ class SubscriptionPlan:
price_monthly: float # 月付价格 price_monthly: float # 月付价格
price_yearly: float # 年付价格 price_yearly: float # 年付价格
currency: str # CNY/USD currency: str # CNY/USD
features: List[str] # 功能列表 features: list[str] # 功能列表
limits: Dict[str, Any] # 资源限制 limits: dict[str, Any] # 资源限制
is_active: bool is_active: bool
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
metadata: Dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
@@ -97,14 +97,14 @@ class Subscription:
current_period_start: datetime current_period_start: datetime
current_period_end: datetime current_period_end: datetime
cancel_at_period_end: bool cancel_at_period_end: bool
canceled_at: Optional[datetime] canceled_at: datetime | None
trial_start: Optional[datetime] trial_start: datetime | None
trial_end: Optional[datetime] trial_end: datetime | None
payment_provider: Optional[str] payment_provider: str | None
provider_subscription_id: Optional[str] # 支付提供商的订阅ID provider_subscription_id: str | None # 支付提供商的订阅ID
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
metadata: Dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
@@ -117,8 +117,8 @@ class UsageRecord:
unit: str # minutes/mb/count unit: str # minutes/mb/count
recorded_at: datetime recorded_at: datetime
cost: float # 费用 cost: float # 费用
description: Optional[str] description: str | None
metadata: Dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
@@ -126,18 +126,18 @@ class Payment:
"""支付记录数据类""" """支付记录数据类"""
id: str id: str
tenant_id: str tenant_id: str
subscription_id: Optional[str] subscription_id: str | None
invoice_id: Optional[str] invoice_id: str | None
amount: float amount: float
currency: str currency: str
provider: str provider: str
provider_payment_id: Optional[str] provider_payment_id: str | None
status: str status: str
payment_method: Optional[str] payment_method: str | None
payment_details: Dict[str, Any] payment_details: dict[str, Any]
paid_at: Optional[datetime] paid_at: datetime | None
failed_at: Optional[datetime] failed_at: datetime | None
failure_reason: Optional[str] failure_reason: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -147,7 +147,7 @@ class Invoice:
"""发票数据类""" """发票数据类"""
id: str id: str
tenant_id: str tenant_id: str
subscription_id: Optional[str] subscription_id: str | None
invoice_number: str invoice_number: str
status: str status: str
amount_due: float amount_due: float
@@ -156,11 +156,11 @@ class Invoice:
period_start: datetime period_start: datetime
period_end: datetime period_end: datetime
description: str description: str
line_items: List[Dict[str, Any]] line_items: list[dict[str, Any]]
due_date: datetime due_date: datetime
paid_at: Optional[datetime] paid_at: datetime | None
voided_at: Optional[datetime] voided_at: datetime | None
void_reason: Optional[str] void_reason: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -171,18 +171,18 @@ class Refund:
id: str id: str
tenant_id: str tenant_id: str
payment_id: str payment_id: str
invoice_id: Optional[str] invoice_id: str | None
amount: float amount: float
currency: str currency: str
reason: str reason: str
status: str status: str
requested_by: str requested_by: str
requested_at: datetime requested_at: datetime
approved_by: Optional[str] approved_by: str | None
approved_at: Optional[str] approved_at: str | None
completed_at: Optional[datetime] completed_at: datetime | None
provider_refund_id: Optional[str] provider_refund_id: str | None
metadata: Dict[str, Any] metadata: dict[str, Any]
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -199,7 +199,7 @@ class BillingHistory:
reference_id: str # 关联的订阅/支付/退款ID reference_id: str # 关联的订阅/支付/退款ID
balance_after: float # 操作后余额 balance_after: float # 操作后余额
created_at: datetime created_at: datetime
metadata: Dict[str, Any] metadata: dict[str, Any]
class SubscriptionManager: class SubscriptionManager:
@@ -542,7 +542,7 @@ class SubscriptionManager:
# ==================== 订阅计划管理 ==================== # ==================== 订阅计划管理 ====================
def get_plan(self, plan_id: str) -> Optional[SubscriptionPlan]: def get_plan(self, plan_id: str) -> SubscriptionPlan | None:
"""获取订阅计划""" """获取订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -557,7 +557,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def get_plan_by_tier(self, tier: str) -> Optional[SubscriptionPlan]: def get_plan_by_tier(self, tier: str) -> SubscriptionPlan | None:
"""通过层级获取订阅计划""" """通过层级获取订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -572,7 +572,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def list_plans(self, include_inactive: bool = False) -> List[SubscriptionPlan]: def list_plans(self, include_inactive: bool = False) -> list[SubscriptionPlan]:
"""列出所有订阅计划""" """列出所有订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -591,8 +591,8 @@ class SubscriptionManager:
def create_plan(self, name: str, tier: str, description: str, def create_plan(self, name: str, tier: str, description: str,
price_monthly: float, price_yearly: float, price_monthly: float, price_yearly: float,
currency: str = "CNY", features: List[str] = None, currency: str = "CNY", features: list[str] = None,
limits: Dict[str, Any] = None) -> SubscriptionPlan: limits: dict[str, Any] = None) -> SubscriptionPlan:
"""创建新订阅计划""" """创建新订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -639,7 +639,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def update_plan(self, plan_id: str, **kwargs) -> Optional[SubscriptionPlan]: def update_plan(self, plan_id: str, **kwargs) -> SubscriptionPlan | None:
"""更新订阅计划""" """更新订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -685,7 +685,7 @@ class SubscriptionManager:
# ==================== 订阅管理 ==================== # ==================== 订阅管理 ====================
def create_subscription(self, tenant_id: str, plan_id: str, def create_subscription(self, tenant_id: str, plan_id: str,
payment_provider: Optional[str] = None, payment_provider: str | None = None,
trial_days: int = 0, trial_days: int = 0,
billing_cycle: str = "monthly") -> Subscription: billing_cycle: str = "monthly") -> Subscription:
"""创建新订阅""" """创建新订阅"""
@@ -785,7 +785,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def get_subscription(self, subscription_id: str) -> Optional[Subscription]: def get_subscription(self, subscription_id: str) -> Subscription | None:
"""获取订阅信息""" """获取订阅信息"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -800,7 +800,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def get_tenant_subscription(self, tenant_id: str) -> Optional[Subscription]: def get_tenant_subscription(self, tenant_id: str) -> Subscription | None:
"""获取租户的当前订阅""" """获取租户的当前订阅"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -819,7 +819,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def update_subscription(self, subscription_id: str, **kwargs) -> Optional[Subscription]: def update_subscription(self, subscription_id: str, **kwargs) -> Subscription | None:
"""更新订阅""" """更新订阅"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -862,7 +862,7 @@ class SubscriptionManager:
conn.close() conn.close()
def cancel_subscription(self, subscription_id: str, def cancel_subscription(self, subscription_id: str,
at_period_end: bool = True) -> Optional[Subscription]: at_period_end: bool = True) -> Subscription | None:
"""取消订阅""" """取消订阅"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -904,7 +904,7 @@ class SubscriptionManager:
conn.close() conn.close()
def change_plan(self, subscription_id: str, new_plan_id: str, def change_plan(self, subscription_id: str, new_plan_id: str,
prorate: bool = True) -> Optional[Subscription]: prorate: bool = True) -> Subscription | None:
"""更改订阅计划""" """更改订阅计划"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -950,8 +950,8 @@ class SubscriptionManager:
def record_usage(self, tenant_id: str, resource_type: str, def record_usage(self, tenant_id: str, resource_type: str,
quantity: float, unit: str, quantity: float, unit: str,
description: Optional[str] = None, description: str | None = None,
metadata: Optional[Dict] = None) -> UsageRecord: metadata: dict | None = None) -> UsageRecord:
"""记录用量""" """记录用量"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -989,8 +989,8 @@ class SubscriptionManager:
conn.close() conn.close()
def get_usage_summary(self, tenant_id: str, def get_usage_summary(self, tenant_id: str,
start_date: Optional[datetime] = None, start_date: datetime | None = None,
end_date: Optional[datetime] = None) -> Dict[str, Any]: end_date: datetime | None = None) -> dict[str, Any]:
"""获取用量汇总""" """获取用量汇总"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1061,10 +1061,10 @@ class SubscriptionManager:
# ==================== 支付管理 ==================== # ==================== 支付管理 ====================
def create_payment(self, tenant_id: str, amount: float, currency: str, def create_payment(self, tenant_id: str, amount: float, currency: str,
provider: str, subscription_id: Optional[str] = None, provider: str, subscription_id: str | None = None,
invoice_id: Optional[str] = None, invoice_id: str | None = None,
payment_method: Optional[str] = None, payment_method: str | None = None,
payment_details: Optional[Dict] = None) -> Payment: payment_details: dict | None = None) -> Payment:
"""创建支付记录""" """创建支付记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1113,7 +1113,7 @@ class SubscriptionManager:
conn.close() conn.close()
def confirm_payment(self, payment_id: str, def confirm_payment(self, payment_id: str,
provider_payment_id: Optional[str] = None) -> Optional[Payment]: provider_payment_id: str | None = None) -> Payment | None:
"""确认支付完成""" """确认支付完成"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1160,7 +1160,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def fail_payment(self, payment_id: str, reason: str) -> Optional[Payment]: def fail_payment(self, payment_id: str, reason: str) -> Payment | None:
"""标记支付失败""" """标记支付失败"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1179,7 +1179,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def get_payment(self, payment_id: str) -> Optional[Payment]: def get_payment(self, payment_id: str) -> Payment | None:
"""获取支付记录""" """获取支付记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1187,8 +1187,8 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def list_payments(self, tenant_id: str, status: Optional[str] = None, def list_payments(self, tenant_id: str, status: str | None = None,
limit: int = 100, offset: int = 0) -> List[Payment]: limit: int = 100, offset: int = 0) -> list[Payment]:
"""列出支付记录""" """列出支付记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1212,7 +1212,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Optional[Payment]: def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None:
"""内部方法:获取支付记录""" """内部方法:获取支付记录"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id,)) cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id,))
@@ -1225,10 +1225,10 @@ class SubscriptionManager:
# ==================== 发票管理 ==================== # ==================== 发票管理 ====================
def _create_invoice_internal(self, conn: sqlite3.Connection, tenant_id: str, def _create_invoice_internal(self, conn: sqlite3.Connection, tenant_id: str,
subscription_id: Optional[str], amount: float, subscription_id: str | None, amount: float,
currency: str, period_start: datetime, currency: str, period_start: datetime,
period_end: datetime, description: str, period_end: datetime, description: str,
line_items: Optional[List[Dict]] = None) -> Invoice: line_items: list[dict] | None = None) -> Invoice:
"""内部方法:创建发票""" """内部方法:创建发票"""
invoice_id = str(uuid.uuid4()) invoice_id = str(uuid.uuid4())
invoice_number = self._generate_invoice_number() invoice_number = self._generate_invoice_number()
@@ -1275,7 +1275,7 @@ class SubscriptionManager:
return invoice return invoice
def get_invoice(self, invoice_id: str) -> Optional[Invoice]: def get_invoice(self, invoice_id: str) -> Invoice | None:
"""获取发票""" """获取发票"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1290,7 +1290,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def get_invoice_by_number(self, invoice_number: str) -> Optional[Invoice]: def get_invoice_by_number(self, invoice_number: str) -> Invoice | None:
"""通过发票号获取发票""" """通过发票号获取发票"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1305,8 +1305,8 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def list_invoices(self, tenant_id: str, status: Optional[str] = None, def list_invoices(self, tenant_id: str, status: str | None = None,
limit: int = 100, offset: int = 0) -> List[Invoice]: limit: int = 100, offset: int = 0) -> list[Invoice]:
"""列出发票""" """列出发票"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1330,7 +1330,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def void_invoice(self, invoice_id: str, reason: str) -> Optional[Invoice]: def void_invoice(self, invoice_id: str, reason: str) -> Invoice | None:
"""作废发票""" """作废发票"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1442,7 +1442,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def approve_refund(self, refund_id: str, approved_by: str) -> Optional[Refund]: def approve_refund(self, refund_id: str, approved_by: str) -> Refund | None:
"""批准退款""" """批准退款"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1469,7 +1469,7 @@ class SubscriptionManager:
conn.close() conn.close()
def complete_refund(self, refund_id: str, def complete_refund(self, refund_id: str,
provider_refund_id: Optional[str] = None) -> Optional[Refund]: provider_refund_id: str | None = None) -> Refund | None:
"""完成退款""" """完成退款"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1507,7 +1507,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def reject_refund(self, refund_id: str, reason: str) -> Optional[Refund]: def reject_refund(self, refund_id: str, reason: str) -> Refund | None:
"""拒绝退款""" """拒绝退款"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1530,7 +1530,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def get_refund(self, refund_id: str) -> Optional[Refund]: def get_refund(self, refund_id: str) -> Refund | None:
"""获取退款记录""" """获取退款记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1538,8 +1538,8 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def list_refunds(self, tenant_id: str, status: Optional[str] = None, def list_refunds(self, tenant_id: str, status: str | None = None,
limit: int = 100, offset: int = 0) -> List[Refund]: limit: int = 100, offset: int = 0) -> list[Refund]:
"""列出退款记录""" """列出退款记录"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1563,7 +1563,7 @@ class SubscriptionManager:
finally: finally:
conn.close() conn.close()
def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Optional[Refund]: def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None:
"""内部方法:获取退款记录""" """内部方法:获取退款记录"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id,)) cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id,))
@@ -1593,9 +1593,9 @@ class SubscriptionManager:
)) ))
def get_billing_history(self, tenant_id: str, def get_billing_history(self, tenant_id: str,
start_date: Optional[datetime] = None, start_date: datetime | None = None,
end_date: Optional[datetime] = None, end_date: datetime | None = None,
limit: int = 100, offset: int = 0) -> List[BillingHistory]: limit: int = 100, offset: int = 0) -> list[BillingHistory]:
"""获取账单历史""" """获取账单历史"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1626,7 +1626,7 @@ class SubscriptionManager:
def create_stripe_checkout_session(self, tenant_id: str, plan_id: str, def create_stripe_checkout_session(self, tenant_id: str, plan_id: str,
success_url: str, cancel_url: str, success_url: str, cancel_url: str,
billing_cycle: str = "monthly") -> Dict[str, Any]: billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建 Stripe Checkout 会话(占位实现)""" """创建 Stripe Checkout 会话(占位实现)"""
# 这里应该集成 Stripe SDK # 这里应该集成 Stripe SDK
# 简化实现,返回模拟数据 # 简化实现,返回模拟数据
@@ -1638,7 +1638,7 @@ class SubscriptionManager:
} }
def create_alipay_order(self, tenant_id: str, plan_id: str, def create_alipay_order(self, tenant_id: str, plan_id: str,
billing_cycle: str = "monthly") -> Dict[str, Any]: billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建支付宝订单(占位实现)""" """创建支付宝订单(占位实现)"""
# 这里应该集成支付宝 SDK # 这里应该集成支付宝 SDK
plan = self.get_plan(plan_id) plan = self.get_plan(plan_id)
@@ -1654,7 +1654,7 @@ class SubscriptionManager:
} }
def create_wechat_order(self, tenant_id: str, plan_id: str, def create_wechat_order(self, tenant_id: str, plan_id: str,
billing_cycle: str = "monthly") -> Dict[str, Any]: billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建微信支付订单(占位实现)""" """创建微信支付订单(占位实现)"""
# 这里应该集成微信支付 SDK # 这里应该集成微信支付 SDK
plan = self.get_plan(plan_id) plan = self.get_plan(plan_id)
@@ -1669,7 +1669,7 @@ class SubscriptionManager:
"provider": "wechat" "provider": "wechat"
} }
def handle_webhook(self, provider: str, payload: Dict[str, Any]) -> bool: def handle_webhook(self, provider: str, payload: dict[str, Any]) -> bool:
"""处理支付提供商的 Webhook占位实现""" """处理支付提供商的 Webhook占位实现"""
# 这里应该实现实际的 Webhook 处理逻辑 # 这里应该实现实际的 Webhook 处理逻辑
logger.info(f"Received webhook from {provider}: {payload.get('event_type', 'unknown')}") logger.info(f"Received webhook from {provider}: {payload.get('event_type', 'unknown')}")

View File

@@ -11,16 +11,16 @@ InsightFlow Phase 8 - 多租户 SaaS 架构管理模块
作者: InsightFlow Team 作者: InsightFlow Team
""" """
import sqlite3
import json
import uuid
import hashlib import hashlib
import re import json
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
import logging import logging
import re
import sqlite3
import uuid
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,7 +44,7 @@ class TenantLimits:
UNLIMITED = -1 UNLIMITED = -1
class TenantStatus(str, Enum): class TenantStatus(StrEnum):
"""租户状态""" """租户状态"""
ACTIVE = "active" # 活跃 ACTIVE = "active" # 活跃
SUSPENDED = "suspended" # 暂停 SUSPENDED = "suspended" # 暂停
@@ -53,14 +53,14 @@ class TenantStatus(str, Enum):
PENDING = "pending" # 待激活 PENDING = "pending" # 待激活
class TenantTier(str, Enum): class TenantTier(StrEnum):
"""租户订阅层级""" """租户订阅层级"""
FREE = "free" # 免费版 FREE = "free" # 免费版
PRO = "pro" # 专业版 PRO = "pro" # 专业版
ENTERPRISE = "enterprise" # 企业版 ENTERPRISE = "enterprise" # 企业版
class TenantRole(str, Enum): class TenantRole(StrEnum):
"""租户角色""" """租户角色"""
OWNER = "owner" # 所有者 OWNER = "owner" # 所有者
ADMIN = "admin" # 管理员 ADMIN = "admin" # 管理员
@@ -68,7 +68,7 @@ class TenantRole(str, Enum):
VIEWER = "viewer" # 查看者 VIEWER = "viewer" # 查看者
class DomainStatus(str, Enum): class DomainStatus(StrEnum):
"""域名状态""" """域名状态"""
PENDING = "pending" # 待验证 PENDING = "pending" # 待验证
VERIFIED = "verified" # 已验证 VERIFIED = "verified" # 已验证
@@ -82,16 +82,16 @@ class Tenant:
id: str id: str
name: str name: str
slug: str # URL 友好的唯一标识 slug: str # URL 友好的唯一标识
description: Optional[str] description: str | None
tier: str # free/pro/enterprise tier: str # free/pro/enterprise
status: str # active/suspended/trial/expired/pending status: str # active/suspended/trial/expired/pending
owner_id: str # 所有者用户ID owner_id: str # 所有者用户ID
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
expires_at: Optional[datetime] # 订阅过期时间 expires_at: datetime | None # 订阅过期时间
settings: Dict[str, Any] # 租户级设置 settings: dict[str, Any] # 租户级设置
resource_limits: Dict[str, Any] # 资源限制 resource_limits: dict[str, Any] # 资源限制
metadata: Dict[str, Any] # 元数据 metadata: dict[str, Any] # 元数据
@dataclass @dataclass
@@ -103,12 +103,12 @@ class TenantDomain:
status: str # pending/verified/failed/expired status: str # pending/verified/failed/expired
verification_token: str # 验证令牌 verification_token: str # 验证令牌
verification_method: str # dns/file verification_method: str # dns/file
verified_at: Optional[datetime] verified_at: datetime | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
is_primary: bool # 是否主域名 is_primary: bool # 是否主域名
ssl_enabled: bool # SSL 是否启用 ssl_enabled: bool # SSL 是否启用
ssl_expires_at: Optional[datetime] ssl_expires_at: datetime | None
@dataclass @dataclass
@@ -116,14 +116,14 @@ class TenantBranding:
"""租户品牌配置数据类""" """租户品牌配置数据类"""
id: str id: str
tenant_id: str tenant_id: str
logo_url: Optional[str] # Logo URL logo_url: str | None # Logo URL
favicon_url: Optional[str] # Favicon URL favicon_url: str | None # Favicon URL
primary_color: Optional[str] # 主题主色 primary_color: str | None # 主题主色
secondary_color: Optional[str] # 主题次色 secondary_color: str | None # 主题次色
custom_css: Optional[str] # 自定义 CSS custom_css: str | None # 自定义 CSS
custom_js: Optional[str] # 自定义 JS custom_js: str | None # 自定义 JS
login_page_bg: Optional[str] # 登录页背景 login_page_bg: str | None # 登录页背景
email_template: Optional[str] # 邮件模板 email_template: str | None # 邮件模板
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -136,11 +136,11 @@ class TenantMember:
user_id: str user_id: str
email: str email: str
role: str # owner/admin/member/viewer role: str # owner/admin/member/viewer
permissions: List[str] # 具体权限列表 permissions: list[str] # 具体权限列表
invited_by: Optional[str] # 邀请者 invited_by: str | None # 邀请者
invited_at: datetime invited_at: datetime
joined_at: Optional[datetime] joined_at: datetime | None
last_active_at: Optional[datetime] last_active_at: datetime | None
status: str # active/pending/suspended status: str # active/pending/suspended
@@ -151,10 +151,10 @@ class TenantPermission:
tenant_id: str tenant_id: str
name: str # 权限名称 name: str # 权限名称
code: str # 权限代码 code: str # 权限代码
description: Optional[str] description: str | None
resource_type: str # project/entity/api/etc resource_type: str # project/entity/api/etc
actions: List[str] # create/read/update/delete/etc actions: list[str] # create/read/update/delete/etc
conditions: Optional[Dict] # 条件限制 conditions: dict | None # 条件限制
created_at: datetime created_at: datetime
@@ -381,8 +381,8 @@ class TenantManager:
def create_tenant(self, name: str, owner_id: str, def create_tenant(self, name: str, owner_id: str,
tier: str = "free", tier: str = "free",
description: Optional[str] = None, description: str | None = None,
settings: Optional[Dict] = None) -> Tenant: settings: dict | None = None) -> Tenant:
"""创建新租户""" """创建新租户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -436,7 +436,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def get_tenant(self, tenant_id: str) -> Optional[Tenant]: def get_tenant(self, tenant_id: str) -> Tenant | None:
"""获取租户信息""" """获取租户信息"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -451,7 +451,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def get_tenant_by_slug(self, slug: str) -> Optional[Tenant]: def get_tenant_by_slug(self, slug: str) -> Tenant | None:
"""通过 slug 获取租户""" """通过 slug 获取租户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -466,7 +466,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def get_tenant_by_domain(self, domain: str) -> Optional[Tenant]: def get_tenant_by_domain(self, domain: str) -> Tenant | None:
"""通过自定义域名获取租户""" """通过自定义域名获取租户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -486,11 +486,11 @@ class TenantManager:
conn.close() conn.close()
def update_tenant(self, tenant_id: str, def update_tenant(self, tenant_id: str,
name: Optional[str] = None, name: str | None = None,
description: Optional[str] = None, description: str | None = None,
tier: Optional[str] = None, tier: str | None = None,
status: Optional[str] = None, status: str | None = None,
settings: Optional[Dict] = None) -> Optional[Tenant]: settings: dict | None = None) -> Tenant | None:
"""更新租户信息""" """更新租户信息"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -548,9 +548,9 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def list_tenants(self, status: Optional[str] = None, def list_tenants(self, status: str | None = None,
tier: Optional[str] = None, tier: str | None = None,
limit: int = 100, offset: int = 0) -> List[Tenant]: limit: int = 100, offset: int = 0) -> list[Tenant]:
"""列出租户""" """列出租户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -689,7 +689,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def get_domain_verification_instructions(self, domain_id: str) -> Dict[str, Any]: def get_domain_verification_instructions(self, domain_id: str) -> dict[str, Any]:
"""获取域名验证指导""" """获取域名验证指导"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -739,7 +739,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def list_domains(self, tenant_id: str) -> List[TenantDomain]: def list_domains(self, tenant_id: str) -> list[TenantDomain]:
"""列出租户的所有域名""" """列出租户的所有域名"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -758,7 +758,7 @@ class TenantManager:
# ==================== 品牌白标管理 ==================== # ==================== 品牌白标管理 ====================
def get_branding(self, tenant_id: str) -> Optional[TenantBranding]: def get_branding(self, tenant_id: str) -> TenantBranding | None:
"""获取租户品牌配置""" """获取租户品牌配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -774,14 +774,14 @@ class TenantManager:
conn.close() conn.close()
def update_branding(self, tenant_id: str, def update_branding(self, tenant_id: str,
logo_url: Optional[str] = None, logo_url: str | None = None,
favicon_url: Optional[str] = None, favicon_url: str | None = None,
primary_color: Optional[str] = None, primary_color: str | None = None,
secondary_color: Optional[str] = None, secondary_color: str | None = None,
custom_css: Optional[str] = None, custom_css: str | None = None,
custom_js: Optional[str] = None, custom_js: str | None = None,
login_page_bg: Optional[str] = None, login_page_bg: str | None = None,
email_template: Optional[str] = None) -> TenantBranding: email_template: str | None = None) -> TenantBranding:
"""更新租户品牌配置""" """更新租户品牌配置"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -890,7 +890,7 @@ class TenantManager:
# ==================== 成员与权限管理 ==================== # ==================== 成员与权限管理 ====================
def invite_member(self, tenant_id: str, email: str, role: str, def invite_member(self, tenant_id: str, email: str, role: str,
invited_by: str, permissions: Optional[List[str]] = None) -> TenantMember: invited_by: str, permissions: list[str] | None = None) -> TenantMember:
"""邀请成员加入租户""" """邀请成员加入租户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -967,7 +967,7 @@ class TenantManager:
conn.close() conn.close()
def update_member_role(self, tenant_id: str, member_id: str, def update_member_role(self, tenant_id: str, member_id: str,
role: str, permissions: Optional[List[str]] = None) -> bool: role: str, permissions: list[str] | None = None) -> bool:
"""更新成员角色""" """更新成员角色"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -988,7 +988,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def list_members(self, tenant_id: str, status: Optional[str] = None) -> List[TenantMember]: def list_members(self, tenant_id: str, status: str | None = None) -> list[TenantMember]:
"""列出租户成员""" """列出租户成员"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1042,7 +1042,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def get_user_tenants(self, user_id: str) -> List[Dict[str, Any]]: def get_user_tenants(self, user_id: str) -> list[dict[str, Any]]:
"""获取用户所属的所有租户""" """获取用户所属的所有租户"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1108,8 +1108,8 @@ class TenantManager:
conn.close() conn.close()
def get_usage_stats(self, tenant_id: str, def get_usage_stats(self, tenant_id: str,
start_date: Optional[datetime] = None, start_date: datetime | None = None,
end_date: Optional[datetime] = None) -> Dict[str, Any]: end_date: datetime | None = None) -> dict[str, Any]:
"""获取使用统计""" """获取使用统计"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1165,7 +1165,7 @@ class TenantManager:
finally: finally:
conn.close() conn.close()
def check_resource_limit(self, tenant_id: str, resource_type: str) -> Tuple[bool, int, int]: def check_resource_limit(self, tenant_id: str, resource_type: str) -> tuple[bool, int, int]:
"""检查资源是否超限 """检查资源是否超限
Returns: Returns:
@@ -1288,7 +1288,7 @@ class TenantManager:
def _add_member_internal(self, conn: sqlite3.Connection, tenant_id: str, def _add_member_internal(self, conn: sqlite3.Connection, tenant_id: str,
user_id: str, email: str, role: TenantRole, user_id: str, email: str, role: TenantRole,
invited_by: Optional[str]): invited_by: str | None):
"""内部方法:添加成员""" """内部方法:添加成员"""
cursor = conn.cursor() cursor = conn.cursor()
member_id = str(uuid.uuid4()) member_id = str(uuid.uuid4())
@@ -1416,8 +1416,8 @@ class TenantManager:
class TenantContext: class TenantContext:
"""租户上下文管理器 - 用于请求级别的租户隔离""" """租户上下文管理器 - 用于请求级别的租户隔离"""
_current_tenant_id: Optional[str] = None _current_tenant_id: str | None = None
_current_user_id: Optional[str] = None _current_user_id: str | None = None
@classmethod @classmethod
def set_current_tenant(cls, tenant_id: str): def set_current_tenant(cls, tenant_id: str):
@@ -1425,7 +1425,7 @@ class TenantContext:
cls._current_tenant_id = tenant_id cls._current_tenant_id = tenant_id
@classmethod @classmethod
def get_current_tenant(cls) -> Optional[str]: def get_current_tenant(cls) -> str | None:
"""获取当前租户ID""" """获取当前租户ID"""
return cls._current_tenant_id return cls._current_tenant_id
@@ -1435,7 +1435,7 @@ class TenantContext:
cls._current_user_id = user_id cls._current_user_id = user_id
@classmethod @classmethod
def get_current_user(cls) -> Optional[str]: def get_current_user(cls) -> str | None:
"""获取当前用户ID""" """获取当前用户ID"""
return cls._current_user_id return cls._current_user_id

View File

@@ -4,8 +4,8 @@ InsightFlow Multimodal Module Test Script
测试多模态支持模块 测试多模态支持模块
""" """
import sys
import os import os
import sys
# 添加 backend 目录到路径 # 添加 backend 目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -18,25 +18,19 @@ print("=" * 60)
print("\n1. 测试模块导入...") print("\n1. 测试模块导入...")
try: try:
from multimodal_processor import ( from multimodal_processor import get_multimodal_processor
get_multimodal_processor
)
print(" ✓ multimodal_processor 导入成功") print(" ✓ multimodal_processor 导入成功")
except ImportError as e: except ImportError as e:
print(f" ✗ multimodal_processor 导入失败: {e}") print(f" ✗ multimodal_processor 导入失败: {e}")
try: try:
from image_processor import ( from image_processor import get_image_processor
get_image_processor
)
print(" ✓ image_processor 导入成功") print(" ✓ image_processor 导入成功")
except ImportError as e: except ImportError as e:
print(f" ✗ image_processor 导入失败: {e}") print(f" ✗ image_processor 导入失败: {e}")
try: try:
from multimodal_entity_linker import ( from multimodal_entity_linker import get_multimodal_entity_linker
get_multimodal_entity_linker
)
print(" ✓ multimodal_entity_linker 导入成功") print(" ✓ multimodal_entity_linker 导入成功")
except ImportError as e: except ImportError as e:
print(f" ✗ multimodal_entity_linker 导入失败: {e}") print(f" ✗ multimodal_entity_linker 导入失败: {e}")

View File

@@ -4,19 +4,19 @@ InsightFlow Phase 7 Task 6 & 8 测试脚本
测试高级搜索与发现、性能优化与扩展功能 测试高级搜索与发现、性能优化与扩展功能
""" """
from performance_manager import (
get_performance_manager, CacheManager,
TaskQueue, PerformanceMonitor
)
from search_manager import (
get_search_manager, FullTextSearch,
SemanticSearch, EntityPathDiscovery,
KnowledgeGapDetection
)
import os import os
import sys import sys
import time import time
from performance_manager import CacheManager, PerformanceMonitor, TaskQueue, get_performance_manager
from search_manager import (
EntityPathDiscovery,
FullTextSearch,
KnowledgeGapDetection,
SemanticSearch,
get_search_manager,
)
# 添加 backend 到路径 # 添加 backend 到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -10,11 +10,11 @@ InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试脚本
5. 资源使用统计 5. 资源使用统计
""" """
from tenant_manager import (
get_tenant_manager
)
import sys
import os import os
import sys
from tenant_manager import get_tenant_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -3,13 +3,12 @@
InsightFlow Phase 8 Task 2 测试脚本 - 订阅与计费系统 InsightFlow Phase 8 Task 2 测试脚本 - 订阅与计费系统
""" """
from subscription_manager import (
SubscriptionManager, PaymentProvider
)
import sys
import os import os
import sys
import tempfile import tempfile
from subscription_manager import PaymentProvider, SubscriptionManager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -4,12 +4,11 @@ InsightFlow Phase 8 Task 4 测试脚本
测试 AI 能力增强功能 测试 AI 能力增强功能
""" """
from ai_manager import (
get_ai_manager, ModelType, PredictionType
)
import asyncio import asyncio
import sys
import os import os
import sys
from ai_manager import ModelType, PredictionType, get_ai_manager
# Add backend directory to path # Add backend directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -13,14 +13,20 @@ InsightFlow Phase 8 Task 5 - 运营与增长工具测试脚本
python test_phase8_task5.py python test_phase8_task5.py
""" """
from growth_manager import (
GrowthManager, EventType, ExperimentStatus, TrafficAllocationType, EmailTemplateType, WorkflowTriggerType
)
import asyncio import asyncio
import sys
import os import os
import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from growth_manager import (
EmailTemplateType,
EventType,
ExperimentStatus,
GrowthManager,
TrafficAllocationType,
WorkflowTriggerType,
)
# 添加 backend 目录到路径 # 添加 backend 目录到路径
backend_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path: if backend_dir not in sys.path:

View File

@@ -10,17 +10,20 @@ InsightFlow Phase 8 Task 6: Developer Ecosystem Test Script
4. 开发者文档与示例代码 4. 开发者文档与示例代码
""" """
from developer_ecosystem_manager import (
DeveloperEcosystemManager,
SDKLanguage, TemplateCategory,
PluginCategory, PluginStatus,
DeveloperStatus
)
import sys
import os import os
import sys
import uuid import uuid
from datetime import datetime from datetime import datetime
from developer_ecosystem_manager import (
DeveloperEcosystemManager,
DeveloperStatus,
PluginCategory,
PluginStatus,
SDKLanguage,
TemplateCategory,
)
# Add backend directory to path # Add backend directory to path
backend_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path: if backend_dir not in sys.path:

View File

@@ -10,15 +10,20 @@ InsightFlow Phase 8 Task 8: Operations & Monitoring Test Script
4. 成本优化 4. 成本优化
""" """
from ops_manager import ( import json
get_ops_manager, AlertSeverity, AlertStatus, AlertChannelType, AlertRuleType,
ResourceType
)
import os import os
import sys import sys
import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from ops_manager import (
AlertChannelType,
AlertRuleType,
AlertSeverity,
AlertStatus,
ResourceType,
get_ops_manager,
)
# Add backend directory to path # Add backend directory to path
backend_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path: if backend_dir not in sys.path:

View File

@@ -6,7 +6,7 @@
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, Any from typing import Any
class TingwuClient: class TingwuClient:
@@ -18,7 +18,7 @@ class TingwuClient:
if not self.access_key or not self.secret_key: if not self.access_key or not self.secret_key:
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required") raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> Dict[str, str]: def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
"""阿里云签名 V3""" """阿里云签名 V3"""
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
@@ -35,9 +35,9 @@ class TingwuClient:
def create_task(self, audio_url: str, language: str = "zh") -> str: def create_task(self, audio_url: str, language: str = "zh") -> str:
"""创建听悟任务""" """创建听悟任务"""
try: try:
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
from alibabacloud_tea_openapi import models as open_api_models
config = open_api_models.Config( config = open_api_models.Config(
access_key_id=self.access_key, access_key_id=self.access_key,
@@ -74,12 +74,12 @@ class TingwuClient:
print(f"Tingwu API error: {e}") print(f"Tingwu API error: {e}")
return f"mock_task_{int(time.time())}" return f"mock_task_{int(time.time())}"
def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> Dict[str, Any]: def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> dict[str, Any]:
"""获取任务结果""" """获取任务结果"""
try: try:
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
from alibabacloud_tea_openapi import models as open_api_models
config = open_api_models.Config( config = open_api_models.Config(
access_key_id=self.access_key, access_key_id=self.access_key,
@@ -114,7 +114,7 @@ class TingwuClient:
raise TimeoutError(f"Task {task_id} timeout") raise TimeoutError(f"Task {task_id} timeout")
def _parse_result(self, data) -> Dict[str, Any]: def _parse_result(self, data) -> dict[str, Any]:
"""解析结果""" """解析结果"""
result = data.result result = data.result
transcription = result.transcription transcription = result.transcription
@@ -140,7 +140,7 @@ class TingwuClient:
"segments": segments "segments": segments
} }
def _mock_result(self) -> Dict[str, Any]: def _mock_result(self) -> dict[str, Any]:
"""Mock 结果""" """Mock 结果"""
return { return {
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
@@ -149,7 +149,7 @@ class TingwuClient:
] ]
} }
def transcribe(self, audio_url: str, language: str = "zh") -> Dict[str, Any]: def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
"""一键转录""" """一键转录"""
task_id = self.create_task(audio_url, language) task_id = self.create_task(audio_url, language)
print(f"Tingwu task: {task_id}") print(f"Tingwu task: {task_id}")

View File

@@ -9,20 +9,21 @@ InsightFlow Workflow Manager - Phase 7
- 工作流配置管理 - 工作流配置管理
""" """
import json
import uuid
import asyncio import asyncio
import httpx import json
import logging import logging
from datetime import datetime, timedelta import uuid
from typing import List, Dict, Optional, Callable, Any from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any
import httpx
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.events import EVENT_JOB_EXECUTED, EVENT_JOB_ERROR
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -70,9 +71,9 @@ class WorkflowTask:
workflow_id: str workflow_id: str
name: str name: str
task_type: str # analyze, align, discover_relations, notify, custom task_type: str # analyze, align, discover_relations, notify, custom
config: Dict = field(default_factory=dict) config: dict = field(default_factory=dict)
order: int = 0 order: int = 0
depends_on: List[str] = field(default_factory=list) depends_on: list[str] = field(default_factory=list)
timeout_seconds: int = 300 timeout_seconds: int = 300
retry_count: int = 3 retry_count: int = 3
retry_delay: int = 5 retry_delay: int = 5
@@ -94,12 +95,12 @@ class WebhookConfig:
webhook_type: str # feishu, dingtalk, slack, custom webhook_type: str # feishu, dingtalk, slack, custom
url: str url: str
secret: str = "" # 用于签名验证 secret: str = "" # 用于签名验证
headers: Dict = field(default_factory=dict) headers: dict = field(default_factory=dict)
template: str = "" # 消息模板 template: str = "" # 消息模板
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_used_at: Optional[str] = None last_used_at: str | None = None
success_count: int = 0 success_count: int = 0
fail_count: int = 0 fail_count: int = 0
@@ -119,15 +120,15 @@ class Workflow:
workflow_type: str workflow_type: str
project_id: str project_id: str
status: str = "active" status: str = "active"
schedule: Optional[str] = None # cron expression or interval schedule: str | None = None # cron expression or interval
schedule_type: str = "manual" # manual, cron, interval schedule_type: str = "manual" # manual, cron, interval
config: Dict = field(default_factory=dict) config: dict = field(default_factory=dict)
webhook_ids: List[str] = field(default_factory=list) webhook_ids: list[str] = field(default_factory=list)
is_active: bool = True is_active: bool = True
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
last_run_at: Optional[str] = None last_run_at: str | None = None
next_run_at: Optional[str] = None next_run_at: str | None = None
run_count: int = 0 run_count: int = 0
success_count: int = 0 success_count: int = 0
fail_count: int = 0 fail_count: int = 0
@@ -144,13 +145,13 @@ class WorkflowLog:
"""工作流执行日志""" """工作流执行日志"""
id: str id: str
workflow_id: str workflow_id: str
task_id: Optional[str] = None task_id: str | None = None
status: str = "pending" # pending, running, success, failed, cancelled status: str = "pending" # pending, running, success, failed, cancelled
start_time: Optional[str] = None start_time: str | None = None
end_time: Optional[str] = None end_time: str | None = None
duration_ms: int = 0 duration_ms: int = 0
input_data: Dict = field(default_factory=dict) input_data: dict = field(default_factory=dict)
output_data: Dict = field(default_factory=dict) output_data: dict = field(default_factory=dict)
error_message: str = "" error_message: str = ""
created_at: str = "" created_at: str = ""
@@ -165,7 +166,7 @@ class WebhookNotifier:
def __init__(self): def __init__(self):
self.http_client = httpx.AsyncClient(timeout=30.0) self.http_client = httpx.AsyncClient(timeout=30.0)
async def send(self, config: WebhookConfig, message: Dict) -> bool: async def send(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Webhook 通知""" """发送 Webhook 通知"""
try: try:
webhook_type = WebhookType(config.webhook_type) webhook_type = WebhookType(config.webhook_type)
@@ -179,14 +180,14 @@ class WebhookNotifier:
else: else:
return await self._send_custom(config, message) return await self._send_custom(config, message)
except (httpx.HTTPError, asyncio.TimeoutError) as e: except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Webhook send failed: {e}") logger.error(f"Webhook send failed: {e}")
return False return False
async def _send_feishu(self, config: WebhookConfig, message: Dict) -> bool: async def _send_feishu(self, config: WebhookConfig, message: dict) -> bool:
"""发送飞书通知""" """发送飞书通知"""
import hashlib
import base64 import base64
import hashlib
import hmac import hmac
timestamp = str(int(datetime.now().timestamp())) timestamp = str(int(datetime.now().timestamp()))
@@ -252,10 +253,10 @@ class WebhookNotifier:
return result.get("code") == 0 return result.get("code") == 0
async def _send_dingtalk(self, config: WebhookConfig, message: Dict) -> bool: async def _send_dingtalk(self, config: WebhookConfig, message: dict) -> bool:
"""发送钉钉通知""" """发送钉钉通知"""
import hashlib
import base64 import base64
import hashlib
import hmac import hmac
import urllib.parse import urllib.parse
@@ -314,7 +315,7 @@ class WebhookNotifier:
return result.get("errcode") == 0 return result.get("errcode") == 0
async def _send_slack(self, config: WebhookConfig, message: Dict) -> bool: async def _send_slack(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Slack 通知""" """发送 Slack 通知"""
# Slack 直接支持标准 webhook 格式 # Slack 直接支持标准 webhook 格式
payload = { payload = {
@@ -341,7 +342,7 @@ class WebhookNotifier:
return response.text == "ok" return response.text == "ok"
async def _send_custom(self, config: WebhookConfig, message: Dict) -> bool: async def _send_custom(self, config: WebhookConfig, message: dict) -> bool:
"""发送自定义 Webhook 通知""" """发送自定义 Webhook 通知"""
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -374,8 +375,8 @@ class WorkflowManager:
self.db = db_manager self.db = db_manager
self.scheduler = AsyncIOScheduler() self.scheduler = AsyncIOScheduler()
self.notifier = WebhookNotifier() self.notifier = WebhookNotifier()
self._task_handlers: Dict[str, Callable] = {} self._task_handlers: dict[str, Callable] = {}
self._running_tasks: Dict[str, asyncio.Task] = {} self._running_tasks: dict[str, asyncio.Task] = {}
self._setup_default_handlers() self._setup_default_handlers()
# 添加调度器事件监听 # 添加调度器事件监听
@@ -421,7 +422,7 @@ class WorkflowManager:
for workflow in workflows: for workflow in workflows:
if workflow.schedule and workflow.is_active: if workflow.schedule and workflow.is_active:
self._schedule_workflow(workflow) self._schedule_workflow(workflow)
except (httpx.HTTPError, asyncio.TimeoutError) as e: except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Failed to load workflows: {e}") logger.error(f"Failed to load workflows: {e}")
def _schedule_workflow(self, workflow: Workflow): def _schedule_workflow(self, workflow: Workflow):
@@ -458,7 +459,7 @@ class WorkflowManager:
"""调度器调用的工作流执行函数""" """调度器调用的工作流执行函数"""
try: try:
await self.execute_workflow(workflow_id) await self.execute_workflow(workflow_id)
except (httpx.HTTPError, asyncio.TimeoutError) as e: except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Scheduled workflow execution failed: {e}") logger.error(f"Scheduled workflow execution failed: {e}")
def _on_job_executed(self, event): def _on_job_executed(self, event):
@@ -497,7 +498,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def get_workflow(self, workflow_id: str) -> Optional[Workflow]: def get_workflow(self, workflow_id: str) -> Workflow | None:
"""获取工作流""" """获取工作流"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -514,7 +515,7 @@ class WorkflowManager:
conn.close() conn.close()
def list_workflows(self, project_id: str = None, status: str = None, def list_workflows(self, project_id: str = None, status: str = None,
workflow_type: str = None) -> List[Workflow]: workflow_type: str = None) -> list[Workflow]:
"""列出工作流""" """列出工作流"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -542,7 +543,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def update_workflow(self, workflow_id: str, **kwargs) -> Optional[Workflow]: def update_workflow(self, workflow_id: str, **kwargs) -> Workflow | None:
"""更新工作流""" """更新工作流"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -648,7 +649,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def get_task(self, task_id: str) -> Optional[WorkflowTask]: def get_task(self, task_id: str) -> WorkflowTask | None:
"""获取任务""" """获取任务"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -664,7 +665,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def list_tasks(self, workflow_id: str) -> List[WorkflowTask]: def list_tasks(self, workflow_id: str) -> list[WorkflowTask]:
"""列出工作流的所有任务""" """列出工作流的所有任务"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -677,7 +678,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def update_task(self, task_id: str, **kwargs) -> Optional[WorkflowTask]: def update_task(self, task_id: str, **kwargs) -> WorkflowTask | None:
"""更新任务""" """更新任务"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -758,7 +759,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def get_webhook(self, webhook_id: str) -> Optional[WebhookConfig]: def get_webhook(self, webhook_id: str) -> WebhookConfig | None:
"""获取 Webhook 配置""" """获取 Webhook 配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -774,7 +775,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def list_webhooks(self) -> List[WebhookConfig]: def list_webhooks(self) -> list[WebhookConfig]:
"""列出所有 Webhook 配置""" """列出所有 Webhook 配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -786,7 +787,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def update_webhook(self, webhook_id: str, **kwargs) -> Optional[WebhookConfig]: def update_webhook(self, webhook_id: str, **kwargs) -> WebhookConfig | None:
"""更新 Webhook 配置""" """更新 Webhook 配置"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -889,7 +890,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def update_log(self, log_id: str, **kwargs) -> Optional[WorkflowLog]: def update_log(self, log_id: str, **kwargs) -> WorkflowLog | None:
"""更新工作流日志""" """更新工作流日志"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -918,7 +919,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def get_log(self, log_id: str) -> Optional[WorkflowLog]: def get_log(self, log_id: str) -> WorkflowLog | None:
"""获取日志""" """获取日志"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -935,7 +936,7 @@ class WorkflowManager:
conn.close() conn.close()
def list_logs(self, workflow_id: str = None, task_id: str = None, def list_logs(self, workflow_id: str = None, task_id: str = None,
status: str = None, limit: int = 100, offset: int = 0) -> List[WorkflowLog]: status: str = None, limit: int = 100, offset: int = 0) -> list[WorkflowLog]:
"""列出工作流日志""" """列出工作流日志"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -966,7 +967,7 @@ class WorkflowManager:
finally: finally:
conn.close() conn.close()
def get_workflow_stats(self, workflow_id: str, days: int = 30) -> Dict: def get_workflow_stats(self, workflow_id: str, days: int = 30) -> dict:
"""获取工作流统计""" """获取工作流统计"""
conn = self.db.get_conn() conn = self.db.get_conn()
try: try:
@@ -1037,7 +1038,7 @@ class WorkflowManager:
# ==================== Workflow Execution ==================== # ==================== Workflow Execution ====================
async def execute_workflow(self, workflow_id: str, input_data: Dict = None) -> Dict: async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict:
"""执行工作流""" """执行工作流"""
workflow = self.get_workflow(workflow_id) workflow = self.get_workflow(workflow_id)
if not workflow: if not workflow:
@@ -1100,7 +1101,7 @@ class WorkflowManager:
"duration_ms": duration "duration_ms": duration
} }
except (httpx.HTTPError, asyncio.TimeoutError) as e: except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Workflow {workflow_id} execution failed: {e}") logger.error(f"Workflow {workflow_id} execution failed: {e}")
# 更新日志为失败 # 更新日志为失败
@@ -1122,8 +1123,8 @@ class WorkflowManager:
raise raise
async def _execute_tasks_with_deps(self, tasks: List[WorkflowTask], async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask],
input_data: Dict, log_id: str) -> Dict: input_data: dict, log_id: str) -> dict:
"""按依赖顺序执行任务""" """按依赖顺序执行任务"""
results = {} results = {}
completed_tasks = set() completed_tasks = set()
@@ -1158,7 +1159,7 @@ class WorkflowManager:
try: try:
result = await self._execute_single_task(task, task_input, log_id) result = await self._execute_single_task(task, task_input, log_id)
break break
except (httpx.HTTPError, asyncio.TimeoutError) as e: except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Task {task.id} retry {attempt + 1} failed: {e}") logger.error(f"Task {task.id} retry {attempt + 1} failed: {e}")
if attempt == task.retry_count - 1: if attempt == task.retry_count - 1:
raise raise
@@ -1171,7 +1172,7 @@ class WorkflowManager:
return results return results
async def _execute_single_task(self, task: WorkflowTask, async def _execute_single_task(self, task: WorkflowTask,
input_data: Dict, log_id: str) -> Any: input_data: dict, log_id: str) -> Any:
"""执行单个任务""" """执行单个任务"""
handler = self._task_handlers.get(task.task_type) handler = self._task_handlers.get(task.task_type)
if not handler: if not handler:
@@ -1205,7 +1206,7 @@ class WorkflowManager:
return result return result
except asyncio.TimeoutError: except TimeoutError:
self.update_log( self.update_log(
task_log.id, task_log.id,
status=TaskStatus.FAILED.value, status=TaskStatus.FAILED.value,
@@ -1224,7 +1225,7 @@ class WorkflowManager:
raise raise
async def _execute_default_workflow(self, workflow: Workflow, async def _execute_default_workflow(self, workflow: Workflow,
input_data: Dict) -> Dict: input_data: dict) -> dict:
"""执行默认工作流(根据类型)""" """执行默认工作流(根据类型)"""
workflow_type = WorkflowType(workflow.workflow_type) workflow_type = WorkflowType(workflow.workflow_type)
@@ -1241,7 +1242,7 @@ class WorkflowManager:
# ==================== Default Task Handlers ==================== # ==================== Default Task Handlers ====================
async def _handle_analyze_task(self, task: WorkflowTask, input_data: Dict) -> Dict: async def _handle_analyze_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理分析任务""" """处理分析任务"""
project_id = input_data.get("project_id") project_id = input_data.get("project_id")
file_ids = input_data.get("file_ids", []) file_ids = input_data.get("file_ids", [])
@@ -1258,7 +1259,7 @@ class WorkflowManager:
"status": "completed" "status": "completed"
} }
async def _handle_align_task(self, task: WorkflowTask, input_data: Dict) -> Dict: async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理实体对齐任务""" """处理实体对齐任务"""
project_id = input_data.get("project_id") project_id = input_data.get("project_id")
threshold = task.config.get("threshold", 0.85) threshold = task.config.get("threshold", 0.85)
@@ -1276,7 +1277,7 @@ class WorkflowManager:
} }
async def _handle_discover_relations_task(self, task: WorkflowTask, async def _handle_discover_relations_task(self, task: WorkflowTask,
input_data: Dict) -> Dict: input_data: dict) -> dict:
"""处理关系发现任务""" """处理关系发现任务"""
project_id = input_data.get("project_id") project_id = input_data.get("project_id")
@@ -1291,7 +1292,7 @@ class WorkflowManager:
"status": "completed" "status": "completed"
} }
async def _handle_notify_task(self, task: WorkflowTask, input_data: Dict) -> Dict: async def _handle_notify_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理通知任务""" """处理通知任务"""
webhook_id = task.config.get("webhook_id") webhook_id = task.config.get("webhook_id")
message = task.config.get("message", {}) message = task.config.get("message", {})
@@ -1319,7 +1320,7 @@ class WorkflowManager:
"success": success "success": success
} }
async def _handle_custom_task(self, task: WorkflowTask, input_data: Dict) -> Dict: async def _handle_custom_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理自定义任务""" """处理自定义任务"""
# 自定义任务的具体逻辑由外部处理器实现 # 自定义任务的具体逻辑由外部处理器实现
return { return {
@@ -1331,7 +1332,7 @@ class WorkflowManager:
# ==================== Default Workflow Implementations ==================== # ==================== Default Workflow Implementations ====================
async def _auto_analyze_files(self, workflow: Workflow, input_data: Dict) -> Dict: async def _auto_analyze_files(self, workflow: Workflow, input_data: dict) -> dict:
"""自动分析新上传的文件""" """自动分析新上传的文件"""
project_id = workflow.project_id project_id = workflow.project_id
@@ -1346,7 +1347,7 @@ class WorkflowManager:
"status": "completed" "status": "completed"
} }
async def _auto_align_entities(self, workflow: Workflow, input_data: Dict) -> Dict: async def _auto_align_entities(self, workflow: Workflow, input_data: dict) -> dict:
"""自动实体对齐""" """自动实体对齐"""
project_id = workflow.project_id project_id = workflow.project_id
threshold = workflow.config.get("threshold", 0.85) threshold = workflow.config.get("threshold", 0.85)
@@ -1359,7 +1360,7 @@ class WorkflowManager:
"status": "completed" "status": "completed"
} }
async def _auto_discover_relations(self, workflow: Workflow, input_data: Dict) -> Dict: async def _auto_discover_relations(self, workflow: Workflow, input_data: dict) -> dict:
"""自动关系发现""" """自动关系发现"""
project_id = workflow.project_id project_id = workflow.project_id
@@ -1370,7 +1371,7 @@ class WorkflowManager:
"status": "completed" "status": "completed"
} }
async def _generate_scheduled_report(self, workflow: Workflow, input_data: Dict) -> Dict: async def _generate_scheduled_report(self, workflow: Workflow, input_data: dict) -> dict:
"""生成定时报告""" """生成定时报告"""
project_id = workflow.project_id project_id = workflow.project_id
report_type = workflow.config.get("report_type", "summary") report_type = workflow.config.get("report_type", "summary")
@@ -1385,7 +1386,7 @@ class WorkflowManager:
# ==================== Notification ==================== # ==================== Notification ====================
async def _send_workflow_notification(self, workflow: Workflow, async def _send_workflow_notification(self, workflow: Workflow,
results: Dict, success: bool = True): results: dict, success: bool = True):
"""发送工作流执行通知""" """发送工作流执行通知"""
if not workflow.webhook_ids: if not workflow.webhook_ids:
return return
@@ -1414,11 +1415,11 @@ class WorkflowManager:
try: try:
result = await self.notifier.send(webhook, message) result = await self.notifier.send(webhook, message)
self.update_webhook_stats(webhook_id, result) self.update_webhook_stats(webhook_id, result)
except (httpx.HTTPError, asyncio.TimeoutError) as e: except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Failed to send notification to {webhook_id}: {e}") logger.error(f"Failed to send notification to {webhook_id}: {e}")
def _build_feishu_message(self, workflow: Workflow, results: Dict, def _build_feishu_message(self, workflow: Workflow, results: dict,
success: bool) -> Dict: success: bool) -> dict:
"""构建飞书消息""" """构建飞书消息"""
status_text = "✅ 成功" if success else "❌ 失败" status_text = "✅ 成功" if success else "❌ 失败"
@@ -1431,8 +1432,8 @@ class WorkflowManager:
] ]
} }
def _build_dingtalk_message(self, workflow: Workflow, results: Dict, def _build_dingtalk_message(self, workflow: Workflow, results: dict,
success: bool) -> Dict: success: bool) -> dict:
"""构建钉钉消息""" """构建钉钉消息"""
status_text = "✅ 成功" if success else "❌ 失败" status_text = "✅ 成功" if success else "❌ 失败"
@@ -1453,8 +1454,8 @@ class WorkflowManager:
""" """
} }
def _build_slack_message(self, workflow: Workflow, results: Dict, def _build_slack_message(self, workflow: Workflow, results: dict,
success: bool) -> Dict: success: bool) -> dict:
"""构建 Slack 消息""" """构建 Slack 消息"""
color = "#36a64f" if success else "#ff0000" color = "#36a64f" if success else "#ff0000"
status_text = "Success" if success else "Failed" status_text = "Success" if success else "Failed"