#!/usr/bin/env python3 """ InsightFlow AI Manager - Phase 8 Task 4 AI 能力增强模块 - 自定义模型训练(领域特定实体识别) - 多模态大模型集成(GPT-4V、Claude 3) - 智能摘要与问答(基于知识图谱的 RAG) - 预测性分析(趋势预测、异常检测) """ import os import json import sqlite3 import httpx import asyncio import random import statistics from typing import List, Dict, Optional from dataclasses import dataclass from datetime import datetime from enum import Enum from collections import defaultdict import uuid # Database path DB_PATH = os.path.join(os.path.dirname(__file__), "insightflow.db") class ModelType(str, Enum): """模型类型""" CUSTOM_NER = "custom_ner" # 自定义实体识别 MULTIMODAL = "multimodal" # 多模态 SUMMARIZATION = "summarization" # 摘要 PREDICTION = "prediction" # 预测 class ModelStatus(str, Enum): """模型状态""" PENDING = "pending" TRAINING = "training" READY = "ready" FAILED = "failed" ARCHIVED = "archived" class MultimodalProvider(str, Enum): """多模态模型提供商""" GPT4V = "gpt-4-vision" CLAUDE3 = "claude-3" GEMINI = "gemini-pro-vision" KIMI_VL = "kimi-vl" class PredictionType(str, Enum): """预测类型""" TREND = "trend" # 趋势预测 ANOMALY = "anomaly" # 异常检测 ENTITY_GROWTH = "entity_growth" # 实体增长预测 RELATION_EVOLUTION = "relation_evolution" # 关系演变预测 @dataclass class CustomModel: """自定义模型""" id: str tenant_id: str name: str description: str model_type: ModelType status: ModelStatus training_data: Dict # 训练数据配置 hyperparameters: Dict # 超参数 metrics: Dict # 训练指标 model_path: Optional[str] # 模型文件路径 created_at: str updated_at: str trained_at: Optional[str] created_by: str @dataclass class TrainingSample: """训练样本""" id: str model_id: str text: str entities: List[Dict] # [{"start": 0, "end": 5, "label": "PERSON", "text": "张三"}] metadata: Dict created_at: str @dataclass class MultimodalAnalysis: """多模态分析结果""" id: str tenant_id: str project_id: str provider: MultimodalProvider input_type: str # image, video, audio, mixed input_urls: List[str] prompt: str result: Dict # 分析结果 tokens_used: int cost: float created_at: str @dataclass class KnowledgeGraphRAG: """基于知识图谱的 RAG 配置""" id: str tenant_id: str project_id: str name: str description: str kg_config: Dict # 知识图谱配置 retrieval_config: Dict # 检索配置 generation_config: Dict # 生成配置 is_active: bool created_at: str updated_at: str @dataclass class RAGQuery: """RAG 查询记录""" id: str rag_id: str query: str context: Dict # 检索到的上下文 answer: str sources: List[Dict] # 来源信息 confidence: float tokens_used: int latency_ms: int created_at: str @dataclass class PredictionModel: """预测模型""" id: str tenant_id: str project_id: str name: str prediction_type: PredictionType target_entity_type: Optional[str] # 目标实体类型 features: List[str] # 特征列表 model_config: Dict # 模型配置 accuracy: Optional[float] last_trained_at: Optional[str] prediction_count: int is_active: bool created_at: str updated_at: str @dataclass class PredictionResult: """预测结果""" id: str model_id: str prediction_type: PredictionType target_id: Optional[str] # 预测目标ID prediction_data: Dict # 预测数据 confidence: float explanation: str # 预测解释 actual_value: Optional[str] # 实际值(用于验证) is_correct: Optional[bool] created_at: str @dataclass class SmartSummary: """智能摘要""" id: str tenant_id: str project_id: str source_type: str # transcript, entity, project source_id: str summary_type: str # extractive, abstractive, key_points, timeline content: str key_points: List[str] entities_mentioned: List[str] confidence: float tokens_used: int created_at: str class AIManager: """AI 能力管理主类""" def __init__(self, db_path: str = DB_PATH): self.db_path = db_path self.kimi_api_key = os.getenv("KIMI_API_KEY", "") self.kimi_base_url = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding") self.openai_api_key = os.getenv("OPENAI_API_KEY", "") self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") def _get_db(self): """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn # ==================== 自定义模型训练 ==================== def create_custom_model( self, tenant_id: str, name: str, description: str, model_type: ModelType, training_data: Dict, hyperparameters: Dict, created_by: str, ) -> CustomModel: """创建自定义模型""" model_id = f"cm_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() model = CustomModel( id=model_id, tenant_id=tenant_id, name=name, description=description, model_type=model_type, status=ModelStatus.PENDING, training_data=training_data, hyperparameters=hyperparameters, metrics={}, model_path=None, created_at=now, updated_at=now, trained_at=None, created_by=created_by, ) with self._get_db() as conn: conn.execute( """ INSERT INTO custom_models (id, tenant_id, name, description, model_type, status, training_data, hyperparameters, metrics, model_path, created_at, updated_at, trained_at, created_by) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( model.id, model.tenant_id, model.name, model.description, model.model_type.value, model.status.value, json.dumps(model.training_data), json.dumps(model.hyperparameters), json.dumps(model.metrics), model.model_path, model.created_at, model.updated_at, model.trained_at, model.created_by, ), ) conn.commit() return model def get_custom_model(self, model_id: str) -> Optional[CustomModel]: """获取自定义模型""" with self._get_db() as conn: row = conn.execute("SELECT * FROM custom_models WHERE id = ?", (model_id,)).fetchone() if not row: return None return self._row_to_custom_model(row) def list_custom_models( self, tenant_id: str, model_type: Optional[ModelType] = None, status: Optional[ModelStatus] = None ) -> List[CustomModel]: """列出自定义模型""" query = "SELECT * FROM custom_models WHERE tenant_id = ?" params = [tenant_id] if model_type: query += " AND model_type = ?" params.append(model_type.value) if status: query += " AND status = ?" params.append(status.value) query += " ORDER BY created_at DESC" with self._get_db() as conn: rows = conn.execute(query, params).fetchall() return [self._row_to_custom_model(row) for row in rows] def add_training_sample( self, model_id: str, text: str, entities: List[Dict], metadata: Dict = None ) -> TrainingSample: """添加训练样本""" sample_id = f"ts_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() sample = TrainingSample( id=sample_id, model_id=model_id, text=text, entities=entities, metadata=metadata or {}, created_at=now ) with self._get_db() as conn: conn.execute( """ INSERT INTO training_samples (id, model_id, text, entities, metadata, created_at) VALUES (?, ?, ?, ?, ?, ?) """, ( sample.id, sample.model_id, sample.text, json.dumps(sample.entities), json.dumps(sample.metadata), sample.created_at, ), ) conn.commit() return sample def get_training_samples(self, model_id: str) -> List[TrainingSample]: """获取训练样本""" with self._get_db() as conn: rows = conn.execute( "SELECT * FROM training_samples WHERE model_id = ? ORDER BY created_at", (model_id,) ).fetchall() return [self._row_to_training_sample(row) for row in rows] async def train_custom_model(self, model_id: str) -> CustomModel: """训练自定义模型""" model = self.get_custom_model(model_id) if not model: raise ValueError(f"Model {model_id} not found") # 更新状态为训练中 with self._get_db() as conn: conn.execute( "UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?", (ModelStatus.TRAINING.value, datetime.now().isoformat(), model_id), ) conn.commit() try: # 获取训练样本 samples = self.get_training_samples(model_id) if len(samples) < 10: raise ValueError("至少需要 10 个训练样本") # 模拟训练过程(实际项目中这里会调用训练框架如 spaCy、Hugging Face 等) await asyncio.sleep(2) # 模拟训练时间 # 计算训练指标 metrics = { "samples_count": len(samples), "epochs": model.hyperparameters.get("epochs", 10), "learning_rate": model.hyperparameters.get("learning_rate", 0.001), "precision": round(0.85 + random.random() * 0.1, 4), "recall": round(0.82 + random.random() * 0.1, 4), "f1_score": round(0.84 + random.random() * 0.1, 4), "training_time_seconds": 120, } # 保存模型(模拟) model_path = f"models/{model_id}.bin" os.makedirs("models", exist_ok=True) now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE custom_models SET status = ?, metrics = ?, model_path = ?, trained_at = ?, updated_at = ? WHERE id = ? """, (ModelStatus.READY.value, json.dumps(metrics), model_path, now, now, model_id), ) conn.commit() return self.get_custom_model(model_id) except Exception as e: with self._get_db() as conn: conn.execute( "UPDATE custom_models SET status = ?, updated_at = ? WHERE id = ?", (ModelStatus.FAILED.value, datetime.now().isoformat(), model_id), ) conn.commit() raise e async def predict_with_custom_model(self, model_id: str, text: str) -> List[Dict]: """使用自定义模型进行预测""" model = self.get_custom_model(model_id) if not model or model.status != ModelStatus.READY: raise ValueError(f"Model {model_id} not ready") # 模拟预测(实际项目中加载模型并进行推理) # 这里使用 LLM 模拟领域特定实体识别 entity_types = model.training_data.get("entity_types", ["PERSON", "ORG", "TECH", "PROJECT"]) prompt = f"""从以下文本中提取实体,类型限定为: {', '.join(entity_types)} 文本: {text} 以 JSON 格式返回实体列表: [{{"text": "实体文本", "label": "类型", "start": 0, "end": 5, "confidence": 0.95}}] 只返回 JSON 数组,不要其他内容。""" headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1} async with httpx.AsyncClient() as client: response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 ) response.raise_for_status() result = response.json() content = result["choices"][0]["message"]["content"] # 解析 JSON import re json_match = re.search(r"\[.*?\]", content, re.DOTALL) if json_match: try: entities = json.loads(json_match.group()) return entities except (json.JSONDecodeError, ValueError): pass return [] # ==================== 多模态大模型集成 ==================== async def analyze_multimodal( self, tenant_id: str, project_id: str, provider: MultimodalProvider, input_type: str, input_urls: List[str], prompt: str, ) -> MultimodalAnalysis: """多模态分析""" analysis_id = f"ma_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() # 根据提供商调用不同的 API if provider == MultimodalProvider.GPT4V and self.openai_api_key: result = await self._call_gpt4v(input_urls, prompt) elif provider == MultimodalProvider.CLAUDE3 and self.anthropic_api_key: result = await self._call_claude3(input_urls, prompt) else: # 默认使用 Kimi result = await self._call_kimi_multimodal(input_urls, prompt) analysis = MultimodalAnalysis( id=analysis_id, tenant_id=tenant_id, project_id=project_id, provider=provider, input_type=input_type, input_urls=input_urls, prompt=prompt, result=result, tokens_used=result.get("tokens_used", 0), cost=result.get("cost", 0.0), created_at=now, ) with self._get_db() as conn: conn.execute( """ INSERT INTO multimodal_analyses (id, tenant_id, project_id, provider, input_type, input_urls, prompt, result, tokens_used, cost, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( analysis.id, analysis.tenant_id, analysis.project_id, analysis.provider.value, analysis.input_type, json.dumps(analysis.input_urls), analysis.prompt, json.dumps(analysis.result), analysis.tokens_used, analysis.cost, analysis.created_at, ), ) conn.commit() return analysis async def _call_gpt4v(self, image_urls: List[str], prompt: str) -> Dict: """调用 GPT-4V""" headers = {"Authorization": f"Bearer {self.openai_api_key}", "Content-Type": "application/json"} content = [{"type": "text", "text": prompt}] for url in image_urls: content.append({"type": "image_url", "image_url": {"url": url}}) payload = { "model": "gpt-4-vision-preview", "messages": [{"role": "user", "content": content}], "max_tokens": 2000, } async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=120.0 ) response.raise_for_status() result = response.json() return { "content": result["choices"][0]["message"]["content"], "tokens_used": result["usage"]["total_tokens"], "cost": result["usage"]["total_tokens"] * 0.00001, # 估算成本 } async def _call_claude3(self, image_urls: List[str], prompt: str) -> Dict: """调用 Claude 3""" headers = { "x-api-key": self.anthropic_api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", } content = [] for url in image_urls: content.append({"type": "image", "source": {"type": "url", "url": url}}) content.append({"type": "text", "text": prompt}) payload = { "model": "claude-3-opus-20240229", "max_tokens": 2000, "messages": [{"role": "user", "content": content}], } async with httpx.AsyncClient() as client: response = await client.post( "https://api.anthropic.com/v1/messages", headers=headers, json=payload, timeout=120.0 ) response.raise_for_status() result = response.json() return { "content": result["content"][0]["text"], "tokens_used": result["usage"]["input_tokens"] + result["usage"]["output_tokens"], "cost": (result["usage"]["input_tokens"] + result["usage"]["output_tokens"]) * 0.000015, } async def _call_kimi_multimodal(self, image_urls: List[str], prompt: str) -> Dict: """调用 Kimi 多模态模型""" headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} # Kimi 目前可能不支持真正的多模态,这里模拟返回 # 实际实现时需要根据 Kimi API 更新 content = f"图片 URL: {', '.join(image_urls)}\n\n{prompt}\n\n注意:请基于图片 URL 描述的内容进行回答。" payload = {"model": "k2p5", "messages": [{"role": "user", "content": content}], "temperature": 0.3} async with httpx.AsyncClient() as client: response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 ) response.raise_for_status() result = response.json() return { "content": result["choices"][0]["message"]["content"], "tokens_used": result["usage"]["total_tokens"], "cost": result["usage"]["total_tokens"] * 0.000005, } def get_multimodal_analyses(self, tenant_id: str, project_id: Optional[str] = None) -> List[MultimodalAnalysis]: """获取多模态分析历史""" query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?" params = [tenant_id] if project_id: query += " AND project_id = ?" params.append(project_id) query += " ORDER BY created_at DESC" with self._get_db() as conn: rows = conn.execute(query, params).fetchall() return [self._row_to_multimodal_analysis(row) for row in rows] # ==================== 智能摘要与问答(基于知识图谱的 RAG) ==================== def create_kg_rag( self, tenant_id: str, project_id: str, name: str, description: str, kg_config: Dict, retrieval_config: Dict, generation_config: Dict, ) -> KnowledgeGraphRAG: """创建知识图谱 RAG 配置""" rag_id = f"kgr_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() rag = KnowledgeGraphRAG( id=rag_id, tenant_id=tenant_id, project_id=project_id, name=name, description=description, kg_config=kg_config, retrieval_config=retrieval_config, generation_config=generation_config, is_active=True, created_at=now, updated_at=now, ) with self._get_db() as conn: conn.execute( """ INSERT INTO kg_rag_configs (id, tenant_id, project_id, name, description, kg_config, retrieval_config, generation_config, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( rag.id, rag.tenant_id, rag.project_id, rag.name, rag.description, json.dumps(rag.kg_config), json.dumps(rag.retrieval_config), json.dumps(rag.generation_config), rag.is_active, rag.created_at, rag.updated_at, ), ) conn.commit() return rag def get_kg_rag(self, rag_id: str) -> Optional[KnowledgeGraphRAG]: """获取知识图谱 RAG 配置""" with self._get_db() as conn: row = conn.execute("SELECT * FROM kg_rag_configs WHERE id = ?", (rag_id,)).fetchone() if not row: return None return self._row_to_kg_rag(row) def list_kg_rags(self, tenant_id: str, project_id: Optional[str] = None) -> List[KnowledgeGraphRAG]: """列出知识图谱 RAG 配置""" query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?" params = [tenant_id] if project_id: query += " AND project_id = ?" params.append(project_id) query += " ORDER BY created_at DESC" with self._get_db() as conn: rows = conn.execute(query, params).fetchall() return [self._row_to_kg_rag(row) for row in rows] async def query_kg_rag( self, rag_id: str, query: str, project_entities: List[Dict], project_relations: List[Dict] ) -> RAGQuery: """基于知识图谱的 RAG 查询""" import time start_time = time.time() rag = self.get_kg_rag(rag_id) if not rag: raise ValueError(f"RAG config {rag_id} not found") # 1. 检索相关实体和关系 retrieval_config = rag.retrieval_config top_k = retrieval_config.get("top_k", 5) # 简单的语义检索(基于实体名称匹配) query_lower = query.lower() relevant_entities = [] for entity in project_entities: score = 0 name = entity.get("name", "").lower() definition = entity.get("definition", "").lower() if name in query_lower or any(word in name for word in query_lower.split()): score += 0.5 if any(word in definition for word in query_lower.split()): score += 0.3 if score > 0: relevant_entities.append({**entity, "relevance_score": score}) relevant_entities.sort(key=lambda x: x["relevance_score"], reverse=True) relevant_entities = relevant_entities[:top_k] # 检索相关关系 relevant_relations = [] entity_ids = {e["id"] for e in relevant_entities} for relation in project_relations: if relation.get("source_entity_id") in entity_ids or relation.get("target_entity_id") in entity_ids: relevant_relations.append(relation) # 2. 构建上下文 context = {"entities": relevant_entities, "relations": relevant_relations[:10]} context_text = self._build_kg_context(relevant_entities, relevant_relations) # 3. 生成回答 generation_config = rag.generation_config temperature = generation_config.get("temperature", 0.3) max_tokens = generation_config.get("max_tokens", 1000) prompt = f"""基于以下知识图谱信息回答问题: ## 知识图谱上下文 {context_text} ## 用户问题 {query} 请基于上述知识图谱信息回答问题。如果信息不足,请明确说明。 回答应该: 1. 准确引用知识图谱中的实体和关系 2. 如果涉及多个实体,说明它们之间的关联 3. 保持简洁专业""" headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} payload = { "model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature, "max_tokens": max_tokens, } async with httpx.AsyncClient() as client: response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 ) response.raise_for_status() result = response.json() answer = result["choices"][0]["message"]["content"] tokens_used = result["usage"]["total_tokens"] latency_ms = int((time.time() - start_time) * 1000) # 4. 保存查询记录 query_id = f"rq_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() sources = [ {"entity_id": e["id"], "entity_name": e["name"], "score": e["relevance_score"]} for e in relevant_entities ] rag_query = RAGQuery( id=query_id, rag_id=rag_id, query=query, context=context, answer=answer, sources=sources, confidence=( sum(e["relevance_score"] for e in relevant_entities) / len(relevant_entities) if relevant_entities else 0 ), tokens_used=tokens_used, latency_ms=latency_ms, created_at=now, ) with self._get_db() as conn: conn.execute( """ INSERT INTO rag_queries (id, rag_id, query, context, answer, sources, confidence, tokens_used, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( rag_query.id, rag_query.rag_id, rag_query.query, json.dumps(rag_query.context), rag_query.answer, json.dumps(rag_query.sources), rag_query.confidence, rag_query.tokens_used, rag_query.latency_ms, rag_query.created_at, ), ) conn.commit() return rag_query def _build_kg_context(self, entities: List[Dict], relations: List[Dict]) -> str: """构建知识图谱上下文文本""" context = [] if entities: context.append("### 相关实体") for entity in entities: name = entity.get("name", "") entity_type = entity.get("type", "") definition = entity.get("definition", "") context.append(f"- **{name}** ({entity_type}): {definition}") if relations: context.append("\n### 相关关系") for relation in relations: source = relation.get("source_name", "") target = relation.get("target_name", "") rel_type = relation.get("relation_type", "") evidence = relation.get("evidence", "") context.append(f"- {source} --[{rel_type}]--> {target}") if evidence: context.append(f" - 依据: {evidence[:100]}...") return "\n".join(context) async def generate_smart_summary( self, tenant_id: str, project_id: str, source_type: str, source_id: str, summary_type: str, content_data: Dict ) -> SmartSummary: """生成智能摘要""" summary_id = f"ss_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() # 根据摘要类型生成不同的提示 if summary_type == "extractive": prompt = f"""从以下内容中提取关键句子作为摘要: {content_data.get('text', '')[:5000]} 要求: 1. 提取 3-5 个最重要的句子 2. 保持原文表述 3. 以 JSON 格式返回: {{"summary": "摘要内容", "key_points": ["要点1", "要点2"]}}""" elif summary_type == "abstractive": prompt = f"""对以下内容生成简洁的摘要: {content_data.get('text', '')[:5000]} 要求: 1. 用 2-3 句话概括核心内容 2. 使用自己的语言重新表述 3. 包含关键实体和概念""" elif summary_type == "key_points": prompt = f"""从以下内容中提取关键要点: {content_data.get('text', '')[:5000]} 要求: 1. 列出 5-8 个关键要点 2. 每个要点简洁明了 3. 以 JSON 格式返回: {{"key_points": ["要点1", "要点2", ...]}}""" else: # timeline prompt = f"""基于以下内容生成时间线摘要: {content_data.get('text', '')[:5000]} 要求: 1. 按时间顺序组织关键事件 2. 标注时间节点(如果有) 3. 突出里程碑事件""" headers = {"Authorization": f"Bearer {self.kimi_api_key}", "Content-Type": "application/json"} payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3} async with httpx.AsyncClient() as client: response = await client.post( f"{self.kimi_base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60.0 ) response.raise_for_status() result = response.json() content = result["choices"][0]["message"]["content"] tokens_used = result["usage"]["total_tokens"] # 解析关键要点 key_points = [] import re # 尝试从 JSON 中提取 json_match = re.search(r"\{.*?\}", content, re.DOTALL) if json_match: try: data = json.loads(json_match.group()) key_points = data.get("key_points", []) if "summary" in data: content = data["summary"] except (json.JSONDecodeError, ValueError): pass # 如果没有提取到关键要点,从文本中提取 if not key_points: lines = content.split("\n") key_points = [ line.strip("- ").strip() for line in lines if line.strip().startswith("-") or line.strip().startswith("•") ] if not key_points: key_points = [content[:200] + "..."] if len(content) > 200 else [content] # 提取提及的实体 entities_mentioned = content_data.get("entities", []) entity_names = [e.get("name", "") for e in entities_mentioned[:10]] summary = SmartSummary( id=summary_id, tenant_id=tenant_id, project_id=project_id, source_type=source_type, source_id=source_id, summary_type=summary_type, content=content, key_points=key_points[:8], entities_mentioned=entity_names, confidence=0.85, tokens_used=tokens_used, created_at=now, ) with self._get_db() as conn: conn.execute( """ INSERT INTO smart_summaries (id, tenant_id, project_id, source_type, source_id, summary_type, content, key_points, entities_mentioned, confidence, tokens_used, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( summary.id, summary.tenant_id, summary.project_id, summary.source_type, summary.source_id, summary.summary_type, summary.content, json.dumps(summary.key_points), json.dumps(summary.entities_mentioned), summary.confidence, summary.tokens_used, summary.created_at, ), ) conn.commit() return summary # ==================== 预测性分析 ==================== def create_prediction_model( self, tenant_id: str, project_id: str, name: str, prediction_type: PredictionType, target_entity_type: Optional[str], features: List[str], model_config: Dict, ) -> PredictionModel: """创建预测模型""" model_id = f"pm_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() model = PredictionModel( id=model_id, tenant_id=tenant_id, project_id=project_id, name=name, prediction_type=prediction_type, target_entity_type=target_entity_type, features=features, model_config=model_config, accuracy=None, last_trained_at=None, prediction_count=0, is_active=True, created_at=now, updated_at=now, ) with self._get_db() as conn: conn.execute( """ INSERT INTO prediction_models (id, tenant_id, project_id, name, prediction_type, target_entity_type, features, model_config, accuracy, last_trained_at, prediction_count, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( model.id, model.tenant_id, model.project_id, model.name, model.prediction_type.value, model.target_entity_type, json.dumps(model.features), json.dumps(model.model_config), model.accuracy, model.last_trained_at, model.prediction_count, model.is_active, model.created_at, model.updated_at, ), ) conn.commit() return model def get_prediction_model(self, model_id: str) -> Optional[PredictionModel]: """获取预测模型""" with self._get_db() as conn: row = conn.execute("SELECT * FROM prediction_models WHERE id = ?", (model_id,)).fetchone() if not row: return None return self._row_to_prediction_model(row) def list_prediction_models(self, tenant_id: str, project_id: Optional[str] = None) -> List[PredictionModel]: """列出预测模型""" query = "SELECT * FROM prediction_models WHERE tenant_id = ?" params = [tenant_id] if project_id: query += " AND project_id = ?" params.append(project_id) query += " ORDER BY created_at DESC" with self._get_db() as conn: rows = conn.execute(query, params).fetchall() return [self._row_to_prediction_model(row) for row in rows] async def train_prediction_model(self, model_id: str, historical_data: List[Dict]) -> PredictionModel: """训练预测模型""" model = self.get_prediction_model(model_id) if not model: raise ValueError(f"Prediction model {model_id} not found") # 模拟训练过程 await asyncio.sleep(1) # 计算准确率(模拟) accuracy = round(0.75 + random.random() * 0.2, 4) now = datetime.now().isoformat() with self._get_db() as conn: conn.execute( """ UPDATE prediction_models SET accuracy = ?, last_trained_at = ?, updated_at = ? WHERE id = ? """, (accuracy, now, now, model_id), ) conn.commit() return self.get_prediction_model(model_id) async def predict(self, model_id: str, input_data: Dict) -> PredictionResult: """进行预测""" model = self.get_prediction_model(model_id) if not model or not model.is_active: raise ValueError(f"Prediction model {model_id} not available") prediction_id = f"pr_{uuid.uuid4().hex[:16]}" now = datetime.now().isoformat() # 根据预测类型进行不同的预测逻辑 if model.prediction_type == PredictionType.TREND: prediction_data = self._predict_trend(input_data, model) elif model.prediction_type == PredictionType.ANOMALY: prediction_data = self._detect_anomaly(input_data, model) elif model.prediction_type == PredictionType.ENTITY_GROWTH: prediction_data = self._predict_entity_growth(input_data, model) elif model.prediction_type == PredictionType.RELATION_EVOLUTION: prediction_data = self._predict_relation_evolution(input_data, model) else: prediction_data = {"value": "unknown", "confidence": 0} confidence = prediction_data.get("confidence", 0.8) explanation = prediction_data.get("explanation", "基于历史数据模式预测") result = PredictionResult( id=prediction_id, model_id=model_id, prediction_type=model.prediction_type, target_id=input_data.get("target_id"), prediction_data=prediction_data, confidence=confidence, explanation=explanation, actual_value=None, is_correct=None, created_at=now, ) with self._get_db() as conn: conn.execute( """ INSERT INTO prediction_results (id, model_id, prediction_type, target_id, prediction_data, confidence, explanation, actual_value, is_correct, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( result.id, result.model_id, result.prediction_type.value, result.target_id, json.dumps(result.prediction_data), result.confidence, result.explanation, result.actual_value, result.is_correct, result.created_at, ), ) # 更新预测计数 conn.execute( "UPDATE prediction_models SET prediction_count = prediction_count + 1 WHERE id = ?", (model_id,) ) conn.commit() return result def _predict_trend(self, input_data: Dict, model: PredictionModel) -> Dict: """趋势预测""" historical_values = input_data.get("historical_values", []) if len(historical_values) < 2: return { "predicted_value": 0, "trend": "stable", "confidence": 0.5, "explanation": "历史数据不足,无法准确预测趋势", } # 简单线性趋势预测 - 使用最小二乘法计算斜率 n = len(historical_values) x = list(range(n)) y = historical_values # 计算均值 mean_x = sum(x) / n mean_y = sum(y) / n # 计算斜率 (最小二乘法) numerator = sum((x[i] - mean_x) * (y[i] - mean_y) for i in range(n)) denominator = sum((x[i] - mean_x) ** 2 for i in range(n)) slope = numerator / denominator if denominator != 0 else 0 # 预测下一个值 next_value = y[-1] + slope trend = "increasing" if slope > 0.01 else "decreasing" if slope < -0.01 else "stable" return { "predicted_value": round(next_value, 2), "trend": trend, "slope": round(slope, 4), "confidence": min(0.95, 0.6 + len(historical_values) * 0.02), "explanation": f"基于{len(historical_values)}个历史数据点,预测趋势为{trend}", } def _detect_anomaly(self, input_data: Dict, model: PredictionModel) -> Dict: """异常检测""" value = input_data.get("value") historical_values = input_data.get("historical_values", []) if not historical_values or value is None: return { "is_anomaly": False, "anomaly_score": 0, "confidence": 0.5, "explanation": "数据不足,无法进行异常检测", } # 计算均值和标准差 mean = statistics.mean(historical_values) std = statistics.stdev(historical_values) if len(historical_values) > 1 else 0 if std == 0: is_anomaly = value != mean z_score = 0 if value == mean else 3 else: z_score = abs(value - mean) / std is_anomaly = z_score > 2.5 # 2.5 个标准差视为异常 return { "is_anomaly": is_anomaly, "anomaly_score": round(min(z_score / 3, 1.0), 2), "z_score": round(z_score, 2), "mean": round(mean, 2), "std": round(std, 2), "confidence": min(0.95, 0.7 + len(historical_values) * 0.01), "explanation": f"当前值偏离均值{z_score:.2f}个标准差,{'检测到异常' if is_anomaly else '处于正常范围'}", } def _predict_entity_growth(self, input_data: Dict, model: PredictionModel) -> Dict: """实体增长预测""" entity_history = input_data.get("entity_history", []) if len(entity_history) < 3: return { "predicted_count": len(entity_history), "growth_rate": 0, "confidence": 0.5, "explanation": "历史数据不足,无法预测增长趋势", } # 计算增长率 counts = [h.get("count", 0) for h in entity_history] growth_rates = [(counts[i] - counts[i - 1]) / max(counts[i - 1], 1) for i in range(1, len(counts))] avg_growth_rate = statistics.mean(growth_rates) if growth_rates else 0 # 预测下一个周期的实体数量 predicted_count = counts[-1] * (1 + avg_growth_rate) return { "predicted_count": round(predicted_count), "current_count": counts[-1], "growth_rate": round(avg_growth_rate, 4), "confidence": min(0.9, 0.6 + len(entity_history) * 0.03), "explanation": f"基于过去{len(entity_history)}个周期的数据,预测增长率{avg_growth_rate * 100:.1f}%", } def _predict_relation_evolution(self, input_data: Dict, model: PredictionModel) -> Dict: """关系演变预测""" relation_history = input_data.get("relation_history", []) if len(relation_history) < 2: return {"predicted_relations": [], "confidence": 0.5, "explanation": "历史数据不足,无法预测关系演变"} # 分析关系变化趋势 relation_counts = defaultdict(int) for snapshot in relation_history: for rel in snapshot.get("relations", []): relation_counts[rel.get("type", "unknown")] += 1 # 预测可能出现的新关系类型 predicted_relations = [ {"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)} for rel_type, count in sorted(relation_counts.items(), key=lambda x: x[1], reverse=True)[:5] ] return { "predicted_relations": predicted_relations, "relation_trends": dict(relation_counts), "confidence": min(0.85, 0.6 + len(relation_history) * 0.05), "explanation": f"基于{len(relation_history)}个历史快照分析关系演变趋势", } def get_prediction_results(self, model_id: str, limit: int = 100) -> List[PredictionResult]: """获取预测结果历史""" with self._get_db() as conn: rows = conn.execute( """SELECT * FROM prediction_results WHERE model_id = ? ORDER BY created_at DESC LIMIT ?""", (model_id, limit), ).fetchall() return [self._row_to_prediction_result(row) for row in rows] def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool): """更新预测反馈(用于模型改进)""" with self._get_db() as conn: conn.execute( """UPDATE prediction_results SET actual_value = ?, is_correct = ? WHERE id = ?""", (actual_value, is_correct, prediction_id), ) conn.commit() # ==================== 辅助方法 ==================== def _row_to_custom_model(self, row) -> CustomModel: """将数据库行转换为 CustomModel""" return CustomModel( id=row["id"], tenant_id=row["tenant_id"], name=row["name"], description=row["description"], model_type=ModelType(row["model_type"]), status=ModelStatus(row["status"]), training_data=json.loads(row["training_data"]), hyperparameters=json.loads(row["hyperparameters"]), metrics=json.loads(row["metrics"]), model_path=row["model_path"], created_at=row["created_at"], updated_at=row["updated_at"], trained_at=row["trained_at"], created_by=row["created_by"], ) def _row_to_training_sample(self, row) -> TrainingSample: """将数据库行转换为 TrainingSample""" return TrainingSample( id=row["id"], model_id=row["model_id"], text=row["text"], entities=json.loads(row["entities"]), metadata=json.loads(row["metadata"]), created_at=row["created_at"], ) def _row_to_multimodal_analysis(self, row) -> MultimodalAnalysis: """将数据库行转换为 MultimodalAnalysis""" return MultimodalAnalysis( id=row["id"], tenant_id=row["tenant_id"], project_id=row["project_id"], provider=MultimodalProvider(row["provider"]), input_type=row["input_type"], input_urls=json.loads(row["input_urls"]), prompt=row["prompt"], result=json.loads(row["result"]), tokens_used=row["tokens_used"], cost=row["cost"], created_at=row["created_at"], ) def _row_to_kg_rag(self, row) -> KnowledgeGraphRAG: """将数据库行转换为 KnowledgeGraphRAG""" return KnowledgeGraphRAG( id=row["id"], tenant_id=row["tenant_id"], project_id=row["project_id"], name=row["name"], description=row["description"], kg_config=json.loads(row["kg_config"]), retrieval_config=json.loads(row["retrieval_config"]), generation_config=json.loads(row["generation_config"]), is_active=bool(row["is_active"]), created_at=row["created_at"], updated_at=row["updated_at"], ) def _row_to_prediction_model(self, row) -> PredictionModel: """将数据库行转换为 PredictionModel""" return PredictionModel( id=row["id"], tenant_id=row["tenant_id"], project_id=row["project_id"], name=row["name"], prediction_type=PredictionType(row["prediction_type"]), target_entity_type=row["target_entity_type"], features=json.loads(row["features"]), model_config=json.loads(row["model_config"]), accuracy=row["accuracy"], last_trained_at=row["last_trained_at"], prediction_count=row["prediction_count"], is_active=bool(row["is_active"]), created_at=row["created_at"], updated_at=row["updated_at"], ) def _row_to_prediction_result(self, row) -> PredictionResult: """将数据库行转换为 PredictionResult""" return PredictionResult( id=row["id"], model_id=row["model_id"], prediction_type=PredictionType(row["prediction_type"]), target_id=row["target_id"], prediction_data=json.loads(row["prediction_data"]), confidence=row["confidence"], explanation=row["explanation"], actual_value=row["actual_value"], is_correct=row["is_correct"], created_at=row["created_at"], ) # Singleton instance _ai_manager = None def get_ai_manager() -> AIManager: global _ai_manager if _ai_manager is None: _ai_manager = AIManager() return _ai_manager