fix: auto-fix code issues (cron)

- 修复重复导入/字段
- 修复异常处理
- 修复PEP8格式问题
- 添加类型注解
This commit is contained in:
OpenClaw Bot
2026-02-27 21:12:04 +08:00
parent 17bda3dbce
commit d767f0dddc
27 changed files with 3636 additions and 4158 deletions

View File

@@ -209,7 +209,7 @@ class AIManager:
self.openai_api_key = os.getenv("OPENAI_API_KEY", "") self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
def _get_db(self): def _get_db(self) -> None:
"""获取数据库连接""" """获取数据库连接"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -921,7 +921,6 @@ class AIManager:
# 解析关键要点 # 解析关键要点
key_points = [] key_points = []
import re
# 尝试从 JSON 中提取 # 尝试从 JSON 中提取
json_match = re.search(r"\{.*?\}", content, re.DOTALL) json_match = re.search(r"\{.*?\}", content, re.DOTALL)
@@ -1312,7 +1311,7 @@ class AIManager:
return [self._row_to_prediction_result(row) for row in rows] return [self._row_to_prediction_result(row) for row in rows]
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool): def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None:
"""更新预测反馈(用于模型改进)""" """更新预测反馈(用于模型改进)"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(

View File

@@ -51,7 +51,7 @@ class ApiKeyManager:
self.db_path = db_path self.db_path = db_path
self._init_db() self._init_db()
def _init_db(self): def _init_db(self) -> None:
"""初始化数据库表""" """初始化数据库表"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.executescript(""" conn.executescript("""
@@ -331,7 +331,7 @@ class ApiKeyManager:
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def update_last_used(self, key_id: str): def update_last_used(self, key_id: str) -> None:
"""更新最后使用时间""" """更新最后使用时间"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute( conn.execute(

View File

@@ -195,7 +195,7 @@ class CollaborationManager:
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}" data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
return hashlib.sha256(data.encode()).hexdigest()[:32] return hashlib.sha256(data.encode()).hexdigest()[:32]
def _save_share_to_db(self, share: ProjectShare): def _save_share_to_db(self, share: ProjectShare) -> None:
"""保存分享记录到数据库""" """保存分享记录到数据库"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute( cursor.execute(
@@ -286,7 +286,7 @@ class CollaborationManager:
allow_export=bool(row[12]), allow_export=bool(row[12]),
) )
def increment_share_usage(self, token: str): def increment_share_usage(self, token: str) -> None:
"""增加分享链接使用次数""" """增加分享链接使用次数"""
share = self._shares_cache.get(token) share = self._shares_cache.get(token)
if share: if share:
@@ -403,7 +403,7 @@ class CollaborationManager:
return comment return comment
def _save_comment_to_db(self, comment: Comment): def _save_comment_to_db(self, comment: Comment) -> None:
"""保存评论到数据库""" """保存评论到数据库"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute( cursor.execute(
@@ -616,7 +616,7 @@ class CollaborationManager:
return record return record
def _save_change_to_db(self, record: ChangeRecord): def _save_change_to_db(self, record: ChangeRecord) -> None:
"""保存变更记录到数据库""" """保存变更记录到数据库"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute( cursor.execute(
@@ -862,7 +862,7 @@ class CollaborationManager:
} }
return permissions_map.get(role, ["read"]) return permissions_map.get(role, ["read"])
def _save_member_to_db(self, member: TeamMember): def _save_member_to_db(self, member: TeamMember) -> None:
"""保存成员到数据库""" """保存成员到数据库"""
cursor = self.db.conn.cursor() cursor = self.db.conn.cursor()
cursor.execute( cursor.execute(
@@ -970,7 +970,7 @@ class CollaborationManager:
permissions = json.loads(row[0]) if row[0] else [] permissions = json.loads(row[0]) if row[0] else []
return permission in permissions or "admin" in permissions return permission in permissions or "admin" in permissions
def update_last_active(self, project_id: str, user_id: str): def update_last_active(self, project_id: str, user_id: str) -> None:
"""更新用户最后活跃时间""" """更新用户最后活跃时间"""
if not self.db: if not self.db:
return return
@@ -992,7 +992,7 @@ class CollaborationManager:
_collaboration_manager = None _collaboration_manager = None
def get_collaboration_manager(db_manager=None): def get_collaboration_manager(db_manager=None) -> None:
"""获取协作管理器单例""" """获取协作管理器单例"""
global _collaboration_manager global _collaboration_manager
if _collaboration_manager is None: if _collaboration_manager is None:

View File

@@ -123,7 +123,7 @@ class DatabaseManager:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def init_db(self): def init_db(self) -> None:
"""初始化数据库表""" """初始化数据库表"""
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f: with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
schema = f.read() schema = f.read()
@@ -299,7 +299,7 @@ class DatabaseManager:
conn.close() conn.close()
return self.get_entity(entity_id) return self.get_entity(entity_id)
def delete_entity(self, entity_id: str): def delete_entity(self, entity_id: str) -> None:
"""删除实体及其关联数据""" """删除实体及其关联数据"""
conn = self.get_conn() conn = self.get_conn()
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,)) conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))

View File

@@ -352,7 +352,7 @@ class DeveloperEcosystemManager:
self.db_path = db_path self.db_path = db_path
self.platform_fee_rate = 0.30 # 平台抽成比例 30% self.platform_fee_rate = 0.30 # 平台抽成比例 30%
def _get_db(self): def _get_db(self) -> None:
"""获取数据库连接""" """获取数据库连接"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -515,7 +515,7 @@ class DeveloperEcosystemManager:
return self.get_sdk_release(sdk_id) return self.get_sdk_release(sdk_id)
def increment_sdk_download(self, sdk_id: str): def increment_sdk_download(self, sdk_id: str) -> None:
"""增加 SDK 下载计数""" """增加 SDK 下载计数"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -785,7 +785,7 @@ class DeveloperEcosystemManager:
return self.get_template(template_id) return self.get_template(template_id)
def increment_template_install(self, template_id: str): def increment_template_install(self, template_id: str) -> None:
"""增加模板安装计数""" """增加模板安装计数"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -852,7 +852,7 @@ class DeveloperEcosystemManager:
return review return review
def _update_template_rating(self, conn, template_id: str): def _update_template_rating(self, conn, template_id: str) -> None:
"""更新模板评分""" """更新模板评分"""
row = conn.execute( row = conn.execute(
""" """
@@ -1084,7 +1084,7 @@ class DeveloperEcosystemManager:
return self.get_plugin(plugin_id) return self.get_plugin(plugin_id)
def increment_plugin_install(self, plugin_id: str, active: bool = True): def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None:
"""增加插件安装计数""" """增加插件安装计数"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -1160,7 +1160,7 @@ class DeveloperEcosystemManager:
return review return review
def _update_plugin_rating(self, conn, plugin_id: str): def _update_plugin_rating(self, conn, plugin_id: str) -> None:
"""更新插件评分""" """更新插件评分"""
row = conn.execute( row = conn.execute(
""" """
@@ -1421,7 +1421,7 @@ class DeveloperEcosystemManager:
return self.get_developer_profile(developer_id) return self.get_developer_profile(developer_id)
def update_developer_stats(self, developer_id: str): def update_developer_stats(self, developer_id: str) -> None:
"""更新开发者统计信息""" """更新开发者统计信息"""
with self._get_db() as conn: with self._get_db() as conn:
# 统计插件数量 # 统计插件数量
@@ -1574,7 +1574,7 @@ class DeveloperEcosystemManager:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._row_to_code_example(row) for row in rows] return [self._row_to_code_example(row) for row in rows]
def increment_example_view(self, example_id: str): def increment_example_view(self, example_id: str) -> None:
"""增加代码示例查看计数""" """增加代码示例查看计数"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -1587,7 +1587,7 @@ class DeveloperEcosystemManager:
) )
conn.commit() conn.commit()
def increment_example_copy(self, example_id: str): def increment_example_copy(self, example_id: str) -> None:
"""增加代码示例复制计数""" """增加代码示例复制计数"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(

View File

@@ -329,7 +329,7 @@ class EnterpriseManager:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def _init_db(self): def _init_db(self) -> None:
"""初始化数据库表""" """初始化数据库表"""
conn = self._get_connection() conn = self._get_connection()
try: try:
@@ -1190,7 +1190,7 @@ class EnterpriseManager:
# GET {scim_base_url}/Users # GET {scim_base_url}/Users
return [] return []
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]): def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]) -> None:
"""插入或更新 SCIM 用户""" """插入或更新 SCIM 用户"""
cursor = conn.cursor() cursor = conn.cursor()

View File

@@ -22,14 +22,7 @@ try:
from reportlab.lib.pagesizes import A4 from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch from reportlab.lib.units import inch
from reportlab.platypus import ( from reportlab.platypus import PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle
PageBreak,
Paragraph,
SimpleDocTemplate,
Spacer,
Table,
TableStyle,
)
REPORTLAB_AVAILABLE = True REPORTLAB_AVAILABLE = True
except ImportError: except ImportError:
@@ -194,20 +187,16 @@ class ExportManager:
f'fill="white" stroke="#bdc3c7" rx="5"/>' f'fill="white" stroke="#bdc3c7" rx="5"/>'
) )
svg_parts.append( svg_parts.append(
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" ' f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" ' f'fill="#2c3e50">实体类型</text>'
f'fill="#2c3e50">实体类型</text>'
) )
for i, (etype, color) in enumerate(type_colors.items()): for i, (etype, color) in enumerate(type_colors.items()):
if etype != "default": if etype != "default":
y_pos = legend_y + 25 + i * 20 y_pos = legend_y + 25 + i * 20
svg_parts.append( svg_parts.append(f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>')
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
)
text_y = y_pos + 4 text_y = y_pos + 4
svg_parts.append( svg_parts.append(
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" ' f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" ' f'fill="#2c3e50">{etype}</text>'
f'fill="#2c3e50">{etype}</text>'
) )
svg_parts.append("</svg>") svg_parts.append("</svg>")
@@ -320,7 +309,6 @@ class ExportManager:
Returns: Returns:
CSV 字符串 CSV 字符串
""" """
import csv
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
@@ -586,7 +574,7 @@ class ExportManager:
_export_manager = None _export_manager = None
def get_export_manager(db_manager=None): def get_export_manager(db_manager=None) -> None:
"""获取导出管理器实例""" """获取导出管理器实例"""
global _export_manager global _export_manager
if _export_manager is None: if _export_manager is None:

View File

@@ -369,7 +369,7 @@ class GrowthManager:
self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "") self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "")
self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "") self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "")
def _get_db(self): def _get_db(self) -> None:
"""获取数据库连接""" """获取数据库连接"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row

View File

@@ -93,7 +93,7 @@ class ImageProcessor:
"other": "其他", "other": "其他",
} }
def __init__(self, temp_dir: str = None): def __init__(self, temp_dir: str = None) -> None:
""" """
初始化图片处理器 初始化图片处理器
@@ -103,7 +103,7 @@ class ImageProcessor:
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images") self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
os.makedirs(self.temp_dir, exist_ok=True) os.makedirs(self.temp_dir, exist_ok=True)
def preprocess_image(self, image, image_type: str = None): def preprocess_image(self, image, image_type: str = None) -> None:
""" """
预处理图片以提高OCR质量 预处理图片以提高OCR质量
@@ -145,7 +145,7 @@ class ImageProcessor:
print(f"Image preprocessing error: {e}") print(f"Image preprocessing error: {e}")
return image return image
def _enhance_whiteboard(self, image): def _enhance_whiteboard(self, image) -> None:
"""增强白板图片""" """增强白板图片"""
# 转换为灰度 # 转换为灰度
gray = image.convert("L") gray = image.convert("L")
@@ -160,7 +160,7 @@ class ImageProcessor:
return binary.convert("L") return binary.convert("L")
def _enhance_handwritten(self, image): def _enhance_handwritten(self, image) -> None:
"""增强手写笔记图片""" """增强手写笔记图片"""
# 转换为灰度 # 转换为灰度
gray = image.convert("L") gray = image.convert("L")

View File

@@ -163,8 +163,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature=0.3)
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:
@@ -217,8 +215,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature=0.3)
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:
@@ -271,8 +267,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature=0.3)
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:
@@ -325,8 +319,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.4) content = await self._call_llm(prompt, temperature=0.4)
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:
@@ -474,8 +466,6 @@ class KnowledgeReasoner:
content = await self._call_llm(prompt, temperature=0.3) content = await self._call_llm(prompt, temperature=0.3)
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if json_match: if json_match:

View File

@@ -204,15 +204,13 @@ class LLMClient:
messages = [ChatMessage(role="user", content=prompt)] messages = [ChatMessage(role="user", content=prompt)]
content = await self.chat(messages, temperature=0.1) content = await self.chat(messages, temperature=0.1)
import re
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL) json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
if not json_match: if not json_match:
return {"intent": "unknown", "explanation": "无法解析指令"} return {"intent": "unknown", "explanation": "无法解析指令"}
try: try:
return json.loads(json_match.group()) return json.loads(json_match.group())
except BaseException: except (json.JSONDecodeError, KeyError, TypeError):
return {"intent": "unknown", "explanation": "解析失败"} return {"intent": "unknown", "explanation": "解析失败"}
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str: async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:

View File

@@ -938,9 +938,7 @@ class LocalizationManager:
finally: finally:
self._close_if_file_db(conn) self._close_if_file_db(conn)
def get_translation( def get_translation(self, key: str, language: str, namespace: str = "common", fallback: bool = True) -> str | None:
self, key: str, language: str, namespace: str = "common", fallback: bool = True
) -> str | None:
conn = self._get_connection() conn = self._get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()

File diff suppressed because it is too large Load Diff

View File

@@ -80,7 +80,7 @@ class MultimodalEntityLinker:
# 模态类型 # 模态类型
MODALITIES = ["audio", "video", "image", "document"] MODALITIES = ["audio", "video", "image", "document"]
def __init__(self, similarity_threshold: float = 0.85): def __init__(self, similarity_threshold: float = 0.85) -> None:
""" """
初始化多模态实体关联器 初始化多模态实体关联器

View File

@@ -94,7 +94,7 @@ class VideoProcessingResult:
class MultimodalProcessor: class MultimodalProcessor:
"""多模态处理器 - 处理视频文件""" """多模态处理器 - 处理视频文件"""
def __init__(self, temp_dir: str = None, frame_interval: int = 5): def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
""" """
初始化多模态处理器 初始化多模态处理器
@@ -401,7 +401,7 @@ class MultimodalProcessor:
error_message=str(e), error_message=str(e),
) )
def cleanup(self, video_id: str = None): def cleanup(self, video_id: str = None) -> None:
""" """
清理临时文件 清理临时文件

View File

@@ -107,7 +107,7 @@ class Neo4jManager:
self._connect() self._connect()
def _connect(self): def _connect(self) -> None:
"""建立 Neo4j 连接""" """建立 Neo4j 连接"""
if not NEO4J_AVAILABLE: if not NEO4J_AVAILABLE:
return return
@@ -121,7 +121,7 @@ class Neo4jManager:
logger.error(f"Failed to connect to Neo4j: {e}") logger.error(f"Failed to connect to Neo4j: {e}")
self._driver = None self._driver = None
def close(self): def close(self) -> None:
"""关闭连接""" """关闭连接"""
if self._driver: if self._driver:
self._driver.close() self._driver.close()
@@ -137,7 +137,7 @@ class Neo4jManager:
except BaseException: except BaseException:
return False return False
def init_schema(self): def init_schema(self) -> None:
"""初始化图数据库 Schema约束和索引""" """初始化图数据库 Schema约束和索引"""
if not self._driver: if not self._driver:
logger.error("Neo4j not connected") logger.error("Neo4j not connected")
@@ -178,7 +178,7 @@ class Neo4jManager:
# ==================== 数据同步 ==================== # ==================== 数据同步 ====================
def sync_project(self, project_id: str, project_name: str, project_description: str = ""): def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None:
"""同步项目节点到 Neo4j""" """同步项目节点到 Neo4j"""
if not self._driver: if not self._driver:
return return
@@ -196,7 +196,7 @@ class Neo4jManager:
description=project_description, description=project_description,
) )
def sync_entity(self, entity: GraphEntity): def sync_entity(self, entity: GraphEntity) -> None:
"""同步单个实体到 Neo4j""" """同步单个实体到 Neo4j"""
if not self._driver: if not self._driver:
return return
@@ -225,7 +225,7 @@ class Neo4jManager:
properties=json.dumps(entity.properties), properties=json.dumps(entity.properties),
) )
def sync_entities_batch(self, entities: list[GraphEntity]): def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
"""批量同步实体到 Neo4j""" """批量同步实体到 Neo4j"""
if not self._driver or not entities: if not self._driver or not entities:
return return
@@ -262,7 +262,7 @@ class Neo4jManager:
entities=entities_data, entities=entities_data,
) )
def sync_relation(self, relation: GraphRelation): def sync_relation(self, relation: GraphRelation) -> None:
"""同步单个关系到 Neo4j""" """同步单个关系到 Neo4j"""
if not self._driver: if not self._driver:
return return
@@ -286,7 +286,7 @@ class Neo4jManager:
properties=json.dumps(relation.properties), properties=json.dumps(relation.properties),
) )
def sync_relations_batch(self, relations: list[GraphRelation]): def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
"""批量同步关系到 Neo4j""" """批量同步关系到 Neo4j"""
if not self._driver or not relations: if not self._driver or not relations:
return return
@@ -318,7 +318,7 @@ class Neo4jManager:
relations=relations_data, relations=relations_data,
) )
def delete_entity(self, entity_id: str): def delete_entity(self, entity_id: str) -> None:
"""从 Neo4j 删除实体及其关系""" """从 Neo4j 删除实体及其关系"""
if not self._driver: if not self._driver:
return return
@@ -332,7 +332,7 @@ class Neo4jManager:
id=entity_id, id=entity_id,
) )
def delete_project(self, project_id: str): def delete_project(self, project_id: str) -> None:
"""从 Neo4j 删除项目及其所有实体和关系""" """从 Neo4j 删除项目及其所有实体和关系"""
if not self._driver: if not self._driver:
return return
@@ -949,7 +949,7 @@ def get_neo4j_manager() -> Neo4jManager:
return _neo4j_manager return _neo4j_manager
def close_neo4j_manager(): def close_neo4j_manager() -> None:
"""关闭 Neo4j 连接""" """关闭 Neo4j 连接"""
global _neo4j_manager global _neo4j_manager
if _neo4j_manager: if _neo4j_manager:
@@ -958,7 +958,7 @@ def close_neo4j_manager():
# 便捷函数 # 便捷函数
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]): def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None:
""" """
同步整个项目到 Neo4j 同步整个项目到 Neo4j

View File

@@ -456,13 +456,13 @@ class OpsManager:
self._evaluator_thread = None self._evaluator_thread = None
self._register_default_evaluators() self._register_default_evaluators()
def _get_db(self): def _get_db(self) -> None:
"""获取数据库连接""" """获取数据库连接"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def _register_default_evaluators(self): def _register_default_evaluators(self) -> None:
"""注册默认的告警评估器""" """注册默认的告警评估器"""
self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule
self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule
@@ -1249,7 +1249,7 @@ class OpsManager:
return self.get_alert(alert_id) return self.get_alert(alert_id)
def _increment_suppression_count(self, alert_id: str): def _increment_suppression_count(self, alert_id: str) -> None:
"""增加告警抑制计数""" """增加告警抑制计数"""
with self._get_db() as conn: with self._get_db() as conn:
conn.execute( conn.execute(
@@ -1262,7 +1262,7 @@ class OpsManager:
) )
conn.commit() conn.commit()
def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool): def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool) -> None:
"""更新告警通知状态""" """更新告警通知状态"""
with self._get_db() as conn: with self._get_db() as conn:
row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone() row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone()
@@ -1276,7 +1276,7 @@ class OpsManager:
) )
conn.commit() conn.commit()
def _update_channel_stats(self, channel_id: str, success: bool): def _update_channel_stats(self, channel_id: str, success: bool) -> None:
"""更新渠道统计""" """更新渠道统计"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -1769,9 +1769,7 @@ class OpsManager:
return self._row_to_scaling_event(row) return self._row_to_scaling_event(row)
return None return None
def update_scaling_event_status( def update_scaling_event_status(self, event_id: str, status: str, error_message: str = None) -> ScalingEvent | None:
self, event_id: str, status: str, error_message: str = None
) -> ScalingEvent | None:
"""更新扩缩容事件状态""" """更新扩缩容事件状态"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -2339,7 +2337,7 @@ class OpsManager:
return record return record
def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None): def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None:
"""完成备份""" """完成备份"""
now = datetime.now().isoformat() now = datetime.now().isoformat()
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16] checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]

View File

@@ -37,7 +37,7 @@ class OSSUploader:
url = self.bucket.sign_url("GET", object_name, 3600) url = self.bucket.sign_url("GET", object_name, 3600)
return url, object_name return url, object_name
def delete_object(self, object_name: str): def delete_object(self, object_name: str) -> None:
"""删除 OSS 对象""" """删除 OSS 对象"""
self.bucket.delete_object(object_name) self.bucket.delete_object(object_name)

View File

@@ -55,7 +55,7 @@ class CacheStats:
expired: int = 0 expired: int = 0
hit_rate: float = 0.0 hit_rate: float = 0.0
def update_hit_rate(self): def update_hit_rate(self) -> None:
"""更新命中率""" """更新命中率"""
if self.total_requests > 0: if self.total_requests > 0:
self.hit_rate = round(self.hits / self.total_requests, 4) self.hit_rate = round(self.hits / self.total_requests, 4)
@@ -194,7 +194,7 @@ class CacheManager:
# 初始化缓存统计表 # 初始化缓存统计表
self._init_cache_tables() self._init_cache_tables()
def _init_cache_tables(self): def _init_cache_tables(self) -> None:
"""初始化缓存统计表""" """初始化缓存统计表"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -234,7 +234,7 @@ class CacheManager:
except BaseException: except BaseException:
return 1024 # 默认估算 return 1024 # 默认估算
def _evict_lru(self, required_space: int = 0): def _evict_lru(self, required_space: int = 0) -> None:
"""LRU 淘汰策略""" """LRU 淘汰策略"""
with self.cache_lock: with self.cache_lock:
while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache: while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache:
@@ -444,7 +444,7 @@ class CacheManager:
return stats return stats
def save_stats(self): def save_stats(self) -> None:
"""保存缓存统计到数据库""" """保存缓存统计到数据库"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -618,7 +618,7 @@ class DatabaseSharding:
# 初始化分片 # 初始化分片
self._init_shards() self._init_shards()
def _init_shards(self): def _init_shards(self) -> None:
"""初始化分片""" """初始化分片"""
# 计算每个分片的 key 范围 # 计算每个分片的 key 范围
chars = "0123456789abcdef" chars = "0123456789abcdef"
@@ -645,7 +645,7 @@ class DatabaseSharding:
if not os.path.exists(db_path): if not os.path.exists(db_path):
self._create_shard_db(db_path) self._create_shard_db(db_path)
def _create_shard_db(self, db_path: str): def _create_shard_db(self, db_path: str) -> None:
"""创建分片数据库""" """创建分片数据库"""
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
@@ -792,7 +792,7 @@ class DatabaseSharding:
print(f"迁移失败: {e}") print(f"迁移失败: {e}")
return False return False
def _update_shard_stats(self, shard_id: str): def _update_shard_stats(self, shard_id: str) -> None:
"""更新分片统计""" """更新分片统计"""
shard_info = self.shard_map.get(shard_id) shard_info = self.shard_map.get(shard_id)
if not shard_info: if not shard_info:
@@ -923,7 +923,7 @@ class TaskQueue:
except Exception as e: except Exception as e:
print(f"Celery 初始化失败,使用内存任务队列: {e}") print(f"Celery 初始化失败,使用内存任务队列: {e}")
def _init_task_tables(self): def _init_task_tables(self) -> None:
"""初始化任务队列表""" """初始化任务队列表"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -953,7 +953,7 @@ class TaskQueue:
"""检查任务队列是否可用""" """检查任务队列是否可用"""
return self.use_celery or True # 内存模式也可用 return self.use_celery or True # 内存模式也可用
def register_handler(self, task_type: str, handler: Callable): def register_handler(self, task_type: str, handler: Callable) -> None:
"""注册任务处理器""" """注册任务处理器"""
self.task_handlers[task_type] = handler self.task_handlers[task_type] = handler
@@ -1014,7 +1014,7 @@ class TaskQueue:
return task_id return task_id
def _execute_task(self, task_id: str): def _execute_task(self, task_id: str) -> None:
"""执行任务(内存模式)""" """执行任务(内存模式)"""
with self.task_lock: with self.task_lock:
task = self.tasks.get(task_id) task = self.tasks.get(task_id)
@@ -1055,7 +1055,7 @@ class TaskQueue:
self._update_task_status(task) self._update_task_status(task)
def _save_task(self, task: TaskInfo): def _save_task(self, task: TaskInfo) -> None:
"""保存任务到数据库""" """保存任务到数据库"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1084,7 +1084,7 @@ class TaskQueue:
conn.commit() conn.commit()
conn.close() conn.close()
def _update_task_status(self, task: TaskInfo): def _update_task_status(self, task: TaskInfo) -> None:
"""更新任务状态""" """更新任务状态"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1143,9 +1143,7 @@ class TaskQueue:
with self.task_lock: with self.task_lock:
return self.tasks.get(task_id) return self.tasks.get(task_id)
def list_tasks( def list_tasks(self, status: str | None = None, task_type: str | None = None, limit: int = 100) -> list[TaskInfo]:
self, status: str | None = None, task_type: str | None = None, limit: int = 100
) -> list[TaskInfo]:
"""列出任务""" """列出任务"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -1333,7 +1331,7 @@ class PerformanceMonitor:
if metric_type == "db_query" and duration_ms > self.slow_query_threshold: if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
self._record_slow_query(metric) self._record_slow_query(metric)
def _flush_metrics(self): def _flush_metrics(self) -> None:
"""将缓冲区指标写入数据库""" """将缓冲区指标写入数据库"""
if not self.metrics_buffer: if not self.metrics_buffer:
return return
@@ -1362,12 +1360,12 @@ class PerformanceMonitor:
self.metrics_buffer = [] self.metrics_buffer = []
def _record_slow_query(self, metric: PerformanceMetric): def _record_slow_query(self, metric: PerformanceMetric) -> None:
"""记录慢查询""" """记录慢查询"""
# 可以发送到专门的慢查询日志或监控系统 # 可以发送到专门的慢查询日志或监控系统
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms") print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
def _trigger_alert(self, metric: PerformanceMetric): def _trigger_alert(self, metric: PerformanceMetric) -> None:
"""触发告警""" """触发告警"""
alert_data = { alert_data = {
"type": "performance_alert", "type": "performance_alert",
@@ -1382,7 +1380,7 @@ class PerformanceMonitor:
except Exception as e: except Exception as e:
print(f"告警处理失败: {e}") print(f"告警处理失败: {e}")
def register_alert_handler(self, handler: Callable): def register_alert_handler(self, handler: Callable) -> None:
"""注册告警处理器""" """注册告警处理器"""
self.alert_handlers.append(handler) self.alert_handlers.append(handler)
@@ -1585,7 +1583,9 @@ class PerformanceMonitor:
# ==================== 性能装饰器 ==================== # ==================== 性能装饰器 ====================
def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None): def cached(
cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None
) -> None:
""" """
缓存装饰器 缓存装饰器
@@ -1598,7 +1598,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs) -> None:
# 生成缓存键 # 生成缓存键
if key_func: if key_func:
cache_key = key_func(*args, **kwargs) cache_key = key_func(*args, **kwargs)
@@ -1625,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
return decorator return decorator
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None): def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
""" """
性能监控装饰器 性能监控装饰器

View File

@@ -163,7 +163,7 @@ class PluginManager:
self._handlers = {} self._handlers = {}
self._register_default_handlers() self._register_default_handlers()
def _register_default_handlers(self): def _register_default_handlers(self) -> None:
"""注册默认处理器""" """注册默认处理器"""
self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self) self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self)
self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu") self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu")
@@ -371,7 +371,7 @@ class PluginManager:
return cursor.rowcount > 0 return cursor.rowcount > 0
def record_plugin_usage(self, plugin_id: str): def record_plugin_usage(self, plugin_id: str) -> None:
"""记录插件使用""" """记录插件使用"""
conn = self.db.get_conn() conn = self.db.get_conn()
now = datetime.now().isoformat() now = datetime.now().isoformat()
@@ -826,9 +826,6 @@ class BotHandler:
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool: async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送飞书消息""" """发送飞书消息"""
import base64
import hashlib
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
# 生成签名 # 生成签名
@@ -851,9 +848,6 @@ class BotHandler:
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool: async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
"""发送钉钉消息""" """发送钉钉消息"""
import base64
import hashlib
timestamp = str(round(time.time() * 1000)) timestamp = str(round(time.time() * 1000))
# 生成签名 # 生成签名
@@ -1358,7 +1352,7 @@ class WebDAVSyncManager:
_plugin_manager = None _plugin_manager = None
def get_plugin_manager(db_manager=None): def get_plugin_manager(db_manager=None) -> None:
"""获取 PluginManager 单例""" """获取 PluginManager 单例"""
global _plugin_manager global _plugin_manager
if _plugin_manager is None: if _plugin_manager is None:

View File

@@ -56,7 +56,7 @@ class SlidingWindowCounter:
self._cleanup_old(now) self._cleanup_old(now)
return sum(self.requests.values()) return sum(self.requests.values())
def _cleanup_old(self, now: int): def _cleanup_old(self, now: int) -> None:
"""清理过期的请求记录 - 使用独立锁避免竞态条件""" """清理过期的请求记录 - 使用独立锁避免竞态条件"""
cutoff = now - self.window_size cutoff = now - self.window_size
old_keys = [k for k in list(self.requests.keys()) if k < cutoff] old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
@@ -67,7 +67,7 @@ class SlidingWindowCounter:
class RateLimiter: class RateLimiter:
"""API 限流器""" """API 限流器"""
def __init__(self): def __init__(self) -> None:
# key -> SlidingWindowCounter # key -> SlidingWindowCounter
self.counters: dict[str, SlidingWindowCounter] = {} self.counters: dict[str, SlidingWindowCounter] = {}
# key -> RateLimitConfig # key -> RateLimitConfig
@@ -143,7 +143,7 @@ class RateLimiter:
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0, retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
) )
def reset(self, key: str | None = None): def reset(self, key: str | None = None) -> None:
"""重置限流计数器""" """重置限流计数器"""
if key: if key:
self.counters.pop(key, None) self.counters.pop(key, None)
@@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter:
# 限流装饰器(用于函数级别限流) # 限流装饰器(用于函数级别限流)
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None): def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
""" """
限流装饰器 限流装饰器

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,7 @@ try:
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
CRYPTO_AVAILABLE = True CRYPTO_AVAILABLE = True
except ImportError: except ImportError:
CRYPTO_AVAILABLE = False CRYPTO_AVAILABLE = False
@@ -27,6 +28,7 @@ except ImportError:
class AuditActionType(Enum): class AuditActionType(Enum):
"""审计动作类型""" """审计动作类型"""
CREATE = "create" CREATE = "create"
READ = "read" READ = "read"
UPDATE = "update" UPDATE = "update"
@@ -49,26 +51,29 @@ class AuditActionType(Enum):
class DataSensitivityLevel(Enum): class DataSensitivityLevel(Enum):
"""数据敏感度级别""" """数据敏感度级别"""
PUBLIC = "public" # 公开
INTERNAL = "internal" # 内部 PUBLIC = "public" # 公开
INTERNAL = "internal" # 内部
CONFIDENTIAL = "confidential" # 机密 CONFIDENTIAL = "confidential" # 机密
SECRET = "secret" # 绝密 SECRET = "secret" # 绝密
class MaskingRuleType(Enum): class MaskingRuleType(Enum):
"""脱敏规则类型""" """脱敏规则类型"""
PHONE = "phone" # 手机号
EMAIL = "email" # 邮箱 PHONE = "phone" # 手机号
ID_CARD = "id_card" # 身份证号 EMAIL = "email" # 邮箱
BANK_CARD = "bank_card" # 银行卡 ID_CARD = "id_card" # 身份证
NAME = "name" # 姓名 BANK_CARD = "bank_card" # 银行卡号
ADDRESS = "address" # 地址 NAME = "name" # 姓名
CUSTOM = "custom" # 自定义 ADDRESS = "address" # 地址
CUSTOM = "custom" # 自定义
@dataclass @dataclass
class AuditLog: class AuditLog:
"""审计日志条目""" """审计日志条目"""
id: str id: str
action_type: str action_type: str
user_id: str | None = None user_id: str | None = None
@@ -90,6 +95,7 @@ class AuditLog:
@dataclass @dataclass
class EncryptionConfig: class EncryptionConfig:
"""加密配置""" """加密配置"""
id: str id: str
project_id: str project_id: str
is_enabled: bool = False is_enabled: bool = False
@@ -107,6 +113,7 @@ class EncryptionConfig:
@dataclass @dataclass
class MaskingRule: class MaskingRule:
"""脱敏规则""" """脱敏规则"""
id: str id: str
project_id: str project_id: str
name: str name: str
@@ -126,6 +133,7 @@ class MaskingRule:
@dataclass @dataclass
class DataAccessPolicy: class DataAccessPolicy:
"""数据访问策略""" """数据访问策略"""
id: str id: str
project_id: str project_id: str
name: str name: str
@@ -147,6 +155,7 @@ class DataAccessPolicy:
@dataclass @dataclass
class AccessRequest: class AccessRequest:
"""访问请求(用于需要审批的访问)""" """访问请求(用于需要审批的访问)"""
id: str id: str
policy_id: str policy_id: str
user_id: str user_id: str
@@ -166,30 +175,15 @@ class SecurityManager:
# 预定义脱敏规则 # 预定义脱敏规则
DEFAULT_MASKING_RULES = { DEFAULT_MASKING_RULES = {
MaskingRuleType.PHONE: { MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
"pattern": r"(\d{3})\d{4}(\d{4})", MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
"replacement": r"\1****\2" MaskingRuleType.ID_CARD: {"pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2"},
}, MaskingRuleType.BANK_CARD: {"pattern": r"(\d{4})\d+(\d{4})", "replacement": r"\1 **** **** \2"},
MaskingRuleType.EMAIL: { MaskingRuleType.NAME: {"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", "replacement": r"\1**"},
"pattern": r"(\w{1,3})\w+(@\w+\.\w+)",
"replacement": r"\1***\2"
},
MaskingRuleType.ID_CARD: {
"pattern": r"(\d{6})\d{8}(\d{4})",
"replacement": r"\1********\2"
},
MaskingRuleType.BANK_CARD: {
"pattern": r"(\d{4})\d+(\d{4})",
"replacement": r"\1 **** **** \2"
},
MaskingRuleType.NAME: {
"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+",
"replacement": r"\1**"
},
MaskingRuleType.ADDRESS: { MaskingRuleType.ADDRESS: {
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)", "pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
"replacement": r"\1\2***" "replacement": r"\1\2***",
} },
} }
def __init__(self, db_path: str = "insightflow.db"): def __init__(self, db_path: str = "insightflow.db"):
@@ -200,7 +194,7 @@ class SecurityManager:
self._local = {} self._local = {}
self._init_db() self._init_db()
def _init_db(self): def _init_db(self) -> None:
"""初始化数据库表""" """初始化数据库表"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -308,9 +302,7 @@ class SecurityManager:
def _generate_id(self) -> str: def _generate_id(self) -> str:
"""生成唯一ID""" """生成唯一ID"""
return hashlib.sha256( return hashlib.sha256(f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()).hexdigest()[:32]
f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
).hexdigest()[:32]
# ==================== 审计日志 ==================== # ==================== 审计日志 ====================
@@ -326,7 +318,7 @@ class SecurityManager:
before_value: str | None = None, before_value: str | None = None,
after_value: str | None = None, after_value: str | None = None,
success: bool = True, success: bool = True,
error_message: str | None = None error_message: str | None = None,
) -> AuditLog: ) -> AuditLog:
"""记录审计日志""" """记录审计日志"""
log = AuditLog( log = AuditLog(
@@ -341,22 +333,34 @@ class SecurityManager:
before_value=before_value, before_value=before_value,
after_value=after_value, after_value=after_value,
success=success, success=success,
error_message=error_message error_message=error_message,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
INSERT INTO audit_logs INSERT INTO audit_logs
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id, (id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
action_details, before_value, after_value, success, error_message, created_at) action_details, before_value, after_value, success, error_message, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
log.id, log.action_type, log.user_id, log.user_ip, log.user_agent, (
log.resource_type, log.resource_id, log.action_details, log.id,
log.before_value, log.after_value, int(log.success), log.action_type,
log.error_message, log.created_at log.user_id,
)) log.user_ip,
log.user_agent,
log.resource_type,
log.resource_id,
log.action_details,
log.before_value,
log.after_value,
int(log.success),
log.error_message,
log.created_at,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -372,7 +376,7 @@ class SecurityManager:
end_time: str | None = None, end_time: str | None = None,
success: bool | None = None, success: bool | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0,
) -> list[AuditLog]: ) -> list[AuditLog]:
"""查询审计日志""" """查询审计日志"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -429,18 +433,14 @@ class SecurityManager:
after_value=row[9], after_value=row[9],
success=bool(row[10]), success=bool(row[10]),
error_message=row[11], error_message=row[11],
created_at=row[12] created_at=row[12],
) )
logs.append(log) logs.append(log)
conn.close() conn.close()
return logs return logs
def get_audit_stats( def get_audit_stats(self, start_time: str | None = None, end_time: str | None = None) -> dict[str, Any]:
self,
start_time: str | None = None,
end_time: str | None = None
) -> dict[str, Any]:
"""获取审计统计""" """获取审计统计"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -460,12 +460,7 @@ class SecurityManager:
cursor.execute(query, params) cursor.execute(query, params)
rows = cursor.fetchall() rows = cursor.fetchall()
stats = { stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}}
"total_actions": 0,
"success_count": 0,
"failure_count": 0,
"action_breakdown": {}
}
for action_type, success, count in rows: for action_type, success, count in rows:
stats["total_actions"] += count stats["total_actions"] += count
@@ -500,11 +495,7 @@ class SecurityManager:
) )
return base64.urlsafe_b64encode(kdf.derive(password.encode())) return base64.urlsafe_b64encode(kdf.derive(password.encode()))
def enable_encryption( def enable_encryption(self, project_id: str, master_password: str) -> EncryptionConfig:
self,
project_id: str,
master_password: str
) -> EncryptionConfig:
"""启用项目加密""" """启用项目加密"""
if not CRYPTO_AVAILABLE: if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available") raise RuntimeError("cryptography library not available")
@@ -523,43 +514,54 @@ class SecurityManager:
encryption_type="aes-256-gcm", encryption_type="aes-256-gcm",
key_derivation="pbkdf2", key_derivation="pbkdf2",
master_key_hash=key_hash, master_key_hash=key_hash,
salt=salt salt=salt,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
# 检查是否已存在配置 # 检查是否已存在配置
cursor.execute( cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,))
"SELECT id FROM encryption_configs WHERE project_id = ?",
(project_id,)
)
existing = cursor.fetchone() existing = cursor.fetchone()
if existing: if existing:
cursor.execute(""" cursor.execute(
"""
UPDATE encryption_configs UPDATE encryption_configs
SET is_enabled = 1, encryption_type = ?, key_derivation = ?, SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
master_key_hash = ?, salt = ?, updated_at = ? master_key_hash = ?, salt = ?, updated_at = ?
WHERE project_id = ? WHERE project_id = ?
""", ( """,
config.encryption_type, config.key_derivation, (
config.master_key_hash, config.salt, config.encryption_type,
config.updated_at, project_id config.key_derivation,
)) config.master_key_hash,
config.salt,
config.updated_at,
project_id,
),
)
config.id = existing[0] config.id = existing[0]
else: else:
cursor.execute(""" cursor.execute(
"""
INSERT INTO encryption_configs INSERT INTO encryption_configs
(id, project_id, is_enabled, encryption_type, key_derivation, (id, project_id, is_enabled, encryption_type, key_derivation,
master_key_hash, salt, created_at, updated_at) master_key_hash, salt, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
config.id, config.project_id, int(config.is_enabled), (
config.encryption_type, config.key_derivation, config.id,
config.master_key_hash, config.salt, config.project_id,
config.created_at, config.updated_at int(config.is_enabled),
)) config.encryption_type,
config.key_derivation,
config.master_key_hash,
config.salt,
config.created_at,
config.updated_at,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -569,16 +571,12 @@ class SecurityManager:
action_type=AuditActionType.ENCRYPTION_ENABLE, action_type=AuditActionType.ENCRYPTION_ENABLE,
resource_type="project", resource_type="project",
resource_id=project_id, resource_id=project_id,
action_details={"encryption_type": config.encryption_type} action_details={"encryption_type": config.encryption_type},
) )
return config return config
def disable_encryption( def disable_encryption(self, project_id: str, master_password: str) -> bool:
self,
project_id: str,
master_password: str
) -> bool:
"""禁用项目加密""" """禁用项目加密"""
# 验证密码 # 验证密码
if not self.verify_encryption_password(project_id, master_password): if not self.verify_encryption_password(project_id, master_password):
@@ -587,29 +585,24 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
UPDATE encryption_configs UPDATE encryption_configs
SET is_enabled = 0, updated_at = ? SET is_enabled = 0, updated_at = ?
WHERE project_id = ? WHERE project_id = ?
""", (datetime.now().isoformat(), project_id)) """,
(datetime.now().isoformat(), project_id),
)
conn.commit() conn.commit()
conn.close() conn.close()
# 记录审计日志 # 记录审计日志
self.log_audit( self.log_audit(action_type=AuditActionType.ENCRYPTION_DISABLE, resource_type="project", resource_id=project_id)
action_type=AuditActionType.ENCRYPTION_DISABLE,
resource_type="project",
resource_id=project_id
)
return True return True
def verify_encryption_password( def verify_encryption_password(self, project_id: str, password: str) -> bool:
self,
project_id: str,
password: str
) -> bool:
"""验证加密密码""" """验证加密密码"""
if not CRYPTO_AVAILABLE: if not CRYPTO_AVAILABLE:
return False return False
@@ -617,10 +610,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id,))
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
(project_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -638,10 +628,7 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,))
"SELECT * FROM encryption_configs WHERE project_id = ?",
(project_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -657,15 +644,10 @@ class SecurityManager:
master_key_hash=row[5], master_key_hash=row[5],
salt=row[6], salt=row[6],
created_at=row[7], created_at=row[7],
updated_at=row[8] updated_at=row[8],
) )
def encrypt_data( def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
self,
data: str,
password: str,
salt: str | None = None
) -> tuple[str, str]:
"""加密数据""" """加密数据"""
if not CRYPTO_AVAILABLE: if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available") raise RuntimeError("cryptography library not available")
@@ -679,12 +661,7 @@ class SecurityManager:
return base64.b64encode(encrypted).decode(), salt return base64.b64encode(encrypted).decode(), salt
def decrypt_data( def decrypt_data(self, encrypted_data: str, password: str, salt: str) -> str:
self,
encrypted_data: str,
password: str,
salt: str
) -> str:
"""解密数据""" """解密数据"""
if not CRYPTO_AVAILABLE: if not CRYPTO_AVAILABLE:
raise RuntimeError("cryptography library not available") raise RuntimeError("cryptography library not available")
@@ -705,7 +682,7 @@ class SecurityManager:
pattern: str | None = None, pattern: str | None = None,
replacement: str | None = None, replacement: str | None = None,
description: str | None = None, description: str | None = None,
priority: int = 0 priority: int = 0,
) -> MaskingRule: ) -> MaskingRule:
"""创建脱敏规则""" """创建脱敏规则"""
# 使用预定义规则或自定义规则 # 使用预定义规则或自定义规则
@@ -722,22 +699,33 @@ class SecurityManager:
pattern=pattern or "", pattern=pattern or "",
replacement=replacement or "****", replacement=replacement or "****",
description=description, description=description,
priority=priority priority=priority,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
INSERT INTO masking_rules INSERT INTO masking_rules
(id, project_id, name, rule_type, pattern, replacement, (id, project_id, name, rule_type, pattern, replacement,
is_active, priority, description, created_at, updated_at) is_active, priority, description, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
rule.id, rule.project_id, rule.name, rule.rule_type, (
rule.pattern, rule.replacement, int(rule.is_active), rule.id,
rule.priority, rule.description, rule.created_at, rule.updated_at rule.project_id,
)) rule.name,
rule.rule_type,
rule.pattern,
rule.replacement,
int(rule.is_active),
rule.priority,
rule.description,
rule.created_at,
rule.updated_at,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -747,16 +735,12 @@ class SecurityManager:
action_type=AuditActionType.DATA_MASKING, action_type=AuditActionType.DATA_MASKING,
resource_type="project", resource_type="project",
resource_id=project_id, resource_id=project_id,
action_details={"action": "create_rule", "rule_name": name} action_details={"action": "create_rule", "rule_name": name},
) )
return rule return rule
def get_masking_rules( def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]:
self,
project_id: str,
active_only: bool = True
) -> list[MaskingRule]:
"""获取脱敏规则""" """获取脱敏规则"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -775,27 +759,25 @@ class SecurityManager:
rules = [] rules = []
for row in rows: for row in rows:
rules.append(MaskingRule( rules.append(
id=row[0], MaskingRule(
project_id=row[1], id=row[0],
name=row[2], project_id=row[1],
rule_type=row[3], name=row[2],
pattern=row[4], rule_type=row[3],
replacement=row[5], pattern=row[4],
is_active=bool(row[6]), replacement=row[5],
priority=row[7], is_active=bool(row[6]),
description=row[8], priority=row[7],
created_at=row[9], description=row[8],
updated_at=row[10] created_at=row[9],
)) updated_at=row[10],
)
)
return rules return rules
def update_masking_rule( def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None:
self,
rule_id: str,
**kwargs
) -> MaskingRule | None:
"""更新脱敏规则""" """更新脱敏规则"""
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"] allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
@@ -818,11 +800,14 @@ class SecurityManager:
params.append(datetime.now().isoformat()) params.append(datetime.now().isoformat())
params.append(rule_id) params.append(rule_id)
cursor.execute(f""" cursor.execute(
f"""
UPDATE masking_rules UPDATE masking_rules
SET {', '.join(set_clauses)} SET {', '.join(set_clauses)}
WHERE id = ? WHERE id = ?
""", params) """,
params,
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -848,7 +833,7 @@ class SecurityManager:
priority=row[7], priority=row[7],
description=row[8], description=row[8],
created_at=row[9], created_at=row[9],
updated_at=row[10] updated_at=row[10],
) )
def delete_masking_rule(self, rule_id: str) -> bool: def delete_masking_rule(self, rule_id: str) -> bool:
@@ -864,12 +849,7 @@ class SecurityManager:
return success return success
def apply_masking( def apply_masking(self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None) -> str:
self,
text: str,
project_id: str,
rule_types: list[MaskingRuleType] | None = None
) -> str:
"""应用脱敏规则到文本""" """应用脱敏规则到文本"""
rules = self.get_masking_rules(project_id) rules = self.get_masking_rules(project_id)
@@ -884,22 +864,14 @@ class SecurityManager:
continue continue
try: try:
masked_text = re.sub( masked_text = re.sub(rule.pattern, rule.replacement, masked_text)
rule.pattern,
rule.replacement,
masked_text
)
except re.error: except re.error:
# 忽略无效的正则表达式 # 忽略无效的正则表达式
continue continue
return masked_text return masked_text
def apply_masking_to_entity( def apply_masking_to_entity(self, entity_data: dict[str, Any], project_id: str) -> dict[str, Any]:
self,
entity_data: dict[str, Any],
project_id: str
) -> dict[str, Any]:
"""对实体数据应用脱敏""" """对实体数据应用脱敏"""
masked_data = entity_data.copy() masked_data = entity_data.copy()
@@ -924,7 +896,7 @@ class SecurityManager:
allowed_ips: list[str] | None = None, allowed_ips: list[str] | None = None,
time_restrictions: dict | None = None, time_restrictions: dict | None = None,
max_access_count: int | None = None, max_access_count: int | None = None,
require_approval: bool = False require_approval: bool = False,
) -> DataAccessPolicy: ) -> DataAccessPolicy:
"""创建数据访问策略""" """创建数据访问策略"""
policy = DataAccessPolicy( policy = DataAccessPolicy(
@@ -937,36 +909,43 @@ class SecurityManager:
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None, allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None, time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
max_access_count=max_access_count, max_access_count=max_access_count,
require_approval=require_approval require_approval=require_approval,
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
INSERT INTO data_access_policies INSERT INTO data_access_policies
(id, project_id, name, description, allowed_users, allowed_roles, (id, project_id, name, description, allowed_users, allowed_roles,
allowed_ips, time_restrictions, max_access_count, require_approval, allowed_ips, time_restrictions, max_access_count, require_approval,
is_active, created_at, updated_at) is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
policy.id, policy.project_id, policy.name, policy.description, (
policy.allowed_users, policy.allowed_roles, policy.allowed_ips, policy.id,
policy.time_restrictions, policy.max_access_count, policy.project_id,
int(policy.require_approval), int(policy.is_active), policy.name,
policy.created_at, policy.updated_at policy.description,
)) policy.allowed_users,
policy.allowed_roles,
policy.allowed_ips,
policy.time_restrictions,
policy.max_access_count,
int(policy.require_approval),
int(policy.is_active),
policy.created_at,
policy.updated_at,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
return policy return policy
def get_access_policies( def get_access_policies(self, project_id: str, active_only: bool = True) -> list[DataAccessPolicy]:
self,
project_id: str,
active_only: bool = True
) -> list[DataAccessPolicy]:
"""获取数据访问策略""" """获取数据访问策略"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
@@ -983,38 +962,34 @@ class SecurityManager:
policies = [] policies = []
for row in rows: for row in rows:
policies.append(DataAccessPolicy( policies.append(
id=row[0], DataAccessPolicy(
project_id=row[1], id=row[0],
name=row[2], project_id=row[1],
description=row[3], name=row[2],
allowed_users=row[4], description=row[3],
allowed_roles=row[5], allowed_users=row[4],
allowed_ips=row[6], allowed_roles=row[5],
time_restrictions=row[7], allowed_ips=row[6],
max_access_count=row[8], time_restrictions=row[7],
require_approval=bool(row[9]), max_access_count=row[8],
is_active=bool(row[10]), require_approval=bool(row[9]),
created_at=row[11], is_active=bool(row[10]),
updated_at=row[12] created_at=row[11],
)) updated_at=row[12],
)
)
return policies return policies
def check_access_permission( def check_access_permission(
self, self, policy_id: str, user_id: str, user_ip: str | None = None
policy_id: str,
user_id: str,
user_ip: str | None = None
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
"""检查访问权限""" """检查访问权限"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,))
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
(policy_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@@ -1034,7 +1009,7 @@ class SecurityManager:
require_approval=bool(row[9]), require_approval=bool(row[9]),
is_active=bool(row[10]), is_active=bool(row[10]),
created_at=row[11], created_at=row[11],
updated_at=row[12] updated_at=row[12],
) )
# 检查用户白名单 # 检查用户白名单
@@ -1074,11 +1049,14 @@ class SecurityManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
SELECT * FROM access_requests SELECT * FROM access_requests
WHERE policy_id = ? AND user_id = ? AND status = 'approved' WHERE policy_id = ? AND user_id = ? AND status = 'approved'
AND (expires_at IS NULL OR expires_at > ?) AND (expires_at IS NULL OR expires_at > ?)
""", (policy_id, user_id, datetime.now().isoformat())) """,
(policy_id, user_id, datetime.now().isoformat()),
)
request = cursor.fetchone() request = cursor.fetchone()
conn.close() conn.close()
@@ -1104,11 +1082,7 @@ class SecurityManager:
return ip == pattern return ip == pattern
def create_access_request( def create_access_request(
self, self, policy_id: str, user_id: str, request_reason: str | None = None, expires_hours: int = 24
policy_id: str,
user_id: str,
request_reason: str | None = None,
expires_hours: int = 24
) -> AccessRequest: ) -> AccessRequest:
"""创建访问请求""" """创建访问请求"""
request = AccessRequest( request = AccessRequest(
@@ -1116,21 +1090,28 @@ class SecurityManager:
policy_id=policy_id, policy_id=policy_id,
user_id=user_id, user_id=user_id,
request_reason=request_reason, request_reason=request_reason,
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat() expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(),
) )
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
INSERT INTO access_requests INSERT INTO access_requests
(id, policy_id, user_id, request_reason, status, expires_at, created_at) (id, policy_id, user_id, request_reason, status, expires_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", ( """,
request.id, request.policy_id, request.user_id, (
request.request_reason, request.status, request.expires_at, request.id,
request.created_at request.policy_id,
)) request.user_id,
request.request_reason,
request.status,
request.expires_at,
request.created_at,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -1138,10 +1119,7 @@ class SecurityManager:
return request return request
def approve_access_request( def approve_access_request(
self, self, request_id: str, approved_by: str, expires_hours: int = 24
request_id: str,
approved_by: str,
expires_hours: int = 24
) -> AccessRequest | None: ) -> AccessRequest | None:
"""批准访问请求""" """批准访问请求"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -1150,11 +1128,14 @@ class SecurityManager:
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat() expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
approved_at = datetime.now().isoformat() approved_at = datetime.now().isoformat()
cursor.execute(""" cursor.execute(
"""
UPDATE access_requests UPDATE access_requests
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ? SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
WHERE id = ? WHERE id = ?
""", (approved_by, approved_at, expires_at, request_id)) """,
(approved_by, approved_at, expires_at, request_id),
)
conn.commit() conn.commit()
@@ -1175,23 +1156,22 @@ class SecurityManager:
approved_by=row[5], approved_by=row[5],
approved_at=row[6], approved_at=row[6],
expires_at=row[7], expires_at=row[7],
created_at=row[8] created_at=row[8],
) )
def reject_access_request( def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
self,
request_id: str,
rejected_by: str
) -> AccessRequest | None:
"""拒绝访问请求""" """拒绝访问请求"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute(
"""
UPDATE access_requests UPDATE access_requests
SET status = 'rejected', approved_by = ? SET status = 'rejected', approved_by = ?
WHERE id = ? WHERE id = ?
""", (rejected_by, request_id)) """,
(rejected_by, request_id),
)
conn.commit() conn.commit()
@@ -1211,7 +1191,7 @@ class SecurityManager:
approved_by=row[5], approved_by=row[5],
approved_at=row[6], approved_at=row[6],
expires_at=row[7], expires_at=row[7],
created_at=row[8] created_at=row[8],
) )

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,7 @@ class TingwuClient:
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]: def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
"""阿里云签名 V3""" """阿里云签名 V3"""
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
# 简化签名,实际生产需要完整实现 # 简化签名,实际生产需要完整实现
# 这里使用基础认证头 # 这里使用基础认证头
@@ -39,25 +39,16 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config( config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
access_key_id=self.access_key,
access_key_secret=self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)
request = tingwu_models.CreateTaskRequest( request = tingwu_models.CreateTaskRequest(
type="offline", type="offline",
input=tingwu_models.Input( input=tingwu_models.Input(source="OSS", file_url=audio_url),
source="OSS",
file_url=audio_url
),
parameters=tingwu_models.Parameters( parameters=tingwu_models.Parameters(
transcription=tingwu_models.Transcription( transcription=tingwu_models.Transcription(diarization_enabled=True, sentence_max_length=20)
diarization_enabled=True, ),
sentence_max_length=20
)
)
) )
response = client.create_task(request) response = client.create_task(request)
@@ -81,10 +72,7 @@ class TingwuClient:
from alibabacloud_tingwu20230930 import models as tingwu_models from alibabacloud_tingwu20230930 import models as tingwu_models
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
config = open_api_models.Config( config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
access_key_id=self.access_key,
access_key_secret=self.secret_key
)
config.endpoint = "tingwu.cn-beijing.aliyuncs.com" config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
client = TingwuSDKClient(config) client = TingwuSDKClient(config)
@@ -128,25 +116,29 @@ class TingwuClient:
if transcription.sentences: if transcription.sentences:
for sent in transcription.sentences: for sent in transcription.sentences:
segments.append({ segments.append(
"start": sent.begin_time / 1000, {
"end": sent.end_time / 1000, "start": sent.begin_time / 1000,
"text": sent.text, "end": sent.end_time / 1000,
"speaker": f"Speaker {sent.speaker_id}" "text": sent.text,
}) "speaker": f"Speaker {sent.speaker_id}",
}
)
return { return {"full_text": full_text.strip(), "segments": segments}
"full_text": full_text.strip(),
"segments": segments
}
def _mock_result(self) -> dict[str, Any]: def _mock_result(self) -> dict[str, Any]:
"""Mock 结果""" """Mock 结果"""
return { return {
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
"segments": [ "segments": [
{"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"} {
] "start": 0.0,
"end": 5.0,
"text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
"speaker": "Speaker A",
}
],
} }
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]: def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:

File diff suppressed because it is too large Load Diff