#!/usr/bin/env python3 """ InsightFlow Database Manager 处理项目、实体、关系的持久化 """ import os import json import sqlite3 from datetime import datetime from typing import List, Dict, Optional, Tuple from dataclasses import dataclass DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db") @dataclass class Project: id: str name: str description: str = "" created_at: str = "" updated_at: str = "" @dataclass class Entity: id: str project_id: str name: str type: str definition: str = "" canonical_name: str = "" aliases: List[str] = None def __post_init__(self): if self.aliases is None: self.aliases = [] @dataclass class EntityMention: id: str entity_id: str transcript_id: str start_pos: int end_pos: int text_snippet: str confidence: float = 1.0 class DatabaseManager: def __init__(self, db_path: str = DB_PATH): self.db_path = db_path os.makedirs(os.path.dirname(db_path), exist_ok=True) self.init_db() def get_conn(self): conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def init_db(self): """初始化数据库表""" with open(os.path.join(os.path.dirname(__file__), 'schema.sql'), 'r') as f: schema = f.read() conn = self.get_conn() conn.executescript(schema) conn.commit() conn.close() # Project operations def create_project(self, project_id: str, name: str, description: str = "") -> Project: conn = self.get_conn() now = datetime.now().isoformat() conn.execute( "INSERT INTO projects (id, name, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", (project_id, name, description, now, now) ) conn.commit() conn.close() return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now) def get_project(self, project_id: str) -> Optional[Project]: conn = self.get_conn() row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() conn.close() if row: return Project(**dict(row)) return None def list_projects(self) -> List[Project]: conn = self.get_conn() rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() conn.close() return [Project(**dict(r)) for r in rows] # Entity operations def create_entity(self, entity: Entity) -> Entity: conn = self.get_conn() conn.execute( """INSERT INTO entities (id, project_id, name, canonical_name, type, definition, aliases, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", (entity.id, entity.project_id, entity.name, entity.canonical_name, entity.type, entity.definition, json.dumps(entity.aliases), datetime.now().isoformat(), datetime.now().isoformat()) ) conn.commit() conn.close() return entity def get_entity_by_name(self, project_id: str, name: str) -> Optional[Entity]: """通过名称查找实体(用于对齐)""" conn = self.get_conn() row = conn.execute( "SELECT * FROM entities WHERE project_id = ? AND (name = ? OR canonical_name = ? OR aliases LIKE ?)", (project_id, name, name, f'%"{name}"%') ).fetchone() conn.close() if row: data = dict(row) data['aliases'] = json.loads(data['aliases']) if data['aliases'] else [] return Entity(**data) return None def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> List[Entity]: """查找相似实体(简单实现,生产可用 embedding)""" # TODO: 使用 embedding 或模糊匹配 # 现在简单返回包含相同关键词的实体 conn = self.get_conn() rows = conn.execute( "SELECT * FROM entities WHERE project_id = ? AND name LIKE ?", (project_id, f"%{name}%") ).fetchall() conn.close() entities = [] for row in rows: data = dict(row) data['aliases'] = json.loads(data['aliases']) if data['aliases'] else [] entities.append(Entity(**data)) return entities def merge_entities(self, target_id: str, source_id: str) -> Entity: """合并两个实体(实体对齐)""" conn = self.get_conn() # 获取两个实体 target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id,)).fetchone() source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id,)).fetchone() if not target or not source: conn.close() raise ValueError("Entity not found") # 合并别名 target_aliases = set(json.loads(target['aliases']) if target['aliases'] else []) target_aliases.add(source['name']) target_aliases.update(json.loads(source['aliases']) if source['aliases'] else []) # 更新目标实体 conn.execute( "UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?", (json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id) ) # 更新提及记录 conn.execute( "UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?", (target_id, source_id) ) # 删除源实体 conn.execute("DELETE FROM entities WHERE id = ?", (source_id,)) conn.commit() conn.close() return self.get_entity(target_id) def get_entity(self, entity_id: str) -> Optional[Entity]: conn = self.get_conn() row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone() conn.close() if row: data = dict(row) data['aliases'] = json.loads(data['aliases']) if data['aliases'] else [] return Entity(**data) return None def list_project_entities(self, project_id: str) -> List[Entity]: conn = self.get_conn() rows = conn.execute( "SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,) ).fetchall() conn.close() entities = [] for row in rows: data = dict(row) data['aliases'] = json.loads(data['aliases']) if data['aliases'] else [] entities.append(Entity(**data)) return entities # Mention operations def add_mention(self, mention: EntityMention) -> EntityMention: conn = self.get_conn() conn.execute( """INSERT INTO entity_mentions (id, entity_id, transcript_id, start_pos, end_pos, text_snippet, confidence) VALUES (?, ?, ?, ?, ?, ?, ?)""", (mention.id, mention.entity_id, mention.transcript_id, mention.start_pos, mention.end_pos, mention.text_snippet, mention.confidence) ) conn.commit() conn.close() return mention def get_entity_mentions(self, entity_id: str) -> List[EntityMention]: conn = self.get_conn() rows = conn.execute( "SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id,) ).fetchall() conn.close() return [EntityMention(**dict(r)) for r in rows] # Transcript operations def save_transcript(self, transcript_id: str, project_id: str, filename: str, full_text: str): """保存转录记录""" conn = self.get_conn() now = datetime.now().isoformat() conn.execute( "INSERT INTO transcripts (id, project_id, filename, full_text, created_at) VALUES (?, ?, ?, ?, ?)", (transcript_id, project_id, filename, full_text, now) ) conn.commit() conn.close() def get_transcript(self, transcript_id: str) -> Optional[dict]: """获取转录记录""" conn = self.get_conn() row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone() conn.close() if row: return dict(row) return None def list_project_transcripts(self, project_id: str) -> List[dict]: """列出项目的所有转录""" conn = self.get_conn() rows = conn.execute( "SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,) ).fetchall() conn.close() return [dict(r) for r in rows] # Relation operations def create_relation(self, project_id: str, source_entity_id: str, target_entity_id: str, relation_type: str = "related", evidence: str = "", transcript_id: str = ""): """创建实体关系""" conn = self.get_conn() relation_id = str(uuid.uuid4())[:8] now = datetime.now().isoformat() conn.execute( """INSERT INTO entity_relations (id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", (relation_id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, now) ) conn.commit() conn.close() return relation_id def get_entity_relations(self, entity_id: str) -> List[dict]: """获取实体的所有关系""" conn = self.get_conn() rows = conn.execute( """SELECT * FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ? ORDER BY created_at DESC""", (entity_id, entity_id) ).fetchall() conn.close() return [dict(r) for r in rows] def list_project_relations(self, project_id: str) -> List[dict]: """列出项目的所有关系""" conn = self.get_conn() rows = conn.execute( "SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id,) ).fetchall() conn.close() return [dict(r) for r in rows] # Singleton instance _db_manager = None def get_db_manager() -> DatabaseManager: global _db_manager if _db_manager is None: _db_manager = DatabaseManager() return _db_manager