Files
insightflow/backend/search_manager.py
OpenClaw Bot 74c2daa5ef fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
2026-02-28 09:11:38 +08:00

2224 lines
73 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
InsightFlow - 高级搜索与发现模块
Phase 7 Task 6: Advanced Search & Discovery
功能模块:
1. FullTextSearch - 全文搜索(关键词高亮、布尔搜索)
2. SemanticSearch - 语义搜索(基于 embedding 的相似度搜索)
3. EntityPathDiscovery - 实体关系路径发现
4. KnowledgeGapDetection - 知识缺口识别
"""
import hashlib
import json
import math
import re
import sqlite3
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
class SearchOperator(Enum):
"""搜索操作符"""
AND = "AND"
OR = "OR"
NOT = "NOT"
# 尝试导入 sentence-transformers 用于语义搜索
try:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
# ==================== 数据模型 ====================
@dataclass
class SearchResult:
"""搜索结果数据模型"""
id: str
content: str
content_type: str # transcript, entity, relation
project_id: str
score: float
highlights: list[tuple[int, int]] = field(default_factory=list) # 高亮位置
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"id": self.id,
"content": self.content,
"content_type": self.content_type,
"project_id": self.project_id,
"score": self.score,
"highlights": self.highlights,
"metadata": self.metadata,
}
@dataclass
class SemanticSearchResult:
"""语义搜索结果数据模型"""
id: str
content: str
content_type: str
project_id: str
similarity: float
embedding: list[float] | None = None
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
result = {
"id": self.id,
"content": self.content[:500] + "..." if len(self.content) > 500 else self.content,
"content_type": self.content_type,
"project_id": self.project_id,
"similarity": round(self.similarity, 4),
"metadata": self.metadata,
}
if self.embedding:
result["embedding_dim"] = len(self.embedding)
return result
@dataclass
class EntityPath:
"""实体关系路径数据模型"""
path_id: str
source_entity_id: str
source_entity_name: str
target_entity_id: str
target_entity_name: str
path_length: int
nodes: list[dict] # 路径上的节点
edges: list[dict] # 路径上的边
confidence: float
path_description: str
def to_dict(self) -> dict:
return {
"path_id": self.path_id,
"source_entity_id": self.source_entity_id,
"source_entity_name": self.source_entity_name,
"target_entity_id": self.target_entity_id,
"target_entity_name": self.target_entity_name,
"path_length": self.path_length,
"nodes": self.nodes,
"edges": self.edges,
"confidence": self.confidence,
"path_description": self.path_description,
}
@dataclass
class KnowledgeGap:
"""知识缺口数据模型"""
gap_id: str
gap_type: str # missing_attribute, sparse_relation, isolated_entity, incomplete_entity
entity_id: str | None
entity_name: str | None
description: str
severity: str # high, medium, low
suggestions: list[str]
related_entities: list[str]
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"gap_id": self.gap_id,
"gap_type": self.gap_type,
"entity_id": self.entity_id,
"entity_name": self.entity_name,
"description": self.description,
"severity": self.severity,
"suggestions": self.suggestions,
"related_entities": self.related_entities,
"metadata": self.metadata,
}
@dataclass
class SearchIndex:
"""搜索索引数据模型"""
id: str
content_id: str
content_type: str
project_id: str
tokens: list[str]
token_positions: dict[str, list[int]] # 词 -> 位置列表
created_at: str
updated_at: str
@dataclass
class TextEmbedding:
"""文本 Embedding 数据模型"""
id: str
content_id: str
content_type: str
project_id: str
embedding: list[float]
model_name: str
created_at: str
# ==================== 全文搜索 ====================
class FullTextSearch:
"""
全文搜索模块
功能:
- 跨所有转录文本搜索
- 支持关键词高亮
- 搜索结果排序(相关性)
- 支持布尔搜索AND/OR/NOT
"""
def __init__(self, db_path: str = "insightflow.db"):
self.db_path = db_path
self._init_search_tables()
def _get_conn(self) -> sqlite3.Connection:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def _init_search_tables(self) -> None:
"""初始化搜索相关表"""
conn = self._get_conn()
# 搜索索引表
conn.execute("""
CREATE TABLE IF NOT EXISTS search_indexes (
id TEXT PRIMARY KEY,
content_id TEXT NOT NULL,
content_type TEXT NOT NULL,
project_id TEXT NOT NULL,
tokens TEXT, -- JSON 数组
token_positions TEXT, -- JSON 对象
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(content_id, content_type)
)
""")
# 搜索词频统计表
conn.execute("""
CREATE TABLE IF NOT EXISTS search_term_freq (
term TEXT NOT NULL,
content_id TEXT NOT NULL,
content_type TEXT NOT NULL,
project_id TEXT NOT NULL,
frequency INTEGER DEFAULT 1,
positions TEXT, -- JSON 数组
PRIMARY KEY (term, content_id, content_type)
)
""")
# 创建索引
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_search_content ON search_indexes(content_id, content_type)"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_search_project ON search_indexes(project_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_term_freq_term ON search_term_freq(term)")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_term_freq_project ON search_term_freq(project_id)"
)
conn.commit()
conn.close()
def _tokenize(self, text: str) -> list[str]:
"""
中文分词(简化版)
实际生产环境可以使用 jieba 等分词工具
"""
# 清理文本
text = text.lower()
# 提取中文字符、英文单词和数字
tokens = re.findall(r"[\u4e00-\u9fa5]+|[a-z]+|\d+", text)
return tokens
def _extract_positions(self, text: str, tokens: list[str]) -> dict[str, list[int]]:
"""提取每个词在文本中的位置"""
positions = defaultdict(list)
text_lower = text.lower()
for token in tokens:
# 查找所有出现位置
start = 0
while True:
pos = text_lower.find(token, start)
if pos == -1:
break
positions[token].append(pos)
start = pos + 1
return dict(positions)
def index_content(self, content_id: str, content_type: str, project_id: str, text: str) -> bool:
"""
为内容创建搜索索引
Args:
content_id: 内容ID
content_type: 内容类型 (transcript, entity, relation)
project_id: 项目ID
text: 要索引的文本
Returns:
bool: 是否成功
"""
try:
conn = self._get_conn()
# 分词
tokens = self._tokenize(text)
if not tokens:
conn.close()
return False
# 提取位置信息
token_positions = self._extract_positions(text, tokens)
# 计算词频
token_freq = defaultdict(int)
for token in tokens:
token_freq[token] += 1
index_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16]
now = datetime.now().isoformat()
# 保存索引
conn.execute(
"""
INSERT OR REPLACE INTO search_indexes
(id, content_id, content_type, project_id, tokens, token_positions, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
index_id,
content_id,
content_type,
project_id,
json.dumps(tokens, ensure_ascii=False),
json.dumps(token_positions, ensure_ascii=False),
now,
now,
),
)
# 保存词频统计
for token, freq in token_freq.items():
positions = token_positions.get(token, [])
conn.execute(
"""
INSERT OR REPLACE INTO search_term_freq
(term, content_id, content_type, project_id, frequency, positions)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
token,
content_id,
content_type,
project_id,
freq,
json.dumps(positions, ensure_ascii=False),
),
)
conn.commit()
conn.close()
return True
except Exception as e:
print(f"索引创建失败: {e}")
return False
def search(
self,
query: str,
project_id: str | None = None,
content_types: list[str] | None = None,
limit: int = 20,
offset: int = 0,
) -> list[SearchResult]:
"""
全文搜索
Args:
query: 搜索查询(支持布尔语法)
project_id: 可选的项目ID过滤
content_types: 可选的内容类型过滤
limit: 返回结果数量限制
offset: 分页偏移
Returns:
List[SearchResult]: 搜索结果列表
"""
# 解析布尔查询
parsed_query = self._parse_boolean_query(query)
# 执行搜索
results = self._execute_boolean_search(parsed_query, project_id, content_types)
# 计算相关性分数
scored_results = self._score_results(results, parsed_query)
# 排序和分页
scored_results.sort(key=lambda x: x.score, reverse=True)
return scored_results[offset : offset + limit]
def _parse_boolean_query(self, query: str) -> dict:
"""
解析布尔查询
支持语法:
- AND: 词1 AND 词2
- OR: 词1 OR 词2
- NOT: NOT 词1 或 词1 -词2
- 短语: "精确短语"
"""
query = query.strip()
# 提取短语(引号内的内容)
phrases = re.findall(r'"([^"]+)"', query)
query_without_phrases = re.sub(r'"[^"]+"', "", query)
# 解析布尔操作
and_terms = []
or_terms = []
not_terms = []
# 处理 NOT
not_pattern = r"(?:NOT\s+|\-)(\w+)"
not_matches = re.findall(not_pattern, query_without_phrases, re.IGNORECASE)
not_terms.extend(not_matches)
query_without_phrases = re.sub(not_pattern, "", query_without_phrases, flags=re.IGNORECASE)
# 处理 OR
or_parts = re.split(r"\s+OR\s+", query_without_phrases, flags=re.IGNORECASE)
if len(or_parts) > 1:
or_terms = [p.strip() for p in or_parts[1:] if p.strip()]
query_without_phrases = or_parts[0]
# 剩余的作为 AND 条件
and_terms = [t.strip() for t in query_without_phrases.split() if t.strip()]
return {"and": and_terms + phrases, "or": or_terms, "not": not_terms, "phrases": phrases}
def _execute_boolean_search(
self,
parsed_query: dict,
project_id: str | None = None,
content_types: list[str] | None = None,
) -> list[dict]:
"""执行布尔搜索"""
conn = self._get_conn()
# 构建基础查询
base_where = []
params = []
if project_id:
base_where.append("project_id = ?")
params.append(project_id)
if content_types:
placeholders = ",".join(["?" for _ in content_types])
base_where.append(f"content_type IN ({placeholders})")
params.extend(content_types)
base_where_str = " AND ".join(base_where) if base_where else "1=1"
# 获取候选结果
candidates = set()
# 处理 AND 条件
if parsed_query["and"]:
for term in parsed_query["and"]:
term_results = conn.execute(
f"""
SELECT content_id, content_type, project_id, frequency, positions
FROM search_term_freq
WHERE term = ? AND {base_where_str}
""",
[term] + params,
).fetchall()
term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
if not candidates:
candidates = term_contents
else:
candidates &= term_contents # 交集
# 处理 OR 条件
if parsed_query["or"]:
for term in parsed_query["or"]:
term_results = conn.execute(
f"""
SELECT content_id, content_type, project_id, frequency, positions
FROM search_term_freq
WHERE term = ? AND {base_where_str}
""",
[term] + params,
).fetchall()
term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
candidates |= term_contents # 并集
# 如果没有 AND 和 OR但有 phrases使用 phrases
if not candidates and parsed_query["phrases"]:
for phrase in parsed_query["phrases"]:
phrase_tokens = self._tokenize(phrase)
if phrase_tokens:
# 查找包含所有短语的文档
for token in phrase_tokens:
term_results = conn.execute(
f"""
SELECT content_id, content_type, project_id, frequency, positions
FROM search_term_freq
WHERE term = ? AND {base_where_str}
""",
[token] + params,
).fetchall()
term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
if not candidates:
candidates = term_contents
else:
candidates &= term_contents
# 处理 NOT 条件(排除)
if parsed_query["not"]:
for term in parsed_query["not"]:
term_results = conn.execute(
f"""
SELECT content_id, content_type
FROM search_term_freq
WHERE term = ? AND {base_where_str}
""",
[term] + params,
).fetchall()
term_contents = {(r["content_id"], r["content_type"]) for r in term_results}
candidates -= term_contents # 差集
# 获取完整内容
results = []
for content_id, content_type in candidates:
# 获取原始内容
content = self._get_content_by_id(conn, content_id, content_type)
if content:
results.append(
{
"id": content_id,
"content_type": content_type,
"project_id": project_id
or self._get_project_id(conn, content_id, content_type),
"content": content,
"terms": parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"],
}
)
conn.close()
return results
def _get_content_by_id(
self, conn: sqlite3.Connection, content_id: str, content_type: str
) -> str | None:
"""根据ID获取内容"""
try:
if content_type == "transcript":
row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,)
).fetchone()
return row["full_text"] if row else None
elif content_type == "entity":
row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,)
).fetchone()
if row:
return f"{row['name']} {row['definition'] or ''}"
return None
elif content_type == "relation":
row = conn.execute(
"""SELECT r.relation_type, r.evidence,
e1.name as source_name, e2.name as target_name
FROM entity_relations r
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""",
(content_id,),
).fetchone()
if row:
return f"{row['source_name']} {row['relation_type']} {row['target_name']} {row['evidence'] or ''}"
return None
return None
except Exception as e:
print(f"获取内容失败: {e}")
return None
def _get_project_id(
self, conn: sqlite3.Connection, content_id: str, content_type: str
) -> str | None:
"""获取内容所属的项目ID"""
try:
if content_type == "transcript":
row = conn.execute(
"SELECT project_id FROM transcripts WHERE id = ?", (content_id,)
).fetchone()
elif content_type == "entity":
row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (content_id,)
).fetchone()
elif content_type == "relation":
row = conn.execute(
"SELECT project_id FROM entity_relations WHERE id = ?", (content_id,)
).fetchone()
else:
return None
return row["project_id"] if row else None
except (sqlite3.Error, KeyError):
return None
def _score_results(self, results: list[dict], parsed_query: dict) -> list[SearchResult]:
"""计算搜索结果的相关性分数"""
scored = []
all_terms = parsed_query["and"] + parsed_query["or"] + parsed_query["phrases"]
for result in results:
content = result["content"].lower()
# 基础分数
score = 0.0
highlights = []
# 计算每个词的匹配分数
for term in all_terms:
term_lower = term.lower()
count = content.count(term_lower)
if count > 0:
# TF 分数(词频)
tf_score = math.log(1 + count)
# 位置加分(标题/开头匹配分数更高)
position_bonus = 0
first_pos = content.find(term_lower)
if first_pos != -1:
if first_pos < 50: # 开头50个字符
position_bonus = 2.0
elif first_pos < 200: # 开头200个字符
position_bonus = 1.0
# 记录高亮位置
start = first_pos
while start != -1:
highlights.append((start, start + len(term)))
start = content.find(term_lower, start + 1)
score += tf_score + position_bonus
# 短语匹配额外加分
for phrase in parsed_query["phrases"]:
if phrase.lower() in content:
score *= 1.5 # 短语匹配加权
# 归一化分数
score = min(score / max(len(all_terms), 1), 10.0)
scored.append(
SearchResult(
id=result["id"],
content=result["content"],
content_type=result["content_type"],
project_id=result["project_id"],
score=round(score, 4),
highlights=highlights[:10], # 限制高亮数量
metadata={},
)
)
return scored
def highlight_text(self, text: str, query: str, max_length: int = 300) -> str:
"""
高亮文本中的关键词
Args:
text: 原始文本
query: 搜索查询
max_length: 返回文本的最大长度
Returns:
str: 带高亮标记的文本
"""
parsed = self._parse_boolean_query(query)
all_terms = parsed["and"] + parsed["or"] + parsed["phrases"]
# 找到第一个匹配位置
first_match = len(text)
for term in all_terms:
pos = text.lower().find(term.lower())
if pos != -1 and pos < first_match:
first_match = pos
# 截取上下文
start = max(0, first_match - 100)
end = min(len(text), start + max_length)
snippet = text[start:end]
if start > 0:
snippet = "..." + snippet
if end < len(text):
snippet = snippet + "..."
# 添加高亮标记
for term in sorted(all_terms, key=len, reverse=True): # 长的先替换
pattern = re.compile(re.escape(term), re.IGNORECASE)
snippet = pattern.sub(f"**{term}**", snippet)
return snippet
def delete_index(self, content_id: str, content_type: str) -> bool:
"""删除内容的搜索索引"""
try:
conn = self._get_conn()
# 删除索引
conn.execute(
"DELETE FROM search_indexes WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
)
# 删除词频统计
conn.execute(
"DELETE FROM search_term_freq WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
)
conn.commit()
conn.close()
return True
except Exception as e:
print(f"删除索引失败: {e}")
return False
def reindex_project(self, project_id: str) -> dict:
"""重新索引整个项目"""
conn = self._get_conn()
stats = {"transcripts": 0, "entities": 0, "relations": 0, "errors": 0}
try:
# 索引转录文本
transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,),
).fetchall()
for t in transcripts:
if t["full_text"]:
if self.index_content(t["id"], "transcript", t["project_id"], t["full_text"]):
stats["transcripts"] += 1
else:
stats["errors"] += 1
# 索引实体
entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,),
).fetchall()
for e in entities:
text = f"{e['name']} {e['definition'] or ''}"
if self.index_content(e["id"], "entity", e["project_id"], text):
stats["entities"] += 1
else:
stats["errors"] += 1
# 索引关系
relations = conn.execute(
"""SELECT r.id, r.project_id, r.relation_type, r.evidence,
e1.name as source_name, e2.name as target_name
FROM entity_relations r
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ?""",
(project_id,),
).fetchall()
for r in relations:
text = f"{r['source_name']} {r['relation_type']} {r['target_name']} {r['evidence'] or ''}"
if self.index_content(r["id"], "relation", r["project_id"], text):
stats["relations"] += 1
else:
stats["errors"] += 1
except Exception as e:
print(f"重新索引失败: {e}")
stats["errors"] += 1
conn.close()
return stats
# ==================== 语义搜索 ====================
class SemanticSearch:
"""
语义搜索模块
功能:
- 基于 embedding 的相似度搜索
- 使用 sentence-transformers 生成文本 embedding
- 支持余弦相似度计算
- 语义相似内容推荐
"""
def __init__(
self,
db_path: str = "insightflow.db",
model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
):
self.db_path = db_path
self.model_name = model_name
self.model = None
self._init_embedding_tables()
# 延迟加载模型
if SENTENCE_TRANSFORMERS_AVAILABLE:
try:
self.model = SentenceTransformer(model_name)
print(f"语义搜索模型加载成功: {model_name}")
except Exception as e:
print(f"模型加载失败: {e}")
def _get_conn(self) -> sqlite3.Connection:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def _init_embedding_tables(self) -> None:
"""初始化 embedding 相关表"""
conn = self._get_conn()
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id TEXT PRIMARY KEY,
content_id TEXT NOT NULL,
content_type TEXT NOT NULL,
project_id TEXT NOT NULL,
embedding TEXT, -- JSON 数组
model_name TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(content_id, content_type)
)
""")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_embedding_content ON embeddings(content_id, content_type)"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_embedding_project ON embeddings(project_id)")
conn.commit()
conn.close()
def is_available(self) -> bool:
"""检查语义搜索是否可用"""
return self.model is not None and SENTENCE_TRANSFORMERS_AVAILABLE
def generate_embedding(self, text: str) -> list[float] | None:
"""
生成文本的 embedding 向量
Args:
text: 输入文本
Returns:
Optional[List[float]]: embedding 向量
"""
if not self.is_available():
return None
try:
# 截断长文本
max_chars = 5000
if len(text) > max_chars:
text = text[:max_chars]
embedding = self.model.encode(text, convert_to_list=True)
return embedding
except Exception as e:
print(f"生成 embedding 失败: {e}")
return None
def index_embedding(
self, content_id: str, content_type: str, project_id: str, text: str
) -> bool:
"""
为内容生成并保存 embedding
Args:
content_id: 内容ID
content_type: 内容类型
project_id: 项目ID
text: 文本内容
Returns:
bool: 是否成功
"""
if not self.is_available():
return False
try:
embedding = self.generate_embedding(text)
if not embedding:
return False
conn = self._get_conn()
embedding_id = hashlib.md5(f"{content_id}:{content_type}".encode()).hexdigest()[:16]
conn.execute(
"""
INSERT OR REPLACE INTO embeddings
(id, content_id, content_type, project_id, embedding, model_name, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
embedding_id,
content_id,
content_type,
project_id,
json.dumps(embedding),
self.model_name,
datetime.now().isoformat(),
),
)
conn.commit()
conn.close()
return True
except Exception as e:
print(f"索引 embedding 失败: {e}")
return False
def search(
self,
query: str,
project_id: str | None = None,
content_types: list[str] | None = None,
top_k: int = 10,
threshold: float = 0.5,
) -> list[SemanticSearchResult]:
"""
语义搜索
Args:
query: 搜索查询
project_id: 可选的项目ID过滤
content_types: 可选的内容类型过滤
top_k: 返回结果数量
threshold: 相似度阈值
Returns:
List[SemanticSearchResult]: 语义搜索结果
"""
if not self.is_available():
return []
# 生成查询的 embedding
query_embedding = self.generate_embedding(query)
if not query_embedding:
return []
# 获取候选 embedding
conn = self._get_conn()
where_clauses = []
params = []
if project_id:
where_clauses.append("project_id = ?")
params.append(project_id)
if content_types:
placeholders = ",".join(["?" for _ in content_types])
where_clauses.append(f"content_type IN ({placeholders})")
params.extend(content_types)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
rows = conn.execute(
f"""
SELECT content_id, content_type, project_id, embedding
FROM embeddings
WHERE {where_str}
""",
params,
).fetchall()
conn.close()
# 计算相似度
results = []
query_vec = [query_embedding]
for row in rows:
try:
content_embedding = json.loads(row["embedding"])
# 计算余弦相似度
similarity = cosine_similarity(query_vec, [content_embedding])[0][0]
if similarity >= threshold:
# 获取原始内容
content = self._get_content_text(row["content_id"], row["content_type"])
results.append(
SemanticSearchResult(
id=row["content_id"],
content=content or "",
content_type=row["content_type"],
project_id=row["project_id"],
similarity=float(similarity),
embedding=None, # 不返回 embedding 以节省带宽
metadata={},
)
)
except Exception as e:
print(f"计算相似度失败: {e}")
continue
# 排序并返回 top_k
results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k]
def _get_content_text(self, content_id: str, content_type: str) -> str | None:
"""获取内容文本"""
conn = self._get_conn()
try:
if content_type == "transcript":
row = conn.execute(
"SELECT full_text FROM transcripts WHERE id = ?", (content_id,)
).fetchone()
result = row["full_text"] if row else None
elif content_type == "entity":
row = conn.execute(
"SELECT name, definition FROM entities WHERE id = ?", (content_id,)
).fetchone()
result = f"{row['name']}: {row['definition']}" if row else None
elif content_type == "relation":
row = conn.execute(
"""SELECT r.relation_type, r.evidence,
e1.name as source_name, e2.name as target_name
FROM entity_relations r
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.id = ?""",
(content_id,),
).fetchone()
result = (
f"{row['source_name']} {row['relation_type']} {row['target_name']}"
if row
else None
)
else:
result = None
conn.close()
return result
except Exception as e:
conn.close()
print(f"获取内容失败: {e}")
return None
def find_similar_content(
self, content_id: str, content_type: str, top_k: int = 5
) -> list[SemanticSearchResult]:
"""
查找与指定内容相似的内容
Args:
content_id: 内容ID
content_type: 内容类型
top_k: 返回结果数量
Returns:
List[SemanticSearchResult]: 相似内容列表
"""
if not self.is_available():
return []
# 获取源内容的 embedding
conn = self._get_conn()
row = conn.execute(
"SELECT embedding, project_id FROM embeddings WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
).fetchone()
if not row:
conn.close()
return []
source_embedding = json.loads(row["embedding"])
project_id = row["project_id"]
# 获取其他内容的 embedding
rows = conn.execute(
"""SELECT content_id, content_type, project_id, embedding
FROM embeddings
WHERE project_id = ? AND (content_id != ? OR content_type != ?)""",
(project_id, content_id, content_type),
).fetchall()
conn.close()
# 计算相似度
results = []
source_vec = [source_embedding]
for row in rows:
try:
content_embedding = json.loads(row["embedding"])
similarity = cosine_similarity(source_vec, [content_embedding])[0][0]
content = self._get_content_text(row["content_id"], row["content_type"])
results.append(
SemanticSearchResult(
id=row["content_id"],
content=content or "",
content_type=row["content_type"],
project_id=row["project_id"],
similarity=float(similarity),
metadata={},
)
)
except (KeyError, ValueError):
continue
results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k]
def delete_embedding(self, content_id: str, content_type: str) -> bool:
"""删除内容的 embedding"""
try:
conn = self._get_conn()
conn.execute(
"DELETE FROM embeddings WHERE content_id = ? AND content_type = ?",
(content_id, content_type),
)
conn.commit()
conn.close()
return True
except Exception as e:
print(f"删除 embedding 失败: {e}")
return False
# ==================== 实体关系路径发现 ====================
class EntityPathDiscovery:
"""
实体关系路径发现模块
功能:
- 查找两个实体之间的关联路径
- 支持最短路径算法
- 支持多跳关系发现
- 路径可视化数据生成
"""
def __init__(self, db_path: str = "insightflow.db"):
self.db_path = db_path
def _get_conn(self) -> sqlite3.Connection:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def find_shortest_path(
self, source_entity_id: str, target_entity_id: str, max_depth: int = 5
) -> EntityPath | None:
"""
查找两个实体之间的最短路径BFS算法
Args:
source_entity_id: 源实体ID
target_entity_id: 目标实体ID
max_depth: 最大搜索深度
Returns:
Optional[EntityPath]: 最短路径
"""
conn = self._get_conn()
# 获取项目ID
row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)
).fetchone()
if not row:
conn.close()
return None
project_id = row["project_id"]
# 验证目标实体也在同一项目
row = conn.execute(
"SELECT 1 FROM entities WHERE id = ? AND project_id = ?", (target_entity_id, project_id)
).fetchone()
if not row:
conn.close()
return None
# BFS
visited = {source_entity_id}
queue = [(source_entity_id, [source_entity_id])]
while queue:
current_id, path = queue.pop(0)
if len(path) > max_depth + 1:
continue
if current_id == target_entity_id:
# 找到路径
conn.close()
return self._build_path_object(path, project_id)
# 获取邻居
neighbors = conn.execute(
"""
SELECT target_entity_id as neighbor_id, relation_type, evidence
FROM entity_relations
WHERE source_entity_id = ? AND project_id = ?
UNION
SELECT source_entity_id as neighbor_id, relation_type, evidence
FROM entity_relations
WHERE target_entity_id = ? AND project_id = ?
""",
(current_id, project_id, current_id, project_id),
).fetchall()
for neighbor in neighbors:
neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited:
visited.add(neighbor_id)
queue.append((neighbor_id, path + [neighbor_id]))
conn.close()
return None
def find_all_paths(
self, source_entity_id: str, target_entity_id: str, max_depth: int = 4, max_paths: int = 10
) -> list[EntityPath]:
"""
查找两个实体之间的所有路径(限制数量和深度)
Args:
source_entity_id: 源实体ID
target_entity_id: 目标实体ID
max_depth: 最大路径深度
max_paths: 最大返回路径数
Returns:
List[EntityPath]: 路径列表
"""
conn = self._get_conn()
# 获取项目ID
row = conn.execute(
"SELECT project_id FROM entities WHERE id = ?", (source_entity_id,)
).fetchone()
if not row:
conn.close()
return []
project_id = row["project_id"]
paths = []
def dfs(current_id: str, target_id: str, path: list[str], visited: set[str], depth: int):
if depth > max_depth:
return
if current_id == target_id:
paths.append(path.copy())
return
# 获取邻居
neighbors = conn.execute(
"""
SELECT target_entity_id as neighbor_id
FROM entity_relations
WHERE source_entity_id = ? AND project_id = ?
UNION
SELECT source_entity_id as neighbor_id
FROM entity_relations
WHERE target_entity_id = ? AND project_id = ?
""",
(current_id, project_id, current_id, project_id),
).fetchall()
for neighbor in neighbors:
neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited and len(paths) < max_paths:
visited.add(neighbor_id)
path.append(neighbor_id)
dfs(neighbor_id, target_id, path, visited, depth + 1)
path.pop()
visited.remove(neighbor_id)
visited = {source_entity_id}
dfs(source_entity_id, target_entity_id, [source_entity_id], visited, 0)
conn.close()
# 构建路径对象
return [self._build_path_object(path, project_id) for path in paths]
def _build_path_object(self, entity_ids: list[str], project_id: str) -> EntityPath:
"""构建路径对象"""
conn = self._get_conn()
# 获取实体信息
nodes = []
for entity_id in entity_ids:
row = conn.execute(
"SELECT id, name, type FROM entities WHERE id = ?", (entity_id,)
).fetchone()
if row:
nodes.append({"id": row["id"], "name": row["name"], "type": row["type"]})
# 获取边信息
edges = []
for i in range(len(entity_ids) - 1):
source_id = entity_ids[i]
target_id = entity_ids[i + 1]
row = conn.execute(
"""
SELECT id, relation_type, evidence
FROM entity_relations
WHERE ((source_entity_id = ? AND target_entity_id = ?)
OR (source_entity_id = ? AND target_entity_id = ?))
AND project_id = ?
""",
(source_id, target_id, target_id, source_id, project_id),
).fetchone()
if row:
edges.append(
{
"id": row["id"],
"source": source_id,
"target": target_id,
"relation_type": row["relation_type"],
"evidence": row["evidence"],
}
)
conn.close()
# 生成路径描述
node_names = [n["name"] for n in nodes]
path_desc = "".join(node_names)
# 计算置信度(基于路径长度和关系数量)
confidence = 1.0 / (len(entity_ids) - 1) if len(entity_ids) > 1 else 1.0
return EntityPath(
path_id=f"path_{entity_ids[0]}_{entity_ids[-1]}_{hash(tuple(entity_ids))}",
source_entity_id=entity_ids[0],
source_entity_name=nodes[0]["name"] if nodes else "",
target_entity_id=entity_ids[-1],
target_entity_name=nodes[-1]["name"] if nodes else "",
path_length=len(entity_ids) - 1,
nodes=nodes,
edges=edges,
confidence=round(confidence, 4),
path_description=path_desc,
)
def find_multi_hop_relations(self, entity_id: str, max_hops: int = 3) -> list[dict]:
"""
查找实体的多跳关系
Args:
entity_id: 实体ID
max_hops: 最大跳数
Returns:
List[Dict]: 多跳关系列表
"""
conn = self._get_conn()
# 获取项目ID
row = conn.execute(
"SELECT project_id, name FROM entities WHERE id = ?", (entity_id,)
).fetchone()
if not row:
conn.close()
return []
project_id = row["project_id"]
row["name"]
# BFS 收集多跳关系
visited = {entity_id: 0}
queue = [(entity_id, 0)]
relations = []
while queue:
current_id, depth = queue.pop(0)
if depth >= max_hops:
continue
# 获取邻居
neighbors = conn.execute(
"""
SELECT
CASE
WHEN source_entity_id = ? THEN target_entity_id
ELSE source_entity_id
END as neighbor_id,
relation_type,
evidence
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
""",
(current_id, current_id, current_id, project_id),
).fetchall()
for neighbor in neighbors:
neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited:
visited[neighbor_id] = depth + 1
queue.append((neighbor_id, depth + 1))
# 获取邻居信息
neighbor_info = conn.execute(
"SELECT name, type FROM entities WHERE id = ?", (neighbor_id,)
).fetchone()
if neighbor_info:
relations.append(
{
"entity_id": neighbor_id,
"entity_name": neighbor_info["name"],
"entity_type": neighbor_info["type"],
"hops": depth + 1,
"relation_type": neighbor["relation_type"],
"evidence": neighbor["evidence"],
"path": self._get_path_to_entity(
entity_id, neighbor_id, project_id, conn
),
}
)
conn.close()
# 按跳数排序
relations.sort(key=lambda x: x["hops"])
return relations
def _get_path_to_entity(
self, source_id: str, target_id: str, project_id: str, conn: sqlite3.Connection
) -> list[str]:
"""获取从源实体到目标实体的路径(简化版)"""
# BFS 找路径
visited = {source_id}
queue = [(source_id, [source_id])]
while queue:
current, path = queue.pop(0)
if current == target_id:
return path
if len(path) > 5: # 限制路径长度
continue
neighbors = conn.execute(
"""
SELECT
CASE
WHEN source_entity_id = ? THEN target_entity_id
ELSE source_entity_id
END as neighbor_id
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
""",
(current, current, current, project_id),
).fetchall()
for neighbor in neighbors:
neighbor_id = neighbor["neighbor_id"]
if neighbor_id not in visited:
visited.add(neighbor_id)
queue.append((neighbor_id, path + [neighbor_id]))
return []
def generate_path_visualization(self, path: EntityPath) -> dict:
"""
生成路径可视化数据
Args:
path: 实体路径
Returns:
Dict: D3.js 可视化数据格式
"""
# 节点数据
nodes = []
for node in path.nodes:
nodes.append(
{
"id": node["id"],
"name": node["name"],
"type": node["type"],
"is_source": node["id"] == path.source_entity_id,
"is_target": node["id"] == path.target_entity_id,
}
)
# 边数据
links = []
for edge in path.edges:
links.append(
{
"source": edge["source"],
"target": edge["target"],
"relation_type": edge["relation_type"],
"evidence": edge["evidence"],
}
)
return {
"nodes": nodes,
"links": links,
"path_description": path.path_description,
"path_length": path.path_length,
"confidence": path.confidence,
}
def analyze_path_centrality(self, project_id: str) -> list[dict]:
"""
分析项目中实体的路径中心性(桥接程度)
Args:
project_id: 项目ID
Returns:
List[Dict]: 中心性分析结果
"""
conn = self._get_conn()
# 获取所有实体
entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
# 计算每个实体作为桥梁的次数
bridge_scores = []
for entity in entities:
entity_id = entity["id"]
# 计算该实体连接的不同群组数量
neighbors = conn.execute(
"""
SELECT
CASE
WHEN source_entity_id = ? THEN target_entity_id
ELSE source_entity_id
END as neighbor_id
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
""",
(entity_id, entity_id, entity_id, project_id),
).fetchall()
neighbor_ids = {n["neighbor_id"] for n in neighbors}
# 计算邻居之间的连接数(用于评估桥接程度)
if len(neighbor_ids) > 1:
connections = conn.execute(
f"""
SELECT COUNT(*) as count
FROM entity_relations
WHERE ((source_entity_id IN ({",".join(["?" for _ in neighbor_ids])})
AND target_entity_id IN ({",".join(["?" for _ in neighbor_ids])}))
OR (target_entity_id IN ({",".join(["?" for _ in neighbor_ids])})
AND source_entity_id IN ({",".join(["?" for _ in neighbor_ids])})))
AND project_id = ?
""",
list(neighbor_ids) * 4 + [project_id],
).fetchone()
# 桥接分数 = 邻居数量 / (邻居间连接数 + 1)
bridge_score = len(neighbor_ids) / (connections["count"] + 1)
else:
bridge_score = 0
bridge_scores.append(
{
"entity_id": entity_id,
"entity_name": entity["name"],
"neighbor_count": len(neighbor_ids),
"bridge_score": round(bridge_score, 4),
}
)
conn.close()
# 按桥接分数排序
bridge_scores.sort(key=lambda x: x["bridge_score"], reverse=True)
return bridge_scores[:20] # 返回前20
# ==================== 知识缺口识别 ====================
class KnowledgeGapDetection:
"""
知识缺口识别模块
功能:
- 识别项目中缺失的关键信息
- 实体属性完整性检查
- 关系稀疏度分析
- 生成知识补全建议
"""
def __init__(self, db_path: str = "insightflow.db"):
self.db_path = db_path
def _get_conn(self) -> sqlite3.Connection:
"""获取数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def analyze_project(self, project_id: str) -> list[KnowledgeGap]:
"""
分析项目中的知识缺口
Args:
project_id: 项目ID
Returns:
List[KnowledgeGap]: 知识缺口列表
"""
gaps = []
# 1. 检查实体属性完整性
gaps.extend(self._check_entity_attribute_completeness(project_id))
# 2. 检查关系稀疏度
gaps.extend(self._check_relation_sparsity(project_id))
# 3. 检查孤立实体
gaps.extend(self._check_isolated_entities(project_id))
# 4. 检查不完整实体
gaps.extend(self._check_incomplete_entities(project_id))
# 5. 检查关键实体缺失
gaps.extend(self._check_missing_key_entities(project_id))
# 按严重程度排序
severity_order = {"high": 0, "medium": 1, "low": 2}
gaps.sort(key=lambda x: severity_order.get(x.severity, 3))
return gaps
def _check_entity_attribute_completeness(self, project_id: str) -> list[KnowledgeGap]:
"""检查实体属性完整性"""
conn = self._get_conn()
gaps = []
# 获取项目的属性模板
templates = conn.execute(
"SELECT id, name, type, is_required FROM attribute_templates WHERE project_id = ?",
(project_id,),
).fetchall()
if not templates:
conn.close()
return []
required_template_ids = {t["id"] for t in templates if t["is_required"]}
if not required_template_ids:
conn.close()
return []
# 检查每个实体的属性完整性
entities = conn.execute(
"SELECT id, name FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
for entity in entities:
entity_id = entity["id"]
# 获取实体已有的属性
existing_attrs = conn.execute(
"SELECT template_id FROM entity_attributes WHERE entity_id = ?", (entity_id,)
).fetchall()
existing_template_ids = {a["template_id"] for a in existing_attrs}
# 找出缺失的必需属性
missing_templates = required_template_ids - existing_template_ids
if missing_templates:
missing_names = []
for template_id in missing_templates:
template = conn.execute(
"SELECT name FROM attribute_templates WHERE id = ?", (template_id,)
).fetchone()
if template:
missing_names.append(template["name"])
if missing_names:
gaps.append(
KnowledgeGap(
gap_id=f"gap_attr_{entity_id}",
gap_type="missing_attribute",
entity_id=entity_id,
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 缺少必需属性: {', '.join(missing_names)}",
severity="medium",
suggestions=[
f"为实体 '{entity['name']}' 补充以下属性: {', '.join(missing_names)}",
"检查属性模板定义是否合理",
],
related_entities=[],
metadata={"missing_attributes": missing_names},
)
)
conn.close()
return gaps
def _check_relation_sparsity(self, project_id: str) -> list[KnowledgeGap]:
"""检查关系稀疏度"""
conn = self._get_conn()
gaps = []
# 获取所有实体及其关系数量
entities = conn.execute(
"SELECT id, name, type FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
for entity in entities:
entity_id = entity["id"]
# 计算关系数量
relation_count = conn.execute(
"""
SELECT COUNT(*) as count
FROM entity_relations
WHERE (source_entity_id = ? OR target_entity_id = ?)
AND project_id = ?
""",
(entity_id, entity_id, project_id),
).fetchone()["count"]
# 根据实体类型判断阈值
threshold = 1 if entity["type"] in ["PERSON", "ORG"] else 0
if relation_count <= threshold:
# 查找潜在的相关实体
potential_related = conn.execute(
"""
SELECT e.id, e.name
FROM entities e
JOIN transcripts t ON t.project_id = e.project_id
WHERE e.project_id = ?
AND e.id != ?
AND t.full_text LIKE ?
LIMIT 5
""",
(project_id, entity_id, f"%{entity['name']}%"),
).fetchall()
gaps.append(
KnowledgeGap(
gap_id=f"gap_sparse_{entity_id}",
gap_type="sparse_relation",
entity_id=entity_id,
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 关系稀疏(仅有 {relation_count} 个关系)",
severity="medium" if relation_count == 0 else "low",
suggestions=[
f"检查转录文本中提及 '{entity['name']}' 的其他实体",
f"手动添加 '{entity['name']}' 与其他实体的关系",
"使用实体对齐功能合并相似实体",
],
related_entities=[r["id"] for r in potential_related],
metadata={
"relation_count": relation_count,
"potential_related": [r["name"] for r in potential_related],
},
)
)
conn.close()
return gaps
def _check_isolated_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查孤立实体(没有任何关系)"""
conn = self._get_conn()
gaps = []
# 查找没有关系的实体
isolated = conn.execute(
"""
SELECT e.id, e.name, e.type
FROM entities e
LEFT JOIN entity_relations r1 ON e.id = r1.source_entity_id
LEFT JOIN entity_relations r2 ON e.id = r2.target_entity_id
WHERE e.project_id = ?
AND r1.id IS NULL
AND r2.id IS NULL
""",
(project_id,),
).fetchall()
for entity in isolated:
gaps.append(
KnowledgeGap(
gap_id=f"gap_iso_{entity['id']}",
gap_type="isolated_entity",
entity_id=entity["id"],
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 是孤立实体(没有任何关系)",
severity="high",
suggestions=[
f"检查 '{entity['name']}' 是否应该与其他实体建立关系",
f"考虑删除不相关的实体 '{entity['name']}'",
"运行关系发现算法自动识别潜在关系",
],
related_entities=[],
metadata={"entity_type": entity["type"]},
)
)
conn.close()
return gaps
def _check_incomplete_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查不完整实体(缺少名称、类型或定义)"""
conn = self._get_conn()
gaps = []
# 查找缺少定义的实体
incomplete = conn.execute(
"""
SELECT id, name, type, definition
FROM entities
WHERE project_id = ?
AND (definition IS NULL OR definition = '')
""",
(project_id,),
).fetchall()
for entity in incomplete:
gaps.append(
KnowledgeGap(
gap_id=f"gap_inc_{entity['id']}",
gap_type="incomplete_entity",
entity_id=entity["id"],
entity_name=entity["name"],
description=f"实体 '{entity['name']}' 缺少定义",
severity="low",
suggestions=[f"'{entity['name']}' 添加定义", "从转录文本中提取定义信息"],
related_entities=[],
metadata={"entity_type": entity["type"]},
)
)
conn.close()
return gaps
def _check_missing_key_entities(self, project_id: str) -> list[KnowledgeGap]:
"""检查可能缺失的关键实体"""
conn = self._get_conn()
gaps = []
# 分析转录文本中频繁提及但未提取为实体的词
transcripts = conn.execute(
"SELECT full_text FROM transcripts WHERE project_id = ?", (project_id,)
).fetchall()
# 合并所有文本
all_text = " ".join([t["full_text"] or "" for t in transcripts])
# 获取现有实体名称
existing_entities = conn.execute(
"SELECT name FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
existing_names = {e["name"].lower() for e in existing_entities}
# 简单的关键词提取(实际可以使用更复杂的 NLP 方法)
# 查找大写的词组(可能是专有名词)
potential_entities = re.findall(r"[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*", all_text)
# 统计频率
freq = defaultdict(int)
for entity in potential_entities:
if len(entity) > 3 and entity.lower() not in existing_names:
freq[entity] += 1
# 找出高频但未提取的词
for entity, count in freq.items():
if count >= 3: # 出现3次以上
gaps.append(
KnowledgeGap(
gap_id=f"gap_missing_{hash(entity) % 10000}",
gap_type="missing_key_entity",
entity_id=None,
entity_name=None,
description=f"文本中频繁提及 '{entity}' 但未提取为实体(出现 {count} 次)",
severity="low",
suggestions=[
f"考虑将 '{entity}' 添加为实体",
"检查实体提取算法是否需要优化",
],
related_entities=[],
metadata={"mention_count": count},
)
)
conn.close()
return gaps[:10] # 限制数量
def generate_completeness_report(self, project_id: str) -> dict:
"""
生成知识完整性报告
Args:
project_id: 项目ID
Returns:
Dict: 完整性报告
"""
conn = self._get_conn()
# 基础统计
stats = conn.execute(
"""
SELECT
(SELECT COUNT(*) FROM entities WHERE project_id = ?) as entity_count,
(SELECT COUNT(*) FROM entity_relations WHERE project_id = ?) as relation_count,
(SELECT COUNT(*) FROM transcripts WHERE project_id = ?) as transcript_count
""",
(project_id, project_id, project_id),
).fetchone()
# 计算完整性分数
gaps = self.analyze_project(project_id)
# 按类型统计
gap_by_type = defaultdict(int)
severity_count = {"high": 0, "medium": 0, "low": 0}
for gap in gaps:
gap_by_type[gap.gap_type] += 1
severity_count[gap.severity] += 1
# 计算完整性分数100 - 扣分)
score = 100
score -= severity_count["high"] * 10
score -= severity_count["medium"] * 5
score -= severity_count["low"] * 2
score = max(0, score)
conn.close()
return {
"project_id": project_id,
"completeness_score": score,
"statistics": {
"entity_count": stats["entity_count"],
"relation_count": stats["relation_count"],
"transcript_count": stats["transcript_count"],
},
"gap_summary": {
"total": len(gaps),
"by_type": dict(gap_by_type),
"by_severity": severity_count,
},
"top_gaps": [g.to_dict() for g in gaps[:10]],
"recommendations": self._generate_recommendations(gaps),
}
def _generate_recommendations(self, gaps: list[KnowledgeGap]) -> list[str]:
"""生成改进建议"""
recommendations = []
gap_types = {g.gap_type for g in gaps}
if "isolated_entity" in gap_types:
recommendations.append("优先处理孤立实体,建立实体间的关系连接")
if "missing_attribute" in gap_types:
recommendations.append("完善实体属性信息,补充必需的属性字段")
if "sparse_relation" in gap_types:
recommendations.append("运行自动关系发现算法,识别更多实体关系")
if "incomplete_entity" in gap_types:
recommendations.append("为缺少定义的实体补充描述信息")
if "missing_key_entity" in gap_types:
recommendations.append("优化实体提取算法,确保关键实体被正确识别")
if not recommendations:
recommendations.append("知识图谱完整性良好,继续保持")
return recommendations
# ==================== 搜索管理器 ====================
class SearchManager:
"""
搜索管理器 - 统一入口
整合全文搜索、语义搜索、实体路径发现和知识缺口识别功能
"""
def __init__(self, db_path: str = "insightflow.db"):
self.db_path = db_path
self.fulltext_search = FullTextSearch(db_path)
self.semantic_search = SemanticSearch(db_path)
self.path_discovery = EntityPathDiscovery(db_path)
self.gap_detection = KnowledgeGapDetection(db_path)
def hybrid_search(self, query: str, project_id: str | None = None, limit: int = 20) -> dict:
"""
混合搜索(全文 + 语义)
Args:
query: 搜索查询
project_id: 可选的项目ID
limit: 返回结果数量
Returns:
Dict: 混合搜索结果
"""
# 全文搜索
fulltext_results = self.fulltext_search.search(query, project_id, limit=limit)
# 语义搜索
semantic_results = []
if self.semantic_search.is_available():
semantic_results = self.semantic_search.search(query, project_id, top_k=limit)
# 合并结果(去重并加权)
combined = {}
# 添加全文搜索结果
for r in fulltext_results:
key = (r.id, r.content_type)
combined[key] = {
"id": r.id,
"content": r.content,
"content_type": r.content_type,
"project_id": r.project_id,
"fulltext_score": r.score,
"semantic_score": 0,
"combined_score": r.score * 0.6, # 全文权重 60%
"highlights": r.highlights,
}
# 添加语义搜索结果
for r in semantic_results:
key = (r.id, r.content_type)
if key in combined:
combined[key]["semantic_score"] = r.similarity
combined[key]["combined_score"] += r.similarity * 0.4 # 语义权重 40%
else:
combined[key] = {
"id": r.id,
"content": r.content,
"content_type": r.content_type,
"project_id": r.project_id,
"fulltext_score": 0,
"semantic_score": r.similarity,
"combined_score": r.similarity * 0.4,
"highlights": [],
}
# 排序
results = list(combined.values())
results.sort(key=lambda x: x["combined_score"], reverse=True)
return {
"query": query,
"project_id": project_id,
"total": len(results),
"fulltext_count": len(fulltext_results),
"semantic_count": len(semantic_results),
"results": results[:limit],
}
def index_project(self, project_id: str) -> dict:
"""
为项目建立所有索引
Args:
project_id: 项目ID
Returns:
Dict: 索引统计
"""
# 全文索引
fulltext_stats = self.fulltext_search.reindex_project(project_id)
# 语义索引
semantic_stats = {"indexed": 0, "errors": 0}
if self.semantic_search.is_available():
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
# 索引转录文本
transcripts = conn.execute(
"SELECT id, project_id, full_text FROM transcripts WHERE project_id = ?",
(project_id,),
).fetchall()
for t in transcripts:
if t["full_text"] and self.semantic_search.index_embedding(
t["id"], "transcript", t["project_id"], t["full_text"]
):
semantic_stats["indexed"] += 1
else:
semantic_stats["errors"] += 1
# 索引实体
entities = conn.execute(
"SELECT id, project_id, name, definition FROM entities WHERE project_id = ?",
(project_id,),
).fetchall()
for e in entities:
text = f"{e['name']} {e['definition'] or ''}"
if self.semantic_search.index_embedding(e["id"], "entity", e["project_id"], text):
semantic_stats["indexed"] += 1
else:
semantic_stats["errors"] += 1
conn.close()
return {"project_id": project_id, "fulltext": fulltext_stats, "semantic": semantic_stats}
def get_search_stats(self, project_id: str | None = None) -> dict:
"""获取搜索统计信息"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
where_clause = "WHERE project_id = ?" if project_id else ""
params = [project_id] if project_id else []
# 全文索引统计
fulltext_count = conn.execute(
f"SELECT COUNT(*) as count FROM search_indexes {where_clause}", params
).fetchone()["count"]
# 语义索引统计
semantic_count = conn.execute(
f"SELECT COUNT(*) as count FROM embeddings {where_clause}", params
).fetchone()["count"]
# 按类型统计
type_stats = {}
if project_id:
rows = conn.execute(
"""SELECT content_type, COUNT(*) as count
FROM search_indexes WHERE project_id = ?
GROUP BY content_type""",
(project_id,),
).fetchall()
type_stats = {r["content_type"]: r["count"] for r in rows}
conn.close()
return {
"project_id": project_id,
"fulltext_indexed": fulltext_count,
"semantic_indexed": semantic_count,
"by_content_type": type_stats,
"semantic_search_available": self.semantic_search.is_available(),
}
# 单例模式
_search_manager = None
def get_search_manager(db_path: str = "insightflow.db") -> SearchManager:
"""获取搜索管理器单例"""
global _search_manager
if _search_manager is None:
_search_manager = SearchManager(db_path)
return _search_manager
# 便捷函数
def fulltext_search(
query: str, project_id: str | None = None, limit: int = 20
) -> list[SearchResult]:
"""全文搜索便捷函数"""
manager = get_search_manager()
return manager.fulltext_search.search(query, project_id, limit=limit)
def semantic_search(
query: str, project_id: str | None = None, top_k: int = 10
) -> list[SemanticSearchResult]:
"""语义搜索便捷函数"""
manager = get_search_manager()
return manager.semantic_search.search(query, project_id, top_k=top_k)
def find_entity_path(source_id: str, target_id: str, max_depth: int = 5) -> EntityPath | None:
"""查找实体路径便捷函数"""
manager = get_search_manager()
return manager.path_discovery.find_shortest_path(source_id, target_id, max_depth)
def detect_knowledge_gaps(project_id: str) -> list[KnowledgeGap]:
"""知识缺口检测便捷函数"""
manager = get_search_manager()
return manager.gap_detection.analyze_project(project_id)