372 lines
10 KiB
Python
372 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Entity Aligner - Phase 3
|
||
使用 embedding 进行实体对齐
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
from dataclasses import dataclass
|
||
|
||
import httpx
|
||
import numpy as np
|
||
|
||
# API Keys
|
||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||
|
||
|
||
@dataclass
|
||
class EntityEmbedding:
|
||
entity_id: str
|
||
name: str
|
||
definition: str
|
||
embedding: list[float]
|
||
|
||
|
||
class EntityAligner:
|
||
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
|
||
|
||
def __init__(self, similarity_threshold: float = 0.85):
|
||
self.similarity_threshold = similarity_threshold
|
||
self.embedding_cache: dict[str, list[float]] = {}
|
||
|
||
def get_embedding(self, text: str) -> list[float] | None:
|
||
"""
|
||
使用 Kimi API 获取文本的 embedding
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
embedding 向量或 None
|
||
"""
|
||
if not KIMI_API_KEY:
|
||
return None
|
||
|
||
# 检查缓存
|
||
cache_key = hash(text)
|
||
if cache_key in self.embedding_cache:
|
||
return self.embedding_cache[cache_key]
|
||
|
||
try:
|
||
response = httpx.post(
|
||
f"{KIMI_BASE_URL}/v1/embeddings",
|
||
headers={
|
||
"Authorization": f"Bearer {KIMI_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={"model": "k2p5", "input": text[:500]}, # 限制长度
|
||
timeout=30.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
embedding = result["data"][0]["embedding"]
|
||
self.embedding_cache[cache_key] = embedding
|
||
return embedding
|
||
|
||
except Exception as e:
|
||
print(f"Embedding API failed: {e}")
|
||
return None
|
||
|
||
def compute_similarity(self, embedding1: list[float], embedding2: list[float]) -> float:
|
||
"""
|
||
计算两个 embedding 的余弦相似度
|
||
|
||
Args:
|
||
embedding1: 第一个向量
|
||
embedding2: 第二个向量
|
||
|
||
Returns:
|
||
相似度分数 (0-1)
|
||
"""
|
||
vec1 = np.array(embedding1)
|
||
vec2 = np.array(embedding2)
|
||
|
||
# 余弦相似度
|
||
dot_product = np.dot(vec1, vec2)
|
||
norm1 = np.linalg.norm(vec1)
|
||
norm2 = np.linalg.norm(vec2)
|
||
|
||
if norm1 == 0 or norm2 == 0:
|
||
return 0.0
|
||
|
||
return float(dot_product / (norm1 * norm2))
|
||
|
||
def get_entity_text(self, name: str, definition: str = "") -> str:
|
||
"""
|
||
构建用于 embedding 的实体文本
|
||
|
||
Args:
|
||
name: 实体名称
|
||
definition: 实体定义
|
||
|
||
Returns:
|
||
组合文本
|
||
"""
|
||
if definition:
|
||
return f"{name}: {definition}"
|
||
return name
|
||
|
||
def find_similar_entity(
|
||
self,
|
||
project_id: str,
|
||
name: str,
|
||
definition: str = "",
|
||
exclude_id: str | None = None,
|
||
threshold: float | None = None,
|
||
) -> object | None:
|
||
"""
|
||
查找相似的实体
|
||
|
||
Args:
|
||
project_id: 项目 ID
|
||
name: 实体名称
|
||
definition: 实体定义
|
||
exclude_id: 要排除的实体 ID
|
||
threshold: 相似度阈值
|
||
|
||
Returns:
|
||
相似的实体或 None
|
||
"""
|
||
if threshold is None:
|
||
threshold = self.similarity_threshold
|
||
|
||
try:
|
||
from db_manager import get_db_manager
|
||
|
||
db = get_db_manager()
|
||
except ImportError:
|
||
return None
|
||
|
||
# 获取项目的所有实体
|
||
entities = db.get_all_entities_for_embedding(project_id)
|
||
|
||
if not entities:
|
||
return None
|
||
|
||
# 获取查询实体的 embedding
|
||
query_text = self.get_entity_text(name, definition)
|
||
query_embedding = self.get_embedding(query_text)
|
||
|
||
if query_embedding is None:
|
||
# 如果 embedding API 失败,回退到简单匹配
|
||
return self._fallback_similarity_match(entities, name, exclude_id)
|
||
|
||
best_match = None
|
||
best_score = threshold
|
||
|
||
for entity in entities:
|
||
if exclude_id and entity.id == exclude_id:
|
||
continue
|
||
|
||
# 获取实体的 embedding
|
||
entity_text = self.get_entity_text(entity.name, entity.definition)
|
||
entity_embedding = self.get_embedding(entity_text)
|
||
|
||
if entity_embedding is None:
|
||
continue
|
||
|
||
# 计算相似度
|
||
similarity = self.compute_similarity(query_embedding, entity_embedding)
|
||
|
||
if similarity > best_score:
|
||
best_score = similarity
|
||
best_match = entity
|
||
|
||
return best_match
|
||
|
||
def _fallback_similarity_match(
|
||
self, entities: list[object], name: str, exclude_id: str | None = None
|
||
) -> object | None:
|
||
"""
|
||
回退到简单的相似度匹配(不使用 embedding)
|
||
|
||
Args:
|
||
entities: 实体列表
|
||
name: 查询名称
|
||
exclude_id: 要排除的实体 ID
|
||
|
||
Returns:
|
||
最相似的实体或 None
|
||
"""
|
||
name_lower = name.lower()
|
||
|
||
# 1. 精确匹配
|
||
for entity in entities:
|
||
if exclude_id and entity.id == exclude_id:
|
||
continue
|
||
if entity.name.lower() == name_lower:
|
||
return entity
|
||
if entity.aliases and name_lower in [a.lower() for a in entity.aliases]:
|
||
return entity
|
||
|
||
# 2. 包含匹配
|
||
for entity in entities:
|
||
if exclude_id and entity.id == exclude_id:
|
||
continue
|
||
if name_lower in entity.name.lower() or entity.name.lower() in name_lower:
|
||
return entity
|
||
|
||
return None
|
||
|
||
def batch_align_entities(
|
||
self, project_id: str, new_entities: list[dict], threshold: float | None = None
|
||
) -> list[dict]:
|
||
"""
|
||
批量对齐实体
|
||
|
||
Args:
|
||
project_id: 项目 ID
|
||
new_entities: 新实体列表 [{"name": "...", "definition": "..."}]
|
||
threshold: 相似度阈值
|
||
|
||
Returns:
|
||
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
|
||
"""
|
||
if threshold is None:
|
||
threshold = self.similarity_threshold
|
||
|
||
results = []
|
||
|
||
for new_ent in new_entities:
|
||
matched = self.find_similar_entity(
|
||
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
|
||
)
|
||
|
||
result = {
|
||
"new_entity": new_ent,
|
||
"matched_entity": None,
|
||
"similarity": 0.0,
|
||
"should_merge": False,
|
||
}
|
||
|
||
if matched:
|
||
# 计算相似度
|
||
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
|
||
matched_text = self.get_entity_text(matched.name, matched.definition)
|
||
|
||
query_emb = self.get_embedding(query_text)
|
||
matched_emb = self.get_embedding(matched_text)
|
||
|
||
if query_emb and matched_emb:
|
||
similarity = self.compute_similarity(query_emb, matched_emb)
|
||
result["matched_entity"] = {
|
||
"id": matched.id,
|
||
"name": matched.name,
|
||
"type": matched.type,
|
||
"definition": matched.definition,
|
||
}
|
||
result["similarity"] = similarity
|
||
result["should_merge"] = similarity >= threshold
|
||
|
||
results.append(result)
|
||
|
||
return results
|
||
|
||
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
|
||
"""
|
||
使用 LLM 建议实体的别名
|
||
|
||
Args:
|
||
entity_name: 实体名称
|
||
entity_definition: 实体定义
|
||
|
||
Returns:
|
||
建议的别名列表
|
||
"""
|
||
if not KIMI_API_KEY:
|
||
return []
|
||
|
||
prompt = f"""为以下实体生成可能的别名或简称:
|
||
|
||
实体名称:{entity_name}
|
||
定义:{entity_definition}
|
||
|
||
请返回 JSON 格式的别名列表:
|
||
{{"aliases": ["别名1", "别名2", "别名3"]}}
|
||
|
||
只返回 JSON,不要其他内容。"""
|
||
|
||
try:
|
||
response = httpx.post(
|
||
f"{KIMI_BASE_URL}/v1/chat/completions",
|
||
headers={
|
||
"Authorization": f"Bearer {KIMI_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={
|
||
"model": "k2p5",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.3,
|
||
},
|
||
timeout=30.0,
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
content = result["choices"][0]["message"]["content"]
|
||
|
||
import re
|
||
|
||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||
if json_match:
|
||
data = json.loads(json_match.group())
|
||
return data.get("aliases", [])
|
||
except Exception as e:
|
||
print(f"Alias suggestion failed: {e}")
|
||
|
||
return []
|
||
|
||
|
||
# 简单的字符串相似度计算(不使用 embedding)
|
||
def simple_similarity(str1: str, str2: str) -> float:
|
||
"""
|
||
计算两个字符串的简单相似度
|
||
|
||
Args:
|
||
str1: 第一个字符串
|
||
str2: 第二个字符串
|
||
|
||
Returns:
|
||
相似度分数 (0-1)
|
||
"""
|
||
if str1 == str2:
|
||
return 1.0
|
||
|
||
if not str1 or not str2:
|
||
return 0.0
|
||
|
||
# 转换为小写
|
||
s1 = str1.lower()
|
||
s2 = str2.lower()
|
||
|
||
# 包含关系
|
||
if s1 in s2 or s2 in s1:
|
||
return 0.8
|
||
|
||
# 计算编辑距离相似度
|
||
from difflib import SequenceMatcher
|
||
|
||
return SequenceMatcher(None, s1, s2).ratio()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试
|
||
aligner = EntityAligner()
|
||
|
||
# 测试 embedding
|
||
test_text = "Kubernetes 容器编排平台"
|
||
embedding = aligner.get_embedding(test_text)
|
||
if embedding:
|
||
print(f"Embedding dimension: {len(embedding)}")
|
||
print(f"First 5 values: {embedding[:5]}")
|
||
else:
|
||
print("Embedding API not available")
|
||
|
||
# 测试相似度计算
|
||
emb1 = [1.0, 0.0, 0.0]
|
||
emb2 = [0.9, 0.1, 0.0]
|
||
sim = aligner.compute_similarity(emb1, emb2)
|
||
print(f"Similarity: {sim:.4f}")
|