Phase 3: Memory & Growth - Multi-file fusion, Entity alignment with embedding, Document import, Knowledge base panel
This commit is contained in:
372
backend/entity_aligner.py
Normal file
372
backend/entity_aligner.py
Normal file
@@ -0,0 +1,372 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Entity Aligner - Phase 3
|
||||
使用 embedding 进行实体对齐
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import httpx
|
||||
import numpy as np
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
# API Keys
|
||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||
|
||||
@dataclass
|
||||
class EntityEmbedding:
|
||||
entity_id: str
|
||||
name: str
|
||||
definition: str
|
||||
embedding: List[float]
|
||||
|
||||
class EntityAligner:
|
||||
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
|
||||
|
||||
def __init__(self, similarity_threshold: float = 0.85):
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_cache: Dict[str, List[float]] = {}
|
||||
|
||||
def get_embedding(self, text: str) -> Optional[List[float]]:
|
||||
"""
|
||||
使用 Kimi API 获取文本的 embedding
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
embedding 向量或 None
|
||||
"""
|
||||
if not KIMI_API_KEY:
|
||||
return None
|
||||
|
||||
# 检查缓存
|
||||
cache_key = hash(text)
|
||||
if cache_key in self.embedding_cache:
|
||||
return self.embedding_cache[cache_key]
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
f"{KIMI_BASE_URL}/v1/embeddings",
|
||||
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
||||
json={
|
||||
"model": "k2p5",
|
||||
"input": text[:500] # 限制长度
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
embedding = result["data"][0]["embedding"]
|
||||
self.embedding_cache[cache_key] = embedding
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
print(f"Embedding API failed: {e}")
|
||||
return None
|
||||
|
||||
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
||||
"""
|
||||
计算两个 embedding 的余弦相似度
|
||||
|
||||
Args:
|
||||
embedding1: 第一个向量
|
||||
embedding2: 第二个向量
|
||||
|
||||
Returns:
|
||||
相似度分数 (0-1)
|
||||
"""
|
||||
vec1 = np.array(embedding1)
|
||||
vec2 = np.array(embedding2)
|
||||
|
||||
# 余弦相似度
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm1 * norm2))
|
||||
|
||||
def get_entity_text(self, name: str, definition: str = "") -> str:
|
||||
"""
|
||||
构建用于 embedding 的实体文本
|
||||
|
||||
Args:
|
||||
name: 实体名称
|
||||
definition: 实体定义
|
||||
|
||||
Returns:
|
||||
组合文本
|
||||
"""
|
||||
if definition:
|
||||
return f"{name}: {definition}"
|
||||
return name
|
||||
|
||||
def find_similar_entity(
|
||||
self,
|
||||
project_id: str,
|
||||
name: str,
|
||||
definition: str = "",
|
||||
exclude_id: Optional[str] = None,
|
||||
threshold: Optional[float] = None
|
||||
) -> Optional[object]:
|
||||
"""
|
||||
查找相似的实体
|
||||
|
||||
Args:
|
||||
project_id: 项目 ID
|
||||
name: 实体名称
|
||||
definition: 实体定义
|
||||
exclude_id: 要排除的实体 ID
|
||||
threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
相似的实体或 None
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.similarity_threshold
|
||||
|
||||
try:
|
||||
from db_manager import get_db_manager
|
||||
db = get_db_manager()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
# 获取项目的所有实体
|
||||
entities = db.get_all_entities_for_embedding(project_id)
|
||||
|
||||
if not entities:
|
||||
return None
|
||||
|
||||
# 获取查询实体的 embedding
|
||||
query_text = self.get_entity_text(name, definition)
|
||||
query_embedding = self.get_embedding(query_text)
|
||||
|
||||
if query_embedding is None:
|
||||
# 如果 embedding API 失败,回退到简单匹配
|
||||
return self._fallback_similarity_match(entities, name, exclude_id)
|
||||
|
||||
best_match = None
|
||||
best_score = threshold
|
||||
|
||||
for entity in entities:
|
||||
if exclude_id and entity.id == exclude_id:
|
||||
continue
|
||||
|
||||
# 获取实体的 embedding
|
||||
entity_text = self.get_entity_text(entity.name, entity.definition)
|
||||
entity_embedding = self.get_embedding(entity_text)
|
||||
|
||||
if entity_embedding is None:
|
||||
continue
|
||||
|
||||
# 计算相似度
|
||||
similarity = self.compute_similarity(query_embedding, entity_embedding)
|
||||
|
||||
if similarity > best_score:
|
||||
best_score = similarity
|
||||
best_match = entity
|
||||
|
||||
return best_match
|
||||
|
||||
def _fallback_similarity_match(
|
||||
self,
|
||||
entities: List[object],
|
||||
name: str,
|
||||
exclude_id: Optional[str] = None
|
||||
) -> Optional[object]:
|
||||
"""
|
||||
回退到简单的相似度匹配(不使用 embedding)
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
name: 查询名称
|
||||
exclude_id: 要排除的实体 ID
|
||||
|
||||
Returns:
|
||||
最相似的实体或 None
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
|
||||
# 1. 精确匹配
|
||||
for entity in entities:
|
||||
if exclude_id and entity.id == exclude_id:
|
||||
continue
|
||||
if entity.name.lower() == name_lower:
|
||||
return entity
|
||||
if entity.aliases and name_lower in [a.lower() for a in entity.aliases]:
|
||||
return entity
|
||||
|
||||
# 2. 包含匹配
|
||||
for entity in entities:
|
||||
if exclude_id and entity.id == exclude_id:
|
||||
continue
|
||||
if name_lower in entity.name.lower() or entity.name.lower() in name_lower:
|
||||
return entity
|
||||
|
||||
return None
|
||||
|
||||
def batch_align_entities(
|
||||
self,
|
||||
project_id: str,
|
||||
new_entities: List[Dict],
|
||||
threshold: Optional[float] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
批量对齐实体
|
||||
|
||||
Args:
|
||||
project_id: 项目 ID
|
||||
new_entities: 新实体列表 [{"name": "...", "definition": "..."}]
|
||||
threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.similarity_threshold
|
||||
|
||||
results = []
|
||||
|
||||
for new_ent in new_entities:
|
||||
matched = self.find_similar_entity(
|
||||
project_id,
|
||||
new_ent["name"],
|
||||
new_ent.get("definition", ""),
|
||||
threshold=threshold
|
||||
)
|
||||
|
||||
result = {
|
||||
"new_entity": new_ent,
|
||||
"matched_entity": None,
|
||||
"similarity": 0.0,
|
||||
"should_merge": False
|
||||
}
|
||||
|
||||
if matched:
|
||||
# 计算相似度
|
||||
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
|
||||
matched_text = self.get_entity_text(matched.name, matched.definition)
|
||||
|
||||
query_emb = self.get_embedding(query_text)
|
||||
matched_emb = self.get_embedding(matched_text)
|
||||
|
||||
if query_emb and matched_emb:
|
||||
similarity = self.compute_similarity(query_emb, matched_emb)
|
||||
result["matched_entity"] = {
|
||||
"id": matched.id,
|
||||
"name": matched.name,
|
||||
"type": matched.type,
|
||||
"definition": matched.definition
|
||||
}
|
||||
result["similarity"] = similarity
|
||||
result["should_merge"] = similarity >= threshold
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> List[str]:
|
||||
"""
|
||||
使用 LLM 建议实体的别名
|
||||
|
||||
Args:
|
||||
entity_name: 实体名称
|
||||
entity_definition: 实体定义
|
||||
|
||||
Returns:
|
||||
建议的别名列表
|
||||
"""
|
||||
if not KIMI_API_KEY:
|
||||
return []
|
||||
|
||||
prompt = f"""为以下实体生成可能的别名或简称:
|
||||
|
||||
实体名称:{entity_name}
|
||||
定义:{entity_definition}
|
||||
|
||||
请返回 JSON 格式的别名列表:
|
||||
{{"aliases": ["别名1", "别名2", "别名3"]}}
|
||||
|
||||
只返回 JSON,不要其他内容。"""
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
f"{KIMI_BASE_URL}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
||||
json={
|
||||
"model": "k2p5",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
return data.get("aliases", [])
|
||||
except Exception as e:
|
||||
print(f"Alias suggestion failed: {e}")
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# 简单的字符串相似度计算(不使用 embedding)
|
||||
def simple_similarity(str1: str, str2: str) -> float:
|
||||
"""
|
||||
计算两个字符串的简单相似度
|
||||
|
||||
Args:
|
||||
str1: 第一个字符串
|
||||
str2: 第二个字符串
|
||||
|
||||
Returns:
|
||||
相似度分数 (0-1)
|
||||
"""
|
||||
if str1 == str2:
|
||||
return 1.0
|
||||
|
||||
if not str1 or not str2:
|
||||
return 0.0
|
||||
|
||||
# 转换为小写
|
||||
s1 = str1.lower()
|
||||
s2 = str2.lower()
|
||||
|
||||
# 包含关系
|
||||
if s1 in s2 or s2 in s1:
|
||||
return 0.8
|
||||
|
||||
# 计算编辑距离相似度
|
||||
from difflib import SequenceMatcher
|
||||
return SequenceMatcher(None, s1, s2).ratio()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试
|
||||
aligner = EntityAligner()
|
||||
|
||||
# 测试 embedding
|
||||
test_text = "Kubernetes 容器编排平台"
|
||||
embedding = aligner.get_embedding(test_text)
|
||||
if embedding:
|
||||
print(f"Embedding dimension: {len(embedding)}")
|
||||
print(f"First 5 values: {embedding[:5]}")
|
||||
else:
|
||||
print("Embedding API not available")
|
||||
|
||||
# 测试相似度计算
|
||||
emb1 = [1.0, 0.0, 0.0]
|
||||
emb2 = [0.9, 0.1, 0.0]
|
||||
sim = aligner.compute_similarity(emb1, emb2)
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
Reference in New Issue
Block a user