fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解
This commit is contained in:
@@ -209,7 +209,7 @@ class AIManager:
|
|||||||
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
|
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
|
||||||
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
|
|
||||||
def _get_db(self):
|
def _get_db(self) -> None:
|
||||||
"""获取数据库连接"""
|
"""获取数据库连接"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
@@ -921,7 +921,6 @@ class AIManager:
|
|||||||
|
|
||||||
# 解析关键要点
|
# 解析关键要点
|
||||||
key_points = []
|
key_points = []
|
||||||
import re
|
|
||||||
|
|
||||||
# 尝试从 JSON 中提取
|
# 尝试从 JSON 中提取
|
||||||
json_match = re.search(r"\{.*?\}", content, re.DOTALL)
|
json_match = re.search(r"\{.*?\}", content, re.DOTALL)
|
||||||
@@ -1312,7 +1311,7 @@ class AIManager:
|
|||||||
|
|
||||||
return [self._row_to_prediction_result(row) for row in rows]
|
return [self._row_to_prediction_result(row) for row in rows]
|
||||||
|
|
||||||
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool):
|
def update_prediction_feedback(self, prediction_id: str, actual_value: str, is_correct: bool) -> None:
|
||||||
"""更新预测反馈(用于模型改进)"""
|
"""更新预测反馈(用于模型改进)"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class ApiKeyManager:
|
|||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self) -> None:
|
||||||
"""初始化数据库表"""
|
"""初始化数据库表"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.executescript("""
|
conn.executescript("""
|
||||||
@@ -331,7 +331,7 @@ class ApiKeyManager:
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def update_last_used(self, key_id: str):
|
def update_last_used(self, key_id: str) -> None:
|
||||||
"""更新最后使用时间"""
|
"""更新最后使用时间"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ class CollaborationManager:
|
|||||||
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
|
data = f"{project_id}:{datetime.now().timestamp()}:{uuid.uuid4()}"
|
||||||
return hashlib.sha256(data.encode()).hexdigest()[:32]
|
return hashlib.sha256(data.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
def _save_share_to_db(self, share: ProjectShare):
|
def _save_share_to_db(self, share: ProjectShare) -> None:
|
||||||
"""保存分享记录到数据库"""
|
"""保存分享记录到数据库"""
|
||||||
cursor = self.db.conn.cursor()
|
cursor = self.db.conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -286,7 +286,7 @@ class CollaborationManager:
|
|||||||
allow_export=bool(row[12]),
|
allow_export=bool(row[12]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def increment_share_usage(self, token: str):
|
def increment_share_usage(self, token: str) -> None:
|
||||||
"""增加分享链接使用次数"""
|
"""增加分享链接使用次数"""
|
||||||
share = self._shares_cache.get(token)
|
share = self._shares_cache.get(token)
|
||||||
if share:
|
if share:
|
||||||
@@ -403,7 +403,7 @@ class CollaborationManager:
|
|||||||
|
|
||||||
return comment
|
return comment
|
||||||
|
|
||||||
def _save_comment_to_db(self, comment: Comment):
|
def _save_comment_to_db(self, comment: Comment) -> None:
|
||||||
"""保存评论到数据库"""
|
"""保存评论到数据库"""
|
||||||
cursor = self.db.conn.cursor()
|
cursor = self.db.conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -616,7 +616,7 @@ class CollaborationManager:
|
|||||||
|
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def _save_change_to_db(self, record: ChangeRecord):
|
def _save_change_to_db(self, record: ChangeRecord) -> None:
|
||||||
"""保存变更记录到数据库"""
|
"""保存变更记录到数据库"""
|
||||||
cursor = self.db.conn.cursor()
|
cursor = self.db.conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -862,7 +862,7 @@ class CollaborationManager:
|
|||||||
}
|
}
|
||||||
return permissions_map.get(role, ["read"])
|
return permissions_map.get(role, ["read"])
|
||||||
|
|
||||||
def _save_member_to_db(self, member: TeamMember):
|
def _save_member_to_db(self, member: TeamMember) -> None:
|
||||||
"""保存成员到数据库"""
|
"""保存成员到数据库"""
|
||||||
cursor = self.db.conn.cursor()
|
cursor = self.db.conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -970,7 +970,7 @@ class CollaborationManager:
|
|||||||
permissions = json.loads(row[0]) if row[0] else []
|
permissions = json.loads(row[0]) if row[0] else []
|
||||||
return permission in permissions or "admin" in permissions
|
return permission in permissions or "admin" in permissions
|
||||||
|
|
||||||
def update_last_active(self, project_id: str, user_id: str):
|
def update_last_active(self, project_id: str, user_id: str) -> None:
|
||||||
"""更新用户最后活跃时间"""
|
"""更新用户最后活跃时间"""
|
||||||
if not self.db:
|
if not self.db:
|
||||||
return
|
return
|
||||||
@@ -992,7 +992,7 @@ class CollaborationManager:
|
|||||||
_collaboration_manager = None
|
_collaboration_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_collaboration_manager(db_manager=None):
|
def get_collaboration_manager(db_manager=None) -> None:
|
||||||
"""获取协作管理器单例"""
|
"""获取协作管理器单例"""
|
||||||
global _collaboration_manager
|
global _collaboration_manager
|
||||||
if _collaboration_manager is None:
|
if _collaboration_manager is None:
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class DatabaseManager:
|
|||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def init_db(self):
|
def init_db(self) -> None:
|
||||||
"""初始化数据库表"""
|
"""初始化数据库表"""
|
||||||
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
|
with open(os.path.join(os.path.dirname(__file__), "schema.sql")) as f:
|
||||||
schema = f.read()
|
schema = f.read()
|
||||||
@@ -299,7 +299,7 @@ class DatabaseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
return self.get_entity(entity_id)
|
return self.get_entity(entity_id)
|
||||||
|
|
||||||
def delete_entity(self, entity_id: str):
|
def delete_entity(self, entity_id: str) -> None:
|
||||||
"""删除实体及其关联数据"""
|
"""删除实体及其关联数据"""
|
||||||
conn = self.get_conn()
|
conn = self.get_conn()
|
||||||
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
|
conn.execute("DELETE FROM entity_mentions WHERE entity_id = ?", (entity_id,))
|
||||||
|
|||||||
@@ -352,7 +352,7 @@ class DeveloperEcosystemManager:
|
|||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.platform_fee_rate = 0.30 # 平台抽成比例 30%
|
self.platform_fee_rate = 0.30 # 平台抽成比例 30%
|
||||||
|
|
||||||
def _get_db(self):
|
def _get_db(self) -> None:
|
||||||
"""获取数据库连接"""
|
"""获取数据库连接"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
@@ -515,7 +515,7 @@ class DeveloperEcosystemManager:
|
|||||||
|
|
||||||
return self.get_sdk_release(sdk_id)
|
return self.get_sdk_release(sdk_id)
|
||||||
|
|
||||||
def increment_sdk_download(self, sdk_id: str):
|
def increment_sdk_download(self, sdk_id: str) -> None:
|
||||||
"""增加 SDK 下载计数"""
|
"""增加 SDK 下载计数"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -785,7 +785,7 @@ class DeveloperEcosystemManager:
|
|||||||
|
|
||||||
return self.get_template(template_id)
|
return self.get_template(template_id)
|
||||||
|
|
||||||
def increment_template_install(self, template_id: str):
|
def increment_template_install(self, template_id: str) -> None:
|
||||||
"""增加模板安装计数"""
|
"""增加模板安装计数"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -852,7 +852,7 @@ class DeveloperEcosystemManager:
|
|||||||
|
|
||||||
return review
|
return review
|
||||||
|
|
||||||
def _update_template_rating(self, conn, template_id: str):
|
def _update_template_rating(self, conn, template_id: str) -> None:
|
||||||
"""更新模板评分"""
|
"""更新模板评分"""
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -1084,7 +1084,7 @@ class DeveloperEcosystemManager:
|
|||||||
|
|
||||||
return self.get_plugin(plugin_id)
|
return self.get_plugin(plugin_id)
|
||||||
|
|
||||||
def increment_plugin_install(self, plugin_id: str, active: bool = True):
|
def increment_plugin_install(self, plugin_id: str, active: bool = True) -> None:
|
||||||
"""增加插件安装计数"""
|
"""增加插件安装计数"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -1160,7 +1160,7 @@ class DeveloperEcosystemManager:
|
|||||||
|
|
||||||
return review
|
return review
|
||||||
|
|
||||||
def _update_plugin_rating(self, conn, plugin_id: str):
|
def _update_plugin_rating(self, conn, plugin_id: str) -> None:
|
||||||
"""更新插件评分"""
|
"""更新插件评分"""
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -1421,7 +1421,7 @@ class DeveloperEcosystemManager:
|
|||||||
|
|
||||||
return self.get_developer_profile(developer_id)
|
return self.get_developer_profile(developer_id)
|
||||||
|
|
||||||
def update_developer_stats(self, developer_id: str):
|
def update_developer_stats(self, developer_id: str) -> None:
|
||||||
"""更新开发者统计信息"""
|
"""更新开发者统计信息"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
# 统计插件数量
|
# 统计插件数量
|
||||||
@@ -1574,7 +1574,7 @@ class DeveloperEcosystemManager:
|
|||||||
rows = conn.execute(query, params).fetchall()
|
rows = conn.execute(query, params).fetchall()
|
||||||
return [self._row_to_code_example(row) for row in rows]
|
return [self._row_to_code_example(row) for row in rows]
|
||||||
|
|
||||||
def increment_example_view(self, example_id: str):
|
def increment_example_view(self, example_id: str) -> None:
|
||||||
"""增加代码示例查看计数"""
|
"""增加代码示例查看计数"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -1587,7 +1587,7 @@ class DeveloperEcosystemManager:
|
|||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def increment_example_copy(self, example_id: str):
|
def increment_example_copy(self, example_id: str) -> None:
|
||||||
"""增加代码示例复制计数"""
|
"""增加代码示例复制计数"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
|
|||||||
@@ -329,7 +329,7 @@ class EnterpriseManager:
|
|||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self) -> None:
|
||||||
"""初始化数据库表"""
|
"""初始化数据库表"""
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -1190,7 +1190,7 @@ class EnterpriseManager:
|
|||||||
# GET {scim_base_url}/Users
|
# GET {scim_base_url}/Users
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]):
|
def _upsert_scim_user(self, conn: sqlite3.Connection, tenant_id: str, user_data: dict[str, Any]) -> None:
|
||||||
"""插入或更新 SCIM 用户"""
|
"""插入或更新 SCIM 用户"""
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
|||||||
@@ -22,14 +22,7 @@ try:
|
|||||||
from reportlab.lib.pagesizes import A4
|
from reportlab.lib.pagesizes import A4
|
||||||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||||||
from reportlab.lib.units import inch
|
from reportlab.lib.units import inch
|
||||||
from reportlab.platypus import (
|
from reportlab.platypus import PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle
|
||||||
PageBreak,
|
|
||||||
Paragraph,
|
|
||||||
SimpleDocTemplate,
|
|
||||||
Spacer,
|
|
||||||
Table,
|
|
||||||
TableStyle,
|
|
||||||
)
|
|
||||||
|
|
||||||
REPORTLAB_AVAILABLE = True
|
REPORTLAB_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -194,20 +187,16 @@ class ExportManager:
|
|||||||
f'fill="white" stroke="#bdc3c7" rx="5"/>'
|
f'fill="white" stroke="#bdc3c7" rx="5"/>'
|
||||||
)
|
)
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" '
|
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" ' f'fill="#2c3e50">实体类型</text>'
|
||||||
f'fill="#2c3e50">实体类型</text>'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (etype, color) in enumerate(type_colors.items()):
|
for i, (etype, color) in enumerate(type_colors.items()):
|
||||||
if etype != "default":
|
if etype != "default":
|
||||||
y_pos = legend_y + 25 + i * 20
|
y_pos = legend_y + 25 + i * 20
|
||||||
svg_parts.append(
|
svg_parts.append(f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>')
|
||||||
f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>'
|
|
||||||
)
|
|
||||||
text_y = y_pos + 4
|
text_y = y_pos + 4
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" '
|
f'<text x="{legend_x + 25}" y="{text_y}" font-size="10" ' f'fill="#2c3e50">{etype}</text>'
|
||||||
f'fill="#2c3e50">{etype}</text>'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
svg_parts.append("</svg>")
|
svg_parts.append("</svg>")
|
||||||
@@ -320,7 +309,6 @@ class ExportManager:
|
|||||||
Returns:
|
Returns:
|
||||||
CSV 字符串
|
CSV 字符串
|
||||||
"""
|
"""
|
||||||
import csv
|
|
||||||
|
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = csv.writer(output)
|
writer = csv.writer(output)
|
||||||
@@ -586,7 +574,7 @@ class ExportManager:
|
|||||||
_export_manager = None
|
_export_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_export_manager(db_manager=None):
|
def get_export_manager(db_manager=None) -> None:
|
||||||
"""获取导出管理器实例"""
|
"""获取导出管理器实例"""
|
||||||
global _export_manager
|
global _export_manager
|
||||||
if _export_manager is None:
|
if _export_manager is None:
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class GrowthManager:
|
|||||||
self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "")
|
self.segment_write_key = os.getenv("SEGMENT_WRITE_KEY", "")
|
||||||
self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "")
|
self.sendgrid_api_key = os.getenv("SENDGRID_API_KEY", "")
|
||||||
|
|
||||||
def _get_db(self):
|
def _get_db(self) -> None:
|
||||||
"""获取数据库连接"""
|
"""获取数据库连接"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class ImageProcessor:
|
|||||||
"other": "其他",
|
"other": "其他",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None):
|
def __init__(self, temp_dir: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
初始化图片处理器
|
初始化图片处理器
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ class ImageProcessor:
|
|||||||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||||||
os.makedirs(self.temp_dir, exist_ok=True)
|
os.makedirs(self.temp_dir, exist_ok=True)
|
||||||
|
|
||||||
def preprocess_image(self, image, image_type: str = None):
|
def preprocess_image(self, image, image_type: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
预处理图片以提高OCR质量
|
预处理图片以提高OCR质量
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ class ImageProcessor:
|
|||||||
print(f"Image preprocessing error: {e}")
|
print(f"Image preprocessing error: {e}")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _enhance_whiteboard(self, image):
|
def _enhance_whiteboard(self, image) -> None:
|
||||||
"""增强白板图片"""
|
"""增强白板图片"""
|
||||||
# 转换为灰度
|
# 转换为灰度
|
||||||
gray = image.convert("L")
|
gray = image.convert("L")
|
||||||
@@ -160,7 +160,7 @@ class ImageProcessor:
|
|||||||
|
|
||||||
return binary.convert("L")
|
return binary.convert("L")
|
||||||
|
|
||||||
def _enhance_handwritten(self, image):
|
def _enhance_handwritten(self, image) -> None:
|
||||||
"""增强手写笔记图片"""
|
"""增强手写笔记图片"""
|
||||||
# 转换为灰度
|
# 转换为灰度
|
||||||
gray = image.convert("L")
|
gray = image.convert("L")
|
||||||
|
|||||||
@@ -163,8 +163,6 @@ class KnowledgeReasoner:
|
|||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
@@ -217,8 +215,6 @@ class KnowledgeReasoner:
|
|||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
@@ -271,8 +267,6 @@ class KnowledgeReasoner:
|
|||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
@@ -325,8 +319,6 @@ class KnowledgeReasoner:
|
|||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.4)
|
content = await self._call_llm(prompt, temperature=0.4)
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
@@ -474,8 +466,6 @@ class KnowledgeReasoner:
|
|||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
|
|||||||
@@ -204,15 +204,13 @@ class LLMClient:
|
|||||||
messages = [ChatMessage(role="user", content=prompt)]
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
content = await self.chat(messages, temperature=0.1)
|
content = await self.chat(messages, temperature=0.1)
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
if not json_match:
|
if not json_match:
|
||||||
return {"intent": "unknown", "explanation": "无法解析指令"}
|
return {"intent": "unknown", "explanation": "无法解析指令"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return json.loads(json_match.group())
|
return json.loads(json_match.group())
|
||||||
except BaseException:
|
except (json.JSONDecodeError, KeyError, TypeError):
|
||||||
return {"intent": "unknown", "explanation": "解析失败"}
|
return {"intent": "unknown", "explanation": "解析失败"}
|
||||||
|
|
||||||
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
|
async def analyze_entity_evolution(self, entity_name: str, mentions: list[dict]) -> str:
|
||||||
|
|||||||
@@ -938,9 +938,7 @@ class LocalizationManager:
|
|||||||
finally:
|
finally:
|
||||||
self._close_if_file_db(conn)
|
self._close_if_file_db(conn)
|
||||||
|
|
||||||
def get_translation(
|
def get_translation(self, key: str, language: str, namespace: str = "common", fallback: bool = True) -> str | None:
|
||||||
self, key: str, language: str, namespace: str = "common", fallback: bool = True
|
|
||||||
) -> str | None:
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|||||||
3741
backend/main.py
3741
backend/main.py
File diff suppressed because it is too large
Load Diff
@@ -80,7 +80,7 @@ class MultimodalEntityLinker:
|
|||||||
# 模态类型
|
# 模态类型
|
||||||
MODALITIES = ["audio", "video", "image", "document"]
|
MODALITIES = ["audio", "video", "image", "document"]
|
||||||
|
|
||||||
def __init__(self, similarity_threshold: float = 0.85):
|
def __init__(self, similarity_threshold: float = 0.85) -> None:
|
||||||
"""
|
"""
|
||||||
初始化多模态实体关联器
|
初始化多模态实体关联器
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class VideoProcessingResult:
|
|||||||
class MultimodalProcessor:
|
class MultimodalProcessor:
|
||||||
"""多模态处理器 - 处理视频文件"""
|
"""多模态处理器 - 处理视频文件"""
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5):
|
def __init__(self, temp_dir: str = None, frame_interval: int = 5) -> None:
|
||||||
"""
|
"""
|
||||||
初始化多模态处理器
|
初始化多模态处理器
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class MultimodalProcessor:
|
|||||||
error_message=str(e),
|
error_message=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
def cleanup(self, video_id: str = None):
|
def cleanup(self, video_id: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
清理临时文件
|
清理临时文件
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class Neo4jManager:
|
|||||||
|
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
def _connect(self):
|
def _connect(self) -> None:
|
||||||
"""建立 Neo4j 连接"""
|
"""建立 Neo4j 连接"""
|
||||||
if not NEO4J_AVAILABLE:
|
if not NEO4J_AVAILABLE:
|
||||||
return
|
return
|
||||||
@@ -121,7 +121,7 @@ class Neo4jManager:
|
|||||||
logger.error(f"Failed to connect to Neo4j: {e}")
|
logger.error(f"Failed to connect to Neo4j: {e}")
|
||||||
self._driver = None
|
self._driver = None
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""关闭连接"""
|
"""关闭连接"""
|
||||||
if self._driver:
|
if self._driver:
|
||||||
self._driver.close()
|
self._driver.close()
|
||||||
@@ -137,7 +137,7 @@ class Neo4jManager:
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def init_schema(self):
|
def init_schema(self) -> None:
|
||||||
"""初始化图数据库 Schema(约束和索引)"""
|
"""初始化图数据库 Schema(约束和索引)"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
logger.error("Neo4j not connected")
|
logger.error("Neo4j not connected")
|
||||||
@@ -178,7 +178,7 @@ class Neo4jManager:
|
|||||||
|
|
||||||
# ==================== 数据同步 ====================
|
# ==================== 数据同步 ====================
|
||||||
|
|
||||||
def sync_project(self, project_id: str, project_name: str, project_description: str = ""):
|
def sync_project(self, project_id: str, project_name: str, project_description: str = "") -> None:
|
||||||
"""同步项目节点到 Neo4j"""
|
"""同步项目节点到 Neo4j"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
return
|
return
|
||||||
@@ -196,7 +196,7 @@ class Neo4jManager:
|
|||||||
description=project_description,
|
description=project_description,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_entity(self, entity: GraphEntity):
|
def sync_entity(self, entity: GraphEntity) -> None:
|
||||||
"""同步单个实体到 Neo4j"""
|
"""同步单个实体到 Neo4j"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
return
|
return
|
||||||
@@ -225,7 +225,7 @@ class Neo4jManager:
|
|||||||
properties=json.dumps(entity.properties),
|
properties=json.dumps(entity.properties),
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_entities_batch(self, entities: list[GraphEntity]):
|
def sync_entities_batch(self, entities: list[GraphEntity]) -> None:
|
||||||
"""批量同步实体到 Neo4j"""
|
"""批量同步实体到 Neo4j"""
|
||||||
if not self._driver or not entities:
|
if not self._driver or not entities:
|
||||||
return
|
return
|
||||||
@@ -262,7 +262,7 @@ class Neo4jManager:
|
|||||||
entities=entities_data,
|
entities=entities_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_relation(self, relation: GraphRelation):
|
def sync_relation(self, relation: GraphRelation) -> None:
|
||||||
"""同步单个关系到 Neo4j"""
|
"""同步单个关系到 Neo4j"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
return
|
return
|
||||||
@@ -286,7 +286,7 @@ class Neo4jManager:
|
|||||||
properties=json.dumps(relation.properties),
|
properties=json.dumps(relation.properties),
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_relations_batch(self, relations: list[GraphRelation]):
|
def sync_relations_batch(self, relations: list[GraphRelation]) -> None:
|
||||||
"""批量同步关系到 Neo4j"""
|
"""批量同步关系到 Neo4j"""
|
||||||
if not self._driver or not relations:
|
if not self._driver or not relations:
|
||||||
return
|
return
|
||||||
@@ -318,7 +318,7 @@ class Neo4jManager:
|
|||||||
relations=relations_data,
|
relations=relations_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_entity(self, entity_id: str):
|
def delete_entity(self, entity_id: str) -> None:
|
||||||
"""从 Neo4j 删除实体及其关系"""
|
"""从 Neo4j 删除实体及其关系"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
return
|
return
|
||||||
@@ -332,7 +332,7 @@ class Neo4jManager:
|
|||||||
id=entity_id,
|
id=entity_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_project(self, project_id: str):
|
def delete_project(self, project_id: str) -> None:
|
||||||
"""从 Neo4j 删除项目及其所有实体和关系"""
|
"""从 Neo4j 删除项目及其所有实体和关系"""
|
||||||
if not self._driver:
|
if not self._driver:
|
||||||
return
|
return
|
||||||
@@ -949,7 +949,7 @@ def get_neo4j_manager() -> Neo4jManager:
|
|||||||
return _neo4j_manager
|
return _neo4j_manager
|
||||||
|
|
||||||
|
|
||||||
def close_neo4j_manager():
|
def close_neo4j_manager() -> None:
|
||||||
"""关闭 Neo4j 连接"""
|
"""关闭 Neo4j 连接"""
|
||||||
global _neo4j_manager
|
global _neo4j_manager
|
||||||
if _neo4j_manager:
|
if _neo4j_manager:
|
||||||
@@ -958,7 +958,7 @@ def close_neo4j_manager():
|
|||||||
|
|
||||||
|
|
||||||
# 便捷函数
|
# 便捷函数
|
||||||
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]):
|
def sync_project_to_neo4j(project_id: str, project_name: str, entities: list[dict], relations: list[dict]) -> None:
|
||||||
"""
|
"""
|
||||||
同步整个项目到 Neo4j
|
同步整个项目到 Neo4j
|
||||||
|
|
||||||
|
|||||||
@@ -456,13 +456,13 @@ class OpsManager:
|
|||||||
self._evaluator_thread = None
|
self._evaluator_thread = None
|
||||||
self._register_default_evaluators()
|
self._register_default_evaluators()
|
||||||
|
|
||||||
def _get_db(self):
|
def _get_db(self) -> None:
|
||||||
"""获取数据库连接"""
|
"""获取数据库连接"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def _register_default_evaluators(self):
|
def _register_default_evaluators(self) -> None:
|
||||||
"""注册默认的告警评估器"""
|
"""注册默认的告警评估器"""
|
||||||
self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule
|
self._alert_evaluators[AlertRuleType.THRESHOLD.value] = self._evaluate_threshold_rule
|
||||||
self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule
|
self._alert_evaluators[AlertRuleType.ANOMALY.value] = self._evaluate_anomaly_rule
|
||||||
@@ -1249,7 +1249,7 @@ class OpsManager:
|
|||||||
|
|
||||||
return self.get_alert(alert_id)
|
return self.get_alert(alert_id)
|
||||||
|
|
||||||
def _increment_suppression_count(self, alert_id: str):
|
def _increment_suppression_count(self, alert_id: str) -> None:
|
||||||
"""增加告警抑制计数"""
|
"""增加告警抑制计数"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -1262,7 +1262,7 @@ class OpsManager:
|
|||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool):
|
def _update_alert_notification_status(self, alert_id: str, channel_id: str, success: bool) -> None:
|
||||||
"""更新告警通知状态"""
|
"""更新告警通知状态"""
|
||||||
with self._get_db() as conn:
|
with self._get_db() as conn:
|
||||||
row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone()
|
row = conn.execute("SELECT notification_sent FROM alerts WHERE id = ?", (alert_id,)).fetchone()
|
||||||
@@ -1276,7 +1276,7 @@ class OpsManager:
|
|||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def _update_channel_stats(self, channel_id: str, success: bool):
|
def _update_channel_stats(self, channel_id: str, success: bool) -> None:
|
||||||
"""更新渠道统计"""
|
"""更新渠道统计"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
|
|
||||||
@@ -1769,9 +1769,7 @@ class OpsManager:
|
|||||||
return self._row_to_scaling_event(row)
|
return self._row_to_scaling_event(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_scaling_event_status(
|
def update_scaling_event_status(self, event_id: str, status: str, error_message: str = None) -> ScalingEvent | None:
|
||||||
self, event_id: str, status: str, error_message: str = None
|
|
||||||
) -> ScalingEvent | None:
|
|
||||||
"""更新扩缩容事件状态"""
|
"""更新扩缩容事件状态"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
|
|
||||||
@@ -2339,7 +2337,7 @@ class OpsManager:
|
|||||||
|
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None):
|
def _complete_backup(self, record_id: str, size_bytes: int, checksum: str = None) -> None:
|
||||||
"""完成备份"""
|
"""完成备份"""
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]
|
checksum = checksum or hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class OSSUploader:
|
|||||||
url = self.bucket.sign_url("GET", object_name, 3600)
|
url = self.bucket.sign_url("GET", object_name, 3600)
|
||||||
return url, object_name
|
return url, object_name
|
||||||
|
|
||||||
def delete_object(self, object_name: str):
|
def delete_object(self, object_name: str) -> None:
|
||||||
"""删除 OSS 对象"""
|
"""删除 OSS 对象"""
|
||||||
self.bucket.delete_object(object_name)
|
self.bucket.delete_object(object_name)
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class CacheStats:
|
|||||||
expired: int = 0
|
expired: int = 0
|
||||||
hit_rate: float = 0.0
|
hit_rate: float = 0.0
|
||||||
|
|
||||||
def update_hit_rate(self):
|
def update_hit_rate(self) -> None:
|
||||||
"""更新命中率"""
|
"""更新命中率"""
|
||||||
if self.total_requests > 0:
|
if self.total_requests > 0:
|
||||||
self.hit_rate = round(self.hits / self.total_requests, 4)
|
self.hit_rate = round(self.hits / self.total_requests, 4)
|
||||||
@@ -194,7 +194,7 @@ class CacheManager:
|
|||||||
# 初始化缓存统计表
|
# 初始化缓存统计表
|
||||||
self._init_cache_tables()
|
self._init_cache_tables()
|
||||||
|
|
||||||
def _init_cache_tables(self):
|
def _init_cache_tables(self) -> None:
|
||||||
"""初始化缓存统计表"""
|
"""初始化缓存统计表"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
@@ -234,7 +234,7 @@ class CacheManager:
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
return 1024 # 默认估算
|
return 1024 # 默认估算
|
||||||
|
|
||||||
def _evict_lru(self, required_space: int = 0):
|
def _evict_lru(self, required_space: int = 0) -> None:
|
||||||
"""LRU 淘汰策略"""
|
"""LRU 淘汰策略"""
|
||||||
with self.cache_lock:
|
with self.cache_lock:
|
||||||
while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache:
|
while self.current_memory_size + required_space > self.max_memory_size and self.memory_cache:
|
||||||
@@ -444,7 +444,7 @@ class CacheManager:
|
|||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def save_stats(self):
|
def save_stats(self) -> None:
|
||||||
"""保存缓存统计到数据库"""
|
"""保存缓存统计到数据库"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
@@ -618,7 +618,7 @@ class DatabaseSharding:
|
|||||||
# 初始化分片
|
# 初始化分片
|
||||||
self._init_shards()
|
self._init_shards()
|
||||||
|
|
||||||
def _init_shards(self):
|
def _init_shards(self) -> None:
|
||||||
"""初始化分片"""
|
"""初始化分片"""
|
||||||
# 计算每个分片的 key 范围
|
# 计算每个分片的 key 范围
|
||||||
chars = "0123456789abcdef"
|
chars = "0123456789abcdef"
|
||||||
@@ -645,7 +645,7 @@ class DatabaseSharding:
|
|||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
self._create_shard_db(db_path)
|
self._create_shard_db(db_path)
|
||||||
|
|
||||||
def _create_shard_db(self, db_path: str):
|
def _create_shard_db(self, db_path: str) -> None:
|
||||||
"""创建分片数据库"""
|
"""创建分片数据库"""
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
|
|
||||||
@@ -792,7 +792,7 @@ class DatabaseSharding:
|
|||||||
print(f"迁移失败: {e}")
|
print(f"迁移失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _update_shard_stats(self, shard_id: str):
|
def _update_shard_stats(self, shard_id: str) -> None:
|
||||||
"""更新分片统计"""
|
"""更新分片统计"""
|
||||||
shard_info = self.shard_map.get(shard_id)
|
shard_info = self.shard_map.get(shard_id)
|
||||||
if not shard_info:
|
if not shard_info:
|
||||||
@@ -923,7 +923,7 @@ class TaskQueue:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Celery 初始化失败,使用内存任务队列: {e}")
|
print(f"Celery 初始化失败,使用内存任务队列: {e}")
|
||||||
|
|
||||||
def _init_task_tables(self):
|
def _init_task_tables(self) -> None:
|
||||||
"""初始化任务队列表"""
|
"""初始化任务队列表"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
@@ -953,7 +953,7 @@ class TaskQueue:
|
|||||||
"""检查任务队列是否可用"""
|
"""检查任务队列是否可用"""
|
||||||
return self.use_celery or True # 内存模式也可用
|
return self.use_celery or True # 内存模式也可用
|
||||||
|
|
||||||
def register_handler(self, task_type: str, handler: Callable):
|
def register_handler(self, task_type: str, handler: Callable) -> None:
|
||||||
"""注册任务处理器"""
|
"""注册任务处理器"""
|
||||||
self.task_handlers[task_type] = handler
|
self.task_handlers[task_type] = handler
|
||||||
|
|
||||||
@@ -1014,7 +1014,7 @@ class TaskQueue:
|
|||||||
|
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
def _execute_task(self, task_id: str):
|
def _execute_task(self, task_id: str) -> None:
|
||||||
"""执行任务(内存模式)"""
|
"""执行任务(内存模式)"""
|
||||||
with self.task_lock:
|
with self.task_lock:
|
||||||
task = self.tasks.get(task_id)
|
task = self.tasks.get(task_id)
|
||||||
@@ -1055,7 +1055,7 @@ class TaskQueue:
|
|||||||
|
|
||||||
self._update_task_status(task)
|
self._update_task_status(task)
|
||||||
|
|
||||||
def _save_task(self, task: TaskInfo):
|
def _save_task(self, task: TaskInfo) -> None:
|
||||||
"""保存任务到数据库"""
|
"""保存任务到数据库"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
@@ -1084,7 +1084,7 @@ class TaskQueue:
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _update_task_status(self, task: TaskInfo):
|
def _update_task_status(self, task: TaskInfo) -> None:
|
||||||
"""更新任务状态"""
|
"""更新任务状态"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
@@ -1143,9 +1143,7 @@ class TaskQueue:
|
|||||||
with self.task_lock:
|
with self.task_lock:
|
||||||
return self.tasks.get(task_id)
|
return self.tasks.get(task_id)
|
||||||
|
|
||||||
def list_tasks(
|
def list_tasks(self, status: str | None = None, task_type: str | None = None, limit: int = 100) -> list[TaskInfo]:
|
||||||
self, status: str | None = None, task_type: str | None = None, limit: int = 100
|
|
||||||
) -> list[TaskInfo]:
|
|
||||||
"""列出任务"""
|
"""列出任务"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
@@ -1333,7 +1331,7 @@ class PerformanceMonitor:
|
|||||||
if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
|
if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
|
||||||
self._record_slow_query(metric)
|
self._record_slow_query(metric)
|
||||||
|
|
||||||
def _flush_metrics(self):
|
def _flush_metrics(self) -> None:
|
||||||
"""将缓冲区指标写入数据库"""
|
"""将缓冲区指标写入数据库"""
|
||||||
if not self.metrics_buffer:
|
if not self.metrics_buffer:
|
||||||
return
|
return
|
||||||
@@ -1362,12 +1360,12 @@ class PerformanceMonitor:
|
|||||||
|
|
||||||
self.metrics_buffer = []
|
self.metrics_buffer = []
|
||||||
|
|
||||||
def _record_slow_query(self, metric: PerformanceMetric):
|
def _record_slow_query(self, metric: PerformanceMetric) -> None:
|
||||||
"""记录慢查询"""
|
"""记录慢查询"""
|
||||||
# 可以发送到专门的慢查询日志或监控系统
|
# 可以发送到专门的慢查询日志或监控系统
|
||||||
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
|
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
|
||||||
|
|
||||||
def _trigger_alert(self, metric: PerformanceMetric):
|
def _trigger_alert(self, metric: PerformanceMetric) -> None:
|
||||||
"""触发告警"""
|
"""触发告警"""
|
||||||
alert_data = {
|
alert_data = {
|
||||||
"type": "performance_alert",
|
"type": "performance_alert",
|
||||||
@@ -1382,7 +1380,7 @@ class PerformanceMonitor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"告警处理失败: {e}")
|
print(f"告警处理失败: {e}")
|
||||||
|
|
||||||
def register_alert_handler(self, handler: Callable):
|
def register_alert_handler(self, handler: Callable) -> None:
|
||||||
"""注册告警处理器"""
|
"""注册告警处理器"""
|
||||||
self.alert_handlers.append(handler)
|
self.alert_handlers.append(handler)
|
||||||
|
|
||||||
@@ -1585,7 +1583,9 @@ class PerformanceMonitor:
|
|||||||
# ==================== 性能装饰器 ====================
|
# ==================== 性能装饰器 ====================
|
||||||
|
|
||||||
|
|
||||||
def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None):
|
def cached(
|
||||||
|
cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, key_func: Callable | None = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
缓存装饰器
|
缓存装饰器
|
||||||
|
|
||||||
@@ -1598,7 +1598,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
|
|||||||
|
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs) -> None:
|
||||||
# 生成缓存键
|
# 生成缓存键
|
||||||
if key_func:
|
if key_func:
|
||||||
cache_key = key_func(*args, **kwargs)
|
cache_key = key_func(*args, **kwargs)
|
||||||
@@ -1625,7 +1625,7 @@ def cached(cache_manager: CacheManager, key_prefix: str = "", ttl: int = 3600, k
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None):
|
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
性能监控装饰器
|
性能监控装饰器
|
||||||
|
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ class PluginManager:
|
|||||||
self._handlers = {}
|
self._handlers = {}
|
||||||
self._register_default_handlers()
|
self._register_default_handlers()
|
||||||
|
|
||||||
def _register_default_handlers(self):
|
def _register_default_handlers(self) -> None:
|
||||||
"""注册默认处理器"""
|
"""注册默认处理器"""
|
||||||
self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self)
|
self._handlers[PluginType.CHROME_EXTENSION] = ChromeExtensionHandler(self)
|
||||||
self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu")
|
self._handlers[PluginType.FEISHU_BOT] = BotHandler(self, "feishu")
|
||||||
@@ -371,7 +371,7 @@ class PluginManager:
|
|||||||
|
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def record_plugin_usage(self, plugin_id: str):
|
def record_plugin_usage(self, plugin_id: str) -> None:
|
||||||
"""记录插件使用"""
|
"""记录插件使用"""
|
||||||
conn = self.db.get_conn()
|
conn = self.db.get_conn()
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
@@ -826,9 +826,6 @@ class BotHandler:
|
|||||||
|
|
||||||
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
|
async def _send_feishu_message(self, session: BotSession, message: str, msg_type: str) -> bool:
|
||||||
"""发送飞书消息"""
|
"""发送飞书消息"""
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
timestamp = str(int(time.time()))
|
timestamp = str(int(time.time()))
|
||||||
|
|
||||||
# 生成签名
|
# 生成签名
|
||||||
@@ -851,9 +848,6 @@ class BotHandler:
|
|||||||
|
|
||||||
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
|
async def _send_dingtalk_message(self, session: BotSession, message: str, msg_type: str) -> bool:
|
||||||
"""发送钉钉消息"""
|
"""发送钉钉消息"""
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
timestamp = str(round(time.time() * 1000))
|
timestamp = str(round(time.time() * 1000))
|
||||||
|
|
||||||
# 生成签名
|
# 生成签名
|
||||||
@@ -1358,7 +1352,7 @@ class WebDAVSyncManager:
|
|||||||
_plugin_manager = None
|
_plugin_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_manager(db_manager=None):
|
def get_plugin_manager(db_manager=None) -> None:
|
||||||
"""获取 PluginManager 单例"""
|
"""获取 PluginManager 单例"""
|
||||||
global _plugin_manager
|
global _plugin_manager
|
||||||
if _plugin_manager is None:
|
if _plugin_manager is None:
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class SlidingWindowCounter:
|
|||||||
self._cleanup_old(now)
|
self._cleanup_old(now)
|
||||||
return sum(self.requests.values())
|
return sum(self.requests.values())
|
||||||
|
|
||||||
def _cleanup_old(self, now: int):
|
def _cleanup_old(self, now: int) -> None:
|
||||||
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
||||||
cutoff = now - self.window_size
|
cutoff = now - self.window_size
|
||||||
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||||
@@ -67,7 +67,7 @@ class SlidingWindowCounter:
|
|||||||
class RateLimiter:
|
class RateLimiter:
|
||||||
"""API 限流器"""
|
"""API 限流器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# key -> SlidingWindowCounter
|
# key -> SlidingWindowCounter
|
||||||
self.counters: dict[str, SlidingWindowCounter] = {}
|
self.counters: dict[str, SlidingWindowCounter] = {}
|
||||||
# key -> RateLimitConfig
|
# key -> RateLimitConfig
|
||||||
@@ -143,7 +143,7 @@ class RateLimiter:
|
|||||||
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
|
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self, key: str | None = None):
|
def reset(self, key: str | None = None) -> None:
|
||||||
"""重置限流计数器"""
|
"""重置限流计数器"""
|
||||||
if key:
|
if key:
|
||||||
self.counters.pop(key, None)
|
self.counters.pop(key, None)
|
||||||
@@ -166,7 +166,7 @@ def get_rate_limiter() -> RateLimiter:
|
|||||||
|
|
||||||
|
|
||||||
# 限流装饰器(用于函数级别限流)
|
# 限流装饰器(用于函数级别限流)
|
||||||
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None):
|
def rate_limit(requests_per_minute: int = 60, key_func: Callable | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
限流装饰器
|
限流装饰器
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ try:
|
|||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from cryptography.hazmat.primitives import hashes
|
from cryptography.hazmat.primitives import hashes
|
||||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
|
||||||
CRYPTO_AVAILABLE = True
|
CRYPTO_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CRYPTO_AVAILABLE = False
|
CRYPTO_AVAILABLE = False
|
||||||
@@ -27,6 +28,7 @@ except ImportError:
|
|||||||
|
|
||||||
class AuditActionType(Enum):
|
class AuditActionType(Enum):
|
||||||
"""审计动作类型"""
|
"""审计动作类型"""
|
||||||
|
|
||||||
CREATE = "create"
|
CREATE = "create"
|
||||||
READ = "read"
|
READ = "read"
|
||||||
UPDATE = "update"
|
UPDATE = "update"
|
||||||
@@ -49,26 +51,29 @@ class AuditActionType(Enum):
|
|||||||
|
|
||||||
class DataSensitivityLevel(Enum):
|
class DataSensitivityLevel(Enum):
|
||||||
"""数据敏感度级别"""
|
"""数据敏感度级别"""
|
||||||
PUBLIC = "public" # 公开
|
|
||||||
INTERNAL = "internal" # 内部
|
PUBLIC = "public" # 公开
|
||||||
|
INTERNAL = "internal" # 内部
|
||||||
CONFIDENTIAL = "confidential" # 机密
|
CONFIDENTIAL = "confidential" # 机密
|
||||||
SECRET = "secret" # 绝密
|
SECRET = "secret" # 绝密
|
||||||
|
|
||||||
|
|
||||||
class MaskingRuleType(Enum):
|
class MaskingRuleType(Enum):
|
||||||
"""脱敏规则类型"""
|
"""脱敏规则类型"""
|
||||||
PHONE = "phone" # 手机号
|
|
||||||
EMAIL = "email" # 邮箱
|
PHONE = "phone" # 手机号
|
||||||
ID_CARD = "id_card" # 身份证号
|
EMAIL = "email" # 邮箱
|
||||||
BANK_CARD = "bank_card" # 银行卡号
|
ID_CARD = "id_card" # 身份证号
|
||||||
NAME = "name" # 姓名
|
BANK_CARD = "bank_card" # 银行卡号
|
||||||
ADDRESS = "address" # 地址
|
NAME = "name" # 姓名
|
||||||
CUSTOM = "custom" # 自定义
|
ADDRESS = "address" # 地址
|
||||||
|
CUSTOM = "custom" # 自定义
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AuditLog:
|
class AuditLog:
|
||||||
"""审计日志条目"""
|
"""审计日志条目"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
action_type: str
|
action_type: str
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
@@ -90,6 +95,7 @@ class AuditLog:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EncryptionConfig:
|
class EncryptionConfig:
|
||||||
"""加密配置"""
|
"""加密配置"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
is_enabled: bool = False
|
is_enabled: bool = False
|
||||||
@@ -107,6 +113,7 @@ class EncryptionConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MaskingRule:
|
class MaskingRule:
|
||||||
"""脱敏规则"""
|
"""脱敏规则"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -126,6 +133,7 @@ class MaskingRule:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DataAccessPolicy:
|
class DataAccessPolicy:
|
||||||
"""数据访问策略"""
|
"""数据访问策略"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -147,6 +155,7 @@ class DataAccessPolicy:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AccessRequest:
|
class AccessRequest:
|
||||||
"""访问请求(用于需要审批的访问)"""
|
"""访问请求(用于需要审批的访问)"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
policy_id: str
|
policy_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
@@ -166,30 +175,15 @@ class SecurityManager:
|
|||||||
|
|
||||||
# 预定义脱敏规则
|
# 预定义脱敏规则
|
||||||
DEFAULT_MASKING_RULES = {
|
DEFAULT_MASKING_RULES = {
|
||||||
MaskingRuleType.PHONE: {
|
MaskingRuleType.PHONE: {"pattern": r"(\d{3})\d{4}(\d{4})", "replacement": r"\1****\2"},
|
||||||
"pattern": r"(\d{3})\d{4}(\d{4})",
|
MaskingRuleType.EMAIL: {"pattern": r"(\w{1,3})\w+(@\w+\.\w+)", "replacement": r"\1***\2"},
|
||||||
"replacement": r"\1****\2"
|
MaskingRuleType.ID_CARD: {"pattern": r"(\d{6})\d{8}(\d{4})", "replacement": r"\1********\2"},
|
||||||
},
|
MaskingRuleType.BANK_CARD: {"pattern": r"(\d{4})\d+(\d{4})", "replacement": r"\1 **** **** \2"},
|
||||||
MaskingRuleType.EMAIL: {
|
MaskingRuleType.NAME: {"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+", "replacement": r"\1**"},
|
||||||
"pattern": r"(\w{1,3})\w+(@\w+\.\w+)",
|
|
||||||
"replacement": r"\1***\2"
|
|
||||||
},
|
|
||||||
MaskingRuleType.ID_CARD: {
|
|
||||||
"pattern": r"(\d{6})\d{8}(\d{4})",
|
|
||||||
"replacement": r"\1********\2"
|
|
||||||
},
|
|
||||||
MaskingRuleType.BANK_CARD: {
|
|
||||||
"pattern": r"(\d{4})\d+(\d{4})",
|
|
||||||
"replacement": r"\1 **** **** \2"
|
|
||||||
},
|
|
||||||
MaskingRuleType.NAME: {
|
|
||||||
"pattern": r"([\u4e00-\u9fa5])[\u4e00-\u9fa5]+",
|
|
||||||
"replacement": r"\1**"
|
|
||||||
},
|
|
||||||
MaskingRuleType.ADDRESS: {
|
MaskingRuleType.ADDRESS: {
|
||||||
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
|
"pattern": r"([\u4e00-\u9fa5]{2,})([\u4e00-\u9fa5]+路|街|巷|号)(.+)",
|
||||||
"replacement": r"\1\2***"
|
"replacement": r"\1\2***",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, db_path: str = "insightflow.db"):
|
def __init__(self, db_path: str = "insightflow.db"):
|
||||||
@@ -200,7 +194,7 @@ class SecurityManager:
|
|||||||
self._local = {}
|
self._local = {}
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self) -> None:
|
||||||
"""初始化数据库表"""
|
"""初始化数据库表"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -308,9 +302,7 @@ class SecurityManager:
|
|||||||
|
|
||||||
def _generate_id(self) -> str:
|
def _generate_id(self) -> str:
|
||||||
"""生成唯一ID"""
|
"""生成唯一ID"""
|
||||||
return hashlib.sha256(
|
return hashlib.sha256(f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()).hexdigest()[:32]
|
||||||
f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
|
|
||||||
).hexdigest()[:32]
|
|
||||||
|
|
||||||
# ==================== 审计日志 ====================
|
# ==================== 审计日志 ====================
|
||||||
|
|
||||||
@@ -326,7 +318,7 @@ class SecurityManager:
|
|||||||
before_value: str | None = None,
|
before_value: str | None = None,
|
||||||
after_value: str | None = None,
|
after_value: str | None = None,
|
||||||
success: bool = True,
|
success: bool = True,
|
||||||
error_message: str | None = None
|
error_message: str | None = None,
|
||||||
) -> AuditLog:
|
) -> AuditLog:
|
||||||
"""记录审计日志"""
|
"""记录审计日志"""
|
||||||
log = AuditLog(
|
log = AuditLog(
|
||||||
@@ -341,22 +333,34 @@ class SecurityManager:
|
|||||||
before_value=before_value,
|
before_value=before_value,
|
||||||
after_value=after_value,
|
after_value=after_value,
|
||||||
success=success,
|
success=success,
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT INTO audit_logs
|
INSERT INTO audit_logs
|
||||||
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
|
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
|
||||||
action_details, before_value, after_value, success, error_message, created_at)
|
action_details, before_value, after_value, success, error_message, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
log.id, log.action_type, log.user_id, log.user_ip, log.user_agent,
|
(
|
||||||
log.resource_type, log.resource_id, log.action_details,
|
log.id,
|
||||||
log.before_value, log.after_value, int(log.success),
|
log.action_type,
|
||||||
log.error_message, log.created_at
|
log.user_id,
|
||||||
))
|
log.user_ip,
|
||||||
|
log.user_agent,
|
||||||
|
log.resource_type,
|
||||||
|
log.resource_id,
|
||||||
|
log.action_details,
|
||||||
|
log.before_value,
|
||||||
|
log.after_value,
|
||||||
|
int(log.success),
|
||||||
|
log.error_message,
|
||||||
|
log.created_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -372,7 +376,7 @@ class SecurityManager:
|
|||||||
end_time: str | None = None,
|
end_time: str | None = None,
|
||||||
success: bool | None = None,
|
success: bool | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0
|
offset: int = 0,
|
||||||
) -> list[AuditLog]:
|
) -> list[AuditLog]:
|
||||||
"""查询审计日志"""
|
"""查询审计日志"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -429,18 +433,14 @@ class SecurityManager:
|
|||||||
after_value=row[9],
|
after_value=row[9],
|
||||||
success=bool(row[10]),
|
success=bool(row[10]),
|
||||||
error_message=row[11],
|
error_message=row[11],
|
||||||
created_at=row[12]
|
created_at=row[12],
|
||||||
)
|
)
|
||||||
logs.append(log)
|
logs.append(log)
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
def get_audit_stats(
|
def get_audit_stats(self, start_time: str | None = None, end_time: str | None = None) -> dict[str, Any]:
|
||||||
self,
|
|
||||||
start_time: str | None = None,
|
|
||||||
end_time: str | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""获取审计统计"""
|
"""获取审计统计"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -460,12 +460,7 @@ class SecurityManager:
|
|||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
stats = {
|
stats = {"total_actions": 0, "success_count": 0, "failure_count": 0, "action_breakdown": {}}
|
||||||
"total_actions": 0,
|
|
||||||
"success_count": 0,
|
|
||||||
"failure_count": 0,
|
|
||||||
"action_breakdown": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for action_type, success, count in rows:
|
for action_type, success, count in rows:
|
||||||
stats["total_actions"] += count
|
stats["total_actions"] += count
|
||||||
@@ -500,11 +495,7 @@ class SecurityManager:
|
|||||||
)
|
)
|
||||||
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||||||
|
|
||||||
def enable_encryption(
|
def enable_encryption(self, project_id: str, master_password: str) -> EncryptionConfig:
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
master_password: str
|
|
||||||
) -> EncryptionConfig:
|
|
||||||
"""启用项目加密"""
|
"""启用项目加密"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
@@ -523,43 +514,54 @@ class SecurityManager:
|
|||||||
encryption_type="aes-256-gcm",
|
encryption_type="aes-256-gcm",
|
||||||
key_derivation="pbkdf2",
|
key_derivation="pbkdf2",
|
||||||
master_key_hash=key_hash,
|
master_key_hash=key_hash,
|
||||||
salt=salt
|
salt=salt,
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 检查是否已存在配置
|
# 检查是否已存在配置
|
||||||
cursor.execute(
|
cursor.execute("SELECT id FROM encryption_configs WHERE project_id = ?", (project_id,))
|
||||||
"SELECT id FROM encryption_configs WHERE project_id = ?",
|
|
||||||
(project_id,)
|
|
||||||
)
|
|
||||||
existing = cursor.fetchone()
|
existing = cursor.fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
UPDATE encryption_configs
|
UPDATE encryption_configs
|
||||||
SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
|
SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
|
||||||
master_key_hash = ?, salt = ?, updated_at = ?
|
master_key_hash = ?, salt = ?, updated_at = ?
|
||||||
WHERE project_id = ?
|
WHERE project_id = ?
|
||||||
""", (
|
""",
|
||||||
config.encryption_type, config.key_derivation,
|
(
|
||||||
config.master_key_hash, config.salt,
|
config.encryption_type,
|
||||||
config.updated_at, project_id
|
config.key_derivation,
|
||||||
))
|
config.master_key_hash,
|
||||||
|
config.salt,
|
||||||
|
config.updated_at,
|
||||||
|
project_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
config.id = existing[0]
|
config.id = existing[0]
|
||||||
else:
|
else:
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT INTO encryption_configs
|
INSERT INTO encryption_configs
|
||||||
(id, project_id, is_enabled, encryption_type, key_derivation,
|
(id, project_id, is_enabled, encryption_type, key_derivation,
|
||||||
master_key_hash, salt, created_at, updated_at)
|
master_key_hash, salt, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
config.id, config.project_id, int(config.is_enabled),
|
(
|
||||||
config.encryption_type, config.key_derivation,
|
config.id,
|
||||||
config.master_key_hash, config.salt,
|
config.project_id,
|
||||||
config.created_at, config.updated_at
|
int(config.is_enabled),
|
||||||
))
|
config.encryption_type,
|
||||||
|
config.key_derivation,
|
||||||
|
config.master_key_hash,
|
||||||
|
config.salt,
|
||||||
|
config.created_at,
|
||||||
|
config.updated_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -569,16 +571,12 @@ class SecurityManager:
|
|||||||
action_type=AuditActionType.ENCRYPTION_ENABLE,
|
action_type=AuditActionType.ENCRYPTION_ENABLE,
|
||||||
resource_type="project",
|
resource_type="project",
|
||||||
resource_id=project_id,
|
resource_id=project_id,
|
||||||
action_details={"encryption_type": config.encryption_type}
|
action_details={"encryption_type": config.encryption_type},
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def disable_encryption(
|
def disable_encryption(self, project_id: str, master_password: str) -> bool:
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
master_password: str
|
|
||||||
) -> bool:
|
|
||||||
"""禁用项目加密"""
|
"""禁用项目加密"""
|
||||||
# 验证密码
|
# 验证密码
|
||||||
if not self.verify_encryption_password(project_id, master_password):
|
if not self.verify_encryption_password(project_id, master_password):
|
||||||
@@ -587,29 +585,24 @@ class SecurityManager:
|
|||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
UPDATE encryption_configs
|
UPDATE encryption_configs
|
||||||
SET is_enabled = 0, updated_at = ?
|
SET is_enabled = 0, updated_at = ?
|
||||||
WHERE project_id = ?
|
WHERE project_id = ?
|
||||||
""", (datetime.now().isoformat(), project_id))
|
""",
|
||||||
|
(datetime.now().isoformat(), project_id),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# 记录审计日志
|
# 记录审计日志
|
||||||
self.log_audit(
|
self.log_audit(action_type=AuditActionType.ENCRYPTION_DISABLE, resource_type="project", resource_id=project_id)
|
||||||
action_type=AuditActionType.ENCRYPTION_DISABLE,
|
|
||||||
resource_type="project",
|
|
||||||
resource_id=project_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def verify_encryption_password(
|
def verify_encryption_password(self, project_id: str, password: str) -> bool:
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
password: str
|
|
||||||
) -> bool:
|
|
||||||
"""验证加密密码"""
|
"""验证加密密码"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
return False
|
return False
|
||||||
@@ -617,10 +610,7 @@ class SecurityManager:
|
|||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute("SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?", (project_id,))
|
||||||
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
|
|
||||||
(project_id,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -638,10 +628,7 @@ class SecurityManager:
|
|||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM encryption_configs WHERE project_id = ?", (project_id,))
|
||||||
"SELECT * FROM encryption_configs WHERE project_id = ?",
|
|
||||||
(project_id,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -657,15 +644,10 @@ class SecurityManager:
|
|||||||
master_key_hash=row[5],
|
master_key_hash=row[5],
|
||||||
salt=row[6],
|
salt=row[6],
|
||||||
created_at=row[7],
|
created_at=row[7],
|
||||||
updated_at=row[8]
|
updated_at=row[8],
|
||||||
)
|
)
|
||||||
|
|
||||||
def encrypt_data(
|
def encrypt_data(self, data: str, password: str, salt: str | None = None) -> tuple[str, str]:
|
||||||
self,
|
|
||||||
data: str,
|
|
||||||
password: str,
|
|
||||||
salt: str | None = None
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""加密数据"""
|
"""加密数据"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
@@ -679,12 +661,7 @@ class SecurityManager:
|
|||||||
|
|
||||||
return base64.b64encode(encrypted).decode(), salt
|
return base64.b64encode(encrypted).decode(), salt
|
||||||
|
|
||||||
def decrypt_data(
|
def decrypt_data(self, encrypted_data: str, password: str, salt: str) -> str:
|
||||||
self,
|
|
||||||
encrypted_data: str,
|
|
||||||
password: str,
|
|
||||||
salt: str
|
|
||||||
) -> str:
|
|
||||||
"""解密数据"""
|
"""解密数据"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
@@ -705,7 +682,7 @@ class SecurityManager:
|
|||||||
pattern: str | None = None,
|
pattern: str | None = None,
|
||||||
replacement: str | None = None,
|
replacement: str | None = None,
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
priority: int = 0
|
priority: int = 0,
|
||||||
) -> MaskingRule:
|
) -> MaskingRule:
|
||||||
"""创建脱敏规则"""
|
"""创建脱敏规则"""
|
||||||
# 使用预定义规则或自定义规则
|
# 使用预定义规则或自定义规则
|
||||||
@@ -722,22 +699,33 @@ class SecurityManager:
|
|||||||
pattern=pattern or "",
|
pattern=pattern or "",
|
||||||
replacement=replacement or "****",
|
replacement=replacement or "****",
|
||||||
description=description,
|
description=description,
|
||||||
priority=priority
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT INTO masking_rules
|
INSERT INTO masking_rules
|
||||||
(id, project_id, name, rule_type, pattern, replacement,
|
(id, project_id, name, rule_type, pattern, replacement,
|
||||||
is_active, priority, description, created_at, updated_at)
|
is_active, priority, description, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
rule.id, rule.project_id, rule.name, rule.rule_type,
|
(
|
||||||
rule.pattern, rule.replacement, int(rule.is_active),
|
rule.id,
|
||||||
rule.priority, rule.description, rule.created_at, rule.updated_at
|
rule.project_id,
|
||||||
))
|
rule.name,
|
||||||
|
rule.rule_type,
|
||||||
|
rule.pattern,
|
||||||
|
rule.replacement,
|
||||||
|
int(rule.is_active),
|
||||||
|
rule.priority,
|
||||||
|
rule.description,
|
||||||
|
rule.created_at,
|
||||||
|
rule.updated_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -747,16 +735,12 @@ class SecurityManager:
|
|||||||
action_type=AuditActionType.DATA_MASKING,
|
action_type=AuditActionType.DATA_MASKING,
|
||||||
resource_type="project",
|
resource_type="project",
|
||||||
resource_id=project_id,
|
resource_id=project_id,
|
||||||
action_details={"action": "create_rule", "rule_name": name}
|
action_details={"action": "create_rule", "rule_name": name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return rule
|
return rule
|
||||||
|
|
||||||
def get_masking_rules(
|
def get_masking_rules(self, project_id: str, active_only: bool = True) -> list[MaskingRule]:
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
active_only: bool = True
|
|
||||||
) -> list[MaskingRule]:
|
|
||||||
"""获取脱敏规则"""
|
"""获取脱敏规则"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -775,27 +759,25 @@ class SecurityManager:
|
|||||||
|
|
||||||
rules = []
|
rules = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
rules.append(MaskingRule(
|
rules.append(
|
||||||
id=row[0],
|
MaskingRule(
|
||||||
project_id=row[1],
|
id=row[0],
|
||||||
name=row[2],
|
project_id=row[1],
|
||||||
rule_type=row[3],
|
name=row[2],
|
||||||
pattern=row[4],
|
rule_type=row[3],
|
||||||
replacement=row[5],
|
pattern=row[4],
|
||||||
is_active=bool(row[6]),
|
replacement=row[5],
|
||||||
priority=row[7],
|
is_active=bool(row[6]),
|
||||||
description=row[8],
|
priority=row[7],
|
||||||
created_at=row[9],
|
description=row[8],
|
||||||
updated_at=row[10]
|
created_at=row[9],
|
||||||
))
|
updated_at=row[10],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
def update_masking_rule(
|
def update_masking_rule(self, rule_id: str, **kwargs) -> MaskingRule | None:
|
||||||
self,
|
|
||||||
rule_id: str,
|
|
||||||
**kwargs
|
|
||||||
) -> MaskingRule | None:
|
|
||||||
"""更新脱敏规则"""
|
"""更新脱敏规则"""
|
||||||
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
|
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
|
||||||
|
|
||||||
@@ -818,11 +800,14 @@ class SecurityManager:
|
|||||||
params.append(datetime.now().isoformat())
|
params.append(datetime.now().isoformat())
|
||||||
params.append(rule_id)
|
params.append(rule_id)
|
||||||
|
|
||||||
cursor.execute(f"""
|
cursor.execute(
|
||||||
|
f"""
|
||||||
UPDATE masking_rules
|
UPDATE masking_rules
|
||||||
SET {', '.join(set_clauses)}
|
SET {', '.join(set_clauses)}
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", params)
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -848,7 +833,7 @@ class SecurityManager:
|
|||||||
priority=row[7],
|
priority=row[7],
|
||||||
description=row[8],
|
description=row[8],
|
||||||
created_at=row[9],
|
created_at=row[9],
|
||||||
updated_at=row[10]
|
updated_at=row[10],
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_masking_rule(self, rule_id: str) -> bool:
|
def delete_masking_rule(self, rule_id: str) -> bool:
|
||||||
@@ -864,12 +849,7 @@ class SecurityManager:
|
|||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
def apply_masking(
|
def apply_masking(self, text: str, project_id: str, rule_types: list[MaskingRuleType] | None = None) -> str:
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
project_id: str,
|
|
||||||
rule_types: list[MaskingRuleType] | None = None
|
|
||||||
) -> str:
|
|
||||||
"""应用脱敏规则到文本"""
|
"""应用脱敏规则到文本"""
|
||||||
rules = self.get_masking_rules(project_id)
|
rules = self.get_masking_rules(project_id)
|
||||||
|
|
||||||
@@ -884,22 +864,14 @@ class SecurityManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
masked_text = re.sub(
|
masked_text = re.sub(rule.pattern, rule.replacement, masked_text)
|
||||||
rule.pattern,
|
|
||||||
rule.replacement,
|
|
||||||
masked_text
|
|
||||||
)
|
|
||||||
except re.error:
|
except re.error:
|
||||||
# 忽略无效的正则表达式
|
# 忽略无效的正则表达式
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return masked_text
|
return masked_text
|
||||||
|
|
||||||
def apply_masking_to_entity(
|
def apply_masking_to_entity(self, entity_data: dict[str, Any], project_id: str) -> dict[str, Any]:
|
||||||
self,
|
|
||||||
entity_data: dict[str, Any],
|
|
||||||
project_id: str
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""对实体数据应用脱敏"""
|
"""对实体数据应用脱敏"""
|
||||||
masked_data = entity_data.copy()
|
masked_data = entity_data.copy()
|
||||||
|
|
||||||
@@ -924,7 +896,7 @@ class SecurityManager:
|
|||||||
allowed_ips: list[str] | None = None,
|
allowed_ips: list[str] | None = None,
|
||||||
time_restrictions: dict | None = None,
|
time_restrictions: dict | None = None,
|
||||||
max_access_count: int | None = None,
|
max_access_count: int | None = None,
|
||||||
require_approval: bool = False
|
require_approval: bool = False,
|
||||||
) -> DataAccessPolicy:
|
) -> DataAccessPolicy:
|
||||||
"""创建数据访问策略"""
|
"""创建数据访问策略"""
|
||||||
policy = DataAccessPolicy(
|
policy = DataAccessPolicy(
|
||||||
@@ -937,36 +909,43 @@ class SecurityManager:
|
|||||||
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
|
allowed_ips=json.dumps(allowed_ips) if allowed_ips else None,
|
||||||
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
|
time_restrictions=json.dumps(time_restrictions) if time_restrictions else None,
|
||||||
max_access_count=max_access_count,
|
max_access_count=max_access_count,
|
||||||
require_approval=require_approval
|
require_approval=require_approval,
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT INTO data_access_policies
|
INSERT INTO data_access_policies
|
||||||
(id, project_id, name, description, allowed_users, allowed_roles,
|
(id, project_id, name, description, allowed_users, allowed_roles,
|
||||||
allowed_ips, time_restrictions, max_access_count, require_approval,
|
allowed_ips, time_restrictions, max_access_count, require_approval,
|
||||||
is_active, created_at, updated_at)
|
is_active, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
policy.id, policy.project_id, policy.name, policy.description,
|
(
|
||||||
policy.allowed_users, policy.allowed_roles, policy.allowed_ips,
|
policy.id,
|
||||||
policy.time_restrictions, policy.max_access_count,
|
policy.project_id,
|
||||||
int(policy.require_approval), int(policy.is_active),
|
policy.name,
|
||||||
policy.created_at, policy.updated_at
|
policy.description,
|
||||||
))
|
policy.allowed_users,
|
||||||
|
policy.allowed_roles,
|
||||||
|
policy.allowed_ips,
|
||||||
|
policy.time_restrictions,
|
||||||
|
policy.max_access_count,
|
||||||
|
int(policy.require_approval),
|
||||||
|
int(policy.is_active),
|
||||||
|
policy.created_at,
|
||||||
|
policy.updated_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def get_access_policies(
|
def get_access_policies(self, project_id: str, active_only: bool = True) -> list[DataAccessPolicy]:
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
active_only: bool = True
|
|
||||||
) -> list[DataAccessPolicy]:
|
|
||||||
"""获取数据访问策略"""
|
"""获取数据访问策略"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -983,38 +962,34 @@ class SecurityManager:
|
|||||||
|
|
||||||
policies = []
|
policies = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
policies.append(DataAccessPolicy(
|
policies.append(
|
||||||
id=row[0],
|
DataAccessPolicy(
|
||||||
project_id=row[1],
|
id=row[0],
|
||||||
name=row[2],
|
project_id=row[1],
|
||||||
description=row[3],
|
name=row[2],
|
||||||
allowed_users=row[4],
|
description=row[3],
|
||||||
allowed_roles=row[5],
|
allowed_users=row[4],
|
||||||
allowed_ips=row[6],
|
allowed_roles=row[5],
|
||||||
time_restrictions=row[7],
|
allowed_ips=row[6],
|
||||||
max_access_count=row[8],
|
time_restrictions=row[7],
|
||||||
require_approval=bool(row[9]),
|
max_access_count=row[8],
|
||||||
is_active=bool(row[10]),
|
require_approval=bool(row[9]),
|
||||||
created_at=row[11],
|
is_active=bool(row[10]),
|
||||||
updated_at=row[12]
|
created_at=row[11],
|
||||||
))
|
updated_at=row[12],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return policies
|
return policies
|
||||||
|
|
||||||
def check_access_permission(
|
def check_access_permission(
|
||||||
self,
|
self, policy_id: str, user_id: str, user_ip: str | None = None
|
||||||
policy_id: str,
|
|
||||||
user_id: str,
|
|
||||||
user_ip: str | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
) -> tuple[bool, str | None]:
|
||||||
"""检查访问权限"""
|
"""检查访问权限"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1", (policy_id,))
|
||||||
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
|
|
||||||
(policy_id,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -1034,7 +1009,7 @@ class SecurityManager:
|
|||||||
require_approval=bool(row[9]),
|
require_approval=bool(row[9]),
|
||||||
is_active=bool(row[10]),
|
is_active=bool(row[10]),
|
||||||
created_at=row[11],
|
created_at=row[11],
|
||||||
updated_at=row[12]
|
updated_at=row[12],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查用户白名单
|
# 检查用户白名单
|
||||||
@@ -1074,11 +1049,14 @@ class SecurityManager:
|
|||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
SELECT * FROM access_requests
|
SELECT * FROM access_requests
|
||||||
WHERE policy_id = ? AND user_id = ? AND status = 'approved'
|
WHERE policy_id = ? AND user_id = ? AND status = 'approved'
|
||||||
AND (expires_at IS NULL OR expires_at > ?)
|
AND (expires_at IS NULL OR expires_at > ?)
|
||||||
""", (policy_id, user_id, datetime.now().isoformat()))
|
""",
|
||||||
|
(policy_id, user_id, datetime.now().isoformat()),
|
||||||
|
)
|
||||||
|
|
||||||
request = cursor.fetchone()
|
request = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -1104,11 +1082,7 @@ class SecurityManager:
|
|||||||
return ip == pattern
|
return ip == pattern
|
||||||
|
|
||||||
def create_access_request(
|
def create_access_request(
|
||||||
self,
|
self, policy_id: str, user_id: str, request_reason: str | None = None, expires_hours: int = 24
|
||||||
policy_id: str,
|
|
||||||
user_id: str,
|
|
||||||
request_reason: str | None = None,
|
|
||||||
expires_hours: int = 24
|
|
||||||
) -> AccessRequest:
|
) -> AccessRequest:
|
||||||
"""创建访问请求"""
|
"""创建访问请求"""
|
||||||
request = AccessRequest(
|
request = AccessRequest(
|
||||||
@@ -1116,21 +1090,28 @@ class SecurityManager:
|
|||||||
policy_id=policy_id,
|
policy_id=policy_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
request_reason=request_reason,
|
request_reason=request_reason,
|
||||||
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat(),
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT INTO access_requests
|
INSERT INTO access_requests
|
||||||
(id, policy_id, user_id, request_reason, status, expires_at, created_at)
|
(id, policy_id, user_id, request_reason, status, expires_at, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
request.id, request.policy_id, request.user_id,
|
(
|
||||||
request.request_reason, request.status, request.expires_at,
|
request.id,
|
||||||
request.created_at
|
request.policy_id,
|
||||||
))
|
request.user_id,
|
||||||
|
request.request_reason,
|
||||||
|
request.status,
|
||||||
|
request.expires_at,
|
||||||
|
request.created_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -1138,10 +1119,7 @@ class SecurityManager:
|
|||||||
return request
|
return request
|
||||||
|
|
||||||
def approve_access_request(
|
def approve_access_request(
|
||||||
self,
|
self, request_id: str, approved_by: str, expires_hours: int = 24
|
||||||
request_id: str,
|
|
||||||
approved_by: str,
|
|
||||||
expires_hours: int = 24
|
|
||||||
) -> AccessRequest | None:
|
) -> AccessRequest | None:
|
||||||
"""批准访问请求"""
|
"""批准访问请求"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -1150,11 +1128,14 @@ class SecurityManager:
|
|||||||
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
||||||
approved_at = datetime.now().isoformat()
|
approved_at = datetime.now().isoformat()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
|
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", (approved_by, approved_at, expires_at, request_id))
|
""",
|
||||||
|
(approved_by, approved_at, expires_at, request_id),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
@@ -1175,23 +1156,22 @@ class SecurityManager:
|
|||||||
approved_by=row[5],
|
approved_by=row[5],
|
||||||
approved_at=row[6],
|
approved_at=row[6],
|
||||||
expires_at=row[7],
|
expires_at=row[7],
|
||||||
created_at=row[8]
|
created_at=row[8],
|
||||||
)
|
)
|
||||||
|
|
||||||
def reject_access_request(
|
def reject_access_request(self, request_id: str, rejected_by: str) -> AccessRequest | None:
|
||||||
self,
|
|
||||||
request_id: str,
|
|
||||||
rejected_by: str
|
|
||||||
) -> AccessRequest | None:
|
|
||||||
"""拒绝访问请求"""
|
"""拒绝访问请求"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute(
|
||||||
|
"""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET status = 'rejected', approved_by = ?
|
SET status = 'rejected', approved_by = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", (rejected_by, request_id))
|
""",
|
||||||
|
(rejected_by, request_id),
|
||||||
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
@@ -1211,7 +1191,7 @@ class SecurityManager:
|
|||||||
approved_by=row[5],
|
approved_by=row[5],
|
||||||
approved_at=row[6],
|
approved_at=row[6],
|
||||||
expires_at=row[7],
|
expires_at=row[7],
|
||||||
created_at=row[8]
|
created_at=row[8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ class TingwuClient:
|
|||||||
|
|
||||||
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
|
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> dict[str, str]:
|
||||||
"""阿里云签名 V3"""
|
"""阿里云签名 V3"""
|
||||||
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
# 简化签名,实际生产需要完整实现
|
# 简化签名,实际生产需要完整实现
|
||||||
# 这里使用基础认证头
|
# 这里使用基础认证头
|
||||||
@@ -39,25 +39,16 @@ class TingwuClient:
|
|||||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||||
|
|
||||||
config = open_api_models.Config(
|
config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
|
||||||
access_key_id=self.access_key,
|
|
||||||
access_key_secret=self.secret_key
|
|
||||||
)
|
|
||||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||||
client = TingwuSDKClient(config)
|
client = TingwuSDKClient(config)
|
||||||
|
|
||||||
request = tingwu_models.CreateTaskRequest(
|
request = tingwu_models.CreateTaskRequest(
|
||||||
type="offline",
|
type="offline",
|
||||||
input=tingwu_models.Input(
|
input=tingwu_models.Input(source="OSS", file_url=audio_url),
|
||||||
source="OSS",
|
|
||||||
file_url=audio_url
|
|
||||||
),
|
|
||||||
parameters=tingwu_models.Parameters(
|
parameters=tingwu_models.Parameters(
|
||||||
transcription=tingwu_models.Transcription(
|
transcription=tingwu_models.Transcription(diarization_enabled=True, sentence_max_length=20)
|
||||||
diarization_enabled=True,
|
),
|
||||||
sentence_max_length=20
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.create_task(request)
|
response = client.create_task(request)
|
||||||
@@ -81,10 +72,7 @@ class TingwuClient:
|
|||||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||||
|
|
||||||
config = open_api_models.Config(
|
config = open_api_models.Config(access_key_id=self.access_key, access_key_secret=self.secret_key)
|
||||||
access_key_id=self.access_key,
|
|
||||||
access_key_secret=self.secret_key
|
|
||||||
)
|
|
||||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||||
client = TingwuSDKClient(config)
|
client = TingwuSDKClient(config)
|
||||||
|
|
||||||
@@ -128,25 +116,29 @@ class TingwuClient:
|
|||||||
|
|
||||||
if transcription.sentences:
|
if transcription.sentences:
|
||||||
for sent in transcription.sentences:
|
for sent in transcription.sentences:
|
||||||
segments.append({
|
segments.append(
|
||||||
"start": sent.begin_time / 1000,
|
{
|
||||||
"end": sent.end_time / 1000,
|
"start": sent.begin_time / 1000,
|
||||||
"text": sent.text,
|
"end": sent.end_time / 1000,
|
||||||
"speaker": f"Speaker {sent.speaker_id}"
|
"text": sent.text,
|
||||||
})
|
"speaker": f"Speaker {sent.speaker_id}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {"full_text": full_text.strip(), "segments": segments}
|
||||||
"full_text": full_text.strip(),
|
|
||||||
"segments": segments
|
|
||||||
}
|
|
||||||
|
|
||||||
def _mock_result(self) -> dict[str, Any]:
|
def _mock_result(self) -> dict[str, Any]:
|
||||||
"""Mock 结果"""
|
"""Mock 结果"""
|
||||||
return {
|
return {
|
||||||
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
|
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
|
||||||
"segments": [
|
"segments": [
|
||||||
{"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"}
|
{
|
||||||
]
|
"start": 0.0,
|
||||||
|
"end": 5.0,
|
||||||
|
"text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
|
||||||
|
"speaker": "Speaker A",
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
|
def transcribe(self, audio_url: str, language: str = "zh") -> dict[str, Any]:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user