Files
insightflow/backend/neo4j_manager.py
OpenClaw Bot d767f0dddc fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
2026-02-27 21:12:04 +08:00

1043 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
InsightFlow Neo4j Graph Database Manager
Phase 5: Neo4j 图数据库集成
支持数据同步、复杂图查询和图算法分析
"""
import json
import logging
import os
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# Neo4j 连接配置
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
# 延迟导入,避免未安装时出错
try:
from neo4j import Driver, GraphDatabase
NEO4J_AVAILABLE = True
except ImportError:
NEO4J_AVAILABLE = False
logger.warning("Neo4j driver not installed. Neo4j features will be disabled.")
@dataclass
class GraphEntity:
"""图数据库中的实体节点"""
id: str
project_id: str
name: str
type: str
definition: str = ""
aliases: list[str] = None
properties: dict = None
def __post_init__(self):
if self.aliases is None:
self.aliases = []
if self.properties is None:
self.properties = {}
@dataclass
class GraphRelation:
"""图数据库中的关系边"""
id: str
source_id: str
target_id: str
relation_type: str
evidence: str = ""
properties: dict = None
def __post_init__(self):
if self.properties is None:
self.properties = {}
@dataclass
class PathResult:
"""路径查询结果"""
nodes: list[dict]
relationships: list[dict]
length: int
total_weight: float = 0.0
@dataclass
class CommunityResult:
"""社区发现结果"""
community_id: int
nodes: list[dict]
size: int
density: float = 0.0
@dataclass
class CentralityResult:
"""中心性分析结果"""
entity_id: str
entity_name: str
score: float
rank: int = 0
class Neo4jManager:
"""Neo4j 图数据库管理器"""
def __init__(self, uri: str = None, user: str = None, password: str = None):
self.uri = uri or NEO4J_URI
self.user = user or NEO4J_USER
self.password = password or NEO4J_PASSWORD
self._driver: Driver | None = None
if not NEO4J_AVAILABLE:
logger.error("Neo4j driver not available. Please install: pip install neo4j")
return
self._connect()
def _connect(self) -> None:
"""建立 Neo4j 连接"""
if not NEO4J_AVAILABLE:
return
try:
self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
# 验证连接
self._driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}")
except Exception as e:
logger.error(f"Failed to connect to Neo4j: {e}")
self._driver = None
def close(self) -> None:
"""关闭连接"""
if self._driver:
self._driver.close()
logger.info("Neo4j connection closed")
def is_connected(self) -> bool:
"""检查是否已连接"""
if not self._driver:
return False
try:
self._driver.verify_connectivity()
return True
except BaseException:
return False
def init_schema(self) -> None:
"""初始化图数据库 Schema约束和索引"""
if not self._driver:
logger.error("Neo4j not connected")
return
with self._driver.session() as session:
# 创建约束:实体 ID 唯一
session.run("""
CREATE CONSTRAINT entity_id IF NOT EXISTS
FOR (e:Entity) REQUIRE e.id IS UNIQUE
""")
# 创建约束:项目 ID 唯一
session.run("""
CREATE CONSTRAINT project_id IF NOT EXISTS
FOR (p:Project) REQUIRE p.id IS UNIQUE
""")
# 创建索引:实体名称
session.run("""
CREATE INDEX entity_name IF NOT EXISTS
FOR (e:Entity) ON (e.name)
""")
# 创建索引:实体类型
session.run("""
CREATE INDEX entity_type IF NOT EXISTS
FOR (e:Entity) ON (e.type)
""")
# 创建索引:关系类型
session.run("""
CREATE INDEX relation_type IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.relation_type)
""")
logger.info("Neo4j schema initialized")
# ==================== 数据同步 ====================
def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None:
"""同步项目节点到 Neo4j"""
if not self._driver:
return
with self._driver.session() as session:
session.run(
"""
MERGE (p:Project {id: $project_id})
SET p.name = $name,
p.description = $description,
p.updated_at = datetime()
""",
project_id=project_id,
name=project_name,
description=project_description,
)
def sync_entity(self, entity: GraphEntity) -> None:
"""同步单个实体到 Neo4j"""
if not self._driver:
return
with self._driver.session() as session:
# 创建实体节点
session.run(
"""
MERGE (e:Entity {id: $id})
SET e.name = $name,
e.type = $type,
e.definition = $definition,
e.aliases = $aliases,
e.properties = $properties,
e.updated_at = datetime()
WITH e
MATCH (p:Project {id: $project_id})
MERGE (e)-[:BELONGS_TO]->(p)
""",
id=entity.id,
project_id=entity.project_id,
name=entity.name,
type=entity.type,
definition=entity.definition,
aliases=json.dumps(entity.aliases),
properties=json.dumps(entity.properties),
)
def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
"""批量同步实体到 Neo4j"""
if not self._driver or not entities:
return
with self._driver.session() as session:
# 使用 UNWIND 批量处理
entities_data = [
{
"id": e.id,
"project_id": e.project_id,
"name": e.name,
"type": e.type,
"definition": e.definition,
"aliases": json.dumps(e.aliases),
"properties": json.dumps(e.properties),
}
for e in entities
]
session.run(
"""
UNWIND $entities AS entity
MERGE (e:Entity {id: entity.id})
SET e.name = entity.name,
e.type = entity.type,
e.definition = entity.definition,
e.aliases = entity.aliases,
e.properties = entity.properties,
e.updated_at = datetime()
WITH e, entity
MATCH (p:Project {id: entity.project_id})
MERGE (e)-[:BELONGS_TO]->(p)
""",
entities=entities_data,
)
def sync_relation(self, relation: GraphRelation) -> None:
"""同步单个关系到 Neo4j"""
if not self._driver:
return
with self._driver.session() as session:
session.run(
"""
MATCH (source:Entity {id: $source_id})
MATCH (target:Entity {id: $target_id})
MERGE (source)-[r:RELATES_TO {id: $id}]->(target)
SET r.relation_type = $relation_type,
r.evidence = $evidence,
r.properties = $properties,
r.updated_at = datetime()
""",
id=relation.id,
source_id=relation.source_id,
target_id=relation.target_id,
relation_type=relation.relation_type,
evidence=relation.evidence,
properties=json.dumps(relation.properties),
)
def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
"""批量同步关系到 Neo4j"""
if not self._driver or not relations:
return
with self._driver.session() as session:
relations_data = [
{
"id": r.id,
"source_id": r.source_id,
"target_id": r.target_id,
"relation_type": r.relation_type,
"evidence": r.evidence,
"properties": json.dumps(r.properties),
}
for r in relations
]
session.run(
"""
UNWIND $relations AS rel
MATCH (source:Entity {id: rel.source_id})
MATCH (target:Entity {id: rel.target_id})
MERGE (source)-[r:RELATES_TO {id: rel.id}]->(target)
SET r.relation_type = rel.relation_type,
r.evidence = rel.evidence,
r.properties = rel.properties,
r.updated_at = datetime()
""",
relations=relations_data,
)
def delete_entity(self, entity_id: str) -> None:
"""从 Neo4j 删除实体及其关系"""
if not self._driver:
return
with self._driver.session() as session:
session.run(
"""
MATCH (e:Entity {id: $id})
DETACH DELETE e
""",
id=entity_id,
)
def delete_project(self, project_id: str) -> None:
"""从 Neo4j 删除项目及其所有实体和关系"""
if not self._driver:
return
with self._driver.session() as session:
session.run(
"""
MATCH (p:Project {id: $id})
OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p)
DETACH DELETE e, p
""",
id=project_id,
)
# ==================== 复杂图查询 ====================
def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> PathResult | None:
"""
查找两个实体之间的最短路径
Args:
source_id: 起始实体 ID
target_id: 目标实体 ID
max_depth: 最大搜索深度
Returns:
PathResult 或 None
"""
if not self._driver:
return None
with self._driver.session() as session:
result = session.run(
"""
MATCH path = shortestPath(
(source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
)
RETURN path
""",
source_id=source_id,
target_id=target_id,
max_depth=max_depth,
)
record = result.single()
if not record:
return None
path = record["path"]
# 提取节点和关系
nodes = [{"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes]
relationships = [
{
"source": rel.start_node["id"],
"target": rel.end_node["id"],
"type": rel["relation_type"],
"evidence": rel.get("evidence", ""),
}
for rel in path.relationships
]
return PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships))
def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> list[PathResult]:
"""
查找两个实体之间的所有路径
Args:
source_id: 起始实体 ID
target_id: 目标实体 ID
max_depth: 最大搜索深度
limit: 返回路径数量限制
Returns:
PathResult 列表
"""
if not self._driver:
return []
with self._driver.session() as session:
result = session.run(
"""
MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id})
WHERE source <> target
RETURN path
LIMIT $limit
""",
source_id=source_id,
target_id=target_id,
max_depth=max_depth,
limit=limit,
)
paths = []
for record in result:
path = record["path"]
nodes = [{"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes]
relationships = [
{
"source": rel.start_node["id"],
"target": rel.end_node["id"],
"type": rel["relation_type"],
"evidence": rel.get("evidence", ""),
}
for rel in path.relationships
]
paths.append(PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships)))
return paths
def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> list[dict]:
"""
查找实体的邻居节点
Args:
entity_id: 实体 ID
relation_type: 可选的关系类型过滤
limit: 返回数量限制
Returns:
邻居节点列表
"""
if not self._driver:
return []
with self._driver.session() as session:
if relation_type:
result = session.run(
"""
MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]-(neighbor:Entity)
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit
""",
entity_id=entity_id,
relation_type=relation_type,
limit=limit,
)
else:
result = session.run(
"""
MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO]-(neighbor:Entity)
RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence
LIMIT $limit
""",
entity_id=entity_id,
limit=limit,
)
neighbors = []
for record in result:
node = record["neighbor"]
neighbors.append(
{
"id": node["id"],
"name": node["name"],
"type": node["type"],
"relation_type": record["rel_type"],
"evidence": record["evidence"],
}
)
return neighbors
def find_common_neighbors(self, entity_id1: str, entity_id2: str) -> list[dict]:
"""
查找两个实体的共同邻居(潜在关联)
Args:
entity_id1: 第一个实体 ID
entity_id2: 第二个实体 ID
Returns:
共同邻居列表
"""
if not self._driver:
return []
with self._driver.session() as session:
result = session.run(
"""
MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2})
RETURN DISTINCT common
""",
id1=entity_id1,
id2=entity_id2,
)
return [
{"id": record["common"]["id"], "name": record["common"]["name"], "type": record["common"]["type"]}
for record in result
]
# ==================== 图算法分析 ====================
def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
"""
计算 PageRank 中心性
Args:
project_id: 项目 ID
top_n: 返回前 N 个结果
Returns:
CentralityResult 列表
"""
if not self._driver:
return []
with self._driver.session() as session:
result = session.run(
"""
CALL gds.graph.exists('project-graph-$project_id') YIELD exists
WITH exists
CALL apoc.do.when(exists,
'CALL gds.graph.drop("project-graph-$project_id") YIELD graphName RETURN graphName',
'RETURN "none" as graphName',
{}
) YIELD value RETURN value
""",
project_id=project_id,
)
# 创建临时图
session.run(
"""
CALL gds.graph.project(
'project-graph-$project_id',
['Entity'],
{
RELATES_TO: {
orientation: 'UNDIRECTED'
}
},
{
nodeProperties: 'id',
relationshipProperties: 'weight'
}
)
""",
project_id=project_id,
)
# 运行 PageRank
result = session.run(
"""
CALL gds.pageRank.stream('project-graph-$project_id')
YIELD nodeId, score
RETURN gds.util.asNode(nodeId).id AS entity_id,
gds.util.asNode(nodeId).name AS entity_name,
score
ORDER BY score DESC
LIMIT $top_n
""",
project_id=project_id,
top_n=top_n,
)
rankings = []
rank = 1
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=record["score"],
rank=rank,
)
)
rank += 1
# 清理临时图
session.run(
"""
CALL gds.graph.drop('project-graph-$project_id')
""",
project_id=project_id,
)
return rankings
def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]:
"""
计算 Betweenness 中心性(桥梁作用)
Args:
project_id: 项目 ID
top_n: 返回前 N 个结果
Returns:
CentralityResult 列表
"""
if not self._driver:
return []
with self._driver.session() as session:
# 使用 APOC 的 betweenness 计算(如果没有 GDS
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
WITH e, count(other) as degree
ORDER BY degree DESC
LIMIT $top_n
RETURN e.id as entity_id, e.name as entity_name, degree as score
""",
project_id=project_id,
top_n=top_n,
)
rankings = []
rank = 1
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=float(record["score"]),
rank=rank,
)
)
rank += 1
return rankings
def detect_communities(self, project_id: str) -> list[CommunityResult]:
"""
社区发现(使用 Louvain 算法)
Args:
project_id: 项目 ID
Returns:
CommunityResult 列表
"""
if not self._driver:
return []
with self._driver.session() as session:
# 简单的社区检测:基于连通分量
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)-[:BELONGS_TO]->(p)
WITH e, collect(DISTINCT other.id) as connections
RETURN e.id as entity_id, e.name as entity_name, e.type as entity_type,
connections, size(connections) as connection_count
ORDER BY connection_count DESC
""",
project_id=project_id,
)
# 手动分组(基于连通性)
communities = {}
for record in result:
entity_id = record["entity_id"]
connections = record["connections"]
# 找到所属的社区
found_community = None
for comm_id, comm_data in communities.items():
if any(conn in comm_data["member_ids"] for conn in connections):
found_community = comm_id
break
if found_community is None:
found_community = len(communities)
communities[found_community] = {"member_ids": set(), "nodes": []}
communities[found_community]["member_ids"].add(entity_id)
communities[found_community]["nodes"].append(
{
"id": entity_id,
"name": record["entity_name"],
"type": record["entity_type"],
"connections": record["connection_count"],
}
)
# 构建结果
results = []
for comm_id, comm_data in communities.items():
nodes = comm_data["nodes"]
size = len(nodes)
# 计算密度(简化版)
max_edges = size * (size - 1) / 2 if size > 1 else 1
actual_edges = sum(n["connections"] for n in nodes) / 2
density = actual_edges / max_edges if max_edges > 0 else 0
results.append(CommunityResult(community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0)))
# 按大小排序
results.sort(key=lambda x: x.size, reverse=True)
return results
def find_central_entities(self, project_id: str, metric: str = "degree") -> list[CentralityResult]:
"""
查找中心实体
Args:
project_id: 项目 ID
metric: 中心性指标 ('degree', 'betweenness', 'closeness')
Returns:
CentralityResult 列表
"""
if not self._driver:
return []
with self._driver.session() as session:
if metric == "degree":
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
WITH e, count(DISTINCT other) as degree
RETURN e.id as entity_id, e.name as entity_name, degree as score
ORDER BY degree DESC
LIMIT 20
""",
project_id=project_id,
)
else:
# 默认使用度中心性
result = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)
WITH e, count(DISTINCT other) as degree
RETURN e.id as entity_id, e.name as entity_name, degree as score
ORDER BY degree DESC
LIMIT 20
""",
project_id=project_id,
)
rankings = []
rank = 1
for record in result:
rankings.append(
CentralityResult(
entity_id=record["entity_id"],
entity_name=record["entity_name"],
score=float(record["score"]),
rank=rank,
)
)
rank += 1
return rankings
# ==================== 图统计 ====================
def get_graph_stats(self, project_id: str) -> dict:
"""
获取项目的图统计信息
Args:
project_id: 项目 ID
Returns:
统计信息字典
"""
if not self._driver:
return {}
with self._driver.session() as session:
# 实体数量
entity_count = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN count(e) as count
""",
project_id=project_id,
).single()["count"]
# 关系数量
relation_count = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
MATCH (e)-[r:RELATES_TO]-()
RETURN count(r) as count
""",
project_id=project_id,
).single()["count"]
# 实体类型分布
type_distribution = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
RETURN e.type as type, count(e) as count
ORDER BY count DESC
""",
project_id=project_id,
)
types = {record["type"]: record["count"] for record in type_distribution}
# 平均度
avg_degree = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
OPTIONAL MATCH (e)-[:RELATES_TO]-(other)
WITH e, count(other) as degree
RETURN avg(degree) as avg_degree
""",
project_id=project_id,
).single()["avg_degree"]
# 关系类型分布
rel_types = session.run(
"""
MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id})
MATCH (e)-[r:RELATES_TO]-()
RETURN r.relation_type as type, count(r) as count
ORDER BY count DESC
LIMIT 10
""",
project_id=project_id,
)
relation_types = {record["type"]: record["count"] for record in rel_types}
return {
"entity_count": entity_count,
"relation_count": relation_count,
"type_distribution": types,
"average_degree": round(avg_degree, 2) if avg_degree else 0,
"relation_type_distribution": relation_types,
"density": round(relation_count / (entity_count * (entity_count - 1)), 4) if entity_count > 1 else 0,
}
def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict:
"""
获取指定实体的子图
Args:
entity_ids: 实体 ID 列表
depth: 扩展深度
Returns:
包含 nodes 和 relationships 的字典
"""
if not self._driver or not entity_ids:
return {"nodes": [], "relationships": []}
with self._driver.session() as session:
result = session.run(
"""
MATCH (e:Entity)
WHERE e.id IN $entity_ids
CALL apoc.path.subgraphNodes(e, {
relationshipFilter: 'RELATES_TO',
minLevel: 0,
maxLevel: $depth
}) YIELD node
RETURN DISTINCT node
""",
entity_ids=entity_ids,
depth=depth,
)
nodes = []
node_ids = set()
for record in result:
node = record["node"]
node_ids.add(node["id"])
nodes.append(
{
"id": node["id"],
"name": node["name"],
"type": node["type"],
"definition": node.get("definition", ""),
}
)
# 获取这些节点之间的关系
result = session.run(
"""
MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity)
WHERE source.id IN $node_ids AND target.id IN $node_ids
RETURN source.id as source_id, target.id as target_id,
r.relation_type as type, r.evidence as evidence
""",
node_ids=list(node_ids),
)
relationships = [
{
"source": record["source_id"],
"target": record["target_id"],
"type": record["type"],
"evidence": record["evidence"],
}
for record in result
]
return {"nodes": nodes, "relationships": relationships}
# 全局单例
_neo4j_manager = None
def get_neo4j_manager() -> Neo4jManager:
"""获取 Neo4j 管理器单例"""
global _neo4j_manager
if _neo4j_manager is None:
_neo4j_manager = Neo4jManager()
return _neo4j_manager
def close_neo4j_manager() -> None:
"""关闭 Neo4j 连接"""
global _neo4j_manager
if _neo4j_manager:
_neo4j_manager.close()
_neo4j_manager = None
# 便捷函数
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None:
"""
同步整个项目到 Neo4j
Args:
project_id: 项目 ID
project_name: 项目名称
entities: 实体列表(字典格式)
relations: 关系列表(字典格式)
"""
manager = get_neo4j_manager()
if not manager.is_connected():
logger.warning("Neo4j not connected, skipping sync")
return
# 同步项目
manager.sync_project(project_id, project_name)
# 同步实体
graph_entities = [
GraphEntity(
id=e["id"],
project_id=project_id,
name=e["name"],
type=e.get("type", "unknown"),
definition=e.get("definition", ""),
aliases=e.get("aliases", []),
properties=e.get("properties", {}),
)
for e in entities
]
manager.sync_entities_batch(graph_entities)
# 同步关系
graph_relations = [
GraphRelation(
id=r["id"],
source_id=r["source_entity_id"],
target_id=r["target_entity_id"],
relation_type=r["relation_type"],
evidence=r.get("evidence", ""),
properties=r.get("properties", {}),
)
for r in relations
]
manager.sync_relations_batch(graph_relations)
logger.info(f"Synced project {project_id} to Neo4j: {len(entities)} entities, {len(relations)} relations")
if __name__ == "__main__":
# 测试代码
logging.basicConfig(level=logging.INFO)
manager = Neo4jManager()
if manager.is_connected():
print("✅ Connected to Neo4j")
# 初始化 Schema
manager.init_schema()
print("✅ Schema initialized")
# 测试同步
manager.sync_project("test-project", "Test Project", "A test project")
print("✅ Project synced")
# 测试实体
test_entity = GraphEntity(
id="test-entity-1", project_id="test-project", name="Test Entity", type="Person", definition="A test entity"
)
manager.sync_entity(test_entity)
print("✅ Entity synced")
# 获取统计
stats = manager.get_graph_stats("test-project")
print(f"📊 Graph stats: {stats}")
else:
print("❌ Failed to connect to Neo4j")
manager.close()