Files
insightflow/backend/ai_manager.py
AutoFix Bot f9dfb03d9a fix: auto-fix code issues (cron)
- 修复PEP8格式问题 (black格式化)
- 修复ai_manager.py中的行长度问题
2026-03-04 00:09:28 +08:00

1539 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
InsightFlow AI Manager - Phase 8 Task 4
AI 能力增强模块
- 自定义模型训练(领域特定实体识别)
- 多模态大模型集成GPT-4V、Claude 3
- 智能摘要与问答(基于知识图谱的 RAG
- 预测性分析(趋势预测、异常检测)
"""
import asyncio
import json
import os
import random
import re
import sqlite3
import statistics
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from enum import StrEnum
import httpx
# Database path
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
class ModelType(StrEnum):
"""模型类型"""
CUSTOM_NER = "custom_ner" # 自定义实体识别
MULTIMODAL = "multimodal" # 多模态
SUMMARIZATION = "summarization" # 摘要
PREDICTION = "prediction" # 预测
class ModelStatus(StrEnum):
"""模型状态"""
PENDING = "pending"
TRAINING = "training"
READY = "ready"
FAILED = "failed"
ARCHIVED = "archived"
class MultimodalProvider(StrEnum):
"""多模态模型提供商"""
GPT4V = "gpt-4-vision"
CLAUDE3 = "claude-3"
GEMINI = "gemini-pro-vision"
KIMI_VL = "kimi-vl"
class PredictionType(StrEnum):
"""预测类型"""
TREND = "trend" # 趋势预测
ANOMALY = "anomaly" # 异常检测
ENTITY_GROWTH = "entity_growth" # 实体增长预测
RELATION_EVOLUTION = "relation_evolution" # 关系演变预测
@dataclass
class CustomModel:
"""自定义模型"""
id: str
tenant_id: str
name: str
description: str
model_type: ModelType
status: ModelStatus
training_data: dict # 训练数据配置
hyperparameters: dict # 超参数
metrics: dict # 训练指标
model_path: str | None # 模型文件路径
created_at: str
updated_at: str
trained_at: str | None
created_by: str
@dataclass
class TrainingSample:
"""训练样本"""
id: str
model_id: str
text: str
entities: list[dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}]
metadata: dict
created_at: str
@dataclass
class MultimodalAnalysis:
"""多模态分析结果"""
id: str
tenant_id: str
project_id: str
provider: MultimodalProvider
input_type: str # image, video, audio, mixed
input_urls: list[str]
prompt: str
result: dict # 分析结果
tokens_used: int
cost: float
created_at: str
@dataclass
class KnowledgeGraphRAG:
"""基于知识图谱的 RAG 配置"""
id: str
tenant_id: str
project_id: str
name: str
description: str
kg_config: dict # 知识图谱配置
retrieval_config: dict # 检索配置
generation_config: dict # 生成配置
is_active: bool
created_at: str
updated_at: str
@dataclass
class RAGQuery:
"""RAG 查询记录"""
id: str
rag_id: str
query: str
context: dict # 检索到的上下文
answer: str
sources: list[dict] # 来源信息
confidence: float
tokens_used: int
latency_ms: int
created_at: str
@dataclass
class PredictionModel:
"""预测模型"""
id: str
tenant_id: str
project_id: str
name: str
prediction_type: PredictionType
target_entity_type: str | None # 目标实体类型
features: list[str] # 特征列表
model_config: dict # 模型配置
accuracy: float | None
last_trained_at: str | None
prediction_count: int
is_active: bool
created_at: str
updated_at: str
@dataclass
class PredictionResult:
"""预测结果"""
id: str
model_id: str
prediction_type: PredictionType
target_id: str | None # 预测目标ID
prediction_data: dict # 预测数据
confidence: float
explanation: str # 预测解释
actual_value: str | None # 实际值(用于验证)
is_correct: bool | None
created_at: str
@dataclass
class SmartSummary:
"""智能摘要"""
id: str
tenant_id: str
project_id: str
source_type: str # transcript, entity, project
source_id: str
summary_type: str # extractive, abstractive, key_points, timeline
content: str
key_points: list[str]
entities_mentioned: list[str]
confidence: float
tokens_used: int
created_at: str
class AIManager:
"""AI 能力管理主类"""
def __init__(self, db_path: str = DB_PATH) -> None:
self.db_path = db_path
self.kimi_api_key = os.getenv("KIMI_API_KEY", "")
self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
def _get_db(self) -> sqlite3.Connection:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
# ==================== 自定义模型训练 ====================
def create_custom_model(
self,
tenant_id: str,
name: str,
description: str,
model_type: ModelType,
training_data: dict,
hyperparameters: dict,
created_by: str,
) -> CustomModel:
"""创建自定义模型"""
model_id = f"cm_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
model = CustomModel(
id=model_id,
tenant_id=tenant_id,
name=name,
description=description,
model_type=model_type,
status=ModelStatus.PENDING,
training_data=training_data,
hyperparameters=hyperparameters,
metrics={},
model_path=None,
created_at=now,
updated_at=now,
trained_at=None,
created_by=created_by,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO custom_models
(id, tenant_id, name, description, model_type, status, training_data,
hyperparameters, metrics, model_path, created_at, updated_at,
trained_at, created_by)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model.id,
model.tenant_id,
model.name,
model.description,
model.model_type.value,
model.status.value,
json.dumps(model.training_data),
json.dumps(model.hyperparameters),
json.dumps(model.metrics),
model.model_path,
model.created_at,
model.updated_at,
model.trained_at,
model.created_by,
),
)
conn.commit()
return model
def get_custom_model(self, model_id: str) -> CustomModel | None:
"""获取自定义模型"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone()
if not row:
return None
return self._row_to_custom_model(row)
def list_custom_models(
self,
tenant_id: str,
model_type: ModelType | None = None,
status: ModelStatus | None = None,
) -> list[CustomModel]:
"""列出自定义模型"""
query = "SELECT * FROM custom_models WHERE tenant_id = ?"
params = [tenant_id]
if model_type:
query += " AND model_type = ?"
params.append(model_type.value)
if status:
query += " AND status = ?"
params.append(status.value)
query += " ORDER BY created_at DESC"
with self._get_db() as conn:
rows = conn.execute(query, params).fetchall()
return [self._row_to_custom_model(row) for row in rows]
def add_training_sample(
self,
model_id: str,
text: str,
entities: list[dict],
metadata: dict | None = None,
) -> TrainingSample:
"""添加训练样本"""
sample_id = f"ts_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
sample = TrainingSample(
id=sample_id,
model_id=model_id,
text=text,
entities=entities,
metadata=metadata or {},
created_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO training_samples
(id, model_id, text, entities, metadata, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
sample.id,
sample.model_id,
sample.text,
json.dumps(sample.entities),
json.dumps(sample.metadata),
sample.created_at,
),
)
conn.commit()
return sample
def get_training_samples(self, model_id: str) -> list[TrainingSample]:
"""获取训练样本"""
with self._get_db() as conn:
rows = conn.execute(
"SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at",
(model_id,),
).fetchall()
return [self._row_to_training_sample(row) for row in rows]
async def train_custom_model(self, model_id: str) -> CustomModel:
"""训练自定义模型"""
model = self.get_custom_model(model_id)
if not model:
raise ValueError(f"Model {model_id} not found")
# 更新状态为训练中
with self._get_db() as conn:
conn.execute(
"UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?",
(ModelStatus.TRAINING.value, datetime.now().isoformat(), model_id),
)
conn.commit()
try:
# 获取训练样本
samples = self.get_training_samples(model_id)
if len(samples) < 10:
raise ValueError("至少需要 10 个训练样本")
# 模拟训练过程(实际项目中这里会调用训练框架如 spaCy、Hugging Face 等)
await asyncio.sleep(2) # 模拟训练时间
# 计算训练指标
metrics = {
"samples_count": len(samples),
"epochs": model.hyperparameters.get("epochs", 10),
"learning_rate": model.hyperparameters.get("learning_rate", 0.001),
"precision": round(0.85 + random.random() * 0.1, 4),
"recall": round(0.82 + random.random() * 0.1, 4),
"f1_score": round(0.84 + random.random() * 0.1, 4),
"training_time_seconds": 120,
}
# 保存模型(模拟)
model_path = f"models/{model_id}.bin"
os.makedirs("models", exist_ok=True)
now = datetime.now().isoformat()
with self._get_db() as conn:
conn.execute(
"""
UPDATE custom_models
SET status = ?, metrics = ?, model_path = ?, trained_at = ?, updated_at = ?
WHERE id = ?
""",
(ModelStatus.READY.value, json.dumps(metrics), model_path, now, now, model_id),
)
conn.commit()
return self.get_custom_model(model_id)
except Exception as e:
with self._get_db() as conn:
conn.execute(
"UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?",
(ModelStatus.FAILED.value, datetime.now().isoformat(), model_id),
)
conn.commit()
raise e
async def predict_with_custom_model(self, model_id: str, text: str) -> list[dict]:
"""使用自定义模型进行预测"""
model = self.get_custom_model(model_id)
if not model or model.status != ModelStatus.READY:
raise ValueError(f"Model {model_id} not ready")
# 模拟预测(实际项目中加载模型并进行推理)
# 这里使用 LLM 模拟领域特定实体识别
entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"])
prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)}
文本: {text}
以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}]
只返回 JSON 数组,不要其他内容。"""
headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.1,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
# 解析 JSON
json_match = re.search(r"\[.*?\]", content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group())
except (json.JSONDecodeError, ValueError):
pass
return []
# ==================== 多模态大模型集成 ====================
async def analyze_multimodal(
self,
tenant_id: str,
project_id: str,
provider: MultimodalProvider,
input_type: str,
input_urls: list[str],
prompt: str,
) -> MultimodalAnalysis:
"""多模态分析"""
analysis_id = f"ma_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
# 根据提供商调用不同的 API
if provider == MultimodalProvider.GPT4V and self.openai_api_key:
result = await self._call_gpt4v(input_urls, prompt)
elif provider == MultimodalProvider.CLAUDE3 and self.anthropic_api_key:
result = await self._call_claude3(input_urls, prompt)
else:
# 默认使用 Kimi
result = await self._call_kimi_multimodal(input_urls, prompt)
analysis = MultimodalAnalysis(
id=analysis_id,
tenant_id=tenant_id,
project_id=project_id,
provider=provider,
input_type=input_type,
input_urls=input_urls,
prompt=prompt,
result=result,
tokens_used=result.get("tokens_used", 0),
cost=result.get("cost", 0.0),
created_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO multimodal_analyses
(id, tenant_id, project_id, provider, input_type, input_urls, prompt,
result, tokens_used, cost, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
analysis.id,
analysis.tenant_id,
analysis.project_id,
analysis.provider.value,
analysis.input_type,
json.dumps(analysis.input_urls),
analysis.prompt,
json.dumps(analysis.result),
analysis.tokens_used,
analysis.cost,
analysis.created_at,
),
)
conn.commit()
return analysis
async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict:
"""调用 GPT-4V"""
headers = {
"Authorization": f"Bearer {self.openai_api_key}",
"Content-Type": "application/json",
}
content = [{"type": "text", "text": prompt}]
content.extend([{"type": "image_url", "image_url": {"url": url}} for url in image_urls])
payload = {
"model": "gpt-4-vision-preview",
"messages": [{"role": "user", "content": content}],
"max_tokens": 2000,
}
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=120.0,
)
response.raise_for_status()
result = response.json()
return {
"content": result["choices"][0]["message"]["content"],
"tokens_used": result["usage"]["total_tokens"],
"cost": result["usage"]["total_tokens"] * 0.00001, # 估算成本
}
async def _call_claude3(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Claude 3"""
headers = {
"x-api-key": self.anthropic_api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
content = [{"type": "image", "source": {"type": "url", "url": url}} for url in image_urls]
content.append({"type": "text", "text": prompt})
payload = {
"model": "claude-3-opus-20240229",
"max_tokens": 2000,
"messages": [{"role": "user", "content": content}],
}
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload,
timeout=120.0,
)
response.raise_for_status()
result = response.json()
return {
"content": result["content"][0]["text"],
"tokens_used": result["usage"]["input_tokens"] + result["usage"]["output_tokens"],
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"])
* 0.000015,
}
async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict:
"""调用 Kimi 多模态模型"""
headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
# Kimi 目前可能不支持真正的多模态,这里模拟返回
# 实际实现时需要根据 Kimi API 更新
content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。"
payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": content}],
"temperature": 0.3,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
)
response.raise_for_status()
result = response.json()
return {
"content": result["choices"][0]["message"]["content"],
"tokens_used": result["usage"]["total_tokens"],
"cost": result["usage"]["total_tokens"] * 0.000005,
}
def get_multimodal_analyses(
self,
tenant_id: str,
project_id: str | None = None,
) -> list[MultimodalAnalysis]:
"""获取多模态分析历史"""
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
params = [tenant_id]
if project_id:
query += " AND project_id = ?"
params.append(project_id)
query += " ORDER BY created_at DESC"
with self._get_db() as conn:
rows = conn.execute(query, params).fetchall()
return [self._row_to_multimodal_analysis(row) for row in rows]
# ==================== 智能摘要与问答(基于知识图谱的 RAG ====================
def create_kg_rag(
self,
tenant_id: str,
project_id: str,
name: str,
description: str,
kg_config: dict,
retrieval_config: dict,
generation_config: dict,
) -> KnowledgeGraphRAG:
"""创建知识图谱 RAG 配置"""
rag_id = f"kgr_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
rag = KnowledgeGraphRAG(
id=rag_id,
tenant_id=tenant_id,
project_id=project_id,
name=name,
description=description,
kg_config=kg_config,
retrieval_config=retrieval_config,
generation_config=generation_config,
is_active=True,
created_at=now,
updated_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO kg_rag_configs
(id, tenant_id, project_id, name, description, kg_config, retrieval_config,
generation_config, is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
rag.id,
rag.tenant_id,
rag.project_id,
rag.name,
rag.description,
json.dumps(rag.kg_config),
json.dumps(rag.retrieval_config),
json.dumps(rag.generation_config),
rag.is_active,
rag.created_at,
rag.updated_at,
),
)
conn.commit()
return rag
def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None:
"""获取知识图谱 RAG 配置"""
with self._get_db() as conn:
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone()
if not row:
return None
return self._row_to_kg_rag(row)
def list_kg_rags(
self,
tenant_id: str,
project_id: str | None = None,
) -> list[KnowledgeGraphRAG]:
"""列出知识图谱 RAG 配置"""
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
params = [tenant_id]
if project_id:
query += " AND project_id = ?"
params.append(project_id)
query += " ORDER BY created_at DESC"
with self._get_db() as conn:
rows = conn.execute(query, params).fetchall()
return [self._row_to_kg_rag(row) for row in rows]
async def query_kg_rag(
self,
rag_id: str,
query: str,
project_entities: list[dict],
project_relations: list[dict],
) -> RAGQuery:
"""基于知识图谱的 RAG 查询"""
start_time = time.time()
rag = self.get_kg_rag(rag_id)
if not rag:
raise ValueError(f"RAG config {rag_id} not found")
# 1. 检索相关实体和关系
retrieval_config = rag.retrieval_config
top_k = retrieval_config.get("top_k", 5)
# 简单的语义检索(基于实体名称匹配)
query_lower = query.lower()
relevant_entities = []
for entity in project_entities:
score = 0
name = entity.get("name", "").lower()
definition = entity.get("definition", "").lower()
if name in query_lower or any(word in name for word in query_lower.split()):
score += 0.5
if any(word in definition for word in query_lower.split()):
score += 0.3
if score > 0:
relevant_entities.append({**entity, "relevance_score": score})
relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True)
relevant_entities = relevant_entities[:top_k]
# 检索相关关系
entity_ids = {e["id"] for e in relevant_entities}
relevant_relations = [
relation
for relation in project_relations
if (
relation.get("source_entity_id") in entity_ids
or relation.get("target_entity_id") in entity_ids
)
]
# 2. 构建上下文
context = {"entities": relevant_entities, "relations": relevant_relations[:10]}
context_text = self._build_kg_context(relevant_entities, relevant_relations)
# 3. 生成回答
generation_config = rag.generation_config
temperature = generation_config.get("temperature", 0.3)
max_tokens = generation_config.get("max_tokens", 1000)
prompt = f"""基于以下知识图谱信息回答问题:
## 知识图谱上下文
{context_text}
## 用户问题
{query}
请基于上述知识图谱信息回答问题。如果信息不足,请明确说明。
回答应该:
1. 准确引用知识图谱中的实体和关系
2. 如果涉及多个实体,说明它们之间的关联
3. 保持简洁专业"""
headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature,
"max_tokens": max_tokens,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
)
response.raise_for_status()
result = response.json()
answer = result["choices"][0]["message"]["content"]
tokens_used = result["usage"]["total_tokens"]
latency_ms = int((time.time() - start_time) * 1000)
# 4. 保存查询记录
query_id = f"rq_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
sources = [
{"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]}
for e in relevant_entities
]
rag_query = RAGQuery(
id=query_id,
rag_id=rag_id,
query=query,
context=context,
answer=answer,
sources=sources,
confidence=(
sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities)
if relevant_entities
else 0
),
tokens_used=tokens_used,
latency_ms=latency_ms,
created_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO rag_queries
(id, rag_id, query, context, answer, sources, confidence,
tokens_used, latency_ms, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
rag_query.id,
rag_query.rag_id,
rag_query.query,
json.dumps(rag_query.context),
rag_query.answer,
json.dumps(rag_query.sources),
rag_query.confidence,
rag_query.tokens_used,
rag_query.latency_ms,
rag_query.created_at,
),
)
conn.commit()
return rag_query
def _build_kg_context(self, entities: list[dict], relations: list[dict]) -> str:
"""构建知识图谱上下文文本"""
context = []
if entities:
context.append("### 相关实体")
for entity in entities:
name = entity.get("name", "")
entity_type = entity.get("type", "")
definition = entity.get("definition", "")
context.append(f"- **{name}** ({entity_type}): {definition}")
if relations:
context.append("\n### 相关关系")
for relation in relations:
source = relation.get("source_name", "")
target = relation.get("target_name", "")
rel_type = relation.get("relation_type", "")
evidence = relation.get("evidence", "")
context.append(f"- {source} --[{rel_type}]--> {target}")
if evidence:
context.append(f" - 依据: {evidence[:100]}...")
return "\n".join(context)
async def generate_smart_summary(
self,
tenant_id: str,
project_id: str,
source_type: str,
source_id: str,
summary_type: str,
content_data: dict,
) -> SmartSummary:
"""生成智能摘要"""
summary_id = f"ss_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
# 根据摘要类型生成不同的提示
if summary_type == "extractive":
prompt = f"""从以下内容中提取关键句子作为摘要:
{content_data.get("text", "")[:5000]}
要求:
1. 提取 3-5 个最重要的句子
2. 保持原文表述
3. 以 JSON 格式返回: {{"summary": "摘要内容", "key_points": ["要点1", "要点2"]}}"""
elif summary_type == "abstractive":
prompt = f"""对以下内容生成简洁的摘要:
{content_data.get("text", "")[:5000]}
要求:
1. 用 2-3 句话概括核心内容
2. 使用自己的语言重新表述
3. 包含关键实体和概念"""
elif summary_type == "key_points":
prompt = f"""从以下内容中提取关键要点:
{content_data.get("text", "")[:5000]}
要求:
1. 列出 5-8 个关键要点
2. 每个要点简洁明了
3. 以 JSON 格式返回: {{"key_points": ["要点1", "要点2", ...]}}"""
else: # timeline
prompt = f"""基于以下内容生成时间线摘要:
{content_data.get("text", "")[:5000]}
要求:
1. 按时间顺序组织关键事件
2. 标注时间节点(如果有)
3. 突出里程碑事件"""
headers = {
"Authorization": f"Bearer {self.kimi_api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.kimi_base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=60.0,
)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
tokens_used = result["usage"]["total_tokens"]
# 解析关键要点
key_points = []
# 尝试从 JSON 中提取
json_match = re.search(r"\{.*?\}", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
key_points = data.get("key_points", [])
if "summary" in data:
content = data["summary"]
except (json.JSONDecodeError, ValueError):
pass
# 如果没有提取到关键要点,从文本中提取
if not key_points:
lines = content.split("\n")
key_points = [
line.strip("- ").strip()
for line in lines
if line.strip().startswith("-") or line.strip().startswith("")
]
if not key_points:
key_points = [content[:200] + "..."] if len(content) > 200 else [content]
# 提取提及的实体
entities_mentioned = content_data.get("entities", [])
entity_names = [e.get("name", "") for e in entities_mentioned[:10]]
summary = SmartSummary(
id=summary_id,
tenant_id=tenant_id,
project_id=project_id,
source_type=source_type,
source_id=source_id,
summary_type=summary_type,
content=content,
key_points=key_points[:8],
entities_mentioned=entity_names,
confidence=0.85,
tokens_used=tokens_used,
created_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO smart_summaries
(id, tenant_id, project_id, source_type, source_id, summary_type, content,
key_points, entities_mentioned, confidence, tokens_used, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
summary.id,
summary.tenant_id,
summary.project_id,
summary.source_type,
summary.source_id,
summary.summary_type,
summary.content,
json.dumps(summary.key_points),
json.dumps(summary.entities_mentioned),
summary.confidence,
summary.tokens_used,
summary.created_at,
),
)
conn.commit()
return summary
# ==================== 预测性分析 ====================
def create_prediction_model(
self,
tenant_id: str,
project_id: str,
name: str,
prediction_type: PredictionType,
target_entity_type: str | None,
features: list[str],
model_config: dict,
) -> PredictionModel:
"""创建预测模型"""
model_id = f"pm_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
model = PredictionModel(
id=model_id,
tenant_id=tenant_id,
project_id=project_id,
name=name,
prediction_type=prediction_type,
target_entity_type=target_entity_type,
features=features,
model_config=model_config,
accuracy=None,
last_trained_at=None,
prediction_count=0,
is_active=True,
created_at=now,
updated_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO prediction_models
(id, tenant_id, project_id, name, prediction_type, target_entity_type,
features, model_config, accuracy, last_trained_at, prediction_count,
is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model.id,
model.tenant_id,
model.project_id,
model.name,
model.prediction_type.value,
model.target_entity_type,
json.dumps(model.features),
json.dumps(model.model_config),
model.accuracy,
model.last_trained_at,
model.prediction_count,
model.is_active,
model.created_at,
model.updated_at,
),
)
conn.commit()
return model
def get_prediction_model(self, model_id: str) -> PredictionModel | None:
"""获取预测模型"""
with self._get_db() as conn:
row = conn.execute(
"SELECT * FROM prediction_models WHERE id = ?",
(model_id,),
).fetchone()
if not row:
return None
return self._row_to_prediction_model(row)
def list_prediction_models(
self,
tenant_id: str,
project_id: str | None = None,
) -> list[PredictionModel]:
"""列出预测模型"""
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
params = [tenant_id]
if project_id:
query += " AND project_id = ?"
params.append(project_id)
query += " ORDER BY created_at DESC"
with self._get_db() as conn:
rows = conn.execute(query, params).fetchall()
return [self._row_to_prediction_model(row) for row in rows]
async def train_prediction_model(
self,
model_id: str,
historical_data: list[dict],
) -> PredictionModel:
"""训练预测模型"""
model = self.get_prediction_model(model_id)
if not model:
raise ValueError(f"Prediction model {model_id} not found")
# 模拟训练过程
await asyncio.sleep(1)
# 计算准确率(模拟)
accuracy = round(0.75 + random.random() * 0.2, 4)
now = datetime.now().isoformat()
with self._get_db() as conn:
conn.execute(
"""
UPDATE prediction_models
SET accuracy = ?, last_trained_at = ?, updated_at = ?
WHERE id = ?
""",
(accuracy, now, now, model_id),
)
conn.commit()
return self.get_prediction_model(model_id)
async def predict(self, model_id: str, input_data: dict) -> PredictionResult:
"""进行预测"""
model = self.get_prediction_model(model_id)
if not model or not model.is_active:
raise ValueError(f"Prediction model {model_id} not available")
prediction_id = f"pr_{uuid.uuid4().hex[:16]}"
now = datetime.now().isoformat()
# 根据预测类型进行不同的预测逻辑
if model.prediction_type == PredictionType.TREND:
prediction_data = self._predict_trend(input_data, model)
elif model.prediction_type == PredictionType.ANOMALY:
prediction_data = self._detect_anomaly(input_data, model)
elif model.prediction_type == PredictionType.ENTITY_GROWTH:
prediction_data = self._predict_entity_growth(input_data, model)
elif model.prediction_type == PredictionType.RELATION_EVOLUTION:
prediction_data = self._predict_relation_evolution(input_data, model)
else:
prediction_data = {"value": "unknown", "confidence": 0}
confidence = prediction_data.get("confidence", 0.8)
explanation = prediction_data.get("explanation", "基于历史数据模式预测")
result = PredictionResult(
id=prediction_id,
model_id=model_id,
prediction_type=model.prediction_type,
target_id=input_data.get("target_id"),
prediction_data=prediction_data,
confidence=confidence,
explanation=explanation,
actual_value=None,
is_correct=None,
created_at=now,
)
with self._get_db() as conn:
conn.execute(
"""
INSERT INTO prediction_results
(id, model_id, prediction_type, target_id, prediction_data, confidence,
explanation, actual_value, is_correct, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
result.id,
result.model_id,
result.prediction_type.value,
result.target_id,
json.dumps(result.prediction_data),
result.confidence,
result.explanation,
result.actual_value,
result.is_correct,
result.created_at,
),
)
# 更新预测计数
conn.execute(
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?",
(model_id,),
)
conn.commit()
return result
def _predict_trend(self, input_data: dict, model: PredictionModel) -> dict:
"""趋势预测"""
historical_values = input_data.get("historical_values", [])
if len(historical_values) < 2:
return {
"predicted_value": 0,
"trend": "stable",
"confidence": 0.5,
"explanation": "历史数据不足,无法准确预测趋势",
}
# 简单线性趋势预测 - 使用最小二乘法计算斜率
n = len(historical_values)
x = list(range(n))
y = historical_values
# 计算均值
mean_x = sum(x) / n
mean_y = sum(y) / n
# 计算斜率 (最小二乘法)
numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n))
denominator = sum((x[i] - mean_x) ** 2 for i in range(n))
slope = numerator / denominator if denominator != 0 else 0
# 预测下一个值
next_value = y[-1] + slope
trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable"
return {
"predicted_value": round(next_value, 2),
"trend": trend,
"slope": round(slope, 4),
"confidence": min(0.95, 0.6 + len(historical_values) * 0.02),
"explanation": f"基于{len(historical_values)}个历史数据点,预测趋势为{trend}",
}
def _detect_anomaly(self, input_data: dict, model: PredictionModel) -> dict:
"""异常检测"""
value = input_data.get("value")
historical_values = input_data.get("historical_values", [])
if not historical_values or value is None:
return {
"is_anomaly": False,
"anomaly_score": 0,
"confidence": 0.5,
"explanation": "数据不足,无法进行异常检测",
}
# 计算均值和标准差
mean = statistics.mean(historical_values)
std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0
if std == 0:
is_anomaly = value != mean
z_score = 0 if value == mean else 3
else:
z_score = abs(value - mean) / std
is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常
return {
"is_anomaly": is_anomaly,
"anomaly_score": round(min(z_score / 3, 1.0), 2),
"z_score": round(z_score, 2),
"mean": round(mean, 2),
"std": round(std, 2),
"confidence": min(0.95, 0.7 + len(historical_values) * 0.01),
"explanation": f"当前值偏离均值{z_score:.2f}个标准差,{'检测到异常' if is_anomaly else '处于正常范围'}",
}
def _predict_entity_growth(self, input_data: dict, model: PredictionModel) -> dict:
"""实体增长预测"""
entity_history = input_data.get("entity_history", [])
if len(entity_history) < 3:
return {
"predicted_count": len(entity_history),
"growth_rate": 0,
"confidence": 0.5,
"explanation": "历史数据不足,无法预测增长趋势",
}
# 计算增长率
counts = [h.get("count", 0) for h in entity_history]
growth_rates = [
(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))
]
avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0
# 预测下一个周期的实体数量
predicted_count = counts[-1] * (1 + avg_growth_rate)
return {
"predicted_count": round(predicted_count),
"current_count": counts[-1],
"growth_rate": round(avg_growth_rate, 4),
"confidence": min(0.9, 0.6 + len(entity_history) * 0.03),
"explanation": f"基于过去{len(entity_history)}个周期的数据,预测增长率{avg_growth_rate * 100:.1f}%",
}
def _predict_relation_evolution(self, input_data: dict, model: PredictionModel) -> dict:
"""关系演变预测"""
relation_history = input_data.get("relation_history", [])
if len(relation_history) < 2:
return {
"predicted_relations": [],
"confidence": 0.5,
"explanation": "历史数据不足,无法预测关系演变",
}
# 分析关系变化趋势
relation_counts = defaultdict(int)
for snapshot in relation_history:
for rel in snapshot.get("relations", []):
relation_counts[rel.get("type", "unknown")] += 1
# 预测可能出现的新关系类型
predicted_relations = [
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
for rel_type, count in sorted(
relation_counts.items(),
key=lambda x: x[1],
reverse=True,
)[:5]
]
return {
"predicted_relations": predicted_relations,
"relation_trends": dict(relation_counts),
"confidence": min(0.85, 0.6 + len(relation_history) * 0.05),
"explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势",
}
def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]:
"""获取预测结果历史"""
with self._get_db() as conn:
rows = conn.execute(
"""SELECT * FROM prediction_results
WHERE model_id = ?
ORDER BY created_at DESC
LIMIT ?""",
(model_id, limit),
).fetchall()
return [self._row_to_prediction_result(row) for row in rows]
def update_prediction_feedback(
self,
prediction_id: str,
actual_value: str,
is_correct: bool,
) -> None:
"""更新预测反馈(用于模型改进)"""
with self._get_db() as conn:
conn.execute(
"""UPDATE prediction_results
SET actual_value = ?, is_correct = ?
WHERE id = ?""",
(actual_value, is_correct, prediction_id),
)
conn.commit()
# ==================== 辅助方法 ====================
def _row_to_custom_model(self, row) -> CustomModel:
"""将数据库行转换为 CustomModel"""
return CustomModel(
id=row["id"],
tenant_id=row["tenant_id"],
name=row["name"],
description=row["description"],
model_type=ModelType(row["model_type"]),
status=ModelStatus(row["status"]),
training_data=json.loads(row["training_data"]),
hyperparameters=json.loads(row["hyperparameters"]),
metrics=json.loads(row["metrics"]),
model_path=row["model_path"],
created_at=row["created_at"],
updated_at=row["updated_at"],
trained_at=row["trained_at"],
created_by=row["created_by"],
)
def _row_to_training_sample(self, row) -> TrainingSample:
"""将数据库行转换为 TrainingSample"""
return TrainingSample(
id=row["id"],
model_id=row["model_id"],
text=row["text"],
entities=json.loads(row["entities"]),
metadata=json.loads(row["metadata"]),
created_at=row["created_at"],
)
def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis:
"""将数据库行转换为 MultimodalAnalysis"""
return MultimodalAnalysis(
id=row["id"],
tenant_id=row["tenant_id"],
project_id=row["project_id"],
provider=MultimodalProvider(row["provider"]),
input_type=row["input_type"],
input_urls=json.loads(row["input_urls"]),
prompt=row["prompt"],
result=json.loads(row["result"]),
tokens_used=row["tokens_used"],
cost=row["cost"],
created_at=row["created_at"],
)
def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG:
"""将数据库行转换为 KnowledgeGraphRAG"""
return KnowledgeGraphRAG(
id=row["id"],
tenant_id=row["tenant_id"],
project_id=row["project_id"],
name=row["name"],
description=row["description"],
kg_config=json.loads(row["kg_config"]),
retrieval_config=json.loads(row["retrieval_config"]),
generation_config=json.loads(row["generation_config"]),
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
)
def _row_to_prediction_model(self, row) -> PredictionModel:
"""将数据库行转换为 PredictionModel"""
return PredictionModel(
id=row["id"],
tenant_id=row["tenant_id"],
project_id=row["project_id"],
name=row["name"],
prediction_type=PredictionType(row["prediction_type"]),
target_entity_type=row["target_entity_type"],
features=json.loads(row["features"]),
model_config=json.loads(row["model_config"]),
accuracy=row["accuracy"],
last_trained_at=row["last_trained_at"],
prediction_count=row["prediction_count"],
is_active=bool(row["is_active"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
)
def _row_to_prediction_result(self, row) -> PredictionResult:
"""将数据库行转换为 PredictionResult"""
return PredictionResult(
id=row["id"],
model_id=row["model_id"],
prediction_type=PredictionType(row["prediction_type"]),
target_id=row["target_id"],
prediction_data=json.loads(row["prediction_data"]),
confidence=row["confidence"],
explanation=row["explanation"],
actual_value=row["actual_value"],
is_correct=row["is_correct"],
created_at=row["created_at"],
)
# Singleton instance
_ai_manager = None
def get_ai_manager() -> AIManager:
global _ai_manager
if _ai_manager is None:
_ai_manager = AIManager()
return _ai_manager