#!/usr/bin/env python3 """ InsightFlow Neo4j Graph Database Manager Phase 5: Neo4j 图数据库集成 支持数据同步、复杂图查询和图算法分析 """ import json import logging import os from dataclasses import dataclass logger = logging.getLogger(__name__) # Neo4j 连接配置 NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password") # 延迟导入,避免未安装时出错 try: from neo4j import Driver, GraphDatabase NEO4J_AVAILABLE = True except ImportError: NEO4J_AVAILABLE = False logger.warning("Neo4j driver not installed. Neo4j features will be disabled.") @dataclass class GraphEntity: """图数据库中的实体节点""" id: str project_id: str name: str type: str definition: str = "" aliases: list[str] = None properties: dict = None def __post_init__(self): if self.aliases is None: self.aliases = [] if self.properties is None: self.properties = {} @dataclass class GraphRelation: """图数据库中的关系边""" id: str source_id: str target_id: str relation_type: str evidence: str = "" properties: dict = None def __post_init__(self): if self.properties is None: self.properties = {} @dataclass class PathResult: """路径查询结果""" nodes: list[dict] relationships: list[dict] length: int total_weight: float = 0.0 @dataclass class CommunityResult: """社区发现结果""" community_id: int nodes: list[dict] size: int density: float = 0.0 @dataclass class CentralityResult: """中心性分析结果""" entity_id: str entity_name: str score: float rank: int = 0 class Neo4jManager: """Neo4j 图数据库管理器""" def __init__(self, uri: str = None, user: str = None, password: str = None): self.uri = uri or NEO4J_URI self.user = user or NEO4J_USER self.password = password or NEO4J_PASSWORD self._driver: Driver | None = None if not NEO4J_AVAILABLE: logger.error("Neo4j driver not available. Please install: pip install neo4j") return self._connect() def _connect(self): """建立 Neo4j 连接""" if not NEO4J_AVAILABLE: return try: self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) # 验证连接 self._driver.verify_connectivity() logger.info(f"Connected to Neo4j at {self.uri}") except Exception as e: logger.error(f"Failed to connect to Neo4j: {e}") self._driver = None def close(self): """关闭连接""" if self._driver: self._driver.close() logger.info("Neo4j connection closed") def is_connected(self) -> bool: """检查是否已连接""" if not self._driver: return False try: self._driver.verify_connectivity() return True except BaseException: return False def init_schema(self): """初始化图数据库 Schema(约束和索引)""" if not self._driver: logger.error("Neo4j not connected") return with self._driver.session() as session: # 创建约束:实体 ID 唯一 session.run(""" CREATE CONSTRAINT entity_id IF NOT EXISTS FOR (e:Entity) REQUIRE e.id IS UNIQUE """) # 创建约束:项目 ID 唯一 session.run(""" CREATE CONSTRAINT project_id IF NOT EXISTS FOR (p:Project) REQUIRE p.id IS UNIQUE """) # 创建索引:实体名称 session.run(""" CREATE INDEX entity_name IF NOT EXISTS FOR (e:Entity) ON (e.name) """) # 创建索引:实体类型 session.run(""" CREATE INDEX entity_type IF NOT EXISTS FOR (e:Entity) ON (e.type) """) # 创建索引:关系类型 session.run(""" CREATE INDEX relation_type IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.relation_type) """) logger.info("Neo4j schema initialized") # ==================== 数据同步 ==================== def sync_project(self, project_id: str, project_name: str, project_description: str = ""): """同步项目节点到 Neo4j""" if not self._driver: return with self._driver.session() as session: session.run( """ MERGE (p:Project {id: $project_id}) SET p.name = $name, p.description = $description, p.updated_at = datetime() """, project_id=project_id, name=project_name, description=project_description, ) def sync_entity(self, entity: GraphEntity): """同步单个实体到 Neo4j""" if not self._driver: return with self._driver.session() as session: # 创建实体节点 session.run( """ MERGE (e:Entity {id: $id}) SET e.name = $name, e.type = $type, e.definition = $definition, e.aliases = $aliases, e.properties = $properties, e.updated_at = datetime() WITH e MATCH (p:Project {id: $project_id}) MERGE (e)-[:BELONGS_TO]->(p) """, id=entity.id, project_id=entity.project_id, name=entity.name, type=entity.type, definition=entity.definition, aliases=json.dumps(entity.aliases), properties=json.dumps(entity.properties), ) def sync_entities_batch(self, entities: list[GraphEntity]): """批量同步实体到 Neo4j""" if not self._driver or not entities: return with self._driver.session() as session: # 使用 UNWIND 批量处理 entities_data = [ { "id": e.id, "project_id": e.project_id, "name": e.name, "type": e.type, "definition": e.definition, "aliases": json.dumps(e.aliases), "properties": json.dumps(e.properties), } for e in entities ] session.run( """ UNWIND $entities AS entity MERGE (e:Entity {id: entity.id}) SET e.name = entity.name, e.type = entity.type, e.definition = entity.definition, e.aliases = entity.aliases, e.properties = entity.properties, e.updated_at = datetime() WITH e, entity MATCH (p:Project {id: entity.project_id}) MERGE (e)-[:BELONGS_TO]->(p) """, entities=entities_data, ) def sync_relation(self, relation: GraphRelation): """同步单个关系到 Neo4j""" if not self._driver: return with self._driver.session() as session: session.run( """ MATCH (source:Entity {id: $source_id}) MATCH (target:Entity {id: $target_id}) MERGE (source)-[r:RELATES_TO {id: $id}]->(target) SET r.relation_type = $relation_type, r.evidence = $evidence, r.properties = $properties, r.updated_at = datetime() """, id=relation.id, source_id=relation.source_id, target_id=relation.target_id, relation_type=relation.relation_type, evidence=relation.evidence, properties=json.dumps(relation.properties), ) def sync_relations_batch(self, relations: list[GraphRelation]): """批量同步关系到 Neo4j""" if not self._driver or not relations: return with self._driver.session() as session: relations_data = [ { "id": r.id, "source_id": r.source_id, "target_id": r.target_id, "relation_type": r.relation_type, "evidence": r.evidence, "properties": json.dumps(r.properties), } for r in relations ] session.run( """ UNWIND $relations AS rel MATCH (source:Entity {id: rel.source_id}) MATCH (target:Entity {id: rel.target_id}) MERGE (source)-[r:RELATES_TO {id: rel.id}]->(target) SET r.relation_type = rel.relation_type, r.evidence = rel.evidence, r.properties = rel.properties, r.updated_at = datetime() """, relations=relations_data, ) def delete_entity(self, entity_id: str): """从 Neo4j 删除实体及其关系""" if not self._driver: return with self._driver.session() as session: session.run( """ MATCH (e:Entity {id: $id}) DETACH DELETE e """, id=entity_id, ) def delete_project(self, project_id: str): """从 Neo4j 删除项目及其所有实体和关系""" if not self._driver: return with self._driver.session() as session: session.run( """ MATCH (p:Project {id: $id}) OPTIONAL MATCH (e:Entity)-[:BELONGS_TO]->(p) DETACH DELETE e, p """, id=project_id, ) # ==================== 复杂图查询 ==================== def find_shortest_path(self, source_id: str, target_id: str, max_depth: int = 10) -> PathResult | None: """ 查找两个实体之间的最短路径 Args: source_id: 起始实体 ID target_id: 目标实体 ID max_depth: 最大搜索深度 Returns: PathResult 或 None """ if not self._driver: return None with self._driver.session() as session: result = session.run( """ MATCH path = shortestPath( (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) ) RETURN path """, source_id=source_id, target_id=target_id, max_depth=max_depth, ) record = result.single() if not record: return None path = record["path"] # 提取节点和关系 nodes = [{"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes] relationships = [ { "source": rel.start_node["id"], "target": rel.end_node["id"], "type": rel["relation_type"], "evidence": rel.get("evidence", ""), } for rel in path.relationships ] return PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships)) def find_all_paths(self, source_id: str, target_id: str, max_depth: int = 5, limit: int = 10) -> list[PathResult]: """ 查找两个实体之间的所有路径 Args: source_id: 起始实体 ID target_id: 目标实体 ID max_depth: 最大搜索深度 limit: 返回路径数量限制 Returns: PathResult 列表 """ if not self._driver: return [] with self._driver.session() as session: result = session.run( """ MATCH path = (source:Entity {id: $source_id})-[*1..$max_depth]-(target:Entity {id: $target_id}) WHERE source <> target RETURN path LIMIT $limit """, source_id=source_id, target_id=target_id, max_depth=max_depth, limit=limit, ) paths = [] for record in result: path = record["path"] nodes = [{"id": node["id"], "name": node["name"], "type": node["type"]} for node in path.nodes] relationships = [ { "source": rel.start_node["id"], "target": rel.end_node["id"], "type": rel["relation_type"], "evidence": rel.get("evidence", ""), } for rel in path.relationships ] paths.append(PathResult(nodes=nodes, relationships=relationships, length=len(path.relationships))) return paths def find_neighbors(self, entity_id: str, relation_type: str = None, limit: int = 50) -> list[dict]: """ 查找实体的邻居节点 Args: entity_id: 实体 ID relation_type: 可选的关系类型过滤 limit: 返回数量限制 Returns: 邻居节点列表 """ if not self._driver: return [] with self._driver.session() as session: if relation_type: result = session.run( """ MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]-(neighbor:Entity) RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence LIMIT $limit """, entity_id=entity_id, relation_type=relation_type, limit=limit, ) else: result = session.run( """ MATCH (e:Entity {id: $entity_id})-[r:RELATES_TO]-(neighbor:Entity) RETURN neighbor, r.relation_type as rel_type, r.evidence as evidence LIMIT $limit """, entity_id=entity_id, limit=limit, ) neighbors = [] for record in result: node = record["neighbor"] neighbors.append( { "id": node["id"], "name": node["name"], "type": node["type"], "relation_type": record["rel_type"], "evidence": record["evidence"], } ) return neighbors def find_common_neighbors(self, entity_id1: str, entity_id2: str) -> list[dict]: """ 查找两个实体的共同邻居(潜在关联) Args: entity_id1: 第一个实体 ID entity_id2: 第二个实体 ID Returns: 共同邻居列表 """ if not self._driver: return [] with self._driver.session() as session: result = session.run( """ MATCH (e1:Entity {id: $id1})-[:RELATES_TO]-(common:Entity)-[:RELATES_TO]-(e2:Entity {id: $id2}) RETURN DISTINCT common """, id1=entity_id1, id2=entity_id2, ) return [ {"id": record["common"]["id"], "name": record["common"]["name"], "type": record["common"]["type"]} for record in result ] # ==================== 图算法分析 ==================== def calculate_pagerank(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: """ 计算 PageRank 中心性 Args: project_id: 项目 ID top_n: 返回前 N 个结果 Returns: CentralityResult 列表 """ if not self._driver: return [] with self._driver.session() as session: result = session.run( """ CALL gds.graph.exists('project-graph-$project_id') YIELD exists WITH exists CALL apoc.do.when(exists, 'CALL gds.graph.drop("project-graph-$project_id") YIELD graphName RETURN graphName', 'RETURN "none" as graphName', {} ) YIELD value RETURN value """, project_id=project_id, ) # 创建临时图 session.run( """ CALL gds.graph.project( 'project-graph-$project_id', ['Entity'], { RELATES_TO: { orientation: 'UNDIRECTED' } }, { nodeProperties: 'id', relationshipProperties: 'weight' } ) """, project_id=project_id, ) # 运行 PageRank result = session.run( """ CALL gds.pageRank.stream('project-graph-$project_id') YIELD nodeId, score RETURN gds.util.asNode(nodeId).id AS entity_id, gds.util.asNode(nodeId).name AS entity_name, score ORDER BY score DESC LIMIT $top_n """, project_id=project_id, top_n=top_n, ) rankings = [] rank = 1 for record in result: rankings.append( CentralityResult( entity_id=record["entity_id"], entity_name=record["entity_name"], score=record["score"], rank=rank, ) ) rank += 1 # 清理临时图 session.run( """ CALL gds.graph.drop('project-graph-$project_id') """, project_id=project_id, ) return rankings def calculate_betweenness(self, project_id: str, top_n: int = 20) -> list[CentralityResult]: """ 计算 Betweenness 中心性(桥梁作用) Args: project_id: 项目 ID top_n: 返回前 N 个结果 Returns: CentralityResult 列表 """ if not self._driver: return [] with self._driver.session() as session: # 使用 APOC 的 betweenness 计算(如果没有 GDS) result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) WITH e, count(other) as degree ORDER BY degree DESC LIMIT $top_n RETURN e.id as entity_id, e.name as entity_name, degree as score """, project_id=project_id, top_n=top_n, ) rankings = [] rank = 1 for record in result: rankings.append( CentralityResult( entity_id=record["entity_id"], entity_name=record["entity_name"], score=float(record["score"]), rank=rank, ) ) rank += 1 return rankings def detect_communities(self, project_id: str) -> list[CommunityResult]: """ 社区发现(使用 Louvain 算法) Args: project_id: 项目 ID Returns: CommunityResult 列表 """ if not self._driver: return [] with self._driver.session() as session: # 简单的社区检测:基于连通分量 result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity)-[:BELONGS_TO]->(p) WITH e, collect(DISTINCT other.id) as connections RETURN e.id as entity_id, e.name as entity_name, e.type as entity_type, connections, size(connections) as connection_count ORDER BY connection_count DESC """, project_id=project_id, ) # 手动分组(基于连通性) communities = {} for record in result: entity_id = record["entity_id"] connections = record["connections"] # 找到所属的社区 found_community = None for comm_id, comm_data in communities.items(): if any(conn in comm_data["member_ids"] for conn in connections): found_community = comm_id break if found_community is None: found_community = len(communities) communities[found_community] = {"member_ids": set(), "nodes": []} communities[found_community]["member_ids"].add(entity_id) communities[found_community]["nodes"].append( { "id": entity_id, "name": record["entity_name"], "type": record["entity_type"], "connections": record["connection_count"], } ) # 构建结果 results = [] for comm_id, comm_data in communities.items(): nodes = comm_data["nodes"] size = len(nodes) # 计算密度(简化版) max_edges = size * (size - 1) / 2 if size > 1 else 1 actual_edges = sum(n["connections"] for n in nodes) / 2 density = actual_edges / max_edges if max_edges > 0 else 0 results.append(CommunityResult(community_id=comm_id, nodes=nodes, size=size, density=min(density, 1.0))) # 按大小排序 results.sort(key=lambda x: x.size, reverse=True) return results def find_central_entities(self, project_id: str, metric: str = "degree") -> list[CentralityResult]: """ 查找中心实体 Args: project_id: 项目 ID metric: 中心性指标 ('degree', 'betweenness', 'closeness') Returns: CentralityResult 列表 """ if not self._driver: return [] with self._driver.session() as session: if metric == "degree": result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) WITH e, count(DISTINCT other) as degree RETURN e.id as entity_id, e.name as entity_name, degree as score ORDER BY degree DESC LIMIT 20 """, project_id=project_id, ) else: # 默认使用度中心性 result = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other:Entity) WITH e, count(DISTINCT other) as degree RETURN e.id as entity_id, e.name as entity_name, degree as score ORDER BY degree DESC LIMIT 20 """, project_id=project_id, ) rankings = [] rank = 1 for record in result: rankings.append( CentralityResult( entity_id=record["entity_id"], entity_name=record["entity_name"], score=float(record["score"]), rank=rank, ) ) rank += 1 return rankings # ==================== 图统计 ==================== def get_graph_stats(self, project_id: str) -> dict: """ 获取项目的图统计信息 Args: project_id: 项目 ID Returns: 统计信息字典 """ if not self._driver: return {} with self._driver.session() as session: # 实体数量 entity_count = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) RETURN count(e) as count """, project_id=project_id, ).single()["count"] # 关系数量 relation_count = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e)-[r:RELATES_TO]-() RETURN count(r) as count """, project_id=project_id, ).single()["count"] # 实体类型分布 type_distribution = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) RETURN e.type as type, count(e) as count ORDER BY count DESC """, project_id=project_id, ) types = {record["type"]: record["count"] for record in type_distribution} # 平均度 avg_degree = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) OPTIONAL MATCH (e)-[:RELATES_TO]-(other) WITH e, count(other) as degree RETURN avg(degree) as avg_degree """, project_id=project_id, ).single()["avg_degree"] # 关系类型分布 rel_types = session.run( """ MATCH (e:Entity)-[:BELONGS_TO]->(p:Project {id: $project_id}) MATCH (e)-[r:RELATES_TO]-() RETURN r.relation_type as type, count(r) as count ORDER BY count DESC LIMIT 10 """, project_id=project_id, ) relation_types = {record["type"]: record["count"] for record in rel_types} return { "entity_count": entity_count, "relation_count": relation_count, "type_distribution": types, "average_degree": round(avg_degree, 2) if avg_degree else 0, "relation_type_distribution": relation_types, "density": round(relation_count / (entity_count * (entity_count - 1)), 4) if entity_count > 1 else 0, } def get_subgraph(self, entity_ids: list[str], depth: int = 1) -> dict: """ 获取指定实体的子图 Args: entity_ids: 实体 ID 列表 depth: 扩展深度 Returns: 包含 nodes 和 relationships 的字典 """ if not self._driver or not entity_ids: return {"nodes": [], "relationships": []} with self._driver.session() as session: result = session.run( """ MATCH (e:Entity) WHERE e.id IN $entity_ids CALL apoc.path.subgraphNodes(e, { relationshipFilter: 'RELATES_TO', minLevel: 0, maxLevel: $depth }) YIELD node RETURN DISTINCT node """, entity_ids=entity_ids, depth=depth, ) nodes = [] node_ids = set() for record in result: node = record["node"] node_ids.add(node["id"]) nodes.append( { "id": node["id"], "name": node["name"], "type": node["type"], "definition": node.get("definition", ""), } ) # 获取这些节点之间的关系 result = session.run( """ MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity) WHERE source.id IN $node_ids AND target.id IN $node_ids RETURN source.id as source_id, target.id as target_id, r.relation_type as type, r.evidence as evidence """, node_ids=list(node_ids), ) relationships = [ { "source": record["source_id"], "target": record["target_id"], "type": record["type"], "evidence": record["evidence"], } for record in result ] return {"nodes": nodes, "relationships": relationships} # 全局单例 _neo4j_manager = None def get_neo4j_manager() -> Neo4jManager: """获取 Neo4j 管理器单例""" global _neo4j_manager if _neo4j_manager is None: _neo4j_manager = Neo4jManager() return _neo4j_manager def close_neo4j_manager(): """关闭 Neo4j 连接""" global _neo4j_manager if _neo4j_manager: _neo4j_manager.close() _neo4j_manager = None # 便捷函数 def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]): """ 同步整个项目到 Neo4j Args: project_id: 项目 ID project_name: 项目名称 entities: 实体列表(字典格式) relations: 关系列表(字典格式) """ manager = get_neo4j_manager() if not manager.is_connected(): logger.warning("Neo4j not connected, skipping sync") return # 同步项目 manager.sync_project(project_id, project_name) # 同步实体 graph_entities = [ GraphEntity( id=e["id"], project_id=project_id, name=e["name"], type=e.get("type", "unknown"), definition=e.get("definition", ""), aliases=e.get("aliases", []), properties=e.get("properties", {}), ) for e in entities ] manager.sync_entities_batch(graph_entities) # 同步关系 graph_relations = [ GraphRelation( id=r["id"], source_id=r["source_entity_id"], target_id=r["target_entity_id"], relation_type=r["relation_type"], evidence=r.get("evidence", ""), properties=r.get("properties", {}), ) for r in relations ] manager.sync_relations_batch(graph_relations) logger.info(f"Synced project {project_id} to Neo4j: {len(entities)} entities, {len(relations)} relations") if __name__ == "__main__": # 测试代码 logging.basicConfig(level=logging.INFO) manager = Neo4jManager() if manager.is_connected(): print("✅ Connected to Neo4j") # 初始化 Schema manager.init_schema() print("✅ Schema initialized") # 测试同步 manager.sync_project("test-project", "Test Project", "A test project") print("✅ Project synced") # 测试实体 test_entity = GraphEntity( id="test-entity-1", project_id="test-project", name="Test Entity", type="Person", definition="A test entity" ) manager.sync_entity(test_entity) print("✅ Entity synced") # 获取统计 stats = manager.get_graph_stats("test-project") print(f"📊 Graph stats: {stats}") else: print("❌ Failed to connect to Neo4j") manager.close()