Files
insightflow/backend/db_manager.py

467 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
InsightFlow Database Manager - Phase 3
处理项目、实体、关系的持久化
支持文档类型和多文件融合
"""
import os
import json
import sqlite3
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
@dataclass
class Project:
id: str
name: str
description: str = ""
created_at: str = ""
updated_at: str = ""
@dataclass
class Entity:
id: str
project_id: str
name: str
type: str
definition: str = ""
canonical_name: str = ""
aliases: List[str] = None
def __post_init__(self):
if self.aliases is None:
self.aliases = []
@dataclass
class EntityMention:
id: str
entity_id: str
transcript_id: str
start_pos: int
end_pos: int
text_snippet: str
confidence: float = 1.0
class DatabaseManager:
def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self.init_db()
def get_conn(self):
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def init_db(self):
"""初始化数据库表"""
with open(os.path.join(os.path.dirname(__file__), 'schema.sql'), 'r') as f:
schema = f.read()
conn = self.get_conn()
conn.executescript(schema)
conn.commit()
conn.close()
# Project operations
def create_project(self, project_id: str, name: str, description: str = "") -> Project:
conn = self.get_conn()
now = datetime.now().isoformat()
conn.execute(
"INSERT INTO projects (id, name, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
(project_id, name, description, now, now)
)
conn.commit()
conn.close()
return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now)
def get_project(self, project_id: str) -> Optional[Project]:
conn = self.get_conn()
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
conn.close()
if row:
return Project(**dict(row))
return None
def list_projects(self) -> List[Project]:
conn = self.get_conn()
rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall()
conn.close()
return [Project(**dict(r)) for r in rows]
# Entity operations
def create_entity(self, entity: Entity) -> Entity:
conn = self.get_conn()
conn.execute(
"""INSERT INTO entities (id, project_id, name, canonical_name, type, definition, aliases, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(entity.id, entity.project_id, entity.name, entity.canonical_name, entity.type,
entity.definition, json.dumps(entity.aliases), datetime.now().isoformat(), datetime.now().isoformat())
)
conn.commit()
conn.close()
return entity
def get_entity_by_name(self, project_id: str, name: str) -> Optional[Entity]:
"""通过名称查找实体(用于对齐)"""
conn = self.get_conn()
row = conn.execute(
"SELECT * FROM entities WHERE project_id = ? AND (name = ? OR canonical_name = ? OR aliases LIKE ?)",
(project_id, name, name, f'%"{name}"%')
).fetchone()
conn.close()
if row:
data = dict(row)
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
return Entity(**data)
return None
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> List[Entity]:
"""查找相似实体(简单实现,生产可用 embedding"""
# TODO: 使用 embedding 或模糊匹配
# 现在简单返回包含相同关键词的实体
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entities WHERE project_id = ? AND name LIKE ?",
(project_id, f"%{name}%")
).fetchall()
conn.close()
entities = []
for row in rows:
data = dict(row)
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
entities.append(Entity(**data))
return entities
def merge_entities(self, target_id: str, source_id: str) -> Entity:
"""合并两个实体(实体对齐)"""
conn = self.get_conn()
# 获取两个实体
target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id,)).fetchone()
source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id,)).fetchone()
if not target or not source:
conn.close()
raise ValueError("Entity not found")
# 合并别名
target_aliases = set(json.loads(target['aliases']) if target['aliases'] else [])
target_aliases.add(source['name'])
target_aliases.update(json.loads(source['aliases']) if source['aliases'] else [])
# 更新目标实体
conn.execute(
"UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?",
(json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id)
)
# 更新提及记录
conn.execute(
"UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?",
(target_id, source_id)
)
# 更新关系 - source 作为 source_entity_id
conn.execute(
"UPDATE entity_relations SET source_entity_id = ? WHERE source_entity_id = ?",
(target_id, source_id)
)
# 更新关系 - source 作为 target_entity_id
conn.execute(
"UPDATE entity_relations SET target_entity_id = ? WHERE target_entity_id = ?",
(target_id, source_id)
)
# 删除源实体
conn.execute("DELETE FROM entities WHERE id = ?", (source_id,))
conn.commit()
conn.close()
return self.get_entity(target_id)
def get_entity(self, entity_id: str) -> Optional[Entity]:
conn = self.get_conn()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
conn.close()
if row:
data = dict(row)
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
return Entity(**data)
return None
def list_project_entities(self, project_id: str) -> List[Entity]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC",
(project_id,)
).fetchall()
conn.close()
entities = []
for row in rows:
data = dict(row)
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
entities.append(Entity(**data))
return entities
# Mention operations
def add_mention(self, mention: EntityMention) -> EntityMention:
conn = self.get_conn()
conn.execute(
"""INSERT INTO entity_mentions (id, entity_id, transcript_id, start_pos, end_pos, text_snippet, confidence)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(mention.id, mention.entity_id, mention.transcript_id, mention.start_pos,
mention.end_pos, mention.text_snippet, mention.confidence)
)
conn.commit()
conn.close()
return mention
def get_entity_mentions(self, entity_id: str) -> List[EntityMention]:
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos",
(entity_id,)
).fetchall()
conn.close()
return [EntityMention(**dict(r)) for r in rows]
# Transcript operations
def save_transcript(self, transcript_id: str, project_id: str, filename: str, full_text: str, transcript_type: str = "audio"):
"""保存转录记录"""
conn = self.get_conn()
now = datetime.now().isoformat()
conn.execute(
"INSERT INTO transcripts (id, project_id, filename, full_text, type, created_at) VALUES (?, ?, ?, ?, ?, ?)",
(transcript_id, project_id, filename, full_text, transcript_type, now)
)
conn.commit()
conn.close()
def get_transcript(self, transcript_id: str) -> Optional[dict]:
"""获取转录记录"""
conn = self.get_conn()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
conn.close()
if row:
return dict(row)
return None
def list_project_transcripts(self, project_id: str) -> List[dict]:
"""列出项目的所有转录"""
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC",
(project_id,)
).fetchall()
conn.close()
return [dict(r) for r in rows]
# Relation operations
def create_relation(self, project_id: str, source_entity_id: str, target_entity_id: str,
relation_type: str = "related", evidence: str = "", transcript_id: str = ""):
"""创建实体关系"""
conn = self.get_conn()
relation_id = str(uuid.uuid4())[:8]
now = datetime.now().isoformat()
conn.execute(
"""INSERT INTO entity_relations
(id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(relation_id, project_id, source_entity_id, target_entity_id, relation_type, evidence, transcript_id, now)
)
conn.commit()
conn.close()
return relation_id
def get_entity_relations(self, entity_id: str) -> List[dict]:
"""获取实体的所有关系"""
conn = self.get_conn()
rows = conn.execute(
"""SELECT * FROM entity_relations
WHERE source_entity_id = ? OR target_entity_id = ?
ORDER BY created_at DESC""",
(entity_id, entity_id)
).fetchall()
conn.close()
return [dict(r) for r in rows]
def list_project_relations(self, project_id: str) -> List[dict]:
"""列出项目的所有关系"""
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC",
(project_id,)
).fetchall()
conn.close()
return [dict(r) for r in rows]
def update_entity(self, entity_id: str, **kwargs) -> Entity:
"""更新实体信息"""
conn = self.get_conn()
# 构建更新字段
allowed_fields = ['name', 'type', 'definition', 'canonical_name']
updates = []
values = []
for field in allowed_fields:
if field in kwargs:
updates.append(f"{field} = ?")
values.append(kwargs[field])
# 处理别名
if 'aliases' in kwargs:
updates.append("aliases = ?")
values.append(json.dumps(kwargs['aliases']))
if not updates:
conn.close()
return self.get_entity(entity_id)
updates.append("updated_at = ?")
values.append(datetime.now().isoformat())
values.append(entity_id)
query = f"UPDATE entities SET {', '.join(updates)} WHERE id = ?"
conn.execute(query, values)
conn.commit()
conn.close()
return self.get_entity(entity_id)
def delete_entity(self, entity_id: str):
"""删除实体及其关联数据"""
conn = self.get_conn()
# 删除提及记录
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
# 删除关系
conn.execute("DELETE FROM entity_relations WHERE source_entity_id = ? OR target_entity_id = ?",
(entity_id, entity_id))
# 删除实体
conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,))
conn.commit()
conn.close()
def delete_relation(self, relation_id: str):
"""删除关系"""
conn = self.get_conn()
conn.execute("DELETE FROM entity_relations WHERE id = ?", (relation_id,))
conn.commit()
conn.close()
def update_relation(self, relation_id: str, **kwargs) -> dict:
"""更新关系"""
conn = self.get_conn()
allowed_fields = ['relation_type', 'evidence']
updates = []
values = []
for field in allowed_fields:
if field in kwargs:
updates.append(f"{field} = ?")
values.append(kwargs[field])
if updates:
query = f"UPDATE entity_relations SET {', '.join(updates)} WHERE id = ?"
values.append(relation_id)
conn.execute(query, values)
conn.commit()
row = conn.execute("SELECT * FROM entity_relations WHERE id = ?", (relation_id,)).fetchone()
conn.close()
return dict(row) if row else None
def update_transcript(self, transcript_id: str, full_text: str) -> dict:
"""更新转录文本"""
conn = self.get_conn()
now = datetime.now().isoformat()
conn.execute(
"UPDATE transcripts SET full_text = ?, updated_at = ? WHERE id = ?",
(full_text, now, transcript_id)
)
conn.commit()
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
conn.close()
return dict(row) if row else None
# Phase 3: Glossary operations
def add_glossary_term(self, project_id: str, term: str, pronunciation: str = "") -> str:
"""添加术语到术语表"""
conn = self.get_conn()
# 检查是否已存在
existing = conn.execute(
"SELECT * FROM glossary WHERE project_id = ? AND term = ?",
(project_id, term)
).fetchone()
if existing:
# 更新频率
conn.execute(
"UPDATE glossary SET frequency = frequency + 1 WHERE id = ?",
(existing['id'],)
)
conn.commit()
conn.close()
return existing['id']
term_id = str(uuid.uuid4())[:8]
conn.execute(
"INSERT INTO glossary (id, project_id, term, pronunciation, frequency) VALUES (?, ?, ?, ?, ?)",
(term_id, project_id, term, pronunciation, 1)
)
conn.commit()
conn.close()
return term_id
def list_glossary(self, project_id: str) -> List[dict]:
"""列出项目术语表"""
conn = self.get_conn()
rows = conn.execute(
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC",
(project_id,)
).fetchall()
conn.close()
return [dict(r) for r in rows]
def delete_glossary_term(self, term_id: str):
"""删除术语"""
conn = self.get_conn()
conn.execute("DELETE FROM glossary WHERE id = ?", (term_id,))
conn.commit()
conn.close()
# Phase 3: Get all entities for embedding
def get_all_entities_for_embedding(self, project_id: str) -> List[Entity]:
"""获取所有实体用于 embedding 计算"""
return self.list_project_entities(project_id)
# Singleton instance
_db_manager = None
def get_db_manager() -> DatabaseManager:
global _db_manager
if _db_manager is None:
_db_manager = DatabaseManager()
return _db_manager