fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解
This commit is contained in:
@@ -5,13 +5,12 @@ InsightFlow Database Manager - Phase 5
|
||||
支持实体属性扩展
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
DB_PATH = os.getenv("DB_PATH", "/app/data/insightflow.db")
|
||||
|
||||
@@ -33,9 +32,9 @@ class Entity:
|
||||
type: str
|
||||
definition: str = ""
|
||||
canonical_name: str = ""
|
||||
aliases: List[str] = None
|
||||
aliases: list[str] = None
|
||||
embedding: str = "" # Phase 3: 实体嵌入向量
|
||||
attributes: Dict = None # Phase 5: 实体属性
|
||||
attributes: dict = None # Phase 5: 实体属性
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
@@ -54,7 +53,7 @@ class AttributeTemplate:
|
||||
project_id: str
|
||||
name: str
|
||||
type: str # text, number, date, select, multiselect, boolean
|
||||
options: List[str] = None # 用于 select/multiselect
|
||||
options: list[str] = None # 用于 select/multiselect
|
||||
default_value: str = ""
|
||||
description: str = ""
|
||||
is_required: bool = False
|
||||
@@ -73,11 +72,11 @@ class EntityAttribute:
|
||||
|
||||
id: str
|
||||
entity_id: str
|
||||
template_id: Optional[str] = None
|
||||
template_id: str | None = None
|
||||
name: str = "" # 属性名称
|
||||
type: str = "text" # 属性类型
|
||||
value: str = ""
|
||||
options: List[str] = None # 选项列表
|
||||
options: list[str] = None # 选项列表
|
||||
template_name: str = "" # 关联查询时填充
|
||||
template_type: str = "" # 关联查询时填充
|
||||
created_at: str = ""
|
||||
@@ -126,7 +125,7 @@ class DatabaseManager:
|
||||
|
||||
def init_db(self):
|
||||
"""初始化数据库表"""
|
||||
with open(os.path.join(os.path.dirname(__file__), "schema.sql"), "r") as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
|
||||
schema = f.read()
|
||||
|
||||
conn = self.get_conn()
|
||||
@@ -147,7 +146,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return Project(id=project_id, name=name, description=description, created_at=now, updated_at=now)
|
||||
|
||||
def get_project(self, project_id: str) -> Optional[Project]:
|
||||
def get_project(self, project_id: str) -> Project | None:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone()
|
||||
conn.close()
|
||||
@@ -155,7 +154,7 @@ class DatabaseManager:
|
||||
return Project(**dict(row))
|
||||
return None
|
||||
|
||||
def list_projects(self) -> List[Project]:
|
||||
def list_projects(self) -> list[Project]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall()
|
||||
conn.close()
|
||||
@@ -184,7 +183,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return entity
|
||||
|
||||
def get_entity_by_name(self, project_id: str, name: str) -> Optional[Entity]:
|
||||
def get_entity_by_name(self, project_id: str, name: str) -> Entity | None:
|
||||
"""通过名称查找实体(用于对齐)"""
|
||||
conn = self.get_conn()
|
||||
row = conn.execute(
|
||||
@@ -198,7 +197,7 @@ class DatabaseManager:
|
||||
return Entity(**data)
|
||||
return None
|
||||
|
||||
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> List[Entity]:
|
||||
def find_similar_entities(self, project_id: str, name: str, threshold: float = 0.8) -> list[Entity]:
|
||||
"""查找相似实体"""
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
@@ -245,7 +244,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return self.get_entity(target_id)
|
||||
|
||||
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
||||
def get_entity(self, entity_id: str) -> Entity | None:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
|
||||
conn.close()
|
||||
@@ -255,7 +254,7 @@ class DatabaseManager:
|
||||
return Entity(**data)
|
||||
return None
|
||||
|
||||
def list_project_entities(self, project_id: str) -> List[Entity]:
|
||||
def list_project_entities(self, project_id: str) -> list[Entity]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM entities WHERE project_id = ? ORDER BY updated_at DESC", (project_id,)
|
||||
@@ -333,7 +332,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return mention
|
||||
|
||||
def get_entity_mentions(self, entity_id: str) -> List[EntityMention]:
|
||||
def get_entity_mentions(self, entity_id: str) -> list[EntityMention]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM entity_mentions WHERE entity_id = ? ORDER BY transcript_id, start_pos", (entity_id,)
|
||||
@@ -355,13 +354,13 @@ class DatabaseManager:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_transcript(self, transcript_id: str) -> Optional[dict]:
|
||||
def get_transcript(self, transcript_id: str) -> dict | None:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM transcripts WHERE id = ?", (transcript_id,)).fetchone()
|
||||
conn.close()
|
||||
return dict(row) if row else None
|
||||
|
||||
def list_project_transcripts(self, project_id: str) -> List[dict]:
|
||||
def list_project_transcripts(self, project_id: str) -> list[dict]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM transcripts WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
|
||||
@@ -404,7 +403,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return relation_id
|
||||
|
||||
def get_entity_relations(self, entity_id: str) -> List[dict]:
|
||||
def get_entity_relations(self, entity_id: str) -> list[dict]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM entity_relations
|
||||
@@ -415,7 +414,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def list_project_relations(self, project_id: str) -> List[dict]:
|
||||
def list_project_relations(self, project_id: str) -> list[dict]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM entity_relations WHERE project_id = ? ORDER BY created_at DESC", (project_id,)
|
||||
@@ -473,7 +472,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return term_id
|
||||
|
||||
def list_glossary(self, project_id: str) -> List[dict]:
|
||||
def list_glossary(self, project_id: str) -> list[dict]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM glossary WHERE project_id = ? ORDER BY frequency DESC", (project_id,)
|
||||
@@ -489,7 +488,7 @@ class DatabaseManager:
|
||||
|
||||
# ==================== Phase 4: Agent & Provenance ====================
|
||||
|
||||
def get_relation_with_details(self, relation_id: str) -> Optional[dict]:
|
||||
def get_relation_with_details(self, relation_id: str) -> dict | None:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute(
|
||||
"""SELECT r.*,
|
||||
@@ -505,7 +504,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return dict(row) if row else None
|
||||
|
||||
def get_entity_with_mentions(self, entity_id: str) -> Optional[dict]:
|
||||
def get_entity_with_mentions(self, entity_id: str) -> dict | None:
|
||||
conn = self.get_conn()
|
||||
entity_row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
|
||||
if not entity_row:
|
||||
@@ -539,7 +538,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return entity
|
||||
|
||||
def search_entities(self, project_id: str, query: str) -> List[Entity]:
|
||||
def search_entities(self, project_id: str, query: str) -> list[Entity]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM entities
|
||||
@@ -616,7 +615,7 @@ class DatabaseManager:
|
||||
|
||||
def get_project_timeline(
|
||||
self, project_id: str, entity_id: str = None, start_date: str = None, end_date: str = None
|
||||
) -> List[dict]:
|
||||
) -> list[dict]:
|
||||
conn = self.get_conn()
|
||||
|
||||
conditions = ["t.project_id = ?"]
|
||||
@@ -722,7 +721,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return template
|
||||
|
||||
def get_attribute_template(self, template_id: str) -> Optional[AttributeTemplate]:
|
||||
def get_attribute_template(self, template_id: str) -> AttributeTemplate | None:
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM attribute_templates WHERE id = ?", (template_id,)).fetchone()
|
||||
conn.close()
|
||||
@@ -732,7 +731,7 @@ class DatabaseManager:
|
||||
return AttributeTemplate(**data)
|
||||
return None
|
||||
|
||||
def list_attribute_templates(self, project_id: str) -> List[AttributeTemplate]:
|
||||
def list_attribute_templates(self, project_id: str) -> list[AttributeTemplate]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM attribute_templates WHERE project_id = ?
|
||||
@@ -748,7 +747,7 @@ class DatabaseManager:
|
||||
templates.append(AttributeTemplate(**data))
|
||||
return templates
|
||||
|
||||
def update_attribute_template(self, template_id: str, **kwargs) -> Optional[AttributeTemplate]:
|
||||
def update_attribute_template(self, template_id: str, **kwargs) -> AttributeTemplate | None:
|
||||
conn = self.get_conn()
|
||||
allowed_fields = ["name", "type", "options", "default_value", "description", "is_required", "sort_order"]
|
||||
updates = []
|
||||
@@ -834,7 +833,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return attr
|
||||
|
||||
def get_entity_attributes(self, entity_id: str) -> List[EntityAttribute]:
|
||||
def get_entity_attributes(self, entity_id: str) -> list[EntityAttribute]:
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT ea.*, at.name as template_name, at.type as template_type
|
||||
@@ -846,7 +845,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return [EntityAttribute(**dict(r)) for r in rows]
|
||||
|
||||
def get_entity_with_attributes(self, entity_id: str) -> Optional[Entity]:
|
||||
def get_entity_with_attributes(self, entity_id: str) -> Entity | None:
|
||||
entity = self.get_entity(entity_id)
|
||||
if not entity:
|
||||
return None
|
||||
@@ -889,7 +888,7 @@ class DatabaseManager:
|
||||
|
||||
def get_attribute_history(
|
||||
self, entity_id: str = None, template_id: str = None, limit: int = 50
|
||||
) -> List[AttributeHistory]:
|
||||
) -> list[AttributeHistory]:
|
||||
conn = self.get_conn()
|
||||
conditions = []
|
||||
params = []
|
||||
@@ -913,7 +912,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return [AttributeHistory(**dict(r)) for r in rows]
|
||||
|
||||
def search_entities_by_attributes(self, project_id: str, attribute_filters: Dict[str, str]) -> List[Entity]:
|
||||
def search_entities_by_attributes(self, project_id: str, attribute_filters: dict[str, str]) -> list[Entity]:
|
||||
entities = self.list_project_entities(project_id)
|
||||
if not attribute_filters:
|
||||
return entities
|
||||
@@ -962,11 +961,11 @@ class DatabaseManager:
|
||||
filename: str,
|
||||
duration: float = 0,
|
||||
fps: float = 0,
|
||||
resolution: Dict = None,
|
||||
resolution: dict = None,
|
||||
audio_transcript_id: str = None,
|
||||
full_ocr_text: str = "",
|
||||
extracted_entities: List[Dict] = None,
|
||||
extracted_relations: List[Dict] = None,
|
||||
extracted_entities: list[dict] = None,
|
||||
extracted_relations: list[dict] = None,
|
||||
) -> str:
|
||||
"""创建视频记录"""
|
||||
conn = self.get_conn()
|
||||
@@ -998,7 +997,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return video_id
|
||||
|
||||
def get_video(self, video_id: str) -> Optional[Dict]:
|
||||
def get_video(self, video_id: str) -> dict | None:
|
||||
"""获取视频信息"""
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM videos WHERE id = ?", (video_id,)).fetchone()
|
||||
@@ -1012,7 +1011,7 @@ class DatabaseManager:
|
||||
return data
|
||||
return None
|
||||
|
||||
def list_project_videos(self, project_id: str) -> List[Dict]:
|
||||
def list_project_videos(self, project_id: str) -> list[dict]:
|
||||
"""获取项目的所有视频"""
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
@@ -1037,7 +1036,7 @@ class DatabaseManager:
|
||||
timestamp: float,
|
||||
image_url: str = None,
|
||||
ocr_text: str = None,
|
||||
extracted_entities: List[Dict] = None,
|
||||
extracted_entities: list[dict] = None,
|
||||
) -> str:
|
||||
"""创建视频帧记录"""
|
||||
conn = self.get_conn()
|
||||
@@ -1062,7 +1061,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return frame_id
|
||||
|
||||
def get_video_frames(self, video_id: str) -> List[Dict]:
|
||||
def get_video_frames(self, video_id: str) -> list[dict]:
|
||||
"""获取视频的所有帧"""
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
@@ -1084,8 +1083,8 @@ class DatabaseManager:
|
||||
filename: str,
|
||||
ocr_text: str = "",
|
||||
description: str = "",
|
||||
extracted_entities: List[Dict] = None,
|
||||
extracted_relations: List[Dict] = None,
|
||||
extracted_entities: list[dict] = None,
|
||||
extracted_relations: list[dict] = None,
|
||||
) -> str:
|
||||
"""创建图片记录"""
|
||||
conn = self.get_conn()
|
||||
@@ -1113,7 +1112,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return image_id
|
||||
|
||||
def get_image(self, image_id: str) -> Optional[Dict]:
|
||||
def get_image(self, image_id: str) -> dict | None:
|
||||
"""获取图片信息"""
|
||||
conn = self.get_conn()
|
||||
row = conn.execute("SELECT * FROM images WHERE id = ?", (image_id,)).fetchone()
|
||||
@@ -1126,7 +1125,7 @@ class DatabaseManager:
|
||||
return data
|
||||
return None
|
||||
|
||||
def list_project_images(self, project_id: str) -> List[Dict]:
|
||||
def list_project_images(self, project_id: str) -> list[dict]:
|
||||
"""获取项目的所有图片"""
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
@@ -1168,7 +1167,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return mention_id
|
||||
|
||||
def get_entity_multimodal_mentions(self, entity_id: str) -> List[Dict]:
|
||||
def get_entity_multimodal_mentions(self, entity_id: str) -> list[dict]:
|
||||
"""获取实体的多模态提及"""
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
@@ -1181,7 +1180,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> List[Dict]:
|
||||
def get_project_multimodal_mentions(self, project_id: str, modality: str = None) -> list[dict]:
|
||||
"""获取项目的多模态提及"""
|
||||
conn = self.get_conn()
|
||||
|
||||
@@ -1214,7 +1213,7 @@ class DatabaseManager:
|
||||
link_type: str,
|
||||
confidence: float = 1.0,
|
||||
evidence: str = "",
|
||||
modalities: List[str] = None,
|
||||
modalities: list[str] = None,
|
||||
) -> str:
|
||||
"""创建多模态实体关联"""
|
||||
conn = self.get_conn()
|
||||
@@ -1231,7 +1230,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return link_id
|
||||
|
||||
def get_entity_multimodal_links(self, entity_id: str) -> List[Dict]:
|
||||
def get_entity_multimodal_links(self, entity_id: str) -> list[dict]:
|
||||
"""获取实体的多模态关联"""
|
||||
conn = self.get_conn()
|
||||
rows = conn.execute(
|
||||
@@ -1251,7 +1250,7 @@ class DatabaseManager:
|
||||
links.append(data)
|
||||
return links
|
||||
|
||||
def get_project_multimodal_stats(self, project_id: str) -> Dict:
|
||||
def get_project_multimodal_stats(self, project_id: str) -> dict:
|
||||
"""获取项目多模态统计信息"""
|
||||
conn = self.get_conn()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user