546 lines
16 KiB
Python
546 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
InsightFlow Image Processor - Phase 7
|
||
图片处理模块:识别白板、PPT、手写笔记等内容
|
||
"""
|
||
|
||
import base64
|
||
import io
|
||
import os
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
|
||
# 尝试导入图像处理库
|
||
try:
|
||
from PIL import Image, ImageEnhance, ImageFilter
|
||
|
||
PIL_AVAILABLE = True
|
||
except ImportError:
|
||
PIL_AVAILABLE = False
|
||
|
||
try:
|
||
import cv2
|
||
import numpy as np
|
||
|
||
CV2_AVAILABLE = True
|
||
except ImportError:
|
||
CV2_AVAILABLE = False
|
||
|
||
try:
|
||
import pytesseract
|
||
|
||
PYTESSERACT_AVAILABLE = True
|
||
except ImportError:
|
||
PYTESSERACT_AVAILABLE = False
|
||
|
||
@dataclass
|
||
class ImageEntity:
|
||
"""图片中检测到的实体"""
|
||
|
||
name: str
|
||
type: str
|
||
confidence: float
|
||
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
|
||
|
||
@dataclass
|
||
class ImageRelation:
|
||
"""图片中检测到的关系"""
|
||
|
||
source: str
|
||
target: str
|
||
relation_type: str
|
||
confidence: float
|
||
|
||
@dataclass
|
||
class ImageProcessingResult:
|
||
"""图片处理结果"""
|
||
|
||
image_id: str
|
||
image_type: str # whiteboard, ppt, handwritten, screenshot, other
|
||
ocr_text: str
|
||
description: str
|
||
entities: list[ImageEntity]
|
||
relations: list[ImageRelation]
|
||
width: int
|
||
height: int
|
||
success: bool
|
||
error_message: str = ""
|
||
|
||
@dataclass
|
||
class BatchProcessingResult:
|
||
"""批量图片处理结果"""
|
||
|
||
results: list[ImageProcessingResult]
|
||
total_count: int
|
||
success_count: int
|
||
failed_count: int
|
||
|
||
class ImageProcessor:
|
||
"""图片处理器 - 处理各种类型图片"""
|
||
|
||
# 图片类型定义
|
||
IMAGE_TYPES = {
|
||
"whiteboard": "白板",
|
||
"ppt": "PPT/演示文稿",
|
||
"handwritten": "手写笔记",
|
||
"screenshot": "屏幕截图",
|
||
"document": "文档图片",
|
||
"other": "其他",
|
||
}
|
||
|
||
def __init__(self, temp_dir: str = None) -> None:
|
||
"""
|
||
初始化图片处理器
|
||
|
||
Args:
|
||
temp_dir: 临时文件目录
|
||
"""
|
||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||
os.makedirs(self.temp_dir, exist_ok=True)
|
||
|
||
def preprocess_image(self, image, image_type: str = None) -> None:
|
||
"""
|
||
预处理图片以提高OCR质量
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
image_type: 图片类型(用于针对性处理)
|
||
|
||
Returns:
|
||
处理后的图片
|
||
"""
|
||
if not PIL_AVAILABLE:
|
||
return image
|
||
|
||
try:
|
||
# 转换为RGB(如果是RGBA)
|
||
if image.mode == "RGBA":
|
||
image = image.convert("RGB")
|
||
|
||
# 根据图片类型进行针对性处理
|
||
if image_type == "whiteboard":
|
||
# 白板:增强对比度,去除背景
|
||
image = self._enhance_whiteboard(image)
|
||
elif image_type == "handwritten":
|
||
# 手写笔记:降噪,增强对比度
|
||
image = self._enhance_handwritten(image)
|
||
elif image_type == "screenshot":
|
||
# 截图:轻微锐化
|
||
image = image.filter(ImageFilter.SHARPEN)
|
||
|
||
# 通用处理:调整大小(如果太大)
|
||
max_size = 4096
|
||
if max(image.size) > max_size:
|
||
ratio = max_size / max(image.size)
|
||
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||
|
||
return image
|
||
except Exception as e:
|
||
print(f"Image preprocessing error: {e}")
|
||
return image
|
||
|
||
def _enhance_whiteboard(self, image) -> None:
|
||
"""增强白板图片"""
|
||
# 转换为灰度
|
||
gray = image.convert("L")
|
||
|
||
# 增强对比度
|
||
enhancer = ImageEnhance.Contrast(gray)
|
||
enhanced = enhancer.enhance(2.0)
|
||
|
||
# 二值化
|
||
threshold = 128
|
||
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
|
||
|
||
return binary.convert("L")
|
||
|
||
def _enhance_handwritten(self, image) -> None:
|
||
"""增强手写笔记图片"""
|
||
# 转换为灰度
|
||
gray = image.convert("L")
|
||
|
||
# 轻微降噪
|
||
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
|
||
|
||
# 增强对比度
|
||
enhancer = ImageEnhance.Contrast(blurred)
|
||
enhanced = enhancer.enhance(1.5)
|
||
|
||
return enhanced
|
||
|
||
def detect_image_type(self, image, ocr_text: str = "") -> str:
|
||
"""
|
||
自动检测图片类型
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
ocr_text: OCR识别的文本
|
||
|
||
Returns:
|
||
图片类型字符串
|
||
"""
|
||
if not PIL_AVAILABLE:
|
||
return "other"
|
||
|
||
try:
|
||
# 基于图片特征和OCR内容判断类型
|
||
width, height = image.size
|
||
aspect_ratio = width / height
|
||
|
||
# 检测是否为PPT(通常是16:9或4:3)
|
||
if 1.3 <= aspect_ratio <= 1.8:
|
||
# 检查是否有典型的PPT特征(标题、项目符号等)
|
||
if any(keyword in ocr_text.lower() for keyword in ["slide", "page", "第", "页"]):
|
||
return "ppt"
|
||
|
||
# 检测是否为白板(大量手写文字,可能有箭头、框等)
|
||
if CV2_AVAILABLE:
|
||
img_array = np.array(image.convert("RGB"))
|
||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||
|
||
# 检测边缘(白板通常有很多线条)
|
||
edges = cv2.Canny(gray, 50, 150)
|
||
edge_ratio = np.sum(edges > 0) / edges.size
|
||
|
||
# 如果边缘比例高,可能是白板
|
||
if edge_ratio > 0.05 and len(ocr_text) > 50:
|
||
return "whiteboard"
|
||
|
||
# 检测是否为手写笔记(文字密度高,可能有涂鸦)
|
||
if len(ocr_text) > 100 and aspect_ratio < 1.5:
|
||
# 检查手写特征(不规则的行高)
|
||
return "handwritten"
|
||
|
||
# 检测是否为截图(可能有UI元素)
|
||
if any(keyword in ocr_text.lower() for keyword in ["button", "menu", "click", "登录", "确定", "取消"]):
|
||
return "screenshot"
|
||
|
||
# 默认文档类型
|
||
if len(ocr_text) > 200:
|
||
return "document"
|
||
|
||
return "other"
|
||
except Exception as e:
|
||
print(f"Image type detection error: {e}")
|
||
return "other"
|
||
|
||
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
|
||
"""
|
||
对图片进行OCR识别
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
lang: OCR语言
|
||
|
||
Returns:
|
||
(识别的文本, 置信度)
|
||
"""
|
||
if not PYTESSERACT_AVAILABLE:
|
||
return "", 0.0
|
||
|
||
try:
|
||
# 预处理图片
|
||
processed_image = self.preprocess_image(image)
|
||
|
||
# 执行OCR
|
||
text = pytesseract.image_to_string(processed_image, lang=lang)
|
||
|
||
# 获取置信度
|
||
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
|
||
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||
|
||
return text.strip(), avg_confidence / 100.0
|
||
except Exception as e:
|
||
print(f"OCR error: {e}")
|
||
return "", 0.0
|
||
|
||
def extract_entities_from_text(self, text: str) -> list[ImageEntity]:
|
||
"""
|
||
从OCR文本中提取实体
|
||
|
||
Args:
|
||
text: OCR识别的文本
|
||
|
||
Returns:
|
||
实体列表
|
||
"""
|
||
entities = []
|
||
|
||
# 简单的实体提取规则(可以替换为LLM调用)
|
||
# 提取大写字母开头的词组(可能是专有名词)
|
||
import re
|
||
|
||
# 项目名称(通常是大写或带引号)
|
||
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
|
||
for match in re.finditer(project_pattern, text):
|
||
name = match.group(1) or match.group(2)
|
||
if name and len(name) > 2:
|
||
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
|
||
|
||
# 人名(中文)
|
||
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)"
|
||
for match in re.finditer(name_pattern, text):
|
||
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
|
||
|
||
# 技术术语
|
||
tech_keywords = [
|
||
"K8s",
|
||
"Kubernetes",
|
||
"Docker",
|
||
"API",
|
||
"SDK",
|
||
"AI",
|
||
"ML",
|
||
"Python",
|
||
"Java",
|
||
"React",
|
||
"Vue",
|
||
"Node.js",
|
||
"数据库",
|
||
"服务器",
|
||
]
|
||
for keyword in tech_keywords:
|
||
if keyword in text:
|
||
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
|
||
|
||
# 去重
|
||
seen = set()
|
||
unique_entities = []
|
||
for e in entities:
|
||
key = (e.name.lower(), e.type)
|
||
if key not in seen:
|
||
seen.add(key)
|
||
unique_entities.append(e)
|
||
|
||
return unique_entities
|
||
|
||
def generate_description(self, image_type: str, ocr_text: str, entities: list[ImageEntity]) -> str:
|
||
"""
|
||
生成图片描述
|
||
|
||
Args:
|
||
image_type: 图片类型
|
||
ocr_text: OCR文本
|
||
entities: 检测到的实体
|
||
|
||
Returns:
|
||
图片描述
|
||
"""
|
||
type_name = self.IMAGE_TYPES.get(image_type, "图片")
|
||
|
||
description_parts = [f"这是一张{type_name}图片。"]
|
||
|
||
if ocr_text:
|
||
# 提取前200字符作为摘要
|
||
text_preview = ocr_text[:200].replace("\n", " ")
|
||
if len(ocr_text) > 200:
|
||
text_preview += "..."
|
||
description_parts.append(f"内容摘要:{text_preview}")
|
||
|
||
if entities:
|
||
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
|
||
description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}")
|
||
|
||
return " ".join(description_parts)
|
||
|
||
def process_image(
|
||
self, image_data: bytes, filename: str = None, image_id: str = None, detect_type: bool = True
|
||
) -> ImageProcessingResult:
|
||
"""
|
||
处理单张图片
|
||
|
||
Args:
|
||
image_data: 图片二进制数据
|
||
filename: 文件名
|
||
image_id: 图片ID(可选)
|
||
detect_type: 是否自动检测图片类型
|
||
|
||
Returns:
|
||
图片处理结果
|
||
"""
|
||
image_id = image_id or str(uuid.uuid4())[:8]
|
||
|
||
if not PIL_AVAILABLE:
|
||
return ImageProcessingResult(
|
||
image_id=image_id,
|
||
image_type="other",
|
||
ocr_text="",
|
||
description="PIL not available",
|
||
entities=[],
|
||
relations=[],
|
||
width=0,
|
||
height=0,
|
||
success=False,
|
||
error_message="PIL library not available",
|
||
)
|
||
|
||
try:
|
||
# 加载图片
|
||
image = Image.open(io.BytesIO(image_data))
|
||
width, height = image.size
|
||
|
||
# 执行OCR
|
||
ocr_text, ocr_confidence = self.perform_ocr(image)
|
||
|
||
# 检测图片类型
|
||
image_type = "other"
|
||
if detect_type:
|
||
image_type = self.detect_image_type(image, ocr_text)
|
||
|
||
# 提取实体
|
||
entities = self.extract_entities_from_text(ocr_text)
|
||
|
||
# 生成描述
|
||
description = self.generate_description(image_type, ocr_text, entities)
|
||
|
||
# 提取关系(基于实体共现)
|
||
relations = self._extract_relations(entities, ocr_text)
|
||
|
||
# 保存图片文件(可选)
|
||
if filename:
|
||
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
|
||
image.save(save_path)
|
||
|
||
return ImageProcessingResult(
|
||
image_id=image_id,
|
||
image_type=image_type,
|
||
ocr_text=ocr_text,
|
||
description=description,
|
||
entities=entities,
|
||
relations=relations,
|
||
width=width,
|
||
height=height,
|
||
success=True,
|
||
)
|
||
|
||
except Exception as e:
|
||
return ImageProcessingResult(
|
||
image_id=image_id,
|
||
image_type="other",
|
||
ocr_text="",
|
||
description="",
|
||
entities=[],
|
||
relations=[],
|
||
width=0,
|
||
height=0,
|
||
success=False,
|
||
error_message=str(e),
|
||
)
|
||
|
||
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
|
||
"""
|
||
从文本中提取实体关系
|
||
|
||
Args:
|
||
entities: 实体列表
|
||
text: 文本内容
|
||
|
||
Returns:
|
||
关系列表
|
||
"""
|
||
relations = []
|
||
|
||
if len(entities) < 2:
|
||
return relations
|
||
|
||
# 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关
|
||
sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".")
|
||
|
||
for sentence in sentences:
|
||
sentence_entities = []
|
||
for entity in entities:
|
||
if entity.name in sentence:
|
||
sentence_entities.append(entity)
|
||
|
||
# 如果句子中有多个实体,建立关系
|
||
if len(sentence_entities) >= 2:
|
||
for i in range(len(sentence_entities)):
|
||
for j in range(i + 1, len(sentence_entities)):
|
||
relations.append(
|
||
ImageRelation(
|
||
source=sentence_entities[i].name,
|
||
target=sentence_entities[j].name,
|
||
relation_type="related",
|
||
confidence=0.5,
|
||
)
|
||
)
|
||
|
||
return relations
|
||
|
||
def process_batch(self, images_data: list[tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult:
|
||
"""
|
||
批量处理图片
|
||
|
||
Args:
|
||
images_data: 图片数据列表,每项为 (image_data, filename)
|
||
project_id: 项目ID
|
||
|
||
Returns:
|
||
批量处理结果
|
||
"""
|
||
results = []
|
||
success_count = 0
|
||
failed_count = 0
|
||
|
||
for image_data, filename in images_data:
|
||
result = self.process_image(image_data, filename)
|
||
results.append(result)
|
||
|
||
if result.success:
|
||
success_count += 1
|
||
else:
|
||
failed_count += 1
|
||
|
||
return BatchProcessingResult(
|
||
results=results, total_count=len(results), success_count=success_count, failed_count=failed_count
|
||
)
|
||
|
||
def image_to_base64(self, image_data: bytes) -> str:
|
||
"""
|
||
将图片转换为base64编码
|
||
|
||
Args:
|
||
image_data: 图片二进制数据
|
||
|
||
Returns:
|
||
base64编码的字符串
|
||
"""
|
||
return base64.b64encode(image_data).decode("utf-8")
|
||
|
||
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
|
||
"""
|
||
生成图片缩略图
|
||
|
||
Args:
|
||
image_data: 图片二进制数据
|
||
size: 缩略图尺寸
|
||
|
||
Returns:
|
||
缩略图二进制数据
|
||
"""
|
||
if not PIL_AVAILABLE:
|
||
return image_data
|
||
|
||
try:
|
||
image = Image.open(io.BytesIO(image_data))
|
||
image.thumbnail(size, Image.Resampling.LANCZOS)
|
||
|
||
buffer = io.BytesIO()
|
||
image.save(buffer, format="JPEG")
|
||
return buffer.getvalue()
|
||
except Exception as e:
|
||
print(f"Thumbnail generation error: {e}")
|
||
return image_data
|
||
|
||
# Singleton instance
|
||
_image_processor = None
|
||
|
||
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
|
||
"""获取图片处理器单例"""
|
||
global _image_processor
|
||
if _image_processor is None:
|
||
_image_processor = ImageProcessor(temp_dir)
|
||
return _image_processor
|