Files
insightflow/backend/image_processor.py
OpenClaw Bot 33555642db fix: auto-fix code issues (cron)
- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
2026-02-28 03:03:50 +08:00

546 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
InsightFlow Image Processor - Phase 7
图片处理模块识别白板、PPT、手写笔记等内容
"""
import base64
import io
import os
import uuid
from dataclasses import dataclass
# 尝试导入图像处理库
try:
from PIL import Image, ImageEnhance, ImageFilter
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
try:
import cv2
import numpy as np
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
try:
import pytesseract
PYTESSERACT_AVAILABLE = True
except ImportError:
PYTESSERACT_AVAILABLE = False
@dataclass
class ImageEntity:
"""图片中检测到的实体"""
name: str
type: str
confidence: float
bbox: tuple[int, int, int, int] | None = None # (x, y, width, height)
@dataclass
class ImageRelation:
"""图片中检测到的关系"""
source: str
target: str
relation_type: str
confidence: float
@dataclass
class ImageProcessingResult:
"""图片处理结果"""
image_id: str
image_type: str # whiteboard, ppt, handwritten, screenshot, other
ocr_text: str
description: str
entities: list[ImageEntity]
relations: list[ImageRelation]
width: int
height: int
success: bool
error_message: str = ""
@dataclass
class BatchProcessingResult:
"""批量图片处理结果"""
results: list[ImageProcessingResult]
total_count: int
success_count: int
failed_count: int
class ImageProcessor:
"""图片处理器 - 处理各种类型图片"""
# 图片类型定义
IMAGE_TYPES = {
"whiteboard": "白板",
"ppt": "PPT/演示文稿",
"handwritten": "手写笔记",
"screenshot": "屏幕截图",
"document": "文档图片",
"other": "其他",
}
def __init__(self, temp_dir: str = None) -> None:
"""
初始化图片处理器
Args:
temp_dir: 临时文件目录
"""
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok=True)
def preprocess_image(self, image, image_type: str = None) -> None:
"""
预处理图片以提高OCR质量
Args:
image: PIL Image 对象
image_type: 图片类型(用于针对性处理)
Returns:
处理后的图片
"""
if not PIL_AVAILABLE:
return image
try:
# 转换为RGB如果是RGBA
if image.mode == "RGBA":
image = image.convert("RGB")
# 根据图片类型进行针对性处理
if image_type == "whiteboard":
# 白板:增强对比度,去除背景
image = self._enhance_whiteboard(image)
elif image_type == "handwritten":
# 手写笔记:降噪,增强对比度
image = self._enhance_handwritten(image)
elif image_type == "screenshot":
# 截图:轻微锐化
image = image.filter(ImageFilter.SHARPEN)
# 通用处理:调整大小(如果太大)
max_size = 4096
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
except Exception as e:
print(f"Image preprocessing error: {e}")
return image
def _enhance_whiteboard(self, image) -> None:
"""增强白板图片"""
# 转换为灰度
gray = image.convert("L")
# 增强对比度
enhancer = ImageEnhance.Contrast(gray)
enhanced = enhancer.enhance(2.0)
# 二值化
threshold = 128
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
return binary.convert("L")
def _enhance_handwritten(self, image) -> None:
"""增强手写笔记图片"""
# 转换为灰度
gray = image.convert("L")
# 轻微降噪
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
# 增强对比度
enhancer = ImageEnhance.Contrast(blurred)
enhanced = enhancer.enhance(1.5)
return enhanced
def detect_image_type(self, image, ocr_text: str = "") -> str:
"""
自动检测图片类型
Args:
image: PIL Image 对象
ocr_text: OCR识别的文本
Returns:
图片类型字符串
"""
if not PIL_AVAILABLE:
return "other"
try:
# 基于图片特征和OCR内容判断类型
width, height = image.size
aspect_ratio = width / height
# 检测是否为PPT通常是16:9或4:3
if 1.3 <= aspect_ratio <= 1.8:
# 检查是否有典型的PPT特征标题、项目符号等
if any(keyword in ocr_text.lower() for keyword in ["slide", "page", "", ""]):
return "ppt"
# 检测是否为白板(大量手写文字,可能有箭头、框等)
if CV2_AVAILABLE:
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# 检测边缘(白板通常有很多线条)
edges = cv2.Canny(gray, 50, 150)
edge_ratio = np.sum(edges > 0) / edges.size
# 如果边缘比例高,可能是白板
if edge_ratio > 0.05 and len(ocr_text) > 50:
return "whiteboard"
# 检测是否为手写笔记(文字密度高,可能有涂鸦)
if len(ocr_text) > 100 and aspect_ratio < 1.5:
# 检查手写特征(不规则的行高)
return "handwritten"
# 检测是否为截图可能有UI元素
if any(keyword in ocr_text.lower() for keyword in ["button", "menu", "click", "登录", "确定", "取消"]):
return "screenshot"
# 默认文档类型
if len(ocr_text) > 200:
return "document"
return "other"
except Exception as e:
print(f"Image type detection error: {e}")
return "other"
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> tuple[str, float]:
"""
对图片进行OCR识别
Args:
image: PIL Image 对象
lang: OCR语言
Returns:
(识别的文本, 置信度)
"""
if not PYTESSERACT_AVAILABLE:
return "", 0.0
try:
# 预处理图片
processed_image = self.preprocess_image(image)
# 执行OCR
text = pytesseract.image_to_string(processed_image, lang=lang)
# 获取置信度
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
confidences = [int(c) for c in data["conf"] if int(c) > 0]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
return text.strip(), avg_confidence / 100.0
except Exception as e:
print(f"OCR error: {e}")
return "", 0.0
def extract_entities_from_text(self, text: str) -> list[ImageEntity]:
"""
从OCR文本中提取实体
Args:
text: OCR识别的文本
Returns:
实体列表
"""
entities = []
# 简单的实体提取规则可以替换为LLM调用
# 提取大写字母开头的词组(可能是专有名词)
import re
# 项目名称(通常是大写或带引号)
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
for match in re.finditer(project_pattern, text):
name = match.group(1) or match.group(2)
if name and len(name) > 2:
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
# 人名(中文)
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)"
for match in re.finditer(name_pattern, text):
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
# 技术术语
tech_keywords = [
"K8s",
"Kubernetes",
"Docker",
"API",
"SDK",
"AI",
"ML",
"Python",
"Java",
"React",
"Vue",
"Node.js",
"数据库",
"服务器",
]
for keyword in tech_keywords:
if keyword in text:
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
# 去重
seen = set()
unique_entities = []
for e in entities:
key = (e.name.lower(), e.type)
if key not in seen:
seen.add(key)
unique_entities.append(e)
return unique_entities
def generate_description(self, image_type: str, ocr_text: str, entities: list[ImageEntity]) -> str:
"""
生成图片描述
Args:
image_type: 图片类型
ocr_text: OCR文本
entities: 检测到的实体
Returns:
图片描述
"""
type_name = self.IMAGE_TYPES.get(image_type, "图片")
description_parts = [f"这是一张{type_name}图片。"]
if ocr_text:
# 提取前200字符作为摘要
text_preview = ocr_text[:200].replace("\n", " ")
if len(ocr_text) > 200:
text_preview += "..."
description_parts.append(f"内容摘要:{text_preview}")
if entities:
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}")
return " ".join(description_parts)
def process_image(
self, image_data: bytes, filename: str = None, image_id: str = None, detect_type: bool = True
) -> ImageProcessingResult:
"""
处理单张图片
Args:
image_data: 图片二进制数据
filename: 文件名
image_id: 图片ID可选
detect_type: 是否自动检测图片类型
Returns:
图片处理结果
"""
image_id = image_id or str(uuid.uuid4())[:8]
if not PIL_AVAILABLE:
return ImageProcessingResult(
image_id=image_id,
image_type="other",
ocr_text="",
description="PIL not available",
entities=[],
relations=[],
width=0,
height=0,
success=False,
error_message="PIL library not available",
)
try:
# 加载图片
image = Image.open(io.BytesIO(image_data))
width, height = image.size
# 执行OCR
ocr_text, ocr_confidence = self.perform_ocr(image)
# 检测图片类型
image_type = "other"
if detect_type:
image_type = self.detect_image_type(image, ocr_text)
# 提取实体
entities = self.extract_entities_from_text(ocr_text)
# 生成描述
description = self.generate_description(image_type, ocr_text, entities)
# 提取关系(基于实体共现)
relations = self._extract_relations(entities, ocr_text)
# 保存图片文件(可选)
if filename:
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
image.save(save_path)
return ImageProcessingResult(
image_id=image_id,
image_type=image_type,
ocr_text=ocr_text,
description=description,
entities=entities,
relations=relations,
width=width,
height=height,
success=True,
)
except Exception as e:
return ImageProcessingResult(
image_id=image_id,
image_type="other",
ocr_text="",
description="",
entities=[],
relations=[],
width=0,
height=0,
success=False,
error_message=str(e),
)
def _extract_relations(self, entities: list[ImageEntity], text: str) -> list[ImageRelation]:
"""
从文本中提取实体关系
Args:
entities: 实体列表
text: 文本内容
Returns:
关系列表
"""
relations = []
if len(entities) < 2:
return relations
# 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关
sentences = text.replace("", ".").replace("", "!").replace("", "?").split(".")
for sentence in sentences:
sentence_entities = []
for entity in entities:
if entity.name in sentence:
sentence_entities.append(entity)
# 如果句子中有多个实体,建立关系
if len(sentence_entities) >= 2:
for i in range(len(sentence_entities)):
for j in range(i + 1, len(sentence_entities)):
relations.append(
ImageRelation(
source=sentence_entities[i].name,
target=sentence_entities[j].name,
relation_type="related",
confidence=0.5,
)
)
return relations
def process_batch(self, images_data: list[tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult:
"""
批量处理图片
Args:
images_data: 图片数据列表,每项为 (image_data, filename)
project_id: 项目ID
Returns:
批量处理结果
"""
results = []
success_count = 0
failed_count = 0
for image_data, filename in images_data:
result = self.process_image(image_data, filename)
results.append(result)
if result.success:
success_count += 1
else:
failed_count += 1
return BatchProcessingResult(
results=results, total_count=len(results), success_count=success_count, failed_count=failed_count
)
def image_to_base64(self, image_data: bytes) -> str:
"""
将图片转换为base64编码
Args:
image_data: 图片二进制数据
Returns:
base64编码的字符串
"""
return base64.b64encode(image_data).decode("utf-8")
def get_image_thumbnail(self, image_data: bytes, size: tuple[int, int] = (200, 200)) -> bytes:
"""
生成图片缩略图
Args:
image_data: 图片二进制数据
size: 缩略图尺寸
Returns:
缩略图二进制数据
"""
if not PIL_AVAILABLE:
return image_data
try:
image = Image.open(io.BytesIO(image_data))
image.thumbnail(size, Image.Resampling.LANCZOS)
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
return buffer.getvalue()
except Exception as e:
print(f"Thumbnail generation error: {e}")
return image_data
# Singleton instance
_image_processor = None
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
"""获取图片处理器单例"""
global _image_processor
if _image_processor is None:
_image_processor = ImageProcessor(temp_dir)
return _image_processor