Files
insightflow/backend/entity_aligner.py
OpenClaw Bot 1a9b5391f7 fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
2026-02-28 09:15:51 +08:00

372 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Entity Aligner - Phase 3
使用 embedding 进行实体对齐
"""
import json
import os
from dataclasses import dataclass
import httpx
import numpy as np
# API Keys
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
@dataclass
class EntityEmbedding:
entity_id: str
name: str
definition: str
embedding: list[float]
class EntityAligner:
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
self.embedding_cache: dict[str, list[float]] = {}
def get_embedding(self, text: str) -> list[float] | None:
"""
使用 Kimi API 获取文本的 embedding
Args:
text: 输入文本
Returns:
embedding 向量或 None
"""
if not KIMI_API_KEY:
return None
# 检查缓存
cache_key = hash(text)
if cache_key in self.embedding_cache:
return self.embedding_cache[cache_key]
try:
response = httpx.post(
f"{KIMI_BASE_URL}/v1/embeddings",
headers={
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={"model": "k2p5", "input": text[:500]}, # 限制长度
timeout=30.0,
)
response.raise_for_status()
result = response.json()
embedding = result["data"][0]["embedding"]
self.embedding_cache[cache_key] = embedding
return embedding
except Exception as e:
print(f"Embedding API failed: {e}")
return None
def compute_similarity(self, embedding1: list[float], embedding2: list[float]) -> float:
"""
计算两个 embedding 的余弦相似度
Args:
embedding1: 第一个向量
embedding2: 第二个向量
Returns:
相似度分数 (0-1)
"""
vec1 = np.array(embedding1)
vec2 = np.array(embedding2)
# 余弦相似度
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return float(dot_product / (norm1 * norm2))
def get_entity_text(self, name: str, definition: str = "") -> str:
"""
构建用于 embedding 的实体文本
Args:
name: 实体名称
definition: 实体定义
Returns:
组合文本
"""
if definition:
return f"{name}: {definition}"
return name
def find_similar_entity(
self,
project_id: str,
name: str,
definition: str = "",
exclude_id: str | None = None,
threshold: float | None = None,
) -> object | None:
"""
查找相似的实体
Args:
project_id: 项目 ID
name: 实体名称
definition: 实体定义
exclude_id: 要排除的实体 ID
threshold: 相似度阈值
Returns:
相似的实体或 None
"""
if threshold is None:
threshold = self.similarity_threshold
try:
from db_manager import get_db_manager
db = get_db_manager()
except ImportError:
return None
# 获取项目的所有实体
entities = db.get_all_entities_for_embedding(project_id)
if not entities:
return None
# 获取查询实体的 embedding
query_text = self.get_entity_text(name, definition)
query_embedding = self.get_embedding(query_text)
if query_embedding is None:
# 如果 embedding API 失败,回退到简单匹配
return self._fallback_similarity_match(entities, name, exclude_id)
best_match = None
best_score = threshold
for entity in entities:
if exclude_id and entity.id == exclude_id:
continue
# 获取实体的 embedding
entity_text = self.get_entity_text(entity.name, entity.definition)
entity_embedding = self.get_embedding(entity_text)
if entity_embedding is None:
continue
# 计算相似度
similarity = self.compute_similarity(query_embedding, entity_embedding)
if similarity > best_score:
best_score = similarity
best_match = entity
return best_match
def _fallback_similarity_match(
self, entities: list[object], name: str, exclude_id: str | None = None
) -> object | None:
"""
回退到简单的相似度匹配(不使用 embedding
Args:
entities: 实体列表
name: 查询名称
exclude_id: 要排除的实体 ID
Returns:
最相似的实体或 None
"""
name_lower = name.lower()
# 1. 精确匹配
for entity in entities:
if exclude_id and entity.id == exclude_id:
continue
if entity.name.lower() == name_lower:
return entity
if entity.aliases and name_lower in [a.lower() for a in entity.aliases]:
return entity
# 2. 包含匹配
for entity in entities:
if exclude_id and entity.id == exclude_id:
continue
if name_lower in entity.name.lower() or entity.name.lower() in name_lower:
return entity
return None
def batch_align_entities(
self, project_id: str, new_entities: list[dict], threshold: float | None = None
) -> list[dict]:
"""
批量对齐实体
Args:
project_id: 项目 ID
new_entities: 新实体列表 [{"name": "...", "definition": "..."}]
threshold: 相似度阈值
Returns:
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
"""
if threshold is None:
threshold = self.similarity_threshold
results = []
for new_ent in new_entities:
matched = self.find_similar_entity(
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
)
result = {
"new_entity": new_ent,
"matched_entity": None,
"similarity": 0.0,
"should_merge": False,
}
if matched:
# 计算相似度
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
matched_text = self.get_entity_text(matched.name, matched.definition)
query_emb = self.get_embedding(query_text)
matched_emb = self.get_embedding(matched_text)
if query_emb and matched_emb:
similarity = self.compute_similarity(query_emb, matched_emb)
result["matched_entity"] = {
"id": matched.id,
"name": matched.name,
"type": matched.type,
"definition": matched.definition,
}
result["similarity"] = similarity
result["should_merge"] = similarity >= threshold
results.append(result)
return results
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> list[str]:
"""
使用 LLM 建议实体的别名
Args:
entity_name: 实体名称
entity_definition: 实体定义
Returns:
建议的别名列表
"""
if not KIMI_API_KEY:
return []
prompt = f"""为以下实体生成可能的别名或简称:
实体名称:{entity_name}
定义:{entity_definition}
请返回 JSON 格式的别名列表:
{{"aliases": ["别名1", "别名2", "别名3"]}}
只返回 JSON不要其他内容。"""
try:
response = httpx.post(
f"{KIMI_BASE_URL}/v1/chat/completions",
headers={
"Authorization": f"Bearer {KIMI_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": "k2p5",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
timeout=30.0,
)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
return data.get("aliases", [])
except Exception as e:
print(f"Alias suggestion failed: {e}")
return []
# 简单的字符串相似度计算(不使用 embedding
def simple_similarity(str1: str, str2: str) -> float:
"""
计算两个字符串的简单相似度
Args:
str1: 第一个字符串
str2: 第二个字符串
Returns:
相似度分数 (0-1)
"""
if str1 == str2:
return 1.0
if not str1 or not str2:
return 0.0
# 转换为小写
s1 = str1.lower()
s2 = str2.lower()
# 包含关系
if s1 in s2 or s2 in s1:
return 0.8
# 计算编辑距离相似度
from difflib import SequenceMatcher
return SequenceMatcher(None, s1, s2).ratio()
if __name__ == "__main__":
# 测试
aligner = EntityAligner()
# 测试 embedding
test_text = "Kubernetes 容器编排平台"
embedding = aligner.get_embedding(test_text)
if embedding:
print(f"Embedding dimension: {len(embedding)}")
print(f"First 5 values: {embedding[:5]}")
else:
print("Embedding API not available")
# 测试相似度计算
emb1 = [1.0, 0.0, 0.0]
emb2 = [0.9, 0.1, 0.0]
sim = aligner.compute_similarity(emb1, emb2)
print(f"Similarity: {sim:.4f}")