1517 lines
50 KiB
Python
1517 lines
50 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
InsightFlow AI Manager - Phase 8 Task 4
|
||
AI 能力增强模块
|
||
- 自定义模型训练(领域特定实体识别)
|
||
- 多模态大模型集成(GPT-4V、Claude 3)
|
||
- 智能摘要与问答(基于知识图谱的 RAG)
|
||
- 预测性分析(趋势预测、异常检测)
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import sqlite3
|
||
import statistics
|
||
import time
|
||
import uuid
|
||
from collections import defaultdict
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from enum import StrEnum
|
||
|
||
import httpx
|
||
|
||
# Database path
|
||
DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db")
|
||
|
||
|
||
class ModelType(StrEnum):
|
||
"""模型类型"""
|
||
|
||
CUSTOM_NER = "custom_ner" # 自定义实体识别
|
||
MULTIMODAL = "multimodal" # 多模态
|
||
SUMMARIZATION = "summarization" # 摘要
|
||
PREDICTION = "prediction" # 预测
|
||
|
||
|
||
class ModelStatus(StrEnum):
|
||
"""模型状态"""
|
||
|
||
PENDING = "pending"
|
||
TRAINING = "training"
|
||
READY = "ready"
|
||
FAILED = "failed"
|
||
ARCHIVED = "archived"
|
||
|
||
|
||
class MultimodalProvider(StrEnum):
|
||
"""多模态模型提供商"""
|
||
|
||
GPT4V = "gpt-4-vision"
|
||
CLAUDE3 = "claude-3"
|
||
GEMINI = "gemini-pro-vision"
|
||
KIMI_VL = "kimi-vl"
|
||
|
||
|
||
class PredictionType(StrEnum):
|
||
"""预测类型"""
|
||
|
||
TREND = "trend" # 趋势预测
|
||
ANOMALY = "anomaly" # 异常检测
|
||
ENTITY_GROWTH = "entity_growth" # 实体增长预测
|
||
RELATION_EVOLUTION = "relation_evolution" # 关系演变预测
|
||
|
||
|
||
@dataclass
|
||
class CustomModel:
|
||
"""自定义模型"""
|
||
|
||
id: str
|
||
tenant_id: str
|
||
name: str
|
||
description: str
|
||
model_type: ModelType
|
||
status: ModelStatus
|
||
training_data: dict # 训练数据配置
|
||
hyperparameters: dict # 超参数
|
||
metrics: dict # 训练指标
|
||
model_path: str | None # 模型文件路径
|
||
created_at: str
|
||
updated_at: str
|
||
trained_at: str | None
|
||
created_by: str
|
||
|
||
|
||
@dataclass
|
||
class TrainingSample:
|
||
"""训练样本"""
|
||
|
||
id: str
|
||
model_id: str
|
||
text: str
|
||
entities: list[dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}]
|
||
metadata: dict
|
||
created_at: str
|
||
|
||
|
||
@dataclass
|
||
class MultimodalAnalysis:
|
||
"""多模态分析结果"""
|
||
|
||
id: str
|
||
tenant_id: str
|
||
project_id: str
|
||
provider: MultimodalProvider
|
||
input_type: str # image, video, audio, mixed
|
||
input_urls: list[str]
|
||
prompt: str
|
||
result: dict # 分析结果
|
||
tokens_used: int
|
||
cost: float
|
||
created_at: str
|
||
|
||
|
||
@dataclass
|
||
class KnowledgeGraphRAG:
|
||
"""基于知识图谱的 RAG 配置"""
|
||
|
||
id: str
|
||
tenant_id: str
|
||
project_id: str
|
||
name: str
|
||
description: str
|
||
kg_config: dict # 知识图谱配置
|
||
retrieval_config: dict # 检索配置
|
||
generation_config: dict # 生成配置
|
||
is_active: bool
|
||
created_at: str
|
||
updated_at: str
|
||
|
||
|
||
@dataclass
|
||
class RAGQuery:
|
||
"""RAG 查询记录"""
|
||
|
||
id: str
|
||
rag_id: str
|
||
query: str
|
||
context: dict # 检索到的上下文
|
||
answer: str
|
||
sources: list[dict] # 来源信息
|
||
confidence: float
|
||
tokens_used: int
|
||
latency_ms: int
|
||
created_at: str
|
||
|
||
|
||
@dataclass
|
||
class PredictionModel:
|
||
"""预测模型"""
|
||
|
||
id: str
|
||
tenant_id: str
|
||
project_id: str
|
||
name: str
|
||
prediction_type: PredictionType
|
||
target_entity_type: str | None # 目标实体类型
|
||
features: list[str] # 特征列表
|
||
model_config: dict # 模型配置
|
||
accuracy: float | None
|
||
last_trained_at: str | None
|
||
prediction_count: int
|
||
is_active: bool
|
||
created_at: str
|
||
updated_at: str
|
||
|
||
|
||
@dataclass
|
||
class PredictionResult:
|
||
"""预测结果"""
|
||
|
||
id: str
|
||
model_id: str
|
||
prediction_type: PredictionType
|
||
target_id: str | None # 预测目标ID
|
||
prediction_data: dict # 预测数据
|
||
confidence: float
|
||
explanation: str # 预测解释
|
||
actual_value: str | None # 实际值(用于验证)
|
||
is_correct: bool | None
|
||
created_at: str
|
||
|
||
|
||
@dataclass
|
||
class SmartSummary:
|
||
"""智能摘要"""
|
||
|
||
id: str
|
||
tenant_id: str
|
||
project_id: str
|
||
source_type: str # transcript, entity, project
|
||
source_id: str
|
||
summary_type: str # extractive, abstractive, key_points, timeline
|
||
content: str
|
||
key_points: list[str]
|
||
entities_mentioned: list[str]
|
||
confidence: float
|
||
tokens_used: int
|
||
created_at: str
|
||
|
||
|
||
class AIManager:
|
||
"""AI 能力管理主类"""
|
||
|
||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||
self.db_path = db_path
|
||
self.kimi_api_key = os.getenv("KIMI_API_KEY", "")
|
||
self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
|
||
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
||
|
||
def _get_db(self) -> sqlite3.Connection:
|
||
"""获取数据库连接"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
# ==================== 自定义模型训练 ====================
|
||
|
||
def create_custom_model(
|
||
self,
|
||
tenant_id: str,
|
||
name: str,
|
||
description: str,
|
||
model_type: ModelType,
|
||
training_data: dict,
|
||
hyperparameters: dict,
|
||
created_by: str,
|
||
) -> CustomModel:
|
||
"""创建自定义模型"""
|
||
model_id = f"cm_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
model = CustomModel(
|
||
id=model_id,
|
||
tenant_id=tenant_id,
|
||
name=name,
|
||
description=description,
|
||
model_type=model_type,
|
||
status=ModelStatus.PENDING,
|
||
training_data=training_data,
|
||
hyperparameters=hyperparameters,
|
||
metrics={},
|
||
model_path=None,
|
||
created_at=now,
|
||
updated_at=now,
|
||
trained_at=None,
|
||
created_by=created_by,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO custom_models
|
||
(id, tenant_id, name, description, model_type, status, training_data,
|
||
hyperparameters, metrics, model_path, created_at, updated_at,
|
||
trained_at, created_by)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
model.id,
|
||
model.tenant_id,
|
||
model.name,
|
||
model.description,
|
||
model.model_type.value,
|
||
model.status.value,
|
||
json.dumps(model.training_data),
|
||
json.dumps(model.hyperparameters),
|
||
json.dumps(model.metrics),
|
||
model.model_path,
|
||
model.created_at,
|
||
model.updated_at,
|
||
model.trained_at,
|
||
model.created_by,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return model
|
||
|
||
def get_custom_model(self, model_id: str) -> CustomModel | None:
|
||
"""获取自定义模型"""
|
||
with self._get_db() as conn:
|
||
row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone()
|
||
|
||
if not row:
|
||
return None
|
||
|
||
return self._row_to_custom_model(row)
|
||
|
||
def list_custom_models(
|
||
self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None,
|
||
) -> list[CustomModel]:
|
||
"""列出自定义模型"""
|
||
query = "SELECT * FROM custom_models WHERE tenant_id = ?"
|
||
params = [tenant_id]
|
||
|
||
if model_type:
|
||
query += " AND model_type = ?"
|
||
params.append(model_type.value)
|
||
if status:
|
||
query += " AND status = ?"
|
||
params.append(status.value)
|
||
|
||
query += " ORDER BY created_at DESC"
|
||
|
||
with self._get_db() as conn:
|
||
rows = conn.execute(query, params).fetchall()
|
||
return [self._row_to_custom_model(row) for row in rows]
|
||
|
||
def add_training_sample(
|
||
self, model_id: str, text: str, entities: list[dict], metadata: dict = None,
|
||
) -> TrainingSample:
|
||
"""添加训练样本"""
|
||
sample_id = f"ts_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
sample = TrainingSample(
|
||
id=sample_id,
|
||
model_id=model_id,
|
||
text=text,
|
||
entities=entities,
|
||
metadata=metadata or {},
|
||
created_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO training_samples
|
||
(id, model_id, text, entities, metadata, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
sample.id,
|
||
sample.model_id,
|
||
sample.text,
|
||
json.dumps(sample.entities),
|
||
json.dumps(sample.metadata),
|
||
sample.created_at,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return sample
|
||
|
||
def get_training_samples(self, model_id: str) -> list[TrainingSample]:
|
||
"""获取训练样本"""
|
||
with self._get_db() as conn:
|
||
rows = conn.execute(
|
||
"SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at",
|
||
(model_id,),
|
||
).fetchall()
|
||
|
||
return [self._row_to_training_sample(row) for row in rows]
|
||
|
||
async def train_custom_model(self, model_id: str) -> CustomModel:
|
||
"""训练自定义模型"""
|
||
model = self.get_custom_model(model_id)
|
||
if not model:
|
||
raise ValueError(f"Model {model_id} not found")
|
||
|
||
# 更新状态为训练中
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?",
|
||
(ModelStatus.TRAINING.value, datetime.now().isoformat(), model_id),
|
||
)
|
||
conn.commit()
|
||
|
||
try:
|
||
# 获取训练样本
|
||
samples = self.get_training_samples(model_id)
|
||
|
||
if len(samples) < 10:
|
||
raise ValueError("至少需要 10 个训练样本")
|
||
|
||
# 模拟训练过程(实际项目中这里会调用训练框架如 spaCy、Hugging Face 等)
|
||
await asyncio.sleep(2) # 模拟训练时间
|
||
|
||
# 计算训练指标
|
||
metrics = {
|
||
"samples_count": len(samples),
|
||
"epochs": model.hyperparameters.get("epochs", 10),
|
||
"learning_rate": model.hyperparameters.get("learning_rate", 0.001),
|
||
"precision": round(0.85 + random.random() * 0.1, 4),
|
||
"recall": round(0.82 + random.random() * 0.1, 4),
|
||
"f1_score": round(0.84 + random.random() * 0.1, 4),
|
||
"training_time_seconds": 120,
|
||
}
|
||
|
||
# 保存模型(模拟)
|
||
model_path = f"models/{model_id}.bin"
|
||
os.makedirs("models", exist_ok=True)
|
||
|
||
now = datetime.now().isoformat()
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
UPDATE custom_models
|
||
SET status = ?, metrics = ?, model_path = ?, trained_at = ?, updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(ModelStatus.READY.value, json.dumps(metrics), model_path, now, now, model_id),
|
||
)
|
||
conn.commit()
|
||
|
||
return self.get_custom_model(model_id)
|
||
|
||
except Exception as e:
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?",
|
||
(ModelStatus.FAILED.value, datetime.now().isoformat(), model_id),
|
||
)
|
||
conn.commit()
|
||
raise e
|
||
|
||
async def predict_with_custom_model(self, model_id: str, text: str) -> list[dict]:
|
||
"""使用自定义模型进行预测"""
|
||
model = self.get_custom_model(model_id)
|
||
if not model or model.status != ModelStatus.READY:
|
||
raise ValueError(f"Model {model_id} not ready")
|
||
|
||
# 模拟预测(实际项目中加载模型并进行推理)
|
||
# 这里使用 LLM 模拟领域特定实体识别
|
||
|
||
entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"])
|
||
|
||
prompt = f"""从以下文本中提取实体,类型限定为: {", ".join(entity_types)}
|
||
|
||
文本: {text}
|
||
|
||
以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}]
|
||
只返回 JSON 数组,不要其他内容。"""
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
payload = {
|
||
"model": "k2p5",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.1,
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
f"{self.kimi_base_url}/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=60.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
content = result["choices"][0]["message"]["content"]
|
||
|
||
# 解析 JSON
|
||
json_match = re.search(r"\[.*?\]", content, re.DOTALL)
|
||
if json_match:
|
||
try:
|
||
entities = json.loads(json_match.group())
|
||
return entities
|
||
except (json.JSONDecodeError, ValueError):
|
||
pass
|
||
|
||
return []
|
||
|
||
# ==================== 多模态大模型集成 ====================
|
||
|
||
async def analyze_multimodal(
|
||
self,
|
||
tenant_id: str,
|
||
project_id: str,
|
||
provider: MultimodalProvider,
|
||
input_type: str,
|
||
input_urls: list[str],
|
||
prompt: str,
|
||
) -> MultimodalAnalysis:
|
||
"""多模态分析"""
|
||
analysis_id = f"ma_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
# 根据提供商调用不同的 API
|
||
if provider == MultimodalProvider.GPT4V and self.openai_api_key:
|
||
result = await self._call_gpt4v(input_urls, prompt)
|
||
elif provider == MultimodalProvider.CLAUDE3 and self.anthropic_api_key:
|
||
result = await self._call_claude3(input_urls, prompt)
|
||
else:
|
||
# 默认使用 Kimi
|
||
result = await self._call_kimi_multimodal(input_urls, prompt)
|
||
|
||
analysis = MultimodalAnalysis(
|
||
id=analysis_id,
|
||
tenant_id=tenant_id,
|
||
project_id=project_id,
|
||
provider=provider,
|
||
input_type=input_type,
|
||
input_urls=input_urls,
|
||
prompt=prompt,
|
||
result=result,
|
||
tokens_used=result.get("tokens_used", 0),
|
||
cost=result.get("cost", 0.0),
|
||
created_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO multimodal_analyses
|
||
(id, tenant_id, project_id, provider, input_type, input_urls, prompt,
|
||
result, tokens_used, cost, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
analysis.id,
|
||
analysis.tenant_id,
|
||
analysis.project_id,
|
||
analysis.provider.value,
|
||
analysis.input_type,
|
||
json.dumps(analysis.input_urls),
|
||
analysis.prompt,
|
||
json.dumps(analysis.result),
|
||
analysis.tokens_used,
|
||
analysis.cost,
|
||
analysis.created_at,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return analysis
|
||
|
||
async def _call_gpt4v(self, image_urls: list[str], prompt: str) -> dict:
|
||
"""调用 GPT-4V"""
|
||
headers = {
|
||
"Authorization": f"Bearer {self.openai_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
content = [{"type": "text", "text": prompt}]
|
||
for url in image_urls:
|
||
content.append({"type": "image_url", "image_url": {"url": url}})
|
||
|
||
payload = {
|
||
"model": "gpt-4-vision-preview",
|
||
"messages": [{"role": "user", "content": content}],
|
||
"max_tokens": 2000,
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
"https://api.openai.com/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=120.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
return {
|
||
"content": result["choices"][0]["message"]["content"],
|
||
"tokens_used": result["usage"]["total_tokens"],
|
||
"cost": result["usage"]["total_tokens"] * 0.00001, # 估算成本
|
||
}
|
||
|
||
async def _call_claude3(self, image_urls: list[str], prompt: str) -> dict:
|
||
"""调用 Claude 3"""
|
||
headers = {
|
||
"x-api-key": self.anthropic_api_key,
|
||
"Content-Type": "application/json",
|
||
"anthropic-version": "2023-06-01",
|
||
}
|
||
|
||
content = []
|
||
for url in image_urls:
|
||
content.append({"type": "image", "source": {"type": "url", "url": url}})
|
||
content.append({"type": "text", "text": prompt})
|
||
|
||
payload = {
|
||
"model": "claude-3-opus-20240229",
|
||
"max_tokens": 2000,
|
||
"messages": [{"role": "user", "content": content}],
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
"https://api.anthropic.com/v1/messages",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=120.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
return {
|
||
"content": result["content"][0]["text"],
|
||
"tokens_used": result["usage"]["input_tokens"] + result["usage"]["output_tokens"],
|
||
"cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"])
|
||
* 0.000015,
|
||
}
|
||
|
||
async def _call_kimi_multimodal(self, image_urls: list[str], prompt: str) -> dict:
|
||
"""调用 Kimi 多模态模型"""
|
||
headers = {
|
||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
# Kimi 目前可能不支持真正的多模态,这里模拟返回
|
||
# 实际实现时需要根据 Kimi API 更新
|
||
|
||
content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。"
|
||
|
||
payload = {
|
||
"model": "k2p5",
|
||
"messages": [{"role": "user", "content": content}],
|
||
"temperature": 0.3,
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
f"{self.kimi_base_url}/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=60.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
return {
|
||
"content": result["choices"][0]["message"]["content"],
|
||
"tokens_used": result["usage"]["total_tokens"],
|
||
"cost": result["usage"]["total_tokens"] * 0.000005,
|
||
}
|
||
|
||
def get_multimodal_analyses(
|
||
self, tenant_id: str, project_id: str | None = None,
|
||
) -> list[MultimodalAnalysis]:
|
||
"""获取多模态分析历史"""
|
||
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
|
||
params = [tenant_id]
|
||
|
||
if project_id:
|
||
query += " AND project_id = ?"
|
||
params.append(project_id)
|
||
|
||
query += " ORDER BY created_at DESC"
|
||
|
||
with self._get_db() as conn:
|
||
rows = conn.execute(query, params).fetchall()
|
||
return [self._row_to_multimodal_analysis(row) for row in rows]
|
||
|
||
# ==================== 智能摘要与问答(基于知识图谱的 RAG) ====================
|
||
|
||
def create_kg_rag(
|
||
self,
|
||
tenant_id: str,
|
||
project_id: str,
|
||
name: str,
|
||
description: str,
|
||
kg_config: dict,
|
||
retrieval_config: dict,
|
||
generation_config: dict,
|
||
) -> KnowledgeGraphRAG:
|
||
"""创建知识图谱 RAG 配置"""
|
||
rag_id = f"kgr_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
rag = KnowledgeGraphRAG(
|
||
id=rag_id,
|
||
tenant_id=tenant_id,
|
||
project_id=project_id,
|
||
name=name,
|
||
description=description,
|
||
kg_config=kg_config,
|
||
retrieval_config=retrieval_config,
|
||
generation_config=generation_config,
|
||
is_active=True,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO kg_rag_configs
|
||
(id, tenant_id, project_id, name, description, kg_config, retrieval_config,
|
||
generation_config, is_active, created_at, updated_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
rag.id,
|
||
rag.tenant_id,
|
||
rag.project_id,
|
||
rag.name,
|
||
rag.description,
|
||
json.dumps(rag.kg_config),
|
||
json.dumps(rag.retrieval_config),
|
||
json.dumps(rag.generation_config),
|
||
rag.is_active,
|
||
rag.created_at,
|
||
rag.updated_at,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return rag
|
||
|
||
def get_kg_rag(self, rag_id: str) -> KnowledgeGraphRAG | None:
|
||
"""获取知识图谱 RAG 配置"""
|
||
with self._get_db() as conn:
|
||
row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone()
|
||
|
||
if not row:
|
||
return None
|
||
|
||
return self._row_to_kg_rag(row)
|
||
|
||
def list_kg_rags(
|
||
self, tenant_id: str, project_id: str | None = None,
|
||
) -> list[KnowledgeGraphRAG]:
|
||
"""列出知识图谱 RAG 配置"""
|
||
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
|
||
params = [tenant_id]
|
||
|
||
if project_id:
|
||
query += " AND project_id = ?"
|
||
params.append(project_id)
|
||
|
||
query += " ORDER BY created_at DESC"
|
||
|
||
with self._get_db() as conn:
|
||
rows = conn.execute(query, params).fetchall()
|
||
return [self._row_to_kg_rag(row) for row in rows]
|
||
|
||
async def query_kg_rag(
|
||
self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict],
|
||
) -> RAGQuery:
|
||
"""基于知识图谱的 RAG 查询"""
|
||
start_time = time.time()
|
||
|
||
rag = self.get_kg_rag(rag_id)
|
||
if not rag:
|
||
raise ValueError(f"RAG config {rag_id} not found")
|
||
|
||
# 1. 检索相关实体和关系
|
||
retrieval_config = rag.retrieval_config
|
||
top_k = retrieval_config.get("top_k", 5)
|
||
|
||
# 简单的语义检索(基于实体名称匹配)
|
||
query_lower = query.lower()
|
||
relevant_entities = []
|
||
for entity in project_entities:
|
||
score = 0
|
||
name = entity.get("name", "").lower()
|
||
definition = entity.get("definition", "").lower()
|
||
|
||
if name in query_lower or any(word in name for word in query_lower.split()):
|
||
score += 0.5
|
||
if any(word in definition for word in query_lower.split()):
|
||
score += 0.3
|
||
|
||
if score > 0:
|
||
relevant_entities.append({**entity, "relevance_score": score})
|
||
|
||
relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||
relevant_entities = relevant_entities[:top_k]
|
||
|
||
# 检索相关关系
|
||
relevant_relations = []
|
||
entity_ids = {e["id"] for e in relevant_entities}
|
||
for relation in project_relations:
|
||
if (
|
||
relation.get("source_entity_id") in entity_ids
|
||
or relation.get("target_entity_id") in entity_ids
|
||
):
|
||
relevant_relations.append(relation)
|
||
|
||
# 2. 构建上下文
|
||
context = {"entities": relevant_entities, "relations": relevant_relations[:10]}
|
||
|
||
context_text = self._build_kg_context(relevant_entities, relevant_relations)
|
||
|
||
# 3. 生成回答
|
||
generation_config = rag.generation_config
|
||
temperature = generation_config.get("temperature", 0.3)
|
||
max_tokens = generation_config.get("max_tokens", 1000)
|
||
|
||
prompt = f"""基于以下知识图谱信息回答问题:
|
||
|
||
## 知识图谱上下文
|
||
{context_text}
|
||
|
||
## 用户问题
|
||
{query}
|
||
|
||
请基于上述知识图谱信息回答问题。如果信息不足,请明确说明。
|
||
回答应该:
|
||
1. 准确引用知识图谱中的实体和关系
|
||
2. 如果涉及多个实体,说明它们之间的关联
|
||
3. 保持简洁专业"""
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
payload = {
|
||
"model": "k2p5",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
f"{self.kimi_base_url}/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=60.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
answer = result["choices"][0]["message"]["content"]
|
||
tokens_used = result["usage"]["total_tokens"]
|
||
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 4. 保存查询记录
|
||
query_id = f"rq_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
sources = [
|
||
{"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]}
|
||
for e in relevant_entities
|
||
]
|
||
|
||
rag_query = RAGQuery(
|
||
id=query_id,
|
||
rag_id=rag_id,
|
||
query=query,
|
||
context=context,
|
||
answer=answer,
|
||
sources=sources,
|
||
confidence=(
|
||
sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities)
|
||
if relevant_entities
|
||
else 0
|
||
),
|
||
tokens_used=tokens_used,
|
||
latency_ms=latency_ms,
|
||
created_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO rag_queries
|
||
(id, rag_id, query, context, answer, sources, confidence,
|
||
tokens_used, latency_ms, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
rag_query.id,
|
||
rag_query.rag_id,
|
||
rag_query.query,
|
||
json.dumps(rag_query.context),
|
||
rag_query.answer,
|
||
json.dumps(rag_query.sources),
|
||
rag_query.confidence,
|
||
rag_query.tokens_used,
|
||
rag_query.latency_ms,
|
||
rag_query.created_at,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return rag_query
|
||
|
||
def _build_kg_context(self, entities: list[dict], relations: list[dict]) -> str:
|
||
"""构建知识图谱上下文文本"""
|
||
context = []
|
||
|
||
if entities:
|
||
context.append("### 相关实体")
|
||
for entity in entities:
|
||
name = entity.get("name", "")
|
||
entity_type = entity.get("type", "")
|
||
definition = entity.get("definition", "")
|
||
context.append(f"- **{name}** ({entity_type}): {definition}")
|
||
|
||
if relations:
|
||
context.append("\n### 相关关系")
|
||
for relation in relations:
|
||
source = relation.get("source_name", "")
|
||
target = relation.get("target_name", "")
|
||
rel_type = relation.get("relation_type", "")
|
||
evidence = relation.get("evidence", "")
|
||
context.append(f"- {source} --[{rel_type}]--> {target}")
|
||
if evidence:
|
||
context.append(f" - 依据: {evidence[:100]}...")
|
||
|
||
return "\n".join(context)
|
||
|
||
async def generate_smart_summary(
|
||
self,
|
||
tenant_id: str,
|
||
project_id: str,
|
||
source_type: str,
|
||
source_id: str,
|
||
summary_type: str,
|
||
content_data: dict,
|
||
) -> SmartSummary:
|
||
"""生成智能摘要"""
|
||
summary_id = f"ss_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
# 根据摘要类型生成不同的提示
|
||
if summary_type == "extractive":
|
||
prompt = f"""从以下内容中提取关键句子作为摘要:
|
||
|
||
{content_data.get("text", "")[:5000]}
|
||
|
||
要求:
|
||
1. 提取 3-5 个最重要的句子
|
||
2. 保持原文表述
|
||
3. 以 JSON 格式返回: {{"summary": "摘要内容", "key_points": ["要点1", "要点2"]}}"""
|
||
|
||
elif summary_type == "abstractive":
|
||
prompt = f"""对以下内容生成简洁的摘要:
|
||
|
||
{content_data.get("text", "")[:5000]}
|
||
|
||
要求:
|
||
1. 用 2-3 句话概括核心内容
|
||
2. 使用自己的语言重新表述
|
||
3. 包含关键实体和概念"""
|
||
|
||
elif summary_type == "key_points":
|
||
prompt = f"""从以下内容中提取关键要点:
|
||
|
||
{content_data.get("text", "")[:5000]}
|
||
|
||
要求:
|
||
1. 列出 5-8 个关键要点
|
||
2. 每个要点简洁明了
|
||
3. 以 JSON 格式返回: {{"key_points": ["要点1", "要点2", ...]}}"""
|
||
|
||
else: # timeline
|
||
prompt = f"""基于以下内容生成时间线摘要:
|
||
|
||
{content_data.get("text", "")[:5000]}
|
||
|
||
要求:
|
||
1. 按时间顺序组织关键事件
|
||
2. 标注时间节点(如果有)
|
||
3. 突出里程碑事件"""
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.kimi_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
payload = {
|
||
"model": "k2p5",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.3,
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
f"{self.kimi_base_url}/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=60.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
content = result["choices"][0]["message"]["content"]
|
||
tokens_used = result["usage"]["total_tokens"]
|
||
|
||
# 解析关键要点
|
||
key_points = []
|
||
|
||
# 尝试从 JSON 中提取
|
||
json_match = re.search(r"\{.*?\}", content, re.DOTALL)
|
||
if json_match:
|
||
try:
|
||
data = json.loads(json_match.group())
|
||
key_points = data.get("key_points", [])
|
||
if "summary" in data:
|
||
content = data["summary"]
|
||
except (json.JSONDecodeError, ValueError):
|
||
pass
|
||
|
||
# 如果没有提取到关键要点,从文本中提取
|
||
if not key_points:
|
||
lines = content.split("\n")
|
||
key_points = [
|
||
line.strip("- ").strip()
|
||
for line in lines
|
||
if line.strip().startswith("-") or line.strip().startswith("•")
|
||
]
|
||
if not key_points:
|
||
key_points = [content[:200] + "..."] if len(content) > 200 else [content]
|
||
|
||
# 提取提及的实体
|
||
entities_mentioned = content_data.get("entities", [])
|
||
entity_names = [e.get("name", "") for e in entities_mentioned[:10]]
|
||
|
||
summary = SmartSummary(
|
||
id=summary_id,
|
||
tenant_id=tenant_id,
|
||
project_id=project_id,
|
||
source_type=source_type,
|
||
source_id=source_id,
|
||
summary_type=summary_type,
|
||
content=content,
|
||
key_points=key_points[:8],
|
||
entities_mentioned=entity_names,
|
||
confidence=0.85,
|
||
tokens_used=tokens_used,
|
||
created_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO smart_summaries
|
||
(id, tenant_id, project_id, source_type, source_id, summary_type, content,
|
||
key_points, entities_mentioned, confidence, tokens_used, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
summary.id,
|
||
summary.tenant_id,
|
||
summary.project_id,
|
||
summary.source_type,
|
||
summary.source_id,
|
||
summary.summary_type,
|
||
summary.content,
|
||
json.dumps(summary.key_points),
|
||
json.dumps(summary.entities_mentioned),
|
||
summary.confidence,
|
||
summary.tokens_used,
|
||
summary.created_at,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return summary
|
||
|
||
# ==================== 预测性分析 ====================
|
||
|
||
def create_prediction_model(
|
||
self,
|
||
tenant_id: str,
|
||
project_id: str,
|
||
name: str,
|
||
prediction_type: PredictionType,
|
||
target_entity_type: str | None,
|
||
features: list[str],
|
||
model_config: dict,
|
||
) -> PredictionModel:
|
||
"""创建预测模型"""
|
||
model_id = f"pm_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
model = PredictionModel(
|
||
id=model_id,
|
||
tenant_id=tenant_id,
|
||
project_id=project_id,
|
||
name=name,
|
||
prediction_type=prediction_type,
|
||
target_entity_type=target_entity_type,
|
||
features=features,
|
||
model_config=model_config,
|
||
accuracy=None,
|
||
last_trained_at=None,
|
||
prediction_count=0,
|
||
is_active=True,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO prediction_models
|
||
(id, tenant_id, project_id, name, prediction_type, target_entity_type,
|
||
features, model_config, accuracy, last_trained_at, prediction_count,
|
||
is_active, created_at, updated_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
model.id,
|
||
model.tenant_id,
|
||
model.project_id,
|
||
model.name,
|
||
model.prediction_type.value,
|
||
model.target_entity_type,
|
||
json.dumps(model.features),
|
||
json.dumps(model.model_config),
|
||
model.accuracy,
|
||
model.last_trained_at,
|
||
model.prediction_count,
|
||
model.is_active,
|
||
model.created_at,
|
||
model.updated_at,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
return model
|
||
|
||
def get_prediction_model(self, model_id: str) -> PredictionModel | None:
|
||
"""获取预测模型"""
|
||
with self._get_db() as conn:
|
||
row = conn.execute(
|
||
"SELECT * FROM prediction_models WHERE id = ?", (model_id,),
|
||
).fetchone()
|
||
|
||
if not row:
|
||
return None
|
||
|
||
return self._row_to_prediction_model(row)
|
||
|
||
def list_prediction_models(
|
||
self, tenant_id: str, project_id: str | None = None,
|
||
) -> list[PredictionModel]:
|
||
"""列出预测模型"""
|
||
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
|
||
params = [tenant_id]
|
||
|
||
if project_id:
|
||
query += " AND project_id = ?"
|
||
params.append(project_id)
|
||
|
||
query += " ORDER BY created_at DESC"
|
||
|
||
with self._get_db() as conn:
|
||
rows = conn.execute(query, params).fetchall()
|
||
return [self._row_to_prediction_model(row) for row in rows]
|
||
|
||
async def train_prediction_model(
|
||
self, model_id: str, historical_data: list[dict],
|
||
) -> PredictionModel:
|
||
"""训练预测模型"""
|
||
model = self.get_prediction_model(model_id)
|
||
if not model:
|
||
raise ValueError(f"Prediction model {model_id} not found")
|
||
|
||
# 模拟训练过程
|
||
await asyncio.sleep(1)
|
||
|
||
# 计算准确率(模拟)
|
||
accuracy = round(0.75 + random.random() * 0.2, 4)
|
||
|
||
now = datetime.now().isoformat()
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
UPDATE prediction_models
|
||
SET accuracy = ?, last_trained_at = ?, updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(accuracy, now, now, model_id),
|
||
)
|
||
conn.commit()
|
||
|
||
return self.get_prediction_model(model_id)
|
||
|
||
async def predict(self, model_id: str, input_data: dict) -> PredictionResult:
|
||
"""进行预测"""
|
||
model = self.get_prediction_model(model_id)
|
||
if not model or not model.is_active:
|
||
raise ValueError(f"Prediction model {model_id} not available")
|
||
|
||
prediction_id = f"pr_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now().isoformat()
|
||
|
||
# 根据预测类型进行不同的预测逻辑
|
||
if model.prediction_type == PredictionType.TREND:
|
||
prediction_data = self._predict_trend(input_data, model)
|
||
elif model.prediction_type == PredictionType.ANOMALY:
|
||
prediction_data = self._detect_anomaly(input_data, model)
|
||
elif model.prediction_type == PredictionType.ENTITY_GROWTH:
|
||
prediction_data = self._predict_entity_growth(input_data, model)
|
||
elif model.prediction_type == PredictionType.RELATION_EVOLUTION:
|
||
prediction_data = self._predict_relation_evolution(input_data, model)
|
||
else:
|
||
prediction_data = {"value": "unknown", "confidence": 0}
|
||
|
||
confidence = prediction_data.get("confidence", 0.8)
|
||
explanation = prediction_data.get("explanation", "基于历史数据模式预测")
|
||
|
||
result = PredictionResult(
|
||
id=prediction_id,
|
||
model_id=model_id,
|
||
prediction_type=model.prediction_type,
|
||
target_id=input_data.get("target_id"),
|
||
prediction_data=prediction_data,
|
||
confidence=confidence,
|
||
explanation=explanation,
|
||
actual_value=None,
|
||
is_correct=None,
|
||
created_at=now,
|
||
)
|
||
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO prediction_results
|
||
(id, model_id, prediction_type, target_id, prediction_data, confidence,
|
||
explanation, actual_value, is_correct, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
result.id,
|
||
result.model_id,
|
||
result.prediction_type.value,
|
||
result.target_id,
|
||
json.dumps(result.prediction_data),
|
||
result.confidence,
|
||
result.explanation,
|
||
result.actual_value,
|
||
result.is_correct,
|
||
result.created_at,
|
||
),
|
||
)
|
||
|
||
# 更新预测计数
|
||
conn.execute(
|
||
"UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?",
|
||
(model_id,),
|
||
)
|
||
conn.commit()
|
||
|
||
return result
|
||
|
||
def _predict_trend(self, input_data: dict, model: PredictionModel) -> dict:
|
||
"""趋势预测"""
|
||
historical_values = input_data.get("historical_values", [])
|
||
|
||
if len(historical_values) < 2:
|
||
return {
|
||
"predicted_value": 0,
|
||
"trend": "stable",
|
||
"confidence": 0.5,
|
||
"explanation": "历史数据不足,无法准确预测趋势",
|
||
}
|
||
|
||
# 简单线性趋势预测 - 使用最小二乘法计算斜率
|
||
n = len(historical_values)
|
||
x = list(range(n))
|
||
y = historical_values
|
||
|
||
# 计算均值
|
||
mean_x = sum(x) / n
|
||
mean_y = sum(y) / n
|
||
|
||
# 计算斜率 (最小二乘法)
|
||
numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n))
|
||
denominator = sum((x[i] - mean_x) ** 2 for i in range(n))
|
||
slope = numerator / denominator if denominator != 0 else 0
|
||
|
||
# 预测下一个值
|
||
next_value = y[-1] + slope
|
||
|
||
trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable"
|
||
|
||
return {
|
||
"predicted_value": round(next_value, 2),
|
||
"trend": trend,
|
||
"slope": round(slope, 4),
|
||
"confidence": min(0.95, 0.6 + len(historical_values) * 0.02),
|
||
"explanation": f"基于{len(historical_values)}个历史数据点,预测趋势为{trend}",
|
||
}
|
||
|
||
def _detect_anomaly(self, input_data: dict, model: PredictionModel) -> dict:
|
||
"""异常检测"""
|
||
value = input_data.get("value")
|
||
historical_values = input_data.get("historical_values", [])
|
||
|
||
if not historical_values or value is None:
|
||
return {
|
||
"is_anomaly": False,
|
||
"anomaly_score": 0,
|
||
"confidence": 0.5,
|
||
"explanation": "数据不足,无法进行异常检测",
|
||
}
|
||
|
||
# 计算均值和标准差
|
||
mean = statistics.mean(historical_values)
|
||
std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0
|
||
|
||
if std == 0:
|
||
is_anomaly = value != mean
|
||
z_score = 0 if value == mean else 3
|
||
else:
|
||
z_score = abs(value - mean) / std
|
||
is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常
|
||
|
||
return {
|
||
"is_anomaly": is_anomaly,
|
||
"anomaly_score": round(min(z_score / 3, 1.0), 2),
|
||
"z_score": round(z_score, 2),
|
||
"mean": round(mean, 2),
|
||
"std": round(std, 2),
|
||
"confidence": min(0.95, 0.7 + len(historical_values) * 0.01),
|
||
"explanation": f"当前值偏离均值{z_score:.2f}个标准差,{'检测到异常' if is_anomaly else '处于正常范围'}",
|
||
}
|
||
|
||
def _predict_entity_growth(self, input_data: dict, model: PredictionModel) -> dict:
|
||
"""实体增长预测"""
|
||
entity_history = input_data.get("entity_history", [])
|
||
|
||
if len(entity_history) < 3:
|
||
return {
|
||
"predicted_count": len(entity_history),
|
||
"growth_rate": 0,
|
||
"confidence": 0.5,
|
||
"explanation": "历史数据不足,无法预测增长趋势",
|
||
}
|
||
|
||
# 计算增长率
|
||
counts = [h.get("count", 0) for h in entity_history]
|
||
growth_rates = [
|
||
(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))
|
||
]
|
||
avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0
|
||
|
||
# 预测下一个周期的实体数量
|
||
predicted_count = counts[-1] * (1 + avg_growth_rate)
|
||
|
||
return {
|
||
"predicted_count": round(predicted_count),
|
||
"current_count": counts[-1],
|
||
"growth_rate": round(avg_growth_rate, 4),
|
||
"confidence": min(0.9, 0.6 + len(entity_history) * 0.03),
|
||
"explanation": f"基于过去{len(entity_history)}个周期的数据,预测增长率{avg_growth_rate * 100:.1f}%",
|
||
}
|
||
|
||
def _predict_relation_evolution(self, input_data: dict, model: PredictionModel) -> dict:
|
||
"""关系演变预测"""
|
||
relation_history = input_data.get("relation_history", [])
|
||
|
||
if len(relation_history) < 2:
|
||
return {
|
||
"predicted_relations": [],
|
||
"confidence": 0.5,
|
||
"explanation": "历史数据不足,无法预测关系演变",
|
||
}
|
||
|
||
# 分析关系变化趋势
|
||
relation_counts = defaultdict(int)
|
||
for snapshot in relation_history:
|
||
for rel in snapshot.get("relations", []):
|
||
relation_counts[rel.get("type", "unknown")] += 1
|
||
|
||
# 预测可能出现的新关系类型
|
||
predicted_relations = [
|
||
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
|
||
for rel_type, count in sorted(
|
||
relation_counts.items(), key=lambda x: x[1], reverse=True,
|
||
)[:5]
|
||
]
|
||
|
||
return {
|
||
"predicted_relations": predicted_relations,
|
||
"relation_trends": dict(relation_counts),
|
||
"confidence": min(0.85, 0.6 + len(relation_history) * 0.05),
|
||
"explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势",
|
||
}
|
||
|
||
def get_prediction_results(self, model_id: str, limit: int = 100) -> list[PredictionResult]:
|
||
"""获取预测结果历史"""
|
||
with self._get_db() as conn:
|
||
rows = conn.execute(
|
||
"""SELECT * FROM prediction_results
|
||
WHERE model_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT ?""",
|
||
(model_id, limit),
|
||
).fetchall()
|
||
|
||
return [self._row_to_prediction_result(row) for row in rows]
|
||
|
||
def update_prediction_feedback(
|
||
self, prediction_id: str, actual_value: str, is_correct: bool,
|
||
) -> None:
|
||
"""更新预测反馈(用于模型改进)"""
|
||
with self._get_db() as conn:
|
||
conn.execute(
|
||
"""UPDATE prediction_results
|
||
SET actual_value = ?, is_correct = ?
|
||
WHERE id = ?""",
|
||
(actual_value, is_correct, prediction_id),
|
||
)
|
||
conn.commit()
|
||
|
||
# ==================== 辅助方法 ====================
|
||
|
||
def _row_to_custom_model(self, row) -> CustomModel:
|
||
"""将数据库行转换为 CustomModel"""
|
||
return CustomModel(
|
||
id=row["id"],
|
||
tenant_id=row["tenant_id"],
|
||
name=row["name"],
|
||
description=row["description"],
|
||
model_type=ModelType(row["model_type"]),
|
||
status=ModelStatus(row["status"]),
|
||
training_data=json.loads(row["training_data"]),
|
||
hyperparameters=json.loads(row["hyperparameters"]),
|
||
metrics=json.loads(row["metrics"]),
|
||
model_path=row["model_path"],
|
||
created_at=row["created_at"],
|
||
updated_at=row["updated_at"],
|
||
trained_at=row["trained_at"],
|
||
created_by=row["created_by"],
|
||
)
|
||
|
||
def _row_to_training_sample(self, row) -> TrainingSample:
|
||
"""将数据库行转换为 TrainingSample"""
|
||
return TrainingSample(
|
||
id=row["id"],
|
||
model_id=row["model_id"],
|
||
text=row["text"],
|
||
entities=json.loads(row["entities"]),
|
||
metadata=json.loads(row["metadata"]),
|
||
created_at=row["created_at"],
|
||
)
|
||
|
||
def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis:
|
||
"""将数据库行转换为 MultimodalAnalysis"""
|
||
return MultimodalAnalysis(
|
||
id=row["id"],
|
||
tenant_id=row["tenant_id"],
|
||
project_id=row["project_id"],
|
||
provider=MultimodalProvider(row["provider"]),
|
||
input_type=row["input_type"],
|
||
input_urls=json.loads(row["input_urls"]),
|
||
prompt=row["prompt"],
|
||
result=json.loads(row["result"]),
|
||
tokens_used=row["tokens_used"],
|
||
cost=row["cost"],
|
||
created_at=row["created_at"],
|
||
)
|
||
|
||
def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG:
|
||
"""将数据库行转换为 KnowledgeGraphRAG"""
|
||
return KnowledgeGraphRAG(
|
||
id=row["id"],
|
||
tenant_id=row["tenant_id"],
|
||
project_id=row["project_id"],
|
||
name=row["name"],
|
||
description=row["description"],
|
||
kg_config=json.loads(row["kg_config"]),
|
||
retrieval_config=json.loads(row["retrieval_config"]),
|
||
generation_config=json.loads(row["generation_config"]),
|
||
is_active=bool(row["is_active"]),
|
||
created_at=row["created_at"],
|
||
updated_at=row["updated_at"],
|
||
)
|
||
|
||
def _row_to_prediction_model(self, row) -> PredictionModel:
|
||
"""将数据库行转换为 PredictionModel"""
|
||
return PredictionModel(
|
||
id=row["id"],
|
||
tenant_id=row["tenant_id"],
|
||
project_id=row["project_id"],
|
||
name=row["name"],
|
||
prediction_type=PredictionType(row["prediction_type"]),
|
||
target_entity_type=row["target_entity_type"],
|
||
features=json.loads(row["features"]),
|
||
model_config=json.loads(row["model_config"]),
|
||
accuracy=row["accuracy"],
|
||
last_trained_at=row["last_trained_at"],
|
||
prediction_count=row["prediction_count"],
|
||
is_active=bool(row["is_active"]),
|
||
created_at=row["created_at"],
|
||
updated_at=row["updated_at"],
|
||
)
|
||
|
||
def _row_to_prediction_result(self, row) -> PredictionResult:
|
||
"""将数据库行转换为 PredictionResult"""
|
||
return PredictionResult(
|
||
id=row["id"],
|
||
model_id=row["model_id"],
|
||
prediction_type=PredictionType(row["prediction_type"]),
|
||
target_id=row["target_id"],
|
||
prediction_data=json.loads(row["prediction_data"]),
|
||
confidence=row["confidence"],
|
||
explanation=row["explanation"],
|
||
actual_value=row["actual_value"],
|
||
is_correct=row["is_correct"],
|
||
created_at=row["created_at"],
|
||
)
|
||
|
||
|
||
# Singleton instance
|
||
_ai_manager = None
|
||
|
||
|
||
def get_ai_manager() -> AIManager:
|
||
global _ai_manager
|
||
if _ai_manager is None:
|
||
_ai_manager = AIManager()
|
||
return _ai_manager
|