fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解
This commit is contained in:
@@ -209,7 +209,7 @@ class AIManager:
|
||||
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
|
||||
def _get_db(self):
|
||||
def _get_db(self) -> None:
|
||||
"""获取数据库连接"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -921,7 +921,6 @@ class AIManager:
|
||||
|
||||
# 解析关键要点
|
||||
key_points = []
|
||||
import re
|
||||
|
||||
# 尝试从 JSON 中提取
|
||||
json_match = re.search(r"\{.*?\}", content, re.DOTALL)
|
||||
@@ -1312,7 +1311,7 @@ class AIManager:
|
||||
|
||||
return [self._row_to_prediction_result(row) for row in rows]
|
||||
|
||||
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool):
|
||||
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None:
|
||||
"""更新预测反馈(用于模型改进)"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
|
||||
@@ -51,7 +51,7 @@ class ApiKeyManager:
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
def _init_db(self) -> None:
|
||||
"""初始化数据库表"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.executescript("""
|
||||
@@ -331,7 +331,7 @@ class ApiKeyManager:
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def update_last_used(self, key_id: str):
|
||||
def update_last_used(self, key_id: str) -> None:
|
||||
"""更新最后使用时间"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
|
||||
@@ -195,7 +195,7 @@ class CollaborationManager:
|
||||
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
|
||||
return hashlib.sha256(data.encode()).hexdigest()[:32]
|
||||
|
||||
def _save_share_to_db(self, share: ProjectShare):
|
||||
def _save_share_to_db(self, share: ProjectShare) -> None:
|
||||
"""保存分享记录到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -286,7 +286,7 @@ class CollaborationManager:
|
||||
allow_export=bool(row[12]),
|
||||
)
|
||||
|
||||
def increment_share_usage(self, token: str):
|
||||
def increment_share_usage(self, token: str) -> None:
|
||||
"""增加分享链接使用次数"""
|
||||
share = self._shares_cache.get(token)
|
||||
if share:
|
||||
@@ -403,7 +403,7 @@ class CollaborationManager:
|
||||
|
||||
return comment
|
||||
|
||||
def _save_comment_to_db(self, comment: Comment):
|
||||
def _save_comment_to_db(self, comment: Comment) -> None:
|
||||
"""保存评论到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -616,7 +616,7 @@ class CollaborationManager:
|
||||
|
||||
return record
|
||||
|
||||
def _save_change_to_db(self, record: ChangeRecord):
|
||||
def _save_change_to_db(self, record: ChangeRecord) -> None:
|
||||
"""保存变更记录到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -862,7 +862,7 @@ class CollaborationManager:
|
||||
}
|
||||
return permissions_map.get(role, ["read"])
|
||||
|
||||
def _save_member_to_db(self, member: TeamMember):
|
||||
def _save_member_to_db(self, member: TeamMember) -> None:
|
||||
"""保存成员到数据库"""
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -970,7 +970,7 @@ class CollaborationManager:
|
||||
permissions = json.loads(row[0]) if row[0] else []
|
||||
return permission in permissions or "admin" in permissions
|
||||
|
||||
def update_last_active(self, project_id: str, user_id: str):
|
||||
def update_last_active(self, project_id: str, user_id: str) -> None:
|
||||
"""更新用户最后活跃时间"""
|
||||
if not self.db:
|
||||
return
|
||||
@@ -992,7 +992,7 @@ class CollaborationManager:
|
||||
_collaboration_manager = None
|
||||
|
||||
|
||||
def get_collaboration_manager(db_manager=None):
|
||||
def get_collaboration_manager(db_manager=None) -> None:
|
||||
"""获取协作管理器单例"""
|
||||
global _collaboration_manager
|
||||
if _collaboration_manager is None:
|
||||
|
||||
@@ -123,7 +123,7 @@ class DatabaseManager:
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def init_db(self):
|
||||
def init_db(self) -> None:
|
||||
"""初始化数据库表"""
|
||||
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
|
||||
schema = f.read()
|
||||
@@ -299,7 +299,7 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
return self.get_entity(entity_id)
|
||||
|
||||
def delete_entity(self, entity_id: str):
|
||||
def delete_entity(self, entity_id: str) -> None:
|
||||
"""删除实体及其关联数据"""
|
||||
conn = self.get_conn()
|
||||
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
|
||||
|
||||
@@ -352,7 +352,7 @@ class DeveloperEcosystemManager:
|
||||
self.db_path = db_path
|
||||
self.platform_fee_rate = 0.30 # 平台抽成比例 30%
|
||||
|
||||
def _get_db(self):
|
||||
def _get_db(self) -> None:
|
||||
"""获取数据库连接"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -515,7 +515,7 @@ class DeveloperEcosystemManager:
|
||||
|
||||
return self.get_sdk_release(sdk_id)
|
||||
|
||||
def increment_sdk_download(self, sdk_id: str):
|
||||
def increment_sdk_download(self, sdk_id: str) -> None:
|
||||
"""增加 SDK 下载计数"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
@@ -785,7 +785,7 @@ class DeveloperEcosystemManager:
|
||||
|
||||
return self.get_template(template_id)
|
||||
|
||||
def increment_template_install(self, template_id: str):
|
||||
def increment_template_install(self, template_id: str) -> None:
|
||||
"""增加模板安装计数"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
@@ -852,7 +852,7 @@ class DeveloperEcosystemManager:
|
||||
|
||||
return review
|
||||
|
||||
def _update_template_rating(self, conn, template_id: str):
|
||||
def _update_template_rating(self, conn, template_id: str) -> None:
|
||||
"""更新模板评分"""
|
||||
row = conn.execute(
|
||||
"""
|
||||
@@ -1084,7 +1084,7 @@ class DeveloperEcosystemManager:
|
||||
|
||||
return self.get_plugin(plugin_id)
|
||||
|
||||
def increment_plugin_install(self, plugin_id: str, active: bool = True):
|
||||
def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None:
|
||||
"""增加插件安装计数"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
@@ -1160,7 +1160,7 @@ class DeveloperEcosystemManager:
|
||||
|
||||
return review
|
||||
|
||||
def _update_plugin_rating(self, conn, plugin_id: str):
|
||||
def _update_plugin_rating(self, conn, plugin_id: str) -> None:
|
||||
"""更新插件评分"""
|
||||
row = conn.execute(
|
||||
"""
|
||||
@@ -1421,7 +1421,7 @@ class DeveloperEcosystemManager:
|
||||
|
||||
return self.get_developer_profile(developer_id)
|
||||
|
||||
def update_developer_stats(self, developer_id: str):
|
||||
def update_developer_stats(self, developer_id: str) -> None:
|
||||
"""更新开发者统计信息"""
|
||||
with self._get_db() as conn:
|
||||
# 统计插件数量
|
||||
@@ -1574,7 +1574,7 @@ class DeveloperEcosystemManager:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [self._row_to_code_example(row) for row in rows]
|
||||
|
||||
def increment_example_view(self, example_id: str):
|
||||
def increment_example_view(self, example_id: str) -> None:
|
||||
"""增加代码示例查看计数"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
@@ -1587,7 +1587,7 @@ class DeveloperEcosystemManager:
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def increment_example_copy(self, example_id: str):
|
||||
def increment_example_copy(self, example_id: str) -> None:
|
||||
"""增加代码示例复制计数"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
|
||||
@@ -329,7 +329,7 @@ class EnterpriseManager:
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
def _init_db(self) -> None:
|
||||
"""初始化数据库表"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
@@ -1190,7 +1190,7 @@ class EnterpriseManager:
|
||||
# GET {scim_base_url}/Users
|
||||
return []
|
||||
|
||||
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]):
|
||||
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]) -> None:
|
||||
"""插入或更新 SCIM 用户"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
@@ -22,14 +22,7 @@ try:
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.platypus import (
|
||||
PageBreak,
|
||||
Paragraph,
|
||||
SimpleDocTemplate,
|
||||
Spacer,
|
||||
Table,
|
||||
TableStyle,
|
||||
)
|
||||
from reportlab.platypus import PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle
|
||||
|
||||
REPORTLAB_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -194,20 +187,16 @@ class ExportManager:
|
||||
f'fill="white" stroke="#bdc3c7" rx="5"/>'
|
||||
)
|
||||
svg_parts.append(
|
||||
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
|
||||
f'fill="#2c3e50">实体类型</text>'
|
||||
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" ' f'fill="#2c3e50">实体类型</text>'
|
||||
)
|
||||
|
||||
for i, (etype, color) in enumerate(type_colors.items()):
|
||||
if etype != "default":
|
||||
y_pos = legend_y + 25 + i * 20
|
||||
svg_parts.append(
|
||||
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
|
||||
)
|
||||
svg_parts.append(f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>')
|
||||
text_y = y_pos + 4
|
||||
svg_parts.append(
|
||||
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
|
||||
f'fill="#2c3e50">{etype}</text>'
|
||||
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" ' f'fill="#2c3e50">{etype}</text>'
|
||||
)
|
||||
|
||||
svg_parts.append("</svg>")
|
||||
@@ -320,7 +309,6 @@ class ExportManager:
|
||||
Returns:
|
||||
CSV 字符串
|
||||
"""
|
||||
import csv
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
@@ -586,7 +574,7 @@ class ExportManager:
|
||||
_export_manager = None
|
||||
|
||||
|
||||
def get_export_manager(db_manager=None):
|
||||
def get_export_manager(db_manager=None) -> None:
|
||||
"""获取导出管理器实例"""
|
||||
global _export_manager
|
||||
if _export_manager is None:
|
||||
|
||||
@@ -369,7 +369,7 @@ class GrowthManager:
|
||||
self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "")
|
||||
self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "")
|
||||
|
||||
def _get_db(self):
|
||||
def _get_db(self) -> None:
|
||||
"""获取数据库连接"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
@@ -93,7 +93,7 @@ class ImageProcessor:
|
||||
"other": "其他",
|
||||
}
|
||||
|
||||
def __init__(self, temp_dir: str = None):
|
||||
def __init__(self, temp_dir: str = None) -> None:
|
||||
"""
|
||||
初始化图片处理器
|
||||
|
||||
@@ -103,7 +103,7 @@ class ImageProcessor:
|
||||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
|
||||
def preprocess_image(self, image, image_type: str = None):
|
||||
def preprocess_image(self, image, image_type: str = None) -> None:
|
||||
"""
|
||||
预处理图片以提高OCR质量
|
||||
|
||||
@@ -145,7 +145,7 @@ class ImageProcessor:
|
||||
print(f"Image preprocessing error: {e}")
|
||||
return image
|
||||
|
||||
def _enhance_whiteboard(self, image):
|
||||
def _enhance_whiteboard(self, image) -> None:
|
||||
"""增强白板图片"""
|
||||
# 转换为灰度
|
||||
gray = image.convert("L")
|
||||
@@ -160,7 +160,7 @@ class ImageProcessor:
|
||||
|
||||
return binary.convert("L")
|
||||
|
||||
def _enhance_handwritten(self, image):
|
||||
def _enhance_handwritten(self, image) -> None:
|
||||
"""增强手写笔记图片"""
|
||||
# 转换为灰度
|
||||
gray = image.convert("L")
|
||||
|
||||
@@ -163,8 +163,6 @@ class KnowledgeReasoner:
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
@@ -217,8 +215,6 @@ class KnowledgeReasoner:
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
@@ -271,8 +267,6 @@ class KnowledgeReasoner:
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
@@ -325,8 +319,6 @@ class KnowledgeReasoner:
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.4)
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
@@ -474,8 +466,6 @@ class KnowledgeReasoner:
|
||||
|
||||
content = await self._call_llm(prompt, temperature=0.3)
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
|
||||
@@ -204,15 +204,13 @@ class LLMClient:
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
content = await self.chat(messages, temperature=0.1)
|
||||
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||
if not json_match:
|
||||
return {"intent": "unknown", "explanation": "无法解析指令"}
|
||||
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
except BaseException:
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
return {"intent": "unknown", "explanation": "解析失败"}
|
||||
|
||||
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
|
||||
|
||||
@@ -938,9 +938,7 @@ class LocalizationManager:
|
||||
finally:
|
||||
self._close_if_file_db(conn)
|
||||
|
||||
def get_translation(
|
||||
self, key: str, language: str, namespace: str = "common", fallback: bool = True
|
||||
) -> str | None:
|
||||
def get_translation(self, key: str, language: str, namespace: str = "common", fallback: bool = True) -> str | None:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
3401
backend/main.py
3401
backend/main.py
File diff suppressed because it is too large
Load Diff
@@ -80,7 +80,7 @@ class MultimodalEntityLinker:
|
||||
# 模态类型
|
||||
MODALITIES = ["audio", "video", "image", "document"]
|
||||
|
||||
def __init__(self, similarity_threshold: float = 0.85):
|
||||
def __init__(self, similarity_threshold: float = 0.85) -> None:
|
||||
"""
|
||||
初始化多模态实体关联器
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class VideoProcessingResult:
|
||||
class MultimodalProcessor:
|
||||
"""多模态处理器 - 处理视频文件"""
|
||||
|
||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5):
|
||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
|
||||
"""
|
||||
初始化多模态处理器
|
||||
|
||||
@@ -401,7 +401,7 @@ class MultimodalProcessor:
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
def cleanup(self, video_id: str = None):
|
||||
def cleanup(self, video_id: str = None) -> None:
|
||||
"""
|
||||
清理临时文件
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class Neo4jManager:
|
||||
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
def _connect(self) -> None:
|
||||
"""建立 Neo4j 连接"""
|
||||
if not NEO4J_AVAILABLE:
|
||||
return
|
||||
@@ -121,7 +121,7 @@ class Neo4jManager:
|
||||
logger.error(f"Failed to connect to Neo4j: {e}")
|
||||
self._driver = None
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self._driver:
|
||||
self._driver.close()
|
||||
@@ -137,7 +137,7 @@ class Neo4jManager:
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
def init_schema(self):
|
||||
def init_schema(self) -> None:
|
||||
"""初始化图数据库 Schema(约束和索引)"""
|
||||
if not self._driver:
|
||||
logger.error("Neo4j not connected")
|
||||
@@ -178,7 +178,7 @@ class Neo4jManager:
|
||||
|
||||
# ==================== 数据同步 ====================
|
||||
|
||||
def sync_project(self, project_id: str, project_name: str, project_description: str = ""):
|
||||
def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None:
|
||||
"""同步项目节点到 Neo4j"""
|
||||
if not self._driver:
|
||||
return
|
||||
@@ -196,7 +196,7 @@ class Neo4jManager:
|
||||
description=project_description,
|
||||
)
|
||||
|
||||
def sync_entity(self, entity: GraphEntity):
|
||||
def sync_entity(self, entity: GraphEntity) -> None:
|
||||
"""同步单个实体到 Neo4j"""
|
||||
if not self._driver:
|
||||
return
|
||||
@@ -225,7 +225,7 @@ class Neo4jManager:
|
||||
properties=json.dumps(entity.properties),
|
||||
)
|
||||
|
||||
def sync_entities_batch(self, entities: list[GraphEntity]):
|
||||
def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
|
||||
"""批量同步实体到 Neo4j"""
|
||||
if not self._driver or not entities:
|
||||
return
|
||||
@@ -262,7 +262,7 @@ class Neo4jManager:
|
||||
entities=entities_data,
|
||||
)
|
||||
|
||||
def sync_relation(self, relation: GraphRelation):
|
||||
def sync_relation(self, relation: GraphRelation) -> None:
|
||||
"""同步单个关系到 Neo4j"""
|
||||
if not self._driver:
|
||||
return
|
||||
@@ -286,7 +286,7 @@ class Neo4jManager:
|
||||
properties=json.dumps(relation.properties),
|
||||
)
|
||||
|
||||
def sync_relations_batch(self, relations: list[GraphRelation]):
|
||||
def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
|
||||
"""批量同步关系到 Neo4j"""
|
||||
if not self._driver or not relations:
|
||||
return
|
||||
@@ -318,7 +318,7 @@ class Neo4jManager:
|
||||
relations=relations_data,
|
||||
)
|
||||
|
||||
def delete_entity(self, entity_id: str):
|
||||
def delete_entity(self, entity_id: str) -> None:
|
||||
"""从 Neo4j 删除实体及其关系"""
|
||||
if not self._driver:
|
||||
return
|
||||
@@ -332,7 +332,7 @@ class Neo4jManager:
|
||||
id=entity_id,
|
||||
)
|
||||
|
||||
def delete_project(self, project_id: str):
|
||||
def delete_project(self, project_id: str) -> None:
|
||||
"""从 Neo4j 删除项目及其所有实体和关系"""
|
||||
if not self._driver:
|
||||
return
|
||||
@@ -949,7 +949,7 @@ def get_neo4j_manager() -> Neo4jManager:
|
||||
return _neo4j_manager
|
||||
|
||||
|
||||
def close_neo4j_manager():
|
||||
def close_neo4j_manager() -> None:
|
||||
"""关闭 Neo4j 连接"""
|
||||
global _neo4j_manager
|
||||
if _neo4j_manager:
|
||||
@@ -958,7 +958,7 @@ def close_neo4j_manager():
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]):
|
||||
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None:
|
||||
"""
|
||||
同步整个项目到 Neo4j
|
||||
|
||||
|
||||
@@ -456,13 +456,13 @@ class OpsManager:
|
||||
self._evaluator_thread = None
|
||||
self._register_default_evaluators()
|
||||
|
||||
def _get_db(self):
|
||||
def _get_db(self) -> None:
|
||||
"""获取数据库连接"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _register_default_evaluators(self):
|
||||
def _register_default_evaluators(self) -> None:
|
||||
"""注册默认的告警评估器"""
|
||||
self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule
|
||||
self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule
|
||||
@@ -1249,7 +1249,7 @@ class OpsManager:
|
||||
|
||||
return self.get_alert(alert_id)
|
||||
|
||||
def _increment_suppression_count(self, alert_id: str):
|
||||
def _increment_suppression_count(self, alert_id: str) -> None:
|
||||
"""增加告警抑制计数"""
|
||||
with self._get_db() as conn:
|
||||
conn.execute(
|
||||
@@ -1262,7 +1262,7 @@ class OpsManager:
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool):
|
||||
def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool) -> None:
|
||||
"""更新告警通知状态"""
|
||||
with self._get_db() as conn:
|
||||
row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone()
|
||||
@@ -1276,7 +1276,7 @@ class OpsManager:
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _update_channel_stats(self, channel_id: str, success: bool):
|
||||
def _update_channel_stats(self, channel_id: str, success: bool) -> None:
|
||||
"""更新渠道统计"""
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
@@ -1769,9 +1769,7 @@ class OpsManager:
|
||||
return self._row_to_scaling_event(row)
|
||||
return None
|
||||
|
||||
def update_scaling_event_status(
|
||||
self, event_id: str, status: str, error_message: str = None
|
||||
) -> ScalingEvent | None:
|
||||
def update_scaling_event_status(self, event_id: str, status: str, error_message: str = None) -> ScalingEvent | None:
|
||||
"""更新扩缩容事件状态"""
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
@@ -2339,7 +2337,7 @@ class OpsManager:
|
||||
|
||||
return record
|
||||
|
||||
def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None):
|
||||
def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None:
|
||||
"""完成备份"""
|
||||
now = datetime.now().isoformat()
|
||||
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]
|
||||
|
||||
@@ -37,7 +37,7 @@ class OSSUploader:
|
||||
url = self.bucket.sign_url("GET", object_name, 3600)
|
||||
return url, object_name
|
||||
|
||||
def delete_object(self, object_name: str):
|
||||
def delete_object(self, object_name: str) -> None:
|
||||
"""删除 OSS 对象"""
|
||||
self.bucket.delete_object(object_name)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class CacheStats:
|
||||
expired: int = 0
|
||||
hit_rate: float = 0.0
|
||||
|
||||
def update_hit_rate(self):
|
||||
def update_hit_rate(self) -> None:
|
||||
"""更新命中率"""
|
||||
if self.total_requests > 0:
|
||||
self.hit_rate = round(self.hits / self.total_requests, 4)
|
||||
@@ -194,7 +194,7 @@ class CacheManager:
|
||||
# 初始化缓存统计表
|
||||
self._init_cache_tables()
|
||||
|
||||
def _init_cache_tables(self):
|
||||
def _init_cache_tables(self) -> None:
|
||||
"""初始化缓存统计表"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
@@ -234,7 +234,7 @@ class CacheManager:
|
||||
except BaseException:
|
||||
return 1024 # 默认估算
|
||||
|
||||
def _evict_lru(self, required_space: int = 0):
|
||||
def _evict_lru(self, required_space: int = 0) -> None:
|
||||
"""LRU 淘汰策略"""
|
||||
with self.cache_lock:
|
||||
while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache:
|
||||
@@ -444,7 +444,7 @@ class CacheManager:
|
||||
|
||||
return stats
|
||||
|
||||
def save_stats(self):
|
||||
def save_stats(self) -> None:
|
||||
"""保存缓存统计到数据库"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
@@ -618,7 +618,7 @@ class DatabaseSharding:
|
||||
# 初始化分片
|
||||
self._init_shards()
|
||||
|
||||
def _init_shards(self):
|
||||
def _init_shards(self) -> None:
|
||||
"""初始化分片"""
|
||||
# 计算每个分片的 key 范围
|
||||
chars = "0123456789abcdef"
|
||||
@@ -645,7 +645,7 @@ class DatabaseSharding:
|
||||
if not os.path.exists(db_path):
|
||||
self._create_shard_db(db_path)
|
||||
|
||||
def _create_shard_db(self, db_path: str):
|
||||
def _create_shard_db(self, db_path: str) -> None:
|
||||
"""创建分片数据库"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
@@ -792,7 +792,7 @@ class DatabaseSharding:
|
||||
print(f"迁移失败: {e}")
|
||||
return False
|
||||
|
||||
def _update_shard_stats(self, shard_id: str):
|
||||
def _update_shard_stats(self, shard_id: str) -> None:
|
||||
"""更新分片统计"""
|
||||
shard_info = self.shard_map.get(shard_id)
|
||||
if not shard_info:
|
||||
@@ -923,7 +923,7 @@ class TaskQueue:
|
||||
except Exception as e:
|
||||
print(f"Celery 初始化失败,使用内存任务队列: {e}")
|
||||
|
||||
def _init_task_tables(self):
|
||||
def _init_task_tables(self) -> None:
|
||||
"""初始化任务队列表"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
@@ -953,7 +953,7 @@ class TaskQueue:
|
||||
"""检查任务队列是否可用"""
|
||||
return self.use_celery or True # 内存模式也可用
|
||||
|
||||
def register_handler(self, task_type: str, handler: Callable):
|
||||
def register_handler(self, task_type: str, handler: Callable) -> None:
|
||||
"""注册任务处理器"""
|
||||
self.task_handlers[task_type] = handler
|
||||
|
||||
@@ -1014,7 +1014,7 @@ class TaskQueue:
|
||||
|
||||
return task_id
|
||||
|
||||
def _execute_task(self, task_id: str):
|
||||
def _execute_task(self, task_id: str) -> None:
|
||||
"""执行任务(内存模式)"""
|
||||
with self.task_lock:
|
||||
task = self.tasks.get(task_id)
|
||||
@@ -1055,7 +1055,7 @@ class TaskQueue:
|
||||
|
||||
self._update_task_status(task)
|
||||
|
||||
def _save_task(self, task: TaskInfo):
|
||||
def _save_task(self, task: TaskInfo) -> None:
|
||||
"""保存任务到数据库"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
@@ -1084,7 +1084,7 @@ class TaskQueue:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _update_task_status(self, task: TaskInfo):
|
||||
def _update_task_status(self, task: TaskInfo) -> None:
|
||||
"""更新任务状态"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
@@ -1143,9 +1143,7 @@ class TaskQueue:
|
||||
with self.task_lock:
|
||||
return self.tasks.get(task_id)
|
||||
|
||||
def list_tasks(
|
||||
self, status: str | None = None, task_type: str | None = None, limit: int = 100
|
||||
) -> list[TaskInfo]:
|
||||
def list_tasks(self, status: str | None = None, task_type: str | None = None, limit: int = 100) -> list[TaskInfo]:
|
||||
"""列出任务"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -1333,7 +1331,7 @@ class PerformanceMonitor:
|
||||
if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
|
||||
self._record_slow_query(metric)
|
||||
|
||||
def _flush_metrics(self):
|
||||
def _flush_metrics(self) -> None:
|
||||
"""将缓冲区指标写入数据库"""
|
||||
if not self.metrics_buffer:
|
||||
return
|
||||
@@ -1362,12 +1360,12 @@ class PerformanceMonitor:
|
||||
|
||||
self.metrics_buffer = []
|
||||
|
||||
def _record_slow_query(self, metric: PerformanceMetric):
|
||||
def _record_slow_query(self, metric: PerformanceMetric) -> None:
|
||||
"""记录慢查询"""
|
||||
# 可以发送到专门的慢查询日志或监控系统
|
||||
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
|
||||
|
||||
def _trigger_alert(self, metric: PerformanceMetric):
|
||||
def _trigger_alert(self, metric: PerformanceMetric) -> None:
|
||||
"""触发告警"""
|
||||
alert_data = {
|
||||
"type": "performance_alert",
|
||||
@@ -1382,7 +1380,7 @@ class PerformanceMonitor:
|
||||
except Exception as e:
|
||||
print(f"告警处理失败: {e}")
|
||||
|
||||
def register_alert_handler(self, handler: Callable):
|
||||
def register_alert_handler(self, handler: Callable) -> None:
|
||||
"""注册告警处理器"""
|
||||
self.alert_handlers.append(handler)
|
||||
|
||||
@@ -1585,7 +1583,9 @@ class PerformanceMonitor:
|
||||
# ==================== 性能装饰器 ====================
|
||||
|
||||
|
||||
def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None):
|
||||
def cached(
|
||||
cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None
|
||||
) -> None:
|
||||
"""
|
||||
缓存装饰器
|
||||
|
||||
@@ -1598,7 +1598,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args, **kwargs) -> None:
|
||||
# 生成缓存键
|
||||
if key_func:
|
||||
cache_key = key_func(*args, **kwargs)
|
||||
@@ -1625,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
|
||||
return decorator
|
||||
|
||||
|
||||
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None):
|
||||
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
|
||||
"""
|
||||
性能监控装饰器
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ class PluginManager:
|
||||
self._handlers = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
def _register_default_handlers(self):
|
||||
def _register_default_handlers(self) -> None:
|
||||
"""注册默认处理器"""
|
||||
self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self)
|
||||
self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu")
|
||||
@@ -371,7 +371,7 @@ class PluginManager:
|
||||
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def record_plugin_usage(self, plugin_id: str):
|
||||
def record_plugin_usage(self, plugin_id: str) -> None:
|
||||
"""记录插件使用"""
|
||||
conn = self.db.get_conn()
|
||||
now = datetime.now().isoformat()
|
||||
@@ -826,9 +826,6 @@ class BotHandler:
|
||||
|
||||
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
|
||||
"""发送飞书消息"""
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
# 生成签名
|
||||
@@ -851,9 +848,6 @@ class BotHandler:
|
||||
|
||||
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
|
||||
"""发送钉钉消息"""
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
timestamp = str(round(time.time() * 1000))
|
||||
|
||||
# 生成签名
|
||||
@@ -1358,7 +1352,7 @@ class WebDAVSyncManager:
|
||||
_plugin_manager = None
|
||||
|
||||
|
||||
def get_plugin_manager(db_manager=None):
|
||||
def get_plugin_manager(db_manager=None) -> None:
|
||||
"""获取 PluginManager 单例"""
|
||||
global _plugin_manager
|
||||
if _plugin_manager is None:
|
||||
|
||||
@@ -56,7 +56,7 @@ class SlidingWindowCounter:
|
||||
self._cleanup_old(now)
|
||||
return sum(self.requests.values())
|
||||
|
||||
def _cleanup_old(self, now: int):
|
||||
def _cleanup_old(self, now: int) -> None:
|
||||
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
||||
cutoff = now - self.window_size
|
||||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||
@@ -67,7 +67,7 @@ class SlidingWindowCounter:
|
||||
class RateLimiter:
|
||||
"""API 限流器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# key -> SlidingWindowCounter
|
||||
self.counters: dict[str, SlidingWindowCounter] = {}
|
||||
# key -> RateLimitConfig
|
||||
@@ -143,7 +143,7 @@ class RateLimiter:
|
||||
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
|
||||
)
|
||||
|
||||
def reset(self, key: str | None = None):
|
||||
def reset(self, key: str | None = None) -> None:
|
||||
"""重置限流计数器"""
|
||||
if key:
|
||||
self.counters.pop(key, None)
|
||||
@@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter:
|
||||
|
||||
|
||||
# 限流装饰器(用于函数级别限流)
|
||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None):
|
||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||||
"""
|
||||
限流装饰器
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ try:
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
CRYPTO_AVAILABLE = True
|
||||
except ImportError:
|
||||
CRYPTO_AVAILABLE = False
|
||||
@@ -27,6 +28,7 @@ except ImportError:
|
||||
|
||||
class AuditActionType(Enum):
|
||||
"""审计动作类型"""
|
||||
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
@@ -49,6 +51,7 @@ class AuditActionType(Enum):
|
||||
|
||||
class DataSensitivityLevel(Enum):
|
||||
"""数据敏感度级别"""
|
||||
|
||||
PUBLIC = "public" # 公开
|
||||
INTERNAL = "internal" # 内部
|
||||
CONFIDENTIAL = "confidential" # 机密
|
||||
@@ -57,6 +60,7 @@ class DataSensitivityLevel(Enum):
|
||||
|
||||
class MaskingRuleType(Enum):
|
||||
"""脱敏规则类型"""
|
||||
|
||||
PHONE = "phone" # 手机号
|
||||
EMAIL = "email" # 邮箱
|
||||
ID_CARD = "id_card" # 身份证号
|
||||
@@ -69,6 +73,7 @@ class MaskingRuleType(Enum):
|
||||
@dataclass
|
||||
class AuditLog:
|
||||
"""审计日志条目"""
|
||||
|
||||
id: str
|
||||
action_type: str
|
||||
user_id: str | None = None
|
||||
@@ -90,6 +95,7 @@ class AuditLog:
|
||||
@dataclass
|
||||
class EncryptionConfig:
|
||||
"""加密配置"""
|
||||
|
||||
id: str
|
||||
project_id: str
|
||||
is_enabled: bool = False
|
||||
@@ -107,6 +113,7 @@ class EncryptionConfig:
|
||||
@dataclass
|
||||
class MaskingRule:
|
||||
"""脱敏规则"""
|
||||
|
||||
id: str
|
||||
project_id: str
|
||||
name: str
|
||||
@@ -126,6 +133,7 @@ class MaskingRule:
|
||||
@dataclass
|
||||
class DataAccessPolicy:
|
||||
"""数据访问策略"""
|
||||
|
||||
id: str
|
||||
project_id: str
|
||||
name: str
|
||||
@@ -147,6 +155,7 @@ class DataAccessPolicy:
|
||||
@dataclass
|
||||
class AccessRequest:
|
||||
"""访问请求(用于需要审批的访问)"""
|
||||
|
||||
id: str
|
||||
policy_id: str
|
||||
user_id: str
|
||||
@@ -166,30 +175,15 @@ class SecurityManager:
|
||||
|
||||
# 预定义脱敏规则
|
||||
DEFAULT_MASKING_RULES = {
|
||||
MaskingRuleType.PHONE: {
|
||||
"pattern": r"(\d{3})\d{4}(\d{4})",
|
||||
"replacement": r"\1****\2"
|
||||
},
|
||||
MaskingRuleType.EMAIL: {
|
||||
"pattern": r"(\w{1,3})\w+(@\w+\.\w+)",
|
||||
"replacement": r"\1***\2"
|
||||
},
|
||||
MaskingRuleType.ID_CARD: {
|
||||
"pattern": r"(\d{6})\d{8}(\d{4})",
|
||||
"replacement": r"\1********\2"
|
||||
},
|
||||
MaskingRuleType.BANK_CARD: {
|
||||
"pattern": r"(\d{4})\d+(\d{4})",
|
||||
"replacement": r"\1 **** **** \2"
|
||||
},
|
||||
MaskingRuleType.NAME: {
|
||||
"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+",
|
||||
"replacement": r"\1**"
|
||||
},
|
||||
MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
|
||||
MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
|
||||
MaskingRuleType.ID_CARD: {"pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2"},
|
||||
MaskingRuleType.BANK_CARD: {"pattern": r"(\d{4})\d+(\d{4})", "replacement": r"\1 **** **** \2"},
|
||||
MaskingRuleType.NAME: {"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", "replacement": r"\1**"},
|
||||
MaskingRuleType.ADDRESS: {
|
||||
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
|
||||
"replacement": r"\1\2***"
|
||||
}
|
||||
"replacement": r"\1\2***",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, db_path: str = "insightflow.db"):
|
||||
@@ -200,7 +194,7 @@ class SecurityManager:
|
||||
self._local = {}
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
def _init_db(self) -> None:
|
||||
"""初始化数据库表"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -308,9 +302,7 @@ class SecurityManager:
|
||||
|
||||
def _generate_id(self) -> str:
|
||||
"""生成唯一ID"""
|
||||
return hashlib.sha256(
|
||||
f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
|
||||
).hexdigest()[:32]
|
||||
return hashlib.sha256(f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()).hexdigest()[:32]
|
||||
|
||||
# ==================== 审计日志 ====================
|
||||
|
||||
@@ -326,7 +318,7 @@ class SecurityManager:
|
||||
before_value: str | None = None,
|
||||
after_value: str | None = None,
|
||||
success: bool = True,
|
||||
error_message: str | None = None
|
||||
error_message: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""记录审计日志"""
|
||||
log = AuditLog(
|
||||
@@ -341,22 +333,34 @@ class SecurityManager:
|
||||
before_value=before_value,
|
||||
after_value=after_value,
|
||||
success=success,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO audit_logs
|
||||
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
|
||||
action_details, before_value, after_value, success, error_message, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
log.id, log.action_type, log.user_id, log.user_ip, log.user_agent,
|
||||
log.resource_type, log.resource_id, log.action_details,
|
||||
log.before_value, log.after_value, int(log.success),
|
||||
log.error_message, log.created_at
|
||||
))
|
||||
""",
|
||||
(
|
||||
log.id,
|
||||
log.action_type,
|
||||
log.user_id,
|
||||
log.user_ip,
|
||||
log.user_agent,
|
||||
log.resource_type,
|
||||
log.resource_id,
|
||||
log.action_details,
|
||||
log.before_value,
|
||||
log.after_value,
|
||||
int(log.success),
|
||||
log.error_message,
|
||||
log.created_at,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@@ -372,7 +376,7 @@ class SecurityManager:
|
||||
end_time: str | None = None,
|
||||
success: bool | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
offset: int = 0,
|
||||
) -> list[AuditLog]:
|
||||
"""查询审计日志"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -429,18 +433,14 @@ class SecurityManager:
|
||||
after_value=row[9],
|
||||
success=bool(row[10]),
|
||||
error_message=row[11],
|
||||
created_at=row[12]
|
||||
created_at=row[12],
|
||||
)
|
||||
logs.append(log)
|
||||
|
||||
conn.close()
|
||||
return logs
|
||||
|
||||
def get_audit_stats(
|
||||
self,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
def get_audit_stats(self, start_time: str | None = None, end_time: str | None = None) -> dict[str, Any]:
|
||||
"""获取审计统计"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -460,12 +460,7 @@ class SecurityManager:
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
stats = {
|
||||
"total_actions": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"action_breakdown": {}
|
||||
}
|
||||
stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}}
|
||||
|
||||
for action_type, success, count in rows:
|
||||
stats["total_actions"] += count
|
||||
@@ -500,11 +495,7 @@ class SecurityManager:
|
||||
)
|
||||
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||||
|
||||
def enable_encryption(
|
||||
self,
|
||||
project_id: str,
|
||||
master_password: str
|
||||
) -> EncryptionConfig:
|
||||
def enable_encryption(self, project_id: str, master_password: str) -> EncryptionConfig:
|
||||
"""启用项目加密"""
|
||||
if not CRYPTO_AVAILABLE:
|
||||
raise RuntimeError("cryptography library not available")
|
||||
@@ -523,43 +514,54 @@ class SecurityManager:
|
||||
encryption_type="aes-256-gcm",
|
||||
key_derivation="pbkdf2",
|
||||
master_key_hash=key_hash,
|
||||
salt=salt
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查是否已存在配置
|
||||
cursor.execute(
|
||||
"SELECT id FROM encryption_configs WHERE project_id = ?",
|
||||
(project_id,)
|
||||
)
|
||||
cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,))
|
||||
existing = cursor.fetchone()
|
||||
|
||||
if existing:
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE encryption_configs
|
||||
SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
|
||||
master_key_hash = ?, salt = ?, updated_at = ?
|
||||
WHERE project_id = ?
|
||||
""", (
|
||||
config.encryption_type, config.key_derivation,
|
||||
config.master_key_hash, config.salt,
|
||||
config.updated_at, project_id
|
||||
))
|
||||
""",
|
||||
(
|
||||
config.encryption_type,
|
||||
config.key_derivation,
|
||||
config.master_key_hash,
|
||||
config.salt,
|
||||
config.updated_at,
|
||||
project_id,
|
||||
),
|
||||
)
|
||||
config.id = existing[0]
|
||||
else:
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO encryption_configs
|
||||
(id, project_id, is_enabled, encryption_type, key_derivation,
|
||||
master_key_hash, salt, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
config.id, config.project_id, int(config.is_enabled),
|
||||
config.encryption_type, config.key_derivation,
|
||||
config.master_key_hash, config.salt,
|
||||
config.created_at, config.updated_at
|
||||
))
|
||||
""",
|
||||
(
|
||||
config.id,
|
||||
config.project_id,
|
||||
int(config.is_enabled),
|
||||
config.encryption_type,
|
||||
config.key_derivation,
|
||||
config.master_key_hash,
|
||||
config.salt,
|
||||
config.created_at,
|
||||
config.updated_at,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -569,16 +571,12 @@ class SecurityManager:
|
||||
action_type=AuditActionType.ENCRYPTION_ENABLE,
|
||||
resource_type="project",
|
||||
resource_id=project_id,
|
||||
action_details={"encryption_type": config.encryption_type}
|
||||
action_details={"encryption_type": config.encryption_type},
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def disable_encryption(
|
||||
self,
|
||||
project_id: str,
|
||||
master_password: str
|
||||
) -> bool:
|
||||
def disable_encryption(self, project_id: str, master_password: str) -> bool:
|
||||
"""禁用项目加密"""
|
||||
# 验证密码
|
||||
if not self.verify_encryption_password(project_id, master_password):
|
||||
@@ -587,29 +585,24 @@ class SecurityManager:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE encryption_configs
|
||||
SET is_enabled = 0, updated_at = ?
|
||||
WHERE project_id = ?
|
||||
""", (datetime.now().isoformat(), project_id))
|
||||
""",
|
||||
(datetime.now().isoformat(), project_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# 记录审计日志
|
||||
self.log_audit(
|
||||
action_type=AuditActionType.ENCRYPTION_DISABLE,
|
||||
resource_type="project",
|
||||
resource_id=project_id
|
||||
)
|
||||
self.log_audit(action_type=AuditActionType.ENCRYPTION_DISABLE, resource_type="project", resource_id=project_id)
|
||||
|
||||
return True
|
||||
|
||||
def verify_encryption_password(
|
||||
self,
|
||||
project_id: str,
|
||||
password: str
|
||||
) -> bool:
|
||||
def verify_encryption_password(self, project_id: str, password: str) -> bool:
|
||||
"""验证加密密码"""
|
||||
if not CRYPTO_AVAILABLE:
|
||||
return False
|
||||
@@ -617,10 +610,7 @@ class SecurityManager:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
|
||||
(project_id,)
|
||||
)
|
||||
cursor.execute("SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -638,10 +628,7 @@ class SecurityManager:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM encryption_configs WHERE project_id = ?",
|
||||
(project_id,)
|
||||
)
|
||||
cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -657,15 +644,10 @@ class SecurityManager:
|
||||
master_key_hash=row[5],
|
||||
salt=row[6],
|
||||
created_at=row[7],
|
||||
updated_at=row[8]
|
||||
updated_at=row[8],
|
||||
)
|
||||
|
||||
def encrypt_data(
|
||||
self,
|
||||
data: str,
|
||||
password: str,
|
||||
salt: str | None = None
|
||||
) -> tuple[str, str]:
|
||||
def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
|
||||
"""加密数据"""
|
||||
if not CRYPTO_AVAILABLE:
|
||||
raise RuntimeError("cryptography library not available")
|
||||
@@ -679,12 +661,7 @@ class SecurityManager:
|
||||
|
||||
return base64.b64encode(encrypted).decode(), salt
|
||||
|
||||
def decrypt_data(
|
||||
self,
|
||||
encrypted_data: str,
|
||||
password: str,
|
||||
salt: str
|
||||
) -> str:
|
||||
def decrypt_data(self, encrypted_data: str, password: str, salt: str) -> str:
|
||||
"""解密数据"""
|
||||
if not CRYPTO_AVAILABLE:
|
||||
raise RuntimeError("cryptography library not available")
|
||||
@@ -705,7 +682,7 @@ class SecurityManager:
|
||||
pattern: str | None = None,
|
||||
replacement: str | None = None,
|
||||
description: str | None = None,
|
||||
priority: int = 0
|
||||
priority: int = 0,
|
||||
) -> MaskingRule:
|
||||
"""创建脱敏规则"""
|
||||
# 使用预定义规则或自定义规则
|
||||
@@ -722,22 +699,33 @@ class SecurityManager:
|
||||
pattern=pattern or "",
|
||||
replacement=replacement or "****",
|
||||
description=description,
|
||||
priority=priority
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO masking_rules
|
||||
(id, project_id, name, rule_type, pattern, replacement,
|
||||
is_active, priority, description, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
rule.id, rule.project_id, rule.name, rule.rule_type,
|
||||
rule.pattern, rule.replacement, int(rule.is_active),
|
||||
rule.priority, rule.description, rule.created_at, rule.updated_at
|
||||
))
|
||||
""",
|
||||
(
|
||||
rule.id,
|
||||
rule.project_id,
|
||||
rule.name,
|
||||
rule.rule_type,
|
||||
rule.pattern,
|
||||
rule.replacement,
|
||||
int(rule.is_active),
|
||||
rule.priority,
|
||||
rule.description,
|
||||
rule.created_at,
|
||||
rule.updated_at,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -747,16 +735,12 @@ class SecurityManager:
|
||||
action_type=AuditActionType.DATA_MASKING,
|
||||
resource_type="project",
|
||||
resource_id=project_id,
|
||||
action_details={"action": "create_rule", "rule_name": name}
|
||||
action_details={"action": "create_rule", "rule_name": name},
|
||||
)
|
||||
|
||||
return rule
|
||||
|
||||
def get_masking_rules(
|
||||
self,
|
||||
project_id: str,
|
||||
active_only: bool = True
|
||||
) -> list[MaskingRule]:
|
||||
def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]:
|
||||
"""获取脱敏规则"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -775,7 +759,8 @@ class SecurityManager:
|
||||
|
||||
rules = []
|
||||
for row in rows:
|
||||
rules.append(MaskingRule(
|
||||
rules.append(
|
||||
MaskingRule(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
name=row[2],
|
||||
@@ -786,16 +771,13 @@ class SecurityManager:
|
||||
priority=row[7],
|
||||
description=row[8],
|
||||
created_at=row[9],
|
||||
updated_at=row[10]
|
||||
))
|
||||
updated_at=row[10],
|
||||
)
|
||||
)
|
||||
|
||||
return rules
|
||||
|
||||
def update_masking_rule(
|
||||
self,
|
||||
rule_id: str,
|
||||
**kwargs
|
||||
) -> MaskingRule | None:
|
||||
def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None:
|
||||
"""更新脱敏规则"""
|
||||
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
|
||||
|
||||
@@ -818,11 +800,14 @@ class SecurityManager:
|
||||
params.append(datetime.now().isoformat())
|
||||
params.append(rule_id)
|
||||
|
||||
cursor.execute(f"""
|
||||
cursor.execute(
|
||||
f"""
|
||||
UPDATE masking_rules
|
||||
SET {', '.join(set_clauses)}
|
||||
WHERE id = ?
|
||||
""", params)
|
||||
""",
|
||||
params,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -848,7 +833,7 @@ class SecurityManager:
|
||||
priority=row[7],
|
||||
description=row[8],
|
||||
created_at=row[9],
|
||||
updated_at=row[10]
|
||||
updated_at=row[10],
|
||||
)
|
||||
|
||||
def delete_masking_rule(self, rule_id: str) -> bool:
|
||||
@@ -864,12 +849,7 @@ class SecurityManager:
|
||||
|
||||
return success
|
||||
|
||||
def apply_masking(
|
||||
self,
|
||||
text: str,
|
||||
project_id: str,
|
||||
rule_types: list[MaskingRuleType] | None = None
|
||||
) -> str:
|
||||
def apply_masking(self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None) -> str:
|
||||
"""应用脱敏规则到文本"""
|
||||
rules = self.get_masking_rules(project_id)
|
||||
|
||||
@@ -884,22 +864,14 @@ class SecurityManager:
|
||||
continue
|
||||
|
||||
try:
|
||||
masked_text = re.sub(
|
||||
rule.pattern,
|
||||
rule.replacement,
|
||||
masked_text
|
||||
)
|
||||
masked_text = re.sub(rule.pattern, rule.replacement, masked_text)
|
||||
except re.error:
|
||||
# 忽略无效的正则表达式
|
||||
continue
|
||||
|
||||
return masked_text
|
||||
|
||||
def apply_masking_to_entity(
|
||||
self,
|
||||
entity_data: dict[str, Any],
|
||||
project_id: str
|
||||
) -> dict[str, Any]:
|
||||
def apply_masking_to_entity(self, entity_data: dict[str, Any], project_id: str) -> dict[str, Any]:
|
||||
"""对实体数据应用脱敏"""
|
||||
masked_data = entity_data.copy()
|
||||
|
||||
@@ -924,7 +896,7 @@ class SecurityManager:
|
||||
allowed_ips: list[str] | None = None,
|
||||
time_restrictions: dict | None = None,
|
||||
max_access_count: int | None = None,
|
||||
require_approval: bool = False
|
||||
require_approval: bool = False,
|
||||
) -> DataAccessPolicy:
|
||||
"""创建数据访问策略"""
|
||||
policy = DataAccessPolicy(
|
||||
@@ -937,36 +909,43 @@ class SecurityManager:
|
||||
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
|
||||
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
|
||||
max_access_count=max_access_count,
|
||||
require_approval=require_approval
|
||||
require_approval=require_approval,
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO data_access_policies
|
||||
(id, project_id, name, description, allowed_users, allowed_roles,
|
||||
allowed_ips, time_restrictions, max_access_count, require_approval,
|
||||
is_active, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
policy.id, policy.project_id, policy.name, policy.description,
|
||||
policy.allowed_users, policy.allowed_roles, policy.allowed_ips,
|
||||
policy.time_restrictions, policy.max_access_count,
|
||||
int(policy.require_approval), int(policy.is_active),
|
||||
policy.created_at, policy.updated_at
|
||||
))
|
||||
""",
|
||||
(
|
||||
policy.id,
|
||||
policy.project_id,
|
||||
policy.name,
|
||||
policy.description,
|
||||
policy.allowed_users,
|
||||
policy.allowed_roles,
|
||||
policy.allowed_ips,
|
||||
policy.time_restrictions,
|
||||
policy.max_access_count,
|
||||
int(policy.require_approval),
|
||||
int(policy.is_active),
|
||||
policy.created_at,
|
||||
policy.updated_at,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return policy
|
||||
|
||||
def get_access_policies(
|
||||
self,
|
||||
project_id: str,
|
||||
active_only: bool = True
|
||||
) -> list[DataAccessPolicy]:
|
||||
def get_access_policies(self, project_id: str, active_only: bool = True) -> list[DataAccessPolicy]:
|
||||
"""获取数据访问策略"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -983,7 +962,8 @@ class SecurityManager:
|
||||
|
||||
policies = []
|
||||
for row in rows:
|
||||
policies.append(DataAccessPolicy(
|
||||
policies.append(
|
||||
DataAccessPolicy(
|
||||
id=row[0],
|
||||
project_id=row[1],
|
||||
name=row[2],
|
||||
@@ -996,25 +976,20 @@ class SecurityManager:
|
||||
require_approval=bool(row[9]),
|
||||
is_active=bool(row[10]),
|
||||
created_at=row[11],
|
||||
updated_at=row[12]
|
||||
))
|
||||
updated_at=row[12],
|
||||
)
|
||||
)
|
||||
|
||||
return policies
|
||||
|
||||
def check_access_permission(
|
||||
self,
|
||||
policy_id: str,
|
||||
user_id: str,
|
||||
user_ip: str | None = None
|
||||
self, policy_id: str, user_id: str, user_ip: str | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""检查访问权限"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
|
||||
(policy_id,)
|
||||
)
|
||||
cursor.execute("SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -1034,7 +1009,7 @@ class SecurityManager:
|
||||
require_approval=bool(row[9]),
|
||||
is_active=bool(row[10]),
|
||||
created_at=row[11],
|
||||
updated_at=row[12]
|
||||
updated_at=row[12],
|
||||
)
|
||||
|
||||
# 检查用户白名单
|
||||
@@ -1074,11 +1049,14 @@ class SecurityManager:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM access_requests
|
||||
WHERE policy_id = ? AND user_id = ? AND status = 'approved'
|
||||
AND (expires_at IS NULL OR expires_at > ?)
|
||||
""", (policy_id, user_id, datetime.now().isoformat()))
|
||||
""",
|
||||
(policy_id, user_id, datetime.now().isoformat()),
|
||||
)
|
||||
|
||||
request = cursor.fetchone()
|
||||
conn.close()
|
||||
@@ -1104,11 +1082,7 @@ class SecurityManager:
|
||||
return ip == pattern
|
||||
|
||||
def create_access_request(
|
||||
self,
|
||||
policy_id: str,
|
||||
user_id: str,
|
||||
request_reason: str | None = None,
|
||||
expires_hours: int = 24
|
||||
self, policy_id: str, user_id: str, request_reason: str | None = None, expires_hours: int = 24
|
||||
) -> AccessRequest:
|
||||
"""创建访问请求"""
|
||||
request = AccessRequest(
|
||||
@@ -1116,21 +1090,28 @@ class SecurityManager:
|
||||
policy_id=policy_id,
|
||||
user_id=user_id,
|
||||
request_reason=request_reason,
|
||||
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
||||
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(),
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO access_requests
|
||||
(id, policy_id, user_id, request_reason, status, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
request.id, request.policy_id, request.user_id,
|
||||
request.request_reason, request.status, request.expires_at,
|
||||
request.created_at
|
||||
))
|
||||
""",
|
||||
(
|
||||
request.id,
|
||||
request.policy_id,
|
||||
request.user_id,
|
||||
request.request_reason,
|
||||
request.status,
|
||||
request.expires_at,
|
||||
request.created_at,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -1138,10 +1119,7 @@ class SecurityManager:
|
||||
return request
|
||||
|
||||
def approve_access_request(
|
||||
self,
|
||||
request_id: str,
|
||||
approved_by: str,
|
||||
expires_hours: int = 24
|
||||
self, request_id: str, approved_by: str, expires_hours: int = 24
|
||||
) -> AccessRequest | None:
|
||||
"""批准访问请求"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -1150,11 +1128,14 @@ class SecurityManager:
|
||||
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
||||
approved_at = datetime.now().isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE access_requests
|
||||
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
|
||||
WHERE id = ?
|
||||
""", (approved_by, approved_at, expires_at, request_id))
|
||||
""",
|
||||
(approved_by, approved_at, expires_at, request_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -1175,23 +1156,22 @@ class SecurityManager:
|
||||
approved_by=row[5],
|
||||
approved_at=row[6],
|
||||
expires_at=row[7],
|
||||
created_at=row[8]
|
||||
created_at=row[8],
|
||||
)
|
||||
|
||||
def reject_access_request(
|
||||
self,
|
||||
request_id: str,
|
||||
rejected_by: str
|
||||
) -> AccessRequest | None:
|
||||
def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
|
||||
"""拒绝访问请求"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE access_requests
|
||||
SET status = 'rejected', approved_by = ?
|
||||
WHERE id = ?
|
||||
""", (rejected_by, request_id))
|
||||
""",
|
||||
(rejected_by, request_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -1211,7 +1191,7 @@ class SecurityManager:
|
||||
approved_by=row[5],
|
||||
approved_at=row[6],
|
||||
expires_at=row[7],
|
||||
created_at=row[8]
|
||||
created_at=row[8],
|
||||
)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ class TingwuClient:
|
||||
|
||||
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
|
||||
"""阿里云签名 V3"""
|
||||
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# 简化签名,实际生产需要完整实现
|
||||
# 这里使用基础认证头
|
||||
@@ -39,25 +39,16 @@ class TingwuClient:
|
||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||
|
||||
config = open_api_models.Config(
|
||||
access_key_id=self.access_key,
|
||||
access_key_secret=self.secret_key
|
||||
)
|
||||
config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
|
||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||
client = TingwuSDKClient(config)
|
||||
|
||||
request = tingwu_models.CreateTaskRequest(
|
||||
type="offline",
|
||||
input=tingwu_models.Input(
|
||||
source="OSS",
|
||||
file_url=audio_url
|
||||
),
|
||||
input=tingwu_models.Input(source="OSS", file_url=audio_url),
|
||||
parameters=tingwu_models.Parameters(
|
||||
transcription=tingwu_models.Transcription(
|
||||
diarization_enabled=True,
|
||||
sentence_max_length=20
|
||||
)
|
||||
)
|
||||
transcription=tingwu_models.Transcription(diarization_enabled=True, sentence_max_length=20)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.create_task(request)
|
||||
@@ -81,10 +72,7 @@ class TingwuClient:
|
||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||
|
||||
config = open_api_models.Config(
|
||||
access_key_id=self.access_key,
|
||||
access_key_secret=self.secret_key
|
||||
)
|
||||
config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
|
||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||
client = TingwuSDKClient(config)
|
||||
|
||||
@@ -128,25 +116,29 @@ class TingwuClient:
|
||||
|
||||
if transcription.sentences:
|
||||
for sent in transcription.sentences:
|
||||
segments.append({
|
||||
segments.append(
|
||||
{
|
||||
"start": sent.begin_time / 1000,
|
||||
"end": sent.end_time / 1000,
|
||||
"text": sent.text,
|
||||
"speaker": f"Speaker {sent.speaker_id}"
|
||||
})
|
||||
|
||||
return {
|
||||
"full_text": full_text.strip(),
|
||||
"segments": segments
|
||||
"speaker": f"Speaker {sent.speaker_id}",
|
||||
}
|
||||
)
|
||||
|
||||
return {"full_text": full_text.strip(), "segments": segments}
|
||||
|
||||
def _mock_result(self) -> dict[str, Any]:
|
||||
"""Mock 结果"""
|
||||
return {
|
||||
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
|
||||
"segments": [
|
||||
{"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"}
|
||||
]
|
||||
{
|
||||
"start": 0.0,
|
||||
"end": 5.0,
|
||||
"text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
|
||||
"speaker": "Speaker A",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user