629 lines
20 KiB
Python
629 lines
20 KiB
Python
"""
|
||
InsightFlow Export Module - Phase 5
|
||
支持导出知识图谱、项目报告、实体数据和转录文本
|
||
"""
|
||
|
||
import base64
|
||
import csv
|
||
import io
|
||
import json
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
try:
|
||
import pandas as pd
|
||
|
||
PANDAS_AVAILABLE = True
|
||
except ImportError:
|
||
PANDAS_AVAILABLE = False
|
||
|
||
try:
|
||
from reportlab.lib import colors
|
||
from reportlab.lib.pagesizes import A4
|
||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||
from reportlab.lib.units import inch
|
||
from reportlab.platypus import (
|
||
PageBreak,
|
||
Paragraph,
|
||
SimpleDocTemplate,
|
||
Spacer,
|
||
Table,
|
||
TableStyle,
|
||
)
|
||
|
||
REPORTLAB_AVAILABLE = True
|
||
except ImportError:
|
||
REPORTLAB_AVAILABLE = False
|
||
|
||
|
||
@dataclass
|
||
class ExportEntity:
|
||
id: str
|
||
name: str
|
||
type: str
|
||
definition: str
|
||
aliases: list[str]
|
||
mention_count: int
|
||
attributes: dict[str, Any]
|
||
|
||
|
||
@dataclass
|
||
class ExportRelation:
|
||
id: str
|
||
source: str
|
||
target: str
|
||
relation_type: str
|
||
confidence: float
|
||
evidence: str
|
||
|
||
|
||
@dataclass
|
||
class ExportTranscript:
|
||
id: str
|
||
name: str
|
||
type: str # audio/document
|
||
content: str
|
||
segments: list[dict]
|
||
entity_mentions: list[dict]
|
||
|
||
|
||
class ExportManager:
|
||
"""导出管理器 - 处理各种导出需求"""
|
||
|
||
def __init__(self, db_manager=None):
|
||
self.db = db_manager
|
||
|
||
def export_knowledge_graph_svg(
|
||
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
|
||
) -> str:
|
||
"""
|
||
导出知识图谱为 SVG 格式
|
||
|
||
Returns:
|
||
SVG 字符串
|
||
"""
|
||
# 计算布局参数
|
||
width = 1200
|
||
height = 800
|
||
center_x = width / 2
|
||
center_y = height / 2
|
||
radius = 300
|
||
|
||
# 按类型分组实体
|
||
entities_by_type = {}
|
||
for e in entities:
|
||
if e.type not in entities_by_type:
|
||
entities_by_type[e.type] = []
|
||
entities_by_type[e.type].append(e)
|
||
|
||
# 颜色映射
|
||
type_colors = {
|
||
"PERSON": "#FF6B6B",
|
||
"ORGANIZATION": "#4ECDC4",
|
||
"LOCATION": "#45B7D1",
|
||
"PRODUCT": "#96CEB4",
|
||
"TECHNOLOGY": "#FFEAA7",
|
||
"EVENT": "#DDA0DD",
|
||
"CONCEPT": "#98D8C8",
|
||
"default": "#BDC3C7",
|
||
}
|
||
|
||
# 计算实体位置
|
||
entity_positions = {}
|
||
angle_step = 2 * 3.14159 / max(len(entities), 1)
|
||
|
||
for i, entity in enumerate(entities):
|
||
i * angle_step
|
||
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
|
||
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
|
||
entity_positions[entity.id] = (x, y)
|
||
|
||
# 生成 SVG
|
||
svg_parts = [
|
||
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" '
|
||
f'viewBox="0 0 {width} {height}">',
|
||
"<defs>",
|
||
' <marker id="arrowhead" markerWidth="10" markerHeight="7" '
|
||
'refX="9" refY="3.5" orient="auto">',
|
||
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>',
|
||
" </marker>",
|
||
"</defs>",
|
||
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>',
|
||
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" '
|
||
f'font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>',
|
||
]
|
||
|
||
# 绘制关系连线
|
||
for rel in relations:
|
||
if rel.source in entity_positions and rel.target in entity_positions:
|
||
x1, y1 = entity_positions[rel.source]
|
||
x2, y2 = entity_positions[rel.target]
|
||
|
||
# 计算箭头终点(避免覆盖节点)
|
||
dx = x2 - x1
|
||
dy = y2 - y1
|
||
dist = (dx**2 + dy**2) ** 0.5
|
||
if dist > 0:
|
||
offset = 40
|
||
x2 = x2 - dx * offset / dist
|
||
y2 = y2 - dy * offset / dist
|
||
|
||
svg_parts.append(
|
||
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" '
|
||
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>'
|
||
)
|
||
|
||
# 关系标签
|
||
mid_x = (x1 + x2) / 2
|
||
mid_y = (y1 + y2) / 2
|
||
svg_parts.append(
|
||
f'<rect x="{mid_x - 30}" y="{mid_y - 10}" width="60" height="20" '
|
||
f'fill="white" stroke="#bdc3c7" rx="3"/>'
|
||
)
|
||
svg_parts.append(
|
||
f'<text x="{mid_x}" y="{mid_y + 5}" text-anchor="middle" '
|
||
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>'
|
||
)
|
||
|
||
# 绘制实体节点
|
||
for entity in entities:
|
||
if entity.id in entity_positions:
|
||
x, y = entity_positions[entity.id]
|
||
color = type_colors.get(entity.type, type_colors["default"])
|
||
|
||
# 节点圆圈
|
||
svg_parts.append(
|
||
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>'
|
||
)
|
||
|
||
# 实体名称
|
||
svg_parts.append(
|
||
f'<text x="{x}" y="{y + 5}" text-anchor="middle" font-size="12" '
|
||
f'font-weight="bold" fill="white">{entity.name[:8]}</text>'
|
||
)
|
||
|
||
# 实体类型
|
||
svg_parts.append(
|
||
f'<text x="{x}" y="{y + 55}" text-anchor="middle" font-size="10" '
|
||
f'fill="#7f8c8d">{entity.type}</text>'
|
||
)
|
||
|
||
# 图例
|
||
legend_x = width - 150
|
||
legend_y = 80
|
||
rect_x = legend_x - 10
|
||
rect_y = legend_y - 20
|
||
rect_height = len(type_colors) * 25 + 10
|
||
svg_parts.append(
|
||
f'<rect x="{rect_x}" y="{rect_y}" width="140" height="{rect_height}" '
|
||
f'fill="white" stroke="#bdc3c7" rx="5"/>'
|
||
)
|
||
svg_parts.append(
|
||
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
|
||
f'fill="#2c3e50">实体类型</text>'
|
||
)
|
||
|
||
for i, (etype, color) in enumerate(type_colors.items()):
|
||
if etype != "default":
|
||
y_pos = legend_y + 25 + i * 20
|
||
svg_parts.append(
|
||
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
|
||
)
|
||
text_y = y_pos + 4
|
||
svg_parts.append(
|
||
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
|
||
f'fill="#2c3e50">{etype}</text>'
|
||
)
|
||
|
||
svg_parts.append("</svg>")
|
||
return "\n".join(svg_parts)
|
||
|
||
def export_knowledge_graph_png(
|
||
self, project_id: str, entities: list[ExportEntity], relations: list[ExportRelation]
|
||
) -> bytes:
|
||
"""
|
||
导出知识图谱为 PNG 格式
|
||
|
||
Returns:
|
||
PNG 图像字节
|
||
"""
|
||
try:
|
||
import cairosvg
|
||
|
||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
|
||
return png_bytes
|
||
except ImportError:
|
||
# 如果没有 cairosvg,返回 SVG 的 base64
|
||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||
return base64.b64encode(svg_content.encode("utf-8"))
|
||
|
||
def export_entities_excel(self, entities: list[ExportEntity]) -> bytes:
|
||
"""
|
||
导出实体数据为 Excel 格式
|
||
|
||
Returns:
|
||
Excel 文件字节
|
||
"""
|
||
if not PANDAS_AVAILABLE:
|
||
raise ImportError("pandas is required for Excel export")
|
||
|
||
# 准备数据
|
||
data = []
|
||
for e in entities:
|
||
row = {
|
||
"ID": e.id,
|
||
"名称": e.name,
|
||
"类型": e.type,
|
||
"定义": e.definition,
|
||
"别名": ", ".join(e.aliases),
|
||
"提及次数": e.mention_count,
|
||
}
|
||
# 添加属性
|
||
for attr_name, attr_value in e.attributes.items():
|
||
row[f"属性:{attr_name}"] = attr_value
|
||
data.append(row)
|
||
|
||
df = pd.DataFrame(data)
|
||
|
||
# 写入 Excel
|
||
output = io.BytesIO()
|
||
with pd.ExcelWriter(output, engine="openpyxl") as writer:
|
||
df.to_excel(writer, sheet_name="实体列表", index=False)
|
||
|
||
# 调整列宽
|
||
worksheet = writer.sheets["实体列表"]
|
||
for column in worksheet.columns:
|
||
max_length = 0
|
||
column_letter = column[0].column_letter
|
||
for cell in column:
|
||
try:
|
||
if len(str(cell.value)) > max_length:
|
||
max_length = len(str(cell.value))
|
||
except (AttributeError, TypeError, ValueError):
|
||
pass
|
||
adjusted_width = min(max_length + 2, 50)
|
||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||
|
||
return output.getvalue()
|
||
|
||
def export_entities_csv(self, entities: list[ExportEntity]) -> str:
|
||
"""
|
||
导出实体数据为 CSV 格式
|
||
|
||
Returns:
|
||
CSV 字符串
|
||
"""
|
||
output = io.StringIO()
|
||
|
||
# 收集所有可能的属性列
|
||
all_attrs = set()
|
||
for e in entities:
|
||
all_attrs.update(e.attributes.keys())
|
||
|
||
# 表头
|
||
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [
|
||
f"属性:{a}" for a in sorted(all_attrs)
|
||
]
|
||
|
||
writer = csv.writer(output)
|
||
writer.writerow(headers)
|
||
|
||
# 数据行
|
||
for e in entities:
|
||
row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count]
|
||
for attr in sorted(all_attrs):
|
||
row.append(e.attributes.get(attr, ""))
|
||
writer.writerow(row)
|
||
|
||
return output.getvalue()
|
||
|
||
def export_relations_csv(self, relations: list[ExportRelation]) -> str:
|
||
"""
|
||
导出关系数据为 CSV 格式
|
||
|
||
Returns:
|
||
CSV 字符串
|
||
"""
|
||
|
||
output = io.StringIO()
|
||
writer = csv.writer(output)
|
||
writer.writerow(["ID", "源实体", "目标实体", "关系类型", "置信度", "证据"])
|
||
|
||
for r in relations:
|
||
writer.writerow([r.id, r.source, r.target, r.relation_type, r.confidence, r.evidence])
|
||
|
||
return output.getvalue()
|
||
|
||
def export_transcript_markdown(
|
||
self, transcript: ExportTranscript, entities_map: dict[str, ExportEntity]
|
||
) -> str:
|
||
"""
|
||
导出转录文本为 Markdown 格式
|
||
|
||
Returns:
|
||
Markdown 字符串
|
||
"""
|
||
lines = [
|
||
f"# {transcript.name}",
|
||
"",
|
||
f"**类型**: {transcript.type}",
|
||
f"**ID**: {transcript.id}",
|
||
"",
|
||
"---",
|
||
"",
|
||
"## 内容",
|
||
"",
|
||
transcript.content,
|
||
"",
|
||
"---",
|
||
"",
|
||
]
|
||
|
||
if transcript.segments:
|
||
lines.extend(
|
||
[
|
||
"## 分段详情",
|
||
"",
|
||
]
|
||
)
|
||
for seg in transcript.segments:
|
||
speaker = seg.get("speaker", "Unknown")
|
||
start = seg.get("start", 0)
|
||
end = seg.get("end", 0)
|
||
text = seg.get("text", "")
|
||
lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}")
|
||
lines.append("")
|
||
|
||
if transcript.entity_mentions:
|
||
lines.extend(
|
||
[
|
||
"",
|
||
"## 实体提及",
|
||
"",
|
||
"| 实体 | 类型 | 位置 | 上下文 |",
|
||
"|------|------|------|--------|",
|
||
]
|
||
)
|
||
for mention in transcript.entity_mentions:
|
||
entity_id = mention.get("entity_id", "")
|
||
entity = entities_map.get(entity_id)
|
||
entity_name = entity.name if entity else mention.get("entity_name", "Unknown")
|
||
entity_type = entity.type if entity else "Unknown"
|
||
position = mention.get("position", "")
|
||
context = mention.get("context", "")[:50] + "..." if mention.get("context") else ""
|
||
lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |")
|
||
|
||
return "\n".join(lines)
|
||
|
||
def export_project_report_pdf(
|
||
self,
|
||
project_id: str,
|
||
project_name: str,
|
||
entities: list[ExportEntity],
|
||
relations: list[ExportRelation],
|
||
transcripts: list[ExportTranscript],
|
||
summary: str = "",
|
||
) -> bytes:
|
||
"""
|
||
导出项目报告为 PDF 格式
|
||
|
||
Returns:
|
||
PDF 文件字节
|
||
"""
|
||
if not REPORTLAB_AVAILABLE:
|
||
raise ImportError("reportlab is required for PDF export")
|
||
|
||
output = io.BytesIO()
|
||
doc = SimpleDocTemplate(
|
||
output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18
|
||
)
|
||
|
||
# 样式
|
||
styles = getSampleStyleSheet()
|
||
title_style = ParagraphStyle(
|
||
"CustomTitle",
|
||
parent=styles["Heading1"],
|
||
fontSize=24,
|
||
spaceAfter=30,
|
||
textColor=colors.HexColor("#2c3e50"),
|
||
)
|
||
heading_style = ParagraphStyle(
|
||
"CustomHeading",
|
||
parent=styles["Heading2"],
|
||
fontSize=16,
|
||
spaceAfter=12,
|
||
textColor=colors.HexColor("#34495e"),
|
||
)
|
||
|
||
story = []
|
||
|
||
# 标题页
|
||
story.append(Paragraph("InsightFlow 项目报告", title_style))
|
||
story.append(Paragraph(f"项目名称: {project_name}", styles["Heading2"]))
|
||
story.append(
|
||
Paragraph(
|
||
f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}",
|
||
styles["Normal"],
|
||
)
|
||
)
|
||
story.append(Spacer(1, 0.3 * inch))
|
||
|
||
# 统计概览
|
||
story.append(Paragraph("项目概览", heading_style))
|
||
stats_data = [
|
||
["指标", "数值"],
|
||
["实体数量", str(len(entities))],
|
||
["关系数量", str(len(relations))],
|
||
["文档数量", str(len(transcripts))],
|
||
]
|
||
|
||
# 按类型统计实体
|
||
type_counts = {}
|
||
for e in entities:
|
||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||
|
||
for etype, count in sorted(type_counts.items()):
|
||
stats_data.append([f"{etype} 实体", str(count)])
|
||
|
||
stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch])
|
||
stats_table.setStyle(
|
||
TableStyle(
|
||
[
|
||
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#34495e")),
|
||
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
|
||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
|
||
("FONTSIZE", (0, 0), (-1, 0), 12),
|
||
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
|
||
("BACKGROUND", (0, 1), (-1, -1), colors.HexColor("#ecf0f1")),
|
||
("GRID", (0, 0), (-1, -1), 1, colors.HexColor("#bdc3c7")),
|
||
]
|
||
)
|
||
)
|
||
story.append(stats_table)
|
||
story.append(Spacer(1, 0.3 * inch))
|
||
|
||
# 项目总结
|
||
if summary:
|
||
story.append(Paragraph("项目总结", heading_style))
|
||
story.append(Paragraph(summary, styles["Normal"]))
|
||
story.append(Spacer(1, 0.3 * inch))
|
||
|
||
# 实体列表
|
||
if entities:
|
||
story.append(PageBreak())
|
||
story.append(Paragraph("实体列表", heading_style))
|
||
|
||
entity_data = [["名称", "类型", "提及次数", "定义"]]
|
||
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[
|
||
:50
|
||
]: # 限制前50个
|
||
entity_data.append(
|
||
[
|
||
e.name,
|
||
e.type,
|
||
str(e.mention_count),
|
||
(e.definition[:100] + "...") if len(e.definition) > 100 else e.definition,
|
||
]
|
||
)
|
||
|
||
entity_table = Table(
|
||
entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch]
|
||
)
|
||
entity_table.setStyle(
|
||
TableStyle(
|
||
[
|
||
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#34495e")),
|
||
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
|
||
("ALIGN", (0, 0), (-1, -1), "LEFT"),
|
||
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
|
||
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
|
||
("BACKGROUND", (0, 1), (-1, -1), colors.HexColor("#ecf0f1")),
|
||
("GRID", (0, 0), (-1, -1), 1, colors.HexColor("#bdc3c7")),
|
||
("VALIGN", (0, 0), (-1, -1), "TOP"),
|
||
]
|
||
)
|
||
)
|
||
story.append(entity_table)
|
||
|
||
# 关系列表
|
||
if relations:
|
||
story.append(PageBreak())
|
||
story.append(Paragraph("关系列表", heading_style))
|
||
|
||
relation_data = [["源实体", "关系", "目标实体", "置信度"]]
|
||
for r in relations[:100]: # 限制前100个
|
||
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
|
||
|
||
relation_table = Table(
|
||
relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch]
|
||
)
|
||
relation_table.setStyle(
|
||
TableStyle(
|
||
[
|
||
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#34495e")),
|
||
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
|
||
("ALIGN", (0, 0), (-1, -1), "LEFT"),
|
||
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
|
||
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
|
||
("BACKGROUND", (0, 1), (-1, -1), colors.HexColor("#ecf0f1")),
|
||
("GRID", (0, 0), (-1, -1), 1, colors.HexColor("#bdc3c7")),
|
||
]
|
||
)
|
||
)
|
||
story.append(relation_table)
|
||
|
||
doc.build(story)
|
||
return output.getvalue()
|
||
|
||
def export_project_json(
|
||
self,
|
||
project_id: str,
|
||
project_name: str,
|
||
entities: list[ExportEntity],
|
||
relations: list[ExportRelation],
|
||
transcripts: list[ExportTranscript],
|
||
) -> str:
|
||
"""
|
||
导出完整项目数据为 JSON 格式
|
||
|
||
Returns:
|
||
JSON 字符串
|
||
"""
|
||
data = {
|
||
"project_id": project_id,
|
||
"project_name": project_name,
|
||
"export_time": datetime.now().isoformat(),
|
||
"entities": [
|
||
{
|
||
"id": e.id,
|
||
"name": e.name,
|
||
"type": e.type,
|
||
"definition": e.definition,
|
||
"aliases": e.aliases,
|
||
"mention_count": e.mention_count,
|
||
"attributes": e.attributes,
|
||
}
|
||
for e in entities
|
||
],
|
||
"relations": [
|
||
{
|
||
"id": r.id,
|
||
"source": r.source,
|
||
"target": r.target,
|
||
"relation_type": r.relation_type,
|
||
"confidence": r.confidence,
|
||
"evidence": r.evidence,
|
||
}
|
||
for r in relations
|
||
],
|
||
"transcripts": [
|
||
{
|
||
"id": t.id,
|
||
"name": t.name,
|
||
"type": t.type,
|
||
"content": t.content,
|
||
"segments": t.segments,
|
||
}
|
||
for t in transcripts
|
||
],
|
||
}
|
||
|
||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||
|
||
|
||
# 全局导出管理器实例
|
||
_export_manager = None
|
||
|
||
|
||
def get_export_manager(db_manager=None) -> None:
|
||
"""获取导出管理器实例"""
|
||
global _export_manager
|
||
if _export_manager is None:
|
||
_export_manager = ExportManager(db_manager)
|
||
return _export_manager
|