fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
This commit is contained in:
OpenClaw Bot
2026-02-27 18:09:24 +08:00
parent 646b64daf7
commit 17bda3dbce
38 changed files with 1993 additions and 1972 deletions

View File

@@ -8,25 +8,25 @@ AI 能力增强模块
- 预测性分析(趋势预测、异常检测)
"""
import os
import json
import sqlite3
import httpx
import asyncio
import json
import os
import random
import sqlite3
import statistics
from typing import List, Dict, Optional
import uuid
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from collections import defaultdict
import uuid
from enum import StrEnum
import httpx
# Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class ModelType(str, Enum):
class ModelType(StrEnum):
"""模型类型"""
CUSTOM_NER = "custom_ner" # 自定义实体识别
@@ -35,7 +35,7 @@ class ModelType(str, Enum):
PREDICTION = "prediction" # 预测
class ModelStatus(str, Enum):
class ModelStatus(StrEnum):
"""模型状态"""
PENDING = "pending"
@@ -45,7 +45,7 @@ class ModelStatus(str, Enum):
ARCHIVED = "archived"
class MultimodalProvider(str, Enum):
class MultimodalProvider(StrEnum):
"""多模态模型提供商"""
GPT4V = "gpt-4-vision"
@@ -54,7 +54,7 @@ class MultimodalProvider(str, Enum):
KIMI_VL = "kimi-vl"
class PredictionType(str, Enum):
class PredictionType(StrEnum):
"""预测类型"""
TREND = "trend" # 趋势预测
@@ -73,13 +73,13 @@ class CustomModel:
description: str
model_type: ModelType
status: ModelStatus
training_data: Dict # 训练数据配置
hyperparameters: Dict # 超参数
metrics: Dict # 训练指标
model_path: Optional[str] # 模型文件路径
training_data: dict # 训练数据配置
hyperparameters: dict # 超参数
metrics: dict # 训练指标
model_path: str | None # 模型文件路径
created_at: str
updated_at: str
trained_at: Optional[str]
trained_at: str | None
created_by: str
@@ -90,8 +90,8 @@ class TrainingSample:
id: str
model_id: str
text: str
entities: List[Dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}]
metadata: Dict
entities: list[dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}]
metadata: dict
created_at: str
@@ -104,9 +104,9 @@ class MultimodalAnalysis:
project_id: str
provider: MultimodalProvider
input_type: str # image, video, audio, mixed
input_urls: List[str]
input_urls: list[str]
prompt: str
result: Dict # 分析结果
result: dict # 分析结果
tokens_used: int
cost: float
created_at: str
@@ -121,9 +121,9 @@ class KnowledgeGraphRAG:
project_id: str
name: str
description: str
kg_config: Dict # 知识图谱配置
retrieval_config: Dict # 检索配置
generation_config: Dict # 生成配置
kg_config: dict # 知识图谱配置
retrieval_config: dict # 检索配置
generation_config: dict # 生成配置
is_active: bool
created_at: str
updated_at: str
@@ -136,9 +136,9 @@ class RAGQuery:
id: str
rag_id: str
query: str
context: Dict # 检索到的上下文
context: dict # 检索到的上下文
answer: str
sources: List[Dict] # 来源信息
sources: list[dict] # 来源信息
confidence: float
tokens_used: int
latency_ms: int
@@ -154,11 +154,11 @@ class PredictionModel:
project_id: str
name: str
prediction_type: PredictionType
target_entity_type: Optional[str] # 目标实体类型
features: List[str] # 特征列表
model_config: Dict # 模型配置
accuracy: Optional[float]
last_trained_at: Optional[str]
target_entity_type: str | None # 目标实体类型
features: list[str] # 特征列表
model_config: dict # 模型配置
accuracy: float | None
last_trained_at: str | None
prediction_count: int
is_active: bool
created_at: str
@@ -172,12 +172,12 @@ class PredictionResult:
id: str
model_id: str
prediction_type: PredictionType
target_id: Optional[str] # 预测目标ID
prediction_data: Dict # 预测数据
target_id: str | None # 预测目标ID
prediction_data: dict # 预测数据
confidence: float
explanation: str # 预测解释
actual_value: Optional[str] # 实际值(用于验证)
is_correct: Optional[bool]
actual_value: str | None # 实际值(用于验证)
is_correct: bool | None
created_at: str
@@ -192,8 +192,8 @@ class SmartSummary:
source_id: str
summary_type: str # extractive, abstractive, key_points, timeline
content: str
key_points: List[str]
entities_mentioned: List[str]
key_points: list[str]
entities_mentioned: list[str]
confidence: float
tokens_used: int
created_at: str
@@ -223,8 +223,8 @@ class AIManager:
name: str,
description: str,
model_type: ModelType,
training_data: Dict,
hyperparameters: Dict,
training_data: dict,
hyperparameters: dict,
created_by: str,
) -> CustomModel:
"""创建自定义模型"""
@@ -277,7 +277,7 @@ class AIManager:
return model
def get_custom_model(self, model_id: str) -> Optional[CustomModel]:
def get_custom_model(self, model_id: str) -> CustomModel | None:
"""获取自定义模型"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone()
@@ -288,8 +288,8 @@ class AIManager:
return self._row_to_custom_model(row)
def list_custom_models(
self, tenant_id: str, model_type: Optional[ModelType] = None, status: Optional[ModelStatus] = None
) -> List[CustomModel]:
self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None
) -> list[CustomModel]:
"""列出自定义模型"""
query = "SELECT * FROM custom_models WHERE tenant_id = ?"
params = [tenant_id]
@@ -308,7 +308,7 @@ class AIManager:
return [self._row_to_custom_model(row) for row in rows]
def add_training_sample(
self, model_id: str, text: str, entities: List[Dict], metadata: Dict = None
self, model_id: str, text: str, entities: list[dict], metadata: dict = None
) -> TrainingSample:
"""添加训练样本"""
sample_id = f"ts_{uuid.uuid4().hex[:16]}"
@@ -338,7 +338,7 @@ class AIManager:
return sample
def get_training_samples(self, model_id: str) -> List[TrainingSample]:
def get_training_samples(self, model_id: str) -> list[TrainingSample]:
"""获取训练样本"""
with self._get_db() as conn:
rows = conn.execute(
@@ -410,7 +410,7 @@ class AIManager:
conn.commit()
raise e
async def predict_with_custom_model(self, model_id: str, text: str) -> List[Dict]:
async def predict_with_custom_model(self, model_id: str, text: str) -> list[dict]:
"""使用自定义模型进行预测"""
model = self.get_custom_model(model_id)
if not model or model.status != ModelStatus.READY:
@@ -461,7 +461,7 @@ class AIManager:
project_id: str,
provider: MultimodalProvider,
input_type: str,
input_urls: List[str],
input_urls: list[str],
prompt: str,
) -> MultimodalAnalysis:
"""多模态分析"""
@@ -517,7 +517,7 @@ class AIManager:
return analysis
async def _call_gpt4v(self, image_urls: List[str], prompt: str) -> Dict:
async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict:
"""调用 GPT-4V"""
headers = {"Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json"}
@@ -544,7 +544,7 @@ class AIManager:
"cost": result["usage"]["total_tokens"] * 0.00001, # 估算成本
}
async def _call_claude3(self, image_urls: List[str], prompt: str) -> Dict:
async def _call_claude3(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Claude 3"""
headers = {
"x-api-key": self.anthropic_api_key,
@@ -576,7 +576,7 @@ class AIManager:
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"]) * 0.000015,
}
async def _call_kimi_multimodal(self, image_urls: List[str], prompt: str) -> Dict:
async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Kimi 多模态模型"""
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"}
@@ -600,7 +600,7 @@ class AIManager:
"cost": result["usage"]["total_tokens"] * 0.000005,
}
def get_multimodal_analyses(self, tenant_id: str, project_id: Optional[str] = None) -> List[MultimodalAnalysis]:
def get_multimodal_analyses(self, tenant_id: str, project_id: str | None = None) -> list[MultimodalAnalysis]:
"""获取多模态分析历史"""
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
params = [tenant_id]
@@ -623,9 +623,9 @@ class AIManager:
project_id: str,
name: str,
description: str,
kg_config: Dict,
retrieval_config: Dict,
generation_config: Dict,
kg_config: dict,
retrieval_config: dict,
generation_config: dict,
) -> KnowledgeGraphRAG:
"""创建知识图谱 RAG 配置"""
rag_id = f"kgr_{uuid.uuid4().hex[:16]}"
@@ -671,7 +671,7 @@ class AIManager:
return rag
def get_kg_rag(self, rag_id: str) -> Optional[KnowledgeGraphRAG]:
def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None:
"""获取知识图谱 RAG 配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone()
@@ -681,7 +681,7 @@ class AIManager:
return self._row_to_kg_rag(row)
def list_kg_rags(self, tenant_id: str, project_id: Optional[str] = None) -> List[KnowledgeGraphRAG]:
def list_kg_rags(self, tenant_id: str, project_id: str | None = None) -> list[KnowledgeGraphRAG]:
"""列出知识图谱 RAG 配置"""
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
params = [tenant_id]
@@ -697,7 +697,7 @@ class AIManager:
return [self._row_to_kg_rag(row) for row in rows]
async def query_kg_rag(
self, rag_id: str, query: str, project_entities: List[Dict], project_relations: List[Dict]
self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict]
) -> RAGQuery:
"""基于知识图谱的 RAG 查询"""
import time
@@ -832,7 +832,7 @@ class AIManager:
return rag_query
def _build_kg_context(self, entities: List[Dict], relations: List[Dict]) -> str:
def _build_kg_context(self, entities: list[dict], relations: list[dict]) -> str:
"""构建知识图谱上下文文本"""
context = []
@@ -858,7 +858,7 @@ class AIManager:
return "\n".join(context)
async def generate_smart_summary(
self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: Dict
self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: dict
) -> SmartSummary:
"""生成智能摘要"""
summary_id = f"ss_{uuid.uuid4().hex[:16]}"
@@ -999,9 +999,9 @@ class AIManager:
project_id: str,
name: str,
prediction_type: PredictionType,
target_entity_type: Optional[str],
features: List[str],
model_config: Dict,
target_entity_type: str | None,
features: list[str],
model_config: dict,
) -> PredictionModel:
"""创建预测模型"""
model_id = f"pm_{uuid.uuid4().hex[:16]}"
@@ -1053,7 +1053,7 @@ class AIManager:
return model
def get_prediction_model(self, model_id: str) -> Optional[PredictionModel]:
def get_prediction_model(self, model_id: str) -> PredictionModel | None:
"""获取预测模型"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM prediction_models WHERE id = ?", (model_id,)).fetchone()
@@ -1063,7 +1063,7 @@ class AIManager:
return self._row_to_prediction_model(row)
def list_prediction_models(self, tenant_id: str, project_id: Optional[str] = None) -> List[PredictionModel]:
def list_prediction_models(self, tenant_id: str, project_id: str | None = None) -> list[PredictionModel]:
"""列出预测模型"""
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
params = [tenant_id]
@@ -1078,7 +1078,7 @@ class AIManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_prediction_model(row) for row in rows]
async def train_prediction_model(self, model_id: str, historical_data: List[Dict]) -> PredictionModel:
async def train_prediction_model(self, model_id: str, historical_data: list[dict]) -> PredictionModel:
"""训练预测模型"""
model = self.get_prediction_model(model_id)
if not model:
@@ -1105,7 +1105,7 @@ class AIManager:
return self.get_prediction_model(model_id)
async def predict(self, model_id: str, input_data: Dict) -> PredictionResult:
async def predict(self, model_id: str, input_data: dict) -> PredictionResult:
"""进行预测"""
model = self.get_prediction_model(model_id)
if not model or not model.is_active:
@@ -1172,7 +1172,7 @@ class AIManager:
return result
def _predict_trend(self, input_data: Dict, model: PredictionModel) -> Dict:
def _predict_trend(self, input_data: dict, model: PredictionModel) -> dict:
"""趋势预测"""
historical_values = input_data.get("historical_values", [])
@@ -1211,7 +1211,7 @@ class AIManager:
"explanation": f"基于{len(historical_values)}个历史数据点,预测趋势为{trend}",
}
def _detect_anomaly(self, input_data: Dict, model: PredictionModel) -> Dict:
def _detect_anomaly(self, input_data: dict, model: PredictionModel) -> dict:
"""异常检测"""
value = input_data.get("value")
historical_values = input_data.get("historical_values", [])
@@ -1245,7 +1245,7 @@ class AIManager:
"explanation": f"当前值偏离均值{z_score:.2f}个标准差,{'检测到异常' if is_anomaly else '处于正常范围'}",
}
def _predict_entity_growth(self, input_data: Dict, model: PredictionModel) -> Dict:
def _predict_entity_growth(self, input_data: dict, model: PredictionModel) -> dict:
"""实体增长预测"""
entity_history = input_data.get("entity_history", [])
@@ -1273,7 +1273,7 @@ class AIManager:
"explanation": f"基于过去{len(entity_history)}个周期的数据,预测增长率{avg_growth_rate * 100:.1f}%",
}
def _predict_relation_evolution(self, input_data: Dict, model: PredictionModel) -> Dict:
def _predict_relation_evolution(self, input_data: dict, model: PredictionModel) -> dict:
"""关系演变预测"""
relation_history = input_data.get("relation_history", [])
@@ -1299,7 +1299,7 @@ class AIManager:
"explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势",
}
def get_prediction_results(self, model_id: str, limit: int = 100) -> List[PredictionResult]:
def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]:
"""获取预测结果历史"""
with self._get_db() as conn:
rows = conn.execute(

View File

@@ -4,14 +4,13 @@ InsightFlow API Key Manager - Phase 6
API Key 管理模块:生成、验证、撤销
"""
import os
import json
import hashlib
import json
import os
import secrets
import sqlite3
from datetime import datetime, timedelta
from typing import Optional, List, Dict
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
@@ -29,15 +28,15 @@ class ApiKey:
key_hash: str # 存储哈希值,不存储原始 key
key_preview: str # 前8位预览如 "ak_live_abc..."
name: str # 密钥名称/描述
owner_id: Optional[str] # 所有者ID预留多用户支持
permissions: List[str] # 权限列表,如 ["read", "write"]
owner_id: str | None # 所有者ID预留多用户支持
permissions: list[str] # 权限列表,如 ["read", "write"]
rate_limit: int # 每分钟请求限制
status: str # active, revoked, expired
created_at: str
expires_at: Optional[str]
last_used_at: Optional[str]
revoked_at: Optional[str]
revoked_reason: Optional[str]
expires_at: str | None
last_used_at: str | None
revoked_at: str | None
revoked_reason: str | None
total_calls: int = 0
@@ -131,10 +130,10 @@ class ApiKeyManager:
def create_key(
self,
name: str,
owner_id: Optional[str] = None,
permissions: List[str] = None,
owner_id: str | None = None,
permissions: list[str] = None,
rate_limit: int = 60,
expires_days: Optional[int] = None,
expires_days: int | None = None,
) -> tuple[str, ApiKey]:
"""
创建新的 API Key
@@ -196,7 +195,7 @@ class ApiKeyManager:
return raw_key, api_key
def validate_key(self, key: str) -> Optional[ApiKey]:
def validate_key(self, key: str) -> ApiKey | None:
"""
验证 API Key
@@ -231,7 +230,7 @@ class ApiKeyManager:
return api_key
def revoke_key(self, key_id: str, reason: str = "", owner_id: Optional[str] = None) -> bool:
def revoke_key(self, key_id: str, reason: str = "", owner_id: str | None = None) -> bool:
"""撤销 API Key"""
with sqlite3.connect(self.db_path) as conn:
# 验证所有权(如果提供了 owner_id
@@ -251,7 +250,7 @@ class ApiKeyManager:
conn.commit()
return cursor.rowcount > 0
def get_key_by_id(self, key_id: str, owner_id: Optional[str] = None) -> Optional[ApiKey]:
def get_key_by_id(self, key_id: str, owner_id: str | None = None) -> ApiKey | None:
"""通过 ID 获取 API Key不包含敏感信息"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
@@ -268,8 +267,8 @@ class ApiKeyManager:
return None
def list_keys(
self, owner_id: Optional[str] = None, status: Optional[str] = None, limit: int = 100, offset: int = 0
) -> List[ApiKey]:
self, owner_id: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0
) -> list[ApiKey]:
"""列出 API Keys"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
@@ -294,10 +293,10 @@ class ApiKeyManager:
def update_key(
self,
key_id: str,
name: Optional[str] = None,
permissions: Optional[List[str]] = None,
rate_limit: Optional[int] = None,
owner_id: Optional[str] = None,
name: str | None = None,
permissions: list[str] | None = None,
rate_limit: int | None = None,
owner_id: str | None = None,
) -> bool:
"""更新 API Key 信息"""
updates = []
@@ -371,12 +370,12 @@ class ApiKeyManager:
def get_call_logs(
self,
api_key_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
api_key_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
) -> list[dict]:
"""获取 API 调用日志"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
@@ -402,13 +401,13 @@ class ApiKeyManager:
rows = conn.execute(query, params).fetchall()
return [dict(row) for row in rows]
def get_call_stats(self, api_key_id: Optional[str] = None, days: int = 30) -> Dict:
def get_call_stats(self, api_key_id: str | None = None, days: int = 30) -> dict:
"""获取 API 调用统计"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
# 总体统计
query = """
query = f"""
SELECT
COUNT(*) as total_calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
@@ -417,8 +416,8 @@ class ApiKeyManager:
MAX(response_time_ms) as max_response_time,
MIN(response_time_ms) as min_response_time
FROM api_call_logs
WHERE created_at >= date('now', '-{} days')
""".format(days)
WHERE created_at >= date('now', '-{days} days')
"""
params = []
if api_key_id:
@@ -428,15 +427,15 @@ class ApiKeyManager:
row = conn.execute(query, params).fetchone()
# 按端点统计
endpoint_query = """
endpoint_query = f"""
SELECT
endpoint,
method,
COUNT(*) as calls,
AVG(response_time_ms) as avg_time
FROM api_call_logs
WHERE created_at >= date('now', '-{} days')
""".format(days)
WHERE created_at >= date('now', '-{days} days')
"""
endpoint_params = []
if api_key_id:
@@ -448,14 +447,14 @@ class ApiKeyManager:
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
# 按天统计
daily_query = """
daily_query = f"""
SELECT
date(created_at) as date,
COUNT(*) as calls,
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success
FROM api_call_logs
WHERE created_at >= date('now', '-{} days')
""".format(days)
WHERE created_at >= date('now', '-{days} days')
"""
daily_params = []
if api_key_id:
@@ -500,7 +499,7 @@ class ApiKeyManager:
# 全局实例
_api_key_manager: Optional[ApiKeyManager] = None
_api_key_manager: ApiKeyManager | None = None
def get_api_key_manager() -> ApiKeyManager:

View File

@@ -3,13 +3,13 @@ InsightFlow - 协作与共享模块 (Phase 7 Task 4)
支持项目分享、评论批注、变更历史、团队空间
"""
import hashlib
import json
import uuid
import hashlib
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
class SharePermission(Enum):
@@ -50,10 +50,10 @@ class ProjectShare:
permission: str # 权限级别
created_by: str # 创建者
created_at: str
expires_at: Optional[str] # 过期时间
max_uses: Optional[int] # 最大使用次数
expires_at: str | None # 过期时间
max_uses: int | None # 最大使用次数
use_count: int # 已使用次数
password_hash: Optional[str] # 密码保护
password_hash: str | None # 密码保护
is_active: bool # 是否激活
allow_download: bool # 允许下载
allow_export: bool # 允许导出
@@ -67,17 +67,17 @@ class Comment:
project_id: str
target_type: str # 评论目标类型
target_id: str # 目标ID
parent_id: Optional[str] # 父评论ID支持回复
parent_id: str | None # 父评论ID支持回复
author: str # 作者
author_name: str # 作者显示名
content: str # 评论内容
created_at: str
updated_at: str
resolved: bool # 是否已解决
resolved_by: Optional[str] # 解决者
resolved_at: Optional[str] # 解决时间
mentions: List[str] # 提及的用户
attachments: List[Dict] # 附件
resolved_by: str | None # 解决者
resolved_at: str | None # 解决时间
mentions: list[str] # 提及的用户
attachments: list[dict] # 附件
@dataclass
@@ -93,13 +93,13 @@ class ChangeRecord:
changed_by: str # 变更者
changed_by_name: str # 变更者显示名
changed_at: str
old_value: Optional[Dict] # 旧值
new_value: Optional[Dict] # 新值
old_value: dict | None # 旧值
new_value: dict | None # 新值
description: str # 变更描述
session_id: Optional[str] # 会话ID批量变更关联
session_id: str | None # 会话ID批量变更关联
reverted: bool # 是否已回滚
reverted_at: Optional[str] # 回滚时间
reverted_by: Optional[str] # 回滚者
reverted_at: str | None # 回滚时间
reverted_by: str | None # 回滚者
@dataclass
@@ -114,8 +114,8 @@ class TeamMember:
role: str # 角色 (owner/admin/editor/viewer)
joined_at: str
invited_by: str # 邀请者
last_active_at: Optional[str] # 最后活跃时间
permissions: List[str] # 具体权限列表
last_active_at: str | None # 最后活跃时间
permissions: list[str] # 具体权限列表
@dataclass
@@ -130,7 +130,7 @@ class TeamSpace:
updated_at: str
member_count: int
project_count: int
settings: Dict[str, Any] # 团队设置
settings: dict[str, Any] # 团队设置
class CollaborationManager:
@@ -138,8 +138,8 @@ class CollaborationManager:
def __init__(self, db_manager=None):
self.db = db_manager
self._shares_cache: Dict[str, ProjectShare] = {}
self._comments_cache: Dict[str, List[Comment]] = {}
self._shares_cache: dict[str, ProjectShare] = {}
self._comments_cache: dict[str, list[Comment]] = {}
# ============ 项目分享 ============
@@ -148,9 +148,9 @@ class CollaborationManager:
project_id: str,
created_by: str,
permission: str = "read_only",
expires_in_days: Optional[int] = None,
max_uses: Optional[int] = None,
password: Optional[str] = None,
expires_in_days: int | None = None,
max_uses: int | None = None,
password: str | None = None,
allow_download: bool = False,
allow_export: bool = False,
) -> ProjectShare:
@@ -224,7 +224,7 @@ class CollaborationManager:
)
self.db.conn.commit()
def validate_share_token(self, token: str, password: Optional[str] = None) -> Optional[ProjectShare]:
def validate_share_token(self, token: str, password: str | None = None) -> ProjectShare | None:
"""验证分享令牌"""
# 从缓存或数据库获取
share = self._shares_cache.get(token)
@@ -256,7 +256,7 @@ class CollaborationManager:
return share
def _get_share_from_db(self, token: str) -> Optional[ProjectShare]:
def _get_share_from_db(self, token: str) -> ProjectShare | None:
"""从数据库获取分享记录"""
cursor = self.db.conn.cursor()
cursor.execute(
@@ -320,7 +320,7 @@ class CollaborationManager:
return cursor.rowcount > 0
return False
def list_project_shares(self, project_id: str) -> List[ProjectShare]:
def list_project_shares(self, project_id: str) -> list[ProjectShare]:
"""列出项目的所有分享链接"""
if not self.db:
return []
@@ -366,9 +366,9 @@ class CollaborationManager:
author: str,
author_name: str,
content: str,
parent_id: Optional[str] = None,
mentions: Optional[List[str]] = None,
attachments: Optional[List[Dict]] = None,
parent_id: str | None = None,
mentions: list[str] | None = None,
attachments: list[dict] | None = None,
) -> Comment:
"""添加评论"""
comment_id = str(uuid.uuid4())
@@ -434,7 +434,7 @@ class CollaborationManager:
)
self.db.conn.commit()
def get_comments(self, target_type: str, target_id: str, include_resolved: bool = True) -> List[Comment]:
def get_comments(self, target_type: str, target_id: str, include_resolved: bool = True) -> list[Comment]:
"""获取评论列表"""
if not self.db:
return []
@@ -484,7 +484,7 @@ class CollaborationManager:
attachments=json.loads(row[14]) if row[14] else [],
)
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Optional[Comment]:
def update_comment(self, comment_id: str, content: str, updated_by: str) -> Comment | None:
"""更新评论"""
if not self.db:
return None
@@ -505,7 +505,7 @@ class CollaborationManager:
return self._get_comment_by_id(comment_id)
return None
def _get_comment_by_id(self, comment_id: str) -> Optional[Comment]:
def _get_comment_by_id(self, comment_id: str) -> Comment | None:
"""根据ID获取评论"""
cursor = self.db.conn.cursor()
cursor.execute("SELECT * FROM comments WHERE id = ?", (comment_id,))
@@ -551,7 +551,7 @@ class CollaborationManager:
self.db.conn.commit()
return cursor.rowcount > 0
def get_project_comments(self, project_id: str, limit: int = 50, offset: int = 0) -> List[Comment]:
def get_project_comments(self, project_id: str, limit: int = 50, offset: int = 0) -> list[Comment]:
"""获取项目下的所有评论"""
if not self.db:
return []
@@ -583,10 +583,10 @@ class CollaborationManager:
entity_name: str,
changed_by: str,
changed_by_name: str,
old_value: Optional[Dict] = None,
new_value: Optional[Dict] = None,
old_value: dict | None = None,
new_value: dict | None = None,
description: str = "",
session_id: Optional[str] = None,
session_id: str | None = None,
) -> ChangeRecord:
"""记录变更"""
record_id = str(uuid.uuid4())
@@ -651,11 +651,11 @@ class CollaborationManager:
def get_change_history(
self,
project_id: str,
entity_type: Optional[str] = None,
entity_id: Optional[str] = None,
entity_type: str | None = None,
entity_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> List[ChangeRecord]:
) -> list[ChangeRecord]:
"""获取变更历史"""
if not self.db:
return []
@@ -719,7 +719,7 @@ class CollaborationManager:
reverted_by=row[15],
)
def get_entity_version_history(self, entity_type: str, entity_id: str) -> List[ChangeRecord]:
def get_entity_version_history(self, entity_type: str, entity_id: str) -> list[ChangeRecord]:
"""获取实体的版本历史(用于版本对比)"""
if not self.db:
return []
@@ -757,7 +757,7 @@ class CollaborationManager:
self.db.conn.commit()
return cursor.rowcount > 0
def get_change_stats(self, project_id: str) -> Dict[str, Any]:
def get_change_stats(self, project_id: str) -> dict[str, Any]:
"""获取变更统计"""
if not self.db:
return {}
@@ -823,7 +823,7 @@ class CollaborationManager:
user_email: str,
role: str,
invited_by: str,
permissions: Optional[List[str]] = None,
permissions: list[str] | None = None,
) -> TeamMember:
"""添加团队成员"""
member_id = str(uuid.uuid4())
@@ -851,7 +851,7 @@ class CollaborationManager:
return member
def _get_default_permissions(self, role: str) -> List[str]:
def _get_default_permissions(self, role: str) -> list[str]:
"""获取角色的默认权限"""
permissions_map = {
"owner": ["read", "write", "delete", "share", "admin", "export"],
@@ -887,7 +887,7 @@ class CollaborationManager:
)
self.db.conn.commit()
def get_team_members(self, project_id: str) -> List[TeamMember]:
def get_team_members(self, project_id: str) -> list[TeamMember]:
"""获取团队成员列表"""
if not self.db:
return []

View File

@@ -5,13 +5,12 @@ InsightFlow Database Manager - Phase 5
支持实体属性扩展
"""
import os
import json
import os
import sqlite3
import uuid
from datetime import datetime
from typing import List, Dict, Optional
from dataclasses import dataclass
from datetime import datetime
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
@@ -33,9 +32,9 @@ class Entity:
type: str
definition: str = ""
canonical_name: str = ""
aliases: List[str] = None
aliases: list[str] = None
embedding: str = "" # Phase 3: 实体嵌入向量
attributes: Dict = None # Phase 5: 实体属性
attributes: dict = None # Phase 5: 实体属性
created_at: str = ""
updated_at: str = ""
@@ -54,7 +53,7 @@ class AttributeTemplate:
project_id: str
name: str
type: str # text, number, date, select, multiselect, boolean
options: List[str] = None # 用于 select/multiselect
options: list[str] = None # 用于 select/multiselect
default_value: str = ""
description: str = ""
is_required: bool = False
@@ -73,11 +72,11 @@ class EntityAttribute:
id: str
entity_id: str
template_id: Optional[str] = None
template_id: str | None = None
name: str = "" # 属性名称
type: str = "text" # 属性类型
value: str = ""
options: List[str] = None # 选项列表
options: list[str] = None # 选项列表
template_name: str = "" # 关联查询时填充
template_type: str = "" # 关联查询时填充
created_at: str = ""
@@ -126,7 +125,7 @@ class DatabaseManager:
def init_db(self):
"""初始化数据库表"""
with open(os.path.join(os.path.dirname(__file__), "schema.sql"), "r") as f:
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
schema = f.read()
conn = self.get_conn()
@@ -147,7 +146,7 @@ class DatabaseManager:
conn.close()
return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now)
def get_project(self, project_id: str) -> Optional[Project]:
def get_project(self, project_id: str) -> Project | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
conn.close()
@@ -155,7 +154,7 @@ class DatabaseManager:
return Project(**dict(row))
return None
def list_projects(self) -> List[Project]:
def list_projects(self) -> list[Project]:
conn = self.get_conn()
rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall()
conn.close()
@@ -184,7 +183,7 @@ class DatabaseManager:
conn.close()
return entity
def get_entity_by_name(self, project_id: str, name: str) -> Optional[Entity]:
def get_entity_by_name(self, project_id: str, name: str) -> Entity | None:
"""通过名称查找实体(用于对齐)"""
conn = self.get_conn()
row = conn.execute(
@@ -198,7 +197,7 @@ class DatabaseManager:
return Entity(**data)
return None
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> List[Entity]:
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> list[Entity]:
"""查找相似实体"""
conn = self.get_conn()
rows = conn.execute(
@@ -245,7 +244,7 @@ class DatabaseManager:
conn.close()
return self.get_entity(target_id)
def get_entity(self, entity_id: str) -> Optional[Entity]:
def get_entity(self, entity_id: str) -> Entity | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
conn.close()
@@ -255,7 +254,7 @@ class DatabaseManager:
return Entity(**data)
return None
def list_project_entities(self, project_id: str) -> List[Entity]:
def list_project_entities(self, project_id: str) -> list[Entity]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,)
@@ -333,7 +332,7 @@ class DatabaseManager:
conn.close()
return mention
def get_entity_mentions(self, entity_id: str) -> List[EntityMention]:
def get_entity_mentions(self, entity_id: str) -> list[EntityMention]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id,)
@@ -355,13 +354,13 @@ class DatabaseManager:
conn.commit()
conn.close()
def get_transcript(self, transcript_id: str) -> Optional[dict]:
def get_transcript(self, transcript_id: str) -> dict | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
conn.close()
return dict(row) if row else None
def list_project_transcripts(self, project_id: str) -> List[dict]:
def list_project_transcripts(self, project_id: str) -> list[dict]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
@@ -404,7 +403,7 @@ class DatabaseManager:
conn.close()
return relation_id
def get_entity_relations(self, entity_id: str) -> List[dict]:
def get_entity_relations(self, entity_id: str) -> list[dict]:
conn = self.get_conn()
rows = conn.execute(
"""SELECT * FROM entity_relations
@@ -415,7 +414,7 @@ class DatabaseManager:
conn.close()
return [dict(r) for r in rows]
def list_project_relations(self, project_id: str) -> List[dict]:
def list_project_relations(self, project_id: str) -> list[dict]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
@@ -473,7 +472,7 @@ class DatabaseManager:
conn.close()
return term_id
def list_glossary(self, project_id: str) -> List[dict]:
def list_glossary(self, project_id: str) -> list[dict]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,)
@@ -489,7 +488,7 @@ class DatabaseManager:
# ==================== Phase 4: Agent & Provenance ====================
def get_relation_with_details(self, relation_id: str) -> Optional[dict]:
def get_relation_with_details(self, relation_id: str) -> dict | None:
conn = self.get_conn()
row = conn.execute(
"""SELECT r.*,
@@ -505,7 +504,7 @@ class DatabaseManager:
conn.close()
return dict(row) if row else None
def get_entity_with_mentions(self, entity_id: str) -> Optional[dict]:
def get_entity_with_mentions(self, entity_id: str) -> dict | None:
conn = self.get_conn()
entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
if not entity_row:
@@ -539,7 +538,7 @@ class DatabaseManager:
conn.close()
return entity
def search_entities(self, project_id: str, query: str) -> List[Entity]:
def search_entities(self, project_id: str, query: str) -> list[Entity]:
conn = self.get_conn()
rows = conn.execute(
"""SELECT * FROM entities
@@ -616,7 +615,7 @@ class DatabaseManager:
def get_project_timeline(
self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None
) -> List[dict]:
) -> list[dict]:
conn = self.get_conn()
conditions = ["t.project_id = ?"]
@@ -722,7 +721,7 @@ class DatabaseManager:
conn.close()
return template
def get_attribute_template(self, template_id: str) -> Optional[AttributeTemplate]:
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
conn = self.get_conn()
row = conn.execute("SELECT * FROM attribute_templates WHERE id = ?", (template_id,)).fetchone()
conn.close()
@@ -732,7 +731,7 @@ class DatabaseManager:
return AttributeTemplate(**data)
return None
def list_attribute_templates(self, project_id: str) -> List[AttributeTemplate]:
def list_attribute_templates(self, project_id: str) -> list[AttributeTemplate]:
conn = self.get_conn()
rows = conn.execute(
"""SELECT * FROM attribute_templates WHERE project_id = ?
@@ -748,7 +747,7 @@ class DatabaseManager:
templates.append(AttributeTemplate(**data))
return templates
def update_attribute_template(self, template_id: str, **kwargs) -> Optional[AttributeTemplate]:
def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None:
conn = self.get_conn()
allowed_fields = ["name", "type", "options", "default_value", "description", "is_required", "sort_order"]
updates = []
@@ -834,7 +833,7 @@ class DatabaseManager:
conn.close()
return attr
def get_entity_attributes(self, entity_id: str) -> List[EntityAttribute]:
def get_entity_attributes(self, entity_id: str) -> list[EntityAttribute]:
conn = self.get_conn()
rows = conn.execute(
"""SELECT ea.*, at.name as template_name, at.type as template_type
@@ -846,7 +845,7 @@ class DatabaseManager:
conn.close()
return [EntityAttribute(**dict(r)) for r in rows]
def get_entity_with_attributes(self, entity_id: str) -> Optional[Entity]:
def get_entity_with_attributes(self, entity_id: str) -> Entity | None:
entity = self.get_entity(entity_id)
if not entity:
return None
@@ -889,7 +888,7 @@ class DatabaseManager:
def get_attribute_history(
self, entity_id: str = None, template_id: str = None, limit: int = 50
) -> List[AttributeHistory]:
) -> list[AttributeHistory]:
conn = self.get_conn()
conditions = []
params = []
@@ -913,7 +912,7 @@ class DatabaseManager:
conn.close()
return [AttributeHistory(**dict(r)) for r in rows]
def search_entities_by_attributes(self, project_id: str, attribute_filters: Dict[str, str]) -> List[Entity]:
def search_entities_by_attributes(self, project_id: str, attribute_filters: dict[str, str]) -> list[Entity]:
entities = self.list_project_entities(project_id)
if not attribute_filters:
return entities
@@ -962,11 +961,11 @@ class DatabaseManager:
filename: str,
duration: float = 0,
fps: float = 0,
resolution: Dict = None,
resolution: dict = None,
audio_transcript_id: str = None,
full_ocr_text: str = "",
extracted_entities: List[Dict] = None,
extracted_relations: List[Dict] = None,
extracted_entities: list[dict] = None,
extracted_relations: list[dict] = None,
) -> str:
"""创建视频记录"""
conn = self.get_conn()
@@ -998,7 +997,7 @@ class DatabaseManager:
conn.close()
return video_id
def get_video(self, video_id: str) -> Optional[Dict]:
def get_video(self, video_id: str) -> dict | None:
"""获取视频信息"""
conn = self.get_conn()
row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone()
@@ -1012,7 +1011,7 @@ class DatabaseManager:
return data
return None
def list_project_videos(self, project_id: str) -> List[Dict]:
def list_project_videos(self, project_id: str) -> list[dict]:
"""获取项目的所有视频"""
conn = self.get_conn()
rows = conn.execute(
@@ -1037,7 +1036,7 @@ class DatabaseManager:
timestamp: float,
image_url: str = None,
ocr_text: str = None,
extracted_entities: List[Dict] = None,
extracted_entities: list[dict] = None,
) -> str:
"""创建视频帧记录"""
conn = self.get_conn()
@@ -1062,7 +1061,7 @@ class DatabaseManager:
conn.close()
return frame_id
def get_video_frames(self, video_id: str) -> List[Dict]:
def get_video_frames(self, video_id: str) -> list[dict]:
"""获取视频的所有帧"""
conn = self.get_conn()
rows = conn.execute(
@@ -1084,8 +1083,8 @@ class DatabaseManager:
filename: str,
ocr_text: str = "",
description: str = "",
extracted_entities: List[Dict] = None,
extracted_relations: List[Dict] = None,
extracted_entities: list[dict] = None,
extracted_relations: list[dict] = None,
) -> str:
"""创建图片记录"""
conn = self.get_conn()
@@ -1113,7 +1112,7 @@ class DatabaseManager:
conn.close()
return image_id
def get_image(self, image_id: str) -> Optional[Dict]:
def get_image(self, image_id: str) -> dict | None:
"""获取图片信息"""
conn = self.get_conn()
row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone()
@@ -1126,7 +1125,7 @@ class DatabaseManager:
return data
return None
def list_project_images(self, project_id: str) -> List[Dict]:
def list_project_images(self, project_id: str) -> list[dict]:
"""获取项目的所有图片"""
conn = self.get_conn()
rows = conn.execute(
@@ -1168,7 +1167,7 @@ class DatabaseManager:
conn.close()
return mention_id
def get_entity_multimodal_mentions(self, entity_id: str) -> List[Dict]:
def get_entity_multimodal_mentions(self, entity_id: str) -> list[dict]:
"""获取实体的多模态提及"""
conn = self.get_conn()
rows = conn.execute(
@@ -1181,7 +1180,7 @@ class DatabaseManager:
conn.close()
return [dict(r) for r in rows]
def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> List[Dict]:
def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]:
"""获取项目的多模态提及"""
conn = self.get_conn()
@@ -1214,7 +1213,7 @@ class DatabaseManager:
link_type: str,
confidence: float = 1.0,
evidence: str = "",
modalities: List[str] = None,
modalities: list[str] = None,
) -> str:
"""创建多模态实体关联"""
conn = self.get_conn()
@@ -1231,7 +1230,7 @@ class DatabaseManager:
conn.close()
return link_id
def get_entity_multimodal_links(self, entity_id: str) -> List[Dict]:
def get_entity_multimodal_links(self, entity_id: str) -> list[dict]:
"""获取实体的多模态关联"""
conn = self.get_conn()
rows = conn.execute(
@@ -1251,7 +1250,7 @@ class DatabaseManager:
links.append(data)
return links
def get_project_multimodal_stats(self, project_id: str) -> Dict:
def get_project_multimodal_stats(self, project_id: str) -> dict:
"""获取项目多模态统计信息"""
conn = self.get_conn()

View File

@@ -10,20 +10,19 @@ InsightFlow Developer Ecosystem Manager - Phase 8 Task 6
作者: InsightFlow Team
"""
import os
import json
import os
import sqlite3
import uuid
from typing import List, Dict, Optional
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from enum import StrEnum
# Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class SDKLanguage(str, Enum):
class SDKLanguage(StrEnum):
"""SDK 语言类型"""
PYTHON = "python"
@@ -34,7 +33,7 @@ class SDKLanguage(str, Enum):
RUST = "rust"
class SDKStatus(str, Enum):
class SDKStatus(StrEnum):
"""SDK 状态"""
DRAFT = "draft" # 草稿
@@ -44,7 +43,7 @@ class SDKStatus(str, Enum):
ARCHIVED = "archived" # 已归档
class TemplateCategory(str, Enum):
class TemplateCategory(StrEnum):
"""模板分类"""
MEDICAL = "medical" # 医疗
@@ -55,7 +54,7 @@ class TemplateCategory(str, Enum):
GENERAL = "general" # 通用
class TemplateStatus(str, Enum):
class TemplateStatus(StrEnum):
"""模板状态"""
PENDING = "pending" # 待审核
@@ -65,7 +64,7 @@ class TemplateStatus(str, Enum):
UNLISTED = "unlisted" # 未列出
class PluginStatus(str, Enum):
class PluginStatus(StrEnum):
"""插件状态"""
PENDING = "pending" # 待审核
@@ -76,7 +75,7 @@ class PluginStatus(str, Enum):
SUSPENDED = "suspended" # 已暂停
class PluginCategory(str, Enum):
class PluginCategory(StrEnum):
"""插件分类"""
INTEGRATION = "integration" # 集成
@@ -87,7 +86,7 @@ class PluginCategory(str, Enum):
CUSTOM = "custom" # 自定义
class DeveloperStatus(str, Enum):
class DeveloperStatus(StrEnum):
"""开发者认证状态"""
UNVERIFIED = "unverified" # 未认证
@@ -113,13 +112,13 @@ class SDKRelease:
package_name: str # pip/npm/go module name
status: SDKStatus
min_platform_version: str
dependencies: List[Dict] # [{"name": "requests", "version": ">=2.0"}]
dependencies: list[dict] # [{"name": "requests", "version": ">=2.0"}]
file_size: int
checksum: str
download_count: int
created_at: str
updated_at: str
published_at: Optional[str]
published_at: str | None
created_by: str
@@ -148,17 +147,17 @@ class TemplateMarketItem:
name: str
description: str
category: TemplateCategory
subcategory: Optional[str]
tags: List[str]
subcategory: str | None
tags: list[str]
author_id: str
author_name: str
status: TemplateStatus
price: float # 0 = 免费
currency: str
preview_image_url: Optional[str]
demo_url: Optional[str]
documentation_url: Optional[str]
download_url: Optional[str]
preview_image_url: str | None
demo_url: str | None
documentation_url: str | None
download_url: str | None
install_count: int
rating: float
rating_count: int
@@ -169,7 +168,7 @@ class TemplateMarketItem:
checksum: str
created_at: str
updated_at: str
published_at: Optional[str]
published_at: str | None
@dataclass
@@ -196,20 +195,20 @@ class PluginMarketItem:
name: str
description: str
category: PluginCategory
tags: List[str]
tags: list[str]
author_id: str
author_name: str
status: PluginStatus
price: float
currency: str
pricing_model: str # free, paid, freemium, subscription
preview_image_url: Optional[str]
demo_url: Optional[str]
documentation_url: Optional[str]
repository_url: Optional[str]
download_url: Optional[str]
webhook_url: Optional[str] # 用于插件回调
permissions: List[str] # 需要的权限列表
preview_image_url: str | None
demo_url: str | None
documentation_url: str | None
repository_url: str | None
download_url: str | None
webhook_url: str | None # 用于插件回调
permissions: list[str] # 需要的权限列表
install_count: int
active_install_count: int
rating: float
@@ -221,10 +220,10 @@ class PluginMarketItem:
checksum: str
created_at: str
updated_at: str
published_at: Optional[str]
reviewed_by: Optional[str]
reviewed_at: Optional[str]
review_notes: Optional[str]
published_at: str | None
reviewed_by: str | None
reviewed_at: str | None
review_notes: str | None
@dataclass
@@ -251,12 +250,12 @@ class DeveloperProfile:
user_id: str
display_name: str
email: str
bio: Optional[str]
website: Optional[str]
github_url: Optional[str]
avatar_url: Optional[str]
bio: str | None
website: str | None
github_url: str | None
avatar_url: str | None
status: DeveloperStatus
verification_documents: Dict # 认证文档
verification_documents: dict # 认证文档
total_sales: float
total_downloads: int
plugin_count: int
@@ -264,7 +263,7 @@ class DeveloperProfile:
rating_average: float
created_at: str
updated_at: str
verified_at: Optional[str]
verified_at: str | None
@dataclass
@@ -296,11 +295,11 @@ class CodeExample:
category: str
code: str
explanation: str
tags: List[str]
tags: list[str]
author_id: str
author_name: str
sdk_id: Optional[str] # 关联的 SDK
api_endpoints: List[str] # 涉及的 API 端点
sdk_id: str | None # 关联的 SDK
api_endpoints: list[str] # 涉及的 API 端点
view_count: int
copy_count: int
rating: float
@@ -330,16 +329,16 @@ class DeveloperPortalConfig:
name: str
description: str
theme: str
custom_css: Optional[str]
custom_js: Optional[str]
logo_url: Optional[str]
favicon_url: Optional[str]
custom_css: str | None
custom_js: str | None
logo_url: str | None
favicon_url: str | None
primary_color: str
secondary_color: str
support_email: str
support_url: Optional[str]
github_url: Optional[str]
discord_url: Optional[str]
support_url: str | None
github_url: str | None
discord_url: str | None
api_base_url: str
is_active: bool
created_at: str
@@ -373,7 +372,7 @@ class DeveloperEcosystemManager:
repository_url: str,
package_name: str,
min_platform_version: str,
dependencies: List[Dict],
dependencies: list[dict],
file_size: int,
checksum: str,
created_by: str,
@@ -442,7 +441,7 @@ class DeveloperEcosystemManager:
return sdk
def get_sdk_release(self, sdk_id: str) -> Optional[SDKRelease]:
def get_sdk_release(self, sdk_id: str) -> SDKRelease | None:
"""获取 SDK 发布详情"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM sdk_releases WHERE id = ?", (sdk_id,)).fetchone()
@@ -452,8 +451,8 @@ class DeveloperEcosystemManager:
return None
def list_sdk_releases(
self, language: Optional[SDKLanguage] = None, status: Optional[SDKStatus] = None, search: Optional[str] = None
) -> List[SDKRelease]:
self, language: SDKLanguage | None = None, status: SDKStatus | None = None, search: str | None = None
) -> list[SDKRelease]:
"""列出 SDK 发布"""
query = "SELECT * FROM sdk_releases WHERE 1=1"
params = []
@@ -474,7 +473,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_sdk_release(row) for row in rows]
def update_sdk_release(self, sdk_id: str, **kwargs) -> Optional[SDKRelease]:
def update_sdk_release(self, sdk_id: str, **kwargs) -> SDKRelease | None:
"""更新 SDK 发布"""
allowed_fields = [
"name",
@@ -499,7 +498,7 @@ class DeveloperEcosystemManager:
return self.get_sdk_release(sdk_id)
def publish_sdk_release(self, sdk_id: str) -> Optional[SDKRelease]:
def publish_sdk_release(self, sdk_id: str) -> SDKRelease | None:
"""发布 SDK"""
now = datetime.now().isoformat()
@@ -529,7 +528,7 @@ class DeveloperEcosystemManager:
)
conn.commit()
def get_sdk_versions(self, sdk_id: str) -> List[SDKVersion]:
def get_sdk_versions(self, sdk_id: str) -> list[SDKVersion]:
"""获取 SDK 版本历史"""
with self._get_db() as conn:
rows = conn.execute(
@@ -588,16 +587,16 @@ class DeveloperEcosystemManager:
name: str,
description: str,
category: TemplateCategory,
subcategory: Optional[str],
tags: List[str],
subcategory: str | None,
tags: list[str],
author_id: str,
author_name: str,
price: float = 0.0,
currency: str = "CNY",
preview_image_url: Optional[str] = None,
demo_url: Optional[str] = None,
documentation_url: Optional[str] = None,
download_url: Optional[str] = None,
preview_image_url: str | None = None,
demo_url: str | None = None,
documentation_url: str | None = None,
download_url: str | None = None,
version: str = "1.0.0",
min_platform_version: str = "1.0.0",
file_size: int = 0,
@@ -679,7 +678,7 @@ class DeveloperEcosystemManager:
return template
def get_template(self, template_id: str) -> Optional[TemplateMarketItem]:
def get_template(self, template_id: str) -> TemplateMarketItem | None:
"""获取模板详情"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM template_market WHERE id = ?", (template_id,)).fetchone()
@@ -690,14 +689,14 @@ class DeveloperEcosystemManager:
def list_templates(
self,
category: Optional[TemplateCategory] = None,
status: Optional[TemplateStatus] = None,
search: Optional[str] = None,
author_id: Optional[str] = None,
min_price: Optional[float] = None,
max_price: Optional[float] = None,
category: TemplateCategory | None = None,
status: TemplateStatus | None = None,
search: str | None = None,
author_id: str | None = None,
min_price: float | None = None,
max_price: float | None = None,
sort_by: str = "created_at",
) -> List[TemplateMarketItem]:
) -> list[TemplateMarketItem]:
"""列出模板"""
query = "SELECT * FROM template_market WHERE 1=1"
params = []
@@ -735,7 +734,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_template(row) for row in rows]
def approve_template(self, template_id: str, reviewed_by: str) -> Optional[TemplateMarketItem]:
def approve_template(self, template_id: str, reviewed_by: str) -> TemplateMarketItem | None:
"""审核通过模板"""
now = datetime.now().isoformat()
@@ -752,7 +751,7 @@ class DeveloperEcosystemManager:
return self.get_template(template_id)
def publish_template(self, template_id: str) -> Optional[TemplateMarketItem]:
def publish_template(self, template_id: str) -> TemplateMarketItem | None:
"""发布模板"""
now = datetime.now().isoformat()
@@ -769,7 +768,7 @@ class DeveloperEcosystemManager:
return self.get_template(template_id)
def reject_template(self, template_id: str, reason: str) -> Optional[TemplateMarketItem]:
def reject_template(self, template_id: str, reason: str) -> TemplateMarketItem | None:
"""拒绝模板"""
now = datetime.now().isoformat()
@@ -874,7 +873,7 @@ class DeveloperEcosystemManager:
(round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], template_id),
)
def get_template_reviews(self, template_id: str, limit: int = 50) -> List[TemplateReview]:
def get_template_reviews(self, template_id: str, limit: int = 50) -> list[TemplateReview]:
"""获取模板评价"""
with self._get_db() as conn:
rows = conn.execute(
@@ -893,19 +892,19 @@ class DeveloperEcosystemManager:
name: str,
description: str,
category: PluginCategory,
tags: List[str],
tags: list[str],
author_id: str,
author_name: str,
price: float = 0.0,
currency: str = "CNY",
pricing_model: str = "free",
preview_image_url: Optional[str] = None,
demo_url: Optional[str] = None,
documentation_url: Optional[str] = None,
repository_url: Optional[str] = None,
download_url: Optional[str] = None,
webhook_url: Optional[str] = None,
permissions: List[str] = None,
preview_image_url: str | None = None,
demo_url: str | None = None,
documentation_url: str | None = None,
repository_url: str | None = None,
download_url: str | None = None,
webhook_url: str | None = None,
permissions: list[str] = None,
version: str = "1.0.0",
min_platform_version: str = "1.0.0",
file_size: int = 0,
@@ -1003,7 +1002,7 @@ class DeveloperEcosystemManager:
return plugin
def get_plugin(self, plugin_id: str) -> Optional[PluginMarketItem]:
def get_plugin(self, plugin_id: str) -> PluginMarketItem | None:
"""获取插件详情"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM plugin_market WHERE id = ?", (plugin_id,)).fetchone()
@@ -1014,12 +1013,12 @@ class DeveloperEcosystemManager:
def list_plugins(
self,
category: Optional[PluginCategory] = None,
status: Optional[PluginStatus] = None,
search: Optional[str] = None,
author_id: Optional[str] = None,
category: PluginCategory | None = None,
status: PluginStatus | None = None,
search: str | None = None,
author_id: str | None = None,
sort_by: str = "created_at",
) -> List[PluginMarketItem]:
) -> list[PluginMarketItem]:
"""列出插件"""
query = "SELECT * FROM plugin_market WHERE 1=1"
params = []
@@ -1051,7 +1050,7 @@ class DeveloperEcosystemManager:
def review_plugin(
self, plugin_id: str, reviewed_by: str, status: PluginStatus, notes: str = ""
) -> Optional[PluginMarketItem]:
) -> PluginMarketItem | None:
"""审核插件"""
now = datetime.now().isoformat()
@@ -1068,7 +1067,7 @@ class DeveloperEcosystemManager:
return self.get_plugin(plugin_id)
def publish_plugin(self, plugin_id: str) -> Optional[PluginMarketItem]:
def publish_plugin(self, plugin_id: str) -> PluginMarketItem | None:
"""发布插件"""
now = datetime.now().isoformat()
@@ -1182,7 +1181,7 @@ class DeveloperEcosystemManager:
(round(row["avg_rating"], 2) if row["avg_rating"] else 0, row["count"], row["count"], plugin_id),
)
def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> List[PluginReview]:
def get_plugin_reviews(self, plugin_id: str, limit: int = 50) -> list[PluginReview]:
"""获取插件评价"""
with self._get_db() as conn:
rows = conn.execute(
@@ -1268,8 +1267,8 @@ class DeveloperEcosystemManager:
return revenue
def get_developer_revenues(
self, developer_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
) -> List[DeveloperRevenue]:
self, developer_id: str, start_date: datetime | None = None, end_date: datetime | None = None
) -> list[DeveloperRevenue]:
"""获取开发者收益记录"""
query = "SELECT * FROM developer_revenues WHERE developer_id = ?"
params = [developer_id]
@@ -1287,7 +1286,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_developer_revenue(row) for row in rows]
def get_developer_revenue_summary(self, developer_id: str) -> Dict:
def get_developer_revenue_summary(self, developer_id: str) -> dict:
"""获取开发者收益汇总"""
with self._get_db() as conn:
row = conn.execute(
@@ -1318,10 +1317,10 @@ class DeveloperEcosystemManager:
user_id: str,
display_name: str,
email: str,
bio: Optional[str] = None,
website: Optional[str] = None,
github_url: Optional[str] = None,
avatar_url: Optional[str] = None,
bio: str | None = None,
website: str | None = None,
github_url: str | None = None,
avatar_url: str | None = None,
) -> DeveloperProfile:
"""创建开发者档案"""
profile_id = f"dev_{uuid.uuid4().hex[:16]}"
@@ -1382,7 +1381,7 @@ class DeveloperEcosystemManager:
return profile
def get_developer_profile(self, developer_id: str) -> Optional[DeveloperProfile]:
def get_developer_profile(self, developer_id: str) -> DeveloperProfile | None:
"""获取开发者档案"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_profiles WHERE id = ?", (developer_id,)).fetchone()
@@ -1391,7 +1390,7 @@ class DeveloperEcosystemManager:
return self._row_to_developer_profile(row)
return None
def get_developer_profile_by_user(self, user_id: str) -> Optional[DeveloperProfile]:
def get_developer_profile_by_user(self, user_id: str) -> DeveloperProfile | None:
"""通过用户 ID 获取开发者档案"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_profiles WHERE user_id = ?", (user_id,)).fetchone()
@@ -1400,7 +1399,7 @@ class DeveloperEcosystemManager:
return self._row_to_developer_profile(row)
return None
def verify_developer(self, developer_id: str, status: DeveloperStatus) -> Optional[DeveloperProfile]:
def verify_developer(self, developer_id: str, status: DeveloperStatus) -> DeveloperProfile | None:
"""验证开发者"""
now = datetime.now().isoformat()
@@ -1473,11 +1472,11 @@ class DeveloperEcosystemManager:
category: str,
code: str,
explanation: str,
tags: List[str],
tags: list[str],
author_id: str,
author_name: str,
sdk_id: Optional[str] = None,
api_endpoints: List[str] = None,
sdk_id: str | None = None,
api_endpoints: list[str] = None,
) -> CodeExample:
"""创建代码示例"""
example_id = f"ex_{uuid.uuid4().hex[:16]}"
@@ -1536,7 +1535,7 @@ class DeveloperEcosystemManager:
return example
def get_code_example(self, example_id: str) -> Optional[CodeExample]:
def get_code_example(self, example_id: str) -> CodeExample | None:
"""获取代码示例"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM code_examples WHERE id = ?", (example_id,)).fetchone()
@@ -1547,11 +1546,11 @@ class DeveloperEcosystemManager:
def list_code_examples(
self,
language: Optional[str] = None,
category: Optional[str] = None,
sdk_id: Optional[str] = None,
search: Optional[str] = None,
) -> List[CodeExample]:
language: str | None = None,
category: str | None = None,
sdk_id: str | None = None,
search: str | None = None,
) -> list[CodeExample]:
"""列出代码示例"""
query = "SELECT * FROM code_examples WHERE 1=1"
params = []
@@ -1650,7 +1649,7 @@ class DeveloperEcosystemManager:
return doc
def get_api_documentation(self, doc_id: str) -> Optional[APIDocumentation]:
def get_api_documentation(self, doc_id: str) -> APIDocumentation | None:
"""获取 API 文档"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM api_documentation WHERE id = ?", (doc_id,)).fetchone()
@@ -1659,7 +1658,7 @@ class DeveloperEcosystemManager:
return self._row_to_api_documentation(row)
return None
def get_latest_api_documentation(self) -> Optional[APIDocumentation]:
def get_latest_api_documentation(self) -> APIDocumentation | None:
"""获取最新 API 文档"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM api_documentation ORDER BY generated_at DESC LIMIT 1").fetchone()
@@ -1675,16 +1674,16 @@ class DeveloperEcosystemManager:
name: str,
description: str,
theme: str = "default",
custom_css: Optional[str] = None,
custom_js: Optional[str] = None,
logo_url: Optional[str] = None,
favicon_url: Optional[str] = None,
custom_css: str | None = None,
custom_js: str | None = None,
logo_url: str | None = None,
favicon_url: str | None = None,
primary_color: str = "#1890ff",
secondary_color: str = "#52c41a",
support_email: str = "support@insightflow.io",
support_url: Optional[str] = None,
github_url: Optional[str] = None,
discord_url: Optional[str] = None,
support_url: str | None = None,
github_url: str | None = None,
discord_url: str | None = None,
api_base_url: str = "https://api.insightflow.io",
) -> DeveloperPortalConfig:
"""创建开发者门户配置"""
@@ -1746,7 +1745,7 @@ class DeveloperEcosystemManager:
return config
def get_portal_config(self, config_id: str) -> Optional[DeveloperPortalConfig]:
def get_portal_config(self, config_id: str) -> DeveloperPortalConfig | None:
"""获取开发者门户配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_portal_configs WHERE id = ?", (config_id,)).fetchone()
@@ -1755,7 +1754,7 @@ class DeveloperEcosystemManager:
return self._row_to_portal_config(row)
return None
def get_active_portal_config(self) -> Optional[DeveloperPortalConfig]:
def get_active_portal_config(self) -> DeveloperPortalConfig | None:
"""获取活跃的开发者门户配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM developer_portal_configs WHERE is_active = 1 LIMIT 1").fetchone()

View File

@@ -4,9 +4,8 @@ Document Processor - Phase 3
支持 PDF 和 DOCX 文档导入
"""
import os
import io
from typing import Dict
import os
class DocumentProcessor:
@@ -21,7 +20,7 @@ class DocumentProcessor:
".md": self._extract_txt,
}
def process(self, content: bytes, filename: str) -> Dict[str, str]:
def process(self, content: bytes, filename: str) -> dict[str, str]:
"""
处理文档并提取文本

View File

@@ -10,19 +10,19 @@ InsightFlow Phase 8 - 企业级功能管理模块
作者: InsightFlow Team
"""
import sqlite3
import json
import uuid
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
class SSOProvider(str, Enum):
class SSOProvider(StrEnum):
"""SSO 提供商类型"""
WECHAT_WORK = "wechat_work" # 企业微信
@@ -34,7 +34,7 @@ class SSOProvider(str, Enum):
CUSTOM_SAML = "custom_saml" # 自定义 SAML
class SSOStatus(str, Enum):
class SSOStatus(StrEnum):
"""SSO 配置状态"""
DISABLED = "disabled" # 未启用
@@ -43,7 +43,7 @@ class SSOStatus(str, Enum):
ERROR = "error" # 配置错误
class SCIMSyncStatus(str, Enum):
class SCIMSyncStatus(StrEnum):
"""SCIM 同步状态"""
IDLE = "idle" # 空闲
@@ -52,7 +52,7 @@ class SCIMSyncStatus(str, Enum):
FAILED = "failed" # 同步失败
class AuditLogExportFormat(str, Enum):
class AuditLogExportFormat(StrEnum):
"""审计日志导出格式"""
JSON = "json"
@@ -61,7 +61,7 @@ class AuditLogExportFormat(str, Enum):
XLSX = "xlsx"
class DataRetentionAction(str, Enum):
class DataRetentionAction(StrEnum):
"""数据保留策略动作"""
ARCHIVE = "archive" # 归档
@@ -69,7 +69,7 @@ class DataRetentionAction(str, Enum):
ANONYMIZE = "anonymize" # 匿名化
class ComplianceStandard(str, Enum):
class ComplianceStandard(StrEnum):
"""合规标准"""
SOC2 = "soc2"
@@ -87,29 +87,29 @@ class SSOConfig:
tenant_id: str
provider: str # SSO 提供商
status: str # 状态
entity_id: Optional[str] # SAML Entity ID
sso_url: Optional[str] # SAML SSO URL
slo_url: Optional[str] # SAML SLO URL
certificate: Optional[str] # SAML 证书 (X.509)
metadata_url: Optional[str] # SAML 元数据 URL
metadata_xml: Optional[str] # SAML 元数据 XML
entity_id: str | None # SAML Entity ID
sso_url: str | None # SAML SSO URL
slo_url: str | None # SAML SLO URL
certificate: str | None # SAML 证书 (X.509)
metadata_url: str | None # SAML 元数据 URL
metadata_xml: str | None # SAML 元数据 XML
# OAuth/OIDC 配置
client_id: Optional[str]
client_secret: Optional[str]
authorization_url: Optional[str]
token_url: Optional[str]
userinfo_url: Optional[str]
scopes: List[str]
client_id: str | None
client_secret: str | None
authorization_url: str | None
token_url: str | None
userinfo_url: str | None
scopes: list[str]
# 属性映射
attribute_mapping: Dict[str, str] # 如 {"email": "user.mail", "name": "user.name"}
attribute_mapping: dict[str, str] # 如 {"email": "user.mail", "name": "user.name"}
# 其他配置
auto_provision: bool # 自动创建用户
default_role: str # 默认角色
domain_restriction: List[str] # 允许的邮箱域名
domain_restriction: list[str] # 允许的邮箱域名
created_at: datetime
updated_at: datetime
last_tested_at: Optional[datetime]
last_error: Optional[str]
last_tested_at: datetime | None
last_error: str | None
@dataclass
@@ -125,14 +125,14 @@ class SCIMConfig:
scim_token: str # SCIM 访问令牌
# 同步配置
sync_interval_minutes: int # 同步间隔(分钟)
last_sync_at: Optional[datetime]
last_sync_status: Optional[str]
last_sync_error: Optional[str]
last_sync_at: datetime | None
last_sync_status: str | None
last_sync_error: str | None
last_sync_users_count: int
# 属性映射
attribute_mapping: Dict[str, str]
attribute_mapping: dict[str, str]
# 同步规则
sync_rules: Dict[str, Any] # 过滤规则、转换规则等
sync_rules: dict[str, Any] # 过滤规则、转换规则等
created_at: datetime
updated_at: datetime
@@ -146,12 +146,12 @@ class SCIMUser:
external_id: str # 外部系统 ID
user_name: str
email: str
display_name: Optional[str]
given_name: Optional[str]
family_name: Optional[str]
display_name: str | None
given_name: str | None
family_name: str | None
active: bool
groups: List[str]
raw_data: Dict[str, Any] # 原始 SCIM 数据
groups: list[str]
raw_data: dict[str, Any] # 原始 SCIM 数据
synced_at: datetime
created_at: datetime
updated_at: datetime
@@ -166,20 +166,20 @@ class AuditLogExport:
export_format: str
start_date: datetime
end_date: datetime
filters: Dict[str, Any] # 过滤条件
compliance_standard: Optional[str]
filters: dict[str, Any] # 过滤条件
compliance_standard: str | None
status: str # pending/processing/completed/failed
file_path: Optional[str]
file_size: Optional[int]
record_count: Optional[int]
checksum: Optional[str] # 文件校验和
downloaded_by: Optional[str]
downloaded_at: Optional[datetime]
expires_at: Optional[datetime] # 文件过期时间
file_path: str | None
file_size: int | None
record_count: int | None
checksum: str | None # 文件校验和
downloaded_by: str | None
downloaded_at: datetime | None
expires_at: datetime | None # 文件过期时间
created_by: str
created_at: datetime
completed_at: Optional[datetime]
error_message: Optional[str]
completed_at: datetime | None
error_message: str | None
@dataclass
@@ -189,23 +189,23 @@ class DataRetentionPolicy:
id: str
tenant_id: str
name: str
description: Optional[str]
description: str | None
resource_type: str # project/transcript/entity/audit_log/user_data
retention_days: int # 保留天数
action: str # archive/delete/anonymize
# 条件
conditions: Dict[str, Any] # 触发条件
conditions: dict[str, Any] # 触发条件
# 执行配置
auto_execute: bool # 自动执行
execute_at: Optional[str] # 执行时间 (cron 表达式)
execute_at: str | None # 执行时间 (cron 表达式)
notify_before_days: int # 提前通知天数
# 归档配置
archive_location: Optional[str] # 归档位置
archive_location: str | None # 归档位置
archive_encryption: bool # 归档加密
# 状态
is_active: bool
last_executed_at: Optional[datetime]
last_execution_result: Optional[str]
last_executed_at: datetime | None
last_execution_result: str | None
created_at: datetime
updated_at: datetime
@@ -218,13 +218,13 @@ class DataRetentionJob:
policy_id: str
tenant_id: str
status: str # pending/running/completed/failed
started_at: Optional[datetime]
completed_at: Optional[datetime]
started_at: datetime | None
completed_at: datetime | None
affected_records: int
archived_records: int
deleted_records: int
error_count: int
details: Dict[str, Any]
details: dict[str, Any]
created_at: datetime
@@ -236,11 +236,11 @@ class SAMLAuthRequest:
tenant_id: str
sso_config_id: str
request_id: str # SAML Request ID
relay_state: Optional[str]
relay_state: str | None
created_at: datetime
expires_at: datetime
used: bool
used_at: Optional[datetime]
used_at: datetime | None
@dataclass
@@ -250,13 +250,13 @@ class SAMLAuthResponse:
id: str
request_id: str
tenant_id: str
user_id: Optional[str]
email: Optional[str]
name: Optional[str]
attributes: Dict[str, Any]
session_index: Optional[str]
user_id: str | None
email: str | None
name: str | None
attributes: dict[str, Any]
session_index: str | None
processed: bool
processed_at: Optional[datetime]
processed_at: datetime | None
created_at: datetime
@@ -548,22 +548,22 @@ class EnterpriseManager:
self,
tenant_id: str,
provider: str,
entity_id: Optional[str] = None,
sso_url: Optional[str] = None,
slo_url: Optional[str] = None,
certificate: Optional[str] = None,
metadata_url: Optional[str] = None,
metadata_xml: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
authorization_url: Optional[str] = None,
token_url: Optional[str] = None,
userinfo_url: Optional[str] = None,
scopes: Optional[List[str]] = None,
attribute_mapping: Optional[Dict[str, str]] = None,
entity_id: str | None = None,
sso_url: str | None = None,
slo_url: str | None = None,
certificate: str | None = None,
metadata_url: str | None = None,
metadata_xml: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
authorization_url: str | None = None,
token_url: str | None = None,
userinfo_url: str | None = None,
scopes: list[str] | None = None,
attribute_mapping: dict[str, str] | None = None,
auto_provision: bool = True,
default_role: str = "member",
domain_restriction: Optional[List[str]] = None,
domain_restriction: list[str] | None = None,
) -> SSOConfig:
"""创建 SSO 配置"""
conn = self._get_connection()
@@ -649,7 +649,7 @@ class EnterpriseManager:
finally:
conn.close()
def get_sso_config(self, config_id: str) -> Optional[SSOConfig]:
def get_sso_config(self, config_id: str) -> SSOConfig | None:
"""获取 SSO 配置"""
conn = self._get_connection()
try:
@@ -664,7 +664,7 @@ class EnterpriseManager:
finally:
conn.close()
def get_tenant_sso_config(self, tenant_id: str, provider: Optional[str] = None) -> Optional[SSOConfig]:
def get_tenant_sso_config(self, tenant_id: str, provider: str | None = None) -> SSOConfig | None:
"""获取租户的 SSO 配置"""
conn = self._get_connection()
try:
@@ -698,7 +698,7 @@ class EnterpriseManager:
finally:
conn.close()
def update_sso_config(self, config_id: str, **kwargs) -> Optional[SSOConfig]:
def update_sso_config(self, config_id: str, **kwargs) -> SSOConfig | None:
"""更新 SSO 配置"""
conn = self._get_connection()
try:
@@ -772,7 +772,7 @@ class EnterpriseManager:
finally:
conn.close()
def list_sso_configs(self, tenant_id: str) -> List[SSOConfig]:
def list_sso_configs(self, tenant_id: str) -> list[SSOConfig]:
"""列出租户的所有 SSO 配置"""
conn = self._get_connection()
try:
@@ -835,7 +835,7 @@ class EnterpriseManager:
return metadata
def create_saml_auth_request(
self, tenant_id: str, config_id: str, relay_state: Optional[str] = None
self, tenant_id: str, config_id: str, relay_state: str | None = None
) -> SAMLAuthRequest:
"""创建 SAML 认证请求"""
conn = self._get_connection()
@@ -881,7 +881,7 @@ class EnterpriseManager:
finally:
conn.close()
def get_saml_auth_request(self, request_id: str) -> Optional[SAMLAuthRequest]:
def get_saml_auth_request(self, request_id: str) -> SAMLAuthRequest | None:
"""获取 SAML 认证请求"""
conn = self._get_connection()
try:
@@ -901,7 +901,7 @@ class EnterpriseManager:
finally:
conn.close()
def process_saml_response(self, request_id: str, saml_response: str) -> Optional[SAMLAuthResponse]:
def process_saml_response(self, request_id: str, saml_response: str) -> SAMLAuthResponse | None:
"""处理 SAML 响应"""
# 这里应该实现实际的 SAML 响应解析
# 简化实现:假设响应已经验证并解析
@@ -954,7 +954,7 @@ class EnterpriseManager:
finally:
conn.close()
def _parse_saml_response(self, saml_response: str) -> Dict[str, Any]:
def _parse_saml_response(self, saml_response: str) -> dict[str, Any]:
"""解析 SAML 响应(简化实现)"""
# 实际应该使用 python-saml 库解析
# 这里返回模拟数据
@@ -974,8 +974,8 @@ class EnterpriseManager:
scim_base_url: str,
scim_token: str,
sync_interval_minutes: int = 60,
attribute_mapping: Optional[Dict[str, str]] = None,
sync_rules: Optional[Dict[str, Any]] = None,
attribute_mapping: dict[str, str] | None = None,
sync_rules: dict[str, Any] | None = None,
) -> SCIMConfig:
"""创建 SCIM 配置"""
conn = self._get_connection()
@@ -1035,7 +1035,7 @@ class EnterpriseManager:
finally:
conn.close()
def get_scim_config(self, config_id: str) -> Optional[SCIMConfig]:
def get_scim_config(self, config_id: str) -> SCIMConfig | None:
"""获取 SCIM 配置"""
conn = self._get_connection()
try:
@@ -1050,7 +1050,7 @@ class EnterpriseManager:
finally:
conn.close()
def get_tenant_scim_config(self, tenant_id: str) -> Optional[SCIMConfig]:
def get_tenant_scim_config(self, tenant_id: str) -> SCIMConfig | None:
"""获取租户的 SCIM 配置"""
conn = self._get_connection()
try:
@@ -1071,7 +1071,7 @@ class EnterpriseManager:
finally:
conn.close()
def update_scim_config(self, config_id: str, **kwargs) -> Optional[SCIMConfig]:
def update_scim_config(self, config_id: str, **kwargs) -> SCIMConfig | None:
"""更新 SCIM 配置"""
conn = self._get_connection()
try:
@@ -1121,7 +1121,7 @@ class EnterpriseManager:
finally:
conn.close()
def sync_scim_users(self, config_id: str) -> Dict[str, Any]:
def sync_scim_users(self, config_id: str) -> dict[str, Any]:
"""执行 SCIM 用户同步"""
config = self.get_scim_config(config_id)
if not config:
@@ -1184,13 +1184,13 @@ class EnterpriseManager:
finally:
conn.close()
def _fetch_scim_users(self, config: SCIMConfig) -> List[Dict[str, Any]]:
def _fetch_scim_users(self, config: SCIMConfig) -> list[dict[str, Any]]:
"""从 SCIM 服务端获取用户(模拟实现)"""
# 实际应该使用 HTTP 请求获取
# GET {scim_base_url}/Users
return []
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: Dict[str, Any]):
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]):
"""插入或更新 SCIM 用户"""
cursor = conn.cursor()
@@ -1238,7 +1238,7 @@ class EnterpriseManager:
),
)
def list_scim_users(self, tenant_id: str, active_only: bool = True) -> List[SCIMUser]:
def list_scim_users(self, tenant_id: str, active_only: bool = True) -> list[SCIMUser]:
"""列出 SCIM 用户"""
conn = self._get_connection()
try:
@@ -1269,8 +1269,8 @@ class EnterpriseManager:
start_date: datetime,
end_date: datetime,
created_by: str,
filters: Optional[Dict[str, Any]] = None,
compliance_standard: Optional[str] = None,
filters: dict[str, Any] | None = None,
compliance_standard: str | None = None,
) -> AuditLogExport:
"""创建审计日志导出任务"""
conn = self._get_connection()
@@ -1337,7 +1337,7 @@ class EnterpriseManager:
finally:
conn.close()
def process_audit_export(self, export_id: str, db_manager=None) -> Optional[AuditLogExport]:
def process_audit_export(self, export_id: str, db_manager=None) -> AuditLogExport | None:
"""处理审计日志导出任务"""
export = self.get_audit_export(export_id)
if not export:
@@ -1401,8 +1401,8 @@ class EnterpriseManager:
conn.close()
def _fetch_audit_logs(
self, tenant_id: str, start_date: datetime, end_date: datetime, filters: Dict[str, Any], db_manager=None
) -> List[Dict[str, Any]]:
self, tenant_id: str, start_date: datetime, end_date: datetime, filters: dict[str, Any], db_manager=None
) -> list[dict[str, Any]]:
"""获取审计日志数据"""
if db_manager is None:
return []
@@ -1411,7 +1411,7 @@ class EnterpriseManager:
# 这里简化实现
return []
def _apply_compliance_filter(self, logs: List[Dict[str, Any]], standard: str) -> List[Dict[str, Any]]:
def _apply_compliance_filter(self, logs: list[dict[str, Any]], standard: str) -> list[dict[str, Any]]:
"""应用合规标准字段过滤"""
fields = self.COMPLIANCE_FIELDS.get(ComplianceStandard(standard), [])
@@ -1425,10 +1425,10 @@ class EnterpriseManager:
return filtered_logs
def _generate_export_file(self, export_id: str, logs: List[Dict[str, Any]], format: str) -> Tuple[str, int, str]:
def _generate_export_file(self, export_id: str, logs: list[dict[str, Any]], format: str) -> tuple[str, int, str]:
"""生成导出文件"""
import os
import hashlib
import os
export_dir = "/tmp/insightflow/exports"
os.makedirs(export_dir, exist_ok=True)
@@ -1461,7 +1461,7 @@ class EnterpriseManager:
return file_path, file_size, checksum
def get_audit_export(self, export_id: str) -> Optional[AuditLogExport]:
def get_audit_export(self, export_id: str) -> AuditLogExport | None:
"""获取审计日志导出记录"""
conn = self._get_connection()
try:
@@ -1476,7 +1476,7 @@ class EnterpriseManager:
finally:
conn.close()
def list_audit_exports(self, tenant_id: str, limit: int = 100) -> List[AuditLogExport]:
def list_audit_exports(self, tenant_id: str, limit: int = 100) -> list[AuditLogExport]:
"""列出审计日志导出记录"""
conn = self._get_connection()
try:
@@ -1524,12 +1524,12 @@ class EnterpriseManager:
resource_type: str,
retention_days: int,
action: str,
description: Optional[str] = None,
conditions: Optional[Dict[str, Any]] = None,
description: str | None = None,
conditions: dict[str, Any] | None = None,
auto_execute: bool = False,
execute_at: Optional[str] = None,
execute_at: str | None = None,
notify_before_days: int = 7,
archive_location: Optional[str] = None,
archive_location: str | None = None,
archive_encryption: bool = True,
) -> DataRetentionPolicy:
"""创建数据保留策略"""
@@ -1599,7 +1599,7 @@ class EnterpriseManager:
finally:
conn.close()
def get_retention_policy(self, policy_id: str) -> Optional[DataRetentionPolicy]:
def get_retention_policy(self, policy_id: str) -> DataRetentionPolicy | None:
"""获取数据保留策略"""
conn = self._get_connection()
try:
@@ -1614,7 +1614,7 @@ class EnterpriseManager:
finally:
conn.close()
def list_retention_policies(self, tenant_id: str, resource_type: Optional[str] = None) -> List[DataRetentionPolicy]:
def list_retention_policies(self, tenant_id: str, resource_type: str | None = None) -> list[DataRetentionPolicy]:
"""列出数据保留策略"""
conn = self._get_connection()
try:
@@ -1637,7 +1637,7 @@ class EnterpriseManager:
finally:
conn.close()
def update_retention_policy(self, policy_id: str, **kwargs) -> Optional[DataRetentionPolicy]:
def update_retention_policy(self, policy_id: str, **kwargs) -> DataRetentionPolicy | None:
"""更新数据保留策略"""
conn = self._get_connection()
try:
@@ -1818,7 +1818,7 @@ class EnterpriseManager:
def _retain_audit_logs(
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime
) -> Dict[str, int]:
) -> dict[str, int]:
"""保留审计日志"""
cursor = conn.cursor()
@@ -1851,19 +1851,19 @@ class EnterpriseManager:
def _retain_projects(
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime
) -> Dict[str, int]:
) -> dict[str, int]:
"""保留项目数据"""
# 简化实现
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
def _retain_transcripts(
self, conn: sqlite3.Connection, policy: DataRetentionPolicy, cutoff_date: datetime
) -> Dict[str, int]:
) -> dict[str, int]:
"""保留转录数据"""
# 简化实现
return {"affected": 0, "archived": 0, "deleted": 0, "errors": 0}
def get_retention_job(self, job_id: str) -> Optional[DataRetentionJob]:
def get_retention_job(self, job_id: str) -> DataRetentionJob | None:
"""获取数据保留任务"""
conn = self._get_connection()
try:
@@ -1878,7 +1878,7 @@ class EnterpriseManager:
finally:
conn.close()
def list_retention_jobs(self, policy_id: str, limit: int = 100) -> List[DataRetentionJob]:
def list_retention_jobs(self, policy_id: str, limit: int = 100) -> list[DataRetentionJob]:
"""列出数据保留任务"""
conn = self._get_connection()
try:

View File

@@ -4,12 +4,12 @@ Entity Aligner - Phase 3
使用 embedding 进行实体对齐
"""
import os
import json
import os
from dataclasses import dataclass
import httpx
import numpy as np
from typing import List, Optional, Dict
from dataclasses import dataclass
# API Keys
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
@@ -21,7 +21,7 @@ class EntityEmbedding:
entity_id: str
name: str
definition: str
embedding: List[float]
embedding: list[float]
class EntityAligner:
@@ -29,9 +29,9 @@ class EntityAligner:
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
self.embedding_cache: Dict[str, List[float]] = {}
self.embedding_cache: dict[str, list[float]] = {}
def get_embedding(self, text: str) -> Optional[List[float]]:
def get_embedding(self, text: str) -> list[float] | None:
"""
使用 Kimi API 获取文本的 embedding
@@ -67,7 +67,7 @@ class EntityAligner:
print(f"Embedding API failed: {e}")
return None
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
def compute_similarity(self, embedding1: list[float], embedding2: list[float]) -> float:
"""
计算两个 embedding 的余弦相似度
@@ -111,9 +111,9 @@ class EntityAligner:
project_id: str,
name: str,
definition: str = "",
exclude_id: Optional[str] = None,
threshold: Optional[float] = None,
) -> Optional[object]:
exclude_id: str | None = None,
threshold: float | None = None,
) -> object | None:
"""
查找相似的实体
@@ -175,8 +175,8 @@ class EntityAligner:
return best_match
def _fallback_similarity_match(
self, entities: List[object], name: str, exclude_id: Optional[str] = None
) -> Optional[object]:
self, entities: list[object], name: str, exclude_id: str | None = None
) -> object | None:
"""
回退到简单的相似度匹配(不使用 embedding
@@ -209,8 +209,8 @@ class EntityAligner:
return None
def batch_align_entities(
self, project_id: str, new_entities: List[Dict], threshold: Optional[float] = None
) -> List[Dict]:
self, project_id: str, new_entities: list[dict], threshold: float | None = None
) -> list[dict]:
"""
批量对齐实体
@@ -257,7 +257,7 @@ class EntityAligner:
return results
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> List[str]:
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
"""
使用 LLM 建议实体的别名

View File

@@ -3,12 +3,12 @@ InsightFlow Export Module - Phase 5
支持导出知识图谱、项目报告、实体数据和转录文本
"""
import base64
import io
import json
import base64
from datetime import datetime
from typing import List, Dict, Any
from dataclasses import dataclass
from datetime import datetime
from typing import Any
try:
import pandas as pd
@@ -20,9 +20,16 @@ except ImportError:
try:
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
from reportlab.platypus import (
PageBreak,
Paragraph,
SimpleDocTemplate,
Spacer,
Table,
TableStyle,
)
REPORTLAB_AVAILABLE = True
except ImportError:
@@ -35,9 +42,9 @@ class ExportEntity:
name: str
type: str
definition: str
aliases: List[str]
aliases: list[str]
mention_count: int
attributes: Dict[str, Any]
attributes: dict[str, Any]
@dataclass
@@ -56,8 +63,8 @@ class ExportTranscript:
name: str
type: str # audio/document
content: str
segments: List[Dict]
entity_mentions: List[Dict]
segments: list[dict]
entity_mentions: list[dict]
class ExportManager:
@@ -67,7 +74,7 @@ class ExportManager:
self.db = db_manager
def export_knowledge_graph_svg(
self, project_id: str, entities: List[ExportEntity], relations: List[ExportRelation]
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
) -> str:
"""
导出知识图谱为 SVG 格式
@@ -207,7 +214,7 @@ class ExportManager:
return "\n".join(svg_parts)
def export_knowledge_graph_png(
self, project_id: str, entities: List[ExportEntity], relations: List[ExportRelation]
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
) -> bytes:
"""
导出知识图谱为 PNG 格式
@@ -226,7 +233,7 @@ class ExportManager:
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
return base64.b64encode(svg_content.encode("utf-8"))
def export_entities_excel(self, entities: List[ExportEntity]) -> bytes:
def export_entities_excel(self, entities: list[ExportEntity]) -> bytes:
"""
导出实体数据为 Excel 格式
@@ -275,7 +282,7 @@ class ExportManager:
return output.getvalue()
def export_entities_csv(self, entities: List[ExportEntity]) -> str:
def export_entities_csv(self, entities: list[ExportEntity]) -> str:
"""
导出实体数据为 CSV 格式
@@ -306,7 +313,7 @@ class ExportManager:
return output.getvalue()
def export_relations_csv(self, relations: List[ExportRelation]) -> str:
def export_relations_csv(self, relations: list[ExportRelation]) -> str:
"""
导出关系数据为 CSV 格式
@@ -324,7 +331,7 @@ class ExportManager:
return output.getvalue()
def export_transcript_markdown(self, transcript: ExportTranscript, entities_map: Dict[str, ExportEntity]) -> str:
def export_transcript_markdown(self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity]) -> str:
"""
导出转录文本为 Markdown 格式
@@ -387,9 +394,9 @@ class ExportManager:
self,
project_id: str,
project_name: str,
entities: List[ExportEntity],
relations: List[ExportRelation],
transcripts: List[ExportTranscript],
entities: list[ExportEntity],
relations: list[ExportRelation],
transcripts: list[ExportTranscript],
summary: str = "",
) -> bytes:
"""
@@ -529,9 +536,9 @@ class ExportManager:
self,
project_id: str,
project_name: str,
entities: List[ExportEntity],
relations: List[ExportRelation],
transcripts: List[ExportTranscript],
entities: list[ExportEntity],
relations: list[ExportRelation],
transcripts: list[ExportTranscript],
) -> str:
"""
导出完整项目数据为 JSON 格式

View File

@@ -10,25 +10,26 @@ InsightFlow Growth Manager - Phase 8 Task 5
作者: InsightFlow Team
"""
import os
import json
import sqlite3
import httpx
import asyncio
import hashlib
import json
import os
import random
from typing import List, Dict, Optional, Any, Tuple
import re
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
import hashlib
import uuid
import re
from enum import StrEnum
from typing import Any
import httpx
# Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class EventType(str, Enum):
class EventType(StrEnum):
"""事件类型"""
PAGE_VIEW = "page_view" # 页面浏览
@@ -44,7 +45,7 @@ class EventType(str, Enum):
REFERRAL_REWARD = "referral_reward" # 推荐奖励
class ExperimentStatus(str, Enum):
class ExperimentStatus(StrEnum):
"""实验状态"""
DRAFT = "draft" # 草稿
@@ -54,7 +55,7 @@ class ExperimentStatus(str, Enum):
ARCHIVED = "archived" # 已归档
class TrafficAllocationType(str, Enum):
class TrafficAllocationType(StrEnum):
"""流量分配类型"""
RANDOM = "random" # 随机分配
@@ -62,7 +63,7 @@ class TrafficAllocationType(str, Enum):
TARGETED = "targeted" # 定向分配
class EmailTemplateType(str, Enum):
class EmailTemplateType(StrEnum):
"""邮件模板类型"""
WELCOME = "welcome" # 欢迎邮件
@@ -74,7 +75,7 @@ class EmailTemplateType(str, Enum):
NEWSLETTER = "newsletter" # 新闻通讯
class EmailStatus(str, Enum):
class EmailStatus(StrEnum):
"""邮件状态"""
DRAFT = "draft" # 草稿
@@ -88,7 +89,7 @@ class EmailStatus(str, Enum):
FAILED = "failed" # 失败
class WorkflowTriggerType(str, Enum):
class WorkflowTriggerType(StrEnum):
"""工作流触发类型"""
USER_SIGNUP = "user_signup" # 用户注册
@@ -100,7 +101,7 @@ class WorkflowTriggerType(str, Enum):
CUSTOM_EVENT = "custom_event" # 自定义事件
class ReferralStatus(str, Enum):
class ReferralStatus(StrEnum):
"""推荐状态"""
PENDING = "pending" # 待处理
@@ -118,14 +119,14 @@ class AnalyticsEvent:
user_id: str
event_type: EventType
event_name: str
properties: Dict[str, Any] # 事件属性
properties: dict[str, Any] # 事件属性
timestamp: datetime
session_id: Optional[str]
device_info: Dict[str, str] # 设备信息
referrer: Optional[str]
utm_source: Optional[str]
utm_medium: Optional[str]
utm_campaign: Optional[str]
session_id: str | None
device_info: dict[str, str] # 设备信息
referrer: str | None
utm_source: str | None
utm_medium: str | None
utm_campaign: str | None
@dataclass
@@ -139,8 +140,8 @@ class UserProfile:
last_seen: datetime
total_sessions: int
total_events: int
feature_usage: Dict[str, int] # 功能使用次数
subscription_history: List[Dict]
feature_usage: dict[str, int] # 功能使用次数
subscription_history: list[dict]
ltv: float # 生命周期价值
churn_risk_score: float # 流失风险分数
engagement_score: float # 参与度分数
@@ -156,7 +157,7 @@ class Funnel:
tenant_id: str
name: str
description: str
steps: List[Dict] # 漏斗步骤
steps: list[dict] # 漏斗步骤
created_at: datetime
updated_at: datetime
@@ -169,9 +170,9 @@ class FunnelAnalysis:
period_start: datetime
period_end: datetime
total_users: int
step_conversions: List[Dict] # 每步转化数据
step_conversions: list[dict] # 每步转化数据
overall_conversion: float # 总体转化率
drop_off_points: List[Dict] # 流失点
drop_off_points: list[dict] # 流失点
@dataclass
@@ -184,14 +185,14 @@ class Experiment:
description: str
hypothesis: str
status: ExperimentStatus
variants: List[Dict] # 实验变体
variants: list[dict] # 实验变体
traffic_allocation: TrafficAllocationType
traffic_split: Dict[str, float] # 流量分配比例
target_audience: Dict # 目标受众
traffic_split: dict[str, float] # 流量分配比例
target_audience: dict # 目标受众
primary_metric: str # 主要指标
secondary_metrics: List[str] # 次要指标
start_date: Optional[datetime]
end_date: Optional[datetime]
secondary_metrics: list[str] # 次要指标
start_date: datetime | None
end_date: datetime | None
min_sample_size: int # 最小样本量
confidence_level: float # 置信水平
created_at: datetime
@@ -210,7 +211,7 @@ class ExperimentResult:
sample_size: int
mean_value: float
std_dev: float
confidence_interval: Tuple[float, float]
confidence_interval: tuple[float, float]
p_value: float
is_significant: bool
uplift: float # 提升幅度
@@ -228,11 +229,11 @@ class EmailTemplate:
subject: str
html_content: str
text_content: str
variables: List[str] # 模板变量
preview_text: Optional[str]
variables: list[str] # 模板变量
preview_text: str | None
from_name: str
from_email: str
reply_to: Optional[str]
reply_to: str | None
is_active: bool
created_at: datetime
updated_at: datetime
@@ -254,9 +255,9 @@ class EmailCampaign:
clicked_count: int
bounced_count: int
failed_count: int
scheduled_at: Optional[datetime]
started_at: Optional[datetime]
completed_at: Optional[datetime]
scheduled_at: datetime | None
started_at: datetime | None
completed_at: datetime | None
created_at: datetime
@@ -272,13 +273,13 @@ class EmailLog:
template_id: str
status: EmailStatus
subject: str
sent_at: Optional[datetime]
delivered_at: Optional[datetime]
opened_at: Optional[datetime]
clicked_at: Optional[datetime]
ip_address: Optional[str]
user_agent: Optional[str]
error_message: Optional[str]
sent_at: datetime | None
delivered_at: datetime | None
opened_at: datetime | None
clicked_at: datetime | None
ip_address: str | None
user_agent: str | None
error_message: str | None
created_at: datetime
@@ -291,8 +292,8 @@ class AutomationWorkflow:
name: str
description: str
trigger_type: WorkflowTriggerType
trigger_conditions: Dict # 触发条件
actions: List[Dict] # 执行动作
trigger_conditions: dict # 触发条件
actions: list[dict] # 执行动作
is_active: bool
execution_count: int
created_at: datetime
@@ -327,15 +328,15 @@ class Referral:
program_id: str
tenant_id: str
referrer_id: str # 推荐人
referee_id: Optional[str] # 被推荐人
referee_id: str | None # 被推荐人
referral_code: str
status: ReferralStatus
referrer_rewarded: bool
referee_rewarded: bool
referrer_reward_value: float
referee_reward_value: float
converted_at: Optional[datetime]
rewarded_at: Optional[datetime]
converted_at: datetime | None
rewarded_at: datetime | None
expires_at: datetime
created_at: datetime
@@ -382,11 +383,11 @@ class GrowthManager:
user_id: str,
event_type: EventType,
event_name: str,
properties: Dict = None,
properties: dict = None,
session_id: str = None,
device_info: Dict = None,
device_info: dict = None,
referrer: str = None,
utm_params: Dict = None,
utm_params: dict = None,
) -> AnalyticsEvent:
"""追踪事件"""
event_id = f"evt_{uuid.uuid4().hex[:16]}"
@@ -554,7 +555,7 @@ class GrowthManager:
conn.commit()
def get_user_profile(self, tenant_id: str, user_id: str) -> Optional[UserProfile]:
def get_user_profile(self, tenant_id: str, user_id: str) -> UserProfile | None:
"""获取用户画像"""
with self._get_db() as conn:
row = conn.execute(
@@ -567,7 +568,7 @@ class GrowthManager:
def get_user_analytics_summary(
self, tenant_id: str, start_date: datetime = None, end_date: datetime = None
) -> Dict:
) -> dict:
"""获取用户分析汇总"""
with self._get_db() as conn:
query = """
@@ -619,7 +620,7 @@ class GrowthManager:
"event_type_distribution": {r["event_type"]: r["count"] for r in type_rows},
}
def create_funnel(self, tenant_id: str, name: str, description: str, steps: List[Dict], created_by: str) -> Funnel:
def create_funnel(self, tenant_id: str, name: str, description: str, steps: list[dict], created_by: str) -> Funnel:
"""创建转化漏斗"""
funnel_id = f"fnl_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
@@ -657,7 +658,7 @@ class GrowthManager:
def analyze_funnel(
self, funnel_id: str, period_start: datetime = None, period_end: datetime = None
) -> Optional[FunnelAnalysis]:
) -> FunnelAnalysis | None:
"""分析漏斗转化率"""
with self._get_db() as conn:
funnel_row = conn.execute("SELECT * FROM funnels WHERE id = ?", (funnel_id,)).fetchone()
@@ -728,7 +729,7 @@ class GrowthManager:
drop_off_points=drop_off_points,
)
def calculate_retention(self, tenant_id: str, cohort_date: datetime, periods: List[int] = None) -> Dict:
def calculate_retention(self, tenant_id: str, cohort_date: datetime, periods: list[int] = None) -> dict:
"""计算留存率"""
if periods is None:
periods = [1, 3, 7, 14, 30]
@@ -787,12 +788,12 @@ class GrowthManager:
name: str,
description: str,
hypothesis: str,
variants: List[Dict],
variants: list[dict],
traffic_allocation: TrafficAllocationType,
traffic_split: Dict[str, float],
target_audience: Dict,
traffic_split: dict[str, float],
target_audience: dict,
primary_metric: str,
secondary_metrics: List[str],
secondary_metrics: list[str],
min_sample_size: int = 100,
confidence_level: float = 0.95,
created_by: str = None,
@@ -859,7 +860,7 @@ class GrowthManager:
return experiment
def get_experiment(self, experiment_id: str) -> Optional[Experiment]:
def get_experiment(self, experiment_id: str) -> Experiment | None:
"""获取实验详情"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM experiments WHERE id = ?", (experiment_id,)).fetchone()
@@ -868,7 +869,7 @@ class GrowthManager:
return self._row_to_experiment(row)
return None
def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> List[Experiment]:
def list_experiments(self, tenant_id: str, status: ExperimentStatus = None) -> list[Experiment]:
"""列出实验"""
query = "SELECT * FROM experiments WHERE tenant_id = ?"
params = [tenant_id]
@@ -883,7 +884,7 @@ class GrowthManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_experiment(row) for row in rows]
def assign_variant(self, experiment_id: str, user_id: str, user_attributes: Dict = None) -> Optional[str]:
def assign_variant(self, experiment_id: str, user_id: str, user_attributes: dict = None) -> str | None:
"""为用户分配实验变体"""
experiment = self.get_experiment(experiment_id)
if not experiment or experiment.status != ExperimentStatus.RUNNING:
@@ -929,7 +930,7 @@ class GrowthManager:
return variant_id
def _random_allocation(self, variants: List[Dict], traffic_split: Dict[str, float]) -> str:
def _random_allocation(self, variants: list[dict], traffic_split: dict[str, float]) -> str:
"""随机分配"""
variant_ids = [v["id"] for v in variants]
weights = [traffic_split.get(v_id, 1.0 / len(variants)) for v_id in variant_ids]
@@ -940,7 +941,7 @@ class GrowthManager:
return random.choices(variant_ids, weights=normalized_weights, k=1)[0]
def _stratified_allocation(
self, variants: List[Dict], traffic_split: Dict[str, float], user_attributes: Dict
self, variants: list[dict], traffic_split: dict[str, float], user_attributes: dict
) -> str:
"""分层分配(基于用户属性)"""
# 简化的分层分配:根据用户 ID 哈希值分配
@@ -952,7 +953,7 @@ class GrowthManager:
return self._random_allocation(variants, traffic_split)
def _targeted_allocation(self, variants: List[Dict], target_audience: Dict, user_attributes: Dict) -> Optional[str]:
def _targeted_allocation(self, variants: list[dict], target_audience: dict, user_attributes: dict) -> str | None:
"""定向分配(基于目标受众条件)"""
# 检查用户是否符合目标受众条件
conditions = target_audience.get("conditions", [])
@@ -1005,7 +1006,7 @@ class GrowthManager:
)
conn.commit()
def analyze_experiment(self, experiment_id: str) -> Dict:
def analyze_experiment(self, experiment_id: str) -> dict:
"""分析实验结果"""
experiment = self.get_experiment(experiment_id)
if not experiment:
@@ -1083,7 +1084,7 @@ class GrowthManager:
"variant_results": results,
}
def start_experiment(self, experiment_id: str) -> Optional[Experiment]:
def start_experiment(self, experiment_id: str) -> Experiment | None:
"""启动实验"""
with self._get_db() as conn:
now = datetime.now().isoformat()
@@ -1099,7 +1100,7 @@ class GrowthManager:
return self.get_experiment(experiment_id)
def stop_experiment(self, experiment_id: str) -> Optional[Experiment]:
def stop_experiment(self, experiment_id: str) -> Experiment | None:
"""停止实验"""
with self._get_db() as conn:
now = datetime.now().isoformat()
@@ -1125,7 +1126,7 @@ class GrowthManager:
subject: str,
html_content: str,
text_content: str = None,
variables: List[str] = None,
variables: list[str] = None,
from_name: str = None,
from_email: str = None,
reply_to: str = None,
@@ -1185,7 +1186,7 @@ class GrowthManager:
return template
def get_email_template(self, template_id: str) -> Optional[EmailTemplate]:
def get_email_template(self, template_id: str) -> EmailTemplate | None:
"""获取邮件模板"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM email_templates WHERE id = ?", (template_id,)).fetchone()
@@ -1194,7 +1195,7 @@ class GrowthManager:
return self._row_to_email_template(row)
return None
def list_email_templates(self, tenant_id: str, template_type: EmailTemplateType = None) -> List[EmailTemplate]:
def list_email_templates(self, tenant_id: str, template_type: EmailTemplateType = None) -> list[EmailTemplate]:
"""列出邮件模板"""
query = "SELECT * FROM email_templates WHERE tenant_id = ? AND is_active = 1"
params = [tenant_id]
@@ -1209,7 +1210,7 @@ class GrowthManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_email_template(row) for row in rows]
def render_template(self, template_id: str, variables: Dict) -> Dict[str, str]:
def render_template(self, template_id: str, variables: dict) -> dict[str, str]:
"""渲染邮件模板"""
template = self.get_email_template(template_id)
if not template:
@@ -1235,7 +1236,7 @@ class GrowthManager:
}
def create_email_campaign(
self, tenant_id: str, name: str, template_id: str, recipient_list: List[Dict], scheduled_at: datetime = None
self, tenant_id: str, name: str, template_id: str, recipient_list: list[dict], scheduled_at: datetime = None
) -> EmailCampaign:
"""创建邮件营销活动"""
campaign_id = f"ec_{uuid.uuid4().hex[:16]}"
@@ -1314,7 +1315,7 @@ class GrowthManager:
return campaign
async def send_email(self, campaign_id: str, user_id: str, email: str, template_id: str, variables: Dict) -> bool:
async def send_email(self, campaign_id: str, user_id: str, email: str, template_id: str, variables: dict) -> bool:
"""发送单封邮件"""
template = self.get_email_template(template_id)
if not template:
@@ -1380,7 +1381,7 @@ class GrowthManager:
conn.commit()
return False
async def send_campaign(self, campaign_id: str) -> Dict:
async def send_campaign(self, campaign_id: str) -> dict:
"""发送整个营销活动"""
with self._get_db() as conn:
campaign_row = conn.execute("SELECT * FROM email_campaigns WHERE id = ?", (campaign_id,)).fetchone()
@@ -1432,7 +1433,7 @@ class GrowthManager:
return {"campaign_id": campaign_id, "total": len(logs), "success": success_count, "failed": failed_count}
def _get_user_variables(self, tenant_id: str, user_id: str) -> Dict:
def _get_user_variables(self, tenant_id: str, user_id: str) -> dict:
"""获取用户变量用于邮件模板"""
# 这里应该从用户服务获取用户信息
# 简化实现
@@ -1444,8 +1445,8 @@ class GrowthManager:
name: str,
description: str,
trigger_type: WorkflowTriggerType,
trigger_conditions: Dict,
actions: List[Dict],
trigger_conditions: dict,
actions: list[dict],
) -> AutomationWorkflow:
"""创建自动化工作流"""
workflow_id = f"aw_{uuid.uuid4().hex[:16]}"
@@ -1491,7 +1492,7 @@ class GrowthManager:
return workflow
async def trigger_workflow(self, workflow_id: str, event_data: Dict):
async def trigger_workflow(self, workflow_id: str, event_data: dict):
"""触发自动化工作流"""
with self._get_db() as conn:
row = conn.execute(
@@ -1519,7 +1520,7 @@ class GrowthManager:
return True
def _check_trigger_conditions(self, conditions: Dict, event_data: Dict) -> bool:
def _check_trigger_conditions(self, conditions: dict, event_data: dict) -> bool:
"""检查触发条件"""
# 简化的条件检查
for key, value in conditions.items():
@@ -1527,7 +1528,7 @@ class GrowthManager:
return False
return True
async def _execute_action(self, action: Dict, event_data: Dict):
async def _execute_action(self, action: dict, event_data: dict):
"""执行工作流动作"""
action_type = action.get("type")
@@ -1691,7 +1692,7 @@ class GrowthManager:
if not row:
return code
def _get_referral_program(self, program_id: str) -> Optional[ReferralProgram]:
def _get_referral_program(self, program_id: str) -> ReferralProgram | None:
"""获取推荐计划"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM referral_programs WHERE id = ?", (program_id,)).fetchone()
@@ -1746,7 +1747,7 @@ class GrowthManager:
return True
def get_referral_stats(self, program_id: str) -> Dict:
def get_referral_stats(self, program_id: str) -> dict:
"""获取推荐统计"""
with self._get_db() as conn:
stats = conn.execute(
@@ -1841,7 +1842,7 @@ class GrowthManager:
def check_team_incentive_eligibility(
self, tenant_id: str, current_tier: str, team_size: int
) -> List[TeamIncentive]:
) -> list[TeamIncentive]:
"""检查团队激励资格"""
with self._get_db() as conn:
now = datetime.now().isoformat()
@@ -1859,7 +1860,7 @@ class GrowthManager:
# ==================== 实时分析仪表板 ====================
def get_realtime_dashboard(self, tenant_id: str) -> Dict:
def get_realtime_dashboard(self, tenant_id: str) -> dict:
"""获取实时分析仪表板数据"""
now = datetime.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -4,11 +4,10 @@ InsightFlow Image Processor - Phase 7
图片处理模块识别白板、PPT、手写笔记等内容
"""
import os
import io
import uuid
import base64
from typing import List, Optional, Tuple
import io
import os
import uuid
from dataclasses import dataclass
# 尝试导入图像处理库
@@ -42,7 +41,7 @@ class ImageEntity:
name: str
type: str
confidence: float
bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, width, height)
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
@dataclass
@@ -63,8 +62,8 @@ class ImageProcessingResult:
image_type: str # whiteboard, ppt, handwritten, screenshot, other
ocr_text: str
description: str
entities: List[ImageEntity]
relations: List[ImageRelation]
entities: list[ImageEntity]
relations: list[ImageRelation]
width: int
height: int
success: bool
@@ -75,7 +74,7 @@ class ImageProcessingResult:
class BatchProcessingResult:
"""批量图片处理结果"""
results: List[ImageProcessingResult]
results: list[ImageProcessingResult]
total_count: int
success_count: int
failed_count: int
@@ -231,7 +230,7 @@ class ImageProcessor:
print(f"Image type detection error: {e}")
return "other"
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> Tuple[str, float]:
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
"""
对图片进行OCR识别
@@ -262,7 +261,7 @@ class ImageProcessor:
print(f"OCR error: {e}")
return "", 0.0
def extract_entities_from_text(self, text: str) -> List[ImageEntity]:
def extract_entities_from_text(self, text: str) -> list[ImageEntity]:
"""
从OCR文本中提取实体
@@ -322,7 +321,7 @@ class ImageProcessor:
return unique_entities
def generate_description(self, image_type: str, ocr_text: str, entities: List[ImageEntity]) -> str:
def generate_description(self, image_type: str, ocr_text: str, entities: list[ImageEntity]) -> str:
"""
生成图片描述
@@ -435,7 +434,7 @@ class ImageProcessor:
error_message=str(e),
)
def _extract_relations(self, entities: List[ImageEntity], text: str) -> List[ImageRelation]:
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
"""
从文本中提取实体关系
@@ -475,7 +474,7 @@ class ImageProcessor:
return relations
def process_batch(self, images_data: List[Tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult:
def process_batch(self, images_data: list[tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult:
"""
批量处理图片
@@ -515,7 +514,7 @@ class ImageProcessor:
"""
return base64.b64encode(image_data).decode("utf-8")
def get_image_thumbnail(self, image_data: bytes, size: Tuple[int, int] = (200, 200)) -> bytes:
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
"""
生成图片缩略图

View File

@@ -1,8 +1,8 @@
#!/usr/bin/env python3
"""Initialize database with schema"""
import sqlite3
import os
import sqlite3
db_path = os.path.join(os.path.dirname(__file__), "insightflow.db")
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
@@ -11,7 +11,7 @@ print(f"Database path: {db_path}")
print(f"Schema path: {schema_path}")
# Read schema
with open(schema_path, "r") as f:
with open(schema_path) as f:
schema = f.read()
# Execute schema

View File

@@ -4,13 +4,13 @@ InsightFlow Knowledge Reasoning - Phase 5
知识推理与问答增强模块
"""
import os
import json
import httpx
from typing import List, Dict
import os
from dataclasses import dataclass
from enum import Enum
import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@@ -32,9 +32,9 @@ class ReasoningResult:
answer: str
reasoning_type: ReasoningType
confidence: float
evidence: List[Dict] # 支撑证据
related_entities: List[str] # 相关实体
gaps: List[str] # 知识缺口
evidence: list[dict] # 支撑证据
related_entities: list[str] # 相关实体
gaps: list[str] # 知识缺口
@dataclass
@@ -43,7 +43,7 @@ class InferencePath:
start_entity: str
end_entity: str
path: List[Dict] # 路径上的节点和关系
path: list[dict] # 路径上的节点和关系
strength: float # 路径强度
@@ -71,7 +71,7 @@ class KnowledgeReasoner:
return result["choices"][0]["message"]["content"]
async def enhanced_qa(
self, query: str, project_context: Dict, graph_data: Dict, reasoning_depth: str = "medium"
self, query: str, project_context: dict, graph_data: dict, reasoning_depth: str = "medium"
) -> ReasoningResult:
"""
增强问答 - 结合图谱推理的问答
@@ -95,7 +95,7 @@ class KnowledgeReasoner:
else:
return await self._associative_reasoning(query, project_context, graph_data)
async def _analyze_question(self, query: str) -> Dict:
async def _analyze_question(self, query: str) -> dict:
"""分析问题类型和意图"""
prompt = f"""分析以下问题的类型和意图:
@@ -129,7 +129,7 @@ class KnowledgeReasoner:
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
async def _causal_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
async def _causal_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""因果推理 - 分析原因和影响"""
# 构建因果分析提示
@@ -190,7 +190,7 @@ class KnowledgeReasoner:
gaps=["无法完成因果推理"],
)
async def _comparative_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
async def _comparative_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""对比推理 - 比较实体间的异同"""
prompt = f"""基于以下知识图谱进行对比分析:
@@ -244,7 +244,7 @@ class KnowledgeReasoner:
gaps=[],
)
async def _temporal_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
async def _temporal_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""时序推理 - 分析时间线和演变"""
prompt = f"""基于以下知识图谱进行时序分析:
@@ -298,7 +298,7 @@ class KnowledgeReasoner:
gaps=[],
)
async def _associative_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
async def _associative_reasoning(self, query: str, project_context: dict, graph_data: dict) -> ReasoningResult:
"""关联推理 - 发现实体间的隐含关联"""
prompt = f"""基于以下知识图谱进行关联分析:
@@ -353,8 +353,8 @@ class KnowledgeReasoner:
)
def find_inference_paths(
self, start_entity: str, end_entity: str, graph_data: Dict, max_depth: int = 3
) -> List[InferencePath]:
self, start_entity: str, end_entity: str, graph_data: dict, max_depth: int = 3
) -> list[InferencePath]:
"""
发现两个实体之间的推理路径
@@ -416,7 +416,7 @@ class KnowledgeReasoner:
paths.sort(key=lambda p: p.strength, reverse=True)
return paths
def _calculate_path_strength(self, path: List[Dict]) -> float:
def _calculate_path_strength(self, path: list[dict]) -> float:
"""计算路径强度"""
if len(path) < 2:
return 0.0
@@ -438,8 +438,8 @@ class KnowledgeReasoner:
return length_factor * confidence_factor
async def summarize_project(
self, project_context: Dict, graph_data: Dict, summary_type: str = "comprehensive"
) -> Dict:
self, project_context: dict, graph_data: dict, summary_type: str = "comprehensive"
) -> dict:
"""
项目智能总结

View File

@@ -4,12 +4,13 @@ InsightFlow LLM Client - Phase 4
用于与 Kimi API 交互,支持 RAG 问答和 Agent 功能
"""
import os
import json
import httpx
from typing import List, Dict, AsyncGenerator
import os
from collections.abc import AsyncGenerator
from dataclasses import dataclass
import httpx
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@@ -44,7 +45,7 @@ class LLMClient:
self.base_url = base_url or KIMI_BASE_URL
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
async def chat(self, messages: List[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str:
async def chat(self, messages: list[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str:
"""发送聊天请求"""
if not self.api_key:
raise ValueError("KIMI_API_KEY not set")
@@ -64,7 +65,7 @@ class LLMClient:
result = response.json()
return result["choices"][0]["message"]["content"]
async def chat_stream(self, messages: List[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]:
async def chat_stream(self, messages: list[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]:
"""流式聊天请求"""
if not self.api_key:
raise ValueError("KIMI_API_KEY not set")
@@ -96,7 +97,7 @@ class LLMClient:
async def extract_entities_with_confidence(
self, text: str
) -> tuple[List[EntityExtractionResult], List[RelationExtractionResult]]:
) -> tuple[list[EntityExtractionResult], list[RelationExtractionResult]]:
"""提取实体和关系,带置信度分数"""
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
@@ -152,7 +153,7 @@ class LLMClient:
print(f"Parse extraction result failed: {e}")
return [], []
async def rag_query(self, query: str, context: str, project_context: Dict) -> str:
async def rag_query(self, query: str, context: str, project_context: dict) -> str:
"""RAG 问答 - 基于项目上下文回答问题"""
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
@@ -174,7 +175,7 @@ class LLMClient:
return await self.chat(messages, temperature=0.3)
async def agent_command(self, command: str, project_context: Dict) -> Dict:
async def agent_command(self, command: str, project_context: dict) -> dict:
"""Agent 指令解析 - 将自然语言指令转换为结构化操作"""
prompt = f"""解析以下用户指令,转换为结构化操作:
@@ -214,7 +215,7 @@ class LLMClient:
except BaseException:
return {"intent": "unknown", "explanation": "解析失败"}
async def analyze_entity_evolution(self, entity_name: str, mentions: List[Dict]) -> str:
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
"""分析实体在项目中的演变/态度变化"""
mentions_text = "\n".join(
[f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20]] # 限制数量

View File

@@ -10,14 +10,14 @@ InsightFlow Phase 8 - 全球化与本地化管理模块
作者: InsightFlow Team
"""
import sqlite3
import json
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import logging
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import StrEnum
from typing import Any
try:
import pytz
@@ -36,7 +36,7 @@ except ImportError:
logger = logging.getLogger(__name__)
class LanguageCode(str, Enum):
class LanguageCode(StrEnum):
"""支持的语言代码"""
EN = "en"
@@ -53,7 +53,7 @@ class LanguageCode(str, Enum):
HI = "hi"
class RegionCode(str, Enum):
class RegionCode(StrEnum):
"""区域代码"""
GLOBAL = "global"
@@ -65,7 +65,7 @@ class RegionCode(str, Enum):
MIDDLE_EAST = "me"
class DataCenterRegion(str, Enum):
class DataCenterRegion(StrEnum):
"""数据中心区域"""
US_EAST = "us-east"
@@ -79,7 +79,7 @@ class DataCenterRegion(str, Enum):
CN_EAST = "cn-east"
class PaymentProvider(str, Enum):
class PaymentProvider(StrEnum):
"""支付提供商"""
STRIPE = "stripe"
@@ -96,7 +96,7 @@ class PaymentProvider(str, Enum):
UNIONPAY = "unionpay"
class CalendarType(str, Enum):
class CalendarType(StrEnum):
"""日历类型"""
GREGORIAN = "gregorian"
@@ -115,12 +115,12 @@ class Translation:
language: str
value: str
namespace: str
context: Optional[str]
context: str | None
created_at: datetime
updated_at: datetime
is_reviewed: bool
reviewed_by: Optional[str]
reviewed_at: Optional[datetime]
reviewed_by: str | None
reviewed_at: datetime | None
@dataclass
@@ -131,7 +131,7 @@ class LanguageConfig:
is_rtl: bool
is_active: bool
is_default: bool
fallback_language: Optional[str]
fallback_language: str | None
date_format: str
time_format: str
datetime_format: str
@@ -150,8 +150,8 @@ class DataCenter:
endpoint: str
status: str
priority: int
supported_regions: List[str]
capabilities: Dict[str, Any]
supported_regions: list[str]
capabilities: dict[str, Any]
created_at: datetime
updated_at: datetime
@@ -161,7 +161,7 @@ class TenantDataCenterMapping:
id: str
tenant_id: str
primary_dc_id: str
secondary_dc_id: Optional[str]
secondary_dc_id: str | None
region_code: str
data_residency: str
created_at: datetime
@@ -173,15 +173,15 @@ class LocalizedPaymentMethod:
id: str
provider: str
name: str
name_local: Dict[str, str]
supported_countries: List[str]
supported_currencies: List[str]
name_local: dict[str, str]
supported_countries: list[str]
supported_currencies: list[str]
is_active: bool
config: Dict[str, Any]
icon_url: Optional[str]
config: dict[str, Any]
icon_url: str | None
display_order: int
min_amount: Optional[float]
max_amount: Optional[float]
min_amount: float | None
max_amount: float | None
created_at: datetime
updated_at: datetime
@@ -191,20 +191,20 @@ class CountryConfig:
code: str
code3: str
name: str
name_local: Dict[str, str]
name_local: dict[str, str]
region: str
default_language: str
supported_languages: List[str]
supported_languages: list[str]
default_currency: str
supported_currencies: List[str]
supported_currencies: list[str]
timezone: str
calendar_type: str
date_format: Optional[str]
time_format: Optional[str]
number_format: Optional[str]
address_format: Optional[str]
phone_format: Optional[str]
vat_rate: Optional[float]
date_format: str | None
time_format: str | None
number_format: str | None
address_format: str | None
phone_format: str | None
vat_rate: float | None
is_active: bool
@@ -213,7 +213,7 @@ class TimezoneConfig:
id: str
timezone: str
utc_offset: str
dst_offset: Optional[str]
dst_offset: str | None
country_code: str
region: str
is_active: bool
@@ -223,7 +223,7 @@ class TimezoneConfig:
class CurrencyConfig:
code: str
name: str
name_local: Dict[str, str]
name_local: dict[str, str]
symbol: str
decimal_places: int
decimal_separator: str
@@ -236,13 +236,13 @@ class LocalizationSettings:
id: str
tenant_id: str
default_language: str
supported_languages: List[str]
supported_languages: list[str]
default_currency: str
supported_currencies: List[str]
supported_currencies: list[str]
default_timezone: str
default_date_format: Optional[str]
default_time_format: Optional[str]
default_number_format: Optional[str]
default_date_format: str | None
default_time_format: str | None
default_number_format: str | None
calendar_type: str
first_day_of_week: int
region_code: str
@@ -940,7 +940,7 @@ class LocalizationManager:
def get_translation(
self, key: str, language: str, namespace: str = "common", fallback: bool = True
) -> Optional[str]:
) -> str | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -962,7 +962,7 @@ class LocalizationManager:
self._close_if_file_db(conn)
def set_translation(
self, key: str, language: str, value: str, namespace: str = "common", context: Optional[str] = None
self, key: str, language: str, value: str, namespace: str = "common", context: str | None = None
) -> Translation:
conn = self._get_connection()
try:
@@ -985,7 +985,7 @@ class LocalizationManager:
def _get_translation_internal(
self, conn: sqlite3.Connection, key: str, language: str, namespace: str
) -> Optional[Translation]:
) -> Translation | None:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM translations WHERE key = ? AND language = ? AND namespace = ?", (key, language, namespace)
@@ -1008,8 +1008,8 @@ class LocalizationManager:
self._close_if_file_db(conn)
def list_translations(
self, language: Optional[str] = None, namespace: Optional[str] = None, limit: int = 1000, offset: int = 0
) -> List[Translation]:
self, language: str | None = None, namespace: str | None = None, limit: int = 1000, offset: int = 0
) -> list[Translation]:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1029,7 +1029,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def get_language_config(self, code: str) -> Optional[LanguageConfig]:
def get_language_config(self, code: str) -> LanguageConfig | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1041,7 +1041,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def list_language_configs(self, active_only: bool = True) -> List[LanguageConfig]:
def list_language_configs(self, active_only: bool = True) -> list[LanguageConfig]:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1055,7 +1055,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def get_data_center(self, dc_id: str) -> Optional[DataCenter]:
def get_data_center(self, dc_id: str) -> DataCenter | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1067,7 +1067,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def get_data_center_by_region(self, region_code: str) -> Optional[DataCenter]:
def get_data_center_by_region(self, region_code: str) -> DataCenter | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1079,7 +1079,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def list_data_centers(self, status: Optional[str] = None, region: Optional[str] = None) -> List[DataCenter]:
def list_data_centers(self, status: str | None = None, region: str | None = None) -> list[DataCenter]:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1098,7 +1098,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def get_tenant_data_center(self, tenant_id: str) -> Optional[TenantDataCenterMapping]:
def get_tenant_data_center(self, tenant_id: str) -> TenantDataCenterMapping | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1159,7 +1159,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def get_payment_method(self, provider: str) -> Optional[LocalizedPaymentMethod]:
def get_payment_method(self, provider: str) -> LocalizedPaymentMethod | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1172,8 +1172,8 @@ class LocalizationManager:
self._close_if_file_db(conn)
def list_payment_methods(
self, country_code: Optional[str] = None, currency: Optional[str] = None, active_only: bool = True
) -> List[LocalizedPaymentMethod]:
self, country_code: str | None = None, currency: str | None = None, active_only: bool = True
) -> list[LocalizedPaymentMethod]:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1194,7 +1194,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def get_localized_payment_methods(self, country_code: str, language: str = "en") -> List[Dict[str, Any]]:
def get_localized_payment_methods(self, country_code: str, language: str = "en") -> list[dict[str, Any]]:
methods = self.list_payment_methods(country_code=country_code)
result = []
for method in methods:
@@ -1212,7 +1212,7 @@ class LocalizationManager:
)
return result
def get_country_config(self, code: str) -> Optional[CountryConfig]:
def get_country_config(self, code: str) -> CountryConfig | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1224,7 +1224,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def list_country_configs(self, region: Optional[str] = None, active_only: bool = True) -> List[CountryConfig]:
def list_country_configs(self, region: str | None = None, active_only: bool = True) -> list[CountryConfig]:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1243,7 +1243,7 @@ class LocalizationManager:
self._close_if_file_db(conn)
def format_datetime(
self, dt: datetime, language: str = "en", timezone: Optional[str] = None, format_type: str = "datetime"
self, dt: datetime, language: str = "en", timezone: str | None = None, format_type: str = "datetime"
) -> str:
try:
if timezone and PYTZ_AVAILABLE:
@@ -1276,7 +1276,7 @@ class LocalizationManager:
logger.error(f"Error formatting datetime: {e}")
return dt.strftime("%Y-%m-%d %H:%M")
def format_number(self, number: float, language: str = "en", decimal_places: Optional[int] = None) -> str:
def format_number(self, number: float, language: str = "en", decimal_places: int | None = None) -> str:
try:
if BABEL_AVAILABLE:
try:
@@ -1319,7 +1319,7 @@ class LocalizationManager:
logger.error(f"Error converting timezone: {e}")
return dt
def get_calendar_info(self, calendar_type: str, year: int, month: int) -> Dict[str, Any]:
def get_calendar_info(self, calendar_type: str, year: int, month: int) -> dict[str, Any]:
import calendar
cal = calendar.Calendar()
@@ -1334,7 +1334,7 @@ class LocalizationManager:
"weeks": month_days,
}
def get_localization_settings(self, tenant_id: str) -> Optional[LocalizationSettings]:
def get_localization_settings(self, tenant_id: str) -> LocalizationSettings | None:
conn = self._get_connection()
try:
cursor = conn.cursor()
@@ -1350,9 +1350,9 @@ class LocalizationManager:
self,
tenant_id: str,
default_language: str = "en",
supported_languages: Optional[List[str]] = None,
supported_languages: list[str] | None = None,
default_currency: str = "USD",
supported_currencies: Optional[List[str]] = None,
supported_currencies: list[str] | None = None,
default_timezone: str = "UTC",
region_code: str = "global",
data_residency: str = "regional",
@@ -1397,7 +1397,7 @@ class LocalizationManager:
finally:
self._close_if_file_db(conn)
def update_localization_settings(self, tenant_id: str, **kwargs) -> Optional[LocalizationSettings]:
def update_localization_settings(self, tenant_id: str, **kwargs) -> LocalizationSettings | None:
conn = self._get_connection()
try:
settings = self.get_localization_settings(tenant_id)
@@ -1441,8 +1441,8 @@ class LocalizationManager:
self._close_if_file_db(conn)
def detect_user_preferences(
self, accept_language: Optional[str] = None, ip_country: Optional[str] = None
) -> Dict[str, str]:
self, accept_language: str | None = None, ip_country: str | None = None
) -> dict[str, str]:
preferences = {"language": "en", "country": "US", "timezone": "UTC", "currency": "USD"}
if accept_language:
langs = accept_language.split(",")

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,6 @@ InsightFlow Multimodal Entity Linker - Phase 7
"""
import uuid
from typing import List, Dict, Optional, Tuple, Set
from dataclasses import dataclass
from difflib import SequenceMatcher
@@ -28,7 +27,7 @@ class MultimodalEntity:
source_id: str
mention_context: str
confidence: float
modality_features: Dict = None # 模态特定特征
modality_features: dict = None # 模态特定特征
def __post_init__(self):
if self.modality_features is None:
@@ -55,7 +54,7 @@ class AlignmentResult:
"""对齐结果"""
entity_id: str
matched_entity_id: Optional[str]
matched_entity_id: str | None
similarity: float
match_type: str # exact, fuzzy, embedding
confidence: float
@@ -66,9 +65,9 @@ class FusionResult:
"""知识融合结果"""
canonical_entity_id: str
merged_entity_ids: List[str]
fused_properties: Dict
source_modalities: List[str]
merged_entity_ids: list[str]
fused_properties: dict
source_modalities: list[str]
confidence: float
@@ -117,7 +116,7 @@ class MultimodalEntityLinker:
# 编辑距离相似度
return SequenceMatcher(None, s1, s2).ratio()
def calculate_entity_similarity(self, entity1: Dict, entity2: Dict) -> Tuple[float, str]:
def calculate_entity_similarity(self, entity1: dict, entity2: dict) -> tuple[float, str]:
"""
计算两个实体的综合相似度
@@ -159,8 +158,8 @@ class MultimodalEntityLinker:
return combined_sim, "none"
def find_matching_entity(
self, query_entity: Dict, candidate_entities: List[Dict], exclude_ids: Set[str] = None
) -> Optional[AlignmentResult]:
self, query_entity: dict, candidate_entities: list[dict], exclude_ids: set[str] = None
) -> AlignmentResult | None:
"""
在候选实体中查找匹配的实体
@@ -201,11 +200,11 @@ class MultimodalEntityLinker:
def align_cross_modal_entities(
self,
project_id: str,
audio_entities: List[Dict],
video_entities: List[Dict],
image_entities: List[Dict],
document_entities: List[Dict],
) -> List[EntityLink]:
audio_entities: list[dict],
video_entities: list[dict],
image_entities: list[dict],
document_entities: list[dict],
) -> list[EntityLink]:
"""
跨模态实体对齐
@@ -259,7 +258,7 @@ class MultimodalEntityLinker:
return links
def fuse_entity_knowledge(
self, entity_id: str, linked_entities: List[Dict], multimodal_mentions: List[Dict]
self, entity_id: str, linked_entities: list[dict], multimodal_mentions: list[dict]
) -> FusionResult:
"""
融合多模态实体知识
@@ -331,7 +330,7 @@ class MultimodalEntityLinker:
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
)
def detect_entity_conflicts(self, entities: List[Dict]) -> List[Dict]:
def detect_entity_conflicts(self, entities: list[dict]) -> list[dict]:
"""
检测实体冲突(同名但不同义)
@@ -380,7 +379,7 @@ class MultimodalEntityLinker:
return conflicts
def suggest_entity_merges(self, entities: List[Dict], existing_links: List[EntityLink] = None) -> List[Dict]:
def suggest_entity_merges(self, entities: list[dict], existing_links: list[EntityLink] = None) -> list[dict]:
"""
建议实体合并
@@ -464,7 +463,7 @@ class MultimodalEntityLinker:
confidence=confidence,
)
def analyze_modality_distribution(self, multimodal_entities: List[MultimodalEntity]) -> Dict:
def analyze_modality_distribution(self, multimodal_entities: list[MultimodalEntity]) -> dict:
"""
分析模态分布

View File

@@ -4,12 +4,11 @@ InsightFlow Multimodal Processor - Phase 7
视频处理模块提取音频、关键帧、OCR识别
"""
import os
import json
import uuid
import tempfile
import os
import subprocess
from typing import List, Dict, Tuple
import tempfile
import uuid
from dataclasses import dataclass
from pathlib import Path
@@ -48,7 +47,7 @@ class VideoFrame:
frame_path: str
ocr_text: str = ""
ocr_confidence: float = 0.0
entities_detected: List[Dict] = None
entities_detected: list[dict] = None
def __post_init__(self):
if self.entities_detected is None:
@@ -72,7 +71,7 @@ class VideoInfo:
transcript_id: str = ""
status: str = "pending"
error_message: str = ""
metadata: Dict = None
metadata: dict = None
def __post_init__(self):
if self.metadata is None:
@@ -85,8 +84,8 @@ class VideoProcessingResult:
video_id: str
audio_path: str
frames: List[VideoFrame]
ocr_results: List[Dict]
frames: list[VideoFrame]
ocr_results: list[dict]
full_text: str # 整合的文本(音频转录 + OCR文本
success: bool
error_message: str = ""
@@ -114,7 +113,7 @@ class MultimodalProcessor:
os.makedirs(self.frames_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok=True)
def extract_video_info(self, video_path: str) -> Dict:
def extract_video_info(self, video_path: str) -> dict:
"""
提取视频基本信息
@@ -215,7 +214,7 @@ class MultimodalProcessor:
print(f"Error extracting audio: {e}")
raise
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> List[str]:
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> list[str]:
"""
从视频中提取关键帧
@@ -275,7 +274,7 @@ class MultimodalProcessor:
return frame_paths
def perform_ocr(self, image_path: str) -> Tuple[str, float]:
def perform_ocr(self, image_path: str) -> tuple[str, float]:
"""
对图片进行OCR识别

View File

@@ -5,10 +5,9 @@ Phase 5: Neo4j 图数据库集成
支持数据同步、复杂图查询和图算法分析
"""
import os
import json
import logging
from typing import List, Dict, Optional
import os
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@@ -20,7 +19,7 @@ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
# 延迟导入,避免未安装时出错
try:
from neo4j import GraphDatabase, Driver
from neo4j import Driver, GraphDatabase
NEO4J_AVAILABLE = True
except ImportError:
@@ -37,8 +36,8 @@ class GraphEntity:
name: str
type: str
definition: str = ""
aliases: List[str] = None
properties: Dict = None
aliases: list[str] = None
properties: dict = None
def __post_init__(self):
if self.aliases is None:
@@ -56,7 +55,7 @@ class GraphRelation:
target_id: str
relation_type: str
evidence: str = ""
properties: Dict = None
properties: dict = None
def __post_init__(self):
if self.properties is None:
@@ -67,8 +66,8 @@ class GraphRelation:
class PathResult:
"""路径查询结果"""
nodes: List[Dict]
relationships: List[Dict]
nodes: list[dict]
relationships: list[dict]
length: int
total_weight: float = 0.0
@@ -78,7 +77,7 @@ class CommunityResult:
"""社区发现结果"""
community_id: int
nodes: List[Dict]
nodes: list[dict]
size: int
density: float = 0.0
@@ -100,7 +99,7 @@ class Neo4jManager:
self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD
self._driver: Optional["Driver"] = None
self._driver: Driver | None = None
if not NEO4J_AVAILABLE:
logger.error("Neo4j driver not available. Please install: pip install neo4j")
@@ -226,7 +225,7 @@ class Neo4jManager:
properties=json.dumps(entity.properties),
)
def sync_entities_batch(self, entities: List[GraphEntity]):
def sync_entities_batch(self, entities: list[GraphEntity]):
"""批量同步实体到 Neo4j"""
if not self._driver or not entities:
return
@@ -287,7 +286,7 @@ class Neo4jManager:
properties=json.dumps(relation.properties),
)
def sync_relations_batch(self, relations: List[GraphRelation]):
def sync_relations_batch(self, relations: list[GraphRelation]):
"""批量同步关系到 Neo4j"""
if not self._driver or not relations:
return
@@ -350,7 +349,7 @@ class Neo4jManager:
# ==================== 复杂图查询 ====================
def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> Optional[PathResult]:
def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> PathResult | None:
"""
查找两个实体之间的最短路径
@@ -399,7 +398,7 @@ class Neo4jManager:
return PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships))
def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> List[PathResult]:
def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> list[PathResult]:
"""
查找两个实体之间的所有路径
@@ -449,7 +448,7 @@ class Neo4jManager:
return paths
def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> List[Dict]:
def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> list[dict]:
"""
查找实体的邻居节点
@@ -502,7 +501,7 @@ class Neo4jManager:
return neighbors
def find_common_neighbors(self, entity_id1: str, entity_id2: str) -> List[Dict]:
def find_common_neighbors(self, entity_id1: str, entity_id2: str) -> list[dict]:
"""
查找两个实体的共同邻居(潜在关联)
@@ -533,7 +532,7 @@ class Neo4jManager:
# ==================== 图算法分析 ====================
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> List[CentralityResult]:
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
"""
计算 PageRank 中心性
@@ -619,7 +618,7 @@ class Neo4jManager:
return rankings
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> List[CentralityResult]:
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
"""
计算 Betweenness 中心性(桥梁作用)
@@ -663,7 +662,7 @@ class Neo4jManager:
return rankings
def detect_communities(self, project_id: str) -> List[CommunityResult]:
def detect_communities(self, project_id: str) -> list[CommunityResult]:
"""
社区发现(使用 Louvain 算法)
@@ -733,7 +732,7 @@ class Neo4jManager:
results.sort(key=lambda x: x.size, reverse=True)
return results
def find_central_entities(self, project_id: str, metric: str = "degree") -> List[CentralityResult]:
def find_central_entities(self, project_id: str, metric: str = "degree") -> list[CentralityResult]:
"""
查找中心实体
@@ -791,7 +790,7 @@ class Neo4jManager:
# ==================== 图统计 ====================
def get_graph_stats(self, project_id: str) -> Dict:
def get_graph_stats(self, project_id: str) -> dict:
"""
获取项目的图统计信息
@@ -870,7 +869,7 @@ class Neo4jManager:
"density": round(relation_count / (entity_count * (entity_count - 1)), 4) if entity_count > 1 else 0,
}
def get_subgraph(self, entity_ids: List[str], depth: int = 1) -> Dict:
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
"""
获取指定实体的子图
@@ -959,7 +958,7 @@ def close_neo4j_manager():
# 便捷函数
def sync_project_to_neo4j(project_id: str, project_name: str, entities: List[Dict], relations: List[Dict]):
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]):
"""
同步整个项目到 Neo4j

View File

@@ -10,26 +10,27 @@ InsightFlow Operations & Monitoring Manager - Phase 8 Task 8
作者: InsightFlow Team
"""
import os
import json
import sqlite3
import httpx
import asyncio
import hashlib
import uuid
import json
import os
import re
import time
import sqlite3
import statistics
from typing import List, Dict, Optional, Tuple, Callable
import time
import uuid
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from enum import StrEnum
import httpx
# Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class AlertSeverity(str, Enum):
class AlertSeverity(StrEnum):
"""告警严重级别 P0-P3"""
P0 = "p0" # 紧急 - 系统不可用,需要立即处理
@@ -38,7 +39,7 @@ class AlertSeverity(str, Enum):
P3 = "p3" # 轻微 - 非核心功能问题24小时内处理
class AlertStatus(str, Enum):
class AlertStatus(StrEnum):
"""告警状态"""
FIRING = "firing" # 正在告警
@@ -47,7 +48,7 @@ class AlertStatus(str, Enum):
SUPPRESSED = "suppressed" # 已抑制
class AlertChannelType(str, Enum):
class AlertChannelType(StrEnum):
"""告警渠道类型"""
PAGERDUTY = "pagerduty"
@@ -60,7 +61,7 @@ class AlertChannelType(str, Enum):
WEBHOOK = "webhook"
class AlertRuleType(str, Enum):
class AlertRuleType(StrEnum):
"""告警规则类型"""
THRESHOLD = "threshold" # 阈值告警
@@ -69,7 +70,7 @@ class AlertRuleType(str, Enum):
COMPOSITE = "composite" # 复合告警
class ResourceType(str, Enum):
class ResourceType(StrEnum):
"""资源类型"""
CPU = "cpu"
@@ -82,7 +83,7 @@ class ResourceType(str, Enum):
QUEUE = "queue"
class ScalingAction(str, Enum):
class ScalingAction(StrEnum):
"""扩缩容动作"""
SCALE_UP = "scale_up" # 扩容
@@ -90,7 +91,7 @@ class ScalingAction(str, Enum):
MAINTAIN = "maintain" # 保持
class HealthStatus(str, Enum):
class HealthStatus(StrEnum):
"""健康状态"""
HEALTHY = "healthy"
@@ -99,7 +100,7 @@ class HealthStatus(str, Enum):
UNKNOWN = "unknown"
class BackupStatus(str, Enum):
class BackupStatus(StrEnum):
"""备份状态"""
PENDING = "pending"
@@ -124,9 +125,9 @@ class AlertRule:
threshold: float
duration: int # 持续时间(秒)
evaluation_interval: int # 评估间隔(秒)
channels: List[str] # 告警渠道ID列表
labels: Dict[str, str] # 标签
annotations: Dict[str, str] # 注释
channels: list[str] # 告警渠道ID列表
labels: dict[str, str] # 标签
annotations: dict[str, str] # 注释
is_enabled: bool
created_at: str
updated_at: str
@@ -141,12 +142,12 @@ class AlertChannel:
tenant_id: str
name: str
channel_type: AlertChannelType
config: Dict # 渠道特定配置
severity_filter: List[str] # 过滤的告警级别
config: dict # 渠道特定配置
severity_filter: list[str] # 过滤的告警级别
is_enabled: bool
success_count: int
fail_count: int
last_used_at: Optional[str]
last_used_at: str | None
created_at: str
updated_at: str
@@ -165,13 +166,13 @@ class Alert:
metric: str
value: float
threshold: float
labels: Dict[str, str]
annotations: Dict[str, str]
labels: dict[str, str]
annotations: dict[str, str]
started_at: str
resolved_at: Optional[str]
acknowledged_by: Optional[str]
acknowledged_at: Optional[str]
notification_sent: Dict[str, bool] # 渠道发送状态
resolved_at: str | None
acknowledged_by: str | None
acknowledged_at: str | None
notification_sent: dict[str, bool] # 渠道发送状态
suppression_count: int # 抑制计数
@@ -182,11 +183,11 @@ class AlertSuppressionRule:
id: str
tenant_id: str
name: str
matchers: Dict[str, str] # 匹配条件
matchers: dict[str, str] # 匹配条件
duration: int # 抑制持续时间(秒)
is_regex: bool # 是否使用正则匹配
created_at: str
expires_at: Optional[str]
expires_at: str | None
@dataclass
@@ -196,7 +197,7 @@ class AlertGroup:
id: str
tenant_id: str
group_key: str # 聚合键
alerts: List[str] # 告警ID列表
alerts: list[str] # 告警ID列表
created_at: str
updated_at: str
@@ -213,7 +214,7 @@ class ResourceMetric:
metric_value: float
unit: str
timestamp: str
metadata: Dict
metadata: dict
@dataclass
@@ -267,8 +268,8 @@ class ScalingEvent:
triggered_by: str # 触发来源: manual, auto, scheduled
status: str # pending, in_progress, completed, failed
started_at: str
completed_at: Optional[str]
error_message: Optional[str]
completed_at: str | None
error_message: str | None
@dataclass
@@ -281,7 +282,7 @@ class HealthCheck:
target_type: str # service, database, api, etc.
target_id: str
check_type: str # http, tcp, ping, custom
check_config: Dict # 检查配置
check_config: dict # 检查配置
interval: int # 检查间隔(秒)
timeout: int # 超时时间(秒)
retry_count: int
@@ -302,7 +303,7 @@ class HealthCheckResult:
status: HealthStatus
response_time: float # 响应时间(毫秒)
message: str
details: Dict
details: dict
checked_at: str
@@ -314,7 +315,7 @@ class FailoverConfig:
tenant_id: str
name: str
primary_region: str
secondary_regions: List[str] # 备用区域列表
secondary_regions: list[str] # 备用区域列表
failover_trigger: str # 触发条件
auto_failover: bool
failover_timeout: int # 故障转移超时(秒)
@@ -336,8 +337,8 @@ class FailoverEvent:
reason: str
status: str # initiated, in_progress, completed, failed, rolled_back
started_at: str
completed_at: Optional[str]
rolled_back_at: Optional[str]
completed_at: str | None
rolled_back_at: str | None
@dataclass
@@ -371,9 +372,9 @@ class BackupRecord:
size_bytes: int
checksum: str
started_at: str
completed_at: Optional[str]
verified_at: Optional[str]
error_message: Optional[str]
completed_at: str | None
verified_at: str | None
error_message: str | None
storage_path: str
@@ -386,9 +387,9 @@ class CostReport:
report_period: str # YYYY-MM
total_cost: float
currency: str
breakdown: Dict[str, float] # 按资源类型分解
trends: Dict # 趋势数据
anomalies: List[Dict] # 异常检测
breakdown: dict[str, float] # 按资源类型分解
trends: dict # 趋势数据
anomalies: list[dict] # 异常检测
created_at: str
@@ -405,7 +406,7 @@ class ResourceUtilization:
avg_utilization: float
idle_time_percent: float
report_date: str
recommendations: List[str]
recommendations: list[str]
@dataclass
@@ -438,11 +439,11 @@ class CostOptimizationSuggestion:
currency: str
confidence: float
difficulty: str # easy, medium, hard
implementation_steps: List[str]
implementation_steps: list[str]
risk_level: str # low, medium, high
is_applied: bool
created_at: str
applied_at: Optional[str]
applied_at: str | None
class OpsManager:
@@ -450,7 +451,7 @@ class OpsManager:
def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path
self._alert_evaluators: Dict[str, Callable] = {}
self._alert_evaluators: dict[str, Callable] = {}
self._running = False
self._evaluator_thread = None
self._register_default_evaluators()
@@ -481,9 +482,9 @@ class OpsManager:
threshold: float,
duration: int,
evaluation_interval: int,
channels: List[str],
labels: Dict,
annotations: Dict,
channels: list[str],
labels: dict,
annotations: dict,
created_by: str,
) -> AlertRule:
"""创建告警规则"""
@@ -545,7 +546,7 @@ class OpsManager:
return rule
def get_alert_rule(self, rule_id: str) -> Optional[AlertRule]:
def get_alert_rule(self, rule_id: str) -> AlertRule | None:
"""获取告警规则"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM alert_rules WHERE id = ?", (rule_id,)).fetchone()
@@ -554,7 +555,7 @@ class OpsManager:
return self._row_to_alert_rule(row)
return None
def list_alert_rules(self, tenant_id: str, is_enabled: Optional[bool] = None) -> List[AlertRule]:
def list_alert_rules(self, tenant_id: str, is_enabled: bool | None = None) -> list[AlertRule]:
"""列出租户的所有告警规则"""
query = "SELECT * FROM alert_rules WHERE tenant_id = ?"
params = [tenant_id]
@@ -569,7 +570,7 @@ class OpsManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_alert_rule(row) for row in rows]
def update_alert_rule(self, rule_id: str, **kwargs) -> Optional[AlertRule]:
def update_alert_rule(self, rule_id: str, **kwargs) -> AlertRule | None:
"""更新告警规则"""
allowed_fields = [
"name",
@@ -619,7 +620,7 @@ class OpsManager:
# ==================== 告警渠道管理 ====================
def create_alert_channel(
self, tenant_id: str, name: str, channel_type: AlertChannelType, config: Dict, severity_filter: List[str] = None
self, tenant_id: str, name: str, channel_type: AlertChannelType, config: dict, severity_filter: list[str] = None
) -> AlertChannel:
"""创建告警渠道"""
channel_id = f"ac_{uuid.uuid4().hex[:16]}"
@@ -667,7 +668,7 @@ class OpsManager:
return channel
def get_alert_channel(self, channel_id: str) -> Optional[AlertChannel]:
def get_alert_channel(self, channel_id: str) -> AlertChannel | None:
"""获取告警渠道"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM alert_channels WHERE id = ?", (channel_id,)).fetchone()
@@ -676,7 +677,7 @@ class OpsManager:
return self._row_to_alert_channel(row)
return None
def list_alert_channels(self, tenant_id: str) -> List[AlertChannel]:
def list_alert_channels(self, tenant_id: str) -> list[AlertChannel]:
"""列出租户的所有告警渠道"""
with self._get_db() as conn:
rows = conn.execute(
@@ -715,7 +716,7 @@ class OpsManager:
# ==================== 告警评估与触发 ====================
def _evaluate_threshold_rule(self, rule: AlertRule, metrics: List[ResourceMetric]) -> bool:
def _evaluate_threshold_rule(self, rule: AlertRule, metrics: list[ResourceMetric]) -> bool:
"""评估阈值告警规则"""
if not metrics:
return False
@@ -746,7 +747,7 @@ class OpsManager:
return False
def _evaluate_anomaly_rule(self, rule: AlertRule, metrics: List[ResourceMetric]) -> bool:
def _evaluate_anomaly_rule(self, rule: AlertRule, metrics: list[ResourceMetric]) -> bool:
"""评估异常检测规则(基于标准差)"""
if len(metrics) < 10:
return False
@@ -764,7 +765,7 @@ class OpsManager:
return z_score > 3.0
def _evaluate_predictive_rule(self, rule: AlertRule, metrics: List[ResourceMetric]) -> bool:
def _evaluate_predictive_rule(self, rule: AlertRule, metrics: list[ResourceMetric]) -> bool:
"""评估预测性告警规则(基于线性趋势)"""
if len(metrics) < 5:
return False
@@ -814,7 +815,7 @@ class OpsManager:
# 触发告警
await self._trigger_alert(rule, metrics[-1] if metrics else None)
async def _trigger_alert(self, rule: AlertRule, metric: Optional[ResourceMetric]):
async def _trigger_alert(self, rule: AlertRule, metric: ResourceMetric | None):
"""触发告警"""
# 检查是否已有相同告警在触发中
existing = self.get_active_alert_by_rule(rule.id)
@@ -1166,7 +1167,7 @@ class OpsManager:
# ==================== 告警查询与管理 ====================
def get_active_alert_by_rule(self, rule_id: str) -> Optional[Alert]:
def get_active_alert_by_rule(self, rule_id: str) -> Alert | None:
"""获取规则对应的活跃告警"""
with self._get_db() as conn:
row = conn.execute(
@@ -1180,7 +1181,7 @@ class OpsManager:
return self._row_to_alert(row)
return None
def get_alert(self, alert_id: str) -> Optional[Alert]:
def get_alert(self, alert_id: str) -> Alert | None:
"""获取告警详情"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM alerts WHERE id = ?", (alert_id,)).fetchone()
@@ -1192,10 +1193,10 @@ class OpsManager:
def list_alerts(
self,
tenant_id: str,
status: Optional[AlertStatus] = None,
severity: Optional[AlertSeverity] = None,
status: AlertStatus | None = None,
severity: AlertSeverity | None = None,
limit: int = 100,
) -> List[Alert]:
) -> list[Alert]:
"""列出租户的告警"""
query = "SELECT * FROM alerts WHERE tenant_id = ?"
params = [tenant_id]
@@ -1214,7 +1215,7 @@ class OpsManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_alert(row) for row in rows]
def acknowledge_alert(self, alert_id: str, user_id: str) -> Optional[Alert]:
def acknowledge_alert(self, alert_id: str, user_id: str) -> Alert | None:
"""确认告警"""
now = datetime.now().isoformat()
@@ -1231,7 +1232,7 @@ class OpsManager:
return self.get_alert(alert_id)
def resolve_alert(self, alert_id: str) -> Optional[Alert]:
def resolve_alert(self, alert_id: str) -> Alert | None:
"""解决告警"""
now = datetime.now().isoformat()
@@ -1306,10 +1307,10 @@ class OpsManager:
self,
tenant_id: str,
name: str,
matchers: Dict[str, str],
matchers: dict[str, str],
duration: int,
is_regex: bool = False,
expires_at: Optional[str] = None,
expires_at: str | None = None,
) -> AlertSuppressionRule:
"""创建告警抑制规则"""
rule_id = f"sr_{uuid.uuid4().hex[:16]}"
@@ -1394,7 +1395,7 @@ class OpsManager:
metric_name: str,
metric_value: float,
unit: str,
metadata: Dict = None,
metadata: dict = None,
) -> ResourceMetric:
"""记录资源指标"""
metric_id = f"rm_{uuid.uuid4().hex[:16]}"
@@ -1436,7 +1437,7 @@ class OpsManager:
return metric
def get_recent_metrics(self, tenant_id: str, metric_name: str, seconds: int = 3600) -> List[ResourceMetric]:
def get_recent_metrics(self, tenant_id: str, metric_name: str, seconds: int = 3600) -> list[ResourceMetric]:
"""获取最近的指标数据"""
cutoff_time = (datetime.now() - timedelta(seconds=seconds)).isoformat()
@@ -1458,7 +1459,7 @@ class OpsManager:
metric_name: str,
start_time: str,
end_time: str,
) -> List[ResourceMetric]:
) -> list[ResourceMetric]:
"""获取指定资源的指标数据"""
with self._get_db() as conn:
rows = conn.execute(
@@ -1549,7 +1550,7 @@ class OpsManager:
return plan
def _calculate_trend(self, values: List[float]) -> float:
def _calculate_trend(self, values: list[float]) -> float:
"""计算趋势(增长率)"""
if len(values) < 2:
return 0.0
@@ -1576,7 +1577,7 @@ class OpsManager:
return slope / mean_y
return 0.0
def get_capacity_plans(self, tenant_id: str) -> List[CapacityPlan]:
def get_capacity_plans(self, tenant_id: str) -> list[CapacityPlan]:
"""获取容量规划列表"""
with self._get_db() as conn:
rows = conn.execute(
@@ -1653,7 +1654,7 @@ class OpsManager:
return policy
def get_auto_scaling_policy(self, policy_id: str) -> Optional[AutoScalingPolicy]:
def get_auto_scaling_policy(self, policy_id: str) -> AutoScalingPolicy | None:
"""获取自动扩缩容策略"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM auto_scaling_policies WHERE id = ?", (policy_id,)).fetchone()
@@ -1662,7 +1663,7 @@ class OpsManager:
return self._row_to_auto_scaling_policy(row)
return None
def list_auto_scaling_policies(self, tenant_id: str) -> List[AutoScalingPolicy]:
def list_auto_scaling_policies(self, tenant_id: str) -> list[AutoScalingPolicy]:
"""列出租户的自动扩缩容策略"""
with self._get_db() as conn:
rows = conn.execute(
@@ -1672,7 +1673,7 @@ class OpsManager:
def evaluate_scaling_policy(
self, policy_id: str, current_instances: int, current_utilization: float
) -> Optional[ScalingEvent]:
) -> ScalingEvent | None:
"""评估扩缩容策略"""
policy = self.get_auto_scaling_policy(policy_id)
if not policy or not policy.is_enabled:
@@ -1754,7 +1755,7 @@ class OpsManager:
return event
def get_last_scaling_event(self, policy_id: str) -> Optional[ScalingEvent]:
def get_last_scaling_event(self, policy_id: str) -> ScalingEvent | None:
"""获取最近的扩缩容事件"""
with self._get_db() as conn:
row = conn.execute(
@@ -1770,7 +1771,7 @@ class OpsManager:
def update_scaling_event_status(
self, event_id: str, status: str, error_message: str = None
) -> Optional[ScalingEvent]:
) -> ScalingEvent | None:
"""更新扩缩容事件状态"""
now = datetime.now().isoformat()
@@ -1797,7 +1798,7 @@ class OpsManager:
return self.get_scaling_event(event_id)
def get_scaling_event(self, event_id: str) -> Optional[ScalingEvent]:
def get_scaling_event(self, event_id: str) -> ScalingEvent | None:
"""获取扩缩容事件"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM scaling_events WHERE id = ?", (event_id,)).fetchone()
@@ -1806,7 +1807,7 @@ class OpsManager:
return self._row_to_scaling_event(row)
return None
def list_scaling_events(self, tenant_id: str, policy_id: str = None, limit: int = 100) -> List[ScalingEvent]:
def list_scaling_events(self, tenant_id: str, policy_id: str = None, limit: int = 100) -> list[ScalingEvent]:
"""列出租户的扩缩容事件"""
query = "SELECT * FROM scaling_events WHERE tenant_id = ?"
params = [tenant_id]
@@ -1831,7 +1832,7 @@ class OpsManager:
target_type: str,
target_id: str,
check_type: str,
check_config: Dict,
check_config: dict,
interval: int = 60,
timeout: int = 10,
retry_count: int = 3,
@@ -1889,7 +1890,7 @@ class OpsManager:
return check
def get_health_check(self, check_id: str) -> Optional[HealthCheck]:
def get_health_check(self, check_id: str) -> HealthCheck | None:
"""获取健康检查配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM health_checks WHERE id = ?", (check_id,)).fetchone()
@@ -1898,7 +1899,7 @@ class OpsManager:
return self._row_to_health_check(row)
return None
def list_health_checks(self, tenant_id: str) -> List[HealthCheck]:
def list_health_checks(self, tenant_id: str) -> list[HealthCheck]:
"""列出租户的健康检查"""
with self._get_db() as conn:
rows = conn.execute(
@@ -1958,7 +1959,7 @@ class OpsManager:
return result
async def _check_http_health(self, check: HealthCheck) -> Tuple[HealthStatus, float, str]:
async def _check_http_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]:
"""HTTP 健康检查"""
config = check.check_config
url = config.get("url")
@@ -1980,7 +1981,7 @@ class OpsManager:
except Exception as e:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e)
async def _check_tcp_health(self, check: HealthCheck) -> Tuple[HealthStatus, float, str]:
async def _check_tcp_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]:
"""TCP 健康检查"""
config = check.check_config
host = config.get("host")
@@ -1996,12 +1997,12 @@ class OpsManager:
writer.close()
await writer.wait_closed()
return HealthStatus.HEALTHY, response_time, "TCP connection successful"
except asyncio.TimeoutError:
except TimeoutError:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, "Connection timeout"
except Exception as e:
return HealthStatus.UNHEALTHY, (time.time() - start_time) * 1000, str(e)
async def _check_ping_health(self, check: HealthCheck) -> Tuple[HealthStatus, float, str]:
async def _check_ping_health(self, check: HealthCheck) -> tuple[HealthStatus, float, str]:
"""Ping 健康检查(模拟)"""
config = check.check_config
host = config.get("host")
@@ -2013,7 +2014,7 @@ class OpsManager:
# 这里模拟成功
return HealthStatus.HEALTHY, 10.0, "Ping successful"
def get_health_check_results(self, check_id: str, limit: int = 100) -> List[HealthCheckResult]:
def get_health_check_results(self, check_id: str, limit: int = 100) -> list[HealthCheckResult]:
"""获取健康检查历史结果"""
with self._get_db() as conn:
rows = conn.execute(
@@ -2031,7 +2032,7 @@ class OpsManager:
tenant_id: str,
name: str,
primary_region: str,
secondary_regions: List[str],
secondary_regions: list[str],
failover_trigger: str,
auto_failover: bool = False,
failover_timeout: int = 300,
@@ -2083,7 +2084,7 @@ class OpsManager:
return config
def get_failover_config(self, config_id: str) -> Optional[FailoverConfig]:
def get_failover_config(self, config_id: str) -> FailoverConfig | None:
"""获取故障转移配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM failover_configs WHERE id = ?", (config_id,)).fetchone()
@@ -2092,7 +2093,7 @@ class OpsManager:
return self._row_to_failover_config(row)
return None
def list_failover_configs(self, tenant_id: str) -> List[FailoverConfig]:
def list_failover_configs(self, tenant_id: str) -> list[FailoverConfig]:
"""列出租户的故障转移配置"""
with self._get_db() as conn:
rows = conn.execute(
@@ -2100,7 +2101,7 @@ class OpsManager:
).fetchall()
return [self._row_to_failover_config(row) for row in rows]
def initiate_failover(self, config_id: str, reason: str) -> Optional[FailoverEvent]:
def initiate_failover(self, config_id: str, reason: str) -> FailoverEvent | None:
"""发起故障转移"""
config = self.get_failover_config(config_id)
if not config or not config.is_enabled:
@@ -2150,7 +2151,7 @@ class OpsManager:
return event
def update_failover_status(self, event_id: str, status: str) -> Optional[FailoverEvent]:
def update_failover_status(self, event_id: str, status: str) -> FailoverEvent | None:
"""更新故障转移状态"""
now = datetime.now().isoformat()
@@ -2186,7 +2187,7 @@ class OpsManager:
return self.get_failover_event(event_id)
def get_failover_event(self, event_id: str) -> Optional[FailoverEvent]:
def get_failover_event(self, event_id: str) -> FailoverEvent | None:
"""获取故障转移事件"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM failover_events WHERE id = ?", (event_id,)).fetchone()
@@ -2195,7 +2196,7 @@ class OpsManager:
return self._row_to_failover_event(row)
return None
def list_failover_events(self, tenant_id: str, limit: int = 100) -> List[FailoverEvent]:
def list_failover_events(self, tenant_id: str, limit: int = 100) -> list[FailoverEvent]:
"""列出租户的故障转移事件"""
with self._get_db() as conn:
rows = conn.execute(
@@ -2272,7 +2273,7 @@ class OpsManager:
return job
def get_backup_job(self, job_id: str) -> Optional[BackupJob]:
def get_backup_job(self, job_id: str) -> BackupJob | None:
"""获取备份任务"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM backup_jobs WHERE id = ?", (job_id,)).fetchone()
@@ -2281,7 +2282,7 @@ class OpsManager:
return self._row_to_backup_job(row)
return None
def list_backup_jobs(self, tenant_id: str) -> List[BackupJob]:
def list_backup_jobs(self, tenant_id: str) -> list[BackupJob]:
"""列出租户的备份任务"""
with self._get_db() as conn:
rows = conn.execute(
@@ -2289,7 +2290,7 @@ class OpsManager:
).fetchall()
return [self._row_to_backup_job(row) for row in rows]
def execute_backup(self, job_id: str) -> Optional[BackupRecord]:
def execute_backup(self, job_id: str) -> BackupRecord | None:
"""执行备份"""
job = self.get_backup_job(job_id)
if not job or not job.is_enabled:
@@ -2354,7 +2355,7 @@ class OpsManager:
)
conn.commit()
def get_backup_record(self, record_id: str) -> Optional[BackupRecord]:
def get_backup_record(self, record_id: str) -> BackupRecord | None:
"""获取备份记录"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM backup_records WHERE id = ?", (record_id,)).fetchone()
@@ -2363,7 +2364,7 @@ class OpsManager:
return self._row_to_backup_record(row)
return None
def list_backup_records(self, tenant_id: str, job_id: str = None, limit: int = 100) -> List[BackupRecord]:
def list_backup_records(self, tenant_id: str, job_id: str = None, limit: int = 100) -> list[BackupRecord]:
"""列出租户的备份记录"""
query = "SELECT * FROM backup_records WHERE tenant_id = ?"
params = [tenant_id]
@@ -2452,7 +2453,7 @@ class OpsManager:
return report
def _detect_cost_anomalies(self, utilizations: List[ResourceUtilization]) -> List[Dict]:
def _detect_cost_anomalies(self, utilizations: list[ResourceUtilization]) -> list[dict]:
"""检测成本异常"""
anomalies = []
@@ -2483,7 +2484,7 @@ class OpsManager:
return anomalies
def _calculate_cost_trends(self, tenant_id: str, year: int, month: int) -> Dict:
def _calculate_cost_trends(self, tenant_id: str, year: int, month: int) -> dict:
"""计算成本趋势"""
# 简化实现:返回模拟趋势
return {"month_over_month": 0.05, "year_over_year": 0.15, "forecast_next_month": 1.05} # 5% 增长 # 15% 增长
@@ -2498,7 +2499,7 @@ class OpsManager:
avg_utilization: float,
idle_time_percent: float,
report_date: str,
recommendations: List[str] = None,
recommendations: list[str] = None,
) -> ResourceUtilization:
"""记录资源利用率"""
util_id = f"ru_{uuid.uuid4().hex[:16]}"
@@ -2541,7 +2542,7 @@ class OpsManager:
return util
def get_resource_utilizations(self, tenant_id: str, report_period: str) -> List[ResourceUtilization]:
def get_resource_utilizations(self, tenant_id: str, report_period: str) -> list[ResourceUtilization]:
"""获取资源利用率列表"""
with self._get_db() as conn:
rows = conn.execute(
@@ -2552,7 +2553,7 @@ class OpsManager:
).fetchall()
return [self._row_to_resource_utilization(row) for row in rows]
def detect_idle_resources(self, tenant_id: str) -> List[IdleResource]:
def detect_idle_resources(self, tenant_id: str) -> list[IdleResource]:
"""检测闲置资源"""
idle_resources = []
@@ -2615,7 +2616,7 @@ class OpsManager:
return idle_resources
def get_idle_resources(self, tenant_id: str) -> List[IdleResource]:
def get_idle_resources(self, tenant_id: str) -> list[IdleResource]:
"""获取闲置资源列表"""
with self._get_db() as conn:
rows = conn.execute(
@@ -2623,7 +2624,7 @@ class OpsManager:
).fetchall()
return [self._row_to_idle_resource(row) for row in rows]
def generate_cost_optimization_suggestions(self, tenant_id: str) -> List[CostOptimizationSuggestion]:
def generate_cost_optimization_suggestions(self, tenant_id: str) -> list[CostOptimizationSuggestion]:
"""生成成本优化建议"""
suggestions = []
@@ -2691,7 +2692,7 @@ class OpsManager:
def get_cost_optimization_suggestions(
self, tenant_id: str, is_applied: bool = None
) -> List[CostOptimizationSuggestion]:
) -> list[CostOptimizationSuggestion]:
"""获取成本优化建议"""
query = "SELECT * FROM cost_optimization_suggestions WHERE tenant_id = ?"
params = [tenant_id]
@@ -2706,7 +2707,7 @@ class OpsManager:
rows = conn.execute(query, params).fetchall()
return [self._row_to_cost_optimization_suggestion(row) for row in rows]
def apply_cost_optimization_suggestion(self, suggestion_id: str) -> Optional[CostOptimizationSuggestion]:
def apply_cost_optimization_suggestion(self, suggestion_id: str) -> CostOptimizationSuggestion | None:
"""应用成本优化建议"""
now = datetime.now().isoformat()
@@ -2723,7 +2724,7 @@ class OpsManager:
return self.get_cost_optimization_suggestion(suggestion_id)
def get_cost_optimization_suggestion(self, suggestion_id: str) -> Optional[CostOptimizationSuggestion]:
def get_cost_optimization_suggestion(self, suggestion_id: str) -> CostOptimizationSuggestion | None:
"""获取成本优化建议详情"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM cost_optimization_suggestions WHERE id = ?", (suggestion_id,)).fetchone()

View File

@@ -6,6 +6,7 @@ OSS 上传工具 - 用于阿里听悟音频上传
import os
import uuid
from datetime import datetime
import oss2

View File

@@ -9,18 +9,19 @@ Phase 7 Task 8: Performance Optimization & Scaling
4. PerformanceMonitor - 性能监控API响应、查询性能、缓存命中率
"""
import os
import json
import time
import hashlib
import json
import os
import sqlite3
import threading
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable, Tuple
from datetime import datetime
from collections import OrderedDict
from functools import wraps
import time
import uuid
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from typing import Any
# 尝试导入 Redis
try:
@@ -67,7 +68,7 @@ class CacheEntry:
key: str
value: Any
created_at: float
expires_at: Optional[float]
expires_at: float | None
access_count: int = 0
last_accessed: float = 0
size_bytes: int = 0
@@ -79,12 +80,12 @@ class PerformanceMetric:
id: str
metric_type: str # api_response, db_query, cache_operation
endpoint: Optional[str]
endpoint: str | None
duration_ms: float
timestamp: str
metadata: Dict = field(default_factory=dict)
metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
return {
"id": self.id,
"metric_type": self.metric_type,
@@ -102,16 +103,16 @@ class TaskInfo:
id: str
task_type: str
status: str # pending, running, success, failed, retrying
payload: Dict
payload: dict
created_at: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
result: Optional[Any] = None
error_message: Optional[str] = None
started_at: str | None = None
completed_at: str | None = None
result: Any | None = None
error_message: str | None = None
retry_count: int = 0
max_retries: int = 3
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
return {
"id": self.id,
"task_type": self.task_type,
@@ -132,7 +133,7 @@ class ShardInfo:
"""分片信息数据模型"""
shard_id: str
shard_key_range: Tuple[str, str] # (start, end)
shard_key_range: tuple[str, str] # (start, end)
db_path: str
entity_count: int = 0
is_active: bool = True
@@ -160,7 +161,7 @@ class CacheManager:
def __init__(
self,
redis_url: Optional[str] = None,
redis_url: str | None = None,
max_memory_size: int = 100 * 1024 * 1024, # 100MB
default_ttl: int = 3600, # 1小时
db_path: str = "insightflow.db",
@@ -242,7 +243,7 @@ class CacheManager:
self.current_memory_size -= oldest_entry.size_bytes
self.stats.evictions += 1
def get(self, key: str) -> Optional[Any]:
def get(self, key: str) -> Any | None:
"""
获取缓存值
@@ -291,7 +292,7 @@ class CacheManager:
self.stats.misses += 1
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
"""
设置缓存值
@@ -373,7 +374,7 @@ class CacheManager:
self.current_memory_size = 0
return True
def get_many(self, keys: List[str]) -> Dict[str, Any]:
def get_many(self, keys: list[str]) -> dict[str, Any]:
"""批量获取缓存"""
results = {}
@@ -397,7 +398,7 @@ class CacheManager:
return results
def set_many(self, mapping: Dict[str, Any], ttl: Optional[int] = None) -> bool:
def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool:
"""批量设置缓存"""
ttl = ttl or self.default_ttl
@@ -417,7 +418,7 @@ class CacheManager:
self.set(key, value, ttl)
return True
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取缓存统计"""
self.stats.update_hit_rate()
@@ -468,7 +469,7 @@ class CacheManager:
conn.commit()
conn.close()
def warm_up(self, project_id: str) -> Dict:
def warm_up(self, project_id: str) -> dict:
"""
缓存预热 - 加载项目的热点数据
@@ -612,7 +613,7 @@ class DatabaseSharding:
os.makedirs(shard_db_dir, exist_ok=True)
# 分片映射
self.shard_map: Dict[str, ShardInfo] = {}
self.shard_map: dict[str, ShardInfo] = {}
# 初始化分片
self._init_shards()
@@ -708,7 +709,7 @@ class DatabaseSharding:
return conn
def get_all_shards(self) -> List[ShardInfo]:
def get_all_shards(self) -> list[ShardInfo]:
"""获取所有分片信息"""
return list(self.shard_map.values())
@@ -805,7 +806,7 @@ class DatabaseSharding:
conn.close()
def cross_shard_query(self, query_func: Callable) -> List[Dict]:
def cross_shard_query(self, query_func: Callable) -> list[dict]:
"""
跨分片查询
@@ -831,7 +832,7 @@ class DatabaseSharding:
return results
def get_shard_stats(self) -> List[Dict]:
def get_shard_stats(self) -> list[dict]:
"""获取所有分片的统计信息"""
stats = []
@@ -852,7 +853,7 @@ class DatabaseSharding:
return stats
def rebalance_shards(self) -> Dict:
def rebalance_shards(self) -> dict:
"""
重新平衡分片
@@ -899,15 +900,15 @@ class TaskQueue:
- 任务状态追踪和重试机制
"""
def __init__(self, redis_url: Optional[str] = None, db_path: str = "insightflow.db"):
def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db"):
self.db_path = db_path
self.redis_url = redis_url
self.celery_app = None
self.use_celery = False
# 内存任务存储(非 Celery 模式)
self.tasks: Dict[str, TaskInfo] = {}
self.task_handlers: Dict[str, Callable] = {}
self.tasks: dict[str, TaskInfo] = {}
self.task_handlers: dict[str, Callable] = {}
self.task_lock = threading.RLock()
# 初始化任务队列表
@@ -956,7 +957,7 @@ class TaskQueue:
"""注册任务处理器"""
self.task_handlers[task_type] = handler
def submit(self, task_type: str, payload: Dict, max_retries: int = 3) -> str:
def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str:
"""
提交任务
@@ -1112,7 +1113,7 @@ class TaskQueue:
conn.commit()
conn.close()
def get_status(self, task_id: str) -> Optional[TaskInfo]:
def get_status(self, task_id: str) -> TaskInfo | None:
"""获取任务状态"""
if self.use_celery:
try:
@@ -1143,8 +1144,8 @@ class TaskQueue:
return self.tasks.get(task_id)
def list_tasks(
self, status: Optional[str] = None, task_type: Optional[str] = None, limit: int = 100
) -> List[TaskInfo]:
self, status: str | None = None, task_type: str | None = None, limit: int = 100
) -> list[TaskInfo]:
"""列出任务"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
@@ -1233,7 +1234,7 @@ class TaskQueue:
self._update_task_status(task)
return True
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取任务队列统计"""
conn = sqlite3.connect(self.db_path)
@@ -1290,15 +1291,15 @@ class PerformanceMonitor:
self.alert_threshold = alert_threshold
# 内存中的指标缓存
self.metrics_buffer: List[PerformanceMetric] = []
self.metrics_buffer: list[PerformanceMetric] = []
self.buffer_lock = threading.RLock()
self.buffer_size = 100
# 告警回调
self.alert_handlers: List[Callable] = []
self.alert_handlers: list[Callable] = []
def record_metric(
self, metric_type: str, duration_ms: float, endpoint: Optional[str] = None, metadata: Optional[Dict] = None
self, metric_type: str, duration_ms: float, endpoint: str | None = None, metadata: dict | None = None
):
"""
记录性能指标
@@ -1385,7 +1386,7 @@ class PerformanceMonitor:
"""注册告警处理器"""
self.alert_handlers.append(handler)
def get_stats(self, hours: int = 24) -> Dict:
def get_stats(self, hours: int = 24) -> dict:
"""
获取性能统计
@@ -1504,7 +1505,7 @@ class PerformanceMonitor:
],
}
def get_api_performance(self, endpoint: Optional[str] = None, hours: int = 24) -> Dict:
def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict:
"""获取 API 性能详情"""
self._flush_metrics()
@@ -1584,7 +1585,7 @@ class PerformanceMonitor:
# ==================== 性能装饰器 ====================
def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Optional[Callable] = None):
def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None):
"""
缓存装饰器
@@ -1624,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
return decorator
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: Optional[str] = None):
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None):
"""
性能监控装饰器
@@ -1662,7 +1663,7 @@ class PerformanceManager:
整合缓存管理、数据库分片、任务队列和性能监控功能
"""
def __init__(self, db_path: str = "insightflow.db", redis_url: Optional[str] = None, enable_sharding: bool = False):
def __init__(self, db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False):
self.db_path = db_path
# 初始化各模块
@@ -1674,7 +1675,7 @@ class PerformanceManager:
self.monitor = PerformanceMonitor(db_path=db_path)
def get_health_status(self) -> Dict:
def get_health_status(self) -> dict:
"""获取系统健康状态"""
return {
"cache": {
@@ -1698,7 +1699,7 @@ class PerformanceManager:
},
}
def get_full_stats(self) -> Dict:
def get_full_stats(self) -> dict:
"""获取完整统计信息"""
stats = {
"cache": self.cache.get_stats(),
@@ -1717,7 +1718,7 @@ _performance_manager = None
def get_performance_manager(
db_path: str = "insightflow.db", redis_url: Optional[str] = None, enable_sharding: bool = False
db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False
) -> PerformanceManager:
"""获取性能管理器单例"""
global _performance_manager

View File

@@ -4,20 +4,21 @@ InsightFlow Plugin Manager - Phase 7 Task 7
插件与集成系统Chrome插件、飞书/钉钉机器人、Zapier/Make集成、WebDAV同步
"""
import os
import json
import base64
import hashlib
import hmac
import base64
import time
import uuid
import httpx
import urllib.parse
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
import json
import os
import sqlite3
import time
import urllib.parse
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
import httpx
# WebDAV 支持
try:
@@ -58,10 +59,10 @@ class Plugin:
plugin_type: str
project_id: str
status: str = "active"
config: Dict = field(default_factory=dict)
config: dict = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
last_used_at: Optional[str] = None
last_used_at: str | None = None
use_count: int = 0
@@ -86,13 +87,13 @@ class BotSession:
bot_type: str # feishu, dingtalk
session_id: str # 群ID或会话ID
session_name: str
project_id: Optional[str] = None
project_id: str | None = None
webhook_url: str = ""
secret: str = ""
is_active: bool = True
created_at: str = ""
updated_at: str = ""
last_message_at: Optional[str] = None
last_message_at: str | None = None
message_count: int = 0
@@ -104,14 +105,14 @@ class WebhookEndpoint:
name: str
endpoint_type: str # zapier, make, custom
endpoint_url: str
project_id: Optional[str] = None
project_id: str | None = None
auth_type: str = "none" # none, api_key, oauth, custom
auth_config: Dict = field(default_factory=dict)
trigger_events: List[str] = field(default_factory=list)
auth_config: dict = field(default_factory=dict)
trigger_events: list[str] = field(default_factory=list)
is_active: bool = True
created_at: str = ""
updated_at: str = ""
last_triggered_at: Optional[str] = None
last_triggered_at: str | None = None
trigger_count: int = 0
@@ -128,7 +129,7 @@ class WebDAVSync:
remote_path: str = "/insightflow"
sync_mode: str = "bidirectional" # bidirectional, upload_only, download_only
sync_interval: int = 3600 # 秒
last_sync_at: Optional[str] = None
last_sync_at: str | None = None
last_sync_status: str = "pending" # pending, success, failed
last_sync_error: str = ""
is_active: bool = True
@@ -143,13 +144,13 @@ class ChromeExtensionToken:
id: str
token: str
user_id: Optional[str] = None
project_id: Optional[str] = None
user_id: str | None = None
project_id: str | None = None
name: str = ""
permissions: List[str] = field(default_factory=lambda: ["read", "write"])
expires_at: Optional[str] = None
permissions: list[str] = field(default_factory=lambda: ["read", "write"])
expires_at: str | None = None
created_at: str = ""
last_used_at: Optional[str] = None
last_used_at: str | None = None
use_count: int = 0
is_revoked: bool = False
@@ -171,7 +172,7 @@ class PluginManager:
self._handlers[PluginType.MAKE] = WebhookIntegration(self, "make")
self._handlers[PluginType.WEBDAV] = WebDAVSyncManager(self)
def get_handler(self, plugin_type: PluginType) -> Optional[Any]:
def get_handler(self, plugin_type: PluginType) -> Any | None:
"""获取插件处理器"""
return self._handlers.get(plugin_type)
@@ -205,7 +206,7 @@ class PluginManager:
plugin.updated_at = now
return plugin
def get_plugin(self, plugin_id: str) -> Optional[Plugin]:
def get_plugin(self, plugin_id: str) -> Plugin | None:
"""获取插件"""
conn = self.db.get_conn()
row = conn.execute("SELECT * FROM plugins WHERE id = ?", (plugin_id,)).fetchone()
@@ -215,7 +216,7 @@ class PluginManager:
return self._row_to_plugin(row)
return None
def list_plugins(self, project_id: str = None, plugin_type: str = None, status: str = None) -> List[Plugin]:
def list_plugins(self, project_id: str = None, plugin_type: str = None, status: str = None) -> list[Plugin]:
"""列出插件"""
conn = self.db.get_conn()
@@ -239,7 +240,7 @@ class PluginManager:
return [self._row_to_plugin(row) for row in rows]
def update_plugin(self, plugin_id: str, **kwargs) -> Optional[Plugin]:
def update_plugin(self, plugin_id: str, **kwargs) -> Plugin | None:
"""更新插件"""
conn = self.db.get_conn()
@@ -341,7 +342,7 @@ class PluginManager:
updated_at=now,
)
def get_plugin_config(self, plugin_id: str, key: str) -> Optional[str]:
def get_plugin_config(self, plugin_id: str, key: str) -> str | None:
"""获取插件配置"""
conn = self.db.get_conn()
row = conn.execute(
@@ -351,7 +352,7 @@ class PluginManager:
return row["config_value"] if row else None
def get_all_plugin_configs(self, plugin_id: str) -> Dict[str, str]:
def get_all_plugin_configs(self, plugin_id: str) -> dict[str, str]:
"""获取插件所有配置"""
conn = self.db.get_conn()
rows = conn.execute(
@@ -396,7 +397,7 @@ class ChromeExtensionHandler:
name: str,
user_id: str = None,
project_id: str = None,
permissions: List[str] = None,
permissions: list[str] = None,
expires_days: int = None,
) -> ChromeExtensionToken:
"""创建 Chrome 扩展令牌"""
@@ -448,7 +449,7 @@ class ChromeExtensionHandler:
created_at=now,
)
def validate_token(self, token: str) -> Optional[ChromeExtensionToken]:
def validate_token(self, token: str) -> ChromeExtensionToken | None:
"""验证 Chrome 扩展令牌"""
token_hash = hashlib.sha256(token.encode()).hexdigest()
@@ -501,7 +502,7 @@ class ChromeExtensionHandler:
return cursor.rowcount > 0
def list_tokens(self, user_id: str = None, project_id: str = None) -> List[ChromeExtensionToken]:
def list_tokens(self, user_id: str = None, project_id: str = None) -> list[ChromeExtensionToken]:
"""列出令牌"""
conn = self.pm.db.get_conn()
@@ -544,7 +545,7 @@ class ChromeExtensionHandler:
async def import_webpage(
self, token: ChromeExtensionToken, url: str, title: str, content: str, html_content: str = None
) -> Dict:
) -> dict:
"""导入网页内容"""
if not token.project_id:
return {"success": False, "error": "Token not associated with any project"}
@@ -617,7 +618,7 @@ class BotHandler:
updated_at=now,
)
def get_session(self, session_id: str) -> Optional[BotSession]:
def get_session(self, session_id: str) -> BotSession | None:
"""获取会话"""
conn = self.pm.db.get_conn()
row = conn.execute(
@@ -631,7 +632,7 @@ class BotHandler:
return self._row_to_session(row)
return None
def list_sessions(self, project_id: str = None) -> List[BotSession]:
def list_sessions(self, project_id: str = None) -> list[BotSession]:
"""列出会话"""
conn = self.pm.db.get_conn()
@@ -652,7 +653,7 @@ class BotHandler:
return [self._row_to_session(row) for row in rows]
def update_session(self, session_id: str, **kwargs) -> Optional[BotSession]:
def update_session(self, session_id: str, **kwargs) -> BotSession | None:
"""更新会话"""
conn = self.pm.db.get_conn()
@@ -709,7 +710,7 @@ class BotHandler:
message_count=row["message_count"],
)
async def handle_message(self, session: BotSession, message: Dict) -> Dict:
async def handle_message(self, session: BotSession, message: dict) -> dict:
"""处理收到的消息"""
now = datetime.now().isoformat()
@@ -740,7 +741,7 @@ class BotHandler:
return {"success": False, "error": "Unsupported message type"}
async def _handle_text_message(self, session: BotSession, text: str, raw_message: Dict) -> Dict:
async def _handle_text_message(self, session: BotSession, text: str, raw_message: dict) -> dict:
"""处理文本消息"""
# 简单命令处理
if text.startswith("/help"):
@@ -772,7 +773,7 @@ class BotHandler:
# 默认回复
return {"success": True, "response": f"收到消息:{text[:100]}...\n\n使用 /help 查看可用命令"}
async def _handle_audio_message(self, session: BotSession, message: Dict) -> Dict:
async def _handle_audio_message(self, session: BotSession, message: dict) -> dict:
"""处理音频消息"""
if not session.project_id:
return {"success": False, "error": "Session not bound to any project"}
@@ -802,7 +803,7 @@ class BotHandler:
except Exception as e:
return {"success": False, "error": f"Failed to process audio: {str(e)}"}
async def _handle_file_message(self, session: BotSession, message: Dict) -> Dict:
async def _handle_file_message(self, session: BotSession, message: dict) -> dict:
"""处理文件消息"""
return {"success": True, "response": "📎 收到文件,正在处理中..."}
@@ -825,8 +826,8 @@ class BotHandler:
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送飞书消息"""
import hashlib
import base64
import hashlib
timestamp = str(int(time.time()))
@@ -850,8 +851,8 @@ class BotHandler:
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送钉钉消息"""
import hashlib
import base64
import hashlib
timestamp = str(round(time.time() * 1000))
@@ -890,8 +891,8 @@ class WebhookIntegration:
endpoint_url: str,
project_id: str = None,
auth_type: str = "none",
auth_config: Dict = None,
trigger_events: List[str] = None,
auth_config: dict = None,
trigger_events: list[str] = None,
) -> WebhookEndpoint:
"""创建 Webhook 端点"""
endpoint_id = str(uuid.uuid4())[:8]
@@ -935,7 +936,7 @@ class WebhookIntegration:
updated_at=now,
)
def get_endpoint(self, endpoint_id: str) -> Optional[WebhookEndpoint]:
def get_endpoint(self, endpoint_id: str) -> WebhookEndpoint | None:
"""获取端点"""
conn = self.pm.db.get_conn()
row = conn.execute(
@@ -947,7 +948,7 @@ class WebhookIntegration:
return self._row_to_endpoint(row)
return None
def list_endpoints(self, project_id: str = None) -> List[WebhookEndpoint]:
def list_endpoints(self, project_id: str = None) -> list[WebhookEndpoint]:
"""列出端点"""
conn = self.pm.db.get_conn()
@@ -968,7 +969,7 @@ class WebhookIntegration:
return [self._row_to_endpoint(row) for row in rows]
def update_endpoint(self, endpoint_id: str, **kwargs) -> Optional[WebhookEndpoint]:
def update_endpoint(self, endpoint_id: str, **kwargs) -> WebhookEndpoint | None:
"""更新端点"""
conn = self.pm.db.get_conn()
@@ -1034,7 +1035,7 @@ class WebhookIntegration:
trigger_count=row["trigger_count"],
)
async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: Dict) -> bool:
async def trigger(self, endpoint: WebhookEndpoint, event_type: str, data: dict) -> bool:
"""触发 Webhook"""
if not endpoint.is_active:
return False
@@ -1079,7 +1080,7 @@ class WebhookIntegration:
print(f"Failed to trigger webhook: {e}")
return False
async def test_endpoint(self, endpoint: WebhookEndpoint) -> Dict:
async def test_endpoint(self, endpoint: WebhookEndpoint) -> dict:
"""测试端点"""
test_data = {
"message": "This is a test event from InsightFlow",
@@ -1160,7 +1161,7 @@ class WebDAVSyncManager:
updated_at=now,
)
def get_sync(self, sync_id: str) -> Optional[WebDAVSync]:
def get_sync(self, sync_id: str) -> WebDAVSync | None:
"""获取同步配置"""
conn = self.pm.db.get_conn()
row = conn.execute("SELECT * FROM webdav_syncs WHERE id = ?", (sync_id,)).fetchone()
@@ -1170,7 +1171,7 @@ class WebDAVSyncManager:
return self._row_to_sync(row)
return None
def list_syncs(self, project_id: str = None) -> List[WebDAVSync]:
def list_syncs(self, project_id: str = None) -> list[WebDAVSync]:
"""列出同步配置"""
conn = self.pm.db.get_conn()
@@ -1185,7 +1186,7 @@ class WebDAVSyncManager:
return [self._row_to_sync(row) for row in rows]
def update_sync(self, sync_id: str, **kwargs) -> Optional[WebDAVSync]:
def update_sync(self, sync_id: str, **kwargs) -> WebDAVSync | None:
"""更新同步配置"""
conn = self.pm.db.get_conn()
@@ -1252,7 +1253,7 @@ class WebDAVSyncManager:
sync_count=row["sync_count"],
)
async def test_connection(self, sync: WebDAVSync) -> Dict:
async def test_connection(self, sync: WebDAVSync) -> dict:
"""测试 WebDAV 连接"""
if not WEBDAV_AVAILABLE:
return {"success": False, "error": "WebDAV library not available"}
@@ -1268,7 +1269,7 @@ class WebDAVSyncManager:
except Exception as e:
return {"success": False, "error": str(e)}
async def sync_project(self, sync: WebDAVSync) -> Dict:
async def sync_project(self, sync: WebDAVSync) -> dict:
"""同步项目到 WebDAV"""
if not WEBDAV_AVAILABLE:
return {"success": False, "error": "WebDAV library not available"}

View File

@@ -5,11 +5,11 @@ API 限流中间件
支持基于内存的滑动窗口限流
"""
import time
import asyncio
from typing import Dict, Optional, Callable
from dataclasses import dataclass
import time
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
@@ -37,7 +37,7 @@ class SlidingWindowCounter:
def __init__(self, window_size: int = 60):
self.window_size = window_size
self.requests: Dict[int, int] = defaultdict(int) # 秒级计数
self.requests: dict[int, int] = defaultdict(int) # 秒级计数
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
@@ -69,13 +69,13 @@ class RateLimiter:
def __init__(self):
# key -> SlidingWindowCounter
self.counters: Dict[str, SlidingWindowCounter] = {}
self.counters: dict[str, SlidingWindowCounter] = {}
# key -> RateLimitConfig
self.configs: Dict[str, RateLimitConfig] = {}
self.configs: dict[str, RateLimitConfig] = {}
self._lock = asyncio.Lock()
self._cleanup_lock = asyncio.Lock()
async def is_allowed(self, key: str, config: Optional[RateLimitConfig] = None) -> RateLimitInfo:
async def is_allowed(self, key: str, config: RateLimitConfig | None = None) -> RateLimitInfo:
"""
检查是否允许请求
@@ -143,7 +143,7 @@ class RateLimiter:
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
)
def reset(self, key: Optional[str] = None):
def reset(self, key: str | None = None):
"""重置限流计数器"""
if key:
self.counters.pop(key, None)
@@ -154,7 +154,7 @@ class RateLimiter:
# 全局限流器实例
_rate_limiter: Optional[RateLimiter] = None
_rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter:
@@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter:
# 限流装饰器(用于函数级别限流)
def rate_limit(requests_per_minute: int = 60, key_func: Optional[Callable] = None):
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None):
"""
限流装饰器

View File

@@ -9,15 +9,14 @@ Phase 7 Task 6: Advanced Search & Discovery
4. KnowledgeGapDetection - 知识缺口识别
"""
import re
import hashlib
import json
import math
import re
import sqlite3
import hashlib
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple, Set
from datetime import datetime
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -46,10 +45,10 @@ class SearchResult:
content_type: str # transcript, entity, relation
project_id: str
score: float
highlights: List[Tuple[int, int]] = field(default_factory=list) # 高亮位置
metadata: Dict = field(default_factory=dict)
highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置
metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
return {
"id": self.id,
"content": self.content,
@@ -69,10 +68,10 @@ class SemanticSearchResult:
content_type: str
project_id: str
similarity: float
embedding: Optional[List[float]] = None
metadata: Dict = field(default_factory=dict)
embedding: list[float] | None = None
metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
result = {
"id": self.id,
"content": self.content[:500] + "..." if len(self.content) > 500 else self.content,
@@ -95,12 +94,12 @@ class EntityPath:
target_entity_id: str
target_entity_name: str
path_length: int
nodes: List[Dict] # 路径上的节点
edges: List[Dict] # 路径上的边
nodes: list[dict] # 路径上的节点
edges: list[dict] # 路径上的边
confidence: float
path_description: str
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
return {
"path_id": self.path_id,
"source_entity_id": self.source_entity_id,
@@ -120,15 +119,15 @@ class KnowledgeGap:
"""知识缺口数据模型"""
gap_id: str
gap_type: str # missing_attribute, sparse_relation, isolated_entity, incomplete_entity
entity_id: Optional[str]
entity_name: Optional[str]
entity_id: str | None
entity_name: str | None
description: str
severity: str # high, medium, low
suggestions: List[str]
related_entities: List[str]
metadata: Dict = field(default_factory=dict)
suggestions: list[str]
related_entities: list[str]
metadata: dict = field(default_factory=dict)
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
return {
"gap_id": self.gap_id,
"gap_type": self.gap_type,
@@ -149,8 +148,8 @@ class SearchIndex:
content_id: str
content_type: str
project_id: str
tokens: List[str]
token_positions: Dict[str, List[int]] # 词 -> 位置列表
tokens: list[str]
token_positions: dict[str, list[int]] # 词 -> 位置列表
created_at: str
updated_at: str
@@ -162,7 +161,7 @@ class TextEmbedding:
content_id: str
content_type: str
project_id: str
embedding: List[float]
embedding: list[float]
model_name: str
created_at: str
@@ -231,7 +230,7 @@ class FullTextSearch:
conn.commit()
conn.close()
def _tokenize(self, text: str) -> List[str]:
def _tokenize(self, text: str) -> list[str]:
"""
中文分词(简化版)
@@ -243,7 +242,7 @@ class FullTextSearch:
tokens = re.findall(r'[\u4e00-\u9fa5]+|[a-z]+|\d+', text)
return tokens
def _extract_positions(self, text: str, tokens: List[str]) -> Dict[str, List[int]]:
def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]:
"""提取每个词在文本中的位置"""
positions = defaultdict(list)
text_lower = text.lower()
@@ -326,9 +325,9 @@ class FullTextSearch:
print(f"索引创建失败: {e}")
return False
def search(self, query: str, project_id: Optional[str] = None,
content_types: Optional[List[str]] = None,
limit: int = 20, offset: int = 0) -> List[SearchResult]:
def search(self, query: str, project_id: str | None = None,
content_types: list[str] | None = None,
limit: int = 20, offset: int = 0) -> list[SearchResult]:
"""
全文搜索
@@ -358,7 +357,7 @@ class FullTextSearch:
return scored_results[offset:offset + limit]
def _parse_boolean_query(self, query: str) -> Dict:
def _parse_boolean_query(self, query: str) -> dict:
"""
解析布尔查询
@@ -401,9 +400,9 @@ class FullTextSearch:
"phrases": phrases
}
def _execute_boolean_search(self, parsed_query: Dict,
project_id: Optional[str] = None,
content_types: Optional[List[str]] = None) -> List[Dict]:
def _execute_boolean_search(self, parsed_query: dict,
project_id: str | None = None,
content_types: list[str] | None = None) -> list[dict]:
"""执行布尔搜索"""
conn = self._get_conn()
@@ -503,7 +502,7 @@ class FullTextSearch:
return results
def _get_content_by_id(self, conn: sqlite3.Connection,
content_id: str, content_type: str) -> Optional[str]:
content_id: str, content_type: str) -> str | None:
"""根据ID获取内容"""
try:
if content_type == "transcript":
@@ -542,7 +541,7 @@ class FullTextSearch:
return None
def _get_project_id(self, conn: sqlite3.Connection,
content_id: str, content_type: str) -> Optional[str]:
content_id: str, content_type: str) -> str | None:
"""获取内容所属的项目ID"""
try:
if content_type == "transcript":
@@ -567,7 +566,7 @@ class FullTextSearch:
except Exception:
return None
def _score_results(self, results: List[Dict], parsed_query: Dict) -> List[SearchResult]:
def _score_results(self, results: list[dict], parsed_query: dict) -> list[SearchResult]:
"""计算搜索结果的相关性分数"""
scored = []
all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"]
@@ -689,7 +688,7 @@ class FullTextSearch:
print(f"删除索引失败: {e}")
return False
def reindex_project(self, project_id: str) -> Dict:
def reindex_project(self, project_id: str) -> dict:
"""重新索引整个项目"""
conn = self._get_conn()
stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0}
@@ -808,7 +807,7 @@ class SemanticSearch:
"""检查语义搜索是否可用"""
return self.model is not None and SENTENCE_TRANSFORMERS_AVAILABLE
def generate_embedding(self, text: str) -> Optional[List[float]]:
def generate_embedding(self, text: str) -> list[float] | None:
"""
生成文本的 embedding 向量
@@ -878,9 +877,9 @@ class SemanticSearch:
print(f"索引 embedding 失败: {e}")
return False
def search(self, query: str, project_id: Optional[str] = None,
content_types: Optional[List[str]] = None,
top_k: int = 10, threshold: float = 0.5) -> List[SemanticSearchResult]:
def search(self, query: str, project_id: str | None = None,
content_types: list[str] | None = None,
top_k: int = 10, threshold: float = 0.5) -> list[SemanticSearchResult]:
"""
语义搜索
@@ -959,7 +958,7 @@ class SemanticSearch:
results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k]
def _get_content_text(self, content_id: str, content_type: str) -> Optional[str]:
def _get_content_text(self, content_id: str, content_type: str) -> str | None:
"""获取内容文本"""
conn = self._get_conn()
@@ -1002,7 +1001,7 @@ class SemanticSearch:
return None
def find_similar_content(self, content_id: str, content_type: str,
top_k: int = 5) -> List[SemanticSearchResult]:
top_k: int = 5) -> list[SemanticSearchResult]:
"""
查找与指定内容相似的内容
@@ -1107,7 +1106,7 @@ class EntityPathDiscovery:
def find_shortest_path(self, source_entity_id: str,
target_entity_id: str,
max_depth: int = 5) -> Optional[EntityPath]:
max_depth: int = 5) -> EntityPath | None:
"""
查找两个实体之间的最短路径BFS算法
@@ -1181,7 +1180,7 @@ class EntityPathDiscovery:
def find_all_paths(self, source_entity_id: str,
target_entity_id: str,
max_depth: int = 4,
max_paths: int = 10) -> List[EntityPath]:
max_paths: int = 10) -> list[EntityPath]:
"""
查找两个实体之间的所有路径(限制数量和深度)
@@ -1211,7 +1210,7 @@ class EntityPathDiscovery:
paths = []
def dfs(current_id: str, target_id: str,
path: List[str], visited: Set[str], depth: int):
path: list[str], visited: set[str], depth: int):
if depth > max_depth:
return
@@ -1247,7 +1246,7 @@ class EntityPathDiscovery:
# 构建路径对象
return [self._build_path_object(path, project_id) for path in paths]
def _build_path_object(self, entity_ids: List[str],
def _build_path_object(self, entity_ids: list[str],
project_id: str) -> EntityPath:
"""构建路径对象"""
conn = self._get_conn()
@@ -1312,7 +1311,7 @@ class EntityPathDiscovery:
)
def find_multi_hop_relations(self, entity_id: str,
max_hops: int = 3) -> List[Dict]:
max_hops: int = 3) -> list[dict]:
"""
查找实体的多跳关系
@@ -1394,7 +1393,7 @@ class EntityPathDiscovery:
return relations
def _get_path_to_entity(self, source_id: str, target_id: str,
project_id: str, conn: sqlite3.Connection) -> List[str]:
project_id: str, conn: sqlite3.Connection) -> list[str]:
"""获取从源实体到目标实体的路径(简化版)"""
# BFS 找路径
visited = {source_id}
@@ -1428,7 +1427,7 @@ class EntityPathDiscovery:
return []
def generate_path_visualization(self, path: EntityPath) -> Dict:
def generate_path_visualization(self, path: EntityPath) -> dict:
"""
生成路径可视化数据
@@ -1467,7 +1466,7 @@ class EntityPathDiscovery:
"confidence": path.confidence
}
def analyze_path_centrality(self, project_id: str) -> List[Dict]:
def analyze_path_centrality(self, project_id: str) -> list[dict]:
"""
分析项目中实体的路径中心性(桥接程度)
@@ -1558,7 +1557,7 @@ class KnowledgeGapDetection:
conn.row_factory = sqlite3.Row
return conn
def analyze_project(self, project_id: str) -> List[KnowledgeGap]:
def analyze_project(self, project_id: str) -> list[KnowledgeGap]:
"""
分析项目中的知识缺口
@@ -1591,7 +1590,7 @@ class KnowledgeGapDetection:
return gaps
def _check_entity_attribute_completeness(self, project_id: str) -> List[KnowledgeGap]:
def _check_entity_attribute_completeness(self, project_id: str) -> list[KnowledgeGap]:
"""检查实体属性完整性"""
conn = self._get_conn()
gaps = []
@@ -1661,7 +1660,7 @@ class KnowledgeGapDetection:
conn.close()
return gaps
def _check_relation_sparsity(self, project_id: str) -> List[KnowledgeGap]:
def _check_relation_sparsity(self, project_id: str) -> list[KnowledgeGap]:
"""检查关系稀疏度"""
conn = self._get_conn()
gaps = []
@@ -1720,7 +1719,7 @@ class KnowledgeGapDetection:
conn.close()
return gaps
def _check_isolated_entities(self, project_id: str) -> List[KnowledgeGap]:
def _check_isolated_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查孤立实体(没有任何关系)"""
conn = self._get_conn()
gaps = []
@@ -1756,7 +1755,7 @@ class KnowledgeGapDetection:
conn.close()
return gaps
def _check_incomplete_entities(self, project_id: str) -> List[KnowledgeGap]:
def _check_incomplete_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查不完整实体(缺少名称、类型或定义)"""
conn = self._get_conn()
gaps = []
@@ -1788,7 +1787,7 @@ class KnowledgeGapDetection:
conn.close()
return gaps
def _check_missing_key_entities(self, project_id: str) -> List[KnowledgeGap]:
def _check_missing_key_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查可能缺失的关键实体"""
conn = self._get_conn()
gaps = []
@@ -1841,7 +1840,7 @@ class KnowledgeGapDetection:
conn.close()
return gaps[:10] # 限制数量
def generate_completeness_report(self, project_id: str) -> Dict:
def generate_completeness_report(self, project_id: str) -> dict:
"""
生成知识完整性报告
@@ -1898,7 +1897,7 @@ class KnowledgeGapDetection:
"recommendations": self._generate_recommendations(gaps)
}
def _generate_recommendations(self, gaps: List[KnowledgeGap]) -> List[str]:
def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]:
"""生成改进建议"""
recommendations = []
@@ -1941,8 +1940,8 @@ class SearchManager:
self.path_discovery = EntityPathDiscovery(db_path)
self.gap_detection = KnowledgeGapDetection(db_path)
def hybrid_search(self, query: str, project_id: Optional[str] = None,
limit: int = 20) -> Dict:
def hybrid_search(self, query: str, project_id: str | None = None,
limit: int = 20) -> dict:
"""
混合搜索(全文 + 语义)
@@ -2014,7 +2013,7 @@ class SearchManager:
"results": results[:limit]
}
def index_project(self, project_id: str) -> Dict:
def index_project(self, project_id: str) -> dict:
"""
为项目建立所有索引
@@ -2071,7 +2070,7 @@ class SearchManager:
"semantic": semantic_stats
}
def get_search_stats(self, project_id: Optional[str] = None) -> Dict:
def get_search_stats(self, project_id: str | None = None) -> dict:
"""获取搜索统计信息"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
@@ -2126,28 +2125,28 @@ def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
# 便捷函数
def fulltext_search(query: str, project_id: Optional[str] = None,
limit: int = 20) -> List[SearchResult]:
def fulltext_search(query: str, project_id: str | None = None,
limit: int = 20) -> list[SearchResult]:
"""全文搜索便捷函数"""
manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit)
def semantic_search(query: str, project_id: Optional[str] = None,
top_k: int = 10) -> List[SemanticSearchResult]:
def semantic_search(query: str, project_id: str | None = None,
top_k: int = 10) -> list[SemanticSearchResult]:
"""语义搜索便捷函数"""
manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k)
def find_entity_path(source_id: str, target_id: str,
max_depth: int = 5) -> Optional[EntityPath]:
max_depth: int = 5) -> EntityPath | None:
"""查找实体路径便捷函数"""
manager = get_search_manager()
return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth)
def detect_knowledge_gaps(project_id: str) -> List[KnowledgeGap]:
def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]:
"""知识缺口检测便捷函数"""
manager = get_search_manager()
return manager.gap_detection.analyze_project(project_id)

View File

@@ -3,16 +3,16 @@ InsightFlow Phase 7 Task 3: 数据安全与合规模块
Security Manager - 端到端加密、数据脱敏、审计日志
"""
import json
import hashlib
import secrets
import base64
import hashlib
import json
import re
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field, asdict
from enum import Enum
import secrets
import sqlite3
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
# 加密相关
try:
@@ -71,19 +71,19 @@ class AuditLog:
"""审计日志条目"""
id: str
action_type: str
user_id: Optional[str] = None
user_ip: Optional[str] = None
user_agent: Optional[str] = None
resource_type: Optional[str] = None # project, entity, transcript, etc.
resource_id: Optional[str] = None
action_details: Optional[str] = None # JSON string
before_value: Optional[str] = None
after_value: Optional[str] = None
user_id: str | None = None
user_ip: str | None = None
user_agent: str | None = None
resource_type: str | None = None # project, entity, transcript, etc.
resource_id: str | None = None
action_details: str | None = None # JSON string
before_value: str | None = None
after_value: str | None = None
success: bool = True
error_message: Optional[str] = None
error_message: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -95,12 +95,12 @@ class EncryptionConfig:
is_enabled: bool = False
encryption_type: str = "aes-256-gcm" # aes-256-gcm, chacha20-poly1305
key_derivation: str = "pbkdf2" # pbkdf2, argon2
master_key_hash: Optional[str] = None # 主密钥哈希(用于验证)
salt: Optional[str] = None
master_key_hash: str | None = None # 主密钥哈希(用于验证)
salt: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -115,11 +115,11 @@ class MaskingRule:
replacement: str # 替换模板,如 "****"
is_active: bool = True
priority: int = 0
description: Optional[str] = None
description: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -129,18 +129,18 @@ class DataAccessPolicy:
id: str
project_id: str
name: str
description: Optional[str] = None
allowed_users: Optional[str] = None # JSON array of user IDs
allowed_roles: Optional[str] = None # JSON array of roles
allowed_ips: Optional[str] = None # JSON array of IP patterns
time_restrictions: Optional[str] = None # JSON: {"start_time": "09:00", "end_time": "18:00"}
max_access_count: Optional[int] = None # 最大访问次数
description: str | None = None
allowed_users: str | None = None # JSON array of user IDs
allowed_roles: str | None = None # JSON array of roles
allowed_ips: str | None = None # JSON array of IP patterns
time_restrictions: str | None = None # JSON: {"start_time": "09:00", "end_time": "18:00"}
max_access_count: int | None = None # 最大访问次数
require_approval: bool = False
is_active: bool = True
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -150,14 +150,14 @@ class AccessRequest:
id: str
policy_id: str
user_id: str
request_reason: Optional[str] = None
request_reason: str | None = None
status: str = "pending" # pending, approved, rejected, expired
approved_by: Optional[str] = None
approved_at: Optional[str] = None
expires_at: Optional[str] = None
approved_by: str | None = None
approved_at: str | None = None
expires_at: str | None = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@@ -196,7 +196,7 @@ class SecurityManager:
self.db_path = db_path
self.db_path = db_path
# 预编译正则缓存
self._compiled_patterns: Dict[str, re.Pattern] = {}
self._compiled_patterns: dict[str, re.Pattern] = {}
self._local = {}
self._init_db()
@@ -317,16 +317,16 @@ class SecurityManager:
def log_audit(
self,
action_type: AuditActionType,
user_id: Optional[str] = None,
user_ip: Optional[str] = None,
user_agent: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
action_details: Optional[Dict] = None,
before_value: Optional[str] = None,
after_value: Optional[str] = None,
user_id: str | None = None,
user_ip: str | None = None,
user_agent: str | None = None,
resource_type: str | None = None,
resource_id: str | None = None,
action_details: dict | None = None,
before_value: str | None = None,
after_value: str | None = None,
success: bool = True,
error_message: Optional[str] = None
error_message: str | None = None
) -> AuditLog:
"""记录审计日志"""
log = AuditLog(
@@ -364,16 +364,16 @@ class SecurityManager:
def get_audit_logs(
self,
user_id: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
action_type: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
success: Optional[bool] = None,
user_id: str | None = None,
resource_type: str | None = None,
resource_id: str | None = None,
action_type: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
success: bool | None = None,
limit: int = 100,
offset: int = 0
) -> List[AuditLog]:
) -> list[AuditLog]:
"""查询审计日志"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -438,9 +438,9 @@ class SecurityManager:
def get_audit_stats(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None
) -> Dict[str, Any]:
start_time: str | None = None,
end_time: str | None = None
) -> dict[str, Any]:
"""获取审计统计"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -633,7 +633,7 @@ class SecurityManager:
return key_hash == stored_hash
def get_encryption_config(self, project_id: str) -> Optional[EncryptionConfig]:
def get_encryption_config(self, project_id: str) -> EncryptionConfig | None:
"""获取加密配置"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -664,8 +664,8 @@ class SecurityManager:
self,
data: str,
password: str,
salt: Optional[str] = None
) -> Tuple[str, str]:
salt: str | None = None
) -> tuple[str, str]:
"""加密数据"""
if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available")
@@ -702,9 +702,9 @@ class SecurityManager:
project_id: str,
name: str,
rule_type: MaskingRuleType,
pattern: Optional[str] = None,
replacement: Optional[str] = None,
description: Optional[str] = None,
pattern: str | None = None,
replacement: str | None = None,
description: str | None = None,
priority: int = 0
) -> MaskingRule:
"""创建脱敏规则"""
@@ -756,7 +756,7 @@ class SecurityManager:
self,
project_id: str,
active_only: bool = True
) -> List[MaskingRule]:
) -> list[MaskingRule]:
"""获取脱敏规则"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -795,7 +795,7 @@ class SecurityManager:
self,
rule_id: str,
**kwargs
) -> Optional[MaskingRule]:
) -> MaskingRule | None:
"""更新脱敏规则"""
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
@@ -868,7 +868,7 @@ class SecurityManager:
self,
text: str,
project_id: str,
rule_types: Optional[List[MaskingRuleType]] = None
rule_types: list[MaskingRuleType] | None = None
) -> str:
"""应用脱敏规则到文本"""
rules = self.get_masking_rules(project_id)
@@ -897,9 +897,9 @@ class SecurityManager:
def apply_masking_to_entity(
self,
entity_data: Dict[str, Any],
entity_data: dict[str, Any],
project_id: str
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""对实体数据应用脱敏"""
masked_data = entity_data.copy()
@@ -918,12 +918,12 @@ class SecurityManager:
self,
project_id: str,
name: str,
description: Optional[str] = None,
allowed_users: Optional[List[str]] = None,
allowed_roles: Optional[List[str]] = None,
allowed_ips: Optional[List[str]] = None,
time_restrictions: Optional[Dict] = None,
max_access_count: Optional[int] = None,
description: str | None = None,
allowed_users: list[str] | None = None,
allowed_roles: list[str] | None = None,
allowed_ips: list[str] | None = None,
time_restrictions: dict | None = None,
max_access_count: int | None = None,
require_approval: bool = False
) -> DataAccessPolicy:
"""创建数据访问策略"""
@@ -966,7 +966,7 @@ class SecurityManager:
self,
project_id: str,
active_only: bool = True
) -> List[DataAccessPolicy]:
) -> list[DataAccessPolicy]:
"""获取数据访问策略"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -1005,8 +1005,8 @@ class SecurityManager:
self,
policy_id: str,
user_id: str,
user_ip: Optional[str] = None
) -> Tuple[bool, Optional[str]]:
user_ip: str | None = None
) -> tuple[bool, str | None]:
"""检查访问权限"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -1107,7 +1107,7 @@ class SecurityManager:
self,
policy_id: str,
user_id: str,
request_reason: Optional[str] = None,
request_reason: str | None = None,
expires_hours: int = 24
) -> AccessRequest:
"""创建访问请求"""
@@ -1142,7 +1142,7 @@ class SecurityManager:
request_id: str,
approved_by: str,
expires_hours: int = 24
) -> Optional[AccessRequest]:
) -> AccessRequest | None:
"""批准访问请求"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -1182,7 +1182,7 @@ class SecurityManager:
self,
request_id: str,
rejected_by: str
) -> Optional[AccessRequest]:
) -> AccessRequest | None:
"""拒绝访问请求"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

View File

@@ -10,19 +10,19 @@ InsightFlow Phase 8 - 订阅与计费系统模块
作者: InsightFlow Team
"""
import sqlite3
import json
import uuid
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import logging
import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
class SubscriptionStatus(str, Enum):
class SubscriptionStatus(StrEnum):
"""订阅状态"""
ACTIVE = "active" # 活跃
CANCELLED = "cancelled" # 已取消
@@ -32,7 +32,7 @@ class SubscriptionStatus(str, Enum):
PENDING = "pending" # 待支付
class PaymentProvider(str, Enum):
class PaymentProvider(StrEnum):
"""支付提供商"""
STRIPE = "stripe" # Stripe
ALIPAY = "alipay" # 支付宝
@@ -40,7 +40,7 @@ class PaymentProvider(str, Enum):
BANK_TRANSFER = "bank_transfer" # 银行转账
class PaymentStatus(str, Enum):
class PaymentStatus(StrEnum):
"""支付状态"""
PENDING = "pending" # 待支付
PROCESSING = "processing" # 处理中
@@ -50,7 +50,7 @@ class PaymentStatus(str, Enum):
PARTIAL_REFUNDED = "partial_refunded" # 部分退款
class InvoiceStatus(str, Enum):
class InvoiceStatus(StrEnum):
"""发票状态"""
DRAFT = "draft" # 草稿
ISSUED = "issued" # 已开具
@@ -60,7 +60,7 @@ class InvoiceStatus(str, Enum):
CREDIT_NOTE = "credit_note" # 贷项通知单
class RefundStatus(str, Enum):
class RefundStatus(StrEnum):
"""退款状态"""
PENDING = "pending" # 待处理
APPROVED = "approved" # 已批准
@@ -79,12 +79,12 @@ class SubscriptionPlan:
price_monthly: float # 月付价格
price_yearly: float # 年付价格
currency: str # CNY/USD
features: List[str] # 功能列表
limits: Dict[str, Any] # 资源限制
features: list[str] # 功能列表
limits: dict[str, Any] # 资源限制
is_active: bool
created_at: datetime
updated_at: datetime
metadata: Dict[str, Any]
metadata: dict[str, Any]
@dataclass
@@ -97,14 +97,14 @@ class Subscription:
current_period_start: datetime
current_period_end: datetime
cancel_at_period_end: bool
canceled_at: Optional[datetime]
trial_start: Optional[datetime]
trial_end: Optional[datetime]
payment_provider: Optional[str]
provider_subscription_id: Optional[str] # 支付提供商的订阅ID
canceled_at: datetime | None
trial_start: datetime | None
trial_end: datetime | None
payment_provider: str | None
provider_subscription_id: str | None # 支付提供商的订阅ID
created_at: datetime
updated_at: datetime
metadata: Dict[str, Any]
metadata: dict[str, Any]
@dataclass
@@ -117,8 +117,8 @@ class UsageRecord:
unit: str # minutes/mb/count
recorded_at: datetime
cost: float # 费用
description: Optional[str]
metadata: Dict[str, Any]
description: str | None
metadata: dict[str, Any]
@dataclass
@@ -126,18 +126,18 @@ class Payment:
"""支付记录数据类"""
id: str
tenant_id: str
subscription_id: Optional[str]
invoice_id: Optional[str]
subscription_id: str | None
invoice_id: str | None
amount: float
currency: str
provider: str
provider_payment_id: Optional[str]
provider_payment_id: str | None
status: str
payment_method: Optional[str]
payment_details: Dict[str, Any]
paid_at: Optional[datetime]
failed_at: Optional[datetime]
failure_reason: Optional[str]
payment_method: str | None
payment_details: dict[str, Any]
paid_at: datetime | None
failed_at: datetime | None
failure_reason: str | None
created_at: datetime
updated_at: datetime
@@ -147,7 +147,7 @@ class Invoice:
"""发票数据类"""
id: str
tenant_id: str
subscription_id: Optional[str]
subscription_id: str | None
invoice_number: str
status: str
amount_due: float
@@ -156,11 +156,11 @@ class Invoice:
period_start: datetime
period_end: datetime
description: str
line_items: List[Dict[str, Any]]
line_items: list[dict[str, Any]]
due_date: datetime
paid_at: Optional[datetime]
voided_at: Optional[datetime]
void_reason: Optional[str]
paid_at: datetime | None
voided_at: datetime | None
void_reason: str | None
created_at: datetime
updated_at: datetime
@@ -171,18 +171,18 @@ class Refund:
id: str
tenant_id: str
payment_id: str
invoice_id: Optional[str]
invoice_id: str | None
amount: float
currency: str
reason: str
status: str
requested_by: str
requested_at: datetime
approved_by: Optional[str]
approved_at: Optional[str]
completed_at: Optional[datetime]
provider_refund_id: Optional[str]
metadata: Dict[str, Any]
approved_by: str | None
approved_at: str | None
completed_at: datetime | None
provider_refund_id: str | None
metadata: dict[str, Any]
created_at: datetime
updated_at: datetime
@@ -199,7 +199,7 @@ class BillingHistory:
reference_id: str # 关联的订阅/支付/退款ID
balance_after: float # 操作后余额
created_at: datetime
metadata: Dict[str, Any]
metadata: dict[str, Any]
class SubscriptionManager:
@@ -542,7 +542,7 @@ class SubscriptionManager:
# ==================== 订阅计划管理 ====================
def get_plan(self, plan_id: str) -> Optional[SubscriptionPlan]:
def get_plan(self, plan_id: str) -> SubscriptionPlan | None:
"""获取订阅计划"""
conn = self._get_connection()
try:
@@ -557,7 +557,7 @@ class SubscriptionManager:
finally:
conn.close()
def get_plan_by_tier(self, tier: str) -> Optional[SubscriptionPlan]:
def get_plan_by_tier(self, tier: str) -> SubscriptionPlan | None:
"""通过层级获取订阅计划"""
conn = self._get_connection()
try:
@@ -572,7 +572,7 @@ class SubscriptionManager:
finally:
conn.close()
def list_plans(self, include_inactive: bool = False) -> List[SubscriptionPlan]:
def list_plans(self, include_inactive: bool = False) -> list[SubscriptionPlan]:
"""列出所有订阅计划"""
conn = self._get_connection()
try:
@@ -591,8 +591,8 @@ class SubscriptionManager:
def create_plan(self, name: str, tier: str, description: str,
price_monthly: float, price_yearly: float,
currency: str = "CNY", features: List[str] = None,
limits: Dict[str, Any] = None) -> SubscriptionPlan:
currency: str = "CNY", features: list[str] = None,
limits: dict[str, Any] = None) -> SubscriptionPlan:
"""创建新订阅计划"""
conn = self._get_connection()
try:
@@ -639,7 +639,7 @@ class SubscriptionManager:
finally:
conn.close()
def update_plan(self, plan_id: str, **kwargs) -> Optional[SubscriptionPlan]:
def update_plan(self, plan_id: str, **kwargs) -> SubscriptionPlan | None:
"""更新订阅计划"""
conn = self._get_connection()
try:
@@ -685,7 +685,7 @@ class SubscriptionManager:
# ==================== 订阅管理 ====================
def create_subscription(self, tenant_id: str, plan_id: str,
payment_provider: Optional[str] = None,
payment_provider: str | None = None,
trial_days: int = 0,
billing_cycle: str = "monthly") -> Subscription:
"""创建新订阅"""
@@ -785,7 +785,7 @@ class SubscriptionManager:
finally:
conn.close()
def get_subscription(self, subscription_id: str) -> Optional[Subscription]:
def get_subscription(self, subscription_id: str) -> Subscription | None:
"""获取订阅信息"""
conn = self._get_connection()
try:
@@ -800,7 +800,7 @@ class SubscriptionManager:
finally:
conn.close()
def get_tenant_subscription(self, tenant_id: str) -> Optional[Subscription]:
def get_tenant_subscription(self, tenant_id: str) -> Subscription | None:
"""获取租户的当前订阅"""
conn = self._get_connection()
try:
@@ -819,7 +819,7 @@ class SubscriptionManager:
finally:
conn.close()
def update_subscription(self, subscription_id: str, **kwargs) -> Optional[Subscription]:
def update_subscription(self, subscription_id: str, **kwargs) -> Subscription | None:
"""更新订阅"""
conn = self._get_connection()
try:
@@ -862,7 +862,7 @@ class SubscriptionManager:
conn.close()
def cancel_subscription(self, subscription_id: str,
at_period_end: bool = True) -> Optional[Subscription]:
at_period_end: bool = True) -> Subscription | None:
"""取消订阅"""
conn = self._get_connection()
try:
@@ -904,7 +904,7 @@ class SubscriptionManager:
conn.close()
def change_plan(self, subscription_id: str, new_plan_id: str,
prorate: bool = True) -> Optional[Subscription]:
prorate: bool = True) -> Subscription | None:
"""更改订阅计划"""
conn = self._get_connection()
try:
@@ -950,8 +950,8 @@ class SubscriptionManager:
def record_usage(self, tenant_id: str, resource_type: str,
quantity: float, unit: str,
description: Optional[str] = None,
metadata: Optional[Dict] = None) -> UsageRecord:
description: str | None = None,
metadata: dict | None = None) -> UsageRecord:
"""记录用量"""
conn = self._get_connection()
try:
@@ -989,8 +989,8 @@ class SubscriptionManager:
conn.close()
def get_usage_summary(self, tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None) -> Dict[str, Any]:
start_date: datetime | None = None,
end_date: datetime | None = None) -> dict[str, Any]:
"""获取用量汇总"""
conn = self._get_connection()
try:
@@ -1061,10 +1061,10 @@ class SubscriptionManager:
# ==================== 支付管理 ====================
def create_payment(self, tenant_id: str, amount: float, currency: str,
provider: str, subscription_id: Optional[str] = None,
invoice_id: Optional[str] = None,
payment_method: Optional[str] = None,
payment_details: Optional[Dict] = None) -> Payment:
provider: str, subscription_id: str | None = None,
invoice_id: str | None = None,
payment_method: str | None = None,
payment_details: dict | None = None) -> Payment:
"""创建支付记录"""
conn = self._get_connection()
try:
@@ -1113,7 +1113,7 @@ class SubscriptionManager:
conn.close()
def confirm_payment(self, payment_id: str,
provider_payment_id: Optional[str] = None) -> Optional[Payment]:
provider_payment_id: str | None = None) -> Payment | None:
"""确认支付完成"""
conn = self._get_connection()
try:
@@ -1160,7 +1160,7 @@ class SubscriptionManager:
finally:
conn.close()
def fail_payment(self, payment_id: str, reason: str) -> Optional[Payment]:
def fail_payment(self, payment_id: str, reason: str) -> Payment | None:
"""标记支付失败"""
conn = self._get_connection()
try:
@@ -1179,7 +1179,7 @@ class SubscriptionManager:
finally:
conn.close()
def get_payment(self, payment_id: str) -> Optional[Payment]:
def get_payment(self, payment_id: str) -> Payment | None:
"""获取支付记录"""
conn = self._get_connection()
try:
@@ -1187,8 +1187,8 @@ class SubscriptionManager:
finally:
conn.close()
def list_payments(self, tenant_id: str, status: Optional[str] = None,
limit: int = 100, offset: int = 0) -> List[Payment]:
def list_payments(self, tenant_id: str, status: str | None = None,
limit: int = 100, offset: int = 0) -> list[Payment]:
"""列出支付记录"""
conn = self._get_connection()
try:
@@ -1212,7 +1212,7 @@ class SubscriptionManager:
finally:
conn.close()
def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Optional[Payment]:
def _get_payment_internal(self, conn: sqlite3.Connection, payment_id: str) -> Payment | None:
"""内部方法:获取支付记录"""
cursor = conn.cursor()
cursor.execute("SELECT * FROM payments WHERE id = ?", (payment_id,))
@@ -1225,10 +1225,10 @@ class SubscriptionManager:
# ==================== 发票管理 ====================
def _create_invoice_internal(self, conn: sqlite3.Connection, tenant_id: str,
subscription_id: Optional[str], amount: float,
subscription_id: str | None, amount: float,
currency: str, period_start: datetime,
period_end: datetime, description: str,
line_items: Optional[List[Dict]] = None) -> Invoice:
line_items: list[dict] | None = None) -> Invoice:
"""内部方法:创建发票"""
invoice_id = str(uuid.uuid4())
invoice_number = self._generate_invoice_number()
@@ -1275,7 +1275,7 @@ class SubscriptionManager:
return invoice
def get_invoice(self, invoice_id: str) -> Optional[Invoice]:
def get_invoice(self, invoice_id: str) -> Invoice | None:
"""获取发票"""
conn = self._get_connection()
try:
@@ -1290,7 +1290,7 @@ class SubscriptionManager:
finally:
conn.close()
def get_invoice_by_number(self, invoice_number: str) -> Optional[Invoice]:
def get_invoice_by_number(self, invoice_number: str) -> Invoice | None:
"""通过发票号获取发票"""
conn = self._get_connection()
try:
@@ -1305,8 +1305,8 @@ class SubscriptionManager:
finally:
conn.close()
def list_invoices(self, tenant_id: str, status: Optional[str] = None,
limit: int = 100, offset: int = 0) -> List[Invoice]:
def list_invoices(self, tenant_id: str, status: str | None = None,
limit: int = 100, offset: int = 0) -> list[Invoice]:
"""列出发票"""
conn = self._get_connection()
try:
@@ -1330,7 +1330,7 @@ class SubscriptionManager:
finally:
conn.close()
def void_invoice(self, invoice_id: str, reason: str) -> Optional[Invoice]:
def void_invoice(self, invoice_id: str, reason: str) -> Invoice | None:
"""作废发票"""
conn = self._get_connection()
try:
@@ -1442,7 +1442,7 @@ class SubscriptionManager:
finally:
conn.close()
def approve_refund(self, refund_id: str, approved_by: str) -> Optional[Refund]:
def approve_refund(self, refund_id: str, approved_by: str) -> Refund | None:
"""批准退款"""
conn = self._get_connection()
try:
@@ -1469,7 +1469,7 @@ class SubscriptionManager:
conn.close()
def complete_refund(self, refund_id: str,
provider_refund_id: Optional[str] = None) -> Optional[Refund]:
provider_refund_id: str | None = None) -> Refund | None:
"""完成退款"""
conn = self._get_connection()
try:
@@ -1507,7 +1507,7 @@ class SubscriptionManager:
finally:
conn.close()
def reject_refund(self, refund_id: str, reason: str) -> Optional[Refund]:
def reject_refund(self, refund_id: str, reason: str) -> Refund | None:
"""拒绝退款"""
conn = self._get_connection()
try:
@@ -1530,7 +1530,7 @@ class SubscriptionManager:
finally:
conn.close()
def get_refund(self, refund_id: str) -> Optional[Refund]:
def get_refund(self, refund_id: str) -> Refund | None:
"""获取退款记录"""
conn = self._get_connection()
try:
@@ -1538,8 +1538,8 @@ class SubscriptionManager:
finally:
conn.close()
def list_refunds(self, tenant_id: str, status: Optional[str] = None,
limit: int = 100, offset: int = 0) -> List[Refund]:
def list_refunds(self, tenant_id: str, status: str | None = None,
limit: int = 100, offset: int = 0) -> list[Refund]:
"""列出退款记录"""
conn = self._get_connection()
try:
@@ -1563,7 +1563,7 @@ class SubscriptionManager:
finally:
conn.close()
def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Optional[Refund]:
def _get_refund_internal(self, conn: sqlite3.Connection, refund_id: str) -> Refund | None:
"""内部方法:获取退款记录"""
cursor = conn.cursor()
cursor.execute("SELECT * FROM refunds WHERE id = ?", (refund_id,))
@@ -1593,9 +1593,9 @@ class SubscriptionManager:
))
def get_billing_history(self, tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100, offset: int = 0) -> List[BillingHistory]:
start_date: datetime | None = None,
end_date: datetime | None = None,
limit: int = 100, offset: int = 0) -> list[BillingHistory]:
"""获取账单历史"""
conn = self._get_connection()
try:
@@ -1626,7 +1626,7 @@ class SubscriptionManager:
def create_stripe_checkout_session(self, tenant_id: str, plan_id: str,
success_url: str, cancel_url: str,
billing_cycle: str = "monthly") -> Dict[str, Any]:
billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建 Stripe Checkout 会话(占位实现)"""
# 这里应该集成 Stripe SDK
# 简化实现,返回模拟数据
@@ -1638,7 +1638,7 @@ class SubscriptionManager:
}
def create_alipay_order(self, tenant_id: str, plan_id: str,
billing_cycle: str = "monthly") -> Dict[str, Any]:
billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建支付宝订单(占位实现)"""
# 这里应该集成支付宝 SDK
plan = self.get_plan(plan_id)
@@ -1654,7 +1654,7 @@ class SubscriptionManager:
}
def create_wechat_order(self, tenant_id: str, plan_id: str,
billing_cycle: str = "monthly") -> Dict[str, Any]:
billing_cycle: str = "monthly") -> dict[str, Any]:
"""创建微信支付订单(占位实现)"""
# 这里应该集成微信支付 SDK
plan = self.get_plan(plan_id)
@@ -1669,7 +1669,7 @@ class SubscriptionManager:
"provider": "wechat"
}
def handle_webhook(self, provider: str, payload: Dict[str, Any]) -> bool:
def handle_webhook(self, provider: str, payload: dict[str, Any]) -> bool:
"""处理支付提供商的 Webhook占位实现"""
# 这里应该实现实际的 Webhook 处理逻辑
logger.info(f"Received webhook from {provider}: {payload.get('event_type', 'unknown')}")

View File

@@ -11,16 +11,16 @@ InsightFlow Phase 8 - 多租户 SaaS 架构管理模块
作者: InsightFlow Team
"""
import sqlite3
import json
import uuid
import hashlib
import re
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
import json
import logging
import re
import sqlite3
import uuid
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
@@ -44,7 +44,7 @@ class TenantLimits:
UNLIMITED = -1
class TenantStatus(str, Enum):
class TenantStatus(StrEnum):
"""租户状态"""
ACTIVE = "active" # 活跃
SUSPENDED = "suspended" # 暂停
@@ -53,14 +53,14 @@ class TenantStatus(str, Enum):
PENDING = "pending" # 待激活
class TenantTier(str, Enum):
class TenantTier(StrEnum):
"""租户订阅层级"""
FREE = "free" # 免费版
PRO = "pro" # 专业版
ENTERPRISE = "enterprise" # 企业版
class TenantRole(str, Enum):
class TenantRole(StrEnum):
"""租户角色"""
OWNER = "owner" # 所有者
ADMIN = "admin" # 管理员
@@ -68,7 +68,7 @@ class TenantRole(str, Enum):
VIEWER = "viewer" # 查看者
class DomainStatus(str, Enum):
class DomainStatus(StrEnum):
"""域名状态"""
PENDING = "pending" # 待验证
VERIFIED = "verified" # 已验证
@@ -82,16 +82,16 @@ class Tenant:
id: str
name: str
slug: str # URL 友好的唯一标识
description: Optional[str]
description: str | None
tier: str # free/pro/enterprise
status: str # active/suspended/trial/expired/pending
owner_id: str # 所有者用户ID
created_at: datetime
updated_at: datetime
expires_at: Optional[datetime] # 订阅过期时间
settings: Dict[str, Any] # 租户级设置
resource_limits: Dict[str, Any] # 资源限制
metadata: Dict[str, Any] # 元数据
expires_at: datetime | None # 订阅过期时间
settings: dict[str, Any] # 租户级设置
resource_limits: dict[str, Any] # 资源限制
metadata: dict[str, Any] # 元数据
@dataclass
@@ -103,12 +103,12 @@ class TenantDomain:
status: str # pending/verified/failed/expired
verification_token: str # 验证令牌
verification_method: str # dns/file
verified_at: Optional[datetime]
verified_at: datetime | None
created_at: datetime
updated_at: datetime
is_primary: bool # 是否主域名
ssl_enabled: bool # SSL 是否启用
ssl_expires_at: Optional[datetime]
ssl_expires_at: datetime | None
@dataclass
@@ -116,14 +116,14 @@ class TenantBranding:
"""租户品牌配置数据类"""
id: str
tenant_id: str
logo_url: Optional[str] # Logo URL
favicon_url: Optional[str] # Favicon URL
primary_color: Optional[str] # 主题主色
secondary_color: Optional[str] # 主题次色
custom_css: Optional[str] # 自定义 CSS
custom_js: Optional[str] # 自定义 JS
login_page_bg: Optional[str] # 登录页背景
email_template: Optional[str] # 邮件模板
logo_url: str | None # Logo URL
favicon_url: str | None # Favicon URL
primary_color: str | None # 主题主色
secondary_color: str | None # 主题次色
custom_css: str | None # 自定义 CSS
custom_js: str | None # 自定义 JS
login_page_bg: str | None # 登录页背景
email_template: str | None # 邮件模板
created_at: datetime
updated_at: datetime
@@ -136,11 +136,11 @@ class TenantMember:
user_id: str
email: str
role: str # owner/admin/member/viewer
permissions: List[str] # 具体权限列表
invited_by: Optional[str] # 邀请者
permissions: list[str] # 具体权限列表
invited_by: str | None # 邀请者
invited_at: datetime
joined_at: Optional[datetime]
last_active_at: Optional[datetime]
joined_at: datetime | None
last_active_at: datetime | None
status: str # active/pending/suspended
@@ -151,10 +151,10 @@ class TenantPermission:
tenant_id: str
name: str # 权限名称
code: str # 权限代码
description: Optional[str]
description: str | None
resource_type: str # project/entity/api/etc
actions: List[str] # create/read/update/delete/etc
conditions: Optional[Dict] # 条件限制
actions: list[str] # create/read/update/delete/etc
conditions: dict | None # 条件限制
created_at: datetime
@@ -381,8 +381,8 @@ class TenantManager:
def create_tenant(self, name: str, owner_id: str,
tier: str = "free",
description: Optional[str] = None,
settings: Optional[Dict] = None) -> Tenant:
description: str | None = None,
settings: dict | None = None) -> Tenant:
"""创建新租户"""
conn = self._get_connection()
try:
@@ -436,7 +436,7 @@ class TenantManager:
finally:
conn.close()
def get_tenant(self, tenant_id: str) -> Optional[Tenant]:
def get_tenant(self, tenant_id: str) -> Tenant | None:
"""获取租户信息"""
conn = self._get_connection()
try:
@@ -451,7 +451,7 @@ class TenantManager:
finally:
conn.close()
def get_tenant_by_slug(self, slug: str) -> Optional[Tenant]:
def get_tenant_by_slug(self, slug: str) -> Tenant | None:
"""通过 slug 获取租户"""
conn = self._get_connection()
try:
@@ -466,7 +466,7 @@ class TenantManager:
finally:
conn.close()
def get_tenant_by_domain(self, domain: str) -> Optional[Tenant]:
def get_tenant_by_domain(self, domain: str) -> Tenant | None:
"""通过自定义域名获取租户"""
conn = self._get_connection()
try:
@@ -486,11 +486,11 @@ class TenantManager:
conn.close()
def update_tenant(self, tenant_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
tier: Optional[str] = None,
status: Optional[str] = None,
settings: Optional[Dict] = None) -> Optional[Tenant]:
name: str | None = None,
description: str | None = None,
tier: str | None = None,
status: str | None = None,
settings: dict | None = None) -> Tenant | None:
"""更新租户信息"""
conn = self._get_connection()
try:
@@ -548,9 +548,9 @@ class TenantManager:
finally:
conn.close()
def list_tenants(self, status: Optional[str] = None,
tier: Optional[str] = None,
limit: int = 100, offset: int = 0) -> List[Tenant]:
def list_tenants(self, status: str | None = None,
tier: str | None = None,
limit: int = 100, offset: int = 0) -> list[Tenant]:
"""列出租户"""
conn = self._get_connection()
try:
@@ -689,7 +689,7 @@ class TenantManager:
finally:
conn.close()
def get_domain_verification_instructions(self, domain_id: str) -> Dict[str, Any]:
def get_domain_verification_instructions(self, domain_id: str) -> dict[str, Any]:
"""获取域名验证指导"""
conn = self._get_connection()
try:
@@ -739,7 +739,7 @@ class TenantManager:
finally:
conn.close()
def list_domains(self, tenant_id: str) -> List[TenantDomain]:
def list_domains(self, tenant_id: str) -> list[TenantDomain]:
"""列出租户的所有域名"""
conn = self._get_connection()
try:
@@ -758,7 +758,7 @@ class TenantManager:
# ==================== 品牌白标管理 ====================
def get_branding(self, tenant_id: str) -> Optional[TenantBranding]:
def get_branding(self, tenant_id: str) -> TenantBranding | None:
"""获取租户品牌配置"""
conn = self._get_connection()
try:
@@ -774,14 +774,14 @@ class TenantManager:
conn.close()
def update_branding(self, tenant_id: str,
logo_url: Optional[str] = None,
favicon_url: Optional[str] = None,
primary_color: Optional[str] = None,
secondary_color: Optional[str] = None,
custom_css: Optional[str] = None,
custom_js: Optional[str] = None,
login_page_bg: Optional[str] = None,
email_template: Optional[str] = None) -> TenantBranding:
logo_url: str | None = None,
favicon_url: str | None = None,
primary_color: str | None = None,
secondary_color: str | None = None,
custom_css: str | None = None,
custom_js: str | None = None,
login_page_bg: str | None = None,
email_template: str | None = None) -> TenantBranding:
"""更新租户品牌配置"""
conn = self._get_connection()
try:
@@ -890,7 +890,7 @@ class TenantManager:
# ==================== 成员与权限管理 ====================
def invite_member(self, tenant_id: str, email: str, role: str,
invited_by: str, permissions: Optional[List[str]] = None) -> TenantMember:
invited_by: str, permissions: list[str] | None = None) -> TenantMember:
"""邀请成员加入租户"""
conn = self._get_connection()
try:
@@ -967,7 +967,7 @@ class TenantManager:
conn.close()
def update_member_role(self, tenant_id: str, member_id: str,
role: str, permissions: Optional[List[str]] = None) -> bool:
role: str, permissions: list[str] | None = None) -> bool:
"""更新成员角色"""
conn = self._get_connection()
try:
@@ -988,7 +988,7 @@ class TenantManager:
finally:
conn.close()
def list_members(self, tenant_id: str, status: Optional[str] = None) -> List[TenantMember]:
def list_members(self, tenant_id: str, status: str | None = None) -> list[TenantMember]:
"""列出租户成员"""
conn = self._get_connection()
try:
@@ -1042,7 +1042,7 @@ class TenantManager:
finally:
conn.close()
def get_user_tenants(self, user_id: str) -> List[Dict[str, Any]]:
def get_user_tenants(self, user_id: str) -> list[dict[str, Any]]:
"""获取用户所属的所有租户"""
conn = self._get_connection()
try:
@@ -1108,8 +1108,8 @@ class TenantManager:
conn.close()
def get_usage_stats(self, tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None) -> Dict[str, Any]:
start_date: datetime | None = None,
end_date: datetime | None = None) -> dict[str, Any]:
"""获取使用统计"""
conn = self._get_connection()
try:
@@ -1165,7 +1165,7 @@ class TenantManager:
finally:
conn.close()
def check_resource_limit(self, tenant_id: str, resource_type: str) -> Tuple[bool, int, int]:
def check_resource_limit(self, tenant_id: str, resource_type: str) -> tuple[bool, int, int]:
"""检查资源是否超限
Returns:
@@ -1288,7 +1288,7 @@ class TenantManager:
def _add_member_internal(self, conn: sqlite3.Connection, tenant_id: str,
user_id: str, email: str, role: TenantRole,
invited_by: Optional[str]):
invited_by: str | None):
"""内部方法:添加成员"""
cursor = conn.cursor()
member_id = str(uuid.uuid4())
@@ -1416,8 +1416,8 @@ class TenantManager:
class TenantContext:
"""租户上下文管理器 - 用于请求级别的租户隔离"""
_current_tenant_id: Optional[str] = None
_current_user_id: Optional[str] = None
_current_tenant_id: str | None = None
_current_user_id: str | None = None
@classmethod
def set_current_tenant(cls, tenant_id: str):
@@ -1425,7 +1425,7 @@ class TenantContext:
cls._current_tenant_id = tenant_id
@classmethod
def get_current_tenant(cls) -> Optional[str]:
def get_current_tenant(cls) -> str | None:
"""获取当前租户ID"""
return cls._current_tenant_id
@@ -1435,7 +1435,7 @@ class TenantContext:
cls._current_user_id = user_id
@classmethod
def get_current_user(cls) -> Optional[str]:
def get_current_user(cls) -> str | None:
"""获取当前用户ID"""
return cls._current_user_id

View File

@@ -4,8 +4,8 @@ InsightFlow Multimodal Module Test Script
测试多模态支持模块
"""
import sys
import os
import sys
# 添加 backend 目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -18,25 +18,19 @@ print("=" * 60)
print("\n1. 测试模块导入...")
try:
from multimodal_processor import (
get_multimodal_processor
)
from multimodal_processor import get_multimodal_processor
print(" ✓ multimodal_processor 导入成功")
except ImportError as e:
print(f" ✗ multimodal_processor 导入失败: {e}")
try:
from image_processor import (
get_image_processor
)
from image_processor import get_image_processor
print(" ✓ image_processor 导入成功")
except ImportError as e:
print(f" ✗ image_processor 导入失败: {e}")
try:
from multimodal_entity_linker import (
get_multimodal_entity_linker
)
from multimodal_entity_linker import get_multimodal_entity_linker
print(" ✓ multimodal_entity_linker 导入成功")
except ImportError as e:
print(f" ✗ multimodal_entity_linker 导入失败: {e}")

View File

@@ -4,19 +4,19 @@ InsightFlow Phase 7 Task 6 & 8 测试脚本
测试高级搜索与发现、性能优化与扩展功能
"""
from performance_manager import (
get_performance_manager, CacheManager,
TaskQueue, PerformanceMonitor
)
from search_manager import (
get_search_manager, FullTextSearch,
SemanticSearch, EntityPathDiscovery,
KnowledgeGapDetection
)
import os
import sys
import time
from performance_manager import CacheManager, PerformanceMonitor, TaskQueue, get_performance_manager
from search_manager import (
EntityPathDiscovery,
FullTextSearch,
KnowledgeGapDetection,
SemanticSearch,
get_search_manager,
)
# 添加 backend 到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -10,11 +10,11 @@ InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试脚本
5. 资源使用统计
"""
from tenant_manager import (
get_tenant_manager
)
import sys
import os
import sys
from tenant_manager import get_tenant_manager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -3,13 +3,12 @@
InsightFlow Phase 8 Task 2 测试脚本 - 订阅与计费系统
"""
from subscription_manager import (
SubscriptionManager, PaymentProvider
)
import sys
import os
import sys
import tempfile
from subscription_manager import PaymentProvider, SubscriptionManager
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -4,12 +4,11 @@ InsightFlow Phase 8 Task 4 测试脚本
测试 AI 能力增强功能
"""
from ai_manager import (
get_ai_manager, ModelType, PredictionType
)
import asyncio
import sys
import os
import sys
from ai_manager import ModelType, PredictionType, get_ai_manager
# Add backend directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

View File

@@ -13,14 +13,20 @@ InsightFlow Phase 8 Task 5 - 运营与增长工具测试脚本
python test_phase8_task5.py
"""
from growth_manager import (
GrowthManager, EventType, ExperimentStatus, TrafficAllocationType, EmailTemplateType, WorkflowTriggerType
)
import asyncio
import sys
import os
import sys
from datetime import datetime, timedelta
from growth_manager import (
EmailTemplateType,
EventType,
ExperimentStatus,
GrowthManager,
TrafficAllocationType,
WorkflowTriggerType,
)
# 添加 backend 目录到路径
backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path:

View File

@@ -10,17 +10,20 @@ InsightFlow Phase 8 Task 6: Developer Ecosystem Test Script
4. 开发者文档与示例代码
"""
from developer_ecosystem_manager import (
DeveloperEcosystemManager,
SDKLanguage, TemplateCategory,
PluginCategory, PluginStatus,
DeveloperStatus
)
import sys
import os
import sys
import uuid
from datetime import datetime
from developer_ecosystem_manager import (
DeveloperEcosystemManager,
DeveloperStatus,
PluginCategory,
PluginStatus,
SDKLanguage,
TemplateCategory,
)
# Add backend directory to path
backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path:

View File

@@ -10,15 +10,20 @@ InsightFlow Phase 8 Task 8: Operations & Monitoring Test Script
4. 成本优化
"""
from ops_manager import (
get_ops_manager, AlertSeverity, AlertStatus, AlertChannelType, AlertRuleType,
ResourceType
)
import json
import os
import sys
import json
from datetime import datetime, timedelta
from ops_manager import (
AlertChannelType,
AlertRuleType,
AlertSeverity,
AlertStatus,
ResourceType,
get_ops_manager,
)
# Add backend directory to path
backend_dir = os.path.dirname(os.path.abspath(__file__))
if backend_dir not in sys.path:

View File

@@ -6,7 +6,7 @@
import os
import time
from datetime import datetime
from typing import Dict, Any
from typing import Any
class TingwuClient:
@@ -18,7 +18,7 @@ class TingwuClient:
if not self.access_key or not self.secret_key:
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> Dict[str, str]:
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
"""阿里云签名 V3"""
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
@@ -35,9 +35,9 @@ class TingwuClient:
def create_task(self, audio_url: str, language: str = "zh") -> str:
"""创建听悟任务"""
try:
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
from alibabacloud_tea_openapi import models as open_api_models
config = open_api_models.Config(
access_key_id=self.access_key,
@@ -74,12 +74,12 @@ class TingwuClient:
print(f"Tingwu API error: {e}")
return f"mock_task_{int(time.time())}"
def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> Dict[str, Any]:
def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> dict[str, Any]:
"""获取任务结果"""
try:
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
from alibabacloud_tea_openapi import models as open_api_models
config = open_api_models.Config(
access_key_id=self.access_key,
@@ -114,7 +114,7 @@ class TingwuClient:
raise TimeoutError(f"Task {task_id} timeout")
def _parse_result(self, data) -> Dict[str, Any]:
def _parse_result(self, data) -> dict[str, Any]:
"""解析结果"""
result = data.result
transcription = result.transcription
@@ -140,7 +140,7 @@ class TingwuClient:
"segments": segments
}
def _mock_result(self) -> Dict[str, Any]:
def _mock_result(self) -> dict[str, Any]:
"""Mock 结果"""
return {
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
@@ -149,7 +149,7 @@ class TingwuClient:
]
}
def transcribe(self, audio_url: str, language: str = "zh") -> Dict[str, Any]:
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
"""一键转录"""
task_id = self.create_task(audio_url, language)
print(f"Tingwu task: {task_id}")

View File

@@ -9,20 +9,21 @@ InsightFlow Workflow Manager - Phase 7
- 工作流配置管理
"""
import json
import uuid
import asyncio
import httpx
import json
import logging
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Callable, Any
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
import httpx
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.events import EVENT_JOB_EXECUTED, EVENT_JOB_ERROR
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -70,9 +71,9 @@ class WorkflowTask:
workflow_id: str
name: str
task_type: str # analyze, align, discover_relations, notify, custom
config: Dict = field(default_factory=dict)
config: dict = field(default_factory=dict)
order: int = 0
depends_on: List[str] = field(default_factory=list)
depends_on: list[str] = field(default_factory=list)
timeout_seconds: int = 300
retry_count: int = 3
retry_delay: int = 5
@@ -94,12 +95,12 @@ class WebhookConfig:
webhook_type: str # feishu, dingtalk, slack, custom
url: str
secret: str = "" # 用于签名验证
headers: Dict = field(default_factory=dict)
headers: dict = field(default_factory=dict)
template: str = "" # 消息模板
is_active: bool = True
created_at: str = ""
updated_at: str = ""
last_used_at: Optional[str] = None
last_used_at: str | None = None
success_count: int = 0
fail_count: int = 0
@@ -119,15 +120,15 @@ class Workflow:
workflow_type: str
project_id: str
status: str = "active"
schedule: Optional[str] = None # cron expression or interval
schedule: str | None = None # cron expression or interval
schedule_type: str = "manual" # manual, cron, interval
config: Dict = field(default_factory=dict)
webhook_ids: List[str] = field(default_factory=list)
config: dict = field(default_factory=dict)
webhook_ids: list[str] = field(default_factory=list)
is_active: bool = True
created_at: str = ""
updated_at: str = ""
last_run_at: Optional[str] = None
next_run_at: Optional[str] = None
last_run_at: str | None = None
next_run_at: str | None = None
run_count: int = 0
success_count: int = 0
fail_count: int = 0
@@ -144,13 +145,13 @@ class WorkflowLog:
"""工作流执行日志"""
id: str
workflow_id: str
task_id: Optional[str] = None
task_id: str | None = None
status: str = "pending" # pending, running, success, failed, cancelled
start_time: Optional[str] = None
end_time: Optional[str] = None
start_time: str | None = None
end_time: str | None = None
duration_ms: int = 0
input_data: Dict = field(default_factory=dict)
output_data: Dict = field(default_factory=dict)
input_data: dict = field(default_factory=dict)
output_data: dict = field(default_factory=dict)
error_message: str = ""
created_at: str = ""
@@ -165,7 +166,7 @@ class WebhookNotifier:
def __init__(self):
self.http_client = httpx.AsyncClient(timeout=30.0)
async def send(self, config: WebhookConfig, message: Dict) -> bool:
async def send(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Webhook 通知"""
try:
webhook_type = WebhookType(config.webhook_type)
@@ -179,14 +180,14 @@ class WebhookNotifier:
else:
return await self._send_custom(config, message)
except (httpx.HTTPError, asyncio.TimeoutError) as e:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Webhook send failed: {e}")
return False
async def _send_feishu(self, config: WebhookConfig, message: Dict) -> bool:
async def _send_feishu(self, config: WebhookConfig, message: dict) -> bool:
"""发送飞书通知"""
import hashlib
import base64
import hashlib
import hmac
timestamp = str(int(datetime.now().timestamp()))
@@ -252,10 +253,10 @@ class WebhookNotifier:
return result.get("code") == 0
async def _send_dingtalk(self, config: WebhookConfig, message: Dict) -> bool:
async def _send_dingtalk(self, config: WebhookConfig, message: dict) -> bool:
"""发送钉钉通知"""
import hashlib
import base64
import hashlib
import hmac
import urllib.parse
@@ -314,7 +315,7 @@ class WebhookNotifier:
return result.get("errcode") == 0
async def _send_slack(self, config: WebhookConfig, message: Dict) -> bool:
async def _send_slack(self, config: WebhookConfig, message: dict) -> bool:
"""发送 Slack 通知"""
# Slack 直接支持标准 webhook 格式
payload = {
@@ -341,7 +342,7 @@ class WebhookNotifier:
return response.text == "ok"
async def _send_custom(self, config: WebhookConfig, message: Dict) -> bool:
async def _send_custom(self, config: WebhookConfig, message: dict) -> bool:
"""发送自定义 Webhook 通知"""
headers = {
"Content-Type": "application/json",
@@ -374,8 +375,8 @@ class WorkflowManager:
self.db = db_manager
self.scheduler = AsyncIOScheduler()
self.notifier = WebhookNotifier()
self._task_handlers: Dict[str, Callable] = {}
self._running_tasks: Dict[str, asyncio.Task] = {}
self._task_handlers: dict[str, Callable] = {}
self._running_tasks: dict[str, asyncio.Task] = {}
self._setup_default_handlers()
# 添加调度器事件监听
@@ -421,7 +422,7 @@ class WorkflowManager:
for workflow in workflows:
if workflow.schedule and workflow.is_active:
self._schedule_workflow(workflow)
except (httpx.HTTPError, asyncio.TimeoutError) as e:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Failed to load workflows: {e}")
def _schedule_workflow(self, workflow: Workflow):
@@ -458,7 +459,7 @@ class WorkflowManager:
"""调度器调用的工作流执行函数"""
try:
await self.execute_workflow(workflow_id)
except (httpx.HTTPError, asyncio.TimeoutError) as e:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Scheduled workflow execution failed: {e}")
def _on_job_executed(self, event):
@@ -497,7 +498,7 @@ class WorkflowManager:
finally:
conn.close()
def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
def get_workflow(self, workflow_id: str) -> Workflow | None:
"""获取工作流"""
conn = self.db.get_conn()
try:
@@ -514,7 +515,7 @@ class WorkflowManager:
conn.close()
def list_workflows(self, project_id: str = None, status: str = None,
workflow_type: str = None) -> List[Workflow]:
workflow_type: str = None) -> list[Workflow]:
"""列出工作流"""
conn = self.db.get_conn()
try:
@@ -542,7 +543,7 @@ class WorkflowManager:
finally:
conn.close()
def update_workflow(self, workflow_id: str, **kwargs) -> Optional[Workflow]:
def update_workflow(self, workflow_id: str, **kwargs) -> Workflow | None:
"""更新工作流"""
conn = self.db.get_conn()
try:
@@ -648,7 +649,7 @@ class WorkflowManager:
finally:
conn.close()
def get_task(self, task_id: str) -> Optional[WorkflowTask]:
def get_task(self, task_id: str) -> WorkflowTask | None:
"""获取任务"""
conn = self.db.get_conn()
try:
@@ -664,7 +665,7 @@ class WorkflowManager:
finally:
conn.close()
def list_tasks(self, workflow_id: str) -> List[WorkflowTask]:
def list_tasks(self, workflow_id: str) -> list[WorkflowTask]:
"""列出工作流的所有任务"""
conn = self.db.get_conn()
try:
@@ -677,7 +678,7 @@ class WorkflowManager:
finally:
conn.close()
def update_task(self, task_id: str, **kwargs) -> Optional[WorkflowTask]:
def update_task(self, task_id: str, **kwargs) -> WorkflowTask | None:
"""更新任务"""
conn = self.db.get_conn()
try:
@@ -758,7 +759,7 @@ class WorkflowManager:
finally:
conn.close()
def get_webhook(self, webhook_id: str) -> Optional[WebhookConfig]:
def get_webhook(self, webhook_id: str) -> WebhookConfig | None:
"""获取 Webhook 配置"""
conn = self.db.get_conn()
try:
@@ -774,7 +775,7 @@ class WorkflowManager:
finally:
conn.close()
def list_webhooks(self) -> List[WebhookConfig]:
def list_webhooks(self) -> list[WebhookConfig]:
"""列出所有 Webhook 配置"""
conn = self.db.get_conn()
try:
@@ -786,7 +787,7 @@ class WorkflowManager:
finally:
conn.close()
def update_webhook(self, webhook_id: str, **kwargs) -> Optional[WebhookConfig]:
def update_webhook(self, webhook_id: str, **kwargs) -> WebhookConfig | None:
"""更新 Webhook 配置"""
conn = self.db.get_conn()
try:
@@ -889,7 +890,7 @@ class WorkflowManager:
finally:
conn.close()
def update_log(self, log_id: str, **kwargs) -> Optional[WorkflowLog]:
def update_log(self, log_id: str, **kwargs) -> WorkflowLog | None:
"""更新工作流日志"""
conn = self.db.get_conn()
try:
@@ -918,7 +919,7 @@ class WorkflowManager:
finally:
conn.close()
def get_log(self, log_id: str) -> Optional[WorkflowLog]:
def get_log(self, log_id: str) -> WorkflowLog | None:
"""获取日志"""
conn = self.db.get_conn()
try:
@@ -935,7 +936,7 @@ class WorkflowManager:
conn.close()
def list_logs(self, workflow_id: str = None, task_id: str = None,
status: str = None, limit: int = 100, offset: int = 0) -> List[WorkflowLog]:
status: str = None, limit: int = 100, offset: int = 0) -> list[WorkflowLog]:
"""列出工作流日志"""
conn = self.db.get_conn()
try:
@@ -966,7 +967,7 @@ class WorkflowManager:
finally:
conn.close()
def get_workflow_stats(self, workflow_id: str, days: int = 30) -> Dict:
def get_workflow_stats(self, workflow_id: str, days: int = 30) -> dict:
"""获取工作流统计"""
conn = self.db.get_conn()
try:
@@ -1037,7 +1038,7 @@ class WorkflowManager:
# ==================== Workflow Execution ====================
async def execute_workflow(self, workflow_id: str, input_data: Dict = None) -> Dict:
async def execute_workflow(self, workflow_id: str, input_data: dict = None) -> dict:
"""执行工作流"""
workflow = self.get_workflow(workflow_id)
if not workflow:
@@ -1100,7 +1101,7 @@ class WorkflowManager:
"duration_ms": duration
}
except (httpx.HTTPError, asyncio.TimeoutError) as e:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Workflow {workflow_id} execution failed: {e}")
# 更新日志为失败
@@ -1122,8 +1123,8 @@ class WorkflowManager:
raise
async def _execute_tasks_with_deps(self, tasks: List[WorkflowTask],
input_data: Dict, log_id: str) -> Dict:
async def _execute_tasks_with_deps(self, tasks: list[WorkflowTask],
input_data: dict, log_id: str) -> dict:
"""按依赖顺序执行任务"""
results = {}
completed_tasks = set()
@@ -1158,7 +1159,7 @@ class WorkflowManager:
try:
result = await self._execute_single_task(task, task_input, log_id)
break
except (httpx.HTTPError, asyncio.TimeoutError) as e:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Task {task.id} retry {attempt + 1} failed: {e}")
if attempt == task.retry_count - 1:
raise
@@ -1171,7 +1172,7 @@ class WorkflowManager:
return results
async def _execute_single_task(self, task: WorkflowTask,
input_data: Dict, log_id: str) -> Any:
input_data: dict, log_id: str) -> Any:
"""执行单个任务"""
handler = self._task_handlers.get(task.task_type)
if not handler:
@@ -1205,7 +1206,7 @@ class WorkflowManager:
return result
except asyncio.TimeoutError:
except TimeoutError:
self.update_log(
task_log.id,
status=TaskStatus.FAILED.value,
@@ -1224,7 +1225,7 @@ class WorkflowManager:
raise
async def _execute_default_workflow(self, workflow: Workflow,
input_data: Dict) -> Dict:
input_data: dict) -> dict:
"""执行默认工作流(根据类型)"""
workflow_type = WorkflowType(workflow.workflow_type)
@@ -1241,7 +1242,7 @@ class WorkflowManager:
# ==================== Default Task Handlers ====================
async def _handle_analyze_task(self, task: WorkflowTask, input_data: Dict) -> Dict:
async def _handle_analyze_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理分析任务"""
project_id = input_data.get("project_id")
file_ids = input_data.get("file_ids", [])
@@ -1258,7 +1259,7 @@ class WorkflowManager:
"status": "completed"
}
async def _handle_align_task(self, task: WorkflowTask, input_data: Dict) -> Dict:
async def _handle_align_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理实体对齐任务"""
project_id = input_data.get("project_id")
threshold = task.config.get("threshold", 0.85)
@@ -1276,7 +1277,7 @@ class WorkflowManager:
}
async def _handle_discover_relations_task(self, task: WorkflowTask,
input_data: Dict) -> Dict:
input_data: dict) -> dict:
"""处理关系发现任务"""
project_id = input_data.get("project_id")
@@ -1291,7 +1292,7 @@ class WorkflowManager:
"status": "completed"
}
async def _handle_notify_task(self, task: WorkflowTask, input_data: Dict) -> Dict:
async def _handle_notify_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理通知任务"""
webhook_id = task.config.get("webhook_id")
message = task.config.get("message", {})
@@ -1319,7 +1320,7 @@ class WorkflowManager:
"success": success
}
async def _handle_custom_task(self, task: WorkflowTask, input_data: Dict) -> Dict:
async def _handle_custom_task(self, task: WorkflowTask, input_data: dict) -> dict:
"""处理自定义任务"""
# 自定义任务的具体逻辑由外部处理器实现
return {
@@ -1331,7 +1332,7 @@ class WorkflowManager:
# ==================== Default Workflow Implementations ====================
async def _auto_analyze_files(self, workflow: Workflow, input_data: Dict) -> Dict:
async def _auto_analyze_files(self, workflow: Workflow, input_data: dict) -> dict:
"""自动分析新上传的文件"""
project_id = workflow.project_id
@@ -1346,7 +1347,7 @@ class WorkflowManager:
"status": "completed"
}
async def _auto_align_entities(self, workflow: Workflow, input_data: Dict) -> Dict:
async def _auto_align_entities(self, workflow: Workflow, input_data: dict) -> dict:
"""自动实体对齐"""
project_id = workflow.project_id
threshold = workflow.config.get("threshold", 0.85)
@@ -1359,7 +1360,7 @@ class WorkflowManager:
"status": "completed"
}
async def _auto_discover_relations(self, workflow: Workflow, input_data: Dict) -> Dict:
async def _auto_discover_relations(self, workflow: Workflow, input_data: dict) -> dict:
"""自动关系发现"""
project_id = workflow.project_id
@@ -1370,7 +1371,7 @@ class WorkflowManager:
"status": "completed"
}
async def _generate_scheduled_report(self, workflow: Workflow, input_data: Dict) -> Dict:
async def _generate_scheduled_report(self, workflow: Workflow, input_data: dict) -> dict:
"""生成定时报告"""
project_id = workflow.project_id
report_type = workflow.config.get("report_type", "summary")
@@ -1385,7 +1386,7 @@ class WorkflowManager:
# ==================== Notification ====================
async def _send_workflow_notification(self, workflow: Workflow,
results: Dict, success: bool = True):
results: dict, success: bool = True):
"""发送工作流执行通知"""
if not workflow.webhook_ids:
return
@@ -1414,11 +1415,11 @@ class WorkflowManager:
try:
result = await self.notifier.send(webhook, message)
self.update_webhook_stats(webhook_id, result)
except (httpx.HTTPError, asyncio.TimeoutError) as e:
except (TimeoutError, httpx.HTTPError) as e:
logger.error(f"Failed to send notification to {webhook_id}: {e}")
def _build_feishu_message(self, workflow: Workflow, results: Dict,
success: bool) -> Dict:
def _build_feishu_message(self, workflow: Workflow, results: dict,
success: bool) -> dict:
"""构建飞书消息"""
status_text = "✅ 成功" if success else "❌ 失败"
@@ -1431,8 +1432,8 @@ class WorkflowManager:
]
}
def _build_dingtalk_message(self, workflow: Workflow, results: Dict,
success: bool) -> Dict:
def _build_dingtalk_message(self, workflow: Workflow, results: dict,
success: bool) -> dict:
"""构建钉钉消息"""
status_text = "✅ 成功" if success else "❌ 失败"
@@ -1453,8 +1454,8 @@ class WorkflowManager:
"""
}
def _build_slack_message(self, workflow: Workflow, results: Dict,
success: bool) -> Dict:
def _build_slack_message(self, workflow: Workflow, results: dict,
success: bool) -> dict:
"""构建 Slack 消息"""
color = "#36a64f" if success else "#ff0000"
status_text = "Success" if success else "Failed"