feat: Phase 3 knowledge growth - multi-file fusion + entity alignment
This commit is contained in:
231
backend/db_manager.py
Normal file
231
backend/db_manager.py
Normal file
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
InsightFlow Database Manager
|
||||
处理项目、实体、关系的持久化
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
id: str
|
||||
project_id: str
|
||||
name: str
|
||||
type: str
|
||||
definition: str = ""
|
||||
canonical_name: str = ""
|
||||
aliases: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.aliases is None:
|
||||
self.aliases = []
|
||||
|
||||
@dataclass
|
||||
class EntityMention:
|
||||
id: str
|
||||
entity_id: str
|
||||
transcript_id: str
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
text_snippet: str
|
||||
confidence: float = 1.0
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
self.db_path = db_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self.init_db()
|
||||
|
||||
def get_conn(self):
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def init_db(self):
|
||||
"""初始化数据库表"""
|
||||
with open(os.path.join(os.path.dirname(__file__), 'schema.sql'), 'r') as f:
|
||||
schema = f.read()
|
||||
|
||||
conn = self.get_conn()
|
||||
conn.executescript(schema)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Project operations
|
||||
def create_project(self, project_id: str, name: str, description: str = "") -> Project:
|
||||
conn = self.get_conn()
|
||||
now = datetime.now().isoformat()
|
||||
conn.execute(
|
||||
"INSERT INTO projects (id, name, description, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||
(project_id, name, description, now, now)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now)
|
||||
|
||||
def get_project(self, project_id: str) -> Optional[Project]:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
return Project(**dict(row))
|
||||
return None
|
||||
|
||||
def list_projects(self) -> List[Project]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall()
|
||||
conn.close()
|
||||
return [Project(**dict(r)) for r in rows]
|
||||
|
||||
# Entity operations
|
||||
def create_entity(self, entity: Entity) -> Entity:
|
||||
conn = self.get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO entities (id, project_id, name, canonical_name, type, definition, aliases, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(entity.id, entity.project_id, entity.name, entity.canonical_name, entity.type,
|
||||
entity.definition, json.dumps(entity.aliases), datetime.now().isoformat(), datetime.now().isoformat())
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return entity
|
||||
|
||||
def get_entity_by_name(self, project_id: str, name: str) -> Optional[Entity]:
|
||||
"""通过名称查找实体(用于对齐)"""
|
||||
conn = self.get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM entities WHERE project_id = ? AND (name = ? OR canonical_name = ? OR aliases LIKE ?)",
|
||||
(project_id, name, name, f'%"{name}"%')
|
||||
).fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
data = dict(row)
|
||||
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
|
||||
return Entity(**data)
|
||||
return None
|
||||
|
||||
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> List[Entity]:
|
||||
"""查找相似实体(简单实现,生产可用 embedding)"""
|
||||
# TODO: 使用 embedding 或模糊匹配
|
||||
# 现在简单返回包含相同关键词的实体
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM entities WHERE project_id = ? AND name LIKE ?",
|
||||
(project_id, f"%{name}%")
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
entities = []
|
||||
for row in rows:
|
||||
data = dict(row)
|
||||
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
|
||||
entities.append(Entity(**data))
|
||||
return entities
|
||||
|
||||
def merge_entities(self, target_id: str, source_id: str) -> Entity:
|
||||
"""合并两个实体(实体对齐)"""
|
||||
conn = self.get_conn()
|
||||
|
||||
# 获取两个实体
|
||||
target = conn.execute("SELECT * FROM entities WHERE id = ?", (target_id,)).fetchone()
|
||||
source = conn.execute("SELECT * FROM entities WHERE id = ?", (source_id,)).fetchone()
|
||||
|
||||
if not target or not source:
|
||||
conn.close()
|
||||
raise ValueError("Entity not found")
|
||||
|
||||
# 合并别名
|
||||
target_aliases = set(json.loads(target['aliases']) if target['aliases'] else [])
|
||||
target_aliases.add(source['name'])
|
||||
target_aliases.update(json.loads(source['aliases']) if source['aliases'] else [])
|
||||
|
||||
# 更新目标实体
|
||||
conn.execute(
|
||||
"UPDATE entities SET aliases = ?, updated_at = ? WHERE id = ?",
|
||||
(json.dumps(list(target_aliases)), datetime.now().isoformat(), target_id)
|
||||
)
|
||||
|
||||
# 更新提及记录
|
||||
conn.execute(
|
||||
"UPDATE entity_mentions SET entity_id = ? WHERE entity_id = ?",
|
||||
(target_id, source_id)
|
||||
)
|
||||
|
||||
# 删除源实体
|
||||
conn.execute("DELETE FROM entities WHERE id = ?", (source_id,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return self.get_entity(target_id)
|
||||
|
||||
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
data = dict(row)
|
||||
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
|
||||
return Entity(**data)
|
||||
return None
|
||||
|
||||
def list_project_entities(self, project_id: str) -> List[Entity]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
entities = []
|
||||
for row in rows:
|
||||
data = dict(row)
|
||||
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
|
||||
entities.append(Entity(**data))
|
||||
return entities
|
||||
|
||||
# Mention operations
|
||||
def add_mention(self, mention: EntityMention) -> EntityMention:
|
||||
conn = self.get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO entity_mentions (id, entity_id, transcript_id, start_pos, end_pos, text_snippet, confidence)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
(mention.id, mention.entity_id, mention.transcript_id, mention.start_pos,
|
||||
mention.end_pos, mention.text_snippet, mention.confidence)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return mention
|
||||
|
||||
def get_entity_mentions(self, entity_id: str) -> List[EntityMention]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos",
|
||||
(entity_id,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return [EntityMention(**dict(r)) for r in rows]
|
||||
|
||||
# Singleton instance
|
||||
_db_manager = None
|
||||
|
||||
def get_db_manager() -> DatabaseManager:
|
||||
global _db_manager
|
||||
if _db_manager is None:
|
||||
_db_manager = DatabaseManager()
|
||||
return _db_manager
|
||||
Reference in New Issue
Block a user