fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解 - 修复缺失的urllib.parse导入
This commit is contained in:
@@ -27,6 +27,7 @@ import httpx
|
||||
# Database path
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
"""模型类型"""
|
||||
|
||||
@@ -35,6 +36,7 @@ class ModelType(StrEnum):
|
||||
SUMMARIZATION = "summarization" # 摘要
|
||||
PREDICTION = "prediction" # 预测
|
||||
|
||||
|
||||
class ModelStatus(StrEnum):
|
||||
"""模型状态"""
|
||||
|
||||
@@ -44,6 +46,7 @@ class ModelStatus(StrEnum):
|
||||
FAILED = "failed"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class MultimodalProvider(StrEnum):
|
||||
"""多模态模型提供商"""
|
||||
|
||||
@@ -52,6 +55,7 @@ class MultimodalProvider(StrEnum):
|
||||
GEMINI = "gemini-pro-vision"
|
||||
KIMI_VL = "kimi-vl"
|
||||
|
||||
|
||||
class PredictionType(StrEnum):
|
||||
"""预测类型"""
|
||||
|
||||
@@ -60,6 +64,7 @@ class PredictionType(StrEnum):
|
||||
ENTITY_GROWTH = "entity_growth" # 实体增长预测
|
||||
RELATION_EVOLUTION = "relation_evolution" # 关系演变预测
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomModel:
|
||||
"""自定义模型"""
|
||||
@@ -79,6 +84,7 @@ class CustomModel:
|
||||
trained_at: str | None
|
||||
created_by: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingSample:
|
||||
"""训练样本"""
|
||||
@@ -90,6 +96,7 @@ class TrainingSample:
|
||||
metadata: dict
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultimodalAnalysis:
|
||||
"""多模态分析结果"""
|
||||
@@ -106,6 +113,7 @@ class MultimodalAnalysis:
|
||||
cost: float
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeGraphRAG:
|
||||
"""基于知识图谱的 RAG 配置"""
|
||||
@@ -122,6 +130,7 @@ class KnowledgeGraphRAG:
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGQuery:
|
||||
"""RAG 查询记录"""
|
||||
@@ -137,6 +146,7 @@ class RAGQuery:
|
||||
latency_ms: int
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionModel:
|
||||
"""预测模型"""
|
||||
@@ -156,6 +166,7 @@ class PredictionModel:
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionResult:
|
||||
"""预测结果"""
|
||||
@@ -171,6 +182,7 @@ class PredictionResult:
|
||||
is_correct: bool | None
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmartSummary:
|
||||
"""智能摘要"""
|
||||
@@ -188,6 +200,7 @@ class SmartSummary:
|
||||
tokens_used: int
|
||||
created_at: str
|
||||
|
||||
|
||||
class AIManager:
|
||||
"""AI 能力管理主类"""
|
||||
|
||||
@@ -304,7 +317,12 @@ class AIManager:
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
sample = TrainingSample(
|
||||
id=sample_id, model_id=model_id, text=text, entities=entities, metadata=metadata or {}, created_at=now
|
||||
id=sample_id,
|
||||
model_id=model_id,
|
||||
text=text,
|
||||
entities=entities,
|
||||
metadata=metadata or {},
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
with self._get_db() as conn:
|
||||
@@ -410,20 +428,30 @@ class AIManager:
|
||||
|
||||
entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"])
|
||||
|
||||
prompt = f"""从以下文本中提取实体,类型限定为: {', '.join(entity_types)}
|
||||
prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)}
|
||||
|
||||
文本: {text}
|
||||
|
||||
以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}]
|
||||
只返回 JSON 数组,不要其他内容。"""
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1}
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.1,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0
|
||||
f"{self.kimi_base_url}/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -506,7 +534,10 @@ class AIManager:
|
||||
|
||||
async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict:
|
||||
"""调用 GPT-4V"""
|
||||
headers = {"Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.openai_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
for url in image_urls:
|
||||
@@ -520,7 +551,10 @@ class AIManager:
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=120.0
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -552,7 +586,10 @@ class AIManager:
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.anthropic.com/v1/messages", headers=headers, json=payload, timeout=120.0
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -560,23 +597,34 @@ class AIManager:
|
||||
return {
|
||||
"content": result["content"][0]["text"],
|
||||
"tokens_used": result["usage"]["input_tokens"] + result["usage"]["output_tokens"],
|
||||
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"]) * 0.000015,
|
||||
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"])
|
||||
* 0.000015,
|
||||
}
|
||||
|
||||
async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict:
|
||||
"""调用 Kimi 多模态模型"""
|
||||
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Kimi 目前可能不支持真正的多模态,这里模拟返回
|
||||
# 实际实现时需要根据 Kimi API 更新
|
||||
|
||||
content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。"
|
||||
|
||||
payload = {"model": "k2p5", "messages": [{"role": "user", "content": content}], "temperature": 0.3}
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0
|
||||
f"{self.kimi_base_url}/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -587,7 +635,9 @@ class AIManager:
|
||||
"cost": result["usage"]["total_tokens"] * 0.000005,
|
||||
}
|
||||
|
||||
def get_multimodal_analyses(self, tenant_id: str, project_id: str | None = None) -> list[MultimodalAnalysis]:
|
||||
def get_multimodal_analyses(
|
||||
self, tenant_id: str, project_id: str | None = None
|
||||
) -> list[MultimodalAnalysis]:
|
||||
"""获取多模态分析历史"""
|
||||
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
|
||||
params = [tenant_id]
|
||||
@@ -668,7 +718,9 @@ class AIManager:
|
||||
|
||||
return self._row_to_kg_rag(row)
|
||||
|
||||
def list_kg_rags(self, tenant_id: str, project_id: str | None = None) -> list[KnowledgeGraphRAG]:
|
||||
def list_kg_rags(
|
||||
self, tenant_id: str, project_id: str | None = None
|
||||
) -> list[KnowledgeGraphRAG]:
|
||||
"""列出知识图谱 RAG 配置"""
|
||||
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
|
||||
params = [tenant_id]
|
||||
@@ -720,7 +772,10 @@ class AIManager:
|
||||
relevant_relations = []
|
||||
entity_ids = {e["id"] for e in relevant_entities}
|
||||
for relation in project_relations:
|
||||
if relation.get("source_entity_id") in entity_ids or relation.get("target_entity_id") in entity_ids:
|
||||
if (
|
||||
relation.get("source_entity_id") in entity_ids
|
||||
or relation.get("target_entity_id") in entity_ids
|
||||
):
|
||||
relevant_relations.append(relation)
|
||||
|
||||
# 2. 构建上下文
|
||||
@@ -747,7 +802,10 @@ class AIManager:
|
||||
2. 如果涉及多个实体,说明它们之间的关联
|
||||
3. 保持简洁专业"""
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
@@ -758,7 +816,10 @@ class AIManager:
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0
|
||||
f"{self.kimi_base_url}/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -773,7 +834,8 @@ class AIManager:
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
sources = [
|
||||
{"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]} for e in relevant_entities
|
||||
{"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]}
|
||||
for e in relevant_entities
|
||||
]
|
||||
|
||||
rag_query = RAGQuery(
|
||||
@@ -843,7 +905,13 @@ class AIManager:
|
||||
return "\n".join(context)
|
||||
|
||||
async def generate_smart_summary(
|
||||
self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
project_id: str,
|
||||
source_type: str,
|
||||
source_id: str,
|
||||
summary_type: str,
|
||||
content_data: dict,
|
||||
) -> SmartSummary:
|
||||
"""生成智能摘要"""
|
||||
summary_id = f"ss_{uuid.uuid4().hex[:16]}"
|
||||
@@ -853,7 +921,7 @@ class AIManager:
|
||||
if summary_type == "extractive":
|
||||
prompt = f"""从以下内容中提取关键句子作为摘要:
|
||||
|
||||
{content_data.get('text', '')[:5000]}
|
||||
{content_data.get("text", "")[:5000]}
|
||||
|
||||
要求:
|
||||
1. 提取 3-5 个最重要的句子
|
||||
@@ -863,7 +931,7 @@ class AIManager:
|
||||
elif summary_type == "abstractive":
|
||||
prompt = f"""对以下内容生成简洁的摘要:
|
||||
|
||||
{content_data.get('text', '')[:5000]}
|
||||
{content_data.get("text", "")[:5000]}
|
||||
|
||||
要求:
|
||||
1. 用 2-3 句话概括核心内容
|
||||
@@ -873,7 +941,7 @@ class AIManager:
|
||||
elif summary_type == "key_points":
|
||||
prompt = f"""从以下内容中提取关键要点:
|
||||
|
||||
{content_data.get('text', '')[:5000]}
|
||||
{content_data.get("text", "")[:5000]}
|
||||
|
||||
要求:
|
||||
1. 列出 5-8 个关键要点
|
||||
@@ -883,20 +951,30 @@ class AIManager:
|
||||
else: # timeline
|
||||
prompt = f"""基于以下内容生成时间线摘要:
|
||||
|
||||
{content_data.get('text', '')[:5000]}
|
||||
{content_data.get("text", "")[:5000]}
|
||||
|
||||
要求:
|
||||
1. 按时间顺序组织关键事件
|
||||
2. 标注时间节点(如果有)
|
||||
3. 突出里程碑事件"""
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3}
|
||||
payload = {
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0
|
||||
f"{self.kimi_base_url}/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -1040,14 +1118,18 @@ class AIManager:
|
||||
def get_prediction_model(self, model_id: str) -> PredictionModel | None:
|
||||
"""获取预测模型"""
|
||||
with self._get_db() as conn:
|
||||
row = conn.execute("SELECT * FROM prediction_models WHERE id = ?", (model_id,)).fetchone()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM prediction_models WHERE id = ?", (model_id,)
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_prediction_model(row)
|
||||
|
||||
def list_prediction_models(self, tenant_id: str, project_id: str | None = None) -> list[PredictionModel]:
|
||||
def list_prediction_models(
|
||||
self, tenant_id: str, project_id: str | None = None
|
||||
) -> list[PredictionModel]:
|
||||
"""列出预测模型"""
|
||||
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
|
||||
params = [tenant_id]
|
||||
@@ -1062,7 +1144,9 @@ class AIManager:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [self._row_to_prediction_model(row) for row in rows]
|
||||
|
||||
async def train_prediction_model(self, model_id: str, historical_data: list[dict]) -> PredictionModel:
|
||||
async def train_prediction_model(
|
||||
self, model_id: str, historical_data: list[dict]
|
||||
) -> PredictionModel:
|
||||
"""训练预测模型"""
|
||||
model = self.get_prediction_model(model_id)
|
||||
if not model:
|
||||
@@ -1150,7 +1234,8 @@ class AIManager:
|
||||
|
||||
# 更新预测计数
|
||||
conn.execute(
|
||||
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?", (model_id,)
|
||||
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?",
|
||||
(model_id,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
@@ -1243,7 +1328,9 @@ class AIManager:
|
||||
|
||||
# 计算增长率
|
||||
counts = [h.get("count", 0) for h in entity_history]
|
||||
growth_rates = [(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))]
|
||||
growth_rates = [
|
||||
(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))
|
||||
]
|
||||
avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0
|
||||
|
||||
# 预测下一个周期的实体数量
|
||||
@@ -1262,7 +1349,11 @@ class AIManager:
|
||||
relation_history = input_data.get("relation_history", [])
|
||||
|
||||
if len(relation_history) < 2:
|
||||
return {"predicted_relations": [], "confidence": 0.5, "explanation": "历史数据不足,无法预测关系演变"}
|
||||
return {
|
||||
"predicted_relations": [],
|
||||
"confidence": 0.5,
|
||||
"explanation": "历史数据不足,无法预测关系演变",
|
||||
}
|
||||
|
||||
# 分析关系变化趋势
|
||||
relation_counts = defaultdict(int)
|
||||
@@ -1273,7 +1364,9 @@ class AIManager:
|
||||
# 预测可能出现的新关系类型
|
||||
predicted_relations = [
|
||||
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
|
||||
for rel_type, count in sorted(relation_counts.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
for rel_type, count in sorted(
|
||||
relation_counts.items(), key=lambda x: x[1], reverse=True
|
||||
)[:5]
|
||||
]
|
||||
|
||||
return {
|
||||
@@ -1296,7 +1389,9 @@ class AIManager:
|
||||
|
||||
return [self._row_to_prediction_result(row) for row in rows]
|
||||
|
||||
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None:
|
||||
def update_prediction_feedback(
|
||||
self, prediction_id: str, actual_value: str, is_correct: bool
|
||||
) -> None:
|
||||
"""更新预测反馈(用于模型改进)"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
@@ -1405,9 +1500,11 @@ class AIManager:
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_ai_manager = None
|
||||
|
||||
|
||||
def get_ai_manager() -> AIManager:
|
||||
global _ai_manager
|
||||
if _ai_manager is None:
|
||||
|
||||
Reference in New Issue
Block a user