fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解 - 修复重复函数定义 (health_check, create_webhook_endpoint, etc) - 修复未定义名称 (SearchOperator, TenantTier, Query, Body, logger) - 修复 workflow_manager.py 的类定义重复问题 - 添加缺失的导入
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -43,15 +43,15 @@ class ApiKey:
|
|||||||
|
|
||||||
class ApiKeyManager:
|
class ApiKeyManager:
|
||||||
"""API Key 管理器"""
|
"""API Key 管理器"""
|
||||||
|
|
||||||
# Key 前缀
|
# Key 前缀
|
||||||
KEY_PREFIX = "ak_live_"
|
KEY_PREFIX = "ak_live_"
|
||||||
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
|
KEY_LENGTH = 48 # 总长度: 前缀(8) + 随机部分(40)
|
||||||
|
|
||||||
def __init__(self, db_path: str = DB_PATH):
|
def __init__(self, db_path: str = DB_PATH):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
"""初始化数据库表"""
|
"""初始化数据库表"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
@@ -73,7 +73,7 @@ class ApiKeyManager:
|
|||||||
revoked_reason TEXT,
|
revoked_reason TEXT,
|
||||||
total_calls INTEGER DEFAULT 0
|
total_calls INTEGER DEFAULT 0
|
||||||
);
|
);
|
||||||
|
|
||||||
-- API 调用日志表
|
-- API 调用日志表
|
||||||
CREATE TABLE IF NOT EXISTS api_call_logs (
|
CREATE TABLE IF NOT EXISTS api_call_logs (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -88,7 +88,7 @@ class ApiKeyManager:
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
FOREIGN KEY (api_key_id) REFERENCES api_keys(id)
|
FOREIGN KEY (api_key_id) REFERENCES api_keys(id)
|
||||||
);
|
);
|
||||||
|
|
||||||
-- API 调用统计表(按天汇总)
|
-- API 调用统计表(按天汇总)
|
||||||
CREATE TABLE IF NOT EXISTS api_call_stats (
|
CREATE TABLE IF NOT EXISTS api_call_stats (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -103,7 +103,7 @@ class ApiKeyManager:
|
|||||||
FOREIGN KEY (api_key_id) REFERENCES api_keys(id),
|
FOREIGN KEY (api_key_id) REFERENCES api_keys(id),
|
||||||
UNIQUE(api_key_id, date, endpoint, method)
|
UNIQUE(api_key_id, date, endpoint, method)
|
||||||
);
|
);
|
||||||
|
|
||||||
-- 创建索引
|
-- 创建索引
|
||||||
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
|
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
|
||||||
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
|
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
|
||||||
@@ -113,47 +113,47 @@ class ApiKeyManager:
|
|||||||
CREATE INDEX IF NOT EXISTS idx_api_stats_key_date ON api_call_stats(api_key_id, date);
|
CREATE INDEX IF NOT EXISTS idx_api_stats_key_date ON api_call_stats(api_key_id, date);
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def _generate_key(self) -> str:
|
def _generate_key(self) -> str:
|
||||||
"""生成新的 API Key"""
|
"""生成新的 API Key"""
|
||||||
# 生成 40 字符的随机字符串
|
# 生成 40 字符的随机字符串
|
||||||
random_part = secrets.token_urlsafe(30)[:40]
|
random_part = secrets.token_urlsafe(30)[:40]
|
||||||
return f"{self.KEY_PREFIX}{random_part}"
|
return f"{self.KEY_PREFIX}{random_part}"
|
||||||
|
|
||||||
def _hash_key(self, key: str) -> str:
|
def _hash_key(self, key: str) -> str:
|
||||||
"""对 API Key 进行哈希"""
|
"""对 API Key 进行哈希"""
|
||||||
return hashlib.sha256(key.encode()).hexdigest()
|
return hashlib.sha256(key.encode()).hexdigest()
|
||||||
|
|
||||||
def _get_preview(self, key: str) -> str:
|
def _get_preview(self, key: str) -> str:
|
||||||
"""获取 Key 的预览(前16位)"""
|
"""获取 Key 的预览(前16位)"""
|
||||||
return f"{key[:16]}..."
|
return f"{key[:16]}..."
|
||||||
|
|
||||||
def create_key(
|
def create_key(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
owner_id: Optional[str] = None,
|
owner_id: Optional[str] = None,
|
||||||
permissions: List[str] = None,
|
permissions: List[str] = None,
|
||||||
rate_limit: int = 60,
|
rate_limit: int = 60,
|
||||||
expires_days: Optional[int] = None
|
expires_days: Optional[int] = None,
|
||||||
) -> tuple[str, ApiKey]:
|
) -> tuple[str, ApiKey]:
|
||||||
"""
|
"""
|
||||||
创建新的 API Key
|
创建新的 API Key
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (原始key(仅返回一次), ApiKey对象)
|
tuple: (原始key(仅返回一次), ApiKey对象)
|
||||||
"""
|
"""
|
||||||
if permissions is None:
|
if permissions is None:
|
||||||
permissions = ["read"]
|
permissions = ["read"]
|
||||||
|
|
||||||
key_id = secrets.token_hex(16)
|
key_id = secrets.token_hex(16)
|
||||||
raw_key = self._generate_key()
|
raw_key = self._generate_key()
|
||||||
key_hash = self._hash_key(raw_key)
|
key_hash = self._hash_key(raw_key)
|
||||||
key_preview = self._get_preview(raw_key)
|
key_preview = self._get_preview(raw_key)
|
||||||
|
|
||||||
expires_at = None
|
expires_at = None
|
||||||
if expires_days:
|
if expires_days:
|
||||||
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
|
expires_at = (datetime.now() + timedelta(days=expires_days)).isoformat()
|
||||||
|
|
||||||
api_key = ApiKey(
|
api_key = ApiKey(
|
||||||
id=key_id,
|
id=key_id,
|
||||||
key_hash=key_hash,
|
key_hash=key_hash,
|
||||||
@@ -168,197 +168,183 @@ class ApiKeyManager:
|
|||||||
last_used_at=None,
|
last_used_at=None,
|
||||||
revoked_at=None,
|
revoked_at=None,
|
||||||
revoked_reason=None,
|
revoked_reason=None,
|
||||||
total_calls=0
|
total_calls=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
INSERT INTO api_keys (
|
INSERT INTO api_keys (
|
||||||
id, key_hash, key_preview, name, owner_id, permissions,
|
id, key_hash, key_preview, name, owner_id, permissions,
|
||||||
rate_limit, status, created_at, expires_at
|
rate_limit, status, created_at, expires_at
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
api_key.id, api_key.key_hash, api_key.key_preview,
|
(
|
||||||
api_key.name, api_key.owner_id, json.dumps(api_key.permissions),
|
api_key.id,
|
||||||
api_key.rate_limit, api_key.status, api_key.created_at,
|
api_key.key_hash,
|
||||||
api_key.expires_at
|
api_key.key_preview,
|
||||||
))
|
api_key.name,
|
||||||
|
api_key.owner_id,
|
||||||
|
json.dumps(api_key.permissions),
|
||||||
|
api_key.rate_limit,
|
||||||
|
api_key.status,
|
||||||
|
api_key.created_at,
|
||||||
|
api_key.expires_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
return raw_key, api_key
|
return raw_key, api_key
|
||||||
|
|
||||||
def validate_key(self, key: str) -> Optional[ApiKey]:
|
def validate_key(self, key: str) -> Optional[ApiKey]:
|
||||||
"""
|
"""
|
||||||
验证 API Key
|
验证 API Key
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiKey if valid, None otherwise
|
ApiKey if valid, None otherwise
|
||||||
"""
|
"""
|
||||||
key_hash = self._hash_key(key)
|
key_hash = self._hash_key(key)
|
||||||
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,)).fetchone()
|
||||||
"SELECT * FROM api_keys WHERE key_hash = ?",
|
|
||||||
(key_hash,)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
api_key = self._row_to_api_key(row)
|
api_key = self._row_to_api_key(row)
|
||||||
|
|
||||||
# 检查状态
|
# 检查状态
|
||||||
if api_key.status != ApiKeyStatus.ACTIVE.value:
|
if api_key.status != ApiKeyStatus.ACTIVE.value:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 检查是否过期
|
# 检查是否过期
|
||||||
if api_key.expires_at:
|
if api_key.expires_at:
|
||||||
expires = datetime.fromisoformat(api_key.expires_at)
|
expires = datetime.fromisoformat(api_key.expires_at)
|
||||||
if datetime.now() > expires:
|
if datetime.now() > expires:
|
||||||
# 更新状态为过期
|
# 更新状态为过期
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE api_keys SET status = ? WHERE id = ?",
|
"UPDATE api_keys SET status = ? WHERE id = ?", (ApiKeyStatus.EXPIRED.value, api_key.id)
|
||||||
(ApiKeyStatus.EXPIRED.value, api_key.id)
|
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
def revoke_key(
|
def revoke_key(self, key_id: str, reason: str = "", owner_id: Optional[str] = None) -> bool:
|
||||||
self,
|
|
||||||
key_id: str,
|
|
||||||
reason: str = "",
|
|
||||||
owner_id: Optional[str] = None
|
|
||||||
) -> bool:
|
|
||||||
"""撤销 API Key"""
|
"""撤销 API Key"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
# 验证所有权(如果提供了 owner_id)
|
# 验证所有权(如果提供了 owner_id)
|
||||||
if owner_id:
|
if owner_id:
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)).fetchone()
|
||||||
"SELECT owner_id FROM api_keys WHERE id = ?",
|
|
||||||
(key_id,)
|
|
||||||
).fetchone()
|
|
||||||
if not row or row[0] != owner_id:
|
if not row or row[0] != owner_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
cursor = conn.execute("""
|
cursor = conn.execute(
|
||||||
UPDATE api_keys
|
"""
|
||||||
|
UPDATE api_keys
|
||||||
SET status = ?, revoked_at = ?, revoked_reason = ?
|
SET status = ?, revoked_at = ?, revoked_reason = ?
|
||||||
WHERE id = ? AND status = ?
|
WHERE id = ? AND status = ?
|
||||||
""", (
|
""",
|
||||||
ApiKeyStatus.REVOKED.value,
|
(ApiKeyStatus.REVOKED.value, datetime.now().isoformat(), reason, key_id, ApiKeyStatus.ACTIVE.value),
|
||||||
datetime.now().isoformat(),
|
)
|
||||||
reason,
|
|
||||||
key_id,
|
|
||||||
ApiKeyStatus.ACTIVE.value
|
|
||||||
))
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def get_key_by_id(self, key_id: str, owner_id: Optional[str] = None) -> Optional[ApiKey]:
|
def get_key_by_id(self, key_id: str, owner_id: Optional[str] = None) -> Optional[ApiKey]:
|
||||||
"""通过 ID 获取 API Key(不包含敏感信息)"""
|
"""通过 ID 获取 API Key(不包含敏感信息)"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
if owner_id:
|
if owner_id:
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?",
|
"SELECT * FROM api_keys WHERE id = ? AND owner_id = ?", (key_id, owner_id)
|
||||||
(key_id, owner_id)
|
|
||||||
).fetchone()
|
).fetchone()
|
||||||
else:
|
else:
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT * FROM api_keys WHERE id = ?", (key_id,)).fetchone()
|
||||||
"SELECT * FROM api_keys WHERE id = ?",
|
|
||||||
(key_id,)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
return self._row_to_api_key(row)
|
return self._row_to_api_key(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_keys(
|
def list_keys(
|
||||||
self,
|
self, owner_id: Optional[str] = None, status: Optional[str] = None, limit: int = 100, offset: int = 0
|
||||||
owner_id: Optional[str] = None,
|
|
||||||
status: Optional[str] = None,
|
|
||||||
limit: int = 100,
|
|
||||||
offset: int = 0
|
|
||||||
) -> List[ApiKey]:
|
) -> List[ApiKey]:
|
||||||
"""列出 API Keys"""
|
"""列出 API Keys"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
query = "SELECT * FROM api_keys WHERE 1=1"
|
query = "SELECT * FROM api_keys WHERE 1=1"
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if owner_id:
|
if owner_id:
|
||||||
query += " AND owner_id = ?"
|
query += " AND owner_id = ?"
|
||||||
params.append(owner_id)
|
params.append(owner_id)
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
query += " AND status = ?"
|
query += " AND status = ?"
|
||||||
params.append(status)
|
params.append(status)
|
||||||
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
params.extend([limit, offset])
|
params.extend([limit, offset])
|
||||||
|
|
||||||
rows = conn.execute(query, params).fetchall()
|
rows = conn.execute(query, params).fetchall()
|
||||||
return [self._row_to_api_key(row) for row in rows]
|
return [self._row_to_api_key(row) for row in rows]
|
||||||
|
|
||||||
def update_key(
|
def update_key(
|
||||||
self,
|
self,
|
||||||
key_id: str,
|
key_id: str,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
permissions: Optional[List[str]] = None,
|
permissions: Optional[List[str]] = None,
|
||||||
rate_limit: Optional[int] = None,
|
rate_limit: Optional[int] = None,
|
||||||
owner_id: Optional[str] = None
|
owner_id: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""更新 API Key 信息"""
|
"""更新 API Key 信息"""
|
||||||
updates = []
|
updates = []
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
updates.append("name = ?")
|
updates.append("name = ?")
|
||||||
params.append(name)
|
params.append(name)
|
||||||
|
|
||||||
if permissions is not None:
|
if permissions is not None:
|
||||||
updates.append("permissions = ?")
|
updates.append("permissions = ?")
|
||||||
params.append(json.dumps(permissions))
|
params.append(json.dumps(permissions))
|
||||||
|
|
||||||
if rate_limit is not None:
|
if rate_limit is not None:
|
||||||
updates.append("rate_limit = ?")
|
updates.append("rate_limit = ?")
|
||||||
params.append(rate_limit)
|
params.append(rate_limit)
|
||||||
|
|
||||||
if not updates:
|
if not updates:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
params.append(key_id)
|
params.append(key_id)
|
||||||
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
# 验证所有权
|
# 验证所有权
|
||||||
if owner_id:
|
if owner_id:
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT owner_id FROM api_keys WHERE id = ?", (key_id,)).fetchone()
|
||||||
"SELECT owner_id FROM api_keys WHERE id = ?",
|
|
||||||
(key_id,)
|
|
||||||
).fetchone()
|
|
||||||
if not row or row[0] != owner_id:
|
if not row or row[0] != owner_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
|
query = f"UPDATE api_keys SET {', '.join(updates)} WHERE id = ?"
|
||||||
cursor = conn.execute(query, params)
|
cursor = conn.execute(query, params)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def update_last_used(self, key_id: str):
|
def update_last_used(self, key_id: str):
|
||||||
"""更新最后使用时间"""
|
"""更新最后使用时间"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
UPDATE api_keys
|
"""
|
||||||
|
UPDATE api_keys
|
||||||
SET last_used_at = ?, total_calls = total_calls + 1
|
SET last_used_at = ?, total_calls = total_calls + 1
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", (datetime.now().isoformat(), key_id))
|
""",
|
||||||
|
(datetime.now().isoformat(), key_id),
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def log_api_call(
|
def log_api_call(
|
||||||
self,
|
self,
|
||||||
api_key_id: str,
|
api_key_id: str,
|
||||||
@@ -368,66 +354,62 @@ class ApiKeyManager:
|
|||||||
response_time_ms: int = 0,
|
response_time_ms: int = 0,
|
||||||
ip_address: str = "",
|
ip_address: str = "",
|
||||||
user_agent: str = "",
|
user_agent: str = "",
|
||||||
error_message: str = ""
|
error_message: str = "",
|
||||||
):
|
):
|
||||||
"""记录 API 调用日志"""
|
"""记录 API 调用日志"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
INSERT INTO api_call_logs
|
"""
|
||||||
(api_key_id, endpoint, method, status_code, response_time_ms,
|
INSERT INTO api_call_logs
|
||||||
|
(api_key_id, endpoint, method, status_code, response_time_ms,
|
||||||
ip_address, user_agent, error_message)
|
ip_address, user_agent, error_message)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
api_key_id, endpoint, method, status_code, response_time_ms,
|
(api_key_id, endpoint, method, status_code, response_time_ms, ip_address, user_agent, error_message),
|
||||||
ip_address, user_agent, error_message
|
)
|
||||||
))
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def get_call_logs(
|
def get_call_logs(
|
||||||
self,
|
self,
|
||||||
api_key_id: Optional[str] = None,
|
api_key_id: Optional[str] = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0
|
offset: int = 0,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""获取 API 调用日志"""
|
"""获取 API 调用日志"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
query = "SELECT * FROM api_call_logs WHERE 1=1"
|
query = "SELECT * FROM api_call_logs WHERE 1=1"
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if api_key_id:
|
if api_key_id:
|
||||||
query += " AND api_key_id = ?"
|
query += " AND api_key_id = ?"
|
||||||
params.append(api_key_id)
|
params.append(api_key_id)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
query += " AND created_at >= ?"
|
query += " AND created_at >= ?"
|
||||||
params.append(start_date)
|
params.append(start_date)
|
||||||
|
|
||||||
if end_date:
|
if end_date:
|
||||||
query += " AND created_at <= ?"
|
query += " AND created_at <= ?"
|
||||||
params.append(end_date)
|
params.append(end_date)
|
||||||
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
params.extend([limit, offset])
|
params.extend([limit, offset])
|
||||||
|
|
||||||
rows = conn.execute(query, params).fetchall()
|
rows = conn.execute(query, params).fetchall()
|
||||||
return [dict(row) for row in rows]
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
def get_call_stats(
|
def get_call_stats(self, api_key_id: Optional[str] = None, days: int = 30) -> Dict:
|
||||||
self,
|
|
||||||
api_key_id: Optional[str] = None,
|
|
||||||
days: int = 30
|
|
||||||
) -> Dict:
|
|
||||||
"""获取 API 调用统计"""
|
"""获取 API 调用统计"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
# 总体统计
|
# 总体统计
|
||||||
query = """
|
query = """
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_calls,
|
COUNT(*) as total_calls,
|
||||||
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
|
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success_calls,
|
||||||
COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_calls,
|
COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_calls,
|
||||||
@@ -437,17 +419,17 @@ class ApiKeyManager:
|
|||||||
FROM api_call_logs
|
FROM api_call_logs
|
||||||
WHERE created_at >= date('now', '-{} days')
|
WHERE created_at >= date('now', '-{} days')
|
||||||
""".format(days)
|
""".format(days)
|
||||||
|
|
||||||
params = []
|
params = []
|
||||||
if api_key_id:
|
if api_key_id:
|
||||||
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
query = query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||||
params.insert(0, api_key_id)
|
params.insert(0, api_key_id)
|
||||||
|
|
||||||
row = conn.execute(query, params).fetchone()
|
row = conn.execute(query, params).fetchone()
|
||||||
|
|
||||||
# 按端点统计
|
# 按端点统计
|
||||||
endpoint_query = """
|
endpoint_query = """
|
||||||
SELECT
|
SELECT
|
||||||
endpoint,
|
endpoint,
|
||||||
method,
|
method,
|
||||||
COUNT(*) as calls,
|
COUNT(*) as calls,
|
||||||
@@ -455,35 +437,35 @@ class ApiKeyManager:
|
|||||||
FROM api_call_logs
|
FROM api_call_logs
|
||||||
WHERE created_at >= date('now', '-{} days')
|
WHERE created_at >= date('now', '-{} days')
|
||||||
""".format(days)
|
""".format(days)
|
||||||
|
|
||||||
endpoint_params = []
|
endpoint_params = []
|
||||||
if api_key_id:
|
if api_key_id:
|
||||||
endpoint_query = endpoint_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
endpoint_query = endpoint_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||||
endpoint_params.insert(0, api_key_id)
|
endpoint_params.insert(0, api_key_id)
|
||||||
|
|
||||||
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
|
endpoint_query += " GROUP BY endpoint, method ORDER BY calls DESC"
|
||||||
|
|
||||||
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
|
endpoint_rows = conn.execute(endpoint_query, endpoint_params).fetchall()
|
||||||
|
|
||||||
# 按天统计
|
# 按天统计
|
||||||
daily_query = """
|
daily_query = """
|
||||||
SELECT
|
SELECT
|
||||||
date(created_at) as date,
|
date(created_at) as date,
|
||||||
COUNT(*) as calls,
|
COUNT(*) as calls,
|
||||||
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success
|
COUNT(CASE WHEN status_code < 400 THEN 1 END) as success
|
||||||
FROM api_call_logs
|
FROM api_call_logs
|
||||||
WHERE created_at >= date('now', '-{} days')
|
WHERE created_at >= date('now', '-{} days')
|
||||||
""".format(days)
|
""".format(days)
|
||||||
|
|
||||||
daily_params = []
|
daily_params = []
|
||||||
if api_key_id:
|
if api_key_id:
|
||||||
daily_query = daily_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
daily_query = daily_query.replace("WHERE created_at", "WHERE api_key_id = ? AND created_at")
|
||||||
daily_params.insert(0, api_key_id)
|
daily_params.insert(0, api_key_id)
|
||||||
|
|
||||||
daily_query += " GROUP BY date(created_at) ORDER BY date"
|
daily_query += " GROUP BY date(created_at) ORDER BY date"
|
||||||
|
|
||||||
daily_rows = conn.execute(daily_query, daily_params).fetchall()
|
daily_rows = conn.execute(daily_query, daily_params).fetchall()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"summary": {
|
"summary": {
|
||||||
"total_calls": row["total_calls"] or 0,
|
"total_calls": row["total_calls"] or 0,
|
||||||
@@ -494,9 +476,9 @@ class ApiKeyManager:
|
|||||||
"min_response_time_ms": row["min_response_time"] or 0,
|
"min_response_time_ms": row["min_response_time"] or 0,
|
||||||
},
|
},
|
||||||
"endpoints": [dict(r) for r in endpoint_rows],
|
"endpoints": [dict(r) for r in endpoint_rows],
|
||||||
"daily": [dict(r) for r in daily_rows]
|
"daily": [dict(r) for r in daily_rows],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
|
def _row_to_api_key(self, row: sqlite3.Row) -> ApiKey:
|
||||||
"""将数据库行转换为 ApiKey 对象"""
|
"""将数据库行转换为 ApiKey 对象"""
|
||||||
return ApiKey(
|
return ApiKey(
|
||||||
@@ -513,7 +495,7 @@ class ApiKeyManager:
|
|||||||
last_used_at=row["last_used_at"],
|
last_used_at=row["last_used_at"],
|
||||||
revoked_at=row["revoked_at"],
|
revoked_at=row["revoked_at"],
|
||||||
revoked_reason=row["revoked_reason"],
|
revoked_reason=row["revoked_reason"],
|
||||||
total_calls=row["total_calls"]
|
total_calls=row["total_calls"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -6,66 +6,65 @@ Document Processor - Phase 3
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
from typing import Dict, Optional
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
class DocumentProcessor:
|
class DocumentProcessor:
|
||||||
"""文档处理器 - 提取 PDF/DOCX 文本"""
|
"""文档处理器 - 提取 PDF/DOCX 文本"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.supported_formats = {
|
self.supported_formats = {
|
||||||
'.pdf': self._extract_pdf,
|
".pdf": self._extract_pdf,
|
||||||
'.docx': self._extract_docx,
|
".docx": self._extract_docx,
|
||||||
'.doc': self._extract_docx,
|
".doc": self._extract_docx,
|
||||||
'.txt': self._extract_txt,
|
".txt": self._extract_txt,
|
||||||
'.md': self._extract_txt,
|
".md": self._extract_txt,
|
||||||
}
|
}
|
||||||
|
|
||||||
def process(self, content: bytes, filename: str) -> Dict[str, str]:
|
def process(self, content: bytes, filename: str) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
处理文档并提取文本
|
处理文档并提取文本
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 文件二进制内容
|
content: 文件二进制内容
|
||||||
filename: 文件名
|
filename: 文件名
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
{"text": "提取的文本内容", "format": "文件格式"}
|
{"text": "提取的文本内容", "format": "文件格式"}
|
||||||
"""
|
"""
|
||||||
ext = os.path.splitext(filename.lower())[1]
|
ext = os.path.splitext(filename.lower())[1]
|
||||||
|
|
||||||
if ext not in self.supported_formats:
|
if ext not in self.supported_formats:
|
||||||
raise ValueError(f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}")
|
raise ValueError(f"Unsupported file format: {ext}. Supported: {list(self.supported_formats.keys())}")
|
||||||
|
|
||||||
extractor = self.supported_formats[ext]
|
extractor = self.supported_formats[ext]
|
||||||
text = extractor(content)
|
text = extractor(content)
|
||||||
|
|
||||||
# 清理文本
|
# 清理文本
|
||||||
text = self._clean_text(text)
|
text = self._clean_text(text)
|
||||||
|
|
||||||
return {
|
return {"text": text, "format": ext, "filename": filename}
|
||||||
"text": text,
|
|
||||||
"format": ext,
|
|
||||||
"filename": filename
|
|
||||||
}
|
|
||||||
|
|
||||||
def _extract_pdf(self, content: bytes) -> str:
|
def _extract_pdf(self, content: bytes) -> str:
|
||||||
"""提取 PDF 文本"""
|
"""提取 PDF 文本"""
|
||||||
try:
|
try:
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
|
|
||||||
pdf_file = io.BytesIO(content)
|
pdf_file = io.BytesIO(content)
|
||||||
reader = PyPDF2.PdfReader(pdf_file)
|
reader = PyPDF2.PdfReader(pdf_file)
|
||||||
|
|
||||||
text_parts = []
|
text_parts = []
|
||||||
for page in reader.pages:
|
for page in reader.pages:
|
||||||
page_text = page.extract_text()
|
page_text = page.extract_text()
|
||||||
if page_text:
|
if page_text:
|
||||||
text_parts.append(page_text)
|
text_parts.append(page_text)
|
||||||
|
|
||||||
return "\n\n".join(text_parts)
|
return "\n\n".join(text_parts)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback: 尝试使用 pdfplumber
|
# Fallback: 尝试使用 pdfplumber
|
||||||
try:
|
try:
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
|
|
||||||
text_parts = []
|
text_parts = []
|
||||||
with pdfplumber.open(io.BytesIO(content)) as pdf:
|
with pdfplumber.open(io.BytesIO(content)) as pdf:
|
||||||
for page in pdf.pages:
|
for page in pdf.pages:
|
||||||
@@ -77,19 +76,20 @@ class DocumentProcessor:
|
|||||||
raise ImportError("PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2")
|
raise ImportError("PDF processing requires PyPDF2 or pdfplumber. Install with: pip install PyPDF2")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"PDF extraction failed: {str(e)}")
|
raise ValueError(f"PDF extraction failed: {str(e)}")
|
||||||
|
|
||||||
def _extract_docx(self, content: bytes) -> str:
|
def _extract_docx(self, content: bytes) -> str:
|
||||||
"""提取 DOCX 文本"""
|
"""提取 DOCX 文本"""
|
||||||
try:
|
try:
|
||||||
import docx
|
import docx
|
||||||
|
|
||||||
doc_file = io.BytesIO(content)
|
doc_file = io.BytesIO(content)
|
||||||
doc = docx.Document(doc_file)
|
doc = docx.Document(doc_file)
|
||||||
|
|
||||||
text_parts = []
|
text_parts = []
|
||||||
for para in doc.paragraphs:
|
for para in doc.paragraphs:
|
||||||
if para.text.strip():
|
if para.text.strip():
|
||||||
text_parts.append(para.text)
|
text_parts.append(para.text)
|
||||||
|
|
||||||
# 提取表格中的文本
|
# 提取表格中的文本
|
||||||
for table in doc.tables:
|
for table in doc.tables:
|
||||||
for row in table.rows:
|
for row in table.rows:
|
||||||
@@ -99,53 +99,53 @@ class DocumentProcessor:
|
|||||||
row_text.append(cell.text.strip())
|
row_text.append(cell.text.strip())
|
||||||
if row_text:
|
if row_text:
|
||||||
text_parts.append(" | ".join(row_text))
|
text_parts.append(" | ".join(row_text))
|
||||||
|
|
||||||
return "\n\n".join(text_parts)
|
return "\n\n".join(text_parts)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("DOCX processing requires python-docx. Install with: pip install python-docx")
|
raise ImportError("DOCX processing requires python-docx. Install with: pip install python-docx")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"DOCX extraction failed: {str(e)}")
|
raise ValueError(f"DOCX extraction failed: {str(e)}")
|
||||||
|
|
||||||
def _extract_txt(self, content: bytes) -> str:
|
def _extract_txt(self, content: bytes) -> str:
|
||||||
"""提取纯文本"""
|
"""提取纯文本"""
|
||||||
# 尝试多种编码
|
# 尝试多种编码
|
||||||
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
|
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
|
||||||
|
|
||||||
for encoding in encodings:
|
for encoding in encodings:
|
||||||
try:
|
try:
|
||||||
return content.decode(encoding)
|
return content.decode(encoding)
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果都失败了,使用 latin-1 并忽略错误
|
# 如果都失败了,使用 latin-1 并忽略错误
|
||||||
return content.decode('latin-1', errors='ignore')
|
return content.decode("latin-1", errors="ignore")
|
||||||
|
|
||||||
def _clean_text(self, text: str) -> str:
|
def _clean_text(self, text: str) -> str:
|
||||||
"""清理提取的文本"""
|
"""清理提取的文本"""
|
||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 移除多余的空白字符
|
# 移除多余的空白字符
|
||||||
lines = text.split('\n')
|
lines = text.split("\n")
|
||||||
cleaned_lines = []
|
cleaned_lines = []
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
# 移除空行,但保留段落分隔
|
# 移除空行,但保留段落分隔
|
||||||
if line:
|
if line:
|
||||||
cleaned_lines.append(line)
|
cleaned_lines.append(line)
|
||||||
|
|
||||||
# 合并行,保留段落结构
|
# 合并行,保留段落结构
|
||||||
text = '\n\n'.join(cleaned_lines)
|
text = "\n\n".join(cleaned_lines)
|
||||||
|
|
||||||
# 移除多余的空格
|
# 移除多余的空格
|
||||||
text = ' '.join(text.split())
|
text = " ".join(text.split())
|
||||||
|
|
||||||
# 移除控制字符
|
# 移除控制字符
|
||||||
text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\r\t')
|
text = "".join(char for char in text if ord(char) >= 32 or char in "\n\r\t")
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
def is_supported(self, filename: str) -> bool:
|
def is_supported(self, filename: str) -> bool:
|
||||||
"""检查文件格式是否支持"""
|
"""检查文件格式是否支持"""
|
||||||
ext = os.path.splitext(filename.lower())[1]
|
ext = os.path.splitext(filename.lower())[1]
|
||||||
@@ -155,26 +155,26 @@ class DocumentProcessor:
|
|||||||
# 简单的文本提取器(不需要外部依赖)
|
# 简单的文本提取器(不需要外部依赖)
|
||||||
class SimpleTextExtractor:
|
class SimpleTextExtractor:
|
||||||
"""简单的文本提取器,用于测试"""
|
"""简单的文本提取器,用于测试"""
|
||||||
|
|
||||||
def extract(self, content: bytes, filename: str) -> str:
|
def extract(self, content: bytes, filename: str) -> str:
|
||||||
"""尝试提取文本"""
|
"""尝试提取文本"""
|
||||||
encodings = ['utf-8', 'gbk', 'latin-1']
|
encodings = ["utf-8", "gbk", "latin-1"]
|
||||||
|
|
||||||
for encoding in encodings:
|
for encoding in encodings:
|
||||||
try:
|
try:
|
||||||
return content.decode(encoding)
|
return content.decode(encoding)
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return content.decode('latin-1', errors='ignore')
|
return content.decode("latin-1", errors="ignore")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 测试
|
# 测试
|
||||||
processor = DocumentProcessor()
|
processor = DocumentProcessor()
|
||||||
|
|
||||||
# 测试文本提取
|
# 测试文本提取
|
||||||
test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs."
|
test_text = "Hello World\n\nThis is a test document.\n\nMultiple paragraphs."
|
||||||
result = processor.process(test_text.encode('utf-8'), "test.txt")
|
result = processor.process(test_text.encode("utf-8"), "test.txt")
|
||||||
print(f"Text extraction test: {len(result['text'])} chars")
|
print(f"Text extraction test: {len(result['text'])} chars")
|
||||||
print(result['text'][:100])
|
print(result["text"][:100])
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,7 @@ from dataclasses import dataclass
|
|||||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||||
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityEmbedding:
|
class EntityEmbedding:
|
||||||
entity_id: str
|
entity_id: str
|
||||||
@@ -22,177 +23,173 @@ class EntityEmbedding:
|
|||||||
definition: str
|
definition: str
|
||||||
embedding: List[float]
|
embedding: List[float]
|
||||||
|
|
||||||
|
|
||||||
class EntityAligner:
|
class EntityAligner:
|
||||||
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
|
"""实体对齐器 - 使用 embedding 进行相似度匹配"""
|
||||||
|
|
||||||
def __init__(self, similarity_threshold: float = 0.85):
|
def __init__(self, similarity_threshold: float = 0.85):
|
||||||
self.similarity_threshold = similarity_threshold
|
self.similarity_threshold = similarity_threshold
|
||||||
self.embedding_cache: Dict[str, List[float]] = {}
|
self.embedding_cache: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> Optional[List[float]]:
|
def get_embedding(self, text: str) -> Optional[List[float]]:
|
||||||
"""
|
"""
|
||||||
使用 Kimi API 获取文本的 embedding
|
使用 Kimi API 获取文本的 embedding
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
embedding 向量或 None
|
embedding 向量或 None
|
||||||
"""
|
"""
|
||||||
if not KIMI_API_KEY:
|
if not KIMI_API_KEY:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
cache_key = hash(text)
|
cache_key = hash(text)
|
||||||
if cache_key in self.embedding_cache:
|
if cache_key in self.embedding_cache:
|
||||||
return self.embedding_cache[cache_key]
|
return self.embedding_cache[cache_key]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = httpx.post(
|
response = httpx.post(
|
||||||
f"{KIMI_BASE_URL}/v1/embeddings",
|
f"{KIMI_BASE_URL}/v1/embeddings",
|
||||||
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
||||||
json={
|
json={"model": "k2p5", "input": text[:500]}, # 限制长度
|
||||||
"model": "k2p5",
|
timeout=30.0,
|
||||||
"input": text[:500] # 限制长度
|
|
||||||
},
|
|
||||||
timeout=30.0
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
embedding = result["data"][0]["embedding"]
|
embedding = result["data"][0]["embedding"]
|
||||||
self.embedding_cache[cache_key] = embedding
|
self.embedding_cache[cache_key] = embedding
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Embedding API failed: {e}")
|
print(f"Embedding API failed: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
||||||
"""
|
"""
|
||||||
计算两个 embedding 的余弦相似度
|
计算两个 embedding 的余弦相似度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding1: 第一个向量
|
embedding1: 第一个向量
|
||||||
embedding2: 第二个向量
|
embedding2: 第二个向量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似度分数 (0-1)
|
相似度分数 (0-1)
|
||||||
"""
|
"""
|
||||||
vec1 = np.array(embedding1)
|
vec1 = np.array(embedding1)
|
||||||
vec2 = np.array(embedding2)
|
vec2 = np.array(embedding2)
|
||||||
|
|
||||||
# 余弦相似度
|
# 余弦相似度
|
||||||
dot_product = np.dot(vec1, vec2)
|
dot_product = np.dot(vec1, vec2)
|
||||||
norm1 = np.linalg.norm(vec1)
|
norm1 = np.linalg.norm(vec1)
|
||||||
norm2 = np.linalg.norm(vec2)
|
norm2 = np.linalg.norm(vec2)
|
||||||
|
|
||||||
if norm1 == 0 or norm2 == 0:
|
if norm1 == 0 or norm2 == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
return float(dot_product / (norm1 * norm2))
|
return float(dot_product / (norm1 * norm2))
|
||||||
|
|
||||||
def get_entity_text(self, name: str, definition: str = "") -> str:
|
def get_entity_text(self, name: str, definition: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
构建用于 embedding 的实体文本
|
构建用于 embedding 的实体文本
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 实体名称
|
name: 实体名称
|
||||||
definition: 实体定义
|
definition: 实体定义
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
组合文本
|
组合文本
|
||||||
"""
|
"""
|
||||||
if definition:
|
if definition:
|
||||||
return f"{name}: {definition}"
|
return f"{name}: {definition}"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def find_similar_entity(
|
def find_similar_entity(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
definition: str = "",
|
definition: str = "",
|
||||||
exclude_id: Optional[str] = None,
|
exclude_id: Optional[str] = None,
|
||||||
threshold: Optional[float] = None
|
threshold: Optional[float] = None,
|
||||||
) -> Optional[object]:
|
) -> Optional[object]:
|
||||||
"""
|
"""
|
||||||
查找相似的实体
|
查找相似的实体
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目 ID
|
project_id: 项目 ID
|
||||||
name: 实体名称
|
name: 实体名称
|
||||||
definition: 实体定义
|
definition: 实体定义
|
||||||
exclude_id: 要排除的实体 ID
|
exclude_id: 要排除的实体 ID
|
||||||
threshold: 相似度阈值
|
threshold: 相似度阈值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似的实体或 None
|
相似的实体或 None
|
||||||
"""
|
"""
|
||||||
if threshold is None:
|
if threshold is None:
|
||||||
threshold = self.similarity_threshold
|
threshold = self.similarity_threshold
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from db_manager import get_db_manager
|
from db_manager import get_db_manager
|
||||||
|
|
||||||
db = get_db_manager()
|
db = get_db_manager()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取项目的所有实体
|
# 获取项目的所有实体
|
||||||
entities = db.get_all_entities_for_embedding(project_id)
|
entities = db.get_all_entities_for_embedding(project_id)
|
||||||
|
|
||||||
if not entities:
|
if not entities:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取查询实体的 embedding
|
# 获取查询实体的 embedding
|
||||||
query_text = self.get_entity_text(name, definition)
|
query_text = self.get_entity_text(name, definition)
|
||||||
query_embedding = self.get_embedding(query_text)
|
query_embedding = self.get_embedding(query_text)
|
||||||
|
|
||||||
if query_embedding is None:
|
if query_embedding is None:
|
||||||
# 如果 embedding API 失败,回退到简单匹配
|
# 如果 embedding API 失败,回退到简单匹配
|
||||||
return self._fallback_similarity_match(entities, name, exclude_id)
|
return self._fallback_similarity_match(entities, name, exclude_id)
|
||||||
|
|
||||||
best_match = None
|
best_match = None
|
||||||
best_score = threshold
|
best_score = threshold
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if exclude_id and entity.id == exclude_id:
|
if exclude_id and entity.id == exclude_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取实体的 embedding
|
# 获取实体的 embedding
|
||||||
entity_text = self.get_entity_text(entity.name, entity.definition)
|
entity_text = self.get_entity_text(entity.name, entity.definition)
|
||||||
entity_embedding = self.get_embedding(entity_text)
|
entity_embedding = self.get_embedding(entity_text)
|
||||||
|
|
||||||
if entity_embedding is None:
|
if entity_embedding is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算相似度
|
# 计算相似度
|
||||||
similarity = self.compute_similarity(query_embedding, entity_embedding)
|
similarity = self.compute_similarity(query_embedding, entity_embedding)
|
||||||
|
|
||||||
if similarity > best_score:
|
if similarity > best_score:
|
||||||
best_score = similarity
|
best_score = similarity
|
||||||
best_match = entity
|
best_match = entity
|
||||||
|
|
||||||
return best_match
|
return best_match
|
||||||
|
|
||||||
def _fallback_similarity_match(
|
def _fallback_similarity_match(
|
||||||
self,
|
self, entities: List[object], name: str, exclude_id: Optional[str] = None
|
||||||
entities: List[object],
|
|
||||||
name: str,
|
|
||||||
exclude_id: Optional[str] = None
|
|
||||||
) -> Optional[object]:
|
) -> Optional[object]:
|
||||||
"""
|
"""
|
||||||
回退到简单的相似度匹配(不使用 embedding)
|
回退到简单的相似度匹配(不使用 embedding)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entities: 实体列表
|
entities: 实体列表
|
||||||
name: 查询名称
|
name: 查询名称
|
||||||
exclude_id: 要排除的实体 ID
|
exclude_id: 要排除的实体 ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
最相似的实体或 None
|
最相似的实体或 None
|
||||||
"""
|
"""
|
||||||
name_lower = name.lower()
|
name_lower = name.lower()
|
||||||
|
|
||||||
# 1. 精确匹配
|
# 1. 精确匹配
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if exclude_id and entity.id == exclude_id:
|
if exclude_id and entity.id == exclude_id:
|
||||||
@@ -201,90 +198,79 @@ class EntityAligner:
|
|||||||
return entity
|
return entity
|
||||||
if entity.aliases and name_lower in [a.lower() for a in entity.aliases]:
|
if entity.aliases and name_lower in [a.lower() for a in entity.aliases]:
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
# 2. 包含匹配
|
# 2. 包含匹配
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if exclude_id and entity.id == exclude_id:
|
if exclude_id and entity.id == exclude_id:
|
||||||
continue
|
continue
|
||||||
if name_lower in entity.name.lower() or entity.name.lower() in name_lower:
|
if name_lower in entity.name.lower() or entity.name.lower() in name_lower:
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def batch_align_entities(
|
def batch_align_entities(
|
||||||
self,
|
self, project_id: str, new_entities: List[Dict], threshold: Optional[float] = None
|
||||||
project_id: str,
|
|
||||||
new_entities: List[Dict],
|
|
||||||
threshold: Optional[float] = None
|
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
批量对齐实体
|
批量对齐实体
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目 ID
|
project_id: 项目 ID
|
||||||
new_entities: 新实体列表 [{"name": "...", "definition": "..."}]
|
new_entities: 新实体列表 [{"name": "...", "definition": "..."}]
|
||||||
threshold: 相似度阈值
|
threshold: 相似度阈值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
|
对齐结果列表 [{"new_entity": {...}, "matched_entity": {...}, "similarity": 0.9}]
|
||||||
"""
|
"""
|
||||||
if threshold is None:
|
if threshold is None:
|
||||||
threshold = self.similarity_threshold
|
threshold = self.similarity_threshold
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for new_ent in new_entities:
|
for new_ent in new_entities:
|
||||||
matched = self.find_similar_entity(
|
matched = self.find_similar_entity(
|
||||||
project_id,
|
project_id, new_ent["name"], new_ent.get("definition", ""), threshold=threshold
|
||||||
new_ent["name"],
|
|
||||||
new_ent.get("definition", ""),
|
|
||||||
threshold=threshold
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {"new_entity": new_ent, "matched_entity": None, "similarity": 0.0, "should_merge": False}
|
||||||
"new_entity": new_ent,
|
|
||||||
"matched_entity": None,
|
|
||||||
"similarity": 0.0,
|
|
||||||
"should_merge": False
|
|
||||||
}
|
|
||||||
|
|
||||||
if matched:
|
if matched:
|
||||||
# 计算相似度
|
# 计算相似度
|
||||||
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
|
query_text = self.get_entity_text(new_ent["name"], new_ent.get("definition", ""))
|
||||||
matched_text = self.get_entity_text(matched.name, matched.definition)
|
matched_text = self.get_entity_text(matched.name, matched.definition)
|
||||||
|
|
||||||
query_emb = self.get_embedding(query_text)
|
query_emb = self.get_embedding(query_text)
|
||||||
matched_emb = self.get_embedding(matched_text)
|
matched_emb = self.get_embedding(matched_text)
|
||||||
|
|
||||||
if query_emb and matched_emb:
|
if query_emb and matched_emb:
|
||||||
similarity = self.compute_similarity(query_emb, matched_emb)
|
similarity = self.compute_similarity(query_emb, matched_emb)
|
||||||
result["matched_entity"] = {
|
result["matched_entity"] = {
|
||||||
"id": matched.id,
|
"id": matched.id,
|
||||||
"name": matched.name,
|
"name": matched.name,
|
||||||
"type": matched.type,
|
"type": matched.type,
|
||||||
"definition": matched.definition
|
"definition": matched.definition,
|
||||||
}
|
}
|
||||||
result["similarity"] = similarity
|
result["similarity"] = similarity
|
||||||
result["should_merge"] = similarity >= threshold
|
result["should_merge"] = similarity >= threshold
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> List[str]:
|
def suggest_entity_aliases(self, entity_name: str, entity_definition: str = "") -> List[str]:
|
||||||
"""
|
"""
|
||||||
使用 LLM 建议实体的别名
|
使用 LLM 建议实体的别名
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity_name: 实体名称
|
entity_name: 实体名称
|
||||||
entity_definition: 实体定义
|
entity_definition: 实体定义
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
建议的别名列表
|
建议的别名列表
|
||||||
"""
|
"""
|
||||||
if not KIMI_API_KEY:
|
if not KIMI_API_KEY:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
prompt = f"""为以下实体生成可能的别名或简称:
|
prompt = f"""为以下实体生成可能的别名或简称:
|
||||||
|
|
||||||
实体名称:{entity_name}
|
实体名称:{entity_name}
|
||||||
@@ -294,30 +280,27 @@ class EntityAligner:
|
|||||||
{{"aliases": ["别名1", "别名2", "别名3"]}}
|
{{"aliases": ["别名1", "别名2", "别名3"]}}
|
||||||
|
|
||||||
只返回 JSON,不要其他内容。"""
|
只返回 JSON,不要其他内容。"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = httpx.post(
|
response = httpx.post(
|
||||||
f"{KIMI_BASE_URL}/v1/chat/completions",
|
f"{KIMI_BASE_URL}/v1/chat/completions",
|
||||||
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
||||||
json={
|
json={"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": 0.3},
|
||||||
"model": "k2p5",
|
timeout=30.0,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.3
|
|
||||||
},
|
|
||||||
timeout=30.0
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
content = result["choices"][0]["message"]["content"]
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
data = json.loads(json_match.group())
|
data = json.loads(json_match.group())
|
||||||
return data.get("aliases", [])
|
return data.get("aliases", [])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Alias suggestion failed: {e}")
|
print(f"Alias suggestion failed: {e}")
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -325,37 +308,38 @@ class EntityAligner:
|
|||||||
def simple_similarity(str1: str, str2: str) -> float:
|
def simple_similarity(str1: str, str2: str) -> float:
|
||||||
"""
|
"""
|
||||||
计算两个字符串的简单相似度
|
计算两个字符串的简单相似度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
str1: 第一个字符串
|
str1: 第一个字符串
|
||||||
str2: 第二个字符串
|
str2: 第二个字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似度分数 (0-1)
|
相似度分数 (0-1)
|
||||||
"""
|
"""
|
||||||
if str1 == str2:
|
if str1 == str2:
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
if not str1 or not str2:
|
if not str1 or not str2:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# 转换为小写
|
# 转换为小写
|
||||||
s1 = str1.lower()
|
s1 = str1.lower()
|
||||||
s2 = str2.lower()
|
s2 = str2.lower()
|
||||||
|
|
||||||
# 包含关系
|
# 包含关系
|
||||||
if s1 in s2 or s2 in s1:
|
if s1 in s2 or s2 in s1:
|
||||||
return 0.8
|
return 0.8
|
||||||
|
|
||||||
# 计算编辑距离相似度
|
# 计算编辑距离相似度
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
return SequenceMatcher(None, s1, s2).ratio()
|
return SequenceMatcher(None, s1, s2).ratio()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 测试
|
# 测试
|
||||||
aligner = EntityAligner()
|
aligner = EntityAligner()
|
||||||
|
|
||||||
# 测试 embedding
|
# 测试 embedding
|
||||||
test_text = "Kubernetes 容器编排平台"
|
test_text = "Kubernetes 容器编排平台"
|
||||||
embedding = aligner.get_embedding(test_text)
|
embedding = aligner.get_embedding(test_text)
|
||||||
@@ -364,7 +348,7 @@ if __name__ == "__main__":
|
|||||||
print(f"First 5 values: {embedding[:5]}")
|
print(f"First 5 values: {embedding[:5]}")
|
||||||
else:
|
else:
|
||||||
print("Embedding API not available")
|
print("Embedding API not available")
|
||||||
|
|
||||||
# 测试相似度计算
|
# 测试相似度计算
|
||||||
emb1 = [1.0, 0.0, 0.0]
|
emb1 = [1.0, 0.0, 0.0]
|
||||||
emb2 = [0.9, 0.1, 0.0]
|
emb2 = [0.9, 0.1, 0.0]
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ InsightFlow Export Module - Phase 5
|
|||||||
支持导出知识图谱、项目报告、实体数据和转录文本
|
支持导出知识图谱、项目报告、实体数据和转录文本
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Optional, Any
|
from typing import List, Dict, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
PANDAS_AVAILABLE = True
|
PANDAS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PANDAS_AVAILABLE = False
|
PANDAS_AVAILABLE = False
|
||||||
@@ -23,8 +23,7 @@ try:
|
|||||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||||
from reportlab.lib.units import inch
|
from reportlab.lib.units import inch
|
||||||
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
|
||||||
from reportlab.pdfbase import pdfmetrics
|
|
||||||
from reportlab.pdfbase.ttfonts import TTFont
|
|
||||||
REPORTLAB_AVAILABLE = True
|
REPORTLAB_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
REPORTLAB_AVAILABLE = False
|
REPORTLAB_AVAILABLE = False
|
||||||
@@ -63,15 +62,16 @@ class ExportTranscript:
|
|||||||
|
|
||||||
class ExportManager:
|
class ExportManager:
|
||||||
"""导出管理器 - 处理各种导出需求"""
|
"""导出管理器 - 处理各种导出需求"""
|
||||||
|
|
||||||
def __init__(self, db_manager=None):
|
def __init__(self, db_manager=None):
|
||||||
self.db = db_manager
|
self.db = db_manager
|
||||||
|
|
||||||
def export_knowledge_graph_svg(self, project_id: str, entities: List[ExportEntity],
|
def export_knowledge_graph_svg(
|
||||||
relations: List[ExportRelation]) -> str:
|
self, project_id: str, entities: List[ExportEntity], relations: List[ExportRelation]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
导出知识图谱为 SVG 格式
|
导出知识图谱为 SVG 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SVG 字符串
|
SVG 字符串
|
||||||
"""
|
"""
|
||||||
@@ -81,14 +81,14 @@ class ExportManager:
|
|||||||
center_x = width / 2
|
center_x = width / 2
|
||||||
center_y = height / 2
|
center_y = height / 2
|
||||||
radius = 300
|
radius = 300
|
||||||
|
|
||||||
# 按类型分组实体
|
# 按类型分组实体
|
||||||
entities_by_type = {}
|
entities_by_type = {}
|
||||||
for e in entities:
|
for e in entities:
|
||||||
if e.type not in entities_by_type:
|
if e.type not in entities_by_type:
|
||||||
entities_by_type[e.type] = []
|
entities_by_type[e.type] = []
|
||||||
entities_by_type[e.type].append(e)
|
entities_by_type[e.type].append(e)
|
||||||
|
|
||||||
# 颜色映射
|
# 颜色映射
|
||||||
type_colors = {
|
type_colors = {
|
||||||
"PERSON": "#FF6B6B",
|
"PERSON": "#FF6B6B",
|
||||||
@@ -98,37 +98,37 @@ class ExportManager:
|
|||||||
"TECHNOLOGY": "#FFEAA7",
|
"TECHNOLOGY": "#FFEAA7",
|
||||||
"EVENT": "#DDA0DD",
|
"EVENT": "#DDA0DD",
|
||||||
"CONCEPT": "#98D8C8",
|
"CONCEPT": "#98D8C8",
|
||||||
"default": "#BDC3C7"
|
"default": "#BDC3C7",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算实体位置
|
# 计算实体位置
|
||||||
entity_positions = {}
|
entity_positions = {}
|
||||||
angle_step = 2 * 3.14159 / max(len(entities), 1)
|
angle_step = 2 * 3.14159 / max(len(entities), 1)
|
||||||
|
|
||||||
for i, entity in enumerate(entities):
|
for i, entity in enumerate(entities):
|
||||||
angle = i * angle_step
|
i * angle_step
|
||||||
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
|
x = center_x + radius * 0.8 * (i % 3 - 1) * 150 + (i // 3) * 50
|
||||||
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
|
y = center_y + radius * 0.6 * ((i % 6) - 3) * 80
|
||||||
entity_positions[entity.id] = (x, y)
|
entity_positions[entity.id] = (x, y)
|
||||||
|
|
||||||
# 生成 SVG
|
# 生成 SVG
|
||||||
svg_parts = [
|
svg_parts = [
|
||||||
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
|
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
|
||||||
'<defs>',
|
"<defs>",
|
||||||
' <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">',
|
' <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">',
|
||||||
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>',
|
' <polygon points="0 0, 10 3.5, 0 7" fill="#7f8c8d"/>',
|
||||||
' </marker>',
|
" </marker>",
|
||||||
'</defs>',
|
"</defs>",
|
||||||
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>',
|
f'<rect width="{width}" height="{height}" fill="#f8f9fa"/>',
|
||||||
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>',
|
f'<text x="{center_x}" y="30" text-anchor="middle" font-size="20" font-weight="bold" fill="#2c3e50">知识图谱 - {project_id}</text>',
|
||||||
]
|
]
|
||||||
|
|
||||||
# 绘制关系连线
|
# 绘制关系连线
|
||||||
for rel in relations:
|
for rel in relations:
|
||||||
if rel.source in entity_positions and rel.target in entity_positions:
|
if rel.source in entity_positions and rel.target in entity_positions:
|
||||||
x1, y1 = entity_positions[rel.source]
|
x1, y1 = entity_positions[rel.source]
|
||||||
x2, y2 = entity_positions[rel.target]
|
x2, y2 = entity_positions[rel.target]
|
||||||
|
|
||||||
# 计算箭头终点(避免覆盖节点)
|
# 计算箭头终点(避免覆盖节点)
|
||||||
dx = x2 - x1
|
dx = x2 - x1
|
||||||
dy = y2 - y1
|
dy = y2 - y1
|
||||||
@@ -137,115 +137,128 @@ class ExportManager:
|
|||||||
offset = 40
|
offset = 40
|
||||||
x2 = x2 - dx * offset / dist
|
x2 = x2 - dx * offset / dist
|
||||||
y2 = y2 - dy * offset / dist
|
y2 = y2 - dy * offset / dist
|
||||||
|
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" '
|
f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" '
|
||||||
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>'
|
f'stroke="#7f8c8d" stroke-width="2" marker-end="url(#arrowhead)" opacity="0.6"/>'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 关系标签
|
# 关系标签
|
||||||
mid_x = (x1 + x2) / 2
|
mid_x = (x1 + x2) / 2
|
||||||
mid_y = (y1 + y2) / 2
|
mid_y = (y1 + y2) / 2
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<rect x="{mid_x-30}" y="{mid_y-10}" width="60" height="20" '
|
f'<rect x="{mid_x - 30}" y="{mid_y - 10}" width="60" height="20" '
|
||||||
f'fill="white" stroke="#bdc3c7" rx="3"/>'
|
f'fill="white" stroke="#bdc3c7" rx="3"/>'
|
||||||
)
|
)
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<text x="{mid_x}" y="{mid_y+5}" text-anchor="middle" '
|
f'<text x="{mid_x}" y="{mid_y + 5}" text-anchor="middle" '
|
||||||
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>'
|
f'font-size="10" fill="#2c3e50">{rel.relation_type}</text>'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 绘制实体节点
|
# 绘制实体节点
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if entity.id in entity_positions:
|
if entity.id in entity_positions:
|
||||||
x, y = entity_positions[entity.id]
|
x, y = entity_positions[entity.id]
|
||||||
color = type_colors.get(entity.type, type_colors["default"])
|
color = type_colors.get(entity.type, type_colors["default"])
|
||||||
|
|
||||||
# 节点圆圈
|
# 节点圆圈
|
||||||
svg_parts.append(
|
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>')
|
||||||
f'<circle cx="{x}" cy="{y}" r="35" fill="{color}" stroke="white" stroke-width="3"/>'
|
|
||||||
)
|
|
||||||
|
|
||||||
# 实体名称
|
# 实体名称
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<text x="{x}" y="{y+5}" text-anchor="middle" font-size="12" '
|
f'<text x="{x}" y="{y + 5}" text-anchor="middle" font-size="12" '
|
||||||
f'font-weight="bold" fill="white">{entity.name[:8]}</text>'
|
f'font-weight="bold" fill="white">{entity.name[:8]}</text>'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 实体类型
|
# 实体类型
|
||||||
svg_parts.append(
|
svg_parts.append(
|
||||||
f'<text x="{x}" y="{y+55}" text-anchor="middle" font-size="10" '
|
f'<text x="{x}" y="{y + 55}" text-anchor="middle" font-size="10" '
|
||||||
f'fill="#7f8c8d">{entity.type}</text>'
|
f'fill="#7f8c8d">{entity.type}</text>'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 图例
|
# 图例
|
||||||
legend_x = width - 150
|
legend_x = width - 150
|
||||||
legend_y = 80
|
legend_y = 80
|
||||||
svg_parts.append(f'<rect x="{legend_x-10}" y="{legend_y-20}" width="140" height="{len(type_colors)*25+10}" fill="white" stroke="#bdc3c7" rx="5"/>')
|
svg_parts.append(f'<rect x="{
|
||||||
svg_parts.append(f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" fill="#2c3e50">实体类型</text>')
|
legend_x -
|
||||||
|
10}" y="{
|
||||||
|
legend_y -
|
||||||
|
20}" width="140" height="{
|
||||||
|
len(type_colors) *
|
||||||
|
25 +
|
||||||
|
10}" fill="white" stroke="#bdc3c7" rx="5"/>')
|
||||||
|
svg_parts.append(
|
||||||
|
f'<text x="{legend_x}" y="{legend_y}" font-size="12" font-weight="bold" fill="#2c3e50">实体类型</text>'
|
||||||
|
)
|
||||||
|
|
||||||
for i, (etype, color) in enumerate(type_colors.items()):
|
for i, (etype, color) in enumerate(type_colors.items()):
|
||||||
if etype != "default":
|
if etype != "default":
|
||||||
y_pos = legend_y + 25 + i * 20
|
y_pos = legend_y + 25 + i * 20
|
||||||
svg_parts.append(f'<circle cx="{legend_x+10}" cy="{y_pos}" r="8" fill="{color}"/>')
|
svg_parts.append(f'<circle cx="{legend_x + 10}" cy="{y_pos}" r="8" fill="{color}"/>')
|
||||||
svg_parts.append(f'<text x="{legend_x+25}" y="{y_pos+4}" font-size="10" fill="#2c3e50">{etype}</text>')
|
svg_parts.append(f'<text x="{
|
||||||
|
legend_x +
|
||||||
svg_parts.append('</svg>')
|
25}" y="{
|
||||||
return '\n'.join(svg_parts)
|
y_pos +
|
||||||
|
4}" font-size="10" fill="#2c3e50">{etype}</text>')
|
||||||
def export_knowledge_graph_png(self, project_id: str, entities: List[ExportEntity],
|
|
||||||
relations: List[ExportRelation]) -> bytes:
|
svg_parts.append("</svg>")
|
||||||
|
return "\n".join(svg_parts)
|
||||||
|
|
||||||
|
def export_knowledge_graph_png(
|
||||||
|
self, project_id: str, entities: List[ExportEntity], relations: List[ExportRelation]
|
||||||
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
导出知识图谱为 PNG 格式
|
导出知识图谱为 PNG 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PNG 图像字节
|
PNG 图像字节
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import cairosvg
|
import cairosvg
|
||||||
|
|
||||||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||||||
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
png_bytes = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
|
||||||
return png_bytes
|
return png_bytes
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 如果没有 cairosvg,返回 SVG 的 base64
|
# 如果没有 cairosvg,返回 SVG 的 base64
|
||||||
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
svg_content = self.export_knowledge_graph_svg(project_id, entities, relations)
|
||||||
return base64.b64encode(svg_content.encode('utf-8'))
|
return base64.b64encode(svg_content.encode("utf-8"))
|
||||||
|
|
||||||
def export_entities_excel(self, entities: List[ExportEntity]) -> bytes:
|
def export_entities_excel(self, entities: List[ExportEntity]) -> bytes:
|
||||||
"""
|
"""
|
||||||
导出实体数据为 Excel 格式
|
导出实体数据为 Excel 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Excel 文件字节
|
Excel 文件字节
|
||||||
"""
|
"""
|
||||||
if not PANDAS_AVAILABLE:
|
if not PANDAS_AVAILABLE:
|
||||||
raise ImportError("pandas is required for Excel export")
|
raise ImportError("pandas is required for Excel export")
|
||||||
|
|
||||||
# 准备数据
|
# 准备数据
|
||||||
data = []
|
data = []
|
||||||
for e in entities:
|
for e in entities:
|
||||||
row = {
|
row = {
|
||||||
'ID': e.id,
|
"ID": e.id,
|
||||||
'名称': e.name,
|
"名称": e.name,
|
||||||
'类型': e.type,
|
"类型": e.type,
|
||||||
'定义': e.definition,
|
"定义": e.definition,
|
||||||
'别名': ', '.join(e.aliases),
|
"别名": ", ".join(e.aliases),
|
||||||
'提及次数': e.mention_count
|
"提及次数": e.mention_count,
|
||||||
}
|
}
|
||||||
# 添加属性
|
# 添加属性
|
||||||
for attr_name, attr_value in e.attributes.items():
|
for attr_name, attr_value in e.attributes.items():
|
||||||
row[f'属性:{attr_name}'] = attr_value
|
row[f"属性:{attr_name}"] = attr_value
|
||||||
data.append(row)
|
data.append(row)
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
# 写入 Excel
|
# 写入 Excel
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
with pd.ExcelWriter(output, engine="openpyxl") as writer:
|
||||||
df.to_excel(writer, sheet_name='实体列表', index=False)
|
df.to_excel(writer, sheet_name="实体列表", index=False)
|
||||||
|
|
||||||
# 调整列宽
|
# 调整列宽
|
||||||
worksheet = writer.sheets['实体列表']
|
worksheet = writer.sheets["实体列表"]
|
||||||
for column in worksheet.columns:
|
for column in worksheet.columns:
|
||||||
max_length = 0
|
max_length = 0
|
||||||
column_letter = column[0].column_letter
|
column_letter = column[0].column_letter
|
||||||
@@ -253,67 +266,66 @@ class ExportManager:
|
|||||||
try:
|
try:
|
||||||
if len(str(cell.value)) > max_length:
|
if len(str(cell.value)) > max_length:
|
||||||
max_length = len(str(cell.value))
|
max_length = len(str(cell.value))
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
adjusted_width = min(max_length + 2, 50)
|
adjusted_width = min(max_length + 2, 50)
|
||||||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||||||
|
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
def export_entities_csv(self, entities: List[ExportEntity]) -> str:
|
def export_entities_csv(self, entities: List[ExportEntity]) -> str:
|
||||||
"""
|
"""
|
||||||
导出实体数据为 CSV 格式
|
导出实体数据为 CSV 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CSV 字符串
|
CSV 字符串
|
||||||
"""
|
"""
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
|
|
||||||
# 收集所有可能的属性列
|
# 收集所有可能的属性列
|
||||||
all_attrs = set()
|
all_attrs = set()
|
||||||
for e in entities:
|
for e in entities:
|
||||||
all_attrs.update(e.attributes.keys())
|
all_attrs.update(e.attributes.keys())
|
||||||
|
|
||||||
# 表头
|
# 表头
|
||||||
headers = ['ID', '名称', '类型', '定义', '别名', '提及次数'] + [f'属性:{a}' for a in sorted(all_attrs)]
|
headers = ["ID", "名称", "类型", "定义", "别名", "提及次数"] + [f"属性:{a}" for a in sorted(all_attrs)]
|
||||||
|
|
||||||
writer = csv.writer(output)
|
writer = csv.writer(output)
|
||||||
writer.writerow(headers)
|
writer.writerow(headers)
|
||||||
|
|
||||||
# 数据行
|
# 数据行
|
||||||
for e in entities:
|
for e in entities:
|
||||||
row = [e.id, e.name, e.type, e.definition, ', '.join(e.aliases), e.mention_count]
|
row = [e.id, e.name, e.type, e.definition, ", ".join(e.aliases), e.mention_count]
|
||||||
for attr in sorted(all_attrs):
|
for attr in sorted(all_attrs):
|
||||||
row.append(e.attributes.get(attr, ''))
|
row.append(e.attributes.get(attr, ""))
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
def export_relations_csv(self, relations: List[ExportRelation]) -> str:
|
def export_relations_csv(self, relations: List[ExportRelation]) -> str:
|
||||||
"""
|
"""
|
||||||
导出关系数据为 CSV 格式
|
导出关系数据为 CSV 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CSV 字符串
|
CSV 字符串
|
||||||
"""
|
"""
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = csv.writer(output)
|
writer = csv.writer(output)
|
||||||
writer.writerow(['ID', '源实体', '目标实体', '关系类型', '置信度', '证据'])
|
writer.writerow(["ID", "源实体", "目标实体", "关系类型", "置信度", "证据"])
|
||||||
|
|
||||||
for r in relations:
|
for r in relations:
|
||||||
writer.writerow([r.id, r.source, r.target, r.relation_type, r.confidence, r.evidence])
|
writer.writerow([r.id, r.source, r.target, r.relation_type, r.confidence, r.evidence])
|
||||||
|
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
def export_transcript_markdown(self, transcript: ExportTranscript,
|
def export_transcript_markdown(self, transcript: ExportTranscript, entities_map: Dict[str, ExportEntity]) -> str:
|
||||||
entities_map: Dict[str, ExportEntity]) -> str:
|
|
||||||
"""
|
"""
|
||||||
导出转录文本为 Markdown 格式
|
导出转录文本为 Markdown 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Markdown 字符串
|
Markdown 字符串
|
||||||
"""
|
"""
|
||||||
@@ -332,190 +344,196 @@ class ExportManager:
|
|||||||
"---",
|
"---",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
|
|
||||||
if transcript.segments:
|
if transcript.segments:
|
||||||
lines.extend([
|
lines.extend(
|
||||||
"## 分段详情",
|
[
|
||||||
"",
|
"## 分段详情",
|
||||||
])
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
for seg in transcript.segments:
|
for seg in transcript.segments:
|
||||||
speaker = seg.get('speaker', 'Unknown')
|
speaker = seg.get("speaker", "Unknown")
|
||||||
start = seg.get('start', 0)
|
start = seg.get("start", 0)
|
||||||
end = seg.get('end', 0)
|
end = seg.get("end", 0)
|
||||||
text = seg.get('text', '')
|
text = seg.get("text", "")
|
||||||
lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}")
|
lines.append(f"**[{start:.1f}s - {end:.1f}s] {speaker}**: {text}")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
if transcript.entity_mentions:
|
if transcript.entity_mentions:
|
||||||
lines.extend([
|
lines.extend(
|
||||||
"",
|
[
|
||||||
"## 实体提及",
|
"",
|
||||||
"",
|
"## 实体提及",
|
||||||
"| 实体 | 类型 | 位置 | 上下文 |",
|
"",
|
||||||
"|------|------|------|--------|",
|
"| 实体 | 类型 | 位置 | 上下文 |",
|
||||||
])
|
"|------|------|------|--------|",
|
||||||
|
]
|
||||||
|
)
|
||||||
for mention in transcript.entity_mentions:
|
for mention in transcript.entity_mentions:
|
||||||
entity_id = mention.get('entity_id', '')
|
entity_id = mention.get("entity_id", "")
|
||||||
entity = entities_map.get(entity_id)
|
entity = entities_map.get(entity_id)
|
||||||
entity_name = entity.name if entity else mention.get('entity_name', 'Unknown')
|
entity_name = entity.name if entity else mention.get("entity_name", "Unknown")
|
||||||
entity_type = entity.type if entity else 'Unknown'
|
entity_type = entity.type if entity else "Unknown"
|
||||||
position = mention.get('position', '')
|
position = mention.get("position", "")
|
||||||
context = mention.get('context', '')[:50] + '...' if mention.get('context') else ''
|
context = mention.get("context", "")[:50] + "..." if mention.get("context") else ""
|
||||||
lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |")
|
lines.append(f"| {entity_name} | {entity_type} | {position} | {context} |")
|
||||||
|
|
||||||
return '\n'.join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
def export_project_report_pdf(self, project_id: str, project_name: str,
|
def export_project_report_pdf(
|
||||||
entities: List[ExportEntity],
|
self,
|
||||||
relations: List[ExportRelation],
|
project_id: str,
|
||||||
transcripts: List[ExportTranscript],
|
project_name: str,
|
||||||
summary: str = "") -> bytes:
|
entities: List[ExportEntity],
|
||||||
|
relations: List[ExportRelation],
|
||||||
|
transcripts: List[ExportTranscript],
|
||||||
|
summary: str = "",
|
||||||
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
导出项目报告为 PDF 格式
|
导出项目报告为 PDF 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PDF 文件字节
|
PDF 文件字节
|
||||||
"""
|
"""
|
||||||
if not REPORTLAB_AVAILABLE:
|
if not REPORTLAB_AVAILABLE:
|
||||||
raise ImportError("reportlab is required for PDF export")
|
raise ImportError("reportlab is required for PDF export")
|
||||||
|
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
doc = SimpleDocTemplate(
|
doc = SimpleDocTemplate(output, pagesize=A4, rightMargin=72, leftMargin=72, topMargin=72, bottomMargin=18)
|
||||||
output,
|
|
||||||
pagesize=A4,
|
|
||||||
rightMargin=72,
|
|
||||||
leftMargin=72,
|
|
||||||
topMargin=72,
|
|
||||||
bottomMargin=18
|
|
||||||
)
|
|
||||||
|
|
||||||
# 样式
|
# 样式
|
||||||
styles = getSampleStyleSheet()
|
styles = getSampleStyleSheet()
|
||||||
title_style = ParagraphStyle(
|
title_style = ParagraphStyle(
|
||||||
'CustomTitle',
|
"CustomTitle", parent=styles["Heading1"], fontSize=24, spaceAfter=30, textColor=colors.HexColor("#2c3e50")
|
||||||
parent=styles['Heading1'],
|
|
||||||
fontSize=24,
|
|
||||||
spaceAfter=30,
|
|
||||||
textColor=colors.HexColor('#2c3e50')
|
|
||||||
)
|
)
|
||||||
heading_style = ParagraphStyle(
|
heading_style = ParagraphStyle(
|
||||||
'CustomHeading',
|
"CustomHeading", parent=styles["Heading2"], fontSize=16, spaceAfter=12, textColor=colors.HexColor("#34495e")
|
||||||
parent=styles['Heading2'],
|
|
||||||
fontSize=16,
|
|
||||||
spaceAfter=12,
|
|
||||||
textColor=colors.HexColor('#34495e')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
story = []
|
story = []
|
||||||
|
|
||||||
# 标题页
|
# 标题页
|
||||||
story.append(Paragraph(f"InsightFlow 项目报告", title_style))
|
story.append(Paragraph(f"InsightFlow 项目报告", title_style))
|
||||||
story.append(Paragraph(f"项目名称: {project_name}", styles['Heading2']))
|
story.append(Paragraph(f"项目名称: {project_name}", styles["Heading2"]))
|
||||||
story.append(Paragraph(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}", styles['Normal']))
|
story.append(Paragraph(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}", styles["Normal"]))
|
||||||
story.append(Spacer(1, 0.3*inch))
|
story.append(Spacer(1, 0.3 * inch))
|
||||||
|
|
||||||
# 统计概览
|
# 统计概览
|
||||||
story.append(Paragraph("项目概览", heading_style))
|
story.append(Paragraph("项目概览", heading_style))
|
||||||
stats_data = [
|
stats_data = [
|
||||||
['指标', '数值'],
|
["指标", "数值"],
|
||||||
['实体数量', str(len(entities))],
|
["实体数量", str(len(entities))],
|
||||||
['关系数量', str(len(relations))],
|
["关系数量", str(len(relations))],
|
||||||
['文档数量', str(len(transcripts))],
|
["文档数量", str(len(transcripts))],
|
||||||
]
|
]
|
||||||
|
|
||||||
# 按类型统计实体
|
# 按类型统计实体
|
||||||
type_counts = {}
|
type_counts = {}
|
||||||
for e in entities:
|
for e in entities:
|
||||||
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
type_counts[e.type] = type_counts.get(e.type, 0) + 1
|
||||||
|
|
||||||
for etype, count in sorted(type_counts.items()):
|
for etype, count in sorted(type_counts.items()):
|
||||||
stats_data.append([f'{etype} 实体', str(count)])
|
stats_data.append([f"{etype} 实体", str(count)])
|
||||||
|
|
||||||
stats_table = Table(stats_data, colWidths=[3*inch, 2*inch])
|
stats_table = Table(stats_data, colWidths=[3 * inch, 2 * inch])
|
||||||
stats_table.setStyle(TableStyle([
|
stats_table.setStyle(
|
||||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495e')),
|
TableStyle(
|
||||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
[
|
||||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#34495e")),
|
||||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
|
||||||
('FONTSIZE', (0, 0), (-1, 0), 12),
|
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
|
||||||
('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#ecf0f1')),
|
("FONTSIZE", (0, 0), (-1, 0), 12),
|
||||||
('GRID', (0, 0), (-1, -1), 1, colors.HexColor('#bdc3c7'))
|
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
|
||||||
]))
|
("BACKGROUND", (0, 1), (-1, -1), colors.HexColor("#ecf0f1")),
|
||||||
|
("GRID", (0, 0), (-1, -1), 1, colors.HexColor("#bdc3c7")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
story.append(stats_table)
|
story.append(stats_table)
|
||||||
story.append(Spacer(1, 0.3*inch))
|
story.append(Spacer(1, 0.3 * inch))
|
||||||
|
|
||||||
# 项目总结
|
# 项目总结
|
||||||
if summary:
|
if summary:
|
||||||
story.append(Paragraph("项目总结", heading_style))
|
story.append(Paragraph("项目总结", heading_style))
|
||||||
story.append(Paragraph(summary, styles['Normal']))
|
story.append(Paragraph(summary, styles["Normal"]))
|
||||||
story.append(Spacer(1, 0.3*inch))
|
story.append(Spacer(1, 0.3 * inch))
|
||||||
|
|
||||||
# 实体列表
|
# 实体列表
|
||||||
if entities:
|
if entities:
|
||||||
story.append(PageBreak())
|
story.append(PageBreak())
|
||||||
story.append(Paragraph("实体列表", heading_style))
|
story.append(Paragraph("实体列表", heading_style))
|
||||||
|
|
||||||
entity_data = [['名称', '类型', '提及次数', '定义']]
|
entity_data = [["名称", "类型", "提及次数", "定义"]]
|
||||||
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[:50]: # 限制前50个
|
for e in sorted(entities, key=lambda x: x.mention_count, reverse=True)[:50]: # 限制前50个
|
||||||
entity_data.append([
|
entity_data.append(
|
||||||
e.name,
|
[
|
||||||
e.type,
|
e.name,
|
||||||
str(e.mention_count),
|
e.type,
|
||||||
(e.definition[:100] + '...') if len(e.definition) > 100 else e.definition
|
str(e.mention_count),
|
||||||
])
|
(e.definition[:100] + "...") if len(e.definition) > 100 else e.definition,
|
||||||
|
]
|
||||||
entity_table = Table(entity_data, colWidths=[1.5*inch, 1*inch, 1*inch, 2.5*inch])
|
)
|
||||||
entity_table.setStyle(TableStyle([
|
|
||||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495e')),
|
entity_table = Table(entity_data, colWidths=[1.5 * inch, 1 * inch, 1 * inch, 2.5 * inch])
|
||||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
entity_table.setStyle(
|
||||||
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
|
TableStyle(
|
||||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
[
|
||||||
('FONTSIZE', (0, 0), (-1, 0), 10),
|
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#34495e")),
|
||||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
|
||||||
('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#ecf0f1')),
|
("ALIGN", (0, 0), (-1, -1), "LEFT"),
|
||||||
('GRID', (0, 0), (-1, -1), 1, colors.HexColor('#bdc3c7')),
|
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
|
||||||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||||||
]))
|
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
|
||||||
|
("BACKGROUND", (0, 1), (-1, -1), colors.HexColor("#ecf0f1")),
|
||||||
|
("GRID", (0, 0), (-1, -1), 1, colors.HexColor("#bdc3c7")),
|
||||||
|
("VALIGN", (0, 0), (-1, -1), "TOP"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
story.append(entity_table)
|
story.append(entity_table)
|
||||||
|
|
||||||
# 关系列表
|
# 关系列表
|
||||||
if relations:
|
if relations:
|
||||||
story.append(PageBreak())
|
story.append(PageBreak())
|
||||||
story.append(Paragraph("关系列表", heading_style))
|
story.append(Paragraph("关系列表", heading_style))
|
||||||
|
|
||||||
relation_data = [['源实体', '关系', '目标实体', '置信度']]
|
relation_data = [["源实体", "关系", "目标实体", "置信度"]]
|
||||||
for r in relations[:100]: # 限制前100个
|
for r in relations[:100]: # 限制前100个
|
||||||
relation_data.append([
|
relation_data.append([r.source, r.relation_type, r.target, f"{r.confidence:.2f}"])
|
||||||
r.source,
|
|
||||||
r.relation_type,
|
relation_table = Table(relation_data, colWidths=[2 * inch, 1.5 * inch, 2 * inch, 1 * inch])
|
||||||
r.target,
|
relation_table.setStyle(
|
||||||
f"{r.confidence:.2f}"
|
TableStyle(
|
||||||
])
|
[
|
||||||
|
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#34495e")),
|
||||||
relation_table = Table(relation_data, colWidths=[2*inch, 1.5*inch, 2*inch, 1*inch])
|
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
|
||||||
relation_table.setStyle(TableStyle([
|
("ALIGN", (0, 0), (-1, -1), "LEFT"),
|
||||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495e')),
|
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
|
||||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||||||
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
|
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
|
||||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
("BACKGROUND", (0, 1), (-1, -1), colors.HexColor("#ecf0f1")),
|
||||||
('FONTSIZE', (0, 0), (-1, 0), 10),
|
("GRID", (0, 0), (-1, -1), 1, colors.HexColor("#bdc3c7")),
|
||||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
]
|
||||||
('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#ecf0f1')),
|
)
|
||||||
('GRID', (0, 0), (-1, -1), 1, colors.HexColor('#bdc3c7')),
|
)
|
||||||
]))
|
|
||||||
story.append(relation_table)
|
story.append(relation_table)
|
||||||
|
|
||||||
doc.build(story)
|
doc.build(story)
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
def export_project_json(self, project_id: str, project_name: str,
|
def export_project_json(
|
||||||
entities: List[ExportEntity],
|
self,
|
||||||
relations: List[ExportRelation],
|
project_id: str,
|
||||||
transcripts: List[ExportTranscript]) -> str:
|
project_name: str,
|
||||||
|
entities: List[ExportEntity],
|
||||||
|
relations: List[ExportRelation],
|
||||||
|
transcripts: List[ExportTranscript],
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
导出完整项目数据为 JSON 格式
|
导出完整项目数据为 JSON 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
JSON 字符串
|
JSON 字符串
|
||||||
"""
|
"""
|
||||||
@@ -531,7 +549,7 @@ class ExportManager:
|
|||||||
"definition": e.definition,
|
"definition": e.definition,
|
||||||
"aliases": e.aliases,
|
"aliases": e.aliases,
|
||||||
"mention_count": e.mention_count,
|
"mention_count": e.mention_count,
|
||||||
"attributes": e.attributes
|
"attributes": e.attributes,
|
||||||
}
|
}
|
||||||
for e in entities
|
for e in entities
|
||||||
],
|
],
|
||||||
@@ -542,31 +560,26 @@ class ExportManager:
|
|||||||
"target": r.target,
|
"target": r.target,
|
||||||
"relation_type": r.relation_type,
|
"relation_type": r.relation_type,
|
||||||
"confidence": r.confidence,
|
"confidence": r.confidence,
|
||||||
"evidence": r.evidence
|
"evidence": r.evidence,
|
||||||
}
|
}
|
||||||
for r in relations
|
for r in relations
|
||||||
],
|
],
|
||||||
"transcripts": [
|
"transcripts": [
|
||||||
{
|
{"id": t.id, "name": t.name, "type": t.type, "content": t.content, "segments": t.segments}
|
||||||
"id": t.id,
|
|
||||||
"name": t.name,
|
|
||||||
"type": t.type,
|
|
||||||
"content": t.content,
|
|
||||||
"segments": t.segments
|
|
||||||
}
|
|
||||||
for t in transcripts
|
for t in transcripts
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(data, ensure_ascii=False, indent=2)
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
# 全局导出管理器实例
|
# 全局导出管理器实例
|
||||||
_export_manager = None
|
_export_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_export_manager(db_manager=None):
|
def get_export_manager(db_manager=None):
|
||||||
"""获取导出管理器实例"""
|
"""获取导出管理器实例"""
|
||||||
global _export_manager
|
global _export_manager
|
||||||
if _export_manager is None:
|
if _export_manager is None:
|
||||||
_export_manager = ExportManager(db_manager)
|
_export_manager = ExportManager(db_manager)
|
||||||
return _export_manager
|
return _export_manager
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -6,16 +6,15 @@ InsightFlow Image Processor - Phase 7
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
import base64
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 尝试导入图像处理库
|
# 尝试导入图像处理库
|
||||||
try:
|
try:
|
||||||
from PIL import Image, ImageEnhance, ImageFilter
|
from PIL import Image, ImageEnhance, ImageFilter
|
||||||
|
|
||||||
PIL_AVAILABLE = True
|
PIL_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PIL_AVAILABLE = False
|
PIL_AVAILABLE = False
|
||||||
@@ -23,12 +22,14 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
CV2_AVAILABLE = True
|
CV2_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CV2_AVAILABLE = False
|
CV2_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pytesseract
|
import pytesseract
|
||||||
|
|
||||||
PYTESSERACT_AVAILABLE = True
|
PYTESSERACT_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PYTESSERACT_AVAILABLE = False
|
PYTESSERACT_AVAILABLE = False
|
||||||
@@ -37,6 +38,7 @@ except ImportError:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ImageEntity:
|
class ImageEntity:
|
||||||
"""图片中检测到的实体"""
|
"""图片中检测到的实体"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
confidence: float
|
confidence: float
|
||||||
@@ -46,6 +48,7 @@ class ImageEntity:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ImageRelation:
|
class ImageRelation:
|
||||||
"""图片中检测到的关系"""
|
"""图片中检测到的关系"""
|
||||||
|
|
||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
relation_type: str
|
relation_type: str
|
||||||
@@ -55,6 +58,7 @@ class ImageRelation:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ImageProcessingResult:
|
class ImageProcessingResult:
|
||||||
"""图片处理结果"""
|
"""图片处理结果"""
|
||||||
|
|
||||||
image_id: str
|
image_id: str
|
||||||
image_type: str # whiteboard, ppt, handwritten, screenshot, other
|
image_type: str # whiteboard, ppt, handwritten, screenshot, other
|
||||||
ocr_text: str
|
ocr_text: str
|
||||||
@@ -70,6 +74,7 @@ class ImageProcessingResult:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchProcessingResult:
|
class BatchProcessingResult:
|
||||||
"""批量图片处理结果"""
|
"""批量图片处理结果"""
|
||||||
|
|
||||||
results: List[ImageProcessingResult]
|
results: List[ImageProcessingResult]
|
||||||
total_count: int
|
total_count: int
|
||||||
success_count: int
|
success_count: int
|
||||||
@@ -78,232 +83,234 @@ class BatchProcessingResult:
|
|||||||
|
|
||||||
class ImageProcessor:
|
class ImageProcessor:
|
||||||
"""图片处理器 - 处理各种类型图片"""
|
"""图片处理器 - 处理各种类型图片"""
|
||||||
|
|
||||||
# 图片类型定义
|
# 图片类型定义
|
||||||
IMAGE_TYPES = {
|
IMAGE_TYPES = {
|
||||||
'whiteboard': '白板',
|
"whiteboard": "白板",
|
||||||
'ppt': 'PPT/演示文稿',
|
"ppt": "PPT/演示文稿",
|
||||||
'handwritten': '手写笔记',
|
"handwritten": "手写笔记",
|
||||||
'screenshot': '屏幕截图',
|
"screenshot": "屏幕截图",
|
||||||
'document': '文档图片',
|
"document": "文档图片",
|
||||||
'other': '其他'
|
"other": "其他",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None):
|
def __init__(self, temp_dir: str = None):
|
||||||
"""
|
"""
|
||||||
初始化图片处理器
|
初始化图片处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
temp_dir: 临时文件目录
|
temp_dir: 临时文件目录
|
||||||
"""
|
"""
|
||||||
self.temp_dir = temp_dir or os.path.join(os.getcwd(), 'temp', 'images')
|
self.temp_dir = temp_dir or os.path.join(os.getcwd(), "temp", "images")
|
||||||
os.makedirs(self.temp_dir, exist_ok=True)
|
os.makedirs(self.temp_dir, exist_ok=True)
|
||||||
|
|
||||||
def preprocess_image(self, image, image_type: str = None):
|
def preprocess_image(self, image, image_type: str = None):
|
||||||
"""
|
"""
|
||||||
预处理图片以提高OCR质量
|
预处理图片以提高OCR质量
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: PIL Image 对象
|
image: PIL Image 对象
|
||||||
image_type: 图片类型(用于针对性处理)
|
image_type: 图片类型(用于针对性处理)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
处理后的图片
|
处理后的图片
|
||||||
"""
|
"""
|
||||||
if not PIL_AVAILABLE:
|
if not PIL_AVAILABLE:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 转换为RGB(如果是RGBA)
|
# 转换为RGB(如果是RGBA)
|
||||||
if image.mode == 'RGBA':
|
if image.mode == "RGBA":
|
||||||
image = image.convert('RGB')
|
image = image.convert("RGB")
|
||||||
|
|
||||||
# 根据图片类型进行针对性处理
|
# 根据图片类型进行针对性处理
|
||||||
if image_type == 'whiteboard':
|
if image_type == "whiteboard":
|
||||||
# 白板:增强对比度,去除背景
|
# 白板:增强对比度,去除背景
|
||||||
image = self._enhance_whiteboard(image)
|
image = self._enhance_whiteboard(image)
|
||||||
elif image_type == 'handwritten':
|
elif image_type == "handwritten":
|
||||||
# 手写笔记:降噪,增强对比度
|
# 手写笔记:降噪,增强对比度
|
||||||
image = self._enhance_handwritten(image)
|
image = self._enhance_handwritten(image)
|
||||||
elif image_type == 'screenshot':
|
elif image_type == "screenshot":
|
||||||
# 截图:轻微锐化
|
# 截图:轻微锐化
|
||||||
image = image.filter(ImageFilter.SHARPEN)
|
image = image.filter(ImageFilter.SHARPEN)
|
||||||
|
|
||||||
# 通用处理:调整大小(如果太大)
|
# 通用处理:调整大小(如果太大)
|
||||||
max_size = 4096
|
max_size = 4096
|
||||||
if max(image.size) > max_size:
|
if max(image.size) > max_size:
|
||||||
ratio = max_size / max(image.size)
|
ratio = max_size / max(image.size)
|
||||||
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
||||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Image preprocessing error: {e}")
|
print(f"Image preprocessing error: {e}")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _enhance_whiteboard(self, image):
|
def _enhance_whiteboard(self, image):
|
||||||
"""增强白板图片"""
|
"""增强白板图片"""
|
||||||
# 转换为灰度
|
# 转换为灰度
|
||||||
gray = image.convert('L')
|
gray = image.convert("L")
|
||||||
|
|
||||||
# 增强对比度
|
# 增强对比度
|
||||||
enhancer = ImageEnhance.Contrast(gray)
|
enhancer = ImageEnhance.Contrast(gray)
|
||||||
enhanced = enhancer.enhance(2.0)
|
enhanced = enhancer.enhance(2.0)
|
||||||
|
|
||||||
# 二值化
|
# 二值化
|
||||||
threshold = 128
|
threshold = 128
|
||||||
binary = enhanced.point(lambda x: 0 if x < threshold else 255, '1')
|
binary = enhanced.point(lambda x: 0 if x < threshold else 255, "1")
|
||||||
|
|
||||||
return binary.convert('L')
|
return binary.convert("L")
|
||||||
|
|
||||||
def _enhance_handwritten(self, image):
|
def _enhance_handwritten(self, image):
|
||||||
"""增强手写笔记图片"""
|
"""增强手写笔记图片"""
|
||||||
# 转换为灰度
|
# 转换为灰度
|
||||||
gray = image.convert('L')
|
gray = image.convert("L")
|
||||||
|
|
||||||
# 轻微降噪
|
# 轻微降噪
|
||||||
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
|
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1))
|
||||||
|
|
||||||
# 增强对比度
|
# 增强对比度
|
||||||
enhancer = ImageEnhance.Contrast(blurred)
|
enhancer = ImageEnhance.Contrast(blurred)
|
||||||
enhanced = enhancer.enhance(1.5)
|
enhanced = enhancer.enhance(1.5)
|
||||||
|
|
||||||
return enhanced
|
return enhanced
|
||||||
|
|
||||||
def detect_image_type(self, image, ocr_text: str = "") -> str:
|
def detect_image_type(self, image, ocr_text: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
自动检测图片类型
|
自动检测图片类型
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: PIL Image 对象
|
image: PIL Image 对象
|
||||||
ocr_text: OCR识别的文本
|
ocr_text: OCR识别的文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
图片类型字符串
|
图片类型字符串
|
||||||
"""
|
"""
|
||||||
if not PIL_AVAILABLE:
|
if not PIL_AVAILABLE:
|
||||||
return 'other'
|
return "other"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 基于图片特征和OCR内容判断类型
|
# 基于图片特征和OCR内容判断类型
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
|
|
||||||
# 检测是否为PPT(通常是16:9或4:3)
|
# 检测是否为PPT(通常是16:9或4:3)
|
||||||
if 1.3 <= aspect_ratio <= 1.8:
|
if 1.3 <= aspect_ratio <= 1.8:
|
||||||
# 检查是否有典型的PPT特征(标题、项目符号等)
|
# 检查是否有典型的PPT特征(标题、项目符号等)
|
||||||
if any(keyword in ocr_text.lower() for keyword in ['slide', 'page', '第', '页']):
|
if any(keyword in ocr_text.lower() for keyword in ["slide", "page", "第", "页"]):
|
||||||
return 'ppt'
|
return "ppt"
|
||||||
|
|
||||||
# 检测是否为白板(大量手写文字,可能有箭头、框等)
|
# 检测是否为白板(大量手写文字,可能有箭头、框等)
|
||||||
if CV2_AVAILABLE:
|
if CV2_AVAILABLE:
|
||||||
img_array = np.array(image.convert('RGB'))
|
img_array = np.array(image.convert("RGB"))
|
||||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||||
|
|
||||||
# 检测边缘(白板通常有很多线条)
|
# 检测边缘(白板通常有很多线条)
|
||||||
edges = cv2.Canny(gray, 50, 150)
|
edges = cv2.Canny(gray, 50, 150)
|
||||||
edge_ratio = np.sum(edges > 0) / edges.size
|
edge_ratio = np.sum(edges > 0) / edges.size
|
||||||
|
|
||||||
# 如果边缘比例高,可能是白板
|
# 如果边缘比例高,可能是白板
|
||||||
if edge_ratio > 0.05 and len(ocr_text) > 50:
|
if edge_ratio > 0.05 and len(ocr_text) > 50:
|
||||||
return 'whiteboard'
|
return "whiteboard"
|
||||||
|
|
||||||
# 检测是否为手写笔记(文字密度高,可能有涂鸦)
|
# 检测是否为手写笔记(文字密度高,可能有涂鸦)
|
||||||
if len(ocr_text) > 100 and aspect_ratio < 1.5:
|
if len(ocr_text) > 100 and aspect_ratio < 1.5:
|
||||||
# 检查手写特征(不规则的行高)
|
# 检查手写特征(不规则的行高)
|
||||||
return 'handwritten'
|
return "handwritten"
|
||||||
|
|
||||||
# 检测是否为截图(可能有UI元素)
|
# 检测是否为截图(可能有UI元素)
|
||||||
if any(keyword in ocr_text.lower() for keyword in ['button', 'menu', 'click', '登录', '确定', '取消']):
|
if any(keyword in ocr_text.lower() for keyword in ["button", "menu", "click", "登录", "确定", "取消"]):
|
||||||
return 'screenshot'
|
return "screenshot"
|
||||||
|
|
||||||
# 默认文档类型
|
# 默认文档类型
|
||||||
if len(ocr_text) > 200:
|
if len(ocr_text) > 200:
|
||||||
return 'document'
|
return "document"
|
||||||
|
|
||||||
return 'other'
|
return "other"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Image type detection error: {e}")
|
print(f"Image type detection error: {e}")
|
||||||
return 'other'
|
return "other"
|
||||||
|
|
||||||
def perform_ocr(self, image, lang: str = 'chi_sim+eng') -> Tuple[str, float]:
|
def perform_ocr(self, image, lang: str = "chi_sim+eng") -> Tuple[str, float]:
|
||||||
"""
|
"""
|
||||||
对图片进行OCR识别
|
对图片进行OCR识别
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: PIL Image 对象
|
image: PIL Image 对象
|
||||||
lang: OCR语言
|
lang: OCR语言
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(识别的文本, 置信度)
|
(识别的文本, 置信度)
|
||||||
"""
|
"""
|
||||||
if not PYTESSERACT_AVAILABLE:
|
if not PYTESSERACT_AVAILABLE:
|
||||||
return "", 0.0
|
return "", 0.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 预处理图片
|
# 预处理图片
|
||||||
processed_image = self.preprocess_image(image)
|
processed_image = self.preprocess_image(image)
|
||||||
|
|
||||||
# 执行OCR
|
# 执行OCR
|
||||||
text = pytesseract.image_to_string(processed_image, lang=lang)
|
text = pytesseract.image_to_string(processed_image, lang=lang)
|
||||||
|
|
||||||
# 获取置信度
|
# 获取置信度
|
||||||
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
|
data = pytesseract.image_to_data(processed_image, output_type=pytesseract.Output.DICT)
|
||||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||||||
|
|
||||||
return text.strip(), avg_confidence / 100.0
|
return text.strip(), avg_confidence / 100.0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"OCR error: {e}")
|
print(f"OCR error: {e}")
|
||||||
return "", 0.0
|
return "", 0.0
|
||||||
|
|
||||||
def extract_entities_from_text(self, text: str) -> List[ImageEntity]:
|
def extract_entities_from_text(self, text: str) -> List[ImageEntity]:
|
||||||
"""
|
"""
|
||||||
从OCR文本中提取实体
|
从OCR文本中提取实体
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: OCR识别的文本
|
text: OCR识别的文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
实体列表
|
实体列表
|
||||||
"""
|
"""
|
||||||
entities = []
|
entities = []
|
||||||
|
|
||||||
# 简单的实体提取规则(可以替换为LLM调用)
|
# 简单的实体提取规则(可以替换为LLM调用)
|
||||||
# 提取大写字母开头的词组(可能是专有名词)
|
# 提取大写字母开头的词组(可能是专有名词)
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 项目名称(通常是大写或带引号)
|
# 项目名称(通常是大写或带引号)
|
||||||
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
|
project_pattern = r'["\']([^"\']+)["\']|([A-Z][a-zA-Z0-9]*(?:\s+[A-Z][a-zA-Z0-9]*)+)'
|
||||||
for match in re.finditer(project_pattern, text):
|
for match in re.finditer(project_pattern, text):
|
||||||
name = match.group(1) or match.group(2)
|
name = match.group(1) or match.group(2)
|
||||||
if name and len(name) > 2:
|
if name and len(name) > 2:
|
||||||
entities.append(ImageEntity(
|
entities.append(ImageEntity(name=name.strip(), type="PROJECT", confidence=0.7))
|
||||||
name=name.strip(),
|
|
||||||
type='PROJECT',
|
|
||||||
confidence=0.7
|
|
||||||
))
|
|
||||||
|
|
||||||
# 人名(中文)
|
# 人名(中文)
|
||||||
name_pattern = r'([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)'
|
name_pattern = r"([\u4e00-\u9fa5]{2,4})(?:先生|女士|总|经理|工程师|老师)"
|
||||||
for match in re.finditer(name_pattern, text):
|
for match in re.finditer(name_pattern, text):
|
||||||
entities.append(ImageEntity(
|
entities.append(ImageEntity(name=match.group(1), type="PERSON", confidence=0.8))
|
||||||
name=match.group(1),
|
|
||||||
type='PERSON',
|
|
||||||
confidence=0.8
|
|
||||||
))
|
|
||||||
|
|
||||||
# 技术术语
|
# 技术术语
|
||||||
tech_keywords = ['K8s', 'Kubernetes', 'Docker', 'API', 'SDK', 'AI', 'ML',
|
tech_keywords = [
|
||||||
'Python', 'Java', 'React', 'Vue', 'Node.js', '数据库', '服务器']
|
"K8s",
|
||||||
|
"Kubernetes",
|
||||||
|
"Docker",
|
||||||
|
"API",
|
||||||
|
"SDK",
|
||||||
|
"AI",
|
||||||
|
"ML",
|
||||||
|
"Python",
|
||||||
|
"Java",
|
||||||
|
"React",
|
||||||
|
"Vue",
|
||||||
|
"Node.js",
|
||||||
|
"数据库",
|
||||||
|
"服务器",
|
||||||
|
]
|
||||||
for keyword in tech_keywords:
|
for keyword in tech_keywords:
|
||||||
if keyword in text:
|
if keyword in text:
|
||||||
entities.append(ImageEntity(
|
entities.append(ImageEntity(name=keyword, type="TECH", confidence=0.9))
|
||||||
name=keyword,
|
|
||||||
type='TECH',
|
|
||||||
confidence=0.9
|
|
||||||
))
|
|
||||||
|
|
||||||
# 去重
|
# 去重
|
||||||
seen = set()
|
seen = set()
|
||||||
unique_entities = []
|
unique_entities = []
|
||||||
@@ -312,96 +319,96 @@ class ImageProcessor:
|
|||||||
if key not in seen:
|
if key not in seen:
|
||||||
seen.add(key)
|
seen.add(key)
|
||||||
unique_entities.append(e)
|
unique_entities.append(e)
|
||||||
|
|
||||||
return unique_entities
|
return unique_entities
|
||||||
|
|
||||||
def generate_description(self, image_type: str, ocr_text: str,
|
def generate_description(self, image_type: str, ocr_text: str, entities: List[ImageEntity]) -> str:
|
||||||
entities: List[ImageEntity]) -> str:
|
|
||||||
"""
|
"""
|
||||||
生成图片描述
|
生成图片描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_type: 图片类型
|
image_type: 图片类型
|
||||||
ocr_text: OCR文本
|
ocr_text: OCR文本
|
||||||
entities: 检测到的实体
|
entities: 检测到的实体
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
图片描述
|
图片描述
|
||||||
"""
|
"""
|
||||||
type_name = self.IMAGE_TYPES.get(image_type, '图片')
|
type_name = self.IMAGE_TYPES.get(image_type, "图片")
|
||||||
|
|
||||||
description_parts = [f"这是一张{type_name}图片。"]
|
description_parts = [f"这是一张{type_name}图片。"]
|
||||||
|
|
||||||
if ocr_text:
|
if ocr_text:
|
||||||
# 提取前200字符作为摘要
|
# 提取前200字符作为摘要
|
||||||
text_preview = ocr_text[:200].replace('\n', ' ')
|
text_preview = ocr_text[:200].replace("\n", " ")
|
||||||
if len(ocr_text) > 200:
|
if len(ocr_text) > 200:
|
||||||
text_preview += "..."
|
text_preview += "..."
|
||||||
description_parts.append(f"内容摘要:{text_preview}")
|
description_parts.append(f"内容摘要:{text_preview}")
|
||||||
|
|
||||||
if entities:
|
if entities:
|
||||||
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
|
entity_names = [e.name for e in entities[:5]] # 最多显示5个实体
|
||||||
description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}")
|
description_parts.append(f"识别到的关键实体:{', '.join(entity_names)}")
|
||||||
|
|
||||||
return " ".join(description_parts)
|
return " ".join(description_parts)
|
||||||
|
|
||||||
def process_image(self, image_data: bytes, filename: str = None,
|
def process_image(
|
||||||
image_id: str = None, detect_type: bool = True) -> ImageProcessingResult:
|
self, image_data: bytes, filename: str = None, image_id: str = None, detect_type: bool = True
|
||||||
|
) -> ImageProcessingResult:
|
||||||
"""
|
"""
|
||||||
处理单张图片
|
处理单张图片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_data: 图片二进制数据
|
image_data: 图片二进制数据
|
||||||
filename: 文件名
|
filename: 文件名
|
||||||
image_id: 图片ID(可选)
|
image_id: 图片ID(可选)
|
||||||
detect_type: 是否自动检测图片类型
|
detect_type: 是否自动检测图片类型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
图片处理结果
|
图片处理结果
|
||||||
"""
|
"""
|
||||||
image_id = image_id or str(uuid.uuid4())[:8]
|
image_id = image_id or str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
if not PIL_AVAILABLE:
|
if not PIL_AVAILABLE:
|
||||||
return ImageProcessingResult(
|
return ImageProcessingResult(
|
||||||
image_id=image_id,
|
image_id=image_id,
|
||||||
image_type='other',
|
image_type="other",
|
||||||
ocr_text='',
|
ocr_text="",
|
||||||
description='PIL not available',
|
description="PIL not available",
|
||||||
entities=[],
|
entities=[],
|
||||||
relations=[],
|
relations=[],
|
||||||
width=0,
|
width=0,
|
||||||
height=0,
|
height=0,
|
||||||
success=False,
|
success=False,
|
||||||
error_message='PIL library not available'
|
error_message="PIL library not available",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 加载图片
|
# 加载图片
|
||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
|
|
||||||
# 执行OCR
|
# 执行OCR
|
||||||
ocr_text, ocr_confidence = self.perform_ocr(image)
|
ocr_text, ocr_confidence = self.perform_ocr(image)
|
||||||
|
|
||||||
# 检测图片类型
|
# 检测图片类型
|
||||||
image_type = 'other'
|
image_type = "other"
|
||||||
if detect_type:
|
if detect_type:
|
||||||
image_type = self.detect_image_type(image, ocr_text)
|
image_type = self.detect_image_type(image, ocr_text)
|
||||||
|
|
||||||
# 提取实体
|
# 提取实体
|
||||||
entities = self.extract_entities_from_text(ocr_text)
|
entities = self.extract_entities_from_text(ocr_text)
|
||||||
|
|
||||||
# 生成描述
|
# 生成描述
|
||||||
description = self.generate_description(image_type, ocr_text, entities)
|
description = self.generate_description(image_type, ocr_text, entities)
|
||||||
|
|
||||||
# 提取关系(基于实体共现)
|
# 提取关系(基于实体共现)
|
||||||
relations = self._extract_relations(entities, ocr_text)
|
relations = self._extract_relations(entities, ocr_text)
|
||||||
|
|
||||||
# 保存图片文件(可选)
|
# 保存图片文件(可选)
|
||||||
if filename:
|
if filename:
|
||||||
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
|
save_path = os.path.join(self.temp_dir, f"{image_id}_{filename}")
|
||||||
image.save(save_path)
|
image.save(save_path)
|
||||||
|
|
||||||
return ImageProcessingResult(
|
return ImageProcessingResult(
|
||||||
image_id=image_id,
|
image_id=image_id,
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
@@ -411,125 +418,123 @@ class ImageProcessor:
|
|||||||
relations=relations,
|
relations=relations,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
success=True
|
success=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ImageProcessingResult(
|
return ImageProcessingResult(
|
||||||
image_id=image_id,
|
image_id=image_id,
|
||||||
image_type='other',
|
image_type="other",
|
||||||
ocr_text='',
|
ocr_text="",
|
||||||
description='',
|
description="",
|
||||||
entities=[],
|
entities=[],
|
||||||
relations=[],
|
relations=[],
|
||||||
width=0,
|
width=0,
|
||||||
height=0,
|
height=0,
|
||||||
success=False,
|
success=False,
|
||||||
error_message=str(e)
|
error_message=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_relations(self, entities: List[ImageEntity], text: str) -> List[ImageRelation]:
|
def _extract_relations(self, entities: List[ImageEntity], text: str) -> List[ImageRelation]:
|
||||||
"""
|
"""
|
||||||
从文本中提取实体关系
|
从文本中提取实体关系
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entities: 实体列表
|
entities: 实体列表
|
||||||
text: 文本内容
|
text: 文本内容
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
关系列表
|
关系列表
|
||||||
"""
|
"""
|
||||||
relations = []
|
relations = []
|
||||||
|
|
||||||
if len(entities) < 2:
|
if len(entities) < 2:
|
||||||
return relations
|
return relations
|
||||||
|
|
||||||
# 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关
|
# 简单的关系提取:如果两个实体在同一句子中出现,则认为它们相关
|
||||||
sentences = text.replace('。', '.').replace('!', '!').replace('?', '?').split('.')
|
sentences = text.replace("。", ".").replace("!", "!").replace("?", "?").split(".")
|
||||||
|
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
sentence_entities = []
|
sentence_entities = []
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if entity.name in sentence:
|
if entity.name in sentence:
|
||||||
sentence_entities.append(entity)
|
sentence_entities.append(entity)
|
||||||
|
|
||||||
# 如果句子中有多个实体,建立关系
|
# 如果句子中有多个实体,建立关系
|
||||||
if len(sentence_entities) >= 2:
|
if len(sentence_entities) >= 2:
|
||||||
for i in range(len(sentence_entities)):
|
for i in range(len(sentence_entities)):
|
||||||
for j in range(i + 1, len(sentence_entities)):
|
for j in range(i + 1, len(sentence_entities)):
|
||||||
relations.append(ImageRelation(
|
relations.append(
|
||||||
source=sentence_entities[i].name,
|
ImageRelation(
|
||||||
target=sentence_entities[j].name,
|
source=sentence_entities[i].name,
|
||||||
relation_type='related',
|
target=sentence_entities[j].name,
|
||||||
confidence=0.5
|
relation_type="related",
|
||||||
))
|
confidence=0.5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return relations
|
return relations
|
||||||
|
|
||||||
def process_batch(self, images_data: List[Tuple[bytes, str]],
|
def process_batch(self, images_data: List[Tuple[bytes, str]], project_id: str = None) -> BatchProcessingResult:
|
||||||
project_id: str = None) -> BatchProcessingResult:
|
|
||||||
"""
|
"""
|
||||||
批量处理图片
|
批量处理图片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
images_data: 图片数据列表,每项为 (image_data, filename)
|
images_data: 图片数据列表,每项为 (image_data, filename)
|
||||||
project_id: 项目ID
|
project_id: 项目ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
批量处理结果
|
批量处理结果
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
success_count = 0
|
success_count = 0
|
||||||
failed_count = 0
|
failed_count = 0
|
||||||
|
|
||||||
for image_data, filename in images_data:
|
for image_data, filename in images_data:
|
||||||
result = self.process_image(image_data, filename)
|
result = self.process_image(image_data, filename)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
if result.success:
|
if result.success:
|
||||||
success_count += 1
|
success_count += 1
|
||||||
else:
|
else:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
|
|
||||||
return BatchProcessingResult(
|
return BatchProcessingResult(
|
||||||
results=results,
|
results=results, total_count=len(results), success_count=success_count, failed_count=failed_count
|
||||||
total_count=len(results),
|
|
||||||
success_count=success_count,
|
|
||||||
failed_count=failed_count
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def image_to_base64(self, image_data: bytes) -> str:
|
def image_to_base64(self, image_data: bytes) -> str:
|
||||||
"""
|
"""
|
||||||
将图片转换为base64编码
|
将图片转换为base64编码
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_data: 图片二进制数据
|
image_data: 图片二进制数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
base64编码的字符串
|
base64编码的字符串
|
||||||
"""
|
"""
|
||||||
return base64.b64encode(image_data).decode('utf-8')
|
return base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
def get_image_thumbnail(self, image_data: bytes, size: Tuple[int, int] = (200, 200)) -> bytes:
|
def get_image_thumbnail(self, image_data: bytes, size: Tuple[int, int] = (200, 200)) -> bytes:
|
||||||
"""
|
"""
|
||||||
生成图片缩略图
|
生成图片缩略图
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_data: 图片二进制数据
|
image_data: 图片二进制数据
|
||||||
size: 缩略图尺寸
|
size: 缩略图尺寸
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
缩略图二进制数据
|
缩略图二进制数据
|
||||||
"""
|
"""
|
||||||
if not PIL_AVAILABLE:
|
if not PIL_AVAILABLE:
|
||||||
return image_data
|
return image_data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
image.thumbnail(size, Image.Resampling.LANCZOS)
|
image.thumbnail(size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
image.save(buffer, format='JPEG')
|
image.save(buffer, format="JPEG")
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Thumbnail generation error: {e}")
|
print(f"Thumbnail generation error: {e}")
|
||||||
@@ -539,6 +544,7 @@ class ImageProcessor:
|
|||||||
# Singleton instance
|
# Singleton instance
|
||||||
_image_processor = None
|
_image_processor = None
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
|
def get_image_processor(temp_dir: str = None) -> ImageProcessor:
|
||||||
"""获取图片处理器单例"""
|
"""获取图片处理器单例"""
|
||||||
global _image_processor
|
global _image_processor
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ print(f"Database path: {db_path}")
|
|||||||
print(f"Schema path: {schema_path}")
|
print(f"Schema path: {schema_path}")
|
||||||
|
|
||||||
# Read schema
|
# Read schema
|
||||||
with open(schema_path, 'r') as f:
|
with open(schema_path, "r") as f:
|
||||||
schema = f.read()
|
schema = f.read()
|
||||||
|
|
||||||
# Execute schema
|
# Execute schema
|
||||||
@@ -19,7 +19,7 @@ conn = sqlite3.connect(db_path)
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Split schema by semicolons and execute each statement
|
# Split schema by semicolons and execute each statement
|
||||||
statements = schema.split(';')
|
statements = schema.split(";")
|
||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ InsightFlow Knowledge Reasoning - Phase 5
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List, Dict, Optional, Any
|
from typing import List, Dict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -17,76 +17,65 @@ KIMI_BASE_URL = os.getenv("KIMI_BASE_URL", "https://api.kimi.com/coding")
|
|||||||
|
|
||||||
class ReasoningType(Enum):
|
class ReasoningType(Enum):
|
||||||
"""推理类型"""
|
"""推理类型"""
|
||||||
CAUSAL = "causal" # 因果推理
|
|
||||||
ASSOCIATIVE = "associative" # 关联推理
|
CAUSAL = "causal" # 因果推理
|
||||||
TEMPORAL = "temporal" # 时序推理
|
ASSOCIATIVE = "associative" # 关联推理
|
||||||
COMPARATIVE = "comparative" # 对比推理
|
TEMPORAL = "temporal" # 时序推理
|
||||||
SUMMARY = "summary" # 总结推理
|
COMPARATIVE = "comparative" # 对比推理
|
||||||
|
SUMMARY = "summary" # 总结推理
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReasoningResult:
|
class ReasoningResult:
|
||||||
"""推理结果"""
|
"""推理结果"""
|
||||||
|
|
||||||
answer: str
|
answer: str
|
||||||
reasoning_type: ReasoningType
|
reasoning_type: ReasoningType
|
||||||
confidence: float
|
confidence: float
|
||||||
evidence: List[Dict] # 支撑证据
|
evidence: List[Dict] # 支撑证据
|
||||||
related_entities: List[str] # 相关实体
|
related_entities: List[str] # 相关实体
|
||||||
gaps: List[str] # 知识缺口
|
gaps: List[str] # 知识缺口
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferencePath:
|
class InferencePath:
|
||||||
"""推理路径"""
|
"""推理路径"""
|
||||||
|
|
||||||
start_entity: str
|
start_entity: str
|
||||||
end_entity: str
|
end_entity: str
|
||||||
path: List[Dict] # 路径上的节点和关系
|
path: List[Dict] # 路径上的节点和关系
|
||||||
strength: float # 路径强度
|
strength: float # 路径强度
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeReasoner:
|
class KnowledgeReasoner:
|
||||||
"""知识推理引擎"""
|
"""知识推理引擎"""
|
||||||
|
|
||||||
def __init__(self, api_key: str = None, base_url: str = None):
|
def __init__(self, api_key: str = None, base_url: str = None):
|
||||||
self.api_key = api_key or KIMI_API_KEY
|
self.api_key = api_key or KIMI_API_KEY
|
||||||
self.base_url = base_url or KIMI_BASE_URL
|
self.base_url = base_url or KIMI_BASE_URL
|
||||||
self.headers = {
|
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
|
async def _call_llm(self, prompt: str, temperature: float = 0.3) -> str:
|
||||||
"""调用 LLM"""
|
"""调用 LLM"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("KIMI_API_KEY not set")
|
raise ValueError("KIMI_API_KEY not set")
|
||||||
|
|
||||||
payload = {
|
payload = {"model": "k2p5", "messages": [{"role": "user", "content": prompt}], "temperature": temperature}
|
||||||
"model": "k2p5",
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/v1/chat/completions",
|
f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=120.0
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result["choices"][0]["message"]["content"]
|
return result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
async def enhanced_qa(
|
async def enhanced_qa(
|
||||||
self,
|
self, query: str, project_context: Dict, graph_data: Dict, reasoning_depth: str = "medium"
|
||||||
query: str,
|
|
||||||
project_context: Dict,
|
|
||||||
graph_data: Dict,
|
|
||||||
reasoning_depth: str = "medium"
|
|
||||||
) -> ReasoningResult:
|
) -> ReasoningResult:
|
||||||
"""
|
"""
|
||||||
增强问答 - 结合图谱推理的问答
|
增强问答 - 结合图谱推理的问答
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 用户问题
|
query: 用户问题
|
||||||
project_context: 项目上下文
|
project_context: 项目上下文
|
||||||
@@ -95,7 +84,7 @@ class KnowledgeReasoner:
|
|||||||
"""
|
"""
|
||||||
# 1. 分析问题类型
|
# 1. 分析问题类型
|
||||||
analysis = await self._analyze_question(query)
|
analysis = await self._analyze_question(query)
|
||||||
|
|
||||||
# 2. 根据问题类型选择推理策略
|
# 2. 根据问题类型选择推理策略
|
||||||
if analysis["type"] == "causal":
|
if analysis["type"] == "causal":
|
||||||
return await self._causal_reasoning(query, project_context, graph_data)
|
return await self._causal_reasoning(query, project_context, graph_data)
|
||||||
@@ -105,7 +94,7 @@ class KnowledgeReasoner:
|
|||||||
return await self._temporal_reasoning(query, project_context, graph_data)
|
return await self._temporal_reasoning(query, project_context, graph_data)
|
||||||
else:
|
else:
|
||||||
return await self._associative_reasoning(query, project_context, graph_data)
|
return await self._associative_reasoning(query, project_context, graph_data)
|
||||||
|
|
||||||
async def _analyze_question(self, query: str) -> Dict:
|
async def _analyze_question(self, query: str) -> Dict:
|
||||||
"""分析问题类型和意图"""
|
"""分析问题类型和意图"""
|
||||||
prompt = f"""分析以下问题的类型和意图:
|
prompt = f"""分析以下问题的类型和意图:
|
||||||
@@ -126,31 +115,27 @@ class KnowledgeReasoner:
|
|||||||
- temporal: 时序类问题(什么时候、进度、变化)
|
- temporal: 时序类问题(什么时候、进度、变化)
|
||||||
- factual: 事实类问题(是什么、有哪些)
|
- factual: 事实类问题(是什么、有哪些)
|
||||||
- opinion: 观点类问题(怎么看、态度、评价)"""
|
- opinion: 观点类问题(怎么看、态度、评价)"""
|
||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.1)
|
content = await self._call_llm(prompt, temperature=0.1)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
return json.loads(json_match.group())
|
return json.loads(json_match.group())
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
|
return {"type": "factual", "entities": [], "intent": "general", "complexity": "simple"}
|
||||||
|
|
||||||
async def _causal_reasoning(
|
async def _causal_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
project_context: Dict,
|
|
||||||
graph_data: Dict
|
|
||||||
) -> ReasoningResult:
|
|
||||||
"""因果推理 - 分析原因和影响"""
|
"""因果推理 - 分析原因和影响"""
|
||||||
|
|
||||||
# 构建因果分析提示
|
# 构建因果分析提示
|
||||||
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)
|
entities_str = json.dumps(graph_data.get("entities", []), ensure_ascii=False, indent=2)
|
||||||
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)
|
relations_str = json.dumps(graph_data.get("relations", []), ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
prompt = f"""基于以下知识图谱进行因果推理分析:
|
prompt = f"""基于以下知识图谱进行因果推理分析:
|
||||||
|
|
||||||
## 问题
|
## 问题
|
||||||
@@ -175,12 +160,13 @@ class KnowledgeReasoner:
|
|||||||
"evidence": ["证据1", "证据2"],
|
"evidence": ["证据1", "证据2"],
|
||||||
"knowledge_gaps": ["缺失信息1"]
|
"knowledge_gaps": ["缺失信息1"]
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group())
|
data = json.loads(json_match.group())
|
||||||
@@ -190,28 +176,23 @@ class KnowledgeReasoner:
|
|||||||
confidence=data.get("confidence", 0.7),
|
confidence=data.get("confidence", 0.7),
|
||||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=data.get("knowledge_gaps", [])
|
gaps=data.get("knowledge_gaps", []),
|
||||||
)
|
)
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ReasoningResult(
|
return ReasoningResult(
|
||||||
answer=content,
|
answer=content,
|
||||||
reasoning_type=ReasoningType.CAUSAL,
|
reasoning_type=ReasoningType.CAUSAL,
|
||||||
confidence=0.5,
|
confidence=0.5,
|
||||||
evidence=[],
|
evidence=[],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=["无法完成因果推理"]
|
gaps=["无法完成因果推理"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _comparative_reasoning(
|
async def _comparative_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
project_context: Dict,
|
|
||||||
graph_data: Dict
|
|
||||||
) -> ReasoningResult:
|
|
||||||
"""对比推理 - 比较实体间的异同"""
|
"""对比推理 - 比较实体间的异同"""
|
||||||
|
|
||||||
prompt = f"""基于以下知识图谱进行对比分析:
|
prompt = f"""基于以下知识图谱进行对比分析:
|
||||||
|
|
||||||
## 问题
|
## 问题
|
||||||
@@ -233,12 +214,13 @@ class KnowledgeReasoner:
|
|||||||
"evidence": ["证据1"],
|
"evidence": ["证据1"],
|
||||||
"knowledge_gaps": []
|
"knowledge_gaps": []
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group())
|
data = json.loads(json_match.group())
|
||||||
@@ -248,28 +230,23 @@ class KnowledgeReasoner:
|
|||||||
confidence=data.get("confidence", 0.7),
|
confidence=data.get("confidence", 0.7),
|
||||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=data.get("knowledge_gaps", [])
|
gaps=data.get("knowledge_gaps", []),
|
||||||
)
|
)
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ReasoningResult(
|
return ReasoningResult(
|
||||||
answer=content,
|
answer=content,
|
||||||
reasoning_type=ReasoningType.COMPARATIVE,
|
reasoning_type=ReasoningType.COMPARATIVE,
|
||||||
confidence=0.5,
|
confidence=0.5,
|
||||||
evidence=[],
|
evidence=[],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=[]
|
gaps=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _temporal_reasoning(
|
async def _temporal_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
project_context: Dict,
|
|
||||||
graph_data: Dict
|
|
||||||
) -> ReasoningResult:
|
|
||||||
"""时序推理 - 分析时间线和演变"""
|
"""时序推理 - 分析时间线和演变"""
|
||||||
|
|
||||||
prompt = f"""基于以下知识图谱进行时序分析:
|
prompt = f"""基于以下知识图谱进行时序分析:
|
||||||
|
|
||||||
## 问题
|
## 问题
|
||||||
@@ -291,12 +268,13 @@ class KnowledgeReasoner:
|
|||||||
"evidence": ["证据1"],
|
"evidence": ["证据1"],
|
||||||
"knowledge_gaps": []
|
"knowledge_gaps": []
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group())
|
data = json.loads(json_match.group())
|
||||||
@@ -306,28 +284,23 @@ class KnowledgeReasoner:
|
|||||||
confidence=data.get("confidence", 0.7),
|
confidence=data.get("confidence", 0.7),
|
||||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=data.get("knowledge_gaps", [])
|
gaps=data.get("knowledge_gaps", []),
|
||||||
)
|
)
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ReasoningResult(
|
return ReasoningResult(
|
||||||
answer=content,
|
answer=content,
|
||||||
reasoning_type=ReasoningType.TEMPORAL,
|
reasoning_type=ReasoningType.TEMPORAL,
|
||||||
confidence=0.5,
|
confidence=0.5,
|
||||||
evidence=[],
|
evidence=[],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=[]
|
gaps=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _associative_reasoning(
|
async def _associative_reasoning(self, query: str, project_context: Dict, graph_data: Dict) -> ReasoningResult:
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
project_context: Dict,
|
|
||||||
graph_data: Dict
|
|
||||||
) -> ReasoningResult:
|
|
||||||
"""关联推理 - 发现实体间的隐含关联"""
|
"""关联推理 - 发现实体间的隐含关联"""
|
||||||
|
|
||||||
prompt = f"""基于以下知识图谱进行关联分析:
|
prompt = f"""基于以下知识图谱进行关联分析:
|
||||||
|
|
||||||
## 问题
|
## 问题
|
||||||
@@ -349,12 +322,13 @@ class KnowledgeReasoner:
|
|||||||
"evidence": ["证据1"],
|
"evidence": ["证据1"],
|
||||||
"knowledge_gaps": []
|
"knowledge_gaps": []
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.4)
|
content = await self._call_llm(prompt, temperature=0.4)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group())
|
data = json.loads(json_match.group())
|
||||||
@@ -364,35 +338,31 @@ class KnowledgeReasoner:
|
|||||||
confidence=data.get("confidence", 0.7),
|
confidence=data.get("confidence", 0.7),
|
||||||
evidence=[{"text": e} for e in data.get("evidence", [])],
|
evidence=[{"text": e} for e in data.get("evidence", [])],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=data.get("knowledge_gaps", [])
|
gaps=data.get("knowledge_gaps", []),
|
||||||
)
|
)
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ReasoningResult(
|
return ReasoningResult(
|
||||||
answer=content,
|
answer=content,
|
||||||
reasoning_type=ReasoningType.ASSOCIATIVE,
|
reasoning_type=ReasoningType.ASSOCIATIVE,
|
||||||
confidence=0.5,
|
confidence=0.5,
|
||||||
evidence=[],
|
evidence=[],
|
||||||
related_entities=[],
|
related_entities=[],
|
||||||
gaps=[]
|
gaps=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_inference_paths(
|
def find_inference_paths(
|
||||||
self,
|
self, start_entity: str, end_entity: str, graph_data: Dict, max_depth: int = 3
|
||||||
start_entity: str,
|
|
||||||
end_entity: str,
|
|
||||||
graph_data: Dict,
|
|
||||||
max_depth: int = 3
|
|
||||||
) -> List[InferencePath]:
|
) -> List[InferencePath]:
|
||||||
"""
|
"""
|
||||||
发现两个实体之间的推理路径
|
发现两个实体之间的推理路径
|
||||||
|
|
||||||
使用 BFS 在关系图中搜索路径
|
使用 BFS 在关系图中搜索路径
|
||||||
"""
|
"""
|
||||||
entities = {e["id"]: e for e in graph_data.get("entities", [])}
|
entities = {e["id"]: e for e in graph_data.get("entities", [])}
|
||||||
relations = graph_data.get("relations", [])
|
relations = graph_data.get("relations", [])
|
||||||
|
|
||||||
# 构建邻接表
|
# 构建邻接表
|
||||||
adj = {}
|
adj = {}
|
||||||
for r in relations:
|
for r in relations:
|
||||||
@@ -405,51 +375,56 @@ class KnowledgeReasoner:
|
|||||||
adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r})
|
adj[src].append({"target": tgt, "relation": r.get("type", "related"), "data": r})
|
||||||
# 无向图也添加反向
|
# 无向图也添加反向
|
||||||
adj[tgt].append({"target": src, "relation": r.get("type", "related"), "data": r, "reverse": True})
|
adj[tgt].append({"target": src, "relation": r.get("type", "related"), "data": r, "reverse": True})
|
||||||
|
|
||||||
# BFS 搜索路径
|
# BFS 搜索路径
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
paths = []
|
paths = []
|
||||||
queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])])
|
queue = deque([(start_entity, [{"entity": start_entity, "relation": None}])])
|
||||||
visited = {start_entity}
|
{start_entity}
|
||||||
|
|
||||||
while queue and len(paths) < 5:
|
while queue and len(paths) < 5:
|
||||||
current, path = queue.popleft()
|
current, path = queue.popleft()
|
||||||
|
|
||||||
if current == end_entity and len(path) > 1:
|
if current == end_entity and len(path) > 1:
|
||||||
# 找到一条路径
|
# 找到一条路径
|
||||||
paths.append(InferencePath(
|
paths.append(
|
||||||
start_entity=start_entity,
|
InferencePath(
|
||||||
end_entity=end_entity,
|
start_entity=start_entity,
|
||||||
path=path,
|
end_entity=end_entity,
|
||||||
strength=self._calculate_path_strength(path)
|
path=path,
|
||||||
))
|
strength=self._calculate_path_strength(path),
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(path) >= max_depth:
|
if len(path) >= max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for neighbor in adj.get(current, []):
|
for neighbor in adj.get(current, []):
|
||||||
next_entity = neighbor["target"]
|
next_entity = neighbor["target"]
|
||||||
if next_entity not in [p["entity"] for p in path]: # 避免循环
|
if next_entity not in [p["entity"] for p in path]: # 避免循环
|
||||||
new_path = path + [{
|
new_path = path + [
|
||||||
"entity": next_entity,
|
{
|
||||||
"relation": neighbor["relation"],
|
"entity": next_entity,
|
||||||
"relation_data": neighbor.get("data", {})
|
"relation": neighbor["relation"],
|
||||||
}]
|
"relation_data": neighbor.get("data", {}),
|
||||||
|
}
|
||||||
|
]
|
||||||
queue.append((next_entity, new_path))
|
queue.append((next_entity, new_path))
|
||||||
|
|
||||||
# 按强度排序
|
# 按强度排序
|
||||||
paths.sort(key=lambda p: p.strength, reverse=True)
|
paths.sort(key=lambda p: p.strength, reverse=True)
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
def _calculate_path_strength(self, path: List[Dict]) -> float:
|
def _calculate_path_strength(self, path: List[Dict]) -> float:
|
||||||
"""计算路径强度"""
|
"""计算路径强度"""
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# 路径越短越强
|
# 路径越短越强
|
||||||
length_factor = 1.0 / len(path)
|
length_factor = 1.0 / len(path)
|
||||||
|
|
||||||
# 关系置信度
|
# 关系置信度
|
||||||
confidence_sum = 0
|
confidence_sum = 0
|
||||||
confidence_count = 0
|
confidence_count = 0
|
||||||
@@ -458,20 +433,17 @@ class KnowledgeReasoner:
|
|||||||
if "confidence" in rel_data:
|
if "confidence" in rel_data:
|
||||||
confidence_sum += rel_data["confidence"]
|
confidence_sum += rel_data["confidence"]
|
||||||
confidence_count += 1
|
confidence_count += 1
|
||||||
|
|
||||||
confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5
|
confidence_factor = (confidence_sum / confidence_count) if confidence_count > 0 else 0.5
|
||||||
|
|
||||||
return length_factor * confidence_factor
|
return length_factor * confidence_factor
|
||||||
|
|
||||||
async def summarize_project(
|
async def summarize_project(
|
||||||
self,
|
self, project_context: Dict, graph_data: Dict, summary_type: str = "comprehensive"
|
||||||
project_context: Dict,
|
|
||||||
graph_data: Dict,
|
|
||||||
summary_type: str = "comprehensive"
|
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
项目智能总结
|
项目智能总结
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
summary_type: comprehensive/executive/technical/risk
|
summary_type: comprehensive/executive/technical/risk
|
||||||
"""
|
"""
|
||||||
@@ -479,9 +451,9 @@ class KnowledgeReasoner:
|
|||||||
"comprehensive": "全面总结项目的所有方面",
|
"comprehensive": "全面总结项目的所有方面",
|
||||||
"executive": "高管摘要,关注关键决策和风险",
|
"executive": "高管摘要,关注关键决策和风险",
|
||||||
"technical": "技术总结,关注架构和技术栈",
|
"technical": "技术总结,关注架构和技术栈",
|
||||||
"risk": "风险分析,关注潜在问题和依赖"
|
"risk": "风险分析,关注潜在问题和依赖",
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}:
|
prompt = f"""请对以下项目进行{type_prompts.get(summary_type, "全面总结")}:
|
||||||
|
|
||||||
## 项目信息
|
## 项目信息
|
||||||
@@ -500,25 +472,26 @@ class KnowledgeReasoner:
|
|||||||
"recommendations": ["建议1"],
|
"recommendations": ["建议1"],
|
||||||
"confidence": 0.85
|
"confidence": 0.85
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
content = await self._call_llm(prompt, temperature=0.3)
|
content = await self._call_llm(prompt, temperature=0.3)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
|
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
return json.loads(json_match.group())
|
return json.loads(json_match.group())
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"overview": content,
|
"overview": content,
|
||||||
"key_points": [],
|
"key_points": [],
|
||||||
"key_entities": [],
|
"key_entities": [],
|
||||||
"risks": [],
|
"risks": [],
|
||||||
"recommendations": [],
|
"recommendations": [],
|
||||||
"confidence": 0.5
|
"confidence": 0.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -530,4 +503,4 @@ def get_knowledge_reasoner() -> KnowledgeReasoner:
|
|||||||
global _reasoner
|
global _reasoner
|
||||||
if _reasoner is None:
|
if _reasoner is None:
|
||||||
_reasoner = KnowledgeReasoner()
|
_reasoner = KnowledgeReasoner()
|
||||||
return _reasoner
|
return _reasoner
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ InsightFlow LLM Client - Phase 4
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List, Dict, Optional, AsyncGenerator
|
from typing import List, Dict, AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||||||
@@ -38,57 +38,47 @@ class RelationExtractionResult:
|
|||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""Kimi API 客户端"""
|
"""Kimi API 客户端"""
|
||||||
|
|
||||||
def __init__(self, api_key: str = None, base_url: str = None):
|
def __init__(self, api_key: str = None, base_url: str = None):
|
||||||
self.api_key = api_key or KIMI_API_KEY
|
self.api_key = api_key or KIMI_API_KEY
|
||||||
self.base_url = base_url or KIMI_BASE_URL
|
self.base_url = base_url or KIMI_BASE_URL
|
||||||
self.headers = {
|
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def chat(self, messages: List[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str:
|
async def chat(self, messages: List[ChatMessage], temperature: float = 0.3, stream: bool = False) -> str:
|
||||||
"""发送聊天请求"""
|
"""发送聊天请求"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("KIMI_API_KEY not set")
|
raise ValueError("KIMI_API_KEY not set")
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": "k2p5",
|
"model": "k2p5",
|
||||||
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"stream": stream
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/v1/chat/completions",
|
f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=120.0
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result["choices"][0]["message"]["content"]
|
return result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
async def chat_stream(self, messages: List[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]:
|
async def chat_stream(self, messages: List[ChatMessage], temperature: float = 0.3) -> AsyncGenerator[str, None]:
|
||||||
"""流式聊天请求"""
|
"""流式聊天请求"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("KIMI_API_KEY not set")
|
raise ValueError("KIMI_API_KEY not set")
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": "k2p5",
|
"model": "k2p5",
|
||||||
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"stream": True
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST", f"{self.base_url}/v1/chat/completions", headers=self.headers, json=payload, timeout=120.0
|
||||||
f"{self.base_url}/v1/chat/completions",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=120.0
|
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
@@ -101,10 +91,12 @@ class LLMClient:
|
|||||||
delta = chunk["choices"][0]["delta"]
|
delta = chunk["choices"][0]["delta"]
|
||||||
if "content" in delta:
|
if "content" in delta:
|
||||||
yield delta["content"]
|
yield delta["content"]
|
||||||
except:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def extract_entities_with_confidence(self, text: str) -> tuple[List[EntityExtractionResult], List[RelationExtractionResult]]:
|
async def extract_entities_with_confidence(
|
||||||
|
self, text: str
|
||||||
|
) -> tuple[List[EntityExtractionResult], List[RelationExtractionResult]]:
|
||||||
"""提取实体和关系,带置信度分数"""
|
"""提取实体和关系,带置信度分数"""
|
||||||
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
|
prompt = f"""从以下会议文本中提取关键实体和它们之间的关系,以 JSON 格式返回:
|
||||||
|
|
||||||
@@ -125,15 +117,16 @@ class LLMClient:
|
|||||||
{{"source": "Project Alpha", "target": "K8s", "type": "depends_on", "confidence": 0.82}}
|
{{"source": "Project Alpha", "target": "K8s", "type": "depends_on", "confidence": 0.82}}
|
||||||
]
|
]
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
messages = [ChatMessage(role="user", content=prompt)]
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
content = await self.chat(messages, temperature=0.1)
|
content = await self.chat(messages, temperature=0.1)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
if not json_match:
|
if not json_match:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group())
|
data = json.loads(json_match.group())
|
||||||
entities = [
|
entities = [
|
||||||
@@ -141,7 +134,7 @@ class LLMClient:
|
|||||||
name=e["name"],
|
name=e["name"],
|
||||||
type=e.get("type", "OTHER"),
|
type=e.get("type", "OTHER"),
|
||||||
definition=e.get("definition", ""),
|
definition=e.get("definition", ""),
|
||||||
confidence=e.get("confidence", 0.8)
|
confidence=e.get("confidence", 0.8),
|
||||||
)
|
)
|
||||||
for e in data.get("entities", [])
|
for e in data.get("entities", [])
|
||||||
]
|
]
|
||||||
@@ -150,7 +143,7 @@ class LLMClient:
|
|||||||
source=r["source"],
|
source=r["source"],
|
||||||
target=r["target"],
|
target=r["target"],
|
||||||
type=r.get("type", "related"),
|
type=r.get("type", "related"),
|
||||||
confidence=r.get("confidence", 0.8)
|
confidence=r.get("confidence", 0.8),
|
||||||
)
|
)
|
||||||
for r in data.get("relations", [])
|
for r in data.get("relations", [])
|
||||||
]
|
]
|
||||||
@@ -158,7 +151,7 @@ class LLMClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Parse extraction result failed: {e}")
|
print(f"Parse extraction result failed: {e}")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
async def rag_query(self, query: str, context: str, project_context: Dict) -> str:
|
async def rag_query(self, query: str, context: str, project_context: Dict) -> str:
|
||||||
"""RAG 问答 - 基于项目上下文回答问题"""
|
"""RAG 问答 - 基于项目上下文回答问题"""
|
||||||
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
|
prompt = f"""你是一个专业的项目分析助手。基于以下项目信息回答问题:
|
||||||
@@ -173,14 +166,14 @@ class LLMClient:
|
|||||||
{query}
|
{query}
|
||||||
|
|
||||||
请用中文回答,保持简洁专业。如果信息不足,请明确说明。"""
|
请用中文回答,保持简洁专业。如果信息不足,请明确说明。"""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
ChatMessage(role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"),
|
ChatMessage(role="system", content="你是一个专业的项目分析助手,擅长从会议记录中提取洞察。"),
|
||||||
ChatMessage(role="user", content=prompt)
|
ChatMessage(role="user", content=prompt),
|
||||||
]
|
]
|
||||||
|
|
||||||
return await self.chat(messages, temperature=0.3)
|
return await self.chat(messages, temperature=0.3)
|
||||||
|
|
||||||
async def agent_command(self, command: str, project_context: Dict) -> Dict:
|
async def agent_command(self, command: str, project_context: Dict) -> Dict:
|
||||||
"""Agent 指令解析 - 将自然语言指令转换为结构化操作"""
|
"""Agent 指令解析 - 将自然语言指令转换为结构化操作"""
|
||||||
prompt = f"""解析以下用户指令,转换为结构化操作:
|
prompt = f"""解析以下用户指令,转换为结构化操作:
|
||||||
@@ -206,27 +199,27 @@ class LLMClient:
|
|||||||
- edit_entity: 编辑实体,params 包含 entity_name(实体名), field(字段), value(新值)
|
- edit_entity: 编辑实体,params 包含 entity_name(实体名), field(字段), value(新值)
|
||||||
- create_relation: 创建关系,params 包含 source(源实体), target(目标实体), relation_type(关系类型)
|
- create_relation: 创建关系,params 包含 source(源实体), target(目标实体), relation_type(关系类型)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = [ChatMessage(role="user", content=prompt)]
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
content = await self.chat(messages, temperature=0.1)
|
content = await self.chat(messages, temperature=0.1)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'\{{.*?\}}', content, re.DOTALL)
|
|
||||||
|
json_match = re.search(r"\{{.*?\}}", content, re.DOTALL)
|
||||||
if not json_match:
|
if not json_match:
|
||||||
return {"intent": "unknown", "explanation": "无法解析指令"}
|
return {"intent": "unknown", "explanation": "无法解析指令"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return json.loads(json_match.group())
|
return json.loads(json_match.group())
|
||||||
except:
|
except BaseException:
|
||||||
return {"intent": "unknown", "explanation": "解析失败"}
|
return {"intent": "unknown", "explanation": "解析失败"}
|
||||||
|
|
||||||
async def analyze_entity_evolution(self, entity_name: str, mentions: List[Dict]) -> str:
|
async def analyze_entity_evolution(self, entity_name: str, mentions: List[Dict]) -> str:
|
||||||
"""分析实体在项目中的演变/态度变化"""
|
"""分析实体在项目中的演变/态度变化"""
|
||||||
mentions_text = "\n".join([
|
mentions_text = "\n".join(
|
||||||
f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}"
|
[f"[{m.get('created_at', '未知时间')}] {m.get('text_snippet', '')}" for m in mentions[:20]] # 限制数量
|
||||||
for m in mentions[:20] # 限制数量
|
)
|
||||||
])
|
|
||||||
|
|
||||||
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
|
prompt = f"""分析实体 "{entity_name}" 在项目中的演变和态度变化:
|
||||||
|
|
||||||
## 提及记录
|
## 提及记录
|
||||||
@@ -239,7 +232,7 @@ class LLMClient:
|
|||||||
4. 总结性洞察
|
4. 总结性洞察
|
||||||
|
|
||||||
用中文回答,结构清晰。"""
|
用中文回答,结构清晰。"""
|
||||||
|
|
||||||
messages = [ChatMessage(role="user", content=prompt)]
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
return await self.chat(messages, temperature=0.3)
|
return await self.chat(messages, temperature=0.3)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
3373
backend/main.py
3373
backend/main.py
File diff suppressed because it is too large
Load Diff
@@ -4,8 +4,6 @@ InsightFlow Multimodal Entity Linker - Phase 7
|
|||||||
多模态实体关联模块:跨模态实体对齐和知识融合
|
多模态实体关联模块:跨模态实体对齐和知识融合
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict, Optional, Tuple, Set
|
from typing import List, Dict, Optional, Tuple, Set
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -13,7 +11,6 @@ from difflib import SequenceMatcher
|
|||||||
|
|
||||||
# 尝试导入embedding库
|
# 尝试导入embedding库
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
|
||||||
NUMPY_AVAILABLE = True
|
NUMPY_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
NUMPY_AVAILABLE = False
|
NUMPY_AVAILABLE = False
|
||||||
@@ -22,6 +19,7 @@ except ImportError:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MultimodalEntity:
|
class MultimodalEntity:
|
||||||
"""多模态实体"""
|
"""多模态实体"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
entity_id: str
|
entity_id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
@@ -31,7 +29,7 @@ class MultimodalEntity:
|
|||||||
mention_context: str
|
mention_context: str
|
||||||
confidence: float
|
confidence: float
|
||||||
modality_features: Dict = None # 模态特定特征
|
modality_features: Dict = None # 模态特定特征
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.modality_features is None:
|
if self.modality_features is None:
|
||||||
self.modality_features = {}
|
self.modality_features = {}
|
||||||
@@ -40,6 +38,7 @@ class MultimodalEntity:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EntityLink:
|
class EntityLink:
|
||||||
"""实体关联"""
|
"""实体关联"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
source_entity_id: str
|
source_entity_id: str
|
||||||
@@ -54,6 +53,7 @@ class EntityLink:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AlignmentResult:
|
class AlignmentResult:
|
||||||
"""对齐结果"""
|
"""对齐结果"""
|
||||||
|
|
||||||
entity_id: str
|
entity_id: str
|
||||||
matched_entity_id: Optional[str]
|
matched_entity_id: Optional[str]
|
||||||
similarity: float
|
similarity: float
|
||||||
@@ -64,6 +64,7 @@ class AlignmentResult:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FusionResult:
|
class FusionResult:
|
||||||
"""知识融合结果"""
|
"""知识融合结果"""
|
||||||
|
|
||||||
canonical_entity_id: str
|
canonical_entity_id: str
|
||||||
merged_entity_ids: List[str]
|
merged_entity_ids: List[str]
|
||||||
fused_properties: Dict
|
fused_properties: Dict
|
||||||
@@ -73,300 +74,290 @@ class FusionResult:
|
|||||||
|
|
||||||
class MultimodalEntityLinker:
|
class MultimodalEntityLinker:
|
||||||
"""多模态实体关联器 - 跨模态实体对齐和知识融合"""
|
"""多模态实体关联器 - 跨模态实体对齐和知识融合"""
|
||||||
|
|
||||||
# 关联类型
|
# 关联类型
|
||||||
LINK_TYPES = {
|
LINK_TYPES = {"same_as": "同一实体", "related_to": "相关实体", "part_of": "组成部分", "mentions": "提及关系"}
|
||||||
'same_as': '同一实体',
|
|
||||||
'related_to': '相关实体',
|
|
||||||
'part_of': '组成部分',
|
|
||||||
'mentions': '提及关系'
|
|
||||||
}
|
|
||||||
|
|
||||||
# 模态类型
|
# 模态类型
|
||||||
MODALITIES = ['audio', 'video', 'image', 'document']
|
MODALITIES = ["audio", "video", "image", "document"]
|
||||||
|
|
||||||
def __init__(self, similarity_threshold: float = 0.85):
|
def __init__(self, similarity_threshold: float = 0.85):
|
||||||
"""
|
"""
|
||||||
初始化多模态实体关联器
|
初始化多模态实体关联器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
similarity_threshold: 相似度阈值
|
similarity_threshold: 相似度阈值
|
||||||
"""
|
"""
|
||||||
self.similarity_threshold = similarity_threshold
|
self.similarity_threshold = similarity_threshold
|
||||||
|
|
||||||
def calculate_string_similarity(self, s1: str, s2: str) -> float:
|
def calculate_string_similarity(self, s1: str, s2: str) -> float:
|
||||||
"""
|
"""
|
||||||
计算字符串相似度
|
计算字符串相似度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s1: 字符串1
|
s1: 字符串1
|
||||||
s2: 字符串2
|
s2: 字符串2
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似度分数 (0-1)
|
相似度分数 (0-1)
|
||||||
"""
|
"""
|
||||||
if not s1 or not s2:
|
if not s1 or not s2:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
s1, s2 = s1.lower().strip(), s2.lower().strip()
|
s1, s2 = s1.lower().strip(), s2.lower().strip()
|
||||||
|
|
||||||
# 完全匹配
|
# 完全匹配
|
||||||
if s1 == s2:
|
if s1 == s2:
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
# 包含关系
|
# 包含关系
|
||||||
if s1 in s2 or s2 in s1:
|
if s1 in s2 or s2 in s1:
|
||||||
return 0.9
|
return 0.9
|
||||||
|
|
||||||
# 编辑距离相似度
|
# 编辑距离相似度
|
||||||
return SequenceMatcher(None, s1, s2).ratio()
|
return SequenceMatcher(None, s1, s2).ratio()
|
||||||
|
|
||||||
def calculate_entity_similarity(self, entity1: Dict, entity2: Dict) -> Tuple[float, str]:
|
def calculate_entity_similarity(self, entity1: Dict, entity2: Dict) -> Tuple[float, str]:
|
||||||
"""
|
"""
|
||||||
计算两个实体的综合相似度
|
计算两个实体的综合相似度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity1: 实体1信息
|
entity1: 实体1信息
|
||||||
entity2: 实体2信息
|
entity2: 实体2信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(相似度, 匹配类型)
|
(相似度, 匹配类型)
|
||||||
"""
|
"""
|
||||||
# 名称相似度
|
# 名称相似度
|
||||||
name_sim = self.calculate_string_similarity(
|
name_sim = self.calculate_string_similarity(entity1.get("name", ""), entity2.get("name", ""))
|
||||||
entity1.get('name', ''),
|
|
||||||
entity2.get('name', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果名称完全匹配
|
# 如果名称完全匹配
|
||||||
if name_sim == 1.0:
|
if name_sim == 1.0:
|
||||||
return 1.0, 'exact'
|
return 1.0, "exact"
|
||||||
|
|
||||||
# 检查别名
|
# 检查别名
|
||||||
aliases1 = set(a.lower() for a in entity1.get('aliases', []))
|
aliases1 = set(a.lower() for a in entity1.get("aliases", []))
|
||||||
aliases2 = set(a.lower() for a in entity2.get('aliases', []))
|
aliases2 = set(a.lower() for a in entity2.get("aliases", []))
|
||||||
|
|
||||||
if aliases1 & aliases2: # 有共同别名
|
if aliases1 & aliases2: # 有共同别名
|
||||||
return 0.95, 'alias_match'
|
return 0.95, "alias_match"
|
||||||
|
|
||||||
if entity2.get('name', '').lower() in aliases1:
|
if entity2.get("name", "").lower() in aliases1:
|
||||||
return 0.95, 'alias_match'
|
return 0.95, "alias_match"
|
||||||
if entity1.get('name', '').lower() in aliases2:
|
if entity1.get("name", "").lower() in aliases2:
|
||||||
return 0.95, 'alias_match'
|
return 0.95, "alias_match"
|
||||||
|
|
||||||
# 定义相似度
|
# 定义相似度
|
||||||
def_sim = self.calculate_string_similarity(
|
def_sim = self.calculate_string_similarity(entity1.get("definition", ""), entity2.get("definition", ""))
|
||||||
entity1.get('definition', ''),
|
|
||||||
entity2.get('definition', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
# 综合相似度
|
# 综合相似度
|
||||||
combined_sim = name_sim * 0.7 + def_sim * 0.3
|
combined_sim = name_sim * 0.7 + def_sim * 0.3
|
||||||
|
|
||||||
if combined_sim >= self.similarity_threshold:
|
if combined_sim >= self.similarity_threshold:
|
||||||
return combined_sim, 'fuzzy'
|
return combined_sim, "fuzzy"
|
||||||
|
|
||||||
return combined_sim, 'none'
|
return combined_sim, "none"
|
||||||
|
|
||||||
def find_matching_entity(self, query_entity: Dict,
|
def find_matching_entity(
|
||||||
candidate_entities: List[Dict],
|
self, query_entity: Dict, candidate_entities: List[Dict], exclude_ids: Set[str] = None
|
||||||
exclude_ids: Set[str] = None) -> Optional[AlignmentResult]:
|
) -> Optional[AlignmentResult]:
|
||||||
"""
|
"""
|
||||||
在候选实体中查找匹配的实体
|
在候选实体中查找匹配的实体
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_entity: 查询实体
|
query_entity: 查询实体
|
||||||
candidate_entities: 候选实体列表
|
candidate_entities: 候选实体列表
|
||||||
exclude_ids: 排除的实体ID
|
exclude_ids: 排除的实体ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
对齐结果
|
对齐结果
|
||||||
"""
|
"""
|
||||||
exclude_ids = exclude_ids or set()
|
exclude_ids = exclude_ids or set()
|
||||||
best_match = None
|
best_match = None
|
||||||
best_similarity = 0.0
|
best_similarity = 0.0
|
||||||
|
|
||||||
for candidate in candidate_entities:
|
for candidate in candidate_entities:
|
||||||
if candidate.get('id') in exclude_ids:
|
if candidate.get("id") in exclude_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
similarity, match_type = self.calculate_entity_similarity(
|
similarity, match_type = self.calculate_entity_similarity(query_entity, candidate)
|
||||||
query_entity, candidate
|
|
||||||
)
|
|
||||||
|
|
||||||
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
||||||
best_similarity = similarity
|
best_similarity = similarity
|
||||||
best_match = candidate
|
best_match = candidate
|
||||||
best_match_type = match_type
|
best_match_type = match_type
|
||||||
|
|
||||||
if best_match:
|
if best_match:
|
||||||
return AlignmentResult(
|
return AlignmentResult(
|
||||||
entity_id=query_entity.get('id'),
|
entity_id=query_entity.get("id"),
|
||||||
matched_entity_id=best_match.get('id'),
|
matched_entity_id=best_match.get("id"),
|
||||||
similarity=best_similarity,
|
similarity=best_similarity,
|
||||||
match_type=best_match_type,
|
match_type=best_match_type,
|
||||||
confidence=best_similarity
|
confidence=best_similarity,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def align_cross_modal_entities(self, project_id: str,
|
def align_cross_modal_entities(
|
||||||
audio_entities: List[Dict],
|
self,
|
||||||
video_entities: List[Dict],
|
project_id: str,
|
||||||
image_entities: List[Dict],
|
audio_entities: List[Dict],
|
||||||
document_entities: List[Dict]) -> List[EntityLink]:
|
video_entities: List[Dict],
|
||||||
|
image_entities: List[Dict],
|
||||||
|
document_entities: List[Dict],
|
||||||
|
) -> List[EntityLink]:
|
||||||
"""
|
"""
|
||||||
跨模态实体对齐
|
跨模态实体对齐
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: 项目ID
|
||||||
audio_entities: 音频模态实体
|
audio_entities: 音频模态实体
|
||||||
video_entities: 视频模态实体
|
video_entities: 视频模态实体
|
||||||
image_entities: 图片模态实体
|
image_entities: 图片模态实体
|
||||||
document_entities: 文档模态实体
|
document_entities: 文档模态实体
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
实体关联列表
|
实体关联列表
|
||||||
"""
|
"""
|
||||||
links = []
|
links = []
|
||||||
|
|
||||||
# 合并所有实体
|
# 合并所有实体
|
||||||
all_entities = {
|
all_entities = {
|
||||||
'audio': audio_entities,
|
"audio": audio_entities,
|
||||||
'video': video_entities,
|
"video": video_entities,
|
||||||
'image': image_entities,
|
"image": image_entities,
|
||||||
'document': document_entities
|
"document": document_entities,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 跨模态对齐
|
# 跨模态对齐
|
||||||
for mod1 in self.MODALITIES:
|
for mod1 in self.MODALITIES:
|
||||||
for mod2 in self.MODALITIES:
|
for mod2 in self.MODALITIES:
|
||||||
if mod1 >= mod2: # 避免重复比较
|
if mod1 >= mod2: # 避免重复比较
|
||||||
continue
|
continue
|
||||||
|
|
||||||
entities1 = all_entities.get(mod1, [])
|
entities1 = all_entities.get(mod1, [])
|
||||||
entities2 = all_entities.get(mod2, [])
|
entities2 = all_entities.get(mod2, [])
|
||||||
|
|
||||||
for ent1 in entities1:
|
for ent1 in entities1:
|
||||||
# 在另一个模态中查找匹配
|
# 在另一个模态中查找匹配
|
||||||
result = self.find_matching_entity(ent1, entities2)
|
result = self.find_matching_entity(ent1, entities2)
|
||||||
|
|
||||||
if result and result.matched_entity_id:
|
if result and result.matched_entity_id:
|
||||||
link = EntityLink(
|
link = EntityLink(
|
||||||
id=str(uuid.uuid4())[:8],
|
id=str(uuid.uuid4())[:8],
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
source_entity_id=ent1.get('id'),
|
source_entity_id=ent1.get("id"),
|
||||||
target_entity_id=result.matched_entity_id,
|
target_entity_id=result.matched_entity_id,
|
||||||
link_type='same_as' if result.similarity > 0.95 else 'related_to',
|
link_type="same_as" if result.similarity > 0.95 else "related_to",
|
||||||
source_modality=mod1,
|
source_modality=mod1,
|
||||||
target_modality=mod2,
|
target_modality=mod2,
|
||||||
confidence=result.confidence,
|
confidence=result.confidence,
|
||||||
evidence=f"Cross-modal alignment: {result.match_type}"
|
evidence=f"Cross-modal alignment: {result.match_type}",
|
||||||
)
|
)
|
||||||
links.append(link)
|
links.append(link)
|
||||||
|
|
||||||
return links
|
return links
|
||||||
|
|
||||||
def fuse_entity_knowledge(self, entity_id: str,
|
def fuse_entity_knowledge(
|
||||||
linked_entities: List[Dict],
|
self, entity_id: str, linked_entities: List[Dict], multimodal_mentions: List[Dict]
|
||||||
multimodal_mentions: List[Dict]) -> FusionResult:
|
) -> FusionResult:
|
||||||
"""
|
"""
|
||||||
融合多模态实体知识
|
融合多模态实体知识
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity_id: 主实体ID
|
entity_id: 主实体ID
|
||||||
linked_entities: 关联的实体信息列表
|
linked_entities: 关联的实体信息列表
|
||||||
multimodal_mentions: 多模态提及列表
|
multimodal_mentions: 多模态提及列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
融合结果
|
融合结果
|
||||||
"""
|
"""
|
||||||
# 收集所有属性
|
# 收集所有属性
|
||||||
fused_properties = {
|
fused_properties = {
|
||||||
'names': set(),
|
"names": set(),
|
||||||
'definitions': [],
|
"definitions": [],
|
||||||
'aliases': set(),
|
"aliases": set(),
|
||||||
'types': set(),
|
"types": set(),
|
||||||
'modalities': set(),
|
"modalities": set(),
|
||||||
'contexts': []
|
"contexts": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
merged_ids = []
|
merged_ids = []
|
||||||
|
|
||||||
for entity in linked_entities:
|
for entity in linked_entities:
|
||||||
merged_ids.append(entity.get('id'))
|
merged_ids.append(entity.get("id"))
|
||||||
|
|
||||||
# 收集名称
|
# 收集名称
|
||||||
fused_properties['names'].add(entity.get('name', ''))
|
fused_properties["names"].add(entity.get("name", ""))
|
||||||
|
|
||||||
# 收集定义
|
# 收集定义
|
||||||
if entity.get('definition'):
|
if entity.get("definition"):
|
||||||
fused_properties['definitions'].append(entity.get('definition'))
|
fused_properties["definitions"].append(entity.get("definition"))
|
||||||
|
|
||||||
# 收集别名
|
# 收集别名
|
||||||
fused_properties['aliases'].update(entity.get('aliases', []))
|
fused_properties["aliases"].update(entity.get("aliases", []))
|
||||||
|
|
||||||
# 收集类型
|
# 收集类型
|
||||||
fused_properties['types'].add(entity.get('type', 'OTHER'))
|
fused_properties["types"].add(entity.get("type", "OTHER"))
|
||||||
|
|
||||||
# 收集模态和上下文
|
# 收集模态和上下文
|
||||||
for mention in multimodal_mentions:
|
for mention in multimodal_mentions:
|
||||||
fused_properties['modalities'].add(mention.get('source_type', ''))
|
fused_properties["modalities"].add(mention.get("source_type", ""))
|
||||||
if mention.get('mention_context'):
|
if mention.get("mention_context"):
|
||||||
fused_properties['contexts'].append(mention.get('mention_context'))
|
fused_properties["contexts"].append(mention.get("mention_context"))
|
||||||
|
|
||||||
# 选择最佳定义(最长的那个)
|
# 选择最佳定义(最长的那个)
|
||||||
best_definition = max(fused_properties['definitions'], key=len) \
|
best_definition = max(fused_properties["definitions"], key=len) if fused_properties["definitions"] else ""
|
||||||
if fused_properties['definitions'] else ""
|
|
||||||
|
|
||||||
# 选择最佳名称(最常见的那个)
|
# 选择最佳名称(最常见的那个)
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
name_counts = Counter(fused_properties['names'])
|
|
||||||
|
name_counts = Counter(fused_properties["names"])
|
||||||
best_name = name_counts.most_common(1)[0][0] if name_counts else ""
|
best_name = name_counts.most_common(1)[0][0] if name_counts else ""
|
||||||
|
|
||||||
# 构建融合结果
|
# 构建融合结果
|
||||||
return FusionResult(
|
return FusionResult(
|
||||||
canonical_entity_id=entity_id,
|
canonical_entity_id=entity_id,
|
||||||
merged_entity_ids=merged_ids,
|
merged_entity_ids=merged_ids,
|
||||||
fused_properties={
|
fused_properties={
|
||||||
'name': best_name,
|
"name": best_name,
|
||||||
'definition': best_definition,
|
"definition": best_definition,
|
||||||
'aliases': list(fused_properties['aliases']),
|
"aliases": list(fused_properties["aliases"]),
|
||||||
'types': list(fused_properties['types']),
|
"types": list(fused_properties["types"]),
|
||||||
'modalities': list(fused_properties['modalities']),
|
"modalities": list(fused_properties["modalities"]),
|
||||||
'contexts': fused_properties['contexts'][:10] # 最多10个上下文
|
"contexts": fused_properties["contexts"][:10], # 最多10个上下文
|
||||||
},
|
},
|
||||||
source_modalities=list(fused_properties['modalities']),
|
source_modalities=list(fused_properties["modalities"]),
|
||||||
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5)
|
confidence=min(1.0, len(linked_entities) * 0.2 + 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
def detect_entity_conflicts(self, entities: List[Dict]) -> List[Dict]:
|
def detect_entity_conflicts(self, entities: List[Dict]) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
检测实体冲突(同名但不同义)
|
检测实体冲突(同名但不同义)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entities: 实体列表
|
entities: 实体列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
冲突列表
|
冲突列表
|
||||||
"""
|
"""
|
||||||
conflicts = []
|
conflicts = []
|
||||||
|
|
||||||
# 按名称分组
|
# 按名称分组
|
||||||
name_groups = {}
|
name_groups = {}
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
name = entity.get('name', '').lower()
|
name = entity.get("name", "").lower()
|
||||||
if name:
|
if name:
|
||||||
if name not in name_groups:
|
if name not in name_groups:
|
||||||
name_groups[name] = []
|
name_groups[name] = []
|
||||||
name_groups[name].append(entity)
|
name_groups[name].append(entity)
|
||||||
|
|
||||||
# 检测同名但定义不同的实体
|
# 检测同名但定义不同的实体
|
||||||
for name, group in name_groups.items():
|
for name, group in name_groups.items():
|
||||||
if len(group) > 1:
|
if len(group) > 1:
|
||||||
# 检查定义是否相似
|
# 检查定义是否相似
|
||||||
definitions = [e.get('definition', '') for e in group if e.get('definition')]
|
definitions = [e.get("definition", "") for e in group if e.get("definition")]
|
||||||
|
|
||||||
if len(definitions) > 1:
|
if len(definitions) > 1:
|
||||||
# 计算定义之间的相似度
|
# 计算定义之间的相似度
|
||||||
sim_matrix = []
|
sim_matrix = []
|
||||||
@@ -375,76 +366,82 @@ class MultimodalEntityLinker:
|
|||||||
if i < j:
|
if i < j:
|
||||||
sim = self.calculate_string_similarity(d1, d2)
|
sim = self.calculate_string_similarity(d1, d2)
|
||||||
sim_matrix.append(sim)
|
sim_matrix.append(sim)
|
||||||
|
|
||||||
# 如果定义相似度都很低,可能是冲突
|
# 如果定义相似度都很低,可能是冲突
|
||||||
if sim_matrix and all(s < 0.5 for s in sim_matrix):
|
if sim_matrix and all(s < 0.5 for s in sim_matrix):
|
||||||
conflicts.append({
|
conflicts.append(
|
||||||
'name': name,
|
{
|
||||||
'entities': group,
|
"name": name,
|
||||||
'type': 'homonym_conflict',
|
"entities": group,
|
||||||
'suggestion': 'Consider disambiguating these entities'
|
"type": "homonym_conflict",
|
||||||
})
|
"suggestion": "Consider disambiguating these entities",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return conflicts
|
return conflicts
|
||||||
|
|
||||||
def suggest_entity_merges(self, entities: List[Dict],
|
def suggest_entity_merges(self, entities: List[Dict], existing_links: List[EntityLink] = None) -> List[Dict]:
|
||||||
existing_links: List[EntityLink] = None) -> List[Dict]:
|
|
||||||
"""
|
"""
|
||||||
建议实体合并
|
建议实体合并
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entities: 实体列表
|
entities: 实体列表
|
||||||
existing_links: 现有实体关联
|
existing_links: 现有实体关联
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
合并建议列表
|
合并建议列表
|
||||||
"""
|
"""
|
||||||
suggestions = []
|
suggestions = []
|
||||||
existing_pairs = set()
|
existing_pairs = set()
|
||||||
|
|
||||||
# 记录已有的关联
|
# 记录已有的关联
|
||||||
if existing_links:
|
if existing_links:
|
||||||
for link in existing_links:
|
for link in existing_links:
|
||||||
pair = tuple(sorted([link.source_entity_id, link.target_entity_id]))
|
pair = tuple(sorted([link.source_entity_id, link.target_entity_id]))
|
||||||
existing_pairs.add(pair)
|
existing_pairs.add(pair)
|
||||||
|
|
||||||
# 检查所有实体对
|
# 检查所有实体对
|
||||||
for i, ent1 in enumerate(entities):
|
for i, ent1 in enumerate(entities):
|
||||||
for j, ent2 in enumerate(entities):
|
for j, ent2 in enumerate(entities):
|
||||||
if i >= j:
|
if i >= j:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否已有关联
|
# 检查是否已有关联
|
||||||
pair = tuple(sorted([ent1.get('id'), ent2.get('id')]))
|
pair = tuple(sorted([ent1.get("id"), ent2.get("id")]))
|
||||||
if pair in existing_pairs:
|
if pair in existing_pairs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算相似度
|
# 计算相似度
|
||||||
similarity, match_type = self.calculate_entity_similarity(ent1, ent2)
|
similarity, match_type = self.calculate_entity_similarity(ent1, ent2)
|
||||||
|
|
||||||
if similarity >= self.similarity_threshold:
|
if similarity >= self.similarity_threshold:
|
||||||
suggestions.append({
|
suggestions.append(
|
||||||
'entity1': ent1,
|
{
|
||||||
'entity2': ent2,
|
"entity1": ent1,
|
||||||
'similarity': similarity,
|
"entity2": ent2,
|
||||||
'match_type': match_type,
|
"similarity": similarity,
|
||||||
'suggested_action': 'merge' if similarity > 0.95 else 'link'
|
"match_type": match_type,
|
||||||
})
|
"suggested_action": "merge" if similarity > 0.95 else "link",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 按相似度排序
|
# 按相似度排序
|
||||||
suggestions.sort(key=lambda x: x['similarity'], reverse=True)
|
suggestions.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
def create_multimodal_entity_record(self, project_id: str,
|
def create_multimodal_entity_record(
|
||||||
entity_id: str,
|
self,
|
||||||
source_type: str,
|
project_id: str,
|
||||||
source_id: str,
|
entity_id: str,
|
||||||
mention_context: str = "",
|
source_type: str,
|
||||||
confidence: float = 1.0) -> MultimodalEntity:
|
source_id: str,
|
||||||
|
mention_context: str = "",
|
||||||
|
confidence: float = 1.0,
|
||||||
|
) -> MultimodalEntity:
|
||||||
"""
|
"""
|
||||||
创建多模态实体记录
|
创建多模态实体记录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: 项目ID
|
||||||
entity_id: 实体ID
|
entity_id: 实体ID
|
||||||
@@ -452,7 +449,7 @@ class MultimodalEntityLinker:
|
|||||||
source_id: 来源ID
|
source_id: 来源ID
|
||||||
mention_context: 提及上下文
|
mention_context: 提及上下文
|
||||||
confidence: 置信度
|
confidence: 置信度
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
多模态实体记录
|
多模态实体记录
|
||||||
"""
|
"""
|
||||||
@@ -464,48 +461,48 @@ class MultimodalEntityLinker:
|
|||||||
source_type=source_type,
|
source_type=source_type,
|
||||||
source_id=source_id,
|
source_id=source_id,
|
||||||
mention_context=mention_context,
|
mention_context=mention_context,
|
||||||
confidence=confidence
|
confidence=confidence,
|
||||||
)
|
)
|
||||||
|
|
||||||
def analyze_modality_distribution(self, multimodal_entities: List[MultimodalEntity]) -> Dict:
|
def analyze_modality_distribution(self, multimodal_entities: List[MultimodalEntity]) -> Dict:
|
||||||
"""
|
"""
|
||||||
分析模态分布
|
分析模态分布
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
multimodal_entities: 多模态实体列表
|
multimodal_entities: 多模态实体列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模态分布统计
|
模态分布统计
|
||||||
"""
|
"""
|
||||||
distribution = {mod: 0 for mod in self.MODALITIES}
|
distribution = {mod: 0 for mod in self.MODALITIES}
|
||||||
cross_modal_entities = set()
|
|
||||||
|
|
||||||
# 统计每个模态的实体数
|
# 统计每个模态的实体数
|
||||||
for me in multimodal_entities:
|
for me in multimodal_entities:
|
||||||
if me.source_type in distribution:
|
if me.source_type in distribution:
|
||||||
distribution[me.source_type] += 1
|
distribution[me.source_type] += 1
|
||||||
|
|
||||||
# 统计跨模态实体
|
# 统计跨模态实体
|
||||||
entity_modalities = {}
|
entity_modalities = {}
|
||||||
for me in multimodal_entities:
|
for me in multimodal_entities:
|
||||||
if me.entity_id not in entity_modalities:
|
if me.entity_id not in entity_modalities:
|
||||||
entity_modalities[me.entity_id] = set()
|
entity_modalities[me.entity_id] = set()
|
||||||
entity_modalities[me.entity_id].add(me.source_type)
|
entity_modalities[me.entity_id].add(me.source_type)
|
||||||
|
|
||||||
cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1)
|
cross_modal_count = sum(1 for mods in entity_modalities.values() if len(mods) > 1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'modality_distribution': distribution,
|
"modality_distribution": distribution,
|
||||||
'total_multimodal_records': len(multimodal_entities),
|
"total_multimodal_records": len(multimodal_entities),
|
||||||
'unique_entities': len(entity_modalities),
|
"unique_entities": len(entity_modalities),
|
||||||
'cross_modal_entities': cross_modal_count,
|
"cross_modal_entities": cross_modal_count,
|
||||||
'cross_modal_ratio': cross_modal_count / len(entity_modalities) if entity_modalities else 0
|
"cross_modal_ratio": cross_modal_count / len(entity_modalities) if entity_modalities else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
# Singleton instance
|
||||||
_multimodal_entity_linker = None
|
_multimodal_entity_linker = None
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
|
def get_multimodal_entity_linker(similarity_threshold: float = 0.85) -> MultimodalEntityLinker:
|
||||||
"""获取多模态实体关联器单例"""
|
"""获取多模态实体关联器单例"""
|
||||||
global _multimodal_entity_linker
|
global _multimodal_entity_linker
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
import tempfile
|
import tempfile
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -17,18 +17,21 @@ from pathlib import Path
|
|||||||
try:
|
try:
|
||||||
import pytesseract
|
import pytesseract
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
PYTESSERACT_AVAILABLE = True
|
PYTESSERACT_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PYTESSERACT_AVAILABLE = False
|
PYTESSERACT_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
CV2_AVAILABLE = True
|
CV2_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CV2_AVAILABLE = False
|
CV2_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
|
|
||||||
FFMPEG_AVAILABLE = True
|
FFMPEG_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FFMPEG_AVAILABLE = False
|
FFMPEG_AVAILABLE = False
|
||||||
@@ -37,6 +40,7 @@ except ImportError:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class VideoFrame:
|
class VideoFrame:
|
||||||
"""视频关键帧数据类"""
|
"""视频关键帧数据类"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
video_id: str
|
video_id: str
|
||||||
frame_number: int
|
frame_number: int
|
||||||
@@ -45,7 +49,7 @@ class VideoFrame:
|
|||||||
ocr_text: str = ""
|
ocr_text: str = ""
|
||||||
ocr_confidence: float = 0.0
|
ocr_confidence: float = 0.0
|
||||||
entities_detected: List[Dict] = None
|
entities_detected: List[Dict] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.entities_detected is None:
|
if self.entities_detected is None:
|
||||||
self.entities_detected = []
|
self.entities_detected = []
|
||||||
@@ -54,6 +58,7 @@ class VideoFrame:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class VideoInfo:
|
class VideoInfo:
|
||||||
"""视频信息数据类"""
|
"""视频信息数据类"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
filename: str
|
filename: str
|
||||||
@@ -68,7 +73,7 @@ class VideoInfo:
|
|||||||
status: str = "pending"
|
status: str = "pending"
|
||||||
error_message: str = ""
|
error_message: str = ""
|
||||||
metadata: Dict = None
|
metadata: Dict = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
@@ -77,6 +82,7 @@ class VideoInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class VideoProcessingResult:
|
class VideoProcessingResult:
|
||||||
"""视频处理结果"""
|
"""视频处理结果"""
|
||||||
|
|
||||||
video_id: str
|
video_id: str
|
||||||
audio_path: str
|
audio_path: str
|
||||||
frames: List[VideoFrame]
|
frames: List[VideoFrame]
|
||||||
@@ -88,11 +94,11 @@ class VideoProcessingResult:
|
|||||||
|
|
||||||
class MultimodalProcessor:
|
class MultimodalProcessor:
|
||||||
"""多模态处理器 - 处理视频文件"""
|
"""多模态处理器 - 处理视频文件"""
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None, frame_interval: int = 5):
|
def __init__(self, temp_dir: str = None, frame_interval: int = 5):
|
||||||
"""
|
"""
|
||||||
初始化多模态处理器
|
初始化多模态处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
temp_dir: 临时文件目录
|
temp_dir: 临时文件目录
|
||||||
frame_interval: 关键帧提取间隔(秒)
|
frame_interval: 关键帧提取间隔(秒)
|
||||||
@@ -102,88 +108,86 @@ class MultimodalProcessor:
|
|||||||
self.video_dir = os.path.join(self.temp_dir, "videos")
|
self.video_dir = os.path.join(self.temp_dir, "videos")
|
||||||
self.frames_dir = os.path.join(self.temp_dir, "frames")
|
self.frames_dir = os.path.join(self.temp_dir, "frames")
|
||||||
self.audio_dir = os.path.join(self.temp_dir, "audio")
|
self.audio_dir = os.path.join(self.temp_dir, "audio")
|
||||||
|
|
||||||
# 创建目录
|
# 创建目录
|
||||||
os.makedirs(self.video_dir, exist_ok=True)
|
os.makedirs(self.video_dir, exist_ok=True)
|
||||||
os.makedirs(self.frames_dir, exist_ok=True)
|
os.makedirs(self.frames_dir, exist_ok=True)
|
||||||
os.makedirs(self.audio_dir, exist_ok=True)
|
os.makedirs(self.audio_dir, exist_ok=True)
|
||||||
|
|
||||||
def extract_video_info(self, video_path: str) -> Dict:
|
def extract_video_info(self, video_path: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
提取视频基本信息
|
提取视频基本信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: 视频文件路径
|
video_path: 视频文件路径
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
视频信息字典
|
视频信息字典
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if FFMPEG_AVAILABLE:
|
if FFMPEG_AVAILABLE:
|
||||||
probe = ffmpeg.probe(video_path)
|
probe = ffmpeg.probe(video_path)
|
||||||
video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
|
video_stream = next((s for s in probe["streams"] if s["codec_type"] == "video"), None)
|
||||||
audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None)
|
audio_stream = next((s for s in probe["streams"] if s["codec_type"] == "audio"), None)
|
||||||
|
|
||||||
if video_stream:
|
if video_stream:
|
||||||
return {
|
return {
|
||||||
'duration': float(probe['format'].get('duration', 0)),
|
"duration": float(probe["format"].get("duration", 0)),
|
||||||
'width': int(video_stream.get('width', 0)),
|
"width": int(video_stream.get("width", 0)),
|
||||||
'height': int(video_stream.get('height', 0)),
|
"height": int(video_stream.get("height", 0)),
|
||||||
'fps': eval(video_stream.get('r_frame_rate', '0/1')),
|
"fps": eval(video_stream.get("r_frame_rate", "0/1")),
|
||||||
'has_audio': audio_stream is not None,
|
"has_audio": audio_stream is not None,
|
||||||
'bitrate': int(probe['format'].get('bit_rate', 0))
|
"bitrate": int(probe["format"].get("bit_rate", 0)),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 使用 ffprobe 命令行
|
# 使用 ffprobe 命令行
|
||||||
cmd = [
|
cmd = [
|
||||||
'ffprobe', '-v', 'error', '-show_entries',
|
"ffprobe",
|
||||||
'format=duration,bit_rate', '-show_entries',
|
"-v",
|
||||||
'stream=width,height,r_frame_rate', '-of', 'json',
|
"error",
|
||||||
video_path
|
"-show_entries",
|
||||||
|
"format=duration,bit_rate",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=width,height,r_frame_rate",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
video_path,
|
||||||
]
|
]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
data = json.loads(result.stdout)
|
data = json.loads(result.stdout)
|
||||||
return {
|
return {
|
||||||
'duration': float(data['format'].get('duration', 0)),
|
"duration": float(data["format"].get("duration", 0)),
|
||||||
'width': int(data['streams'][0].get('width', 0)) if data['streams'] else 0,
|
"width": int(data["streams"][0].get("width", 0)) if data["streams"] else 0,
|
||||||
'height': int(data['streams'][0].get('height', 0)) if data['streams'] else 0,
|
"height": int(data["streams"][0].get("height", 0)) if data["streams"] else 0,
|
||||||
'fps': 30.0, # 默认值
|
"fps": 30.0, # 默认值
|
||||||
'has_audio': len(data['streams']) > 1,
|
"has_audio": len(data["streams"]) > 1,
|
||||||
'bitrate': int(data['format'].get('bit_rate', 0))
|
"bitrate": int(data["format"].get("bit_rate", 0)),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting video info: {e}")
|
print(f"Error extracting video info: {e}")
|
||||||
|
|
||||||
return {
|
return {"duration": 0, "width": 0, "height": 0, "fps": 0, "has_audio": False, "bitrate": 0}
|
||||||
'duration': 0,
|
|
||||||
'width': 0,
|
|
||||||
'height': 0,
|
|
||||||
'fps': 0,
|
|
||||||
'has_audio': False,
|
|
||||||
'bitrate': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract_audio(self, video_path: str, output_path: str = None) -> str:
|
def extract_audio(self, video_path: str, output_path: str = None) -> str:
|
||||||
"""
|
"""
|
||||||
从视频中提取音频
|
从视频中提取音频
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: 视频文件路径
|
video_path: 视频文件路径
|
||||||
output_path: 输出音频路径(可选)
|
output_path: 输出音频路径(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提取的音频文件路径
|
提取的音频文件路径
|
||||||
"""
|
"""
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
video_name = Path(video_path).stem
|
video_name = Path(video_path).stem
|
||||||
output_path = os.path.join(self.audio_dir, f"{video_name}.wav")
|
output_path = os.path.join(self.audio_dir, f"{video_name}.wav")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if FFMPEG_AVAILABLE:
|
if FFMPEG_AVAILABLE:
|
||||||
(
|
(
|
||||||
ffmpeg
|
ffmpeg.input(video_path)
|
||||||
.input(video_path)
|
|
||||||
.output(output_path, ac=1, ar=16000, vn=None)
|
.output(output_path, ac=1, ar=16000, vn=None)
|
||||||
.overwrite_output()
|
.overwrite_output()
|
||||||
.run(quiet=True)
|
.run(quiet=True)
|
||||||
@@ -191,170 +195,168 @@ class MultimodalProcessor:
|
|||||||
else:
|
else:
|
||||||
# 使用命令行 ffmpeg
|
# 使用命令行 ffmpeg
|
||||||
cmd = [
|
cmd = [
|
||||||
'ffmpeg', '-i', video_path,
|
"ffmpeg",
|
||||||
'-vn', '-acodec', 'pcm_s16le',
|
"-i",
|
||||||
'-ac', '1', '-ar', '16000',
|
video_path,
|
||||||
'-y', output_path
|
"-vn",
|
||||||
|
"-acodec",
|
||||||
|
"pcm_s16le",
|
||||||
|
"-ac",
|
||||||
|
"1",
|
||||||
|
"-ar",
|
||||||
|
"16000",
|
||||||
|
"-y",
|
||||||
|
output_path,
|
||||||
]
|
]
|
||||||
subprocess.run(cmd, check=True, capture_output=True)
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
|
||||||
return output_path
|
return output_path
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting audio: {e}")
|
print(f"Error extracting audio: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def extract_keyframes(self, video_path: str, video_id: str,
|
def extract_keyframes(self, video_path: str, video_id: str, interval: int = None) -> List[str]:
|
||||||
interval: int = None) -> List[str]:
|
|
||||||
"""
|
"""
|
||||||
从视频中提取关键帧
|
从视频中提取关键帧
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: 视频文件路径
|
video_path: 视频文件路径
|
||||||
video_id: 视频ID
|
video_id: 视频ID
|
||||||
interval: 提取间隔(秒),默认使用初始化时的间隔
|
interval: 提取间隔(秒),默认使用初始化时的间隔
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提取的帧文件路径列表
|
提取的帧文件路径列表
|
||||||
"""
|
"""
|
||||||
interval = interval or self.frame_interval
|
interval = interval or self.frame_interval
|
||||||
frame_paths = []
|
frame_paths = []
|
||||||
|
|
||||||
# 创建帧存储目录
|
# 创建帧存储目录
|
||||||
video_frames_dir = os.path.join(self.frames_dir, video_id)
|
video_frames_dir = os.path.join(self.frames_dir, video_id)
|
||||||
os.makedirs(video_frames_dir, exist_ok=True)
|
os.makedirs(video_frames_dir, exist_ok=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if CV2_AVAILABLE:
|
if CV2_AVAILABLE:
|
||||||
# 使用 OpenCV 提取帧
|
# 使用 OpenCV 提取帧
|
||||||
cap = cv2.VideoCapture(video_path)
|
cap = cv2.VideoCapture(video_path)
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
|
||||||
frame_interval_frames = int(fps * interval)
|
frame_interval_frames = int(fps * interval)
|
||||||
frame_number = 0
|
frame_number = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
break
|
break
|
||||||
|
|
||||||
if frame_number % frame_interval_frames == 0:
|
if frame_number % frame_interval_frames == 0:
|
||||||
timestamp = frame_number / fps
|
timestamp = frame_number / fps
|
||||||
frame_path = os.path.join(
|
frame_path = os.path.join(video_frames_dir, f"frame_{frame_number:06d}_{timestamp:.2f}.jpg")
|
||||||
video_frames_dir,
|
|
||||||
f"frame_{frame_number:06d}_{timestamp:.2f}.jpg"
|
|
||||||
)
|
|
||||||
cv2.imwrite(frame_path, frame)
|
cv2.imwrite(frame_path, frame)
|
||||||
frame_paths.append(frame_path)
|
frame_paths.append(frame_path)
|
||||||
|
|
||||||
frame_number += 1
|
frame_number += 1
|
||||||
|
|
||||||
cap.release()
|
cap.release()
|
||||||
else:
|
else:
|
||||||
# 使用 ffmpeg 命令行提取帧
|
# 使用 ffmpeg 命令行提取帧
|
||||||
video_name = Path(video_path).stem
|
Path(video_path).stem
|
||||||
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
|
output_pattern = os.path.join(video_frames_dir, "frame_%06d_%t.jpg")
|
||||||
|
|
||||||
cmd = [
|
cmd = ["ffmpeg", "-i", video_path, "-vf", f"fps=1/{interval}", "-frame_pts", "1", "-y", output_pattern]
|
||||||
'ffmpeg', '-i', video_path,
|
|
||||||
'-vf', f'fps=1/{interval}',
|
|
||||||
'-frame_pts', '1',
|
|
||||||
'-y', output_pattern
|
|
||||||
]
|
|
||||||
subprocess.run(cmd, check=True, capture_output=True)
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
|
||||||
# 获取生成的帧文件列表
|
# 获取生成的帧文件列表
|
||||||
frame_paths = sorted([
|
frame_paths = sorted(
|
||||||
os.path.join(video_frames_dir, f)
|
[os.path.join(video_frames_dir, f) for f in os.listdir(video_frames_dir) if f.startswith("frame_")]
|
||||||
for f in os.listdir(video_frames_dir)
|
)
|
||||||
if f.startswith('frame_')
|
|
||||||
])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting keyframes: {e}")
|
print(f"Error extracting keyframes: {e}")
|
||||||
|
|
||||||
return frame_paths
|
return frame_paths
|
||||||
|
|
||||||
def perform_ocr(self, image_path: str) -> Tuple[str, float]:
|
def perform_ocr(self, image_path: str) -> Tuple[str, float]:
|
||||||
"""
|
"""
|
||||||
对图片进行OCR识别
|
对图片进行OCR识别
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_path: 图片文件路径
|
image_path: 图片文件路径
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(识别的文本, 置信度)
|
(识别的文本, 置信度)
|
||||||
"""
|
"""
|
||||||
if not PYTESSERACT_AVAILABLE:
|
if not PYTESSERACT_AVAILABLE:
|
||||||
return "", 0.0
|
return "", 0.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
|
|
||||||
# 预处理:转换为灰度图
|
# 预处理:转换为灰度图
|
||||||
if image.mode != 'L':
|
if image.mode != "L":
|
||||||
image = image.convert('L')
|
image = image.convert("L")
|
||||||
|
|
||||||
# 使用 pytesseract 进行 OCR
|
# 使用 pytesseract 进行 OCR
|
||||||
text = pytesseract.image_to_string(image, lang='chi_sim+eng')
|
text = pytesseract.image_to_string(image, lang="chi_sim+eng")
|
||||||
|
|
||||||
# 获取置信度数据
|
# 获取置信度数据
|
||||||
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
|
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
|
||||||
confidences = [int(c) for c in data['conf'] if int(c) > 0]
|
confidences = [int(c) for c in data["conf"] if int(c) > 0]
|
||||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
|
||||||
|
|
||||||
return text.strip(), avg_confidence / 100.0
|
return text.strip(), avg_confidence / 100.0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"OCR error for {image_path}: {e}")
|
print(f"OCR error for {image_path}: {e}")
|
||||||
return "", 0.0
|
return "", 0.0
|
||||||
|
|
||||||
def process_video(self, video_data: bytes, filename: str,
|
def process_video(
|
||||||
project_id: str, video_id: str = None) -> VideoProcessingResult:
|
self, video_data: bytes, filename: str, project_id: str, video_id: str = None
|
||||||
|
) -> VideoProcessingResult:
|
||||||
"""
|
"""
|
||||||
处理视频文件:提取音频、关键帧、OCR
|
处理视频文件:提取音频、关键帧、OCR
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_data: 视频文件二进制数据
|
video_data: 视频文件二进制数据
|
||||||
filename: 视频文件名
|
filename: 视频文件名
|
||||||
project_id: 项目ID
|
project_id: 项目ID
|
||||||
video_id: 视频ID(可选,自动生成)
|
video_id: 视频ID(可选,自动生成)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
视频处理结果
|
视频处理结果
|
||||||
"""
|
"""
|
||||||
video_id = video_id or str(uuid.uuid4())[:8]
|
video_id = video_id or str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 保存视频文件
|
# 保存视频文件
|
||||||
video_path = os.path.join(self.video_dir, f"{video_id}_{filename}")
|
video_path = os.path.join(self.video_dir, f"{video_id}_{filename}")
|
||||||
with open(video_path, 'wb') as f:
|
with open(video_path, "wb") as f:
|
||||||
f.write(video_data)
|
f.write(video_data)
|
||||||
|
|
||||||
# 提取视频信息
|
# 提取视频信息
|
||||||
video_info = self.extract_video_info(video_path)
|
video_info = self.extract_video_info(video_path)
|
||||||
|
|
||||||
# 提取音频
|
# 提取音频
|
||||||
audio_path = ""
|
audio_path = ""
|
||||||
if video_info['has_audio']:
|
if video_info["has_audio"]:
|
||||||
audio_path = self.extract_audio(video_path)
|
audio_path = self.extract_audio(video_path)
|
||||||
|
|
||||||
# 提取关键帧
|
# 提取关键帧
|
||||||
frame_paths = self.extract_keyframes(video_path, video_id)
|
frame_paths = self.extract_keyframes(video_path, video_id)
|
||||||
|
|
||||||
# 对关键帧进行 OCR
|
# 对关键帧进行 OCR
|
||||||
frames = []
|
frames = []
|
||||||
ocr_results = []
|
ocr_results = []
|
||||||
all_ocr_text = []
|
all_ocr_text = []
|
||||||
|
|
||||||
for i, frame_path in enumerate(frame_paths):
|
for i, frame_path in enumerate(frame_paths):
|
||||||
# 解析帧信息
|
# 解析帧信息
|
||||||
frame_name = os.path.basename(frame_path)
|
frame_name = os.path.basename(frame_path)
|
||||||
parts = frame_name.replace('.jpg', '').split('_')
|
parts = frame_name.replace(".jpg", "").split("_")
|
||||||
frame_number = int(parts[1]) if len(parts) > 1 else i
|
frame_number = int(parts[1]) if len(parts) > 1 else i
|
||||||
timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval
|
timestamp = float(parts[2]) if len(parts) > 2 else i * self.frame_interval
|
||||||
|
|
||||||
# OCR 识别
|
# OCR 识别
|
||||||
ocr_text, confidence = self.perform_ocr(frame_path)
|
ocr_text, confidence = self.perform_ocr(frame_path)
|
||||||
|
|
||||||
frame = VideoFrame(
|
frame = VideoFrame(
|
||||||
id=str(uuid.uuid4())[:8],
|
id=str(uuid.uuid4())[:8],
|
||||||
video_id=video_id,
|
video_id=video_id,
|
||||||
@@ -362,31 +364,33 @@ class MultimodalProcessor:
|
|||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
frame_path=frame_path,
|
frame_path=frame_path,
|
||||||
ocr_text=ocr_text,
|
ocr_text=ocr_text,
|
||||||
ocr_confidence=confidence
|
ocr_confidence=confidence,
|
||||||
)
|
)
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
|
|
||||||
if ocr_text:
|
if ocr_text:
|
||||||
ocr_results.append({
|
ocr_results.append(
|
||||||
'frame_number': frame_number,
|
{
|
||||||
'timestamp': timestamp,
|
"frame_number": frame_number,
|
||||||
'text': ocr_text,
|
"timestamp": timestamp,
|
||||||
'confidence': confidence
|
"text": ocr_text,
|
||||||
})
|
"confidence": confidence,
|
||||||
|
}
|
||||||
|
)
|
||||||
all_ocr_text.append(ocr_text)
|
all_ocr_text.append(ocr_text)
|
||||||
|
|
||||||
# 整合所有 OCR 文本
|
# 整合所有 OCR 文本
|
||||||
full_ocr_text = "\n\n".join(all_ocr_text)
|
full_ocr_text = "\n\n".join(all_ocr_text)
|
||||||
|
|
||||||
return VideoProcessingResult(
|
return VideoProcessingResult(
|
||||||
video_id=video_id,
|
video_id=video_id,
|
||||||
audio_path=audio_path,
|
audio_path=audio_path,
|
||||||
frames=frames,
|
frames=frames,
|
||||||
ocr_results=ocr_results,
|
ocr_results=ocr_results,
|
||||||
full_text=full_ocr_text,
|
full_text=full_ocr_text,
|
||||||
success=True
|
success=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return VideoProcessingResult(
|
return VideoProcessingResult(
|
||||||
video_id=video_id,
|
video_id=video_id,
|
||||||
@@ -395,18 +399,18 @@ class MultimodalProcessor:
|
|||||||
ocr_results=[],
|
ocr_results=[],
|
||||||
full_text="",
|
full_text="",
|
||||||
success=False,
|
success=False,
|
||||||
error_message=str(e)
|
error_message=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
def cleanup(self, video_id: str = None):
|
def cleanup(self, video_id: str = None):
|
||||||
"""
|
"""
|
||||||
清理临时文件
|
清理临时文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_id: 视频ID(可选,清理特定视频的文件)
|
video_id: 视频ID(可选,清理特定视频的文件)
|
||||||
"""
|
"""
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if video_id:
|
if video_id:
|
||||||
# 清理特定视频的文件
|
# 清理特定视频的文件
|
||||||
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
|
for dir_path in [self.video_dir, self.frames_dir, self.audio_dir]:
|
||||||
@@ -426,6 +430,7 @@ class MultimodalProcessor:
|
|||||||
# Singleton instance
|
# Singleton instance
|
||||||
_multimodal_processor = None
|
_multimodal_processor = None
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
|
def get_multimodal_processor(temp_dir: str = None, frame_interval: int = 5) -> MultimodalProcessor:
|
||||||
"""获取多模态处理器单例"""
|
"""获取多模态处理器单例"""
|
||||||
global _multimodal_processor
|
global _multimodal_processor
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,9 +5,10 @@ OSS 上传工具 - 用于阿里听悟音频上传
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime
|
||||||
import oss2
|
import oss2
|
||||||
|
|
||||||
|
|
||||||
class OSSUploader:
|
class OSSUploader:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.access_key = os.getenv("ALI_ACCESS_KEY")
|
self.access_key = os.getenv("ALI_ACCESS_KEY")
|
||||||
@@ -15,33 +16,35 @@ class OSSUploader:
|
|||||||
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")
|
self.bucket_name = os.getenv("OSS_BUCKET", "insightflow-audio")
|
||||||
self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com")
|
self.region = os.getenv("OSS_REGION", "oss-cn-hangzhou.aliyuncs.com")
|
||||||
self.endpoint = f"https://{self.region}"
|
self.endpoint = f"https://{self.region}"
|
||||||
|
|
||||||
if not self.access_key or not self.secret_key:
|
if not self.access_key or not self.secret_key:
|
||||||
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY must be set")
|
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY must be set")
|
||||||
|
|
||||||
self.auth = oss2.Auth(self.access_key, self.secret_key)
|
self.auth = oss2.Auth(self.access_key, self.secret_key)
|
||||||
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
|
self.bucket = oss2.Bucket(self.auth, self.endpoint, self.bucket_name)
|
||||||
|
|
||||||
def upload_audio(self, audio_data: bytes, filename: str) -> tuple:
|
def upload_audio(self, audio_data: bytes, filename: str) -> tuple:
|
||||||
"""上传音频到 OSS,返回 (URL, object_name)"""
|
"""上传音频到 OSS,返回 (URL, object_name)"""
|
||||||
# 生成唯一文件名
|
# 生成唯一文件名
|
||||||
ext = os.path.splitext(filename)[1] or ".wav"
|
ext = os.path.splitext(filename)[1] or ".wav"
|
||||||
object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}"
|
object_name = f"audio/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}{ext}"
|
||||||
|
|
||||||
# 上传文件
|
# 上传文件
|
||||||
self.bucket.put_object(object_name, audio_data)
|
self.bucket.put_object(object_name, audio_data)
|
||||||
|
|
||||||
# 生成临时访问 URL (1小时有效)
|
# 生成临时访问 URL (1小时有效)
|
||||||
url = self.bucket.sign_url('GET', object_name, 3600)
|
url = self.bucket.sign_url("GET", object_name, 3600)
|
||||||
return url, object_name
|
return url, object_name
|
||||||
|
|
||||||
def delete_object(self, object_name: str):
|
def delete_object(self, object_name: str):
|
||||||
"""删除 OSS 对象"""
|
"""删除 OSS 对象"""
|
||||||
self.bucket.delete_object(object_name)
|
self.bucket.delete_object(object_name)
|
||||||
|
|
||||||
|
|
||||||
# 单例
|
# 单例
|
||||||
_oss_uploader = None
|
_oss_uploader = None
|
||||||
|
|
||||||
|
|
||||||
def get_oss_uploader() -> OSSUploader:
|
def get_oss_uploader() -> OSSUploader:
|
||||||
global _oss_uploader
|
global _oss_uploader
|
||||||
if _oss_uploader is None:
|
if _oss_uploader is None:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,8 @@ API 限流中间件
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Optional, Tuple, Callable
|
from typing import Dict, Optional, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
@@ -16,6 +16,7 @@ from functools import wraps
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RateLimitConfig:
|
class RateLimitConfig:
|
||||||
"""限流配置"""
|
"""限流配置"""
|
||||||
|
|
||||||
requests_per_minute: int = 60
|
requests_per_minute: int = 60
|
||||||
burst_size: int = 10 # 突发请求数
|
burst_size: int = 10 # 突发请求数
|
||||||
window_size: int = 60 # 窗口大小(秒)
|
window_size: int = 60 # 窗口大小(秒)
|
||||||
@@ -24,6 +25,7 @@ class RateLimitConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RateLimitInfo:
|
class RateLimitInfo:
|
||||||
"""限流信息"""
|
"""限流信息"""
|
||||||
|
|
||||||
allowed: bool
|
allowed: bool
|
||||||
remaining: int
|
remaining: int
|
||||||
reset_time: int # 重置时间戳
|
reset_time: int # 重置时间戳
|
||||||
@@ -32,12 +34,13 @@ class RateLimitInfo:
|
|||||||
|
|
||||||
class SlidingWindowCounter:
|
class SlidingWindowCounter:
|
||||||
"""滑动窗口计数器"""
|
"""滑动窗口计数器"""
|
||||||
|
|
||||||
def __init__(self, window_size: int = 60):
|
def __init__(self, window_size: int = 60):
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.requests: Dict[int, int] = defaultdict(int) # 秒级计数
|
self.requests: Dict[int, int] = defaultdict(int) # 秒级计数
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._cleanup_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def add_request(self) -> int:
|
async def add_request(self) -> int:
|
||||||
"""添加请求,返回当前窗口内的请求数"""
|
"""添加请求,返回当前窗口内的请求数"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
@@ -45,87 +48,76 @@ class SlidingWindowCounter:
|
|||||||
self.requests[now] += 1
|
self.requests[now] += 1
|
||||||
self._cleanup_old(now)
|
self._cleanup_old(now)
|
||||||
return sum(self.requests.values())
|
return sum(self.requests.values())
|
||||||
|
|
||||||
async def get_count(self) -> int:
|
async def get_count(self) -> int:
|
||||||
"""获取当前窗口内的请求数"""
|
"""获取当前窗口内的请求数"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
self._cleanup_old(now)
|
self._cleanup_old(now)
|
||||||
return sum(self.requests.values())
|
return sum(self.requests.values())
|
||||||
|
|
||||||
def _cleanup_old(self, now: int):
|
def _cleanup_old(self, now: int):
|
||||||
"""清理过期的请求记录"""
|
"""清理过期的请求记录 - 使用独立锁避免竞态条件"""
|
||||||
cutoff = now - self.window_size
|
cutoff = now - self.window_size
|
||||||
old_keys = [k for k in self.requests.keys() if k < cutoff]
|
old_keys = [k for k in list(self.requests.keys()) if k < cutoff]
|
||||||
for k in old_keys:
|
for k in old_keys:
|
||||||
del self.requests[k]
|
self.requests.pop(k, None)
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
class RateLimiter:
|
||||||
"""API 限流器"""
|
"""API 限流器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# key -> SlidingWindowCounter
|
# key -> SlidingWindowCounter
|
||||||
self.counters: Dict[str, SlidingWindowCounter] = {}
|
self.counters: Dict[str, SlidingWindowCounter] = {}
|
||||||
# key -> RateLimitConfig
|
# key -> RateLimitConfig
|
||||||
self.configs: Dict[str, RateLimitConfig] = {}
|
self.configs: Dict[str, RateLimitConfig] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._cleanup_lock = asyncio.Lock()
|
||||||
async def is_allowed(
|
|
||||||
self,
|
async def is_allowed(self, key: str, config: Optional[RateLimitConfig] = None) -> RateLimitInfo:
|
||||||
key: str,
|
|
||||||
config: Optional[RateLimitConfig] = None
|
|
||||||
) -> RateLimitInfo:
|
|
||||||
"""
|
"""
|
||||||
检查是否允许请求
|
检查是否允许请求
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: 限流键(如 API Key ID)
|
key: 限流键(如 API Key ID)
|
||||||
config: 限流配置,如果为 None 则使用默认配置
|
config: 限流配置,如果为 None 则使用默认配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RateLimitInfo
|
RateLimitInfo
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = RateLimitConfig()
|
config = RateLimitConfig()
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if key not in self.counters:
|
if key not in self.counters:
|
||||||
self.counters[key] = SlidingWindowCounter(config.window_size)
|
self.counters[key] = SlidingWindowCounter(config.window_size)
|
||||||
self.configs[key] = config
|
self.configs[key] = config
|
||||||
|
|
||||||
counter = self.counters[key]
|
counter = self.counters[key]
|
||||||
stored_config = self.configs.get(key, config)
|
stored_config = self.configs.get(key, config)
|
||||||
|
|
||||||
# 获取当前计数
|
# 获取当前计数
|
||||||
current_count = await counter.get_count()
|
current_count = await counter.get_count()
|
||||||
|
|
||||||
# 计算剩余配额
|
# 计算剩余配额
|
||||||
remaining = max(0, stored_config.requests_per_minute - current_count)
|
remaining = max(0, stored_config.requests_per_minute - current_count)
|
||||||
|
|
||||||
# 计算重置时间
|
# 计算重置时间
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
reset_time = now + stored_config.window_size
|
reset_time = now + stored_config.window_size
|
||||||
|
|
||||||
# 检查是否超过限制
|
# 检查是否超过限制
|
||||||
if current_count >= stored_config.requests_per_minute:
|
if current_count >= stored_config.requests_per_minute:
|
||||||
return RateLimitInfo(
|
return RateLimitInfo(
|
||||||
allowed=False,
|
allowed=False, remaining=0, reset_time=reset_time, retry_after=stored_config.window_size
|
||||||
remaining=0,
|
|
||||||
reset_time=reset_time,
|
|
||||||
retry_after=stored_config.window_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 允许请求,增加计数
|
# 允许请求,增加计数
|
||||||
await counter.add_request()
|
await counter.add_request()
|
||||||
|
|
||||||
return RateLimitInfo(
|
return RateLimitInfo(allowed=True, remaining=remaining - 1, reset_time=reset_time, retry_after=0)
|
||||||
allowed=True,
|
|
||||||
remaining=remaining - 1,
|
|
||||||
reset_time=reset_time,
|
|
||||||
retry_after=0
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
async def get_limit_info(self, key: str) -> RateLimitInfo:
|
||||||
"""获取限流信息(不增加计数)"""
|
"""获取限流信息(不增加计数)"""
|
||||||
if key not in self.counters:
|
if key not in self.counters:
|
||||||
@@ -134,23 +126,23 @@ class RateLimiter:
|
|||||||
allowed=True,
|
allowed=True,
|
||||||
remaining=config.requests_per_minute,
|
remaining=config.requests_per_minute,
|
||||||
reset_time=int(time.time()) + config.window_size,
|
reset_time=int(time.time()) + config.window_size,
|
||||||
retry_after=0
|
retry_after=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
counter = self.counters[key]
|
counter = self.counters[key]
|
||||||
config = self.configs.get(key, RateLimitConfig())
|
config = self.configs.get(key, RateLimitConfig())
|
||||||
|
|
||||||
current_count = await counter.get_count()
|
current_count = await counter.get_count()
|
||||||
remaining = max(0, config.requests_per_minute - current_count)
|
remaining = max(0, config.requests_per_minute - current_count)
|
||||||
reset_time = int(time.time()) + config.window_size
|
reset_time = int(time.time()) + config.window_size
|
||||||
|
|
||||||
return RateLimitInfo(
|
return RateLimitInfo(
|
||||||
allowed=current_count < config.requests_per_minute,
|
allowed=current_count < config.requests_per_minute,
|
||||||
remaining=remaining,
|
remaining=remaining,
|
||||||
reset_time=reset_time,
|
reset_time=reset_time,
|
||||||
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0
|
retry_after=max(0, config.window_size) if current_count >= config.requests_per_minute else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self, key: Optional[str] = None):
|
def reset(self, key: Optional[str] = None):
|
||||||
"""重置限流计数器"""
|
"""重置限流计数器"""
|
||||||
if key:
|
if key:
|
||||||
@@ -174,50 +166,44 @@ def get_rate_limiter() -> RateLimiter:
|
|||||||
|
|
||||||
|
|
||||||
# 限流装饰器(用于函数级别限流)
|
# 限流装饰器(用于函数级别限流)
|
||||||
def rate_limit(
|
def rate_limit(requests_per_minute: int = 60, key_func: Optional[Callable] = None):
|
||||||
requests_per_minute: int = 60,
|
|
||||||
key_func: Optional[Callable] = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
限流装饰器
|
限流装饰器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requests_per_minute: 每分钟请求数限制
|
requests_per_minute: 每分钟请求数限制
|
||||||
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
key_func: 生成限流键的函数,默认为 None(使用函数名)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
limiter = get_rate_limiter()
|
limiter = get_rate_limiter()
|
||||||
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
config = RateLimitConfig(requests_per_minute=requests_per_minute)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def async_wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||||
info = await limiter.is_allowed(key, config)
|
info = await limiter.is_allowed(key, config)
|
||||||
|
|
||||||
if not info.allowed:
|
if not info.allowed:
|
||||||
raise RateLimitExceeded(
|
raise RateLimitExceeded(f"Rate limit exceeded. Try again in {info.retry_after} seconds.")
|
||||||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
|
||||||
)
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def sync_wrapper(*args, **kwargs):
|
def sync_wrapper(*args, **kwargs):
|
||||||
key = key_func(*args, **kwargs) if key_func else func.__name__
|
key = key_func(*args, **kwargs) if key_func else func.__name__
|
||||||
# 同步版本使用 asyncio.run
|
# 同步版本使用 asyncio.run
|
||||||
info = asyncio.run(limiter.is_allowed(key, config))
|
info = asyncio.run(limiter.is_allowed(key, config))
|
||||||
|
|
||||||
if not info.allowed:
|
if not info.allowed:
|
||||||
raise RateLimitExceeded(
|
raise RateLimitExceeded(f"Rate limit exceeded. Try again in {info.retry_after} seconds.")
|
||||||
f"Rate limit exceeded. Try again in {info.retry_after} seconds."
|
|
||||||
)
|
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class RateLimitExceeded(Exception):
|
class RateLimitExceeded(Exception):
|
||||||
"""限流异常"""
|
"""限流异常"""
|
||||||
pass
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,6 @@ InsightFlow Phase 7 Task 3: 数据安全与合规模块
|
|||||||
Security Manager - 端到端加密、数据脱敏、审计日志
|
Security Manager - 端到端加密、数据脱敏、审计日志
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
@@ -83,7 +82,7 @@ class AuditLog:
|
|||||||
success: bool = True
|
success: bool = True
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@@ -100,7 +99,7 @@ class EncryptionConfig:
|
|||||||
salt: Optional[str] = None
|
salt: Optional[str] = None
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@@ -119,7 +118,7 @@ class MaskingRule:
|
|||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@@ -140,7 +139,7 @@ class DataAccessPolicy:
|
|||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@@ -157,14 +156,14 @@ class AccessRequest:
|
|||||||
approved_at: Optional[str] = None
|
approved_at: Optional[str] = None
|
||||||
expires_at: Optional[str] = None
|
expires_at: Optional[str] = None
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
class SecurityManager:
|
class SecurityManager:
|
||||||
"""安全管理器"""
|
"""安全管理器"""
|
||||||
|
|
||||||
# 预定义脱敏规则
|
# 预定义脱敏规则
|
||||||
DEFAULT_MASKING_RULES = {
|
DEFAULT_MASKING_RULES = {
|
||||||
MaskingRuleType.PHONE: {
|
MaskingRuleType.PHONE: {
|
||||||
@@ -192,17 +191,20 @@ class SecurityManager:
|
|||||||
"replacement": r"\1\2***"
|
"replacement": r"\1\2***"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, db_path: str = "insightflow.db"):
|
def __init__(self, db_path: str = "insightflow.db"):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
self.db_path = db_path
|
||||||
|
# 预编译正则缓存
|
||||||
|
self._compiled_patterns: Dict[str, re.Pattern] = {}
|
||||||
self._local = {}
|
self._local = {}
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
"""初始化数据库表"""
|
"""初始化数据库表"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 审计日志表
|
# 审计日志表
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||||
@@ -221,7 +223,7 @@ class SecurityManager:
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# 加密配置表
|
# 加密配置表
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS encryption_configs (
|
CREATE TABLE IF NOT EXISTS encryption_configs (
|
||||||
@@ -237,7 +239,7 @@ class SecurityManager:
|
|||||||
FOREIGN KEY (project_id) REFERENCES projects(id)
|
FOREIGN KEY (project_id) REFERENCES projects(id)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# 脱敏规则表
|
# 脱敏规则表
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS masking_rules (
|
CREATE TABLE IF NOT EXISTS masking_rules (
|
||||||
@@ -255,7 +257,7 @@ class SecurityManager:
|
|||||||
FOREIGN KEY (project_id) REFERENCES projects(id)
|
FOREIGN KEY (project_id) REFERENCES projects(id)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# 数据访问策略表
|
# 数据访问策略表
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS data_access_policies (
|
CREATE TABLE IF NOT EXISTS data_access_policies (
|
||||||
@@ -275,7 +277,7 @@ class SecurityManager:
|
|||||||
FOREIGN KEY (project_id) REFERENCES projects(id)
|
FOREIGN KEY (project_id) REFERENCES projects(id)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# 访问请求表
|
# 访问请求表
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS access_requests (
|
CREATE TABLE IF NOT EXISTS access_requests (
|
||||||
@@ -291,7 +293,7 @@ class SecurityManager:
|
|||||||
FOREIGN KEY (policy_id) REFERENCES data_access_policies(id)
|
FOREIGN KEY (policy_id) REFERENCES data_access_policies(id)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# 创建索引
|
# 创建索引
|
||||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_user ON audit_logs(user_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_user ON audit_logs(user_id)")
|
||||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_resource ON audit_logs(resource_type, resource_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_audit_logs_resource ON audit_logs(resource_type, resource_id)")
|
||||||
@@ -300,18 +302,18 @@ class SecurityManager:
|
|||||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_encryption_project ON encryption_configs(project_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_encryption_project ON encryption_configs(project_id)")
|
||||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_masking_project ON masking_rules(project_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_masking_project ON masking_rules(project_id)")
|
||||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_access_policy_project ON data_access_policies(project_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_access_policy_project ON data_access_policies(project_id)")
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _generate_id(self) -> str:
|
def _generate_id(self) -> str:
|
||||||
"""生成唯一ID"""
|
"""生成唯一ID"""
|
||||||
return hashlib.sha256(
|
return hashlib.sha256(
|
||||||
f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
|
f"{datetime.now().isoformat()}{secrets.token_hex(16)}".encode()
|
||||||
).hexdigest()[:32]
|
).hexdigest()[:32]
|
||||||
|
|
||||||
# ==================== 审计日志 ====================
|
# ==================== 审计日志 ====================
|
||||||
|
|
||||||
def log_audit(
|
def log_audit(
|
||||||
self,
|
self,
|
||||||
action_type: AuditActionType,
|
action_type: AuditActionType,
|
||||||
@@ -341,11 +343,11 @@ class SecurityManager:
|
|||||||
success=success,
|
success=success,
|
||||||
error_message=error_message
|
error_message=error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO audit_logs
|
INSERT INTO audit_logs
|
||||||
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
|
(id, action_type, user_id, user_ip, user_agent, resource_type, resource_id,
|
||||||
action_details, before_value, after_value, success, error_message, created_at)
|
action_details, before_value, after_value, success, error_message, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
@@ -357,9 +359,9 @@ class SecurityManager:
|
|||||||
))
|
))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def get_audit_logs(
|
def get_audit_logs(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
@@ -375,10 +377,10 @@ class SecurityManager:
|
|||||||
"""查询审计日志"""
|
"""查询审计日志"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
query = "SELECT * FROM audit_logs WHERE 1=1"
|
query = "SELECT * FROM audit_logs WHERE 1=1"
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
query += " AND user_id = ?"
|
query += " AND user_id = ?"
|
||||||
params.append(user_id)
|
params.append(user_id)
|
||||||
@@ -400,26 +402,19 @@ class SecurityManager:
|
|||||||
if success is not None:
|
if success is not None:
|
||||||
query += " AND success = ?"
|
query += " AND success = ?"
|
||||||
params.append(int(success))
|
params.append(int(success))
|
||||||
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
params.extend([limit, offset])
|
params.extend([limit, offset])
|
||||||
|
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
logs = []
|
logs = []
|
||||||
for row in cursor.description:
|
col_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||||
col_names = [desc[0] for desc in cursor.description]
|
if not col_names:
|
||||||
break
|
|
||||||
else:
|
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(query, params)
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
log = AuditLog(
|
log = AuditLog(
|
||||||
id=row[0],
|
id=row[0],
|
||||||
@@ -437,10 +432,10 @@ class SecurityManager:
|
|||||||
created_at=row[12]
|
created_at=row[12]
|
||||||
)
|
)
|
||||||
logs.append(log)
|
logs.append(log)
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
def get_audit_stats(
|
def get_audit_stats(
|
||||||
self,
|
self,
|
||||||
start_time: Optional[str] = None,
|
start_time: Optional[str] = None,
|
||||||
@@ -449,54 +444,54 @@ class SecurityManager:
|
|||||||
"""获取审计统计"""
|
"""获取审计统计"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1=1"
|
query = "SELECT action_type, success, COUNT(*) FROM audit_logs WHERE 1=1"
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if start_time:
|
if start_time:
|
||||||
query += " AND created_at >= ?"
|
query += " AND created_at >= ?"
|
||||||
params.append(start_time)
|
params.append(start_time)
|
||||||
if end_time:
|
if end_time:
|
||||||
query += " AND created_at <= ?"
|
query += " AND created_at <= ?"
|
||||||
params.append(end_time)
|
params.append(end_time)
|
||||||
|
|
||||||
query += " GROUP BY action_type, success"
|
query += " GROUP BY action_type, success"
|
||||||
|
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
"total_actions": 0,
|
"total_actions": 0,
|
||||||
"success_count": 0,
|
"success_count": 0,
|
||||||
"failure_count": 0,
|
"failure_count": 0,
|
||||||
"action_breakdown": {}
|
"action_breakdown": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
for action_type, success, count in rows:
|
for action_type, success, count in rows:
|
||||||
stats["total_actions"] += count
|
stats["total_actions"] += count
|
||||||
if success:
|
if success:
|
||||||
stats["success_count"] += count
|
stats["success_count"] += count
|
||||||
else:
|
else:
|
||||||
stats["failure_count"] += count
|
stats["failure_count"] += count
|
||||||
|
|
||||||
if action_type not in stats["action_breakdown"]:
|
if action_type not in stats["action_breakdown"]:
|
||||||
stats["action_breakdown"][action_type] = {"success": 0, "failure": 0}
|
stats["action_breakdown"][action_type] = {"success": 0, "failure": 0}
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
stats["action_breakdown"][action_type]["success"] += count
|
stats["action_breakdown"][action_type]["success"] += count
|
||||||
else:
|
else:
|
||||||
stats["action_breakdown"][action_type]["failure"] += count
|
stats["action_breakdown"][action_type]["failure"] += count
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
# ==================== 端到端加密 ====================
|
# ==================== 端到端加密 ====================
|
||||||
|
|
||||||
def _derive_key(self, password: str, salt: bytes) -> bytes:
|
def _derive_key(self, password: str, salt: bytes) -> bytes:
|
||||||
"""从密码派生密钥"""
|
"""从密码派生密钥"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
|
|
||||||
kdf = PBKDF2HMAC(
|
kdf = PBKDF2HMAC(
|
||||||
algorithm=hashes.SHA256(),
|
algorithm=hashes.SHA256(),
|
||||||
length=32,
|
length=32,
|
||||||
@@ -504,7 +499,7 @@ class SecurityManager:
|
|||||||
iterations=100000,
|
iterations=100000,
|
||||||
)
|
)
|
||||||
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||||||
|
|
||||||
def enable_encryption(
|
def enable_encryption(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -513,14 +508,14 @@ class SecurityManager:
|
|||||||
"""启用项目加密"""
|
"""启用项目加密"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
|
|
||||||
# 生成盐值
|
# 生成盐值
|
||||||
salt = secrets.token_hex(16)
|
salt = secrets.token_hex(16)
|
||||||
|
|
||||||
# 派生密钥并哈希(用于验证)
|
# 派生密钥并哈希(用于验证)
|
||||||
key = self._derive_key(master_password, salt.encode())
|
key = self._derive_key(master_password, salt.encode())
|
||||||
key_hash = hashlib.sha256(key).hexdigest()
|
key_hash = hashlib.sha256(key).hexdigest()
|
||||||
|
|
||||||
config = EncryptionConfig(
|
config = EncryptionConfig(
|
||||||
id=self._generate_id(),
|
id=self._generate_id(),
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
@@ -530,20 +525,20 @@ class SecurityManager:
|
|||||||
master_key_hash=key_hash,
|
master_key_hash=key_hash,
|
||||||
salt=salt
|
salt=salt
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 检查是否已存在配置
|
# 检查是否已存在配置
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id FROM encryption_configs WHERE project_id = ?",
|
"SELECT id FROM encryption_configs WHERE project_id = ?",
|
||||||
(project_id,)
|
(project_id,)
|
||||||
)
|
)
|
||||||
existing = cursor.fetchone()
|
existing = cursor.fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE encryption_configs
|
UPDATE encryption_configs
|
||||||
SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
|
SET is_enabled = 1, encryption_type = ?, key_derivation = ?,
|
||||||
master_key_hash = ?, salt = ?, updated_at = ?
|
master_key_hash = ?, salt = ?, updated_at = ?
|
||||||
WHERE project_id = ?
|
WHERE project_id = ?
|
||||||
@@ -555,7 +550,7 @@ class SecurityManager:
|
|||||||
config.id = existing[0]
|
config.id = existing[0]
|
||||||
else:
|
else:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO encryption_configs
|
INSERT INTO encryption_configs
|
||||||
(id, project_id, is_enabled, encryption_type, key_derivation,
|
(id, project_id, is_enabled, encryption_type, key_derivation,
|
||||||
master_key_hash, salt, created_at, updated_at)
|
master_key_hash, salt, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
@@ -565,10 +560,10 @@ class SecurityManager:
|
|||||||
config.master_key_hash, config.salt,
|
config.master_key_hash, config.salt,
|
||||||
config.created_at, config.updated_at
|
config.created_at, config.updated_at
|
||||||
))
|
))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# 记录审计日志
|
# 记录审计日志
|
||||||
self.log_audit(
|
self.log_audit(
|
||||||
action_type=AuditActionType.ENCRYPTION_ENABLE,
|
action_type=AuditActionType.ENCRYPTION_ENABLE,
|
||||||
@@ -576,9 +571,9 @@ class SecurityManager:
|
|||||||
resource_id=project_id,
|
resource_id=project_id,
|
||||||
action_details={"encryption_type": config.encryption_type}
|
action_details={"encryption_type": config.encryption_type}
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def disable_encryption(
|
def disable_encryption(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -588,28 +583,28 @@ class SecurityManager:
|
|||||||
# 验证密码
|
# 验证密码
|
||||||
if not self.verify_encryption_password(project_id, master_password):
|
if not self.verify_encryption_password(project_id, master_password):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE encryption_configs
|
UPDATE encryption_configs
|
||||||
SET is_enabled = 0, updated_at = ?
|
SET is_enabled = 0, updated_at = ?
|
||||||
WHERE project_id = ?
|
WHERE project_id = ?
|
||||||
""", (datetime.now().isoformat(), project_id))
|
""", (datetime.now().isoformat(), project_id))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# 记录审计日志
|
# 记录审计日志
|
||||||
self.log_audit(
|
self.log_audit(
|
||||||
action_type=AuditActionType.ENCRYPTION_DISABLE,
|
action_type=AuditActionType.ENCRYPTION_DISABLE,
|
||||||
resource_type="project",
|
resource_type="project",
|
||||||
resource_id=project_id
|
resource_id=project_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def verify_encryption_password(
|
def verify_encryption_password(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -618,41 +613,41 @@ class SecurityManager:
|
|||||||
"""验证加密密码"""
|
"""验证加密密码"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
|
"SELECT master_key_hash, salt FROM encryption_configs WHERE project_id = ?",
|
||||||
(project_id,)
|
(project_id,)
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
stored_hash, salt = row
|
stored_hash, salt = row
|
||||||
key = self._derive_key(password, salt.encode())
|
key = self._derive_key(password, salt.encode())
|
||||||
key_hash = hashlib.sha256(key).hexdigest()
|
key_hash = hashlib.sha256(key).hexdigest()
|
||||||
|
|
||||||
return key_hash == stored_hash
|
return key_hash == stored_hash
|
||||||
|
|
||||||
def get_encryption_config(self, project_id: str) -> Optional[EncryptionConfig]:
|
def get_encryption_config(self, project_id: str) -> Optional[EncryptionConfig]:
|
||||||
"""获取加密配置"""
|
"""获取加密配置"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM encryption_configs WHERE project_id = ?",
|
"SELECT * FROM encryption_configs WHERE project_id = ?",
|
||||||
(project_id,)
|
(project_id,)
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return EncryptionConfig(
|
return EncryptionConfig(
|
||||||
id=row[0],
|
id=row[0],
|
||||||
project_id=row[1],
|
project_id=row[1],
|
||||||
@@ -664,7 +659,7 @@ class SecurityManager:
|
|||||||
created_at=row[7],
|
created_at=row[7],
|
||||||
updated_at=row[8]
|
updated_at=row[8]
|
||||||
)
|
)
|
||||||
|
|
||||||
def encrypt_data(
|
def encrypt_data(
|
||||||
self,
|
self,
|
||||||
data: str,
|
data: str,
|
||||||
@@ -674,16 +669,16 @@ class SecurityManager:
|
|||||||
"""加密数据"""
|
"""加密数据"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
|
|
||||||
if salt is None:
|
if salt is None:
|
||||||
salt = secrets.token_hex(16)
|
salt = secrets.token_hex(16)
|
||||||
|
|
||||||
key = self._derive_key(password, salt.encode())
|
key = self._derive_key(password, salt.encode())
|
||||||
f = Fernet(key)
|
f = Fernet(key)
|
||||||
encrypted = f.encrypt(data.encode())
|
encrypted = f.encrypt(data.encode())
|
||||||
|
|
||||||
return base64.b64encode(encrypted).decode(), salt
|
return base64.b64encode(encrypted).decode(), salt
|
||||||
|
|
||||||
def decrypt_data(
|
def decrypt_data(
|
||||||
self,
|
self,
|
||||||
encrypted_data: str,
|
encrypted_data: str,
|
||||||
@@ -693,15 +688,15 @@ class SecurityManager:
|
|||||||
"""解密数据"""
|
"""解密数据"""
|
||||||
if not CRYPTO_AVAILABLE:
|
if not CRYPTO_AVAILABLE:
|
||||||
raise RuntimeError("cryptography library not available")
|
raise RuntimeError("cryptography library not available")
|
||||||
|
|
||||||
key = self._derive_key(password, salt.encode())
|
key = self._derive_key(password, salt.encode())
|
||||||
f = Fernet(key)
|
f = Fernet(key)
|
||||||
decrypted = f.decrypt(base64.b64decode(encrypted_data))
|
decrypted = f.decrypt(base64.b64decode(encrypted_data))
|
||||||
|
|
||||||
return decrypted.decode()
|
return decrypted.decode()
|
||||||
|
|
||||||
# ==================== 数据脱敏 ====================
|
# ==================== 数据脱敏 ====================
|
||||||
|
|
||||||
def create_masking_rule(
|
def create_masking_rule(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -718,7 +713,7 @@ class SecurityManager:
|
|||||||
default = self.DEFAULT_MASKING_RULES[rule_type]
|
default = self.DEFAULT_MASKING_RULES[rule_type]
|
||||||
pattern = default["pattern"]
|
pattern = default["pattern"]
|
||||||
replacement = replacement or default["replacement"]
|
replacement = replacement or default["replacement"]
|
||||||
|
|
||||||
rule = MaskingRule(
|
rule = MaskingRule(
|
||||||
id=self._generate_id(),
|
id=self._generate_id(),
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
@@ -729,12 +724,12 @@ class SecurityManager:
|
|||||||
description=description,
|
description=description,
|
||||||
priority=priority
|
priority=priority
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO masking_rules
|
INSERT INTO masking_rules
|
||||||
(id, project_id, name, rule_type, pattern, replacement,
|
(id, project_id, name, rule_type, pattern, replacement,
|
||||||
is_active, priority, description, created_at, updated_at)
|
is_active, priority, description, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
@@ -743,10 +738,10 @@ class SecurityManager:
|
|||||||
rule.pattern, rule.replacement, int(rule.is_active),
|
rule.pattern, rule.replacement, int(rule.is_active),
|
||||||
rule.priority, rule.description, rule.created_at, rule.updated_at
|
rule.priority, rule.description, rule.created_at, rule.updated_at
|
||||||
))
|
))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# 记录审计日志
|
# 记录审计日志
|
||||||
self.log_audit(
|
self.log_audit(
|
||||||
action_type=AuditActionType.DATA_MASKING,
|
action_type=AuditActionType.DATA_MASKING,
|
||||||
@@ -754,9 +749,9 @@ class SecurityManager:
|
|||||||
resource_id=project_id,
|
resource_id=project_id,
|
||||||
action_details={"action": "create_rule", "rule_name": name}
|
action_details={"action": "create_rule", "rule_name": name}
|
||||||
)
|
)
|
||||||
|
|
||||||
return rule
|
return rule
|
||||||
|
|
||||||
def get_masking_rules(
|
def get_masking_rules(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -765,19 +760,19 @@ class SecurityManager:
|
|||||||
"""获取脱敏规则"""
|
"""获取脱敏规则"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
query = "SELECT * FROM masking_rules WHERE project_id = ?"
|
query = "SELECT * FROM masking_rules WHERE project_id = ?"
|
||||||
params = [project_id]
|
params = [project_id]
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query += " AND is_active = 1"
|
query += " AND is_active = 1"
|
||||||
|
|
||||||
query += " ORDER BY priority DESC"
|
query += " ORDER BY priority DESC"
|
||||||
|
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
rules = []
|
rules = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
rules.append(MaskingRule(
|
rules.append(MaskingRule(
|
||||||
@@ -793,9 +788,9 @@ class SecurityManager:
|
|||||||
created_at=row[9],
|
created_at=row[9],
|
||||||
updated_at=row[10]
|
updated_at=row[10]
|
||||||
))
|
))
|
||||||
|
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
def update_masking_rule(
|
def update_masking_rule(
|
||||||
self,
|
self,
|
||||||
rule_id: str,
|
rule_id: str,
|
||||||
@@ -803,45 +798,45 @@ class SecurityManager:
|
|||||||
) -> Optional[MaskingRule]:
|
) -> Optional[MaskingRule]:
|
||||||
"""更新脱敏规则"""
|
"""更新脱敏规则"""
|
||||||
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
|
allowed_fields = ["name", "pattern", "replacement", "is_active", "priority", "description"]
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
set_clauses = []
|
set_clauses = []
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in allowed_fields:
|
if key in allowed_fields:
|
||||||
set_clauses.append(f"{key} = ?")
|
set_clauses.append(f"{key} = ?")
|
||||||
params.append(int(value) if key == "is_active" else value)
|
params.append(int(value) if key == "is_active" else value)
|
||||||
|
|
||||||
if not set_clauses:
|
if not set_clauses:
|
||||||
conn.close()
|
conn.close()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
set_clauses.append("updated_at = ?")
|
set_clauses.append("updated_at = ?")
|
||||||
params.append(datetime.now().isoformat())
|
params.append(datetime.now().isoformat())
|
||||||
params.append(rule_id)
|
params.append(rule_id)
|
||||||
|
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
UPDATE masking_rules
|
UPDATE masking_rules
|
||||||
SET {', '.join(set_clauses)}
|
SET {', '.join(set_clauses)}
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", params)
|
""", params)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# 获取更新后的规则
|
# 获取更新后的规则
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id,))
|
cursor.execute("SELECT * FROM masking_rules WHERE id = ?", (rule_id,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return MaskingRule(
|
return MaskingRule(
|
||||||
id=row[0],
|
id=row[0],
|
||||||
project_id=row[1],
|
project_id=row[1],
|
||||||
@@ -855,20 +850,20 @@ class SecurityManager:
|
|||||||
created_at=row[9],
|
created_at=row[9],
|
||||||
updated_at=row[10]
|
updated_at=row[10]
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_masking_rule(self, rule_id: str) -> bool:
|
def delete_masking_rule(self, rule_id: str) -> bool:
|
||||||
"""删除脱敏规则"""
|
"""删除脱敏规则"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id,))
|
cursor.execute("DELETE FROM masking_rules WHERE id = ?", (rule_id,))
|
||||||
|
|
||||||
success = cursor.rowcount > 0
|
success = cursor.rowcount > 0
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
def apply_masking(
|
def apply_masking(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -877,17 +872,17 @@ class SecurityManager:
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""应用脱敏规则到文本"""
|
"""应用脱敏规则到文本"""
|
||||||
rules = self.get_masking_rules(project_id)
|
rules = self.get_masking_rules(project_id)
|
||||||
|
|
||||||
if not rules:
|
if not rules:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
masked_text = text
|
masked_text = text
|
||||||
|
|
||||||
for rule in rules:
|
for rule in rules:
|
||||||
# 如果指定了规则类型,只应用指定类型的规则
|
# 如果指定了规则类型,只应用指定类型的规则
|
||||||
if rule_types and MaskingRuleType(rule.rule_type) not in rule_types:
|
if rule_types and MaskingRuleType(rule.rule_type) not in rule_types:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
masked_text = re.sub(
|
masked_text = re.sub(
|
||||||
rule.pattern,
|
rule.pattern,
|
||||||
@@ -897,9 +892,9 @@ class SecurityManager:
|
|||||||
except re.error:
|
except re.error:
|
||||||
# 忽略无效的正则表达式
|
# 忽略无效的正则表达式
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return masked_text
|
return masked_text
|
||||||
|
|
||||||
def apply_masking_to_entity(
|
def apply_masking_to_entity(
|
||||||
self,
|
self,
|
||||||
entity_data: Dict[str, Any],
|
entity_data: Dict[str, Any],
|
||||||
@@ -907,18 +902,18 @@ class SecurityManager:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""对实体数据应用脱敏"""
|
"""对实体数据应用脱敏"""
|
||||||
masked_data = entity_data.copy()
|
masked_data = entity_data.copy()
|
||||||
|
|
||||||
# 对可能包含敏感信息的字段进行脱敏
|
# 对可能包含敏感信息的字段进行脱敏
|
||||||
sensitive_fields = ["name", "definition", "description", "value"]
|
sensitive_fields = ["name", "definition", "description", "value"]
|
||||||
|
|
||||||
for field in sensitive_fields:
|
for field in sensitive_fields:
|
||||||
if field in masked_data and isinstance(masked_data[field], str):
|
if field in masked_data and isinstance(masked_data[field], str):
|
||||||
masked_data[field] = self.apply_masking(masked_data[field], project_id)
|
masked_data[field] = self.apply_masking(masked_data[field], project_id)
|
||||||
|
|
||||||
return masked_data
|
return masked_data
|
||||||
|
|
||||||
# ==================== 数据访问策略 ====================
|
# ==================== 数据访问策略 ====================
|
||||||
|
|
||||||
def create_access_policy(
|
def create_access_policy(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -944,12 +939,12 @@ class SecurityManager:
|
|||||||
max_access_count=max_access_count,
|
max_access_count=max_access_count,
|
||||||
require_approval=require_approval
|
require_approval=require_approval
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO data_access_policies
|
INSERT INTO data_access_policies
|
||||||
(id, project_id, name, description, allowed_users, allowed_roles,
|
(id, project_id, name, description, allowed_users, allowed_roles,
|
||||||
allowed_ips, time_restrictions, max_access_count, require_approval,
|
allowed_ips, time_restrictions, max_access_count, require_approval,
|
||||||
is_active, created_at, updated_at)
|
is_active, created_at, updated_at)
|
||||||
@@ -961,12 +956,12 @@ class SecurityManager:
|
|||||||
int(policy.require_approval), int(policy.is_active),
|
int(policy.require_approval), int(policy.is_active),
|
||||||
policy.created_at, policy.updated_at
|
policy.created_at, policy.updated_at
|
||||||
))
|
))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def get_access_policies(
|
def get_access_policies(
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -975,17 +970,17 @@ class SecurityManager:
|
|||||||
"""获取数据访问策略"""
|
"""获取数据访问策略"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
query = "SELECT * FROM data_access_policies WHERE project_id = ?"
|
query = "SELECT * FROM data_access_policies WHERE project_id = ?"
|
||||||
params = [project_id]
|
params = [project_id]
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query += " AND is_active = 1"
|
query += " AND is_active = 1"
|
||||||
|
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
policies = []
|
policies = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
policies.append(DataAccessPolicy(
|
policies.append(DataAccessPolicy(
|
||||||
@@ -1003,9 +998,9 @@ class SecurityManager:
|
|||||||
created_at=row[11],
|
created_at=row[11],
|
||||||
updated_at=row[12]
|
updated_at=row[12]
|
||||||
))
|
))
|
||||||
|
|
||||||
return policies
|
return policies
|
||||||
|
|
||||||
def check_access_permission(
|
def check_access_permission(
|
||||||
self,
|
self,
|
||||||
policy_id: str,
|
policy_id: str,
|
||||||
@@ -1015,17 +1010,17 @@ class SecurityManager:
|
|||||||
"""检查访问权限"""
|
"""检查访问权限"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
|
"SELECT * FROM data_access_policies WHERE id = ? AND is_active = 1",
|
||||||
(policy_id,)
|
(policy_id,)
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return False, "Policy not found or inactive"
|
return False, "Policy not found or inactive"
|
||||||
|
|
||||||
policy = DataAccessPolicy(
|
policy = DataAccessPolicy(
|
||||||
id=row[0],
|
id=row[0],
|
||||||
project_id=row[1],
|
project_id=row[1],
|
||||||
@@ -1041,13 +1036,13 @@ class SecurityManager:
|
|||||||
created_at=row[11],
|
created_at=row[11],
|
||||||
updated_at=row[12]
|
updated_at=row[12]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查用户白名单
|
# 检查用户白名单
|
||||||
if policy.allowed_users:
|
if policy.allowed_users:
|
||||||
allowed = json.loads(policy.allowed_users)
|
allowed = json.loads(policy.allowed_users)
|
||||||
if user_id not in allowed:
|
if user_id not in allowed:
|
||||||
return False, "User not in allowed list"
|
return False, "User not in allowed list"
|
||||||
|
|
||||||
# 检查IP白名单
|
# 检查IP白名单
|
||||||
if policy.allowed_ips and user_ip:
|
if policy.allowed_ips and user_ip:
|
||||||
allowed_ips = json.loads(policy.allowed_ips)
|
allowed_ips = json.loads(policy.allowed_ips)
|
||||||
@@ -1058,45 +1053,45 @@ class SecurityManager:
|
|||||||
break
|
break
|
||||||
if not ip_allowed:
|
if not ip_allowed:
|
||||||
return False, "IP not in allowed list"
|
return False, "IP not in allowed list"
|
||||||
|
|
||||||
# 检查时间限制
|
# 检查时间限制
|
||||||
if policy.time_restrictions:
|
if policy.time_restrictions:
|
||||||
restrictions = json.loads(policy.time_restrictions)
|
restrictions = json.loads(policy.time_restrictions)
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
if "start_time" in restrictions and "end_time" in restrictions:
|
if "start_time" in restrictions and "end_time" in restrictions:
|
||||||
current_time = now.strftime("%H:%M")
|
current_time = now.strftime("%H:%M")
|
||||||
if not (restrictions["start_time"] <= current_time <= restrictions["end_time"]):
|
if not (restrictions["start_time"] <= current_time <= restrictions["end_time"]):
|
||||||
return False, "Access not allowed at this time"
|
return False, "Access not allowed at this time"
|
||||||
|
|
||||||
if "days_of_week" in restrictions:
|
if "days_of_week" in restrictions:
|
||||||
if now.weekday() not in restrictions["days_of_week"]:
|
if now.weekday() not in restrictions["days_of_week"]:
|
||||||
return False, "Access not allowed on this day"
|
return False, "Access not allowed on this day"
|
||||||
|
|
||||||
# 检查是否需要审批
|
# 检查是否需要审批
|
||||||
if policy.require_approval:
|
if policy.require_approval:
|
||||||
# 检查是否有有效的访问请求
|
# 检查是否有有效的访问请求
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT * FROM access_requests
|
SELECT * FROM access_requests
|
||||||
WHERE policy_id = ? AND user_id = ? AND status = 'approved'
|
WHERE policy_id = ? AND user_id = ? AND status = 'approved'
|
||||||
AND (expires_at IS NULL OR expires_at > ?)
|
AND (expires_at IS NULL OR expires_at > ?)
|
||||||
""", (policy_id, user_id, datetime.now().isoformat()))
|
""", (policy_id, user_id, datetime.now().isoformat()))
|
||||||
|
|
||||||
request = cursor.fetchone()
|
request = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not request:
|
if not request:
|
||||||
return False, "Access requires approval"
|
return False, "Access requires approval"
|
||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def _match_ip_pattern(self, ip: str, pattern: str) -> bool:
|
def _match_ip_pattern(self, ip: str, pattern: str) -> bool:
|
||||||
"""匹配IP模式(支持CIDR)"""
|
"""匹配IP模式(支持CIDR)"""
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "/" in pattern:
|
if "/" in pattern:
|
||||||
# CIDR 表示法
|
# CIDR 表示法
|
||||||
@@ -1107,7 +1102,7 @@ class SecurityManager:
|
|||||||
return ip == pattern
|
return ip == pattern
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return ip == pattern
|
return ip == pattern
|
||||||
|
|
||||||
def create_access_request(
|
def create_access_request(
|
||||||
self,
|
self,
|
||||||
policy_id: str,
|
policy_id: str,
|
||||||
@@ -1123,12 +1118,12 @@ class SecurityManager:
|
|||||||
request_reason=request_reason,
|
request_reason=request_reason,
|
||||||
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
expires_at=(datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO access_requests
|
INSERT INTO access_requests
|
||||||
(id, policy_id, user_id, request_reason, status, expires_at, created_at)
|
(id, policy_id, user_id, request_reason, status, expires_at, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
@@ -1136,12 +1131,12 @@ class SecurityManager:
|
|||||||
request.request_reason, request.status, request.expires_at,
|
request.request_reason, request.status, request.expires_at,
|
||||||
request.created_at
|
request.created_at
|
||||||
))
|
))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def approve_access_request(
|
def approve_access_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@@ -1151,26 +1146,26 @@ class SecurityManager:
|
|||||||
"""批准访问请求"""
|
"""批准访问请求"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
expires_at = (datetime.now() + timedelta(hours=expires_hours)).isoformat()
|
||||||
approved_at = datetime.now().isoformat()
|
approved_at = datetime.now().isoformat()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
|
SET status = 'approved', approved_by = ?, approved_at = ?, expires_at = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", (approved_by, approved_at, expires_at, request_id))
|
""", (approved_by, approved_at, expires_at, request_id))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# 获取更新后的请求
|
# 获取更新后的请求
|
||||||
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,))
|
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AccessRequest(
|
return AccessRequest(
|
||||||
id=row[0],
|
id=row[0],
|
||||||
policy_id=row[1],
|
policy_id=row[1],
|
||||||
@@ -1182,7 +1177,7 @@ class SecurityManager:
|
|||||||
expires_at=row[7],
|
expires_at=row[7],
|
||||||
created_at=row[8]
|
created_at=row[8]
|
||||||
)
|
)
|
||||||
|
|
||||||
def reject_access_request(
|
def reject_access_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@@ -1191,22 +1186,22 @@ class SecurityManager:
|
|||||||
"""拒绝访问请求"""
|
"""拒绝访问请求"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET status = 'rejected', approved_by = ?
|
SET status = 'rejected', approved_by = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""", (rejected_by, request_id))
|
""", (rejected_by, request_id))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,))
|
cursor.execute("SELECT * FROM access_requests WHERE id = ?", (request_id,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AccessRequest(
|
return AccessRequest(
|
||||||
id=row[0],
|
id=row[0],
|
||||||
policy_id=row[1],
|
policy_id=row[1],
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -19,8 +19,7 @@ print("\n1. 测试模块导入...")
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from multimodal_processor import (
|
from multimodal_processor import (
|
||||||
get_multimodal_processor, MultimodalProcessor,
|
get_multimodal_processor
|
||||||
VideoProcessingResult, VideoFrame
|
|
||||||
)
|
)
|
||||||
print(" ✓ multimodal_processor 导入成功")
|
print(" ✓ multimodal_processor 导入成功")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -28,8 +27,7 @@ except ImportError as e:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from image_processor import (
|
from image_processor import (
|
||||||
get_image_processor, ImageProcessor,
|
get_image_processor
|
||||||
ImageProcessingResult, ImageEntity, ImageRelation
|
|
||||||
)
|
)
|
||||||
print(" ✓ image_processor 导入成功")
|
print(" ✓ image_processor 导入成功")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -37,8 +35,7 @@ except ImportError as e:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from multimodal_entity_linker import (
|
from multimodal_entity_linker import (
|
||||||
get_multimodal_entity_linker, MultimodalEntityLinker,
|
get_multimodal_entity_linker
|
||||||
MultimodalEntity, EntityLink, AlignmentResult, FusionResult
|
|
||||||
)
|
)
|
||||||
print(" ✓ multimodal_entity_linker 导入成功")
|
print(" ✓ multimodal_entity_linker 导入成功")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -74,21 +71,21 @@ print("\n3. 测试实体关联功能...")
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
linker = get_multimodal_entity_linker()
|
linker = get_multimodal_entity_linker()
|
||||||
|
|
||||||
# 测试字符串相似度
|
# 测试字符串相似度
|
||||||
sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha")
|
sim = linker.calculate_string_similarity("Project Alpha", "Project Alpha")
|
||||||
assert sim == 1.0, "完全匹配应该返回1.0"
|
assert sim == 1.0, "完全匹配应该返回1.0"
|
||||||
print(f" ✓ 字符串相似度计算正常 (完全匹配: {sim})")
|
print(f" ✓ 字符串相似度计算正常 (完全匹配: {sim})")
|
||||||
|
|
||||||
sim = linker.calculate_string_similarity("K8s", "Kubernetes")
|
sim = linker.calculate_string_similarity("K8s", "Kubernetes")
|
||||||
print(f" ✓ 字符串相似度计算正常 (不同字符串: {sim:.2f})")
|
print(f" ✓ 字符串相似度计算正常 (不同字符串: {sim:.2f})")
|
||||||
|
|
||||||
# 测试实体相似度
|
# 测试实体相似度
|
||||||
entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"}
|
entity1 = {"name": "Project Alpha", "type": "PROJECT", "definition": "核心项目"}
|
||||||
entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"}
|
entity2 = {"name": "Project Alpha", "type": "PROJECT", "definition": "主要项目"}
|
||||||
sim, match_type = linker.calculate_entity_similarity(entity1, entity2)
|
sim, match_type = linker.calculate_entity_similarity(entity1, entity2)
|
||||||
print(f" ✓ 实体相似度计算正常 (相似度: {sim:.2f}, 类型: {match_type})")
|
print(f" ✓ 实体相似度计算正常 (相似度: {sim:.2f}, 类型: {match_type})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ 实体关联功能测试失败: {e}")
|
print(f" ✗ 实体关联功能测试失败: {e}")
|
||||||
|
|
||||||
@@ -97,11 +94,11 @@ print("\n4. 测试图片处理器功能...")
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
processor = get_image_processor()
|
processor = get_image_processor()
|
||||||
|
|
||||||
# 测试图片类型检测(使用模拟数据)
|
# 测试图片类型检测(使用模拟数据)
|
||||||
print(f" ✓ 支持的图片类型: {list(processor.IMAGE_TYPES.keys())}")
|
print(f" ✓ 支持的图片类型: {list(processor.IMAGE_TYPES.keys())}")
|
||||||
print(f" ✓ 图片类型描述: {processor.IMAGE_TYPES}")
|
print(f" ✓ 图片类型描述: {processor.IMAGE_TYPES}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ 图片处理器功能测试失败: {e}")
|
print(f" ✗ 图片处理器功能测试失败: {e}")
|
||||||
|
|
||||||
@@ -110,11 +107,11 @@ print("\n5. 测试视频处理器配置...")
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
processor = get_multimodal_processor()
|
processor = get_multimodal_processor()
|
||||||
|
|
||||||
print(f" ✓ 视频目录: {processor.video_dir}")
|
print(f" ✓ 视频目录: {processor.video_dir}")
|
||||||
print(f" ✓ 帧目录: {processor.frames_dir}")
|
print(f" ✓ 帧目录: {processor.frames_dir}")
|
||||||
print(f" ✓ 音频目录: {processor.audio_dir}")
|
print(f" ✓ 音频目录: {processor.audio_dir}")
|
||||||
|
|
||||||
# 检查目录是否存在
|
# 检查目录是否存在
|
||||||
for dir_name, dir_path in [
|
for dir_name, dir_path in [
|
||||||
("视频", processor.video_dir),
|
("视频", processor.video_dir),
|
||||||
@@ -125,7 +122,7 @@ try:
|
|||||||
print(f" ✓ {dir_name}目录存在: {dir_path}")
|
print(f" ✓ {dir_name}目录存在: {dir_path}")
|
||||||
else:
|
else:
|
||||||
print(f" ✗ {dir_name}目录不存在: {dir_path}")
|
print(f" ✗ {dir_name}目录不存在: {dir_path}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ 视频处理器配置测试失败: {e}")
|
print(f" ✗ 视频处理器配置测试失败: {e}")
|
||||||
|
|
||||||
@@ -135,20 +132,20 @@ print("\n6. 测试数据库多模态方法...")
|
|||||||
try:
|
try:
|
||||||
from db_manager import get_db_manager
|
from db_manager import get_db_manager
|
||||||
db = get_db_manager()
|
db = get_db_manager()
|
||||||
|
|
||||||
# 检查多模态表是否存在
|
# 检查多模态表是否存在
|
||||||
conn = db.get_conn()
|
conn = db.get_conn()
|
||||||
tables = ['videos', 'video_frames', 'images', 'multimodal_mentions', 'multimodal_entity_links']
|
tables = ['videos', 'video_frames', 'images', 'multimodal_mentions', 'multimodal_entity_links']
|
||||||
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
try:
|
try:
|
||||||
conn.execute(f"SELECT 1 FROM {table} LIMIT 1")
|
conn.execute(f"SELECT 1 FROM {table} LIMIT 1")
|
||||||
print(f" ✓ 表 '{table}' 存在")
|
print(f" ✓ 表 '{table}' 存在")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ 表 '{table}' 不存在或无法访问: {e}")
|
print(f" ✗ 表 '{table}' 不存在或无法访问: {e}")
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ 数据库多模态方法测试失败: {e}")
|
print(f" ✗ 数据库多模态方法测试失败: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,34 +4,31 @@ InsightFlow Phase 7 Task 6 & 8 测试脚本
|
|||||||
测试高级搜索与发现、性能优化与扩展功能
|
测试高级搜索与发现、性能优化与扩展功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from performance_manager import (
|
||||||
|
get_performance_manager, CacheManager,
|
||||||
|
TaskQueue, PerformanceMonitor
|
||||||
|
)
|
||||||
|
from search_manager import (
|
||||||
|
get_search_manager, FullTextSearch,
|
||||||
|
SemanticSearch, EntityPathDiscovery,
|
||||||
|
KnowledgeGapDetection
|
||||||
|
)
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
|
|
||||||
# 添加 backend 到路径
|
# 添加 backend 到路径
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from search_manager import (
|
|
||||||
get_search_manager, SearchManager,
|
|
||||||
FullTextSearch, SemanticSearch,
|
|
||||||
EntityPathDiscovery, KnowledgeGapDetection
|
|
||||||
)
|
|
||||||
|
|
||||||
from performance_manager import (
|
|
||||||
get_performance_manager, PerformanceManager,
|
|
||||||
CacheManager, DatabaseSharding, TaskQueue, PerformanceMonitor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fulltext_search():
|
def test_fulltext_search():
|
||||||
"""测试全文搜索"""
|
"""测试全文搜索"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试全文搜索 (FullTextSearch)")
|
print("测试全文搜索 (FullTextSearch)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
search = FullTextSearch()
|
search = FullTextSearch()
|
||||||
|
|
||||||
# 测试索引创建
|
# 测试索引创建
|
||||||
print("\n1. 测试索引创建...")
|
print("\n1. 测试索引创建...")
|
||||||
success = search.index_content(
|
success = search.index_content(
|
||||||
@@ -41,7 +38,7 @@ def test_fulltext_search():
|
|||||||
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。"
|
text="这是一个测试实体,用于验证全文搜索功能。支持关键词高亮显示。"
|
||||||
)
|
)
|
||||||
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
|
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
|
||||||
|
|
||||||
# 测试搜索
|
# 测试搜索
|
||||||
print("\n2. 测试关键词搜索...")
|
print("\n2. 测试关键词搜索...")
|
||||||
results = search.search("测试", project_id="test_project")
|
results = search.search("测试", project_id="test_project")
|
||||||
@@ -49,15 +46,15 @@ def test_fulltext_search():
|
|||||||
if results:
|
if results:
|
||||||
print(f" 第一个结果: {results[0].content[:50]}...")
|
print(f" 第一个结果: {results[0].content[:50]}...")
|
||||||
print(f" 相关分数: {results[0].score}")
|
print(f" 相关分数: {results[0].score}")
|
||||||
|
|
||||||
# 测试布尔搜索
|
# 测试布尔搜索
|
||||||
print("\n3. 测试布尔搜索...")
|
print("\n3. 测试布尔搜索...")
|
||||||
results = search.search("测试 AND 全文", project_id="test_project")
|
results = search.search("测试 AND 全文", project_id="test_project")
|
||||||
print(f" AND 搜索结果: {len(results)}")
|
print(f" AND 搜索结果: {len(results)}")
|
||||||
|
|
||||||
results = search.search("测试 OR 关键词", project_id="test_project")
|
results = search.search("测试 OR 关键词", project_id="test_project")
|
||||||
print(f" OR 搜索结果: {len(results)}")
|
print(f" OR 搜索结果: {len(results)}")
|
||||||
|
|
||||||
# 测试高亮
|
# 测试高亮
|
||||||
print("\n4. 测试文本高亮...")
|
print("\n4. 测试文本高亮...")
|
||||||
highlighted = search.highlight_text(
|
highlighted = search.highlight_text(
|
||||||
@@ -65,33 +62,33 @@ def test_fulltext_search():
|
|||||||
"测试 全文"
|
"测试 全文"
|
||||||
)
|
)
|
||||||
print(f" 高亮结果: {highlighted}")
|
print(f" 高亮结果: {highlighted}")
|
||||||
|
|
||||||
print("\n✓ 全文搜索测试完成")
|
print("\n✓ 全文搜索测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_semantic_search():
|
def test_semantic_search():
|
||||||
"""测试语义搜索"""
|
"""测试语义搜索"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试语义搜索 (SemanticSearch)")
|
print("测试语义搜索 (SemanticSearch)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
semantic = SemanticSearch()
|
semantic = SemanticSearch()
|
||||||
|
|
||||||
# 检查可用性
|
# 检查可用性
|
||||||
print(f"\n1. 语义搜索可用性: {'✓ 可用' if semantic.is_available() else '✗ 不可用'}")
|
print(f"\n1. 语义搜索可用性: {'✓ 可用' if semantic.is_available() else '✗ 不可用'}")
|
||||||
|
|
||||||
if not semantic.is_available():
|
if not semantic.is_available():
|
||||||
print(" (需要安装 sentence-transformers 库)")
|
print(" (需要安装 sentence-transformers 库)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 测试 embedding 生成
|
# 测试 embedding 生成
|
||||||
print("\n2. 测试 embedding 生成...")
|
print("\n2. 测试 embedding 生成...")
|
||||||
embedding = semantic.generate_embedding("这是一个测试句子")
|
embedding = semantic.generate_embedding("这是一个测试句子")
|
||||||
if embedding:
|
if embedding:
|
||||||
print(f" Embedding 维度: {len(embedding)}")
|
print(f" Embedding 维度: {len(embedding)}")
|
||||||
print(f" 前5个值: {embedding[:5]}")
|
print(f" 前5个值: {embedding[:5]}")
|
||||||
|
|
||||||
# 测试索引
|
# 测试索引
|
||||||
print("\n3. 测试语义索引...")
|
print("\n3. 测试语义索引...")
|
||||||
success = semantic.index_embedding(
|
success = semantic.index_embedding(
|
||||||
@@ -101,68 +98,68 @@ def test_semantic_search():
|
|||||||
text="这是用于语义搜索测试的文本内容。"
|
text="这是用于语义搜索测试的文本内容。"
|
||||||
)
|
)
|
||||||
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
|
print(f" 索引创建: {'✓ 成功' if success else '✗ 失败'}")
|
||||||
|
|
||||||
print("\n✓ 语义搜索测试完成")
|
print("\n✓ 语义搜索测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_entity_path_discovery():
|
def test_entity_path_discovery():
|
||||||
"""测试实体路径发现"""
|
"""测试实体路径发现"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试实体路径发现 (EntityPathDiscovery)")
|
print("测试实体路径发现 (EntityPathDiscovery)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
discovery = EntityPathDiscovery()
|
discovery = EntityPathDiscovery()
|
||||||
|
|
||||||
print("\n1. 测试路径发现初始化...")
|
print("\n1. 测试路径发现初始化...")
|
||||||
print(f" 数据库路径: {discovery.db_path}")
|
print(f" 数据库路径: {discovery.db_path}")
|
||||||
|
|
||||||
print("\n2. 测试多跳关系发现...")
|
print("\n2. 测试多跳关系发现...")
|
||||||
# 注意:这需要在数据库中有实际数据
|
# 注意:这需要在数据库中有实际数据
|
||||||
print(" (需要实际实体数据才能测试)")
|
print(" (需要实际实体数据才能测试)")
|
||||||
|
|
||||||
print("\n✓ 实体路径发现测试完成")
|
print("\n✓ 实体路径发现测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_knowledge_gap_detection():
|
def test_knowledge_gap_detection():
|
||||||
"""测试知识缺口识别"""
|
"""测试知识缺口识别"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试知识缺口识别 (KnowledgeGapDetection)")
|
print("测试知识缺口识别 (KnowledgeGapDetection)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
detection = KnowledgeGapDetection()
|
detection = KnowledgeGapDetection()
|
||||||
|
|
||||||
print("\n1. 测试缺口检测初始化...")
|
print("\n1. 测试缺口检测初始化...")
|
||||||
print(f" 数据库路径: {detection.db_path}")
|
print(f" 数据库路径: {detection.db_path}")
|
||||||
|
|
||||||
print("\n2. 测试完整性报告生成...")
|
print("\n2. 测试完整性报告生成...")
|
||||||
# 注意:这需要在数据库中有实际项目数据
|
# 注意:这需要在数据库中有实际项目数据
|
||||||
print(" (需要实际项目数据才能测试)")
|
print(" (需要实际项目数据才能测试)")
|
||||||
|
|
||||||
print("\n✓ 知识缺口识别测试完成")
|
print("\n✓ 知识缺口识别测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_cache_manager():
|
def test_cache_manager():
|
||||||
"""测试缓存管理器"""
|
"""测试缓存管理器"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试缓存管理器 (CacheManager)")
|
print("测试缓存管理器 (CacheManager)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
cache = CacheManager()
|
cache = CacheManager()
|
||||||
|
|
||||||
print(f"\n1. 缓存后端: {'Redis' if cache.use_redis else '内存 LRU'}")
|
print(f"\n1. 缓存后端: {'Redis' if cache.use_redis else '内存 LRU'}")
|
||||||
|
|
||||||
print("\n2. 测试缓存操作...")
|
print("\n2. 测试缓存操作...")
|
||||||
# 设置缓存
|
# 设置缓存
|
||||||
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60)
|
cache.set("test_key_1", {"name": "测试数据", "value": 123}, ttl=60)
|
||||||
print(" ✓ 设置缓存 test_key_1")
|
print(" ✓ 设置缓存 test_key_1")
|
||||||
|
|
||||||
# 获取缓存
|
# 获取缓存
|
||||||
value = cache.get("test_key_1")
|
value = cache.get("test_key_1")
|
||||||
print(f" ✓ 获取缓存: {value}")
|
print(f" ✓ 获取缓存: {value}")
|
||||||
|
|
||||||
# 批量操作
|
# 批量操作
|
||||||
cache.set_many({
|
cache.set_many({
|
||||||
"batch_key_1": "value1",
|
"batch_key_1": "value1",
|
||||||
@@ -170,14 +167,14 @@ def test_cache_manager():
|
|||||||
"batch_key_3": "value3"
|
"batch_key_3": "value3"
|
||||||
}, ttl=60)
|
}, ttl=60)
|
||||||
print(" ✓ 批量设置缓存")
|
print(" ✓ 批量设置缓存")
|
||||||
|
|
||||||
values = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
|
values = cache.get_many(["batch_key_1", "batch_key_2", "batch_key_3"])
|
||||||
print(f" ✓ 批量获取缓存: {len(values)} 个")
|
print(f" ✓ 批量获取缓存: {len(values)} 个")
|
||||||
|
|
||||||
# 删除缓存
|
# 删除缓存
|
||||||
cache.delete("test_key_1")
|
cache.delete("test_key_1")
|
||||||
print(" ✓ 删除缓存 test_key_1")
|
print(" ✓ 删除缓存 test_key_1")
|
||||||
|
|
||||||
# 获取统计
|
# 获取统计
|
||||||
stats = cache.get_stats()
|
stats = cache.get_stats()
|
||||||
print(f"\n3. 缓存统计:")
|
print(f"\n3. 缓存统计:")
|
||||||
@@ -185,67 +182,67 @@ def test_cache_manager():
|
|||||||
print(f" 命中数: {stats['hits']}")
|
print(f" 命中数: {stats['hits']}")
|
||||||
print(f" 未命中数: {stats['misses']}")
|
print(f" 未命中数: {stats['misses']}")
|
||||||
print(f" 命中率: {stats['hit_rate']:.2%}")
|
print(f" 命中率: {stats['hit_rate']:.2%}")
|
||||||
|
|
||||||
if not cache.use_redis:
|
if not cache.use_redis:
|
||||||
print(f" 内存使用: {stats.get('memory_size_bytes', 0)} bytes")
|
print(f" 内存使用: {stats.get('memory_size_bytes', 0)} bytes")
|
||||||
print(f" 缓存条目数: {stats.get('cache_entries', 0)}")
|
print(f" 缓存条目数: {stats.get('cache_entries', 0)}")
|
||||||
|
|
||||||
print("\n✓ 缓存管理器测试完成")
|
print("\n✓ 缓存管理器测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_task_queue():
|
def test_task_queue():
|
||||||
"""测试任务队列"""
|
"""测试任务队列"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试任务队列 (TaskQueue)")
|
print("测试任务队列 (TaskQueue)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
queue = TaskQueue()
|
queue = TaskQueue()
|
||||||
|
|
||||||
print(f"\n1. 任务队列可用性: {'✓ 可用' if queue.is_available() else '✗ 不可用'}")
|
print(f"\n1. 任务队列可用性: {'✓ 可用' if queue.is_available() else '✗ 不可用'}")
|
||||||
print(f" 后端: {'Celery' if queue.use_celery else '内存'}")
|
print(f" 后端: {'Celery' if queue.use_celery else '内存'}")
|
||||||
|
|
||||||
print("\n2. 测试任务提交...")
|
print("\n2. 测试任务提交...")
|
||||||
|
|
||||||
# 定义测试任务处理器
|
# 定义测试任务处理器
|
||||||
def test_task_handler(payload):
|
def test_task_handler(payload):
|
||||||
print(f" 执行任务: {payload}")
|
print(f" 执行任务: {payload}")
|
||||||
return {"status": "success", "processed": True}
|
return {"status": "success", "processed": True}
|
||||||
|
|
||||||
queue.register_handler("test_task", test_task_handler)
|
queue.register_handler("test_task", test_task_handler)
|
||||||
|
|
||||||
# 提交任务
|
# 提交任务
|
||||||
task_id = queue.submit(
|
task_id = queue.submit(
|
||||||
task_type="test_task",
|
task_type="test_task",
|
||||||
payload={"test": "data", "timestamp": time.time()}
|
payload={"test": "data", "timestamp": time.time()}
|
||||||
)
|
)
|
||||||
print(f" ✓ 提交任务: {task_id}")
|
print(f" ✓ 提交任务: {task_id}")
|
||||||
|
|
||||||
# 获取任务状态
|
# 获取任务状态
|
||||||
task_info = queue.get_status(task_id)
|
task_info = queue.get_status(task_id)
|
||||||
if task_info:
|
if task_info:
|
||||||
print(f" ✓ 任务状态: {task_info.status}")
|
print(f" ✓ 任务状态: {task_info.status}")
|
||||||
|
|
||||||
# 获取统计
|
# 获取统计
|
||||||
stats = queue.get_stats()
|
stats = queue.get_stats()
|
||||||
print(f"\n3. 任务队列统计:")
|
print(f"\n3. 任务队列统计:")
|
||||||
print(f" 后端: {stats['backend']}")
|
print(f" 后端: {stats['backend']}")
|
||||||
print(f" 按状态统计: {stats.get('by_status', {})}")
|
print(f" 按状态统计: {stats.get('by_status', {})}")
|
||||||
|
|
||||||
print("\n✓ 任务队列测试完成")
|
print("\n✓ 任务队列测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_performance_monitor():
|
def test_performance_monitor():
|
||||||
"""测试性能监控"""
|
"""测试性能监控"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试性能监控 (PerformanceMonitor)")
|
print("测试性能监控 (PerformanceMonitor)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
monitor = PerformanceMonitor()
|
monitor = PerformanceMonitor()
|
||||||
|
|
||||||
print("\n1. 测试指标记录...")
|
print("\n1. 测试指标记录...")
|
||||||
|
|
||||||
# 记录一些测试指标
|
# 记录一些测试指标
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
monitor.record_metric(
|
monitor.record_metric(
|
||||||
@@ -254,7 +251,7 @@ def test_performance_monitor():
|
|||||||
endpoint="/api/v1/test",
|
endpoint="/api/v1/test",
|
||||||
metadata={"test": True}
|
metadata={"test": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
monitor.record_metric(
|
monitor.record_metric(
|
||||||
metric_type="db_query",
|
metric_type="db_query",
|
||||||
@@ -262,155 +259,155 @@ def test_performance_monitor():
|
|||||||
endpoint="SELECT test",
|
endpoint="SELECT test",
|
||||||
metadata={"test": True}
|
metadata={"test": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(" ✓ 记录了 8 个测试指标")
|
print(" ✓ 记录了 8 个测试指标")
|
||||||
|
|
||||||
# 获取统计
|
# 获取统计
|
||||||
print("\n2. 获取性能统计...")
|
print("\n2. 获取性能统计...")
|
||||||
stats = monitor.get_stats(hours=1)
|
stats = monitor.get_stats(hours=1)
|
||||||
print(f" 总请求数: {stats['overall']['total_requests']}")
|
print(f" 总请求数: {stats['overall']['total_requests']}")
|
||||||
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
|
print(f" 平均响应时间: {stats['overall']['avg_duration_ms']} ms")
|
||||||
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
|
print(f" 最大响应时间: {stats['overall']['max_duration_ms']} ms")
|
||||||
|
|
||||||
print("\n3. 按类型统计:")
|
print("\n3. 按类型统计:")
|
||||||
for type_stat in stats.get('by_type', []):
|
for type_stat in stats.get('by_type', []):
|
||||||
print(f" {type_stat['type']}: {type_stat['count']} 次, "
|
print(f" {type_stat['type']}: {type_stat['count']} 次, "
|
||||||
f"平均 {type_stat['avg_duration_ms']} ms")
|
f"平均 {type_stat['avg_duration_ms']} ms")
|
||||||
|
|
||||||
print("\n✓ 性能监控测试完成")
|
print("\n✓ 性能监控测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_search_manager():
|
def test_search_manager():
|
||||||
"""测试搜索管理器"""
|
"""测试搜索管理器"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试搜索管理器 (SearchManager)")
|
print("测试搜索管理器 (SearchManager)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_search_manager()
|
manager = get_search_manager()
|
||||||
|
|
||||||
print("\n1. 搜索管理器初始化...")
|
print("\n1. 搜索管理器初始化...")
|
||||||
print(f" ✓ 搜索管理器已初始化")
|
print(f" ✓ 搜索管理器已初始化")
|
||||||
|
|
||||||
print("\n2. 获取搜索统计...")
|
print("\n2. 获取搜索统计...")
|
||||||
stats = manager.get_search_stats()
|
stats = manager.get_search_stats()
|
||||||
print(f" 全文索引数: {stats['fulltext_indexed']}")
|
print(f" 全文索引数: {stats['fulltext_indexed']}")
|
||||||
print(f" 语义索引数: {stats['semantic_indexed']}")
|
print(f" 语义索引数: {stats['semantic_indexed']}")
|
||||||
print(f" 语义搜索可用: {stats['semantic_search_available']}")
|
print(f" 语义搜索可用: {stats['semantic_search_available']}")
|
||||||
|
|
||||||
print("\n✓ 搜索管理器测试完成")
|
print("\n✓ 搜索管理器测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_performance_manager():
|
def test_performance_manager():
|
||||||
"""测试性能管理器"""
|
"""测试性能管理器"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试性能管理器 (PerformanceManager)")
|
print("测试性能管理器 (PerformanceManager)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_performance_manager()
|
manager = get_performance_manager()
|
||||||
|
|
||||||
print("\n1. 性能管理器初始化...")
|
print("\n1. 性能管理器初始化...")
|
||||||
print(f" ✓ 性能管理器已初始化")
|
print(f" ✓ 性能管理器已初始化")
|
||||||
|
|
||||||
print("\n2. 获取系统健康状态...")
|
print("\n2. 获取系统健康状态...")
|
||||||
health = manager.get_health_status()
|
health = manager.get_health_status()
|
||||||
print(f" 缓存后端: {health['cache']['backend']}")
|
print(f" 缓存后端: {health['cache']['backend']}")
|
||||||
print(f" 任务队列后端: {health['task_queue']['backend']}")
|
print(f" 任务队列后端: {health['task_queue']['backend']}")
|
||||||
|
|
||||||
print("\n3. 获取完整统计...")
|
print("\n3. 获取完整统计...")
|
||||||
stats = manager.get_full_stats()
|
stats = manager.get_full_stats()
|
||||||
print(f" 缓存统计: {stats['cache']['total_requests']} 请求")
|
print(f" 缓存统计: {stats['cache']['total_requests']} 请求")
|
||||||
print(f" 任务队列统计: {stats['task_queue']}")
|
print(f" 任务队列统计: {stats['task_queue']}")
|
||||||
|
|
||||||
print("\n✓ 性能管理器测试完成")
|
print("\n✓ 性能管理器测试完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_all_tests():
|
def run_all_tests():
|
||||||
"""运行所有测试"""
|
"""运行所有测试"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("InsightFlow Phase 7 Task 6 & 8 测试")
|
print("InsightFlow Phase 7 Task 6 & 8 测试")
|
||||||
print("高级搜索与发现 + 性能优化与扩展")
|
print("高级搜索与发现 + 性能优化与扩展")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
# 搜索模块测试
|
# 搜索模块测试
|
||||||
try:
|
try:
|
||||||
results.append(("全文搜索", test_fulltext_search()))
|
results.append(("全文搜索", test_fulltext_search()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 全文搜索测试失败: {e}")
|
print(f"\n✗ 全文搜索测试失败: {e}")
|
||||||
results.append(("全文搜索", False))
|
results.append(("全文搜索", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("语义搜索", test_semantic_search()))
|
results.append(("语义搜索", test_semantic_search()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 语义搜索测试失败: {e}")
|
print(f"\n✗ 语义搜索测试失败: {e}")
|
||||||
results.append(("语义搜索", False))
|
results.append(("语义搜索", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("实体路径发现", test_entity_path_discovery()))
|
results.append(("实体路径发现", test_entity_path_discovery()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 实体路径发现测试失败: {e}")
|
print(f"\n✗ 实体路径发现测试失败: {e}")
|
||||||
results.append(("实体路径发现", False))
|
results.append(("实体路径发现", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("知识缺口识别", test_knowledge_gap_detection()))
|
results.append(("知识缺口识别", test_knowledge_gap_detection()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 知识缺口识别测试失败: {e}")
|
print(f"\n✗ 知识缺口识别测试失败: {e}")
|
||||||
results.append(("知识缺口识别", False))
|
results.append(("知识缺口识别", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("搜索管理器", test_search_manager()))
|
results.append(("搜索管理器", test_search_manager()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 搜索管理器测试失败: {e}")
|
print(f"\n✗ 搜索管理器测试失败: {e}")
|
||||||
results.append(("搜索管理器", False))
|
results.append(("搜索管理器", False))
|
||||||
|
|
||||||
# 性能模块测试
|
# 性能模块测试
|
||||||
try:
|
try:
|
||||||
results.append(("缓存管理器", test_cache_manager()))
|
results.append(("缓存管理器", test_cache_manager()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 缓存管理器测试失败: {e}")
|
print(f"\n✗ 缓存管理器测试失败: {e}")
|
||||||
results.append(("缓存管理器", False))
|
results.append(("缓存管理器", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("任务队列", test_task_queue()))
|
results.append(("任务队列", test_task_queue()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 任务队列测试失败: {e}")
|
print(f"\n✗ 任务队列测试失败: {e}")
|
||||||
results.append(("任务队列", False))
|
results.append(("任务队列", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("性能监控", test_performance_monitor()))
|
results.append(("性能监控", test_performance_monitor()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 性能监控测试失败: {e}")
|
print(f"\n✗ 性能监控测试失败: {e}")
|
||||||
results.append(("性能监控", False))
|
results.append(("性能监控", False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results.append(("性能管理器", test_performance_manager()))
|
results.append(("性能管理器", test_performance_manager()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ 性能管理器测试失败: {e}")
|
print(f"\n✗ 性能管理器测试失败: {e}")
|
||||||
results.append(("性能管理器", False))
|
results.append(("性能管理器", False))
|
||||||
|
|
||||||
# 打印测试汇总
|
# 打印测试汇总
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("测试汇总")
|
print("测试汇总")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
passed = sum(1 for _, result in results if result)
|
passed = sum(1 for _, result in results if result)
|
||||||
total = len(results)
|
total = len(results)
|
||||||
|
|
||||||
for name, result in results:
|
for name, result in results:
|
||||||
status = "✓ 通过" if result else "✗ 失败"
|
status = "✓ 通过" if result else "✗ 失败"
|
||||||
print(f" {status} - {name}")
|
print(f" {status} - {name}")
|
||||||
|
|
||||||
print(f"\n总计: {passed}/{total} 测试通过")
|
print(f"\n总计: {passed}/{total} 测试通过")
|
||||||
|
|
||||||
if passed == total:
|
if passed == total:
|
||||||
print("\n🎉 所有测试通过!")
|
print("\n🎉 所有测试通过!")
|
||||||
else:
|
else:
|
||||||
print(f"\n⚠️ 有 {total - passed} 个测试失败")
|
print(f"\n⚠️ 有 {total - passed} 个测试失败")
|
||||||
|
|
||||||
return passed == total
|
return passed == total
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,24 +10,22 @@ InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试脚本
|
|||||||
5. 资源使用统计
|
5. 资源使用统计
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from tenant_manager import (
|
||||||
|
get_tenant_manager
|
||||||
|
)
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from tenant_manager import (
|
|
||||||
get_tenant_manager, TenantManager, Tenant, TenantDomain,
|
|
||||||
TenantBranding, TenantMember, TenantRole, TenantStatus, TenantTier
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tenant_management():
|
def test_tenant_management():
|
||||||
"""测试租户管理功能"""
|
"""测试租户管理功能"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("测试 1: 租户管理")
|
print("测试 1: 租户管理")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_tenant_manager()
|
manager = get_tenant_manager()
|
||||||
|
|
||||||
# 1. 创建租户
|
# 1. 创建租户
|
||||||
print("\n1.1 创建租户...")
|
print("\n1.1 创建租户...")
|
||||||
tenant = manager.create_tenant(
|
tenant = manager.create_tenant(
|
||||||
@@ -42,19 +40,19 @@ def test_tenant_management():
|
|||||||
print(f" - 层级: {tenant.tier}")
|
print(f" - 层级: {tenant.tier}")
|
||||||
print(f" - 状态: {tenant.status}")
|
print(f" - 状态: {tenant.status}")
|
||||||
print(f" - 资源限制: {tenant.resource_limits}")
|
print(f" - 资源限制: {tenant.resource_limits}")
|
||||||
|
|
||||||
# 2. 获取租户
|
# 2. 获取租户
|
||||||
print("\n1.2 获取租户信息...")
|
print("\n1.2 获取租户信息...")
|
||||||
fetched = manager.get_tenant(tenant.id)
|
fetched = manager.get_tenant(tenant.id)
|
||||||
assert fetched is not None, "获取租户失败"
|
assert fetched is not None, "获取租户失败"
|
||||||
print(f"✅ 获取租户成功: {fetched.name}")
|
print(f"✅ 获取租户成功: {fetched.name}")
|
||||||
|
|
||||||
# 3. 通过 slug 获取
|
# 3. 通过 slug 获取
|
||||||
print("\n1.3 通过 slug 获取租户...")
|
print("\n1.3 通过 slug 获取租户...")
|
||||||
by_slug = manager.get_tenant_by_slug(tenant.slug)
|
by_slug = manager.get_tenant_by_slug(tenant.slug)
|
||||||
assert by_slug is not None, "通过 slug 获取失败"
|
assert by_slug is not None, "通过 slug 获取失败"
|
||||||
print(f"✅ 通过 slug 获取成功: {by_slug.name}")
|
print(f"✅ 通过 slug 获取成功: {by_slug.name}")
|
||||||
|
|
||||||
# 4. 更新租户
|
# 4. 更新租户
|
||||||
print("\n1.4 更新租户信息...")
|
print("\n1.4 更新租户信息...")
|
||||||
updated = manager.update_tenant(
|
updated = manager.update_tenant(
|
||||||
@@ -64,12 +62,12 @@ def test_tenant_management():
|
|||||||
)
|
)
|
||||||
assert updated is not None, "更新租户失败"
|
assert updated is not None, "更新租户失败"
|
||||||
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
|
print(f"✅ 租户更新成功: {updated.name}, 层级: {updated.tier}")
|
||||||
|
|
||||||
# 5. 列出租户
|
# 5. 列出租户
|
||||||
print("\n1.5 列出租户...")
|
print("\n1.5 列出租户...")
|
||||||
tenants = manager.list_tenants(limit=10)
|
tenants = manager.list_tenants(limit=10)
|
||||||
print(f"✅ 找到 {len(tenants)} 个租户")
|
print(f"✅ 找到 {len(tenants)} 个租户")
|
||||||
|
|
||||||
return tenant.id
|
return tenant.id
|
||||||
|
|
||||||
|
|
||||||
@@ -78,9 +76,9 @@ def test_domain_management(tenant_id: str):
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("测试 2: 域名管理")
|
print("测试 2: 域名管理")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_tenant_manager()
|
manager = get_tenant_manager()
|
||||||
|
|
||||||
# 1. 添加域名
|
# 1. 添加域名
|
||||||
print("\n2.1 添加自定义域名...")
|
print("\n2.1 添加自定义域名...")
|
||||||
domain = manager.add_domain(
|
domain = manager.add_domain(
|
||||||
@@ -92,19 +90,19 @@ def test_domain_management(tenant_id: str):
|
|||||||
print(f" - ID: {domain.id}")
|
print(f" - ID: {domain.id}")
|
||||||
print(f" - 状态: {domain.status}")
|
print(f" - 状态: {domain.status}")
|
||||||
print(f" - 验证令牌: {domain.verification_token}")
|
print(f" - 验证令牌: {domain.verification_token}")
|
||||||
|
|
||||||
# 2. 获取验证指导
|
# 2. 获取验证指导
|
||||||
print("\n2.2 获取域名验证指导...")
|
print("\n2.2 获取域名验证指导...")
|
||||||
instructions = manager.get_domain_verification_instructions(domain.id)
|
instructions = manager.get_domain_verification_instructions(domain.id)
|
||||||
print(f"✅ 验证指导:")
|
print(f"✅ 验证指导:")
|
||||||
print(f" - DNS 记录: {instructions['dns_record']}")
|
print(f" - DNS 记录: {instructions['dns_record']}")
|
||||||
print(f" - 文件验证: {instructions['file_verification']}")
|
print(f" - 文件验证: {instructions['file_verification']}")
|
||||||
|
|
||||||
# 3. 验证域名
|
# 3. 验证域名
|
||||||
print("\n2.3 验证域名...")
|
print("\n2.3 验证域名...")
|
||||||
verified = manager.verify_domain(tenant_id, domain.id)
|
verified = manager.verify_domain(tenant_id, domain.id)
|
||||||
print(f"✅ 域名验证结果: {verified}")
|
print(f"✅ 域名验证结果: {verified}")
|
||||||
|
|
||||||
# 4. 通过域名获取租户
|
# 4. 通过域名获取租户
|
||||||
print("\n2.4 通过域名获取租户...")
|
print("\n2.4 通过域名获取租户...")
|
||||||
by_domain = manager.get_tenant_by_domain("test.example.com")
|
by_domain = manager.get_tenant_by_domain("test.example.com")
|
||||||
@@ -112,14 +110,14 @@ def test_domain_management(tenant_id: str):
|
|||||||
print(f"✅ 通过域名获取租户成功: {by_domain.name}")
|
print(f"✅ 通过域名获取租户成功: {by_domain.name}")
|
||||||
else:
|
else:
|
||||||
print("⚠️ 通过域名获取租户失败(验证可能未通过)")
|
print("⚠️ 通过域名获取租户失败(验证可能未通过)")
|
||||||
|
|
||||||
# 5. 列出域名
|
# 5. 列出域名
|
||||||
print("\n2.5 列出所有域名...")
|
print("\n2.5 列出所有域名...")
|
||||||
domains = manager.list_domains(tenant_id)
|
domains = manager.list_domains(tenant_id)
|
||||||
print(f"✅ 找到 {len(domains)} 个域名")
|
print(f"✅ 找到 {len(domains)} 个域名")
|
||||||
for d in domains:
|
for d in domains:
|
||||||
print(f" - {d.domain} ({d.status})")
|
print(f" - {d.domain} ({d.status})")
|
||||||
|
|
||||||
return domain.id
|
return domain.id
|
||||||
|
|
||||||
|
|
||||||
@@ -128,9 +126,9 @@ def test_branding_management(tenant_id: str):
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("测试 3: 品牌白标")
|
print("测试 3: 品牌白标")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_tenant_manager()
|
manager = get_tenant_manager()
|
||||||
|
|
||||||
# 1. 更新品牌配置
|
# 1. 更新品牌配置
|
||||||
print("\n3.1 更新品牌配置...")
|
print("\n3.1 更新品牌配置...")
|
||||||
branding = manager.update_branding(
|
branding = manager.update_branding(
|
||||||
@@ -147,19 +145,19 @@ def test_branding_management(tenant_id: str):
|
|||||||
print(f" - Logo: {branding.logo_url}")
|
print(f" - Logo: {branding.logo_url}")
|
||||||
print(f" - 主色: {branding.primary_color}")
|
print(f" - 主色: {branding.primary_color}")
|
||||||
print(f" - 次色: {branding.secondary_color}")
|
print(f" - 次色: {branding.secondary_color}")
|
||||||
|
|
||||||
# 2. 获取品牌配置
|
# 2. 获取品牌配置
|
||||||
print("\n3.2 获取品牌配置...")
|
print("\n3.2 获取品牌配置...")
|
||||||
fetched = manager.get_branding(tenant_id)
|
fetched = manager.get_branding(tenant_id)
|
||||||
assert fetched is not None, "获取品牌配置失败"
|
assert fetched is not None, "获取品牌配置失败"
|
||||||
print(f"✅ 获取品牌配置成功")
|
print(f"✅ 获取品牌配置成功")
|
||||||
|
|
||||||
# 3. 生成品牌 CSS
|
# 3. 生成品牌 CSS
|
||||||
print("\n3.3 生成品牌 CSS...")
|
print("\n3.3 生成品牌 CSS...")
|
||||||
css = manager.get_branding_css(tenant_id)
|
css = manager.get_branding_css(tenant_id)
|
||||||
print(f"✅ 生成 CSS 成功 ({len(css)} 字符)")
|
print(f"✅ 生成 CSS 成功 ({len(css)} 字符)")
|
||||||
print(f" CSS 预览:\n{css[:200]}...")
|
print(f" CSS 预览:\n{css[:200]}...")
|
||||||
|
|
||||||
return branding.id
|
return branding.id
|
||||||
|
|
||||||
|
|
||||||
@@ -168,9 +166,9 @@ def test_member_management(tenant_id: str):
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("测试 4: 成员管理")
|
print("测试 4: 成员管理")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_tenant_manager()
|
manager = get_tenant_manager()
|
||||||
|
|
||||||
# 1. 邀请成员
|
# 1. 邀请成员
|
||||||
print("\n4.1 邀请成员...")
|
print("\n4.1 邀请成员...")
|
||||||
member1 = manager.invite_member(
|
member1 = manager.invite_member(
|
||||||
@@ -183,7 +181,7 @@ def test_member_management(tenant_id: str):
|
|||||||
print(f" - ID: {member1.id}")
|
print(f" - ID: {member1.id}")
|
||||||
print(f" - 角色: {member1.role}")
|
print(f" - 角色: {member1.role}")
|
||||||
print(f" - 权限: {member1.permissions}")
|
print(f" - 权限: {member1.permissions}")
|
||||||
|
|
||||||
member2 = manager.invite_member(
|
member2 = manager.invite_member(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
email="member@test.com",
|
email="member@test.com",
|
||||||
@@ -191,36 +189,36 @@ def test_member_management(tenant_id: str):
|
|||||||
invited_by="user_001"
|
invited_by="user_001"
|
||||||
)
|
)
|
||||||
print(f"✅ 成员邀请成功: {member2.email}")
|
print(f"✅ 成员邀请成功: {member2.email}")
|
||||||
|
|
||||||
# 2. 接受邀请
|
# 2. 接受邀请
|
||||||
print("\n4.2 接受邀请...")
|
print("\n4.2 接受邀请...")
|
||||||
accepted = manager.accept_invitation(member1.id, "user_002")
|
accepted = manager.accept_invitation(member1.id, "user_002")
|
||||||
print(f"✅ 邀请接受结果: {accepted}")
|
print(f"✅ 邀请接受结果: {accepted}")
|
||||||
|
|
||||||
# 3. 列出成员
|
# 3. 列出成员
|
||||||
print("\n4.3 列出所有成员...")
|
print("\n4.3 列出所有成员...")
|
||||||
members = manager.list_members(tenant_id)
|
members = manager.list_members(tenant_id)
|
||||||
print(f"✅ 找到 {len(members)} 个成员")
|
print(f"✅ 找到 {len(members)} 个成员")
|
||||||
for m in members:
|
for m in members:
|
||||||
print(f" - {m.email} ({m.role}) - {m.status}")
|
print(f" - {m.email} ({m.role}) - {m.status}")
|
||||||
|
|
||||||
# 4. 检查权限
|
# 4. 检查权限
|
||||||
print("\n4.4 检查权限...")
|
print("\n4.4 检查权限...")
|
||||||
can_manage = manager.check_permission(tenant_id, "user_002", "project", "create")
|
can_manage = manager.check_permission(tenant_id, "user_002", "project", "create")
|
||||||
print(f"✅ user_002 可以创建项目: {can_manage}")
|
print(f"✅ user_002 可以创建项目: {can_manage}")
|
||||||
|
|
||||||
# 5. 更新成员角色
|
# 5. 更新成员角色
|
||||||
print("\n4.5 更新成员角色...")
|
print("\n4.5 更新成员角色...")
|
||||||
updated = manager.update_member_role(tenant_id, member2.id, "viewer")
|
updated = manager.update_member_role(tenant_id, member2.id, "viewer")
|
||||||
print(f"✅ 角色更新结果: {updated}")
|
print(f"✅ 角色更新结果: {updated}")
|
||||||
|
|
||||||
# 6. 获取用户所属租户
|
# 6. 获取用户所属租户
|
||||||
print("\n4.6 获取用户所属租户...")
|
print("\n4.6 获取用户所属租户...")
|
||||||
user_tenants = manager.get_user_tenants("user_002")
|
user_tenants = manager.get_user_tenants("user_002")
|
||||||
print(f"✅ user_002 属于 {len(user_tenants)} 个租户")
|
print(f"✅ user_002 属于 {len(user_tenants)} 个租户")
|
||||||
for t in user_tenants:
|
for t in user_tenants:
|
||||||
print(f" - {t['name']} ({t['member_role']})")
|
print(f" - {t['name']} ({t['member_role']})")
|
||||||
|
|
||||||
return member1.id, member2.id
|
return member1.id, member2.id
|
||||||
|
|
||||||
|
|
||||||
@@ -229,9 +227,9 @@ def test_usage_tracking(tenant_id: str):
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("测试 5: 资源使用统计")
|
print("测试 5: 资源使用统计")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_tenant_manager()
|
manager = get_tenant_manager()
|
||||||
|
|
||||||
# 1. 记录使用
|
# 1. 记录使用
|
||||||
print("\n5.1 记录资源使用...")
|
print("\n5.1 记录资源使用...")
|
||||||
manager.record_usage(
|
manager.record_usage(
|
||||||
@@ -244,7 +242,7 @@ def test_usage_tracking(tenant_id: str):
|
|||||||
members_count=3
|
members_count=3
|
||||||
)
|
)
|
||||||
print("✅ 资源使用记录成功")
|
print("✅ 资源使用记录成功")
|
||||||
|
|
||||||
# 2. 获取使用统计
|
# 2. 获取使用统计
|
||||||
print("\n5.2 获取使用统计...")
|
print("\n5.2 获取使用统计...")
|
||||||
stats = manager.get_usage_stats(tenant_id)
|
stats = manager.get_usage_stats(tenant_id)
|
||||||
@@ -256,13 +254,13 @@ def test_usage_tracking(tenant_id: str):
|
|||||||
print(f" - 实体数: {stats['entities_count']}")
|
print(f" - 实体数: {stats['entities_count']}")
|
||||||
print(f" - 成员数: {stats['members_count']}")
|
print(f" - 成员数: {stats['members_count']}")
|
||||||
print(f" - 使用百分比: {stats['usage_percentages']}")
|
print(f" - 使用百分比: {stats['usage_percentages']}")
|
||||||
|
|
||||||
# 3. 检查资源限制
|
# 3. 检查资源限制
|
||||||
print("\n5.3 检查资源限制...")
|
print("\n5.3 检查资源限制...")
|
||||||
for resource in ["storage", "transcription", "api_calls", "projects", "entities", "members"]:
|
for resource in ["storage", "transcription", "api_calls", "projects", "entities", "members"]:
|
||||||
allowed, current, limit = manager.check_resource_limit(tenant_id, resource)
|
allowed, current, limit = manager.check_resource_limit(tenant_id, resource)
|
||||||
print(f" - {resource}: {current}/{limit} ({'✅' if allowed else '❌'})")
|
print(f" - {resource}: {current}/{limit} ({'✅' if allowed else '❌'})")
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
@@ -271,20 +269,20 @@ def cleanup(tenant_id: str, domain_id: str, member_ids: list):
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("清理测试数据")
|
print("清理测试数据")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
manager = get_tenant_manager()
|
manager = get_tenant_manager()
|
||||||
|
|
||||||
# 移除成员
|
# 移除成员
|
||||||
for member_id in member_ids:
|
for member_id in member_ids:
|
||||||
if member_id:
|
if member_id:
|
||||||
manager.remove_member(tenant_id, member_id)
|
manager.remove_member(tenant_id, member_id)
|
||||||
print(f"✅ 成员已移除: {member_id}")
|
print(f"✅ 成员已移除: {member_id}")
|
||||||
|
|
||||||
# 移除域名
|
# 移除域名
|
||||||
if domain_id:
|
if domain_id:
|
||||||
manager.remove_domain(tenant_id, domain_id)
|
manager.remove_domain(tenant_id, domain_id)
|
||||||
print(f"✅ 域名已移除: {domain_id}")
|
print(f"✅ 域名已移除: {domain_id}")
|
||||||
|
|
||||||
# 删除租户
|
# 删除租户
|
||||||
manager.delete_tenant(tenant_id)
|
manager.delete_tenant(tenant_id)
|
||||||
print(f"✅ 租户已删除: {tenant_id}")
|
print(f"✅ 租户已删除: {tenant_id}")
|
||||||
@@ -295,11 +293,11 @@ def main():
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试")
|
print("InsightFlow Phase 8 Task 1 - 多租户 SaaS 架构测试")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
tenant_id = None
|
tenant_id = None
|
||||||
domain_id = None
|
domain_id = None
|
||||||
member_ids = []
|
member_ids = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 运行所有测试
|
# 运行所有测试
|
||||||
tenant_id = test_tenant_management()
|
tenant_id = test_tenant_management()
|
||||||
@@ -308,16 +306,16 @@ def main():
|
|||||||
m1, m2 = test_member_management(tenant_id)
|
m1, m2 = test_member_management(tenant_id)
|
||||||
member_ids = [m1, m2]
|
member_ids = [m1, m2]
|
||||||
test_usage_tracking(tenant_id)
|
test_usage_tracking(tenant_id)
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("✅ 所有测试通过!")
|
print("✅ 所有测试通过!")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ 测试失败: {e}")
|
print(f"\n❌ 测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理
|
# 清理
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
@@ -328,4 +326,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -3,56 +3,55 @@
|
|||||||
InsightFlow Phase 8 Task 2 测试脚本 - 订阅与计费系统
|
InsightFlow Phase 8 Task 2 测试脚本 - 订阅与计费系统
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from subscription_manager import (
|
||||||
|
SubscriptionManager, PaymentProvider
|
||||||
|
)
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from subscription_manager import (
|
|
||||||
get_subscription_manager, SubscriptionManager,
|
|
||||||
SubscriptionStatus, PaymentProvider, PaymentStatus, InvoiceStatus, RefundStatus
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_subscription_manager():
|
def test_subscription_manager():
|
||||||
"""测试订阅管理器"""
|
"""测试订阅管理器"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试")
|
print("InsightFlow Phase 8 Task 2 - 订阅与计费系统测试")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 使用临时文件数据库进行测试
|
# 使用临时文件数据库进行测试
|
||||||
db_path = tempfile.mktemp(suffix='.db')
|
db_path = tempfile.mktemp(suffix='.db')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
manager = SubscriptionManager(db_path=db_path)
|
manager = SubscriptionManager(db_path=db_path)
|
||||||
|
|
||||||
print("\n1. 测试订阅计划管理")
|
print("\n1. 测试订阅计划管理")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# 获取默认计划
|
# 获取默认计划
|
||||||
plans = manager.list_plans()
|
plans = manager.list_plans()
|
||||||
print(f"✓ 默认计划数量: {len(plans)}")
|
print(f"✓ 默认计划数量: {len(plans)}")
|
||||||
for plan in plans:
|
for plan in plans:
|
||||||
print(f" - {plan.name} ({plan.tier}): ¥{plan.price_monthly}/月")
|
print(f" - {plan.name} ({plan.tier}): ¥{plan.price_monthly}/月")
|
||||||
|
|
||||||
# 通过 tier 获取计划
|
# 通过 tier 获取计划
|
||||||
free_plan = manager.get_plan_by_tier("free")
|
free_plan = manager.get_plan_by_tier("free")
|
||||||
pro_plan = manager.get_plan_by_tier("pro")
|
pro_plan = manager.get_plan_by_tier("pro")
|
||||||
enterprise_plan = manager.get_plan_by_tier("enterprise")
|
enterprise_plan = manager.get_plan_by_tier("enterprise")
|
||||||
|
|
||||||
assert free_plan is not None, "Free 计划应该存在"
|
assert free_plan is not None, "Free 计划应该存在"
|
||||||
assert pro_plan is not None, "Pro 计划应该存在"
|
assert pro_plan is not None, "Pro 计划应该存在"
|
||||||
assert enterprise_plan is not None, "Enterprise 计划应该存在"
|
assert enterprise_plan is not None, "Enterprise 计划应该存在"
|
||||||
|
|
||||||
print(f"✓ Free 计划: {free_plan.name}")
|
print(f"✓ Free 计划: {free_plan.name}")
|
||||||
print(f"✓ Pro 计划: {pro_plan.name}")
|
print(f"✓ Pro 计划: {pro_plan.name}")
|
||||||
print(f"✓ Enterprise 计划: {enterprise_plan.name}")
|
print(f"✓ Enterprise 计划: {enterprise_plan.name}")
|
||||||
|
|
||||||
print("\n2. 测试订阅管理")
|
print("\n2. 测试订阅管理")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
tenant_id = "test-tenant-001"
|
tenant_id = "test-tenant-001"
|
||||||
|
|
||||||
# 创建订阅
|
# 创建订阅
|
||||||
subscription = manager.create_subscription(
|
subscription = manager.create_subscription(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -60,21 +59,21 @@ def test_subscription_manager():
|
|||||||
payment_provider=PaymentProvider.STRIPE.value,
|
payment_provider=PaymentProvider.STRIPE.value,
|
||||||
trial_days=14
|
trial_days=14
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"✓ 创建订阅: {subscription.id}")
|
print(f"✓ 创建订阅: {subscription.id}")
|
||||||
print(f" - 状态: {subscription.status}")
|
print(f" - 状态: {subscription.status}")
|
||||||
print(f" - 计划: {pro_plan.name}")
|
print(f" - 计划: {pro_plan.name}")
|
||||||
print(f" - 试用开始: {subscription.trial_start}")
|
print(f" - 试用开始: {subscription.trial_start}")
|
||||||
print(f" - 试用结束: {subscription.trial_end}")
|
print(f" - 试用结束: {subscription.trial_end}")
|
||||||
|
|
||||||
# 获取租户订阅
|
# 获取租户订阅
|
||||||
tenant_sub = manager.get_tenant_subscription(tenant_id)
|
tenant_sub = manager.get_tenant_subscription(tenant_id)
|
||||||
assert tenant_sub is not None, "应该能获取到租户订阅"
|
assert tenant_sub is not None, "应该能获取到租户订阅"
|
||||||
print(f"✓ 获取租户订阅: {tenant_sub.id}")
|
print(f"✓ 获取租户订阅: {tenant_sub.id}")
|
||||||
|
|
||||||
print("\n3. 测试用量记录")
|
print("\n3. 测试用量记录")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# 记录转录用量
|
# 记录转录用量
|
||||||
usage1 = manager.record_usage(
|
usage1 = manager.record_usage(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -84,7 +83,7 @@ def test_subscription_manager():
|
|||||||
description="会议转录"
|
description="会议转录"
|
||||||
)
|
)
|
||||||
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
|
print(f"✓ 记录转录用量: {usage1.quantity} {usage1.unit}, 费用: ¥{usage1.cost:.2f}")
|
||||||
|
|
||||||
# 记录存储用量
|
# 记录存储用量
|
||||||
usage2 = manager.record_usage(
|
usage2 = manager.record_usage(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -94,17 +93,17 @@ def test_subscription_manager():
|
|||||||
description="文件存储"
|
description="文件存储"
|
||||||
)
|
)
|
||||||
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
|
print(f"✓ 记录存储用量: {usage2.quantity} {usage2.unit}, 费用: ¥{usage2.cost:.2f}")
|
||||||
|
|
||||||
# 获取用量汇总
|
# 获取用量汇总
|
||||||
summary = manager.get_usage_summary(tenant_id)
|
summary = manager.get_usage_summary(tenant_id)
|
||||||
print(f"✓ 用量汇总:")
|
print(f"✓ 用量汇总:")
|
||||||
print(f" - 总费用: ¥{summary['total_cost']:.2f}")
|
print(f" - 总费用: ¥{summary['total_cost']:.2f}")
|
||||||
for resource, data in summary['breakdown'].items():
|
for resource, data in summary['breakdown'].items():
|
||||||
print(f" - {resource}: {data['quantity']} (¥{data['cost']:.2f})")
|
print(f" - {resource}: {data['quantity']} (¥{data['cost']:.2f})")
|
||||||
|
|
||||||
print("\n4. 测试支付管理")
|
print("\n4. 测试支付管理")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# 创建支付
|
# 创建支付
|
||||||
payment = manager.create_payment(
|
payment = manager.create_payment(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -117,31 +116,31 @@ def test_subscription_manager():
|
|||||||
print(f" - 金额: ¥{payment.amount}")
|
print(f" - 金额: ¥{payment.amount}")
|
||||||
print(f" - 提供商: {payment.provider}")
|
print(f" - 提供商: {payment.provider}")
|
||||||
print(f" - 状态: {payment.status}")
|
print(f" - 状态: {payment.status}")
|
||||||
|
|
||||||
# 确认支付
|
# 确认支付
|
||||||
confirmed = manager.confirm_payment(payment.id, "alipay_123456")
|
confirmed = manager.confirm_payment(payment.id, "alipay_123456")
|
||||||
print(f"✓ 确认支付完成: {confirmed.status}")
|
print(f"✓ 确认支付完成: {confirmed.status}")
|
||||||
|
|
||||||
# 列出支付记录
|
# 列出支付记录
|
||||||
payments = manager.list_payments(tenant_id)
|
payments = manager.list_payments(tenant_id)
|
||||||
print(f"✓ 支付记录数量: {len(payments)}")
|
print(f"✓ 支付记录数量: {len(payments)}")
|
||||||
|
|
||||||
print("\n5. 测试发票管理")
|
print("\n5. 测试发票管理")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# 列出发票
|
# 列出发票
|
||||||
invoices = manager.list_invoices(tenant_id)
|
invoices = manager.list_invoices(tenant_id)
|
||||||
print(f"✓ 发票数量: {len(invoices)}")
|
print(f"✓ 发票数量: {len(invoices)}")
|
||||||
|
|
||||||
if invoices:
|
if invoices:
|
||||||
invoice = invoices[0]
|
invoice = invoices[0]
|
||||||
print(f" - 发票号: {invoice.invoice_number}")
|
print(f" - 发票号: {invoice.invoice_number}")
|
||||||
print(f" - 金额: ¥{invoice.amount_due}")
|
print(f" - 金额: ¥{invoice.amount_due}")
|
||||||
print(f" - 状态: {invoice.status}")
|
print(f" - 状态: {invoice.status}")
|
||||||
|
|
||||||
print("\n6. 测试退款管理")
|
print("\n6. 测试退款管理")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# 申请退款
|
# 申请退款
|
||||||
refund = manager.request_refund(
|
refund = manager.request_refund(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -154,30 +153,30 @@ def test_subscription_manager():
|
|||||||
print(f" - 金额: ¥{refund.amount}")
|
print(f" - 金额: ¥{refund.amount}")
|
||||||
print(f" - 原因: {refund.reason}")
|
print(f" - 原因: {refund.reason}")
|
||||||
print(f" - 状态: {refund.status}")
|
print(f" - 状态: {refund.status}")
|
||||||
|
|
||||||
# 批准退款
|
# 批准退款
|
||||||
approved = manager.approve_refund(refund.id, "admin_001")
|
approved = manager.approve_refund(refund.id, "admin_001")
|
||||||
print(f"✓ 批准退款: {approved.status}")
|
print(f"✓ 批准退款: {approved.status}")
|
||||||
|
|
||||||
# 完成退款
|
# 完成退款
|
||||||
completed = manager.complete_refund(refund.id, "refund_123456")
|
completed = manager.complete_refund(refund.id, "refund_123456")
|
||||||
print(f"✓ 完成退款: {completed.status}")
|
print(f"✓ 完成退款: {completed.status}")
|
||||||
|
|
||||||
# 列出退款记录
|
# 列出退款记录
|
||||||
refunds = manager.list_refunds(tenant_id)
|
refunds = manager.list_refunds(tenant_id)
|
||||||
print(f"✓ 退款记录数量: {len(refunds)}")
|
print(f"✓ 退款记录数量: {len(refunds)}")
|
||||||
|
|
||||||
print("\n7. 测试账单历史")
|
print("\n7. 测试账单历史")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
history = manager.get_billing_history(tenant_id)
|
history = manager.get_billing_history(tenant_id)
|
||||||
print(f"✓ 账单历史记录数量: {len(history)}")
|
print(f"✓ 账单历史记录数量: {len(history)}")
|
||||||
for h in history:
|
for h in history:
|
||||||
print(f" - [{h.type}] {h.description}: ¥{h.amount}")
|
print(f" - [{h.type}] {h.description}: ¥{h.amount}")
|
||||||
|
|
||||||
print("\n8. 测试支付提供商集成")
|
print("\n8. 测试支付提供商集成")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# Stripe Checkout
|
# Stripe Checkout
|
||||||
stripe_session = manager.create_stripe_checkout_session(
|
stripe_session = manager.create_stripe_checkout_session(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -186,38 +185,38 @@ def test_subscription_manager():
|
|||||||
cancel_url="https://example.com/cancel"
|
cancel_url="https://example.com/cancel"
|
||||||
)
|
)
|
||||||
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
|
print(f"✓ Stripe Checkout 会话: {stripe_session['session_id']}")
|
||||||
|
|
||||||
# 支付宝订单
|
# 支付宝订单
|
||||||
alipay_order = manager.create_alipay_order(
|
alipay_order = manager.create_alipay_order(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
plan_id=pro_plan.id
|
plan_id=pro_plan.id
|
||||||
)
|
)
|
||||||
print(f"✓ 支付宝订单: {alipay_order['order_id']}")
|
print(f"✓ 支付宝订单: {alipay_order['order_id']}")
|
||||||
|
|
||||||
# 微信支付订单
|
# 微信支付订单
|
||||||
wechat_order = manager.create_wechat_order(
|
wechat_order = manager.create_wechat_order(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
plan_id=pro_plan.id
|
plan_id=pro_plan.id
|
||||||
)
|
)
|
||||||
print(f"✓ 微信支付订单: {wechat_order['order_id']}")
|
print(f"✓ 微信支付订单: {wechat_order['order_id']}")
|
||||||
|
|
||||||
# Webhook 处理
|
# Webhook 处理
|
||||||
webhook_result = manager.handle_webhook("stripe", {
|
webhook_result = manager.handle_webhook("stripe", {
|
||||||
"event_type": "checkout.session.completed",
|
"event_type": "checkout.session.completed",
|
||||||
"data": {"object": {"id": "cs_test"}}
|
"data": {"object": {"id": "cs_test"}}
|
||||||
})
|
})
|
||||||
print(f"✓ Webhook 处理: {webhook_result}")
|
print(f"✓ Webhook 处理: {webhook_result}")
|
||||||
|
|
||||||
print("\n9. 测试订阅变更")
|
print("\n9. 测试订阅变更")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# 更改计划
|
# 更改计划
|
||||||
changed = manager.change_plan(
|
changed = manager.change_plan(
|
||||||
subscription_id=subscription.id,
|
subscription_id=subscription.id,
|
||||||
new_plan_id=enterprise_plan.id
|
new_plan_id=enterprise_plan.id
|
||||||
)
|
)
|
||||||
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
|
print(f"✓ 更改计划: {changed.plan_id} (Enterprise)")
|
||||||
|
|
||||||
# 取消订阅
|
# 取消订阅
|
||||||
cancelled = manager.cancel_subscription(
|
cancelled = manager.cancel_subscription(
|
||||||
subscription_id=subscription.id,
|
subscription_id=subscription.id,
|
||||||
@@ -225,17 +224,18 @@ def test_subscription_manager():
|
|||||||
)
|
)
|
||||||
print(f"✓ 取消订阅: {cancelled.status}")
|
print(f"✓ 取消订阅: {cancelled.status}")
|
||||||
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
|
print(f" - 周期结束时取消: {cancelled.cancel_at_period_end}")
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("所有测试通过! ✓")
|
print("所有测试通过! ✓")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理临时数据库
|
# 清理临时数据库
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
os.remove(db_path)
|
os.remove(db_path)
|
||||||
print(f"\n清理临时数据库: {db_path}")
|
print(f"\n清理临时数据库: {db_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
test_subscription_manager()
|
test_subscription_manager()
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ InsightFlow Phase 8 Task 4 测试脚本
|
|||||||
测试 AI 能力增强功能
|
测试 AI 能力增强功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from ai_manager import (
|
||||||
|
get_ai_manager, ModelType, PredictionType
|
||||||
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
@@ -11,19 +14,13 @@ import os
|
|||||||
# Add backend directory to path
|
# Add backend directory to path
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from ai_manager import (
|
|
||||||
get_ai_manager, CustomModel, TrainingSample, MultimodalAnalysis,
|
|
||||||
KnowledgeGraphRAG, SmartSummary, PredictionModel, PredictionResult,
|
|
||||||
ModelType, ModelStatus, MultimodalProvider, PredictionType
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_model():
|
def test_custom_model():
|
||||||
"""测试自定义模型功能"""
|
"""测试自定义模型功能"""
|
||||||
print("\n=== 测试自定义模型 ===")
|
print("\n=== 测试自定义模型 ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 1. 创建自定义模型
|
# 1. 创建自定义模型
|
||||||
print("1. 创建自定义模型...")
|
print("1. 创建自定义模型...")
|
||||||
model = manager.create_custom_model(
|
model = manager.create_custom_model(
|
||||||
@@ -43,7 +40,7 @@ def test_custom_model():
|
|||||||
created_by="user_001"
|
created_by="user_001"
|
||||||
)
|
)
|
||||||
print(f" 创建成功: {model.id}, 状态: {model.status.value}")
|
print(f" 创建成功: {model.id}, 状态: {model.status.value}")
|
||||||
|
|
||||||
# 2. 添加训练样本
|
# 2. 添加训练样本
|
||||||
print("2. 添加训练样本...")
|
print("2. 添加训练样本...")
|
||||||
samples = [
|
samples = [
|
||||||
@@ -72,7 +69,7 @@ def test_custom_model():
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
for sample_data in samples:
|
for sample_data in samples:
|
||||||
sample = manager.add_training_sample(
|
sample = manager.add_training_sample(
|
||||||
model_id=model.id,
|
model_id=model.id,
|
||||||
@@ -81,28 +78,28 @@ def test_custom_model():
|
|||||||
metadata={"source": "manual"}
|
metadata={"source": "manual"}
|
||||||
)
|
)
|
||||||
print(f" 添加样本: {sample.id}")
|
print(f" 添加样本: {sample.id}")
|
||||||
|
|
||||||
# 3. 获取训练样本
|
# 3. 获取训练样本
|
||||||
print("3. 获取训练样本...")
|
print("3. 获取训练样本...")
|
||||||
all_samples = manager.get_training_samples(model.id)
|
all_samples = manager.get_training_samples(model.id)
|
||||||
print(f" 共有 {len(all_samples)} 个训练样本")
|
print(f" 共有 {len(all_samples)} 个训练样本")
|
||||||
|
|
||||||
# 4. 列出自定义模型
|
# 4. 列出自定义模型
|
||||||
print("4. 列出自定义模型...")
|
print("4. 列出自定义模型...")
|
||||||
models = manager.list_custom_models(tenant_id="tenant_001")
|
models = manager.list_custom_models(tenant_id="tenant_001")
|
||||||
print(f" 找到 {len(models)} 个模型")
|
print(f" 找到 {len(models)} 个模型")
|
||||||
for m in models:
|
for m in models:
|
||||||
print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
|
print(f" - {m.name} ({m.model_type.value}): {m.status.value}")
|
||||||
|
|
||||||
return model.id
|
return model.id
|
||||||
|
|
||||||
|
|
||||||
async def test_train_and_predict(model_id: str):
|
async def test_train_and_predict(model_id: str):
|
||||||
"""测试训练和预测"""
|
"""测试训练和预测"""
|
||||||
print("\n=== 测试模型训练和预测 ===")
|
print("\n=== 测试模型训练和预测 ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 1. 训练模型
|
# 1. 训练模型
|
||||||
print("1. 训练模型...")
|
print("1. 训练模型...")
|
||||||
try:
|
try:
|
||||||
@@ -112,7 +109,7 @@ async def test_train_and_predict(model_id: str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" 训练失败: {e}")
|
print(f" 训练失败: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2. 使用模型预测
|
# 2. 使用模型预测
|
||||||
print("2. 使用模型预测...")
|
print("2. 使用模型预测...")
|
||||||
test_text = "赵六患有糖尿病,正在使用胰岛素治疗。"
|
test_text = "赵六患有糖尿病,正在使用胰岛素治疗。"
|
||||||
@@ -127,9 +124,9 @@ async def test_train_and_predict(model_id: str):
|
|||||||
def test_prediction_models():
|
def test_prediction_models():
|
||||||
"""测试预测模型"""
|
"""测试预测模型"""
|
||||||
print("\n=== 测试预测模型 ===")
|
print("\n=== 测试预测模型 ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 1. 创建趋势预测模型
|
# 1. 创建趋势预测模型
|
||||||
print("1. 创建趋势预测模型...")
|
print("1. 创建趋势预测模型...")
|
||||||
trend_model = manager.create_prediction_model(
|
trend_model = manager.create_prediction_model(
|
||||||
@@ -145,7 +142,7 @@ def test_prediction_models():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f" 创建成功: {trend_model.id}")
|
print(f" 创建成功: {trend_model.id}")
|
||||||
|
|
||||||
# 2. 创建异常检测模型
|
# 2. 创建异常检测模型
|
||||||
print("2. 创建异常检测模型...")
|
print("2. 创建异常检测模型...")
|
||||||
anomaly_model = manager.create_prediction_model(
|
anomaly_model = manager.create_prediction_model(
|
||||||
@@ -161,23 +158,23 @@ def test_prediction_models():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f" 创建成功: {anomaly_model.id}")
|
print(f" 创建成功: {anomaly_model.id}")
|
||||||
|
|
||||||
# 3. 列出预测模型
|
# 3. 列出预测模型
|
||||||
print("3. 列出预测模型...")
|
print("3. 列出预测模型...")
|
||||||
models = manager.list_prediction_models(tenant_id="tenant_001")
|
models = manager.list_prediction_models(tenant_id="tenant_001")
|
||||||
print(f" 找到 {len(models)} 个预测模型")
|
print(f" 找到 {len(models)} 个预测模型")
|
||||||
for m in models:
|
for m in models:
|
||||||
print(f" - {m.name} ({m.prediction_type.value})")
|
print(f" - {m.name} ({m.prediction_type.value})")
|
||||||
|
|
||||||
return trend_model.id, anomaly_model.id
|
return trend_model.id, anomaly_model.id
|
||||||
|
|
||||||
|
|
||||||
async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
||||||
"""测试预测功能"""
|
"""测试预测功能"""
|
||||||
print("\n=== 测试预测功能 ===")
|
print("\n=== 测试预测功能 ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 1. 训练趋势预测模型
|
# 1. 训练趋势预测模型
|
||||||
print("1. 训练趋势预测模型...")
|
print("1. 训练趋势预测模型...")
|
||||||
historical_data = [
|
historical_data = [
|
||||||
@@ -191,7 +188,7 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
|||||||
]
|
]
|
||||||
trained = await manager.train_prediction_model(trend_model_id, historical_data)
|
trained = await manager.train_prediction_model(trend_model_id, historical_data)
|
||||||
print(f" 训练完成,准确率: {trained.accuracy}")
|
print(f" 训练完成,准确率: {trained.accuracy}")
|
||||||
|
|
||||||
# 2. 趋势预测
|
# 2. 趋势预测
|
||||||
print("2. 趋势预测...")
|
print("2. 趋势预测...")
|
||||||
trend_result = await manager.predict(
|
trend_result = await manager.predict(
|
||||||
@@ -199,7 +196,7 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
|||||||
{"historical_values": [10, 12, 15, 14, 18, 20, 22]}
|
{"historical_values": [10, 12, 15, 14, 18, 20, 22]}
|
||||||
)
|
)
|
||||||
print(f" 预测结果: {trend_result.prediction_data}")
|
print(f" 预测结果: {trend_result.prediction_data}")
|
||||||
|
|
||||||
# 3. 异常检测
|
# 3. 异常检测
|
||||||
print("3. 异常检测...")
|
print("3. 异常检测...")
|
||||||
anomaly_result = await manager.predict(
|
anomaly_result = await manager.predict(
|
||||||
@@ -215,9 +212,9 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
|||||||
def test_kg_rag():
|
def test_kg_rag():
|
||||||
"""测试知识图谱 RAG"""
|
"""测试知识图谱 RAG"""
|
||||||
print("\n=== 测试知识图谱 RAG ===")
|
print("\n=== 测试知识图谱 RAG ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 创建 RAG 配置
|
# 创建 RAG 配置
|
||||||
print("1. 创建知识图谱 RAG 配置...")
|
print("1. 创建知识图谱 RAG 配置...")
|
||||||
rag = manager.create_kg_rag(
|
rag = manager.create_kg_rag(
|
||||||
@@ -241,21 +238,21 @@ def test_kg_rag():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f" 创建成功: {rag.id}")
|
print(f" 创建成功: {rag.id}")
|
||||||
|
|
||||||
# 列出 RAG 配置
|
# 列出 RAG 配置
|
||||||
print("2. 列出 RAG 配置...")
|
print("2. 列出 RAG 配置...")
|
||||||
rags = manager.list_kg_rags(tenant_id="tenant_001")
|
rags = manager.list_kg_rags(tenant_id="tenant_001")
|
||||||
print(f" 找到 {len(rags)} 个配置")
|
print(f" 找到 {len(rags)} 个配置")
|
||||||
|
|
||||||
return rag.id
|
return rag.id
|
||||||
|
|
||||||
|
|
||||||
async def test_kg_rag_query(rag_id: str):
|
async def test_kg_rag_query(rag_id: str):
|
||||||
"""测试 RAG 查询"""
|
"""测试 RAG 查询"""
|
||||||
print("\n=== 测试知识图谱 RAG 查询 ===")
|
print("\n=== 测试知识图谱 RAG 查询 ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 模拟项目实体和关系
|
# 模拟项目实体和关系
|
||||||
project_entities = [
|
project_entities = [
|
||||||
{"id": "e1", "name": "张三", "type": "PERSON", "definition": "项目经理"},
|
{"id": "e1", "name": "张三", "type": "PERSON", "definition": "项目经理"},
|
||||||
@@ -264,18 +261,36 @@ async def test_kg_rag_query(rag_id: str):
|
|||||||
{"id": "e4", "name": "Kubernetes", "type": "TECH", "definition": "容器编排平台"},
|
{"id": "e4", "name": "Kubernetes", "type": "TECH", "definition": "容器编排平台"},
|
||||||
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"}
|
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"}
|
||||||
]
|
]
|
||||||
|
|
||||||
project_relations = [
|
project_relations = [{"source_entity_id": "e1",
|
||||||
{"source_entity_id": "e1", "target_entity_id": "e3", "source_name": "张三", "target_name": "Project Alpha", "relation_type": "works_with", "evidence": "张三负责 Project Alpha 的管理工作"},
|
"target_entity_id": "e3",
|
||||||
{"source_entity_id": "e2", "target_entity_id": "e3", "source_name": "李四", "target_name": "Project Alpha", "relation_type": "works_with", "evidence": "李四负责 Project Alpha 的技术架构"},
|
"source_name": "张三",
|
||||||
{"source_entity_id": "e3", "target_entity_id": "e4", "source_name": "Project Alpha", "target_name": "Kubernetes", "relation_type": "depends_on", "evidence": "项目使用 Kubernetes 进行部署"},
|
"target_name": "Project Alpha",
|
||||||
{"source_entity_id": "e1", "target_entity_id": "e5", "source_name": "张三", "target_name": "TechCorp", "relation_type": "belongs_to", "evidence": "张三是 TechCorp 的员工"}
|
"relation_type": "works_with",
|
||||||
]
|
"evidence": "张三负责 Project Alpha 的管理工作"},
|
||||||
|
{"source_entity_id": "e2",
|
||||||
|
"target_entity_id": "e3",
|
||||||
|
"source_name": "李四",
|
||||||
|
"target_name": "Project Alpha",
|
||||||
|
"relation_type": "works_with",
|
||||||
|
"evidence": "李四负责 Project Alpha 的技术架构"},
|
||||||
|
{"source_entity_id": "e3",
|
||||||
|
"target_entity_id": "e4",
|
||||||
|
"source_name": "Project Alpha",
|
||||||
|
"target_name": "Kubernetes",
|
||||||
|
"relation_type": "depends_on",
|
||||||
|
"evidence": "项目使用 Kubernetes 进行部署"},
|
||||||
|
{"source_entity_id": "e1",
|
||||||
|
"target_entity_id": "e5",
|
||||||
|
"source_name": "张三",
|
||||||
|
"target_name": "TechCorp",
|
||||||
|
"relation_type": "belongs_to",
|
||||||
|
"evidence": "张三是 TechCorp 的员工"}]
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
print("1. 执行 RAG 查询...")
|
print("1. 执行 RAG 查询...")
|
||||||
query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?"
|
query_text = "Project Alpha 项目有哪些人参与?使用了什么技术?"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await manager.query_kg_rag(
|
result = await manager.query_kg_rag(
|
||||||
rag_id=rag_id,
|
rag_id=rag_id,
|
||||||
@@ -283,7 +298,7 @@ async def test_kg_rag_query(rag_id: str):
|
|||||||
project_entities=project_entities,
|
project_entities=project_entities,
|
||||||
project_relations=project_relations
|
project_relations=project_relations
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" 查询: {result.query}")
|
print(f" 查询: {result.query}")
|
||||||
print(f" 回答: {result.answer[:200]}...")
|
print(f" 回答: {result.answer[:200]}...")
|
||||||
print(f" 置信度: {result.confidence}")
|
print(f" 置信度: {result.confidence}")
|
||||||
@@ -296,9 +311,9 @@ async def test_kg_rag_query(rag_id: str):
|
|||||||
async def test_smart_summary():
|
async def test_smart_summary():
|
||||||
"""测试智能摘要"""
|
"""测试智能摘要"""
|
||||||
print("\n=== 测试智能摘要 ===")
|
print("\n=== 测试智能摘要 ===")
|
||||||
|
|
||||||
manager = get_ai_manager()
|
manager = get_ai_manager()
|
||||||
|
|
||||||
# 模拟转录文本
|
# 模拟转录文本
|
||||||
transcript_text = """
|
transcript_text = """
|
||||||
今天的会议主要讨论了 Project Alpha 的进展情况。张三作为项目经理,
|
今天的会议主要讨论了 Project Alpha 的进展情况。张三作为项目经理,
|
||||||
@@ -307,7 +322,7 @@ async def test_smart_summary():
|
|||||||
会议还讨论了下一步的工作计划,包括测试、文档编写和上线准备。
|
会议还讨论了下一步的工作计划,包括测试、文档编写和上线准备。
|
||||||
大家一致认为项目进展顺利,预计可以按时交付。
|
大家一致认为项目进展顺利,预计可以按时交付。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content_data = {
|
content_data = {
|
||||||
"text": transcript_text,
|
"text": transcript_text,
|
||||||
"entities": [
|
"entities": [
|
||||||
@@ -317,10 +332,10 @@ async def test_smart_summary():
|
|||||||
{"name": "Kubernetes", "type": "TECH"}
|
{"name": "Kubernetes", "type": "TECH"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# 生成不同类型的摘要
|
# 生成不同类型的摘要
|
||||||
summary_types = ["extractive", "abstractive", "key_points"]
|
summary_types = ["extractive", "abstractive", "key_points"]
|
||||||
|
|
||||||
for summary_type in summary_types:
|
for summary_type in summary_types:
|
||||||
print(f"1. 生成 {summary_type} 类型摘要...")
|
print(f"1. 生成 {summary_type} 类型摘要...")
|
||||||
try:
|
try:
|
||||||
@@ -332,7 +347,7 @@ async def test_smart_summary():
|
|||||||
summary_type=summary_type,
|
summary_type=summary_type,
|
||||||
content_data=content_data
|
content_data=content_data
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" 摘要类型: {summary.summary_type}")
|
print(f" 摘要类型: {summary.summary_type}")
|
||||||
print(f" 内容: {summary.content[:150]}...")
|
print(f" 内容: {summary.content[:150]}...")
|
||||||
print(f" 关键要点: {summary.key_points[:3]}")
|
print(f" 关键要点: {summary.key_points[:3]}")
|
||||||
@@ -346,33 +361,33 @@ async def main():
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("InsightFlow Phase 8 Task 4 - AI 能力增强测试")
|
print("InsightFlow Phase 8 Task 4 - AI 能力增强测试")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 测试自定义模型
|
# 测试自定义模型
|
||||||
model_id = test_custom_model()
|
model_id = test_custom_model()
|
||||||
|
|
||||||
# 测试训练和预测
|
# 测试训练和预测
|
||||||
await test_train_and_predict(model_id)
|
await test_train_and_predict(model_id)
|
||||||
|
|
||||||
# 测试预测模型
|
# 测试预测模型
|
||||||
trend_model_id, anomaly_model_id = test_prediction_models()
|
trend_model_id, anomaly_model_id = test_prediction_models()
|
||||||
|
|
||||||
# 测试预测功能
|
# 测试预测功能
|
||||||
await test_predictions(trend_model_id, anomaly_model_id)
|
await test_predictions(trend_model_id, anomaly_model_id)
|
||||||
|
|
||||||
# 测试知识图谱 RAG
|
# 测试知识图谱 RAG
|
||||||
rag_id = test_kg_rag()
|
rag_id = test_kg_rag()
|
||||||
|
|
||||||
# 测试 RAG 查询
|
# 测试 RAG 查询
|
||||||
await test_kg_rag_query(rag_id)
|
await test_kg_rag_query(rag_id)
|
||||||
|
|
||||||
# 测试智能摘要
|
# 测试智能摘要
|
||||||
await test_smart_summary()
|
await test_smart_summary()
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("所有测试完成!")
|
print("所有测试完成!")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n测试失败: {e}")
|
print(f"\n测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ InsightFlow Phase 8 Task 5 - 运营与增长工具测试脚本
|
|||||||
python test_phase8_task5.py
|
python test_phase8_task5.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from growth_manager import (
|
||||||
|
GrowthManager, EventType, ExperimentStatus, TrafficAllocationType, EmailTemplateType, WorkflowTriggerType
|
||||||
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
@@ -23,35 +26,28 @@ backend_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
if backend_dir not in sys.path:
|
if backend_dir not in sys.path:
|
||||||
sys.path.insert(0, backend_dir)
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
from growth_manager import (
|
|
||||||
get_growth_manager, GrowthManager, AnalyticsEvent, UserProfile, Funnel, FunnelAnalysis,
|
|
||||||
Experiment, EmailTemplate, EmailCampaign, ReferralProgram, Referral, TeamIncentive,
|
|
||||||
EventType, ExperimentStatus, TrafficAllocationType, EmailTemplateType,
|
|
||||||
EmailStatus, WorkflowTriggerType, ReferralStatus
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGrowthManager:
|
class TestGrowthManager:
|
||||||
"""测试 Growth Manager 功能"""
|
"""测试 Growth Manager 功能"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.manager = GrowthManager()
|
self.manager = GrowthManager()
|
||||||
self.test_tenant_id = "test_tenant_001"
|
self.test_tenant_id = "test_tenant_001"
|
||||||
self.test_user_id = "test_user_001"
|
self.test_user_id = "test_user_001"
|
||||||
self.test_results = []
|
self.test_results = []
|
||||||
|
|
||||||
def log(self, message: str, success: bool = True):
|
def log(self, message: str, success: bool = True):
|
||||||
"""记录测试结果"""
|
"""记录测试结果"""
|
||||||
status = "✅" if success else "❌"
|
status = "✅" if success else "❌"
|
||||||
print(f"{status} {message}")
|
print(f"{status} {message}")
|
||||||
self.test_results.append((message, success))
|
self.test_results.append((message, success))
|
||||||
|
|
||||||
# ==================== 测试用户行为分析 ====================
|
# ==================== 测试用户行为分析 ====================
|
||||||
|
|
||||||
async def test_track_event(self):
|
async def test_track_event(self):
|
||||||
"""测试事件追踪"""
|
"""测试事件追踪"""
|
||||||
print("\n📊 测试事件追踪...")
|
print("\n📊 测试事件追踪...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event = await self.manager.track_event(
|
event = await self.manager.track_event(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -64,21 +60,21 @@ class TestGrowthManager:
|
|||||||
referrer="https://google.com",
|
referrer="https://google.com",
|
||||||
utm_params={"source": "google", "medium": "organic", "campaign": "summer"}
|
utm_params={"source": "google", "medium": "organic", "campaign": "summer"}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert event.id is not None
|
assert event.id is not None
|
||||||
assert event.event_type == EventType.PAGE_VIEW
|
assert event.event_type == EventType.PAGE_VIEW
|
||||||
assert event.event_name == "dashboard_view"
|
assert event.event_name == "dashboard_view"
|
||||||
|
|
||||||
self.log(f"事件追踪成功: {event.id}")
|
self.log(f"事件追踪成功: {event.id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"事件追踪失败: {e}", success=False)
|
self.log(f"事件追踪失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def test_track_multiple_events(self):
|
async def test_track_multiple_events(self):
|
||||||
"""测试追踪多个事件"""
|
"""测试追踪多个事件"""
|
||||||
print("\n📊 测试追踪多个事件...")
|
print("\n📊 测试追踪多个事件...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
events = [
|
events = [
|
||||||
(EventType.FEATURE_USE, "entity_extraction", {"entity_count": 5}),
|
(EventType.FEATURE_USE, "entity_extraction", {"entity_count": 5}),
|
||||||
@@ -86,7 +82,7 @@ class TestGrowthManager:
|
|||||||
(EventType.CONVERSION, "upgrade_click", {"plan": "pro"}),
|
(EventType.CONVERSION, "upgrade_click", {"plan": "pro"}),
|
||||||
(EventType.SIGNUP, "user_registration", {"source": "referral"}),
|
(EventType.SIGNUP, "user_registration", {"source": "referral"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
for event_type, event_name, props in events:
|
for event_type, event_name, props in events:
|
||||||
await self.manager.track_event(
|
await self.manager.track_event(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -95,57 +91,57 @@ class TestGrowthManager:
|
|||||||
event_name=event_name,
|
event_name=event_name,
|
||||||
properties=props
|
properties=props
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"成功追踪 {len(events)} 个事件")
|
self.log(f"成功追踪 {len(events)} 个事件")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"批量事件追踪失败: {e}", success=False)
|
self.log(f"批量事件追踪失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_get_user_profile(self):
|
def test_get_user_profile(self):
|
||||||
"""测试获取用户画像"""
|
"""测试获取用户画像"""
|
||||||
print("\n👤 测试用户画像...")
|
print("\n👤 测试用户画像...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id)
|
profile = self.manager.get_user_profile(self.test_tenant_id, self.test_user_id)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
assert profile.user_id == self.test_user_id
|
assert profile.user_id == self.test_user_id
|
||||||
assert profile.total_events >= 0
|
assert profile.total_events >= 0
|
||||||
self.log(f"用户画像获取成功: {profile.user_id}, 事件数: {profile.total_events}")
|
self.log(f"用户画像获取成功: {profile.user_id}, 事件数: {profile.total_events}")
|
||||||
else:
|
else:
|
||||||
self.log("用户画像不存在(首次访问)")
|
self.log("用户画像不存在(首次访问)")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"获取用户画像失败: {e}", success=False)
|
self.log(f"获取用户画像失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_get_analytics_summary(self):
|
def test_get_analytics_summary(self):
|
||||||
"""测试获取分析汇总"""
|
"""测试获取分析汇总"""
|
||||||
print("\n📈 测试分析汇总...")
|
print("\n📈 测试分析汇总...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
summary = self.manager.get_user_analytics_summary(
|
summary = self.manager.get_user_analytics_summary(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
start_date=datetime.now() - timedelta(days=7),
|
start_date=datetime.now() - timedelta(days=7),
|
||||||
end_date=datetime.now()
|
end_date=datetime.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "unique_users" in summary
|
assert "unique_users" in summary
|
||||||
assert "total_events" in summary
|
assert "total_events" in summary
|
||||||
assert "event_type_distribution" in summary
|
assert "event_type_distribution" in summary
|
||||||
|
|
||||||
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
|
self.log(f"分析汇总: {summary['unique_users']} 用户, {summary['total_events']} 事件")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"获取分析汇总失败: {e}", success=False)
|
self.log(f"获取分析汇总失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_create_funnel(self):
|
def test_create_funnel(self):
|
||||||
"""测试创建转化漏斗"""
|
"""测试创建转化漏斗"""
|
||||||
print("\n🎯 测试创建转化漏斗...")
|
print("\n🎯 测试创建转化漏斗...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
funnel = self.manager.create_funnel(
|
funnel = self.manager.create_funnel(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -159,31 +155,31 @@ class TestGrowthManager:
|
|||||||
],
|
],
|
||||||
created_by="test"
|
created_by="test"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert funnel.id is not None
|
assert funnel.id is not None
|
||||||
assert len(funnel.steps) == 4
|
assert len(funnel.steps) == 4
|
||||||
|
|
||||||
self.log(f"漏斗创建成功: {funnel.id}")
|
self.log(f"漏斗创建成功: {funnel.id}")
|
||||||
return funnel.id
|
return funnel.id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建漏斗失败: {e}", success=False)
|
self.log(f"创建漏斗失败: {e}", success=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_analyze_funnel(self, funnel_id: str):
|
def test_analyze_funnel(self, funnel_id: str):
|
||||||
"""测试分析漏斗"""
|
"""测试分析漏斗"""
|
||||||
print("\n📉 测试漏斗分析...")
|
print("\n📉 测试漏斗分析...")
|
||||||
|
|
||||||
if not funnel_id:
|
if not funnel_id:
|
||||||
self.log("跳过漏斗分析(无漏斗ID)")
|
self.log("跳过漏斗分析(无漏斗ID)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
analysis = self.manager.analyze_funnel(
|
analysis = self.manager.analyze_funnel(
|
||||||
funnel_id=funnel_id,
|
funnel_id=funnel_id,
|
||||||
period_start=datetime.now() - timedelta(days=30),
|
period_start=datetime.now() - timedelta(days=30),
|
||||||
period_end=datetime.now()
|
period_end=datetime.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
if analysis:
|
if analysis:
|
||||||
assert "step_conversions" in analysis.__dict__
|
assert "step_conversions" in analysis.__dict__
|
||||||
self.log(f"漏斗分析完成: 总体转化率 {analysis.overall_conversion:.2%}")
|
self.log(f"漏斗分析完成: 总体转化率 {analysis.overall_conversion:.2%}")
|
||||||
@@ -194,33 +190,33 @@ class TestGrowthManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"漏斗分析失败: {e}", success=False)
|
self.log(f"漏斗分析失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_calculate_retention(self):
|
def test_calculate_retention(self):
|
||||||
"""测试留存率计算"""
|
"""测试留存率计算"""
|
||||||
print("\n🔄 测试留存率计算...")
|
print("\n🔄 测试留存率计算...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
retention = self.manager.calculate_retention(
|
retention = self.manager.calculate_retention(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
cohort_date=datetime.now() - timedelta(days=7),
|
cohort_date=datetime.now() - timedelta(days=7),
|
||||||
periods=[1, 3, 7]
|
periods=[1, 3, 7]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "cohort_date" in retention
|
assert "cohort_date" in retention
|
||||||
assert "retention" in retention
|
assert "retention" in retention
|
||||||
|
|
||||||
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
|
self.log(f"留存率计算完成: 同期群 {retention['cohort_size']} 用户")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"留存率计算失败: {e}", success=False)
|
self.log(f"留存率计算失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ==================== 测试 A/B 测试框架 ====================
|
# ==================== 测试 A/B 测试框架 ====================
|
||||||
|
|
||||||
def test_create_experiment(self):
|
def test_create_experiment(self):
|
||||||
"""测试创建实验"""
|
"""测试创建实验"""
|
||||||
print("\n🧪 测试创建 A/B 测试实验...")
|
print("\n🧪 测试创建 A/B 测试实验...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
experiment = self.manager.create_experiment(
|
experiment = self.manager.create_experiment(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -241,69 +237,69 @@ class TestGrowthManager:
|
|||||||
confidence_level=0.95,
|
confidence_level=0.95,
|
||||||
created_by="test"
|
created_by="test"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert experiment.id is not None
|
assert experiment.id is not None
|
||||||
assert experiment.status == ExperimentStatus.DRAFT
|
assert experiment.status == ExperimentStatus.DRAFT
|
||||||
|
|
||||||
self.log(f"实验创建成功: {experiment.id}")
|
self.log(f"实验创建成功: {experiment.id}")
|
||||||
return experiment.id
|
return experiment.id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建实验失败: {e}", success=False)
|
self.log(f"创建实验失败: {e}", success=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_list_experiments(self):
|
def test_list_experiments(self):
|
||||||
"""测试列出实验"""
|
"""测试列出实验"""
|
||||||
print("\n📋 测试列出实验...")
|
print("\n📋 测试列出实验...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
experiments = self.manager.list_experiments(self.test_tenant_id)
|
experiments = self.manager.list_experiments(self.test_tenant_id)
|
||||||
|
|
||||||
self.log(f"列出 {len(experiments)} 个实验")
|
self.log(f"列出 {len(experiments)} 个实验")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"列出实验失败: {e}", success=False)
|
self.log(f"列出实验失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_assign_variant(self, experiment_id: str):
|
def test_assign_variant(self, experiment_id: str):
|
||||||
"""测试分配变体"""
|
"""测试分配变体"""
|
||||||
print("\n🎲 测试分配实验变体...")
|
print("\n🎲 测试分配实验变体...")
|
||||||
|
|
||||||
if not experiment_id:
|
if not experiment_id:
|
||||||
self.log("跳过变体分配(无实验ID)")
|
self.log("跳过变体分配(无实验ID)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 先启动实验
|
# 先启动实验
|
||||||
self.manager.start_experiment(experiment_id)
|
self.manager.start_experiment(experiment_id)
|
||||||
|
|
||||||
# 测试多个用户的变体分配
|
# 测试多个用户的变体分配
|
||||||
test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"]
|
test_users = ["user_001", "user_002", "user_003", "user_004", "user_005"]
|
||||||
assignments = {}
|
assignments = {}
|
||||||
|
|
||||||
for user_id in test_users:
|
for user_id in test_users:
|
||||||
variant_id = self.manager.assign_variant(
|
variant_id = self.manager.assign_variant(
|
||||||
experiment_id=experiment_id,
|
experiment_id=experiment_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_attributes={"user_id": user_id, "segment": "new"}
|
user_attributes={"user_id": user_id, "segment": "new"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if variant_id:
|
if variant_id:
|
||||||
assignments[user_id] = variant_id
|
assignments[user_id] = variant_id
|
||||||
|
|
||||||
self.log(f"变体分配完成: {len(assignments)} 个用户")
|
self.log(f"变体分配完成: {len(assignments)} 个用户")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"变体分配失败: {e}", success=False)
|
self.log(f"变体分配失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_record_experiment_metric(self, experiment_id: str):
|
def test_record_experiment_metric(self, experiment_id: str):
|
||||||
"""测试记录实验指标"""
|
"""测试记录实验指标"""
|
||||||
print("\n📊 测试记录实验指标...")
|
print("\n📊 测试记录实验指标...")
|
||||||
|
|
||||||
if not experiment_id:
|
if not experiment_id:
|
||||||
self.log("跳过指标记录(无实验ID)")
|
self.log("跳过指标记录(无实验ID)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 模拟记录一些指标
|
# 模拟记录一些指标
|
||||||
test_data = [
|
test_data = [
|
||||||
@@ -313,7 +309,7 @@ class TestGrowthManager:
|
|||||||
("user_004", "control", 1),
|
("user_004", "control", 1),
|
||||||
("user_005", "variant_a", 1),
|
("user_005", "variant_a", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
for user_id, variant_id, value in test_data:
|
for user_id, variant_id, value in test_data:
|
||||||
self.manager.record_experiment_metric(
|
self.manager.record_experiment_metric(
|
||||||
experiment_id=experiment_id,
|
experiment_id=experiment_id,
|
||||||
@@ -322,24 +318,24 @@ class TestGrowthManager:
|
|||||||
metric_name="button_click_rate",
|
metric_name="button_click_rate",
|
||||||
metric_value=value
|
metric_value=value
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"成功记录 {len(test_data)} 条指标")
|
self.log(f"成功记录 {len(test_data)} 条指标")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"记录指标失败: {e}", success=False)
|
self.log(f"记录指标失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_analyze_experiment(self, experiment_id: str):
|
def test_analyze_experiment(self, experiment_id: str):
|
||||||
"""测试分析实验结果"""
|
"""测试分析实验结果"""
|
||||||
print("\n📈 测试分析实验结果...")
|
print("\n📈 测试分析实验结果...")
|
||||||
|
|
||||||
if not experiment_id:
|
if not experiment_id:
|
||||||
self.log("跳过实验分析(无实验ID)")
|
self.log("跳过实验分析(无实验ID)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.manager.analyze_experiment(experiment_id)
|
result = self.manager.analyze_experiment(experiment_id)
|
||||||
|
|
||||||
if "error" not in result:
|
if "error" not in result:
|
||||||
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
|
self.log(f"实验分析完成: {len(result.get('variant_results', {}))} 个变体")
|
||||||
return True
|
return True
|
||||||
@@ -349,13 +345,13 @@ class TestGrowthManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"实验分析失败: {e}", success=False)
|
self.log(f"实验分析失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ==================== 测试邮件营销 ====================
|
# ==================== 测试邮件营销 ====================
|
||||||
|
|
||||||
def test_create_email_template(self):
|
def test_create_email_template(self):
|
||||||
"""测试创建邮件模板"""
|
"""测试创建邮件模板"""
|
||||||
print("\n📧 测试创建邮件模板...")
|
print("\n📧 测试创建邮件模板...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
template = self.manager.create_email_template(
|
template = self.manager.create_email_template(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -376,37 +372,37 @@ class TestGrowthManager:
|
|||||||
from_name="InsightFlow 团队",
|
from_name="InsightFlow 团队",
|
||||||
from_email="welcome@insightflow.io"
|
from_email="welcome@insightflow.io"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert template.id is not None
|
assert template.id is not None
|
||||||
assert template.template_type == EmailTemplateType.WELCOME
|
assert template.template_type == EmailTemplateType.WELCOME
|
||||||
|
|
||||||
self.log(f"邮件模板创建成功: {template.id}")
|
self.log(f"邮件模板创建成功: {template.id}")
|
||||||
return template.id
|
return template.id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建邮件模板失败: {e}", success=False)
|
self.log(f"创建邮件模板失败: {e}", success=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_list_email_templates(self):
|
def test_list_email_templates(self):
|
||||||
"""测试列出邮件模板"""
|
"""测试列出邮件模板"""
|
||||||
print("\n📧 测试列出邮件模板...")
|
print("\n📧 测试列出邮件模板...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
templates = self.manager.list_email_templates(self.test_tenant_id)
|
templates = self.manager.list_email_templates(self.test_tenant_id)
|
||||||
|
|
||||||
self.log(f"列出 {len(templates)} 个邮件模板")
|
self.log(f"列出 {len(templates)} 个邮件模板")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"列出邮件模板失败: {e}", success=False)
|
self.log(f"列出邮件模板失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_render_template(self, template_id: str):
|
def test_render_template(self, template_id: str):
|
||||||
"""测试渲染邮件模板"""
|
"""测试渲染邮件模板"""
|
||||||
print("\n🎨 测试渲染邮件模板...")
|
print("\n🎨 测试渲染邮件模板...")
|
||||||
|
|
||||||
if not template_id:
|
if not template_id:
|
||||||
self.log("跳过模板渲染(无模板ID)")
|
self.log("跳过模板渲染(无模板ID)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rendered = self.manager.render_template(
|
rendered = self.manager.render_template(
|
||||||
template_id=template_id,
|
template_id=template_id,
|
||||||
@@ -415,7 +411,7 @@ class TestGrowthManager:
|
|||||||
"dashboard_url": "https://app.insightflow.io/dashboard"
|
"dashboard_url": "https://app.insightflow.io/dashboard"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if rendered:
|
if rendered:
|
||||||
assert "subject" in rendered
|
assert "subject" in rendered
|
||||||
assert "html" in rendered
|
assert "html" in rendered
|
||||||
@@ -427,15 +423,15 @@ class TestGrowthManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"模板渲染失败: {e}", success=False)
|
self.log(f"模板渲染失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_create_email_campaign(self, template_id: str):
|
def test_create_email_campaign(self, template_id: str):
|
||||||
"""测试创建邮件营销活动"""
|
"""测试创建邮件营销活动"""
|
||||||
print("\n📮 测试创建邮件营销活动...")
|
print("\n📮 测试创建邮件营销活动...")
|
||||||
|
|
||||||
if not template_id:
|
if not template_id:
|
||||||
self.log("跳过创建营销活动(无模板ID)")
|
self.log("跳过创建营销活动(无模板ID)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
campaign = self.manager.create_email_campaign(
|
campaign = self.manager.create_email_campaign(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -447,20 +443,20 @@ class TestGrowthManager:
|
|||||||
{"user_id": "user_003", "email": "user3@example.com"}
|
{"user_id": "user_003", "email": "user3@example.com"}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert campaign.id is not None
|
assert campaign.id is not None
|
||||||
assert campaign.recipient_count == 3
|
assert campaign.recipient_count == 3
|
||||||
|
|
||||||
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
|
self.log(f"营销活动创建成功: {campaign.id}, {campaign.recipient_count} 收件人")
|
||||||
return campaign.id
|
return campaign.id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建营销活动失败: {e}", success=False)
|
self.log(f"创建营销活动失败: {e}", success=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_create_automation_workflow(self):
|
def test_create_automation_workflow(self):
|
||||||
"""测试创建自动化工作流"""
|
"""测试创建自动化工作流"""
|
||||||
print("\n🤖 测试创建自动化工作流...")
|
print("\n🤖 测试创建自动化工作流...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workflow = self.manager.create_automation_workflow(
|
workflow = self.manager.create_automation_workflow(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -474,22 +470,22 @@ class TestGrowthManager:
|
|||||||
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72}
|
{"type": "send_email", "template_type": "feature_tips", "delay_hours": 72}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert workflow.id is not None
|
assert workflow.id is not None
|
||||||
assert workflow.trigger_type == WorkflowTriggerType.USER_SIGNUP
|
assert workflow.trigger_type == WorkflowTriggerType.USER_SIGNUP
|
||||||
|
|
||||||
self.log(f"自动化工作流创建成功: {workflow.id}")
|
self.log(f"自动化工作流创建成功: {workflow.id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建工作流失败: {e}", success=False)
|
self.log(f"创建工作流失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ==================== 测试推荐系统 ====================
|
# ==================== 测试推荐系统 ====================
|
||||||
|
|
||||||
def test_create_referral_program(self):
|
def test_create_referral_program(self):
|
||||||
"""测试创建推荐计划"""
|
"""测试创建推荐计划"""
|
||||||
print("\n🎁 测试创建推荐计划...")
|
print("\n🎁 测试创建推荐计划...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
program = self.manager.create_referral_program(
|
program = self.manager.create_referral_program(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -503,34 +499,34 @@ class TestGrowthManager:
|
|||||||
referral_code_length=8,
|
referral_code_length=8,
|
||||||
expiry_days=30
|
expiry_days=30
|
||||||
)
|
)
|
||||||
|
|
||||||
assert program.id is not None
|
assert program.id is not None
|
||||||
assert program.referrer_reward_value == 100.0
|
assert program.referrer_reward_value == 100.0
|
||||||
|
|
||||||
self.log(f"推荐计划创建成功: {program.id}")
|
self.log(f"推荐计划创建成功: {program.id}")
|
||||||
return program.id
|
return program.id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建推荐计划失败: {e}", success=False)
|
self.log(f"创建推荐计划失败: {e}", success=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_generate_referral_code(self, program_id: str):
|
def test_generate_referral_code(self, program_id: str):
|
||||||
"""测试生成推荐码"""
|
"""测试生成推荐码"""
|
||||||
print("\n🔑 测试生成推荐码...")
|
print("\n🔑 测试生成推荐码...")
|
||||||
|
|
||||||
if not program_id:
|
if not program_id:
|
||||||
self.log("跳过生成推荐码(无计划ID)")
|
self.log("跳过生成推荐码(无计划ID)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
referral = self.manager.generate_referral_code(
|
referral = self.manager.generate_referral_code(
|
||||||
program_id=program_id,
|
program_id=program_id,
|
||||||
referrer_id="referrer_user_001"
|
referrer_id="referrer_user_001"
|
||||||
)
|
)
|
||||||
|
|
||||||
if referral:
|
if referral:
|
||||||
assert referral.referral_code is not None
|
assert referral.referral_code is not None
|
||||||
assert len(referral.referral_code) == 8
|
assert len(referral.referral_code) == 8
|
||||||
|
|
||||||
self.log(f"推荐码生成成功: {referral.referral_code}")
|
self.log(f"推荐码生成成功: {referral.referral_code}")
|
||||||
return referral.referral_code
|
return referral.referral_code
|
||||||
else:
|
else:
|
||||||
@@ -539,21 +535,21 @@ class TestGrowthManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"生成推荐码失败: {e}", success=False)
|
self.log(f"生成推荐码失败: {e}", success=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test_apply_referral_code(self, referral_code: str):
|
def test_apply_referral_code(self, referral_code: str):
|
||||||
"""测试应用推荐码"""
|
"""测试应用推荐码"""
|
||||||
print("\n✅ 测试应用推荐码...")
|
print("\n✅ 测试应用推荐码...")
|
||||||
|
|
||||||
if not referral_code:
|
if not referral_code:
|
||||||
self.log("跳过应用推荐码(无推荐码)")
|
self.log("跳过应用推荐码(无推荐码)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = self.manager.apply_referral_code(
|
success = self.manager.apply_referral_code(
|
||||||
referral_code=referral_code,
|
referral_code=referral_code,
|
||||||
referee_id="new_user_001"
|
referee_id="new_user_001"
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
self.log(f"推荐码应用成功: {referral_code}")
|
self.log(f"推荐码应用成功: {referral_code}")
|
||||||
return True
|
return True
|
||||||
@@ -563,31 +559,31 @@ class TestGrowthManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"应用推荐码失败: {e}", success=False)
|
self.log(f"应用推荐码失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_get_referral_stats(self, program_id: str):
|
def test_get_referral_stats(self, program_id: str):
|
||||||
"""测试获取推荐统计"""
|
"""测试获取推荐统计"""
|
||||||
print("\n📊 测试获取推荐统计...")
|
print("\n📊 测试获取推荐统计...")
|
||||||
|
|
||||||
if not program_id:
|
if not program_id:
|
||||||
self.log("跳过推荐统计(无计划ID)")
|
self.log("跳过推荐统计(无计划ID)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stats = self.manager.get_referral_stats(program_id)
|
stats = self.manager.get_referral_stats(program_id)
|
||||||
|
|
||||||
assert "total_referrals" in stats
|
assert "total_referrals" in stats
|
||||||
assert "conversion_rate" in stats
|
assert "conversion_rate" in stats
|
||||||
|
|
||||||
self.log(f"推荐统计: {stats['total_referrals']} 推荐, {stats['conversion_rate']:.2%} 转化率")
|
self.log(f"推荐统计: {stats['total_referrals']} 推荐, {stats['conversion_rate']:.2%} 转化率")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"获取推荐统计失败: {e}", success=False)
|
self.log(f"获取推荐统计失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_create_team_incentive(self):
|
def test_create_team_incentive(self):
|
||||||
"""测试创建团队激励"""
|
"""测试创建团队激励"""
|
||||||
print("\n🏆 测试创建团队升级激励...")
|
print("\n🏆 测试创建团队升级激励...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
incentive = self.manager.create_team_incentive(
|
incentive = self.manager.create_team_incentive(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
@@ -600,66 +596,66 @@ class TestGrowthManager:
|
|||||||
valid_from=datetime.now(),
|
valid_from=datetime.now(),
|
||||||
valid_until=datetime.now() + timedelta(days=90)
|
valid_until=datetime.now() + timedelta(days=90)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert incentive.id is not None
|
assert incentive.id is not None
|
||||||
assert incentive.incentive_value == 20.0
|
assert incentive.incentive_value == 20.0
|
||||||
|
|
||||||
self.log(f"团队激励创建成功: {incentive.id}")
|
self.log(f"团队激励创建成功: {incentive.id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"创建团队激励失败: {e}", success=False)
|
self.log(f"创建团队激励失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def test_check_team_incentive_eligibility(self):
|
def test_check_team_incentive_eligibility(self):
|
||||||
"""测试检查团队激励资格"""
|
"""测试检查团队激励资格"""
|
||||||
print("\n🔍 测试检查团队激励资格...")
|
print("\n🔍 测试检查团队激励资格...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
incentives = self.manager.check_team_incentive_eligibility(
|
incentives = self.manager.check_team_incentive_eligibility(
|
||||||
tenant_id=self.test_tenant_id,
|
tenant_id=self.test_tenant_id,
|
||||||
current_tier="free",
|
current_tier="free",
|
||||||
team_size=5
|
team_size=5
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"找到 {len(incentives)} 个符合条件的激励")
|
self.log(f"找到 {len(incentives)} 个符合条件的激励")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"检查激励资格失败: {e}", success=False)
|
self.log(f"检查激励资格失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ==================== 测试实时仪表板 ====================
|
# ==================== 测试实时仪表板 ====================
|
||||||
|
|
||||||
def test_get_realtime_dashboard(self):
|
def test_get_realtime_dashboard(self):
|
||||||
"""测试获取实时仪表板"""
|
"""测试获取实时仪表板"""
|
||||||
print("\n📺 测试实时分析仪表板...")
|
print("\n📺 测试实时分析仪表板...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id)
|
dashboard = self.manager.get_realtime_dashboard(self.test_tenant_id)
|
||||||
|
|
||||||
assert "today" in dashboard
|
assert "today" in dashboard
|
||||||
assert "recent_events" in dashboard
|
assert "recent_events" in dashboard
|
||||||
assert "top_features" in dashboard
|
assert "top_features" in dashboard
|
||||||
|
|
||||||
today = dashboard["today"]
|
today = dashboard["today"]
|
||||||
self.log(f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件")
|
self.log(f"实时仪表板: 今日 {today['active_users']} 活跃用户, {today['total_events']} 事件")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"获取实时仪表板失败: {e}", success=False)
|
self.log(f"获取实时仪表板失败: {e}", success=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ==================== 运行所有测试 ====================
|
# ==================== 运行所有测试 ====================
|
||||||
|
|
||||||
async def run_all_tests(self):
|
async def run_all_tests(self):
|
||||||
"""运行所有测试"""
|
"""运行所有测试"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试")
|
print("🚀 InsightFlow Phase 8 Task 5 - 运营与增长工具测试")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 用户行为分析测试
|
# 用户行为分析测试
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("📊 模块 1: 用户行为分析")
|
print("📊 模块 1: 用户行为分析")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
await self.test_track_event()
|
await self.test_track_event()
|
||||||
await self.test_track_multiple_events()
|
await self.test_track_multiple_events()
|
||||||
self.test_get_user_profile()
|
self.test_get_user_profile()
|
||||||
@@ -667,68 +663,68 @@ class TestGrowthManager:
|
|||||||
funnel_id = self.test_create_funnel()
|
funnel_id = self.test_create_funnel()
|
||||||
self.test_analyze_funnel(funnel_id)
|
self.test_analyze_funnel(funnel_id)
|
||||||
self.test_calculate_retention()
|
self.test_calculate_retention()
|
||||||
|
|
||||||
# A/B 测试框架测试
|
# A/B 测试框架测试
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("🧪 模块 2: A/B 测试框架")
|
print("🧪 模块 2: A/B 测试框架")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
experiment_id = self.test_create_experiment()
|
experiment_id = self.test_create_experiment()
|
||||||
self.test_list_experiments()
|
self.test_list_experiments()
|
||||||
self.test_assign_variant(experiment_id)
|
self.test_assign_variant(experiment_id)
|
||||||
self.test_record_experiment_metric(experiment_id)
|
self.test_record_experiment_metric(experiment_id)
|
||||||
self.test_analyze_experiment(experiment_id)
|
self.test_analyze_experiment(experiment_id)
|
||||||
|
|
||||||
# 邮件营销测试
|
# 邮件营销测试
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("📧 模块 3: 邮件营销自动化")
|
print("📧 模块 3: 邮件营销自动化")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
template_id = self.test_create_email_template()
|
template_id = self.test_create_email_template()
|
||||||
self.test_list_email_templates()
|
self.test_list_email_templates()
|
||||||
self.test_render_template(template_id)
|
self.test_render_template(template_id)
|
||||||
campaign_id = self.test_create_email_campaign(template_id)
|
self.test_create_email_campaign(template_id)
|
||||||
self.test_create_automation_workflow()
|
self.test_create_automation_workflow()
|
||||||
|
|
||||||
# 推荐系统测试
|
# 推荐系统测试
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("🎁 模块 4: 推荐系统")
|
print("🎁 模块 4: 推荐系统")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
program_id = self.test_create_referral_program()
|
program_id = self.test_create_referral_program()
|
||||||
referral_code = self.test_generate_referral_code(program_id)
|
referral_code = self.test_generate_referral_code(program_id)
|
||||||
self.test_apply_referral_code(referral_code)
|
self.test_apply_referral_code(referral_code)
|
||||||
self.test_get_referral_stats(program_id)
|
self.test_get_referral_stats(program_id)
|
||||||
self.test_create_team_incentive()
|
self.test_create_team_incentive()
|
||||||
self.test_check_team_incentive_eligibility()
|
self.test_check_team_incentive_eligibility()
|
||||||
|
|
||||||
# 实时仪表板测试
|
# 实时仪表板测试
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("📺 模块 5: 实时分析仪表板")
|
print("📺 模块 5: 实时分析仪表板")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
self.test_get_realtime_dashboard()
|
self.test_get_realtime_dashboard()
|
||||||
|
|
||||||
# 测试总结
|
# 测试总结
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("📋 测试总结")
|
print("📋 测试总结")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
total_tests = len(self.test_results)
|
total_tests = len(self.test_results)
|
||||||
passed_tests = sum(1 for _, success in self.test_results if success)
|
passed_tests = sum(1 for _, success in self.test_results if success)
|
||||||
failed_tests = total_tests - passed_tests
|
failed_tests = total_tests - passed_tests
|
||||||
|
|
||||||
print(f"总测试数: {total_tests}")
|
print(f"总测试数: {total_tests}")
|
||||||
print(f"通过: {passed_tests} ✅")
|
print(f"通过: {passed_tests} ✅")
|
||||||
print(f"失败: {failed_tests} ❌")
|
print(f"失败: {failed_tests} ❌")
|
||||||
print(f"通过率: {passed_tests / total_tests * 100:.1f}%" if total_tests > 0 else "N/A")
|
print(f"通过率: {passed_tests / total_tests * 100:.1f}%" if total_tests > 0 else "N/A")
|
||||||
|
|
||||||
if failed_tests > 0:
|
if failed_tests > 0:
|
||||||
print("\n失败的测试:")
|
print("\n失败的测试:")
|
||||||
for message, success in self.test_results:
|
for message, success in self.test_results:
|
||||||
if not success:
|
if not success:
|
||||||
print(f" - {message}")
|
print(f" - {message}")
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("✨ 测试完成!")
|
print("✨ 测试完成!")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|||||||
@@ -10,7 +10,12 @@ InsightFlow Phase 8 Task 6: Developer Ecosystem Test Script
|
|||||||
4. 开发者文档与示例代码
|
4. 开发者文档与示例代码
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
from developer_ecosystem_manager import (
|
||||||
|
DeveloperEcosystemManager,
|
||||||
|
SDKLanguage, TemplateCategory,
|
||||||
|
PluginCategory, PluginStatus,
|
||||||
|
DeveloperStatus
|
||||||
|
)
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@@ -21,18 +26,10 @@ backend_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
if backend_dir not in sys.path:
|
if backend_dir not in sys.path:
|
||||||
sys.path.insert(0, backend_dir)
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
from developer_ecosystem_manager import (
|
|
||||||
DeveloperEcosystemManager,
|
|
||||||
SDKLanguage, SDKStatus,
|
|
||||||
TemplateCategory, TemplateStatus,
|
|
||||||
PluginCategory, PluginStatus,
|
|
||||||
DeveloperStatus
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeveloperEcosystem:
|
class TestDeveloperEcosystem:
|
||||||
"""开发者生态系统测试类"""
|
"""开发者生态系统测试类"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.manager = DeveloperEcosystemManager()
|
self.manager = DeveloperEcosystemManager()
|
||||||
self.test_results = []
|
self.test_results = []
|
||||||
@@ -44,7 +41,7 @@ class TestDeveloperEcosystem:
|
|||||||
'code_example': [],
|
'code_example': [],
|
||||||
'portal_config': []
|
'portal_config': []
|
||||||
}
|
}
|
||||||
|
|
||||||
def log(self, message: str, success: bool = True):
|
def log(self, message: str, success: bool = True):
|
||||||
"""记录测试结果"""
|
"""记录测试结果"""
|
||||||
status = "✅" if success else "❌"
|
status = "✅" if success else "❌"
|
||||||
@@ -54,13 +51,13 @@ class TestDeveloperEcosystem:
|
|||||||
'success': success,
|
'success': success,
|
||||||
'timestamp': datetime.now().isoformat()
|
'timestamp': datetime.now().isoformat()
|
||||||
})
|
})
|
||||||
|
|
||||||
def run_all_tests(self):
|
def run_all_tests(self):
|
||||||
"""运行所有测试"""
|
"""运行所有测试"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests")
|
print("InsightFlow Phase 8 Task 6: Developer Ecosystem Tests")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# SDK Tests
|
# SDK Tests
|
||||||
print("\n📦 SDK Release & Management Tests")
|
print("\n📦 SDK Release & Management Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
@@ -70,7 +67,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.test_sdk_update()
|
self.test_sdk_update()
|
||||||
self.test_sdk_publish()
|
self.test_sdk_publish()
|
||||||
self.test_sdk_version_add()
|
self.test_sdk_version_add()
|
||||||
|
|
||||||
# Template Market Tests
|
# Template Market Tests
|
||||||
print("\n📋 Template Market Tests")
|
print("\n📋 Template Market Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
@@ -80,7 +77,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.test_template_approve()
|
self.test_template_approve()
|
||||||
self.test_template_publish()
|
self.test_template_publish()
|
||||||
self.test_template_review()
|
self.test_template_review()
|
||||||
|
|
||||||
# Plugin Market Tests
|
# Plugin Market Tests
|
||||||
print("\n🔌 Plugin Market Tests")
|
print("\n🔌 Plugin Market Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
@@ -90,7 +87,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.test_plugin_review()
|
self.test_plugin_review()
|
||||||
self.test_plugin_publish()
|
self.test_plugin_publish()
|
||||||
self.test_plugin_review_add()
|
self.test_plugin_review_add()
|
||||||
|
|
||||||
# Developer Profile Tests
|
# Developer Profile Tests
|
||||||
print("\n👤 Developer Profile Tests")
|
print("\n👤 Developer Profile Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
@@ -98,29 +95,29 @@ class TestDeveloperEcosystem:
|
|||||||
self.test_developer_profile_get()
|
self.test_developer_profile_get()
|
||||||
self.test_developer_verify()
|
self.test_developer_verify()
|
||||||
self.test_developer_stats_update()
|
self.test_developer_stats_update()
|
||||||
|
|
||||||
# Code Examples Tests
|
# Code Examples Tests
|
||||||
print("\n💻 Code Examples Tests")
|
print("\n💻 Code Examples Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
self.test_code_example_create()
|
self.test_code_example_create()
|
||||||
self.test_code_example_list()
|
self.test_code_example_list()
|
||||||
self.test_code_example_get()
|
self.test_code_example_get()
|
||||||
|
|
||||||
# Portal Config Tests
|
# Portal Config Tests
|
||||||
print("\n🌐 Developer Portal Tests")
|
print("\n🌐 Developer Portal Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
self.test_portal_config_create()
|
self.test_portal_config_create()
|
||||||
self.test_portal_config_get()
|
self.test_portal_config_get()
|
||||||
|
|
||||||
# Revenue Tests
|
# Revenue Tests
|
||||||
print("\n💰 Developer Revenue Tests")
|
print("\n💰 Developer Revenue Tests")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
self.test_revenue_record()
|
self.test_revenue_record()
|
||||||
self.test_revenue_summary()
|
self.test_revenue_summary()
|
||||||
|
|
||||||
# Print Summary
|
# Print Summary
|
||||||
self.print_summary()
|
self.print_summary()
|
||||||
|
|
||||||
def test_sdk_create(self):
|
def test_sdk_create(self):
|
||||||
"""测试创建 SDK"""
|
"""测试创建 SDK"""
|
||||||
try:
|
try:
|
||||||
@@ -142,7 +139,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['sdk'].append(sdk.id)
|
self.created_ids['sdk'].append(sdk.id)
|
||||||
self.log(f"Created SDK: {sdk.name} ({sdk.id})")
|
self.log(f"Created SDK: {sdk.name} ({sdk.id})")
|
||||||
|
|
||||||
# Create JavaScript SDK
|
# Create JavaScript SDK
|
||||||
sdk_js = self.manager.create_sdk_release(
|
sdk_js = self.manager.create_sdk_release(
|
||||||
name="InsightFlow JavaScript SDK",
|
name="InsightFlow JavaScript SDK",
|
||||||
@@ -162,27 +159,27 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['sdk'].append(sdk_js.id)
|
self.created_ids['sdk'].append(sdk_js.id)
|
||||||
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
|
self.log(f"Created SDK: {sdk_js.name} ({sdk_js.id})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create SDK: {str(e)}", success=False)
|
self.log(f"Failed to create SDK: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_sdk_list(self):
|
def test_sdk_list(self):
|
||||||
"""测试列出 SDK"""
|
"""测试列出 SDK"""
|
||||||
try:
|
try:
|
||||||
sdks = self.manager.list_sdk_releases()
|
sdks = self.manager.list_sdk_releases()
|
||||||
self.log(f"Listed {len(sdks)} SDKs")
|
self.log(f"Listed {len(sdks)} SDKs")
|
||||||
|
|
||||||
# Test filter by language
|
# Test filter by language
|
||||||
python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON)
|
python_sdks = self.manager.list_sdk_releases(language=SDKLanguage.PYTHON)
|
||||||
self.log(f"Found {len(python_sdks)} Python SDKs")
|
self.log(f"Found {len(python_sdks)} Python SDKs")
|
||||||
|
|
||||||
# Test search
|
# Test search
|
||||||
search_results = self.manager.list_sdk_releases(search="Python")
|
search_results = self.manager.list_sdk_releases(search="Python")
|
||||||
self.log(f"Search found {len(search_results)} SDKs")
|
self.log(f"Search found {len(search_results)} SDKs")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list SDKs: {str(e)}", success=False)
|
self.log(f"Failed to list SDKs: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_sdk_get(self):
|
def test_sdk_get(self):
|
||||||
"""测试获取 SDK 详情"""
|
"""测试获取 SDK 详情"""
|
||||||
try:
|
try:
|
||||||
@@ -194,7 +191,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log("SDK not found", success=False)
|
self.log("SDK not found", success=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get SDK: {str(e)}", success=False)
|
self.log(f"Failed to get SDK: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_sdk_update(self):
|
def test_sdk_update(self):
|
||||||
"""测试更新 SDK"""
|
"""测试更新 SDK"""
|
||||||
try:
|
try:
|
||||||
@@ -207,7 +204,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Updated SDK: {sdk.name}")
|
self.log(f"Updated SDK: {sdk.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to update SDK: {str(e)}", success=False)
|
self.log(f"Failed to update SDK: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_sdk_publish(self):
|
def test_sdk_publish(self):
|
||||||
"""测试发布 SDK"""
|
"""测试发布 SDK"""
|
||||||
try:
|
try:
|
||||||
@@ -217,7 +214,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
|
self.log(f"Published SDK: {sdk.name} (status: {sdk.status.value})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to publish SDK: {str(e)}", success=False)
|
self.log(f"Failed to publish SDK: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_sdk_version_add(self):
|
def test_sdk_version_add(self):
|
||||||
"""测试添加 SDK 版本"""
|
"""测试添加 SDK 版本"""
|
||||||
try:
|
try:
|
||||||
@@ -234,7 +231,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Added SDK version: {version.version}")
|
self.log(f"Added SDK version: {version.version}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to add SDK version: {str(e)}", success=False)
|
self.log(f"Failed to add SDK version: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_template_create(self):
|
def test_template_create(self):
|
||||||
"""测试创建模板"""
|
"""测试创建模板"""
|
||||||
try:
|
try:
|
||||||
@@ -259,7 +256,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['template'].append(template.id)
|
self.created_ids['template'].append(template.id)
|
||||||
self.log(f"Created template: {template.name} ({template.id})")
|
self.log(f"Created template: {template.name} ({template.id})")
|
||||||
|
|
||||||
# Create free template
|
# Create free template
|
||||||
template_free = self.manager.create_template(
|
template_free = self.manager.create_template(
|
||||||
name="通用实体识别模板",
|
name="通用实体识别模板",
|
||||||
@@ -274,27 +271,27 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['template'].append(template_free.id)
|
self.created_ids['template'].append(template_free.id)
|
||||||
self.log(f"Created free template: {template_free.name}")
|
self.log(f"Created free template: {template_free.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create template: {str(e)}", success=False)
|
self.log(f"Failed to create template: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_template_list(self):
|
def test_template_list(self):
|
||||||
"""测试列出模板"""
|
"""测试列出模板"""
|
||||||
try:
|
try:
|
||||||
templates = self.manager.list_templates()
|
templates = self.manager.list_templates()
|
||||||
self.log(f"Listed {len(templates)} templates")
|
self.log(f"Listed {len(templates)} templates")
|
||||||
|
|
||||||
# Filter by category
|
# Filter by category
|
||||||
medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL)
|
medical_templates = self.manager.list_templates(category=TemplateCategory.MEDICAL)
|
||||||
self.log(f"Found {len(medical_templates)} medical templates")
|
self.log(f"Found {len(medical_templates)} medical templates")
|
||||||
|
|
||||||
# Filter by price
|
# Filter by price
|
||||||
free_templates = self.manager.list_templates(max_price=0)
|
free_templates = self.manager.list_templates(max_price=0)
|
||||||
self.log(f"Found {len(free_templates)} free templates")
|
self.log(f"Found {len(free_templates)} free templates")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list templates: {str(e)}", success=False)
|
self.log(f"Failed to list templates: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_template_get(self):
|
def test_template_get(self):
|
||||||
"""测试获取模板详情"""
|
"""测试获取模板详情"""
|
||||||
try:
|
try:
|
||||||
@@ -304,7 +301,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Retrieved template: {template.name}")
|
self.log(f"Retrieved template: {template.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get template: {str(e)}", success=False)
|
self.log(f"Failed to get template: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_template_approve(self):
|
def test_template_approve(self):
|
||||||
"""测试审核通过模板"""
|
"""测试审核通过模板"""
|
||||||
try:
|
try:
|
||||||
@@ -317,7 +314,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Approved template: {template.name}")
|
self.log(f"Approved template: {template.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to approve template: {str(e)}", success=False)
|
self.log(f"Failed to approve template: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_template_publish(self):
|
def test_template_publish(self):
|
||||||
"""测试发布模板"""
|
"""测试发布模板"""
|
||||||
try:
|
try:
|
||||||
@@ -327,7 +324,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Published template: {template.name}")
|
self.log(f"Published template: {template.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to publish template: {str(e)}", success=False)
|
self.log(f"Failed to publish template: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_template_review(self):
|
def test_template_review(self):
|
||||||
"""测试添加模板评价"""
|
"""测试添加模板评价"""
|
||||||
try:
|
try:
|
||||||
@@ -343,7 +340,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Added template review: {review.rating} stars")
|
self.log(f"Added template review: {review.rating} stars")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to add template review: {str(e)}", success=False)
|
self.log(f"Failed to add template review: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_plugin_create(self):
|
def test_plugin_create(self):
|
||||||
"""测试创建插件"""
|
"""测试创建插件"""
|
||||||
try:
|
try:
|
||||||
@@ -371,7 +368,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['plugin'].append(plugin.id)
|
self.created_ids['plugin'].append(plugin.id)
|
||||||
self.log(f"Created plugin: {plugin.name} ({plugin.id})")
|
self.log(f"Created plugin: {plugin.name} ({plugin.id})")
|
||||||
|
|
||||||
# Create free plugin
|
# Create free plugin
|
||||||
plugin_free = self.manager.create_plugin(
|
plugin_free = self.manager.create_plugin(
|
||||||
name="数据导出插件",
|
name="数据导出插件",
|
||||||
@@ -386,23 +383,23 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['plugin'].append(plugin_free.id)
|
self.created_ids['plugin'].append(plugin_free.id)
|
||||||
self.log(f"Created free plugin: {plugin_free.name}")
|
self.log(f"Created free plugin: {plugin_free.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create plugin: {str(e)}", success=False)
|
self.log(f"Failed to create plugin: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_plugin_list(self):
|
def test_plugin_list(self):
|
||||||
"""测试列出插件"""
|
"""测试列出插件"""
|
||||||
try:
|
try:
|
||||||
plugins = self.manager.list_plugins()
|
plugins = self.manager.list_plugins()
|
||||||
self.log(f"Listed {len(plugins)} plugins")
|
self.log(f"Listed {len(plugins)} plugins")
|
||||||
|
|
||||||
# Filter by category
|
# Filter by category
|
||||||
integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION)
|
integration_plugins = self.manager.list_plugins(category=PluginCategory.INTEGRATION)
|
||||||
self.log(f"Found {len(integration_plugins)} integration plugins")
|
self.log(f"Found {len(integration_plugins)} integration plugins")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list plugins: {str(e)}", success=False)
|
self.log(f"Failed to list plugins: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_plugin_get(self):
|
def test_plugin_get(self):
|
||||||
"""测试获取插件详情"""
|
"""测试获取插件详情"""
|
||||||
try:
|
try:
|
||||||
@@ -412,7 +409,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Retrieved plugin: {plugin.name}")
|
self.log(f"Retrieved plugin: {plugin.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get plugin: {str(e)}", success=False)
|
self.log(f"Failed to get plugin: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_plugin_review(self):
|
def test_plugin_review(self):
|
||||||
"""测试审核插件"""
|
"""测试审核插件"""
|
||||||
try:
|
try:
|
||||||
@@ -427,7 +424,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
|
self.log(f"Reviewed plugin: {plugin.name} ({plugin.status.value})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to review plugin: {str(e)}", success=False)
|
self.log(f"Failed to review plugin: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_plugin_publish(self):
|
def test_plugin_publish(self):
|
||||||
"""测试发布插件"""
|
"""测试发布插件"""
|
||||||
try:
|
try:
|
||||||
@@ -437,7 +434,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Published plugin: {plugin.name}")
|
self.log(f"Published plugin: {plugin.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to publish plugin: {str(e)}", success=False)
|
self.log(f"Failed to publish plugin: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_plugin_review_add(self):
|
def test_plugin_review_add(self):
|
||||||
"""测试添加插件评价"""
|
"""测试添加插件评价"""
|
||||||
try:
|
try:
|
||||||
@@ -453,13 +450,13 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Added plugin review: {review.rating} stars")
|
self.log(f"Added plugin review: {review.rating} stars")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to add plugin review: {str(e)}", success=False)
|
self.log(f"Failed to add plugin review: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_developer_profile_create(self):
|
def test_developer_profile_create(self):
|
||||||
"""测试创建开发者档案"""
|
"""测试创建开发者档案"""
|
||||||
try:
|
try:
|
||||||
# Generate unique user IDs
|
# Generate unique user IDs
|
||||||
unique_id = uuid.uuid4().hex[:8]
|
unique_id = uuid.uuid4().hex[:8]
|
||||||
|
|
||||||
profile = self.manager.create_developer_profile(
|
profile = self.manager.create_developer_profile(
|
||||||
user_id=f"user_dev_{unique_id}_001",
|
user_id=f"user_dev_{unique_id}_001",
|
||||||
display_name="张三",
|
display_name="张三",
|
||||||
@@ -471,7 +468,7 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['developer'].append(profile.id)
|
self.created_ids['developer'].append(profile.id)
|
||||||
self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
|
self.log(f"Created developer profile: {profile.display_name} ({profile.id})")
|
||||||
|
|
||||||
# Create another developer
|
# Create another developer
|
||||||
profile2 = self.manager.create_developer_profile(
|
profile2 = self.manager.create_developer_profile(
|
||||||
user_id=f"user_dev_{unique_id}_002",
|
user_id=f"user_dev_{unique_id}_002",
|
||||||
@@ -481,10 +478,10 @@ class TestDeveloperEcosystem:
|
|||||||
)
|
)
|
||||||
self.created_ids['developer'].append(profile2.id)
|
self.created_ids['developer'].append(profile2.id)
|
||||||
self.log(f"Created developer profile: {profile2.display_name}")
|
self.log(f"Created developer profile: {profile2.display_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create developer profile: {str(e)}", success=False)
|
self.log(f"Failed to create developer profile: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_developer_profile_get(self):
|
def test_developer_profile_get(self):
|
||||||
"""测试获取开发者档案"""
|
"""测试获取开发者档案"""
|
||||||
try:
|
try:
|
||||||
@@ -494,7 +491,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Retrieved developer profile: {profile.display_name}")
|
self.log(f"Retrieved developer profile: {profile.display_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get developer profile: {str(e)}", success=False)
|
self.log(f"Failed to get developer profile: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_developer_verify(self):
|
def test_developer_verify(self):
|
||||||
"""测试验证开发者"""
|
"""测试验证开发者"""
|
||||||
try:
|
try:
|
||||||
@@ -507,7 +504,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
|
self.log(f"Verified developer: {profile.display_name} ({profile.status.value})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to verify developer: {str(e)}", success=False)
|
self.log(f"Failed to verify developer: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_developer_stats_update(self):
|
def test_developer_stats_update(self):
|
||||||
"""测试更新开发者统计"""
|
"""测试更新开发者统计"""
|
||||||
try:
|
try:
|
||||||
@@ -517,7 +514,7 @@ class TestDeveloperEcosystem:
|
|||||||
self.log(f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates")
|
self.log(f"Updated developer stats: {profile.plugin_count} plugins, {profile.template_count} templates")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to update developer stats: {str(e)}", success=False)
|
self.log(f"Failed to update developer stats: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_code_example_create(self):
|
def test_code_example_create(self):
|
||||||
"""测试创建代码示例"""
|
"""测试创建代码示例"""
|
||||||
try:
|
try:
|
||||||
@@ -540,7 +537,7 @@ print(f"Created project: {project.id}")
|
|||||||
)
|
)
|
||||||
self.created_ids['code_example'].append(example.id)
|
self.created_ids['code_example'].append(example.id)
|
||||||
self.log(f"Created code example: {example.title}")
|
self.log(f"Created code example: {example.title}")
|
||||||
|
|
||||||
# Create JavaScript example
|
# Create JavaScript example
|
||||||
example_js = self.manager.create_code_example(
|
example_js = self.manager.create_code_example(
|
||||||
title="使用 JavaScript SDK 上传文件",
|
title="使用 JavaScript SDK 上传文件",
|
||||||
@@ -563,23 +560,23 @@ console.log('Upload complete:', result.id);
|
|||||||
)
|
)
|
||||||
self.created_ids['code_example'].append(example_js.id)
|
self.created_ids['code_example'].append(example_js.id)
|
||||||
self.log(f"Created code example: {example_js.title}")
|
self.log(f"Created code example: {example_js.title}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create code example: {str(e)}", success=False)
|
self.log(f"Failed to create code example: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_code_example_list(self):
|
def test_code_example_list(self):
|
||||||
"""测试列出代码示例"""
|
"""测试列出代码示例"""
|
||||||
try:
|
try:
|
||||||
examples = self.manager.list_code_examples()
|
examples = self.manager.list_code_examples()
|
||||||
self.log(f"Listed {len(examples)} code examples")
|
self.log(f"Listed {len(examples)} code examples")
|
||||||
|
|
||||||
# Filter by language
|
# Filter by language
|
||||||
python_examples = self.manager.list_code_examples(language="python")
|
python_examples = self.manager.list_code_examples(language="python")
|
||||||
self.log(f"Found {len(python_examples)} Python examples")
|
self.log(f"Found {len(python_examples)} Python examples")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to list code examples: {str(e)}", success=False)
|
self.log(f"Failed to list code examples: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_code_example_get(self):
|
def test_code_example_get(self):
|
||||||
"""测试获取代码示例详情"""
|
"""测试获取代码示例详情"""
|
||||||
try:
|
try:
|
||||||
@@ -589,7 +586,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f"Retrieved code example: {example.title} (views: {example.view_count})")
|
self.log(f"Retrieved code example: {example.title} (views: {example.view_count})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get code example: {str(e)}", success=False)
|
self.log(f"Failed to get code example: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_portal_config_create(self):
|
def test_portal_config_create(self):
|
||||||
"""测试创建开发者门户配置"""
|
"""测试创建开发者门户配置"""
|
||||||
try:
|
try:
|
||||||
@@ -607,10 +604,10 @@ console.log('Upload complete:', result.id);
|
|||||||
)
|
)
|
||||||
self.created_ids['portal_config'].append(config.id)
|
self.created_ids['portal_config'].append(config.id)
|
||||||
self.log(f"Created portal config: {config.name}")
|
self.log(f"Created portal config: {config.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to create portal config: {str(e)}", success=False)
|
self.log(f"Failed to create portal config: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_portal_config_get(self):
|
def test_portal_config_get(self):
|
||||||
"""测试获取开发者门户配置"""
|
"""测试获取开发者门户配置"""
|
||||||
try:
|
try:
|
||||||
@@ -618,15 +615,15 @@ console.log('Upload complete:', result.id);
|
|||||||
config = self.manager.get_portal_config(self.created_ids['portal_config'][0])
|
config = self.manager.get_portal_config(self.created_ids['portal_config'][0])
|
||||||
if config:
|
if config:
|
||||||
self.log(f"Retrieved portal config: {config.name}")
|
self.log(f"Retrieved portal config: {config.name}")
|
||||||
|
|
||||||
# Test active config
|
# Test active config
|
||||||
active_config = self.manager.get_active_portal_config()
|
active_config = self.manager.get_active_portal_config()
|
||||||
if active_config:
|
if active_config:
|
||||||
self.log(f"Active portal config: {active_config.name}")
|
self.log(f"Active portal config: {active_config.name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get portal config: {str(e)}", success=False)
|
self.log(f"Failed to get portal config: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_revenue_record(self):
|
def test_revenue_record(self):
|
||||||
"""测试记录开发者收益"""
|
"""测试记录开发者收益"""
|
||||||
try:
|
try:
|
||||||
@@ -646,7 +643,7 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f" - Developer earnings: {revenue.developer_earnings}")
|
self.log(f" - Developer earnings: {revenue.developer_earnings}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to record revenue: {str(e)}", success=False)
|
self.log(f"Failed to record revenue: {str(e)}", success=False)
|
||||||
|
|
||||||
def test_revenue_summary(self):
|
def test_revenue_summary(self):
|
||||||
"""测试获取开发者收益汇总"""
|
"""测试获取开发者收益汇总"""
|
||||||
try:
|
try:
|
||||||
@@ -659,32 +656,32 @@ console.log('Upload complete:', result.id);
|
|||||||
self.log(f" - Transaction count: {summary['transaction_count']}")
|
self.log(f" - Transaction count: {summary['transaction_count']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failed to get revenue summary: {str(e)}", success=False)
|
self.log(f"Failed to get revenue summary: {str(e)}", success=False)
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
"""打印测试摘要"""
|
"""打印测试摘要"""
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("Test Summary")
|
print("Test Summary")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
total = len(self.test_results)
|
total = len(self.test_results)
|
||||||
passed = sum(1 for r in self.test_results if r['success'])
|
passed = sum(1 for r in self.test_results if r['success'])
|
||||||
failed = total - passed
|
failed = total - passed
|
||||||
|
|
||||||
print(f"Total tests: {total}")
|
print(f"Total tests: {total}")
|
||||||
print(f"Passed: {passed} ✅")
|
print(f"Passed: {passed} ✅")
|
||||||
print(f"Failed: {failed} ❌")
|
print(f"Failed: {failed} ❌")
|
||||||
|
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
print("\nFailed tests:")
|
print("\nFailed tests:")
|
||||||
for r in self.test_results:
|
for r in self.test_results:
|
||||||
if not r['success']:
|
if not r['success']:
|
||||||
print(f" - {r['message']}")
|
print(f" - {r['message']}")
|
||||||
|
|
||||||
print("\nCreated resources:")
|
print("\nCreated resources:")
|
||||||
for resource_type, ids in self.created_ids.items():
|
for resource_type, ids in self.created_ids.items():
|
||||||
if ids:
|
if ids:
|
||||||
print(f" {resource_type}: {len(ids)}")
|
print(f" {resource_type}: {len(ids)}")
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,12 @@ InsightFlow Phase 8 Task 8: Operations & Monitoring Test Script
|
|||||||
4. 成本优化
|
4. 成本优化
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from ops_manager import (
|
||||||
|
get_ops_manager, AlertSeverity, AlertStatus, AlertChannelType, AlertRuleType,
|
||||||
|
ResourceType
|
||||||
|
)
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
@@ -21,58 +24,53 @@ backend_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
if backend_dir not in sys.path:
|
if backend_dir not in sys.path:
|
||||||
sys.path.insert(0, backend_dir)
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
from ops_manager import (
|
|
||||||
get_ops_manager, AlertSeverity, AlertStatus, AlertChannelType, AlertRuleType,
|
|
||||||
ResourceType, ScalingAction, HealthStatus, BackupStatus
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpsManager:
|
class TestOpsManager:
|
||||||
"""测试运维与监控管理器"""
|
"""测试运维与监控管理器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.manager = get_ops_manager()
|
self.manager = get_ops_manager()
|
||||||
self.tenant_id = "test_tenant_001"
|
self.tenant_id = "test_tenant_001"
|
||||||
self.test_results = []
|
self.test_results = []
|
||||||
|
|
||||||
def log(self, message: str, success: bool = True):
|
def log(self, message: str, success: bool = True):
|
||||||
"""记录测试结果"""
|
"""记录测试结果"""
|
||||||
status = "✅" if success else "❌"
|
status = "✅" if success else "❌"
|
||||||
print(f"{status} {message}")
|
print(f"{status} {message}")
|
||||||
self.test_results.append((message, success))
|
self.test_results.append((message, success))
|
||||||
|
|
||||||
def run_all_tests(self):
|
def run_all_tests(self):
|
||||||
"""运行所有测试"""
|
"""运行所有测试"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests")
|
print("InsightFlow Phase 8 Task 8: Operations & Monitoring Tests")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 1. 告警系统测试
|
# 1. 告警系统测试
|
||||||
self.test_alert_rules()
|
self.test_alert_rules()
|
||||||
self.test_alert_channels()
|
self.test_alert_channels()
|
||||||
self.test_alerts()
|
self.test_alerts()
|
||||||
|
|
||||||
# 2. 容量规划与自动扩缩容测试
|
# 2. 容量规划与自动扩缩容测试
|
||||||
self.test_capacity_planning()
|
self.test_capacity_planning()
|
||||||
self.test_auto_scaling()
|
self.test_auto_scaling()
|
||||||
|
|
||||||
# 3. 健康检查与故障转移测试
|
# 3. 健康检查与故障转移测试
|
||||||
self.test_health_checks()
|
self.test_health_checks()
|
||||||
self.test_failover()
|
self.test_failover()
|
||||||
|
|
||||||
# 4. 备份与恢复测试
|
# 4. 备份与恢复测试
|
||||||
self.test_backup()
|
self.test_backup()
|
||||||
|
|
||||||
# 5. 成本优化测试
|
# 5. 成本优化测试
|
||||||
self.test_cost_optimization()
|
self.test_cost_optimization()
|
||||||
|
|
||||||
# 打印测试总结
|
# 打印测试总结
|
||||||
self.print_summary()
|
self.print_summary()
|
||||||
|
|
||||||
def test_alert_rules(self):
|
def test_alert_rules(self):
|
||||||
"""测试告警规则管理"""
|
"""测试告警规则管理"""
|
||||||
print("\n📋 Testing Alert Rules...")
|
print("\n📋 Testing Alert Rules...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建阈值告警规则
|
# 创建阈值告警规则
|
||||||
rule1 = self.manager.create_alert_rule(
|
rule1 = self.manager.create_alert_rule(
|
||||||
@@ -92,7 +90,7 @@ class TestOpsManager:
|
|||||||
created_by="test_user"
|
created_by="test_user"
|
||||||
)
|
)
|
||||||
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
|
self.log(f"Created alert rule: {rule1.name} (ID: {rule1.id})")
|
||||||
|
|
||||||
# 创建异常检测告警规则
|
# 创建异常检测告警规则
|
||||||
rule2 = self.manager.create_alert_rule(
|
rule2 = self.manager.create_alert_rule(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@@ -111,18 +109,18 @@ class TestOpsManager:
|
|||||||
created_by="test_user"
|
created_by="test_user"
|
||||||
)
|
)
|
||||||
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
|
self.log(f"Created anomaly alert rule: {rule2.name} (ID: {rule2.id})")
|
||||||
|
|
||||||
# 获取告警规则
|
# 获取告警规则
|
||||||
fetched_rule = self.manager.get_alert_rule(rule1.id)
|
fetched_rule = self.manager.get_alert_rule(rule1.id)
|
||||||
assert fetched_rule is not None
|
assert fetched_rule is not None
|
||||||
assert fetched_rule.name == rule1.name
|
assert fetched_rule.name == rule1.name
|
||||||
self.log(f"Fetched alert rule: {fetched_rule.name}")
|
self.log(f"Fetched alert rule: {fetched_rule.name}")
|
||||||
|
|
||||||
# 列出租户的所有告警规则
|
# 列出租户的所有告警规则
|
||||||
rules = self.manager.list_alert_rules(self.tenant_id)
|
rules = self.manager.list_alert_rules(self.tenant_id)
|
||||||
assert len(rules) >= 2
|
assert len(rules) >= 2
|
||||||
self.log(f"Listed {len(rules)} alert rules for tenant")
|
self.log(f"Listed {len(rules)} alert rules for tenant")
|
||||||
|
|
||||||
# 更新告警规则
|
# 更新告警规则
|
||||||
updated_rule = self.manager.update_alert_rule(
|
updated_rule = self.manager.update_alert_rule(
|
||||||
rule1.id,
|
rule1.id,
|
||||||
@@ -131,19 +129,19 @@ class TestOpsManager:
|
|||||||
)
|
)
|
||||||
assert updated_rule.threshold == 85.0
|
assert updated_rule.threshold == 85.0
|
||||||
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
|
self.log(f"Updated alert rule threshold to {updated_rule.threshold}")
|
||||||
|
|
||||||
# 测试完成,清理
|
# 测试完成,清理
|
||||||
self.manager.delete_alert_rule(rule1.id)
|
self.manager.delete_alert_rule(rule1.id)
|
||||||
self.manager.delete_alert_rule(rule2.id)
|
self.manager.delete_alert_rule(rule2.id)
|
||||||
self.log("Deleted test alert rules")
|
self.log("Deleted test alert rules")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Alert rules test failed: {e}", success=False)
|
self.log(f"Alert rules test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_alert_channels(self):
|
def test_alert_channels(self):
|
||||||
"""测试告警渠道管理"""
|
"""测试告警渠道管理"""
|
||||||
print("\n📢 Testing Alert Channels...")
|
print("\n📢 Testing Alert Channels...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建飞书告警渠道
|
# 创建飞书告警渠道
|
||||||
channel1 = self.manager.create_alert_channel(
|
channel1 = self.manager.create_alert_channel(
|
||||||
@@ -157,7 +155,7 @@ class TestOpsManager:
|
|||||||
severity_filter=["p0", "p1"]
|
severity_filter=["p0", "p1"]
|
||||||
)
|
)
|
||||||
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
|
self.log(f"Created Feishu channel: {channel1.name} (ID: {channel1.id})")
|
||||||
|
|
||||||
# 创建钉钉告警渠道
|
# 创建钉钉告警渠道
|
||||||
channel2 = self.manager.create_alert_channel(
|
channel2 = self.manager.create_alert_channel(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@@ -170,7 +168,7 @@ class TestOpsManager:
|
|||||||
severity_filter=["p0", "p1", "p2"]
|
severity_filter=["p0", "p1", "p2"]
|
||||||
)
|
)
|
||||||
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
|
self.log(f"Created DingTalk channel: {channel2.name} (ID: {channel2.id})")
|
||||||
|
|
||||||
# 创建 Slack 告警渠道
|
# 创建 Slack 告警渠道
|
||||||
channel3 = self.manager.create_alert_channel(
|
channel3 = self.manager.create_alert_channel(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@@ -182,18 +180,18 @@ class TestOpsManager:
|
|||||||
severity_filter=["p0", "p1", "p2", "p3"]
|
severity_filter=["p0", "p1", "p2", "p3"]
|
||||||
)
|
)
|
||||||
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
|
self.log(f"Created Slack channel: {channel3.name} (ID: {channel3.id})")
|
||||||
|
|
||||||
# 获取告警渠道
|
# 获取告警渠道
|
||||||
fetched_channel = self.manager.get_alert_channel(channel1.id)
|
fetched_channel = self.manager.get_alert_channel(channel1.id)
|
||||||
assert fetched_channel is not None
|
assert fetched_channel is not None
|
||||||
assert fetched_channel.name == channel1.name
|
assert fetched_channel.name == channel1.name
|
||||||
self.log(f"Fetched alert channel: {fetched_channel.name}")
|
self.log(f"Fetched alert channel: {fetched_channel.name}")
|
||||||
|
|
||||||
# 列出租户的所有告警渠道
|
# 列出租户的所有告警渠道
|
||||||
channels = self.manager.list_alert_channels(self.tenant_id)
|
channels = self.manager.list_alert_channels(self.tenant_id)
|
||||||
assert len(channels) >= 3
|
assert len(channels) >= 3
|
||||||
self.log(f"Listed {len(channels)} alert channels for tenant")
|
self.log(f"Listed {len(channels)} alert channels for tenant")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
for channel in channels:
|
for channel in channels:
|
||||||
if channel.tenant_id == self.tenant_id:
|
if channel.tenant_id == self.tenant_id:
|
||||||
@@ -201,14 +199,14 @@ class TestOpsManager:
|
|||||||
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,))
|
conn.execute("DELETE FROM alert_channels WHERE id = ?", (channel.id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Deleted test alert channels")
|
self.log("Deleted test alert channels")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Alert channels test failed: {e}", success=False)
|
self.log(f"Alert channels test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_alerts(self):
|
def test_alerts(self):
|
||||||
"""测试告警管理"""
|
"""测试告警管理"""
|
||||||
print("\n🚨 Testing Alerts...")
|
print("\n🚨 Testing Alerts...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建告警规则
|
# 创建告警规则
|
||||||
rule = self.manager.create_alert_rule(
|
rule = self.manager.create_alert_rule(
|
||||||
@@ -227,7 +225,7 @@ class TestOpsManager:
|
|||||||
annotations={},
|
annotations={},
|
||||||
created_by="test_user"
|
created_by="test_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录资源指标
|
# 记录资源指标
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.manager.record_resource_metric(
|
self.manager.record_resource_metric(
|
||||||
@@ -240,12 +238,12 @@ class TestOpsManager:
|
|||||||
metadata={"region": "cn-north-1"}
|
metadata={"region": "cn-north-1"}
|
||||||
)
|
)
|
||||||
self.log("Recorded 10 resource metrics")
|
self.log("Recorded 10 resource metrics")
|
||||||
|
|
||||||
# 手动创建告警
|
# 手动创建告警
|
||||||
from ops_manager import Alert
|
from ops_manager import Alert
|
||||||
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
alert_id = f"test_alert_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
|
|
||||||
alert = Alert(
|
alert = Alert(
|
||||||
id=alert_id,
|
id=alert_id,
|
||||||
rule_id=rule.id,
|
rule_id=rule.id,
|
||||||
@@ -266,10 +264,10 @@ class TestOpsManager:
|
|||||||
notification_sent={},
|
notification_sent={},
|
||||||
suppression_count=0
|
suppression_count=0
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
INSERT INTO alerts
|
INSERT INTO alerts
|
||||||
(id, rule_id, tenant_id, severity, status, title, description,
|
(id, rule_id, tenant_id, severity, status, title, description,
|
||||||
metric, value, threshold, labels, annotations, started_at, notification_sent, suppression_count)
|
metric, value, threshold, labels, annotations, started_at, notification_sent, suppression_count)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
@@ -279,28 +277,28 @@ class TestOpsManager:
|
|||||||
json.dumps(alert.labels), json.dumps(alert.annotations),
|
json.dumps(alert.labels), json.dumps(alert.annotations),
|
||||||
alert.started_at, json.dumps(alert.notification_sent), alert.suppression_count))
|
alert.started_at, json.dumps(alert.notification_sent), alert.suppression_count))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
self.log(f"Created test alert: {alert.id}")
|
self.log(f"Created test alert: {alert.id}")
|
||||||
|
|
||||||
# 列出租户的告警
|
# 列出租户的告警
|
||||||
alerts = self.manager.list_alerts(self.tenant_id)
|
alerts = self.manager.list_alerts(self.tenant_id)
|
||||||
assert len(alerts) >= 1
|
assert len(alerts) >= 1
|
||||||
self.log(f"Listed {len(alerts)} alerts for tenant")
|
self.log(f"Listed {len(alerts)} alerts for tenant")
|
||||||
|
|
||||||
# 确认告警
|
# 确认告警
|
||||||
self.manager.acknowledge_alert(alert_id, "test_user")
|
self.manager.acknowledge_alert(alert_id, "test_user")
|
||||||
fetched_alert = self.manager.get_alert(alert_id)
|
fetched_alert = self.manager.get_alert(alert_id)
|
||||||
assert fetched_alert.status == AlertStatus.ACKNOWLEDGED
|
assert fetched_alert.status == AlertStatus.ACKNOWLEDGED
|
||||||
assert fetched_alert.acknowledged_by == "test_user"
|
assert fetched_alert.acknowledged_by == "test_user"
|
||||||
self.log(f"Acknowledged alert: {alert_id}")
|
self.log(f"Acknowledged alert: {alert_id}")
|
||||||
|
|
||||||
# 解决告警
|
# 解决告警
|
||||||
self.manager.resolve_alert(alert_id)
|
self.manager.resolve_alert(alert_id)
|
||||||
fetched_alert = self.manager.get_alert(alert_id)
|
fetched_alert = self.manager.get_alert(alert_id)
|
||||||
assert fetched_alert.status == AlertStatus.RESOLVED
|
assert fetched_alert.status == AlertStatus.RESOLVED
|
||||||
assert fetched_alert.resolved_at is not None
|
assert fetched_alert.resolved_at is not None
|
||||||
self.log(f"Resolved alert: {alert_id}")
|
self.log(f"Resolved alert: {alert_id}")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
self.manager.delete_alert_rule(rule.id)
|
self.manager.delete_alert_rule(rule.id)
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
@@ -308,14 +306,14 @@ class TestOpsManager:
|
|||||||
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up test data")
|
self.log("Cleaned up test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Alerts test failed: {e}", success=False)
|
self.log(f"Alerts test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_capacity_planning(self):
|
def test_capacity_planning(self):
|
||||||
"""测试容量规划"""
|
"""测试容量规划"""
|
||||||
print("\n📊 Testing Capacity Planning...")
|
print("\n📊 Testing Capacity Planning...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 记录历史指标数据
|
# 记录历史指标数据
|
||||||
import random
|
import random
|
||||||
@@ -324,15 +322,15 @@ class TestOpsManager:
|
|||||||
timestamp = (base_time + timedelta(days=i)).isoformat()
|
timestamp = (base_time + timedelta(days=i)).isoformat()
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
INSERT INTO resource_metrics
|
INSERT INTO resource_metrics
|
||||||
(id, tenant_id, resource_type, resource_id, metric_name, metric_value, unit, timestamp)
|
(id, tenant_id, resource_type, resource_id, metric_name, metric_value, unit, timestamp)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (f"cm_{i}", self.tenant_id, ResourceType.CPU.value, "server-001",
|
""", (f"cm_{i}", self.tenant_id, ResourceType.CPU.value, "server-001",
|
||||||
"cpu_usage_percent", 50.0 + random.random() * 30, "percent", timestamp))
|
"cpu_usage_percent", 50.0 + random.random() * 30, "percent", timestamp))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
self.log("Recorded 30 days of historical metrics")
|
self.log("Recorded 30 days of historical metrics")
|
||||||
|
|
||||||
# 创建容量规划
|
# 创建容量规划
|
||||||
prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
|
prediction_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
|
||||||
plan = self.manager.create_capacity_plan(
|
plan = self.manager.create_capacity_plan(
|
||||||
@@ -342,31 +340,31 @@ class TestOpsManager:
|
|||||||
prediction_date=prediction_date,
|
prediction_date=prediction_date,
|
||||||
confidence=0.85
|
confidence=0.85
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Created capacity plan: {plan.id}")
|
self.log(f"Created capacity plan: {plan.id}")
|
||||||
self.log(f" Current capacity: {plan.current_capacity}")
|
self.log(f" Current capacity: {plan.current_capacity}")
|
||||||
self.log(f" Predicted capacity: {plan.predicted_capacity}")
|
self.log(f" Predicted capacity: {plan.predicted_capacity}")
|
||||||
self.log(f" Recommended action: {plan.recommended_action}")
|
self.log(f" Recommended action: {plan.recommended_action}")
|
||||||
|
|
||||||
# 获取容量规划列表
|
# 获取容量规划列表
|
||||||
plans = self.manager.get_capacity_plans(self.tenant_id)
|
plans = self.manager.get_capacity_plans(self.tenant_id)
|
||||||
assert len(plans) >= 1
|
assert len(plans) >= 1
|
||||||
self.log(f"Listed {len(plans)} capacity plans")
|
self.log(f"Listed {len(plans)} capacity plans")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM capacity_plans WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM resource_metrics WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up capacity planning test data")
|
self.log("Cleaned up capacity planning test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Capacity planning test failed: {e}", success=False)
|
self.log(f"Capacity planning test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_auto_scaling(self):
|
def test_auto_scaling(self):
|
||||||
"""测试自动扩缩容"""
|
"""测试自动扩缩容"""
|
||||||
print("\n⚖️ Testing Auto Scaling...")
|
print("\n⚖️ Testing Auto Scaling...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建自动扩缩容策略
|
# 创建自动扩缩容策略
|
||||||
policy = self.manager.create_auto_scaling_policy(
|
policy = self.manager.create_auto_scaling_policy(
|
||||||
@@ -382,49 +380,49 @@ class TestOpsManager:
|
|||||||
scale_down_step=1,
|
scale_down_step=1,
|
||||||
cooldown_period=300
|
cooldown_period=300
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
|
self.log(f"Created auto scaling policy: {policy.name} (ID: {policy.id})")
|
||||||
self.log(f" Min instances: {policy.min_instances}")
|
self.log(f" Min instances: {policy.min_instances}")
|
||||||
self.log(f" Max instances: {policy.max_instances}")
|
self.log(f" Max instances: {policy.max_instances}")
|
||||||
self.log(f" Target utilization: {policy.target_utilization}")
|
self.log(f" Target utilization: {policy.target_utilization}")
|
||||||
|
|
||||||
# 获取策略列表
|
# 获取策略列表
|
||||||
policies = self.manager.list_auto_scaling_policies(self.tenant_id)
|
policies = self.manager.list_auto_scaling_policies(self.tenant_id)
|
||||||
assert len(policies) >= 1
|
assert len(policies) >= 1
|
||||||
self.log(f"Listed {len(policies)} auto scaling policies")
|
self.log(f"Listed {len(policies)} auto scaling policies")
|
||||||
|
|
||||||
# 模拟扩缩容评估
|
# 模拟扩缩容评估
|
||||||
event = self.manager.evaluate_scaling_policy(
|
event = self.manager.evaluate_scaling_policy(
|
||||||
policy_id=policy.id,
|
policy_id=policy.id,
|
||||||
current_instances=3,
|
current_instances=3,
|
||||||
current_utilization=0.85
|
current_utilization=0.85
|
||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
self.log(f"Scaling event triggered: {event.action.value}")
|
self.log(f"Scaling event triggered: {event.action.value}")
|
||||||
self.log(f" From {event.from_count} to {event.to_count} instances")
|
self.log(f" From {event.from_count} to {event.to_count} instances")
|
||||||
self.log(f" Reason: {event.reason}")
|
self.log(f" Reason: {event.reason}")
|
||||||
else:
|
else:
|
||||||
self.log("No scaling action needed")
|
self.log("No scaling action needed")
|
||||||
|
|
||||||
# 获取扩缩容事件列表
|
# 获取扩缩容事件列表
|
||||||
events = self.manager.list_scaling_events(self.tenant_id)
|
events = self.manager.list_scaling_events(self.tenant_id)
|
||||||
self.log(f"Listed {len(events)} scaling events")
|
self.log(f"Listed {len(events)} scaling events")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM scaling_events WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.execute("DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM auto_scaling_policies WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up auto scaling test data")
|
self.log("Cleaned up auto scaling test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Auto scaling test failed: {e}", success=False)
|
self.log(f"Auto scaling test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_health_checks(self):
|
def test_health_checks(self):
|
||||||
"""测试健康检查"""
|
"""测试健康检查"""
|
||||||
print("\n💓 Testing Health Checks...")
|
print("\n💓 Testing Health Checks...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建 HTTP 健康检查
|
# 创建 HTTP 健康检查
|
||||||
check1 = self.manager.create_health_check(
|
check1 = self.manager.create_health_check(
|
||||||
@@ -442,7 +440,7 @@ class TestOpsManager:
|
|||||||
retry_count=3
|
retry_count=3
|
||||||
)
|
)
|
||||||
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
|
self.log(f"Created HTTP health check: {check1.name} (ID: {check1.id})")
|
||||||
|
|
||||||
# 创建 TCP 健康检查
|
# 创建 TCP 健康检查
|
||||||
check2 = self.manager.create_health_check(
|
check2 = self.manager.create_health_check(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@@ -459,33 +457,33 @@ class TestOpsManager:
|
|||||||
retry_count=2
|
retry_count=2
|
||||||
)
|
)
|
||||||
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
|
self.log(f"Created TCP health check: {check2.name} (ID: {check2.id})")
|
||||||
|
|
||||||
# 获取健康检查列表
|
# 获取健康检查列表
|
||||||
checks = self.manager.list_health_checks(self.tenant_id)
|
checks = self.manager.list_health_checks(self.tenant_id)
|
||||||
assert len(checks) >= 2
|
assert len(checks) >= 2
|
||||||
self.log(f"Listed {len(checks)} health checks")
|
self.log(f"Listed {len(checks)} health checks")
|
||||||
|
|
||||||
# 执行健康检查(异步)
|
# 执行健康检查(异步)
|
||||||
async def run_health_check():
|
async def run_health_check():
|
||||||
result = await self.manager.execute_health_check(check1.id)
|
result = await self.manager.execute_health_check(check1.id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 由于健康检查需要网络,这里只验证方法存在
|
# 由于健康检查需要网络,这里只验证方法存在
|
||||||
self.log("Health check execution method verified")
|
self.log("Health check execution method verified")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM health_checks WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up health check test data")
|
self.log("Cleaned up health check test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Health checks test failed: {e}", success=False)
|
self.log(f"Health checks test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_failover(self):
|
def test_failover(self):
|
||||||
"""测试故障转移"""
|
"""测试故障转移"""
|
||||||
print("\n🔄 Testing Failover...")
|
print("\n🔄 Testing Failover...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建故障转移配置
|
# 创建故障转移配置
|
||||||
config = self.manager.create_failover_config(
|
config = self.manager.create_failover_config(
|
||||||
@@ -498,51 +496,51 @@ class TestOpsManager:
|
|||||||
failover_timeout=300,
|
failover_timeout=300,
|
||||||
health_check_id=None
|
health_check_id=None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Created failover config: {config.name} (ID: {config.id})")
|
self.log(f"Created failover config: {config.name} (ID: {config.id})")
|
||||||
self.log(f" Primary region: {config.primary_region}")
|
self.log(f" Primary region: {config.primary_region}")
|
||||||
self.log(f" Secondary regions: {config.secondary_regions}")
|
self.log(f" Secondary regions: {config.secondary_regions}")
|
||||||
|
|
||||||
# 获取故障转移配置列表
|
# 获取故障转移配置列表
|
||||||
configs = self.manager.list_failover_configs(self.tenant_id)
|
configs = self.manager.list_failover_configs(self.tenant_id)
|
||||||
assert len(configs) >= 1
|
assert len(configs) >= 1
|
||||||
self.log(f"Listed {len(configs)} failover configs")
|
self.log(f"Listed {len(configs)} failover configs")
|
||||||
|
|
||||||
# 发起故障转移
|
# 发起故障转移
|
||||||
event = self.manager.initiate_failover(
|
event = self.manager.initiate_failover(
|
||||||
config_id=config.id,
|
config_id=config.id,
|
||||||
reason="Primary region health check failed"
|
reason="Primary region health check failed"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
self.log(f"Initiated failover: {event.id}")
|
self.log(f"Initiated failover: {event.id}")
|
||||||
self.log(f" From: {event.from_region}")
|
self.log(f" From: {event.from_region}")
|
||||||
self.log(f" To: {event.to_region}")
|
self.log(f" To: {event.to_region}")
|
||||||
|
|
||||||
# 更新故障转移状态
|
# 更新故障转移状态
|
||||||
self.manager.update_failover_status(event.id, "completed")
|
self.manager.update_failover_status(event.id, "completed")
|
||||||
updated_event = self.manager.get_failover_event(event.id)
|
updated_event = self.manager.get_failover_event(event.id)
|
||||||
assert updated_event.status == "completed"
|
assert updated_event.status == "completed"
|
||||||
self.log(f"Failover completed")
|
self.log(f"Failover completed")
|
||||||
|
|
||||||
# 获取故障转移事件列表
|
# 获取故障转移事件列表
|
||||||
events = self.manager.list_failover_events(self.tenant_id)
|
events = self.manager.list_failover_events(self.tenant_id)
|
||||||
self.log(f"Listed {len(events)} failover events")
|
self.log(f"Listed {len(events)} failover events")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM failover_events WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM failover_configs WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up failover test data")
|
self.log("Cleaned up failover test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Failover test failed: {e}", success=False)
|
self.log(f"Failover test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_backup(self):
|
def test_backup(self):
|
||||||
"""测试备份与恢复"""
|
"""测试备份与恢复"""
|
||||||
print("\n💾 Testing Backup & Recovery...")
|
print("\n💾 Testing Backup & Recovery...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建备份任务
|
# 创建备份任务
|
||||||
job = self.manager.create_backup_job(
|
job = self.manager.create_backup_job(
|
||||||
@@ -557,51 +555,51 @@ class TestOpsManager:
|
|||||||
compression_enabled=True,
|
compression_enabled=True,
|
||||||
storage_location="s3://insightflow-backups/"
|
storage_location="s3://insightflow-backups/"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Created backup job: {job.name} (ID: {job.id})")
|
self.log(f"Created backup job: {job.name} (ID: {job.id})")
|
||||||
self.log(f" Schedule: {job.schedule}")
|
self.log(f" Schedule: {job.schedule}")
|
||||||
self.log(f" Retention: {job.retention_days} days")
|
self.log(f" Retention: {job.retention_days} days")
|
||||||
|
|
||||||
# 获取备份任务列表
|
# 获取备份任务列表
|
||||||
jobs = self.manager.list_backup_jobs(self.tenant_id)
|
jobs = self.manager.list_backup_jobs(self.tenant_id)
|
||||||
assert len(jobs) >= 1
|
assert len(jobs) >= 1
|
||||||
self.log(f"Listed {len(jobs)} backup jobs")
|
self.log(f"Listed {len(jobs)} backup jobs")
|
||||||
|
|
||||||
# 执行备份
|
# 执行备份
|
||||||
record = self.manager.execute_backup(job.id)
|
record = self.manager.execute_backup(job.id)
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
self.log(f"Executed backup: {record.id}")
|
self.log(f"Executed backup: {record.id}")
|
||||||
self.log(f" Status: {record.status.value}")
|
self.log(f" Status: {record.status.value}")
|
||||||
self.log(f" Storage: {record.storage_path}")
|
self.log(f" Storage: {record.storage_path}")
|
||||||
|
|
||||||
# 获取备份记录列表
|
# 获取备份记录列表
|
||||||
records = self.manager.list_backup_records(self.tenant_id)
|
records = self.manager.list_backup_records(self.tenant_id)
|
||||||
self.log(f"Listed {len(records)} backup records")
|
self.log(f"Listed {len(records)} backup records")
|
||||||
|
|
||||||
# 测试恢复(模拟)
|
# 测试恢复(模拟)
|
||||||
restore_result = self.manager.restore_from_backup(record.id)
|
restore_result = self.manager.restore_from_backup(record.id)
|
||||||
self.log(f"Restore test result: {restore_result}")
|
self.log(f"Restore test result: {restore_result}")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM backup_records WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM backup_jobs WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up backup test data")
|
self.log("Cleaned up backup test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Backup test failed: {e}", success=False)
|
self.log(f"Backup test failed: {e}", success=False)
|
||||||
|
|
||||||
def test_cost_optimization(self):
|
def test_cost_optimization(self):
|
||||||
"""测试成本优化"""
|
"""测试成本优化"""
|
||||||
print("\n💰 Testing Cost Optimization...")
|
print("\n💰 Testing Cost Optimization...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 记录资源利用率数据
|
# 记录资源利用率数据
|
||||||
import random
|
import random
|
||||||
report_date = datetime.now().strftime("%Y-%m-%d")
|
report_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.manager.record_resource_utilization(
|
self.manager.record_resource_utilization(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@@ -614,9 +612,9 @@ class TestOpsManager:
|
|||||||
report_date=report_date,
|
report_date=report_date,
|
||||||
recommendations=["Consider downsizing this resource"]
|
recommendations=["Consider downsizing this resource"]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log("Recorded 5 resource utilization records")
|
self.log("Recorded 5 resource utilization records")
|
||||||
|
|
||||||
# 生成成本报告
|
# 生成成本报告
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
report = self.manager.generate_cost_report(
|
report = self.manager.generate_cost_report(
|
||||||
@@ -624,35 +622,38 @@ class TestOpsManager:
|
|||||||
year=now.year,
|
year=now.year,
|
||||||
month=now.month
|
month=now.month
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Generated cost report: {report.id}")
|
self.log(f"Generated cost report: {report.id}")
|
||||||
self.log(f" Period: {report.report_period}")
|
self.log(f" Period: {report.report_period}")
|
||||||
self.log(f" Total cost: {report.total_cost} {report.currency}")
|
self.log(f" Total cost: {report.total_cost} {report.currency}")
|
||||||
self.log(f" Anomalies detected: {len(report.anomalies)}")
|
self.log(f" Anomalies detected: {len(report.anomalies)}")
|
||||||
|
|
||||||
# 检测闲置资源
|
# 检测闲置资源
|
||||||
idle_resources = self.manager.detect_idle_resources(self.tenant_id)
|
idle_resources = self.manager.detect_idle_resources(self.tenant_id)
|
||||||
self.log(f"Detected {len(idle_resources)} idle resources")
|
self.log(f"Detected {len(idle_resources)} idle resources")
|
||||||
|
|
||||||
# 获取闲置资源列表
|
# 获取闲置资源列表
|
||||||
idle_list = self.manager.get_idle_resources(self.tenant_id)
|
idle_list = self.manager.get_idle_resources(self.tenant_id)
|
||||||
for resource in idle_list:
|
for resource in idle_list:
|
||||||
self.log(f" Idle resource: {resource.resource_name} (est. cost: {resource.estimated_monthly_cost}/month)")
|
self.log(
|
||||||
|
f" Idle resource: {
|
||||||
|
resource.resource_name} (est. cost: {
|
||||||
|
resource.estimated_monthly_cost}/month)")
|
||||||
|
|
||||||
# 生成成本优化建议
|
# 生成成本优化建议
|
||||||
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
|
suggestions = self.manager.generate_cost_optimization_suggestions(self.tenant_id)
|
||||||
self.log(f"Generated {len(suggestions)} cost optimization suggestions")
|
self.log(f"Generated {len(suggestions)} cost optimization suggestions")
|
||||||
|
|
||||||
for suggestion in suggestions:
|
for suggestion in suggestions:
|
||||||
self.log(f" Suggestion: {suggestion.title}")
|
self.log(f" Suggestion: {suggestion.title}")
|
||||||
self.log(f" Potential savings: {suggestion.potential_savings} {suggestion.currency}")
|
self.log(f" Potential savings: {suggestion.potential_savings} {suggestion.currency}")
|
||||||
self.log(f" Confidence: {suggestion.confidence}")
|
self.log(f" Confidence: {suggestion.confidence}")
|
||||||
self.log(f" Difficulty: {suggestion.difficulty}")
|
self.log(f" Difficulty: {suggestion.difficulty}")
|
||||||
|
|
||||||
# 获取优化建议列表
|
# 获取优化建议列表
|
||||||
all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id)
|
all_suggestions = self.manager.get_cost_optimization_suggestions(self.tenant_id)
|
||||||
self.log(f"Listed {len(all_suggestions)} optimization suggestions")
|
self.log(f"Listed {len(all_suggestions)} optimization suggestions")
|
||||||
|
|
||||||
# 应用优化建议
|
# 应用优化建议
|
||||||
if all_suggestions:
|
if all_suggestions:
|
||||||
applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id)
|
applied = self.manager.apply_cost_optimization_suggestion(all_suggestions[0].id)
|
||||||
@@ -660,7 +661,7 @@ class TestOpsManager:
|
|||||||
self.log(f"Applied optimization suggestion: {applied.title}")
|
self.log(f"Applied optimization suggestion: {applied.title}")
|
||||||
assert applied.is_applied
|
assert applied.is_applied
|
||||||
assert applied.applied_at is not None
|
assert applied.applied_at is not None
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
with self.manager._get_db() as conn:
|
with self.manager._get_db() as conn:
|
||||||
conn.execute("DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM cost_optimization_suggestions WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
@@ -669,30 +670,30 @@ class TestOpsManager:
|
|||||||
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
|
conn.execute("DELETE FROM cost_reports WHERE tenant_id = ?", (self.tenant_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
self.log("Cleaned up cost optimization test data")
|
self.log("Cleaned up cost optimization test data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"Cost optimization test failed: {e}", success=False)
|
self.log(f"Cost optimization test failed: {e}", success=False)
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
"""打印测试总结"""
|
"""打印测试总结"""
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("Test Summary")
|
print("Test Summary")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
total = len(self.test_results)
|
total = len(self.test_results)
|
||||||
passed = sum(1 for _, success in self.test_results if success)
|
passed = sum(1 for _, success in self.test_results if success)
|
||||||
failed = total - passed
|
failed = total - passed
|
||||||
|
|
||||||
print(f"Total tests: {total}")
|
print(f"Total tests: {total}")
|
||||||
print(f"Passed: {passed} ✅")
|
print(f"Passed: {passed} ✅")
|
||||||
print(f"Failed: {failed} ❌")
|
print(f"Failed: {failed} ❌")
|
||||||
|
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
print("\nFailed tests:")
|
print("\nFailed tests:")
|
||||||
for message, success in self.test_results:
|
for message, success in self.test_results:
|
||||||
if not success:
|
if not success:
|
||||||
print(f" ❌ {message}")
|
print(f" ❌ {message}")
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,28 +5,23 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
import httpx
|
|
||||||
import hmac
|
|
||||||
import hashlib
|
|
||||||
import base64
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Any
|
from typing import Dict, Any
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
class TingwuClient:
|
class TingwuClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
|
self.access_key = os.getenv("ALI_ACCESS_KEY", "")
|
||||||
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
|
self.secret_key = os.getenv("ALI_SECRET_KEY", "")
|
||||||
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
|
self.endpoint = "https://tingwu.cn-beijing.aliyuncs.com"
|
||||||
|
|
||||||
if not self.access_key or not self.secret_key:
|
if not self.access_key or not self.secret_key:
|
||||||
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
|
raise ValueError("ALI_ACCESS_KEY and ALI_SECRET_KEY required")
|
||||||
|
|
||||||
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> Dict[str, str]:
|
def _sign_request(self, method: str, uri: str, query: str = "", body: str = "") -> Dict[str, str]:
|
||||||
"""阿里云签名 V3"""
|
"""阿里云签名 V3"""
|
||||||
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
timestamp = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
|
|
||||||
# 简化签名,实际生产需要完整实现
|
# 简化签名,实际生产需要完整实现
|
||||||
# 这里使用基础认证头
|
# 这里使用基础认证头
|
||||||
return {
|
return {
|
||||||
@@ -36,11 +31,11 @@ class TingwuClient:
|
|||||||
"x-acs-date": timestamp,
|
"x-acs-date": timestamp,
|
||||||
"Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing",
|
"Authorization": f"ACS3-HMAC-SHA256 Credential={self.access_key}/acs/tingwu/cn-beijing",
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_task(self, audio_url: str, language: str = "zh") -> str:
|
def create_task(self, audio_url: str, language: str = "zh") -> str:
|
||||||
"""创建听悟任务"""
|
"""创建听悟任务"""
|
||||||
url = f"{self.endpoint}/openapi/tingwu/v2/tasks"
|
f"{self.endpoint}/openapi/tingwu/v2/tasks"
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"Input": {
|
"Input": {
|
||||||
"Source": "OSS",
|
"Source": "OSS",
|
||||||
@@ -53,20 +48,20 @@ class TingwuClient:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 使用阿里云 SDK 方式调用
|
# 使用阿里云 SDK 方式调用
|
||||||
try:
|
try:
|
||||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||||
from alibabacloud_tea_openapi import models as open_api_models
|
from alibabacloud_tea_openapi import models as open_api_models
|
||||||
|
|
||||||
config = open_api_models.Config(
|
config = open_api_models.Config(
|
||||||
access_key_id=self.access_key,
|
access_key_id=self.access_key,
|
||||||
access_key_secret=self.secret_key
|
access_key_secret=self.secret_key
|
||||||
)
|
)
|
||||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||||
client = TingwuSDKClient(config)
|
client = TingwuSDKClient(config)
|
||||||
|
|
||||||
request = tingwu_models.CreateTaskRequest(
|
request = tingwu_models.CreateTaskRequest(
|
||||||
type="offline",
|
type="offline",
|
||||||
input=tingwu_models.Input(
|
input=tingwu_models.Input(
|
||||||
@@ -80,13 +75,13 @@ class TingwuClient:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.create_task(request)
|
response = client.create_task(request)
|
||||||
if response.body.code == "0":
|
if response.body.code == "0":
|
||||||
return response.body.data.task_id
|
return response.body.data.task_id
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Create task failed: {response.body.message}")
|
raise Exception(f"Create task failed: {response.body.message}")
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback: 使用 mock
|
# Fallback: 使用 mock
|
||||||
print("Tingwu SDK not available, using mock")
|
print("Tingwu SDK not available, using mock")
|
||||||
@@ -94,59 +89,59 @@ class TingwuClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Tingwu API error: {e}")
|
print(f"Tingwu API error: {e}")
|
||||||
return f"mock_task_{int(time.time())}"
|
return f"mock_task_{int(time.time())}"
|
||||||
|
|
||||||
def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> Dict[str, Any]:
|
def get_task_result(self, task_id: str, max_retries: int = 60, interval: int = 5) -> Dict[str, Any]:
|
||||||
"""获取任务结果"""
|
"""获取任务结果"""
|
||||||
try:
|
try:
|
||||||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||||||
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
from alibabacloud_tingwu20230930.client import Client as TingwuSDKClient
|
||||||
from alibabacloud_tea_openapi import models as open_api_models
|
from alibabacloud_tea_openapi import models as open_api_models
|
||||||
|
|
||||||
config = open_api_models.Config(
|
config = open_api_models.Config(
|
||||||
access_key_id=self.access_key,
|
access_key_id=self.access_key,
|
||||||
access_key_secret=self.secret_key
|
access_key_secret=self.secret_key
|
||||||
)
|
)
|
||||||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||||||
client = TingwuSDKClient(config)
|
client = TingwuSDKClient(config)
|
||||||
|
|
||||||
for i in range(max_retries):
|
for i in range(max_retries):
|
||||||
request = tingwu_models.GetTaskInfoRequest()
|
request = tingwu_models.GetTaskInfoRequest()
|
||||||
response = client.get_task_info(task_id, request)
|
response = client.get_task_info(task_id, request)
|
||||||
|
|
||||||
if response.body.code != "0":
|
if response.body.code != "0":
|
||||||
raise Exception(f"Query failed: {response.body.message}")
|
raise Exception(f"Query failed: {response.body.message}")
|
||||||
|
|
||||||
status = response.body.data.task_status
|
status = response.body.data.task_status
|
||||||
|
|
||||||
if status == "SUCCESS":
|
if status == "SUCCESS":
|
||||||
return self._parse_result(response.body.data)
|
return self._parse_result(response.body.data)
|
||||||
elif status == "FAILED":
|
elif status == "FAILED":
|
||||||
raise Exception(f"Task failed: {response.body.data.error_message}")
|
raise Exception(f"Task failed: {response.body.data.error_message}")
|
||||||
|
|
||||||
print(f"Task {task_id} status: {status}, retry {i+1}/{max_retries}")
|
print(f"Task {task_id} status: {status}, retry {i + 1}/{max_retries}")
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Tingwu SDK not available, using mock result")
|
print("Tingwu SDK not available, using mock result")
|
||||||
return self._mock_result()
|
return self._mock_result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Get result error: {e}")
|
print(f"Get result error: {e}")
|
||||||
return self._mock_result()
|
return self._mock_result()
|
||||||
|
|
||||||
raise TimeoutError(f"Task {task_id} timeout")
|
raise TimeoutError(f"Task {task_id} timeout")
|
||||||
|
|
||||||
def _parse_result(self, data) -> Dict[str, Any]:
|
def _parse_result(self, data) -> Dict[str, Any]:
|
||||||
"""解析结果"""
|
"""解析结果"""
|
||||||
result = data.result
|
result = data.result
|
||||||
transcription = result.transcription
|
transcription = result.transcription
|
||||||
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
segments = []
|
segments = []
|
||||||
|
|
||||||
if transcription.paragraphs:
|
if transcription.paragraphs:
|
||||||
for para in transcription.paragraphs:
|
for para in transcription.paragraphs:
|
||||||
full_text += para.text + " "
|
full_text += para.text + " "
|
||||||
|
|
||||||
if transcription.sentences:
|
if transcription.sentences:
|
||||||
for sent in transcription.sentences:
|
for sent in transcription.sentences:
|
||||||
segments.append({
|
segments.append({
|
||||||
@@ -155,12 +150,12 @@ class TingwuClient:
|
|||||||
"text": sent.text,
|
"text": sent.text,
|
||||||
"speaker": f"Speaker {sent.speaker_id}"
|
"speaker": f"Speaker {sent.speaker_id}"
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"full_text": full_text.strip(),
|
"full_text": full_text.strip(),
|
||||||
"segments": segments
|
"segments": segments
|
||||||
}
|
}
|
||||||
|
|
||||||
def _mock_result(self) -> Dict[str, Any]:
|
def _mock_result(self) -> Dict[str, Any]:
|
||||||
"""Mock 结果"""
|
"""Mock 结果"""
|
||||||
return {
|
return {
|
||||||
@@ -169,7 +164,7 @@ class TingwuClient:
|
|||||||
{"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"}
|
{"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def transcribe(self, audio_url: str, language: str = "zh") -> Dict[str, Any]:
|
def transcribe(self, audio_url: str, language: str = "zh") -> Dict[str, Any]:
|
||||||
"""一键转录"""
|
"""一键转录"""
|
||||||
task_id = self.create_task(audio_url, language)
|
task_id = self.create_task(audio_url, language)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
278
code_review_report.md
Normal file
278
code_review_report.md
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
# InsightFlow 代码审查报告
|
||||||
|
|
||||||
|
**审查日期**: 2026年2月27日
|
||||||
|
**审查范围**: /root/.openclaw/workspace/projects/insightflow/backend/
|
||||||
|
**审查文件**: main.py, db_manager.py, api_key_manager.py, workflow_manager.py, tenant_manager.py, security_manager.py, rate_limiter.py, schema.sql
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 执行摘要
|
||||||
|
|
||||||
|
| 项目 | 数值 |
|
||||||
|
|------|------|
|
||||||
|
| 发现问题总数 | 23 |
|
||||||
|
| 严重 (Critical) | 2 |
|
||||||
|
| 高 (High) | 5 |
|
||||||
|
| 中 (Medium) | 8 |
|
||||||
|
| 低 (Low) | 8 |
|
||||||
|
| 已自动修复 | 3 |
|
||||||
|
| 代码质量评分 | **72/100** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 严重问题 (Critical)
|
||||||
|
|
||||||
|
### 🔴 C1: SQL 注入风险 - db_manager.py
|
||||||
|
**位置**: `search_entities_by_attributes()` 方法
|
||||||
|
**问题**: 使用字符串拼接构建 SQL 查询,存在 SQL 注入风险
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 问题代码
|
||||||
|
placeholders = ','.join(['?' for _ in entity_ids])
|
||||||
|
rows = conn.execute(
|
||||||
|
f"""SELECT ea.*, at.name as template_name
|
||||||
|
FROM entity_attributes ea
|
||||||
|
JOIN attribute_templates at ON ea.template_id = at.id
|
||||||
|
WHERE ea.entity_id IN ({placeholders})""", # 虽然使用了参数化,但其他地方有拼接
|
||||||
|
entity_ids
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**建议**: 确保所有动态 SQL 都使用参数化查询
|
||||||
|
|
||||||
|
### 🔴 C2: 敏感信息硬编码风险 - main.py
|
||||||
|
**位置**: 多处环境变量读取
|
||||||
|
**问题**: MASTER_KEY 等敏感配置通过环境变量获取,但缺少验证和加密存储
|
||||||
|
|
||||||
|
```python
|
||||||
|
MASTER_KEY = os.getenv("INSIGHTFLOW_MASTER_KEY", "")
|
||||||
|
```
|
||||||
|
|
||||||
|
**建议**: 添加密钥长度和格式验证,考虑使用密钥管理服务
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 高优先级问题 (High)
|
||||||
|
|
||||||
|
### 🟠 H1: 重复导入 - main.py
|
||||||
|
**位置**: 第 1-200 行
|
||||||
|
**问题**: `search_manager` 和 `performance_manager` 被重复导入两次
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 第 95-105 行
|
||||||
|
from search_manager import get_search_manager, ...
|
||||||
|
|
||||||
|
# 第 107-115 行 (重复)
|
||||||
|
from search_manager import get_search_manager, ...
|
||||||
|
|
||||||
|
# 第 117-125 行
|
||||||
|
from performance_manager import get_performance_manager, ...
|
||||||
|
|
||||||
|
# 第 127-135 行 (重复)
|
||||||
|
from performance_manager import get_performance_manager, ...
|
||||||
|
```
|
||||||
|
|
||||||
|
**状态**: ✅ 已自动修复
|
||||||
|
|
||||||
|
### 🟠 H2: 异常处理不完善 - workflow_manager.py
|
||||||
|
**位置**: `_execute_tasks_with_deps()` 方法
|
||||||
|
**问题**: 捕获所有异常但没有分类处理,可能隐藏关键错误
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 问题代码
|
||||||
|
for task, result in zip(ready_tasks, task_results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Task {task.id} failed: {result}")
|
||||||
|
# 重试逻辑...
|
||||||
|
```
|
||||||
|
|
||||||
|
**建议**: 区分可重试异常和不可重试异常
|
||||||
|
|
||||||
|
### 🟠 H3: 资源泄漏风险 - workflow_manager.py
|
||||||
|
**位置**: `WebhookNotifier` 类
|
||||||
|
**问题**: HTTP 客户端可能在异常情况下未正确关闭
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def send(self, config: WebhookConfig, message: Dict) -> bool:
|
||||||
|
try:
|
||||||
|
# ... 发送逻辑
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Webhook send failed: {e}")
|
||||||
|
return False # 异常时未清理资源
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🟠 H4: 密码明文存储风险 - tenant_manager.py
|
||||||
|
**位置**: WebDAV 配置表
|
||||||
|
**问题**: 密码字段注释建议加密,但实际未实现
|
||||||
|
|
||||||
|
```python
|
||||||
|
# schema.sql
|
||||||
|
password TEXT NOT NULL, -- 建议加密存储
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🟠 H5: 缺少输入验证 - main.py
|
||||||
|
**位置**: 多个 API 端点
|
||||||
|
**问题**: 文件上传端点缺少文件类型和大小验证
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 中优先级问题 (Medium)
|
||||||
|
|
||||||
|
### 🟡 M1: 代码重复 - db_manager.py
|
||||||
|
**位置**: 多个方法
|
||||||
|
**问题**: JSON 解析逻辑重复出现
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 重复代码模式
|
||||||
|
data['aliases'] = json.loads(data['aliases']) if data['aliases'] else []
|
||||||
|
```
|
||||||
|
|
||||||
|
**状态**: ✅ 已自动修复 (提取为辅助方法)
|
||||||
|
|
||||||
|
### 🟡 M2: 魔法数字 - tenant_manager.py
|
||||||
|
**位置**: 资源限制配置
|
||||||
|
**问题**: 使用硬编码数字
|
||||||
|
|
||||||
|
```python
|
||||||
|
"max_projects": 3,
|
||||||
|
"max_storage_mb": 100,
|
||||||
|
```
|
||||||
|
|
||||||
|
**建议**: 使用常量或配置类
|
||||||
|
|
||||||
|
### 🟡 M3: 类型注解不一致 - 多个文件
|
||||||
|
**问题**: 部分函数缺少返回类型注解,Optional 使用不规范
|
||||||
|
|
||||||
|
### 🟡 M4: 日志记录不完整 - security_manager.py
|
||||||
|
**位置**: `get_audit_logs()` 方法
|
||||||
|
**问题**: 代码逻辑混乱,有重复的数据库连接操作
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 问题代码
|
||||||
|
for row in cursor.description: # 这行逻辑有问题
|
||||||
|
col_names = [desc[0] for desc in cursor.description]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return logs
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🟡 M5: 时区处理不一致 - 多个文件
|
||||||
|
**问题**: 部分使用 `datetime.now()`,没有统一使用 UTC
|
||||||
|
|
||||||
|
### 🟡 M6: 缺少事务管理 - db_manager.py
|
||||||
|
**位置**: 多个方法
|
||||||
|
**问题**: 复杂操作没有使用事务包装
|
||||||
|
|
||||||
|
### 🟡 M7: 正则表达式未编译 - security_manager.py
|
||||||
|
**位置**: 脱敏规则应用
|
||||||
|
**问题**: 每次应用都重新编译正则
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 问题代码
|
||||||
|
masked_text = re.sub(rule.pattern, rule.replacement, masked_text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🟡 M8: 竞态条件 - rate_limiter.py
|
||||||
|
**位置**: `SlidingWindowCounter` 类
|
||||||
|
**问题**: 清理操作和计数操作之间可能存在竞态条件
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 低优先级问题 (Low)
|
||||||
|
|
||||||
|
### 🟢 L1: PEP8 格式问题
|
||||||
|
**位置**: 多个文件
|
||||||
|
**问题**:
|
||||||
|
- 行长度超过 120 字符
|
||||||
|
- 缺少文档字符串
|
||||||
|
- 导入顺序不规范
|
||||||
|
|
||||||
|
**状态**: ✅ 已自动修复 (主要格式问题)
|
||||||
|
|
||||||
|
### 🟢 L2: 未使用的导入 - main.py
|
||||||
|
**问题**: 部分导入的模块未使用
|
||||||
|
|
||||||
|
### 🟢 L3: 注释质量 - 多个文件
|
||||||
|
**问题**: 部分注释与代码不符或过于简单
|
||||||
|
|
||||||
|
### 🟢 L4: 字符串格式化不一致
|
||||||
|
**问题**: 混用 f-string、% 格式化和 .format()
|
||||||
|
|
||||||
|
### 🟢 L5: 类命名不一致
|
||||||
|
**问题**: 部分 dataclass 使用小写命名
|
||||||
|
|
||||||
|
### 🟢 L6: 缺少单元测试
|
||||||
|
**问题**: 核心逻辑缺少测试覆盖
|
||||||
|
|
||||||
|
### 🟢 L7: 配置硬编码
|
||||||
|
**问题**: 部分配置项硬编码在代码中
|
||||||
|
|
||||||
|
### 🟢 L8: 性能优化空间
|
||||||
|
**问题**: 数据库查询可以添加更多索引
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 已自动修复的问题
|
||||||
|
|
||||||
|
| 问题 | 文件 | 修复内容 |
|
||||||
|
|------|------|----------|
|
||||||
|
| 重复导入 | main.py | 移除重复的 import 语句 |
|
||||||
|
| JSON 解析重复 | db_manager.py | 提取 `_parse_json_field()` 辅助方法 |
|
||||||
|
| PEP8 格式 | 多个文件 | 修复行长度、空格等问题 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 需要人工处理的问题建议
|
||||||
|
|
||||||
|
### 优先级 1 (立即处理)
|
||||||
|
1. **修复 SQL 注入风险** - 审查所有 SQL 构建逻辑
|
||||||
|
2. **加强敏感信息处理** - 实现密码加密存储
|
||||||
|
3. **完善异常处理** - 分类处理不同类型的异常
|
||||||
|
|
||||||
|
### 优先级 2 (本周处理)
|
||||||
|
4. **统一时区处理** - 使用 UTC 时间或带时区的时间
|
||||||
|
5. **添加事务管理** - 对多表操作添加事务包装
|
||||||
|
6. **优化正则性能** - 预编译常用正则表达式
|
||||||
|
|
||||||
|
### 优先级 3 (本月处理)
|
||||||
|
7. **完善类型注解** - 为所有公共 API 添加类型注解
|
||||||
|
8. **增加单元测试** - 为核心模块添加测试
|
||||||
|
9. **代码重构** - 提取重复代码到工具模块
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 代码质量评分详情
|
||||||
|
|
||||||
|
| 维度 | 得分 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| 代码规范 | 75/100 | PEP8 基本合规,部分行过长 |
|
||||||
|
| 安全性 | 65/100 | 存在 SQL 注入和敏感信息风险 |
|
||||||
|
| 可维护性 | 70/100 | 代码重复较多,缺少文档 |
|
||||||
|
| 性能 | 75/100 | 部分查询可优化 |
|
||||||
|
| 可靠性 | 70/100 | 异常处理不完善 |
|
||||||
|
| **综合** | **72/100** | 良好,但有改进空间 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 架构建议
|
||||||
|
|
||||||
|
### 短期 (1-2 周)
|
||||||
|
- 引入 SQLAlchemy 或类似 ORM 替代原始 SQL
|
||||||
|
- 添加统一的异常处理中间件
|
||||||
|
- 实现配置管理类
|
||||||
|
|
||||||
|
### 中期 (1-2 月)
|
||||||
|
- 引入依赖注入框架
|
||||||
|
- 完善审计日志系统
|
||||||
|
- 实现 API 版本控制
|
||||||
|
|
||||||
|
### 长期 (3-6 月)
|
||||||
|
- 考虑微服务拆分
|
||||||
|
- 引入消息队列处理异步任务
|
||||||
|
- 完善监控和告警系统
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**报告生成时间**: 2026-02-27 06:15 AM (Asia/Shanghai)
|
||||||
|
**审查工具**: InsightFlow Code Review Agent
|
||||||
|
**下次审查建议**: 2026-03-27
|
||||||
Reference in New Issue
Block a user