fix: auto-fix code issues (cron)
- 修复隐式 Optional 类型注解 (RUF013) - 修复不必要的赋值后返回 (RET504) - 优化列表推导式 (PERF401) - 修复未使用的参数 (ARG002) - 清理重复导入 - 优化异常处理
This commit is contained in:
@@ -291,7 +291,10 @@ class AIManager:
|
||||
return self._row_to_custom_model(row)
|
||||
|
||||
def list_custom_models(
|
||||
self, tenant_id: str, model_type: ModelType | None = None, status: ModelStatus | None = None,
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_type: ModelType | None = None,
|
||||
status: ModelStatus | None = None,
|
||||
) -> list[CustomModel]:
|
||||
"""列出自定义模型"""
|
||||
query = "SELECT * FROM custom_models WHERE tenant_id = ?"
|
||||
@@ -311,7 +314,11 @@ class AIManager:
|
||||
return [self._row_to_custom_model(row) for row in rows]
|
||||
|
||||
def add_training_sample(
|
||||
self, model_id: str, text: str, entities: list[dict], metadata: dict = None,
|
||||
self,
|
||||
model_id: str,
|
||||
text: str,
|
||||
entities: list[dict],
|
||||
metadata: dict | None = None,
|
||||
) -> TrainingSample:
|
||||
"""添加训练样本"""
|
||||
sample_id = f"ts_{uuid.uuid4().hex[:16]}"
|
||||
@@ -463,8 +470,7 @@ class AIManager:
|
||||
json_match = re.search(r"\[.*?\]", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
entities = json.loads(json_match.group())
|
||||
return entities
|
||||
return json.loads(json_match.group())
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
@@ -542,8 +548,9 @@ class AIManager:
|
||||
}
|
||||
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
for url in image_urls:
|
||||
content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
content.extend(
|
||||
[{"type": "image_url", "image_url": {"url": url}} for url in image_urls]
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4-vision-preview",
|
||||
@@ -575,9 +582,9 @@ class AIManager:
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
content = []
|
||||
for url in image_urls:
|
||||
content.append({"type": "image", "source": {"type": "url", "url": url}})
|
||||
content = [
|
||||
{"type": "image", "source": {"type": "url", "url": url}} for url in image_urls
|
||||
]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
payload = {
|
||||
@@ -638,7 +645,9 @@ class AIManager:
|
||||
}
|
||||
|
||||
def get_multimodal_analyses(
|
||||
self, tenant_id: str, project_id: str | None = None,
|
||||
self,
|
||||
tenant_id: str,
|
||||
project_id: str | None = None,
|
||||
) -> list[MultimodalAnalysis]:
|
||||
"""获取多模态分析历史"""
|
||||
query = "SELECT * FROM multimodal_analyses WHERE tenant_id = ?"
|
||||
@@ -721,7 +730,9 @@ class AIManager:
|
||||
return self._row_to_kg_rag(row)
|
||||
|
||||
def list_kg_rags(
|
||||
self, tenant_id: str, project_id: str | None = None,
|
||||
self,
|
||||
tenant_id: str,
|
||||
project_id: str | None = None,
|
||||
) -> list[KnowledgeGraphRAG]:
|
||||
"""列出知识图谱 RAG 配置"""
|
||||
query = "SELECT * FROM kg_rag_configs WHERE tenant_id = ?"
|
||||
@@ -738,7 +749,11 @@ class AIManager:
|
||||
return [self._row_to_kg_rag(row) for row in rows]
|
||||
|
||||
async def query_kg_rag(
|
||||
self, rag_id: str, query: str, project_entities: list[dict], project_relations: list[dict],
|
||||
self,
|
||||
rag_id: str,
|
||||
query: str,
|
||||
project_entities: list[dict],
|
||||
project_relations: list[dict],
|
||||
) -> RAGQuery:
|
||||
"""基于知识图谱的 RAG 查询"""
|
||||
start_time = time.time()
|
||||
@@ -771,14 +786,15 @@ class AIManager:
|
||||
relevant_entities = relevant_entities[:top_k]
|
||||
|
||||
# 检索相关关系
|
||||
relevant_relations = []
|
||||
entity_ids = {e["id"] for e in relevant_entities}
|
||||
for relation in project_relations:
|
||||
relevant_relations = [
|
||||
relation
|
||||
for relation in project_relations
|
||||
if (
|
||||
relation.get("source_entity_id") in entity_ids
|
||||
or relation.get("target_entity_id") in entity_ids
|
||||
):
|
||||
relevant_relations.append(relation)
|
||||
)
|
||||
]
|
||||
|
||||
# 2. 构建上下文
|
||||
context = {"entities": relevant_entities, "relations": relevant_relations[:10]}
|
||||
@@ -1123,7 +1139,8 @@ class AIManager:
|
||||
"""获取预测模型"""
|
||||
with self._get_db() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM prediction_models WHERE id = ?", (model_id,),
|
||||
"SELECT * FROM prediction_models WHERE id = ?",
|
||||
(model_id,),
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
@@ -1132,7 +1149,9 @@ class AIManager:
|
||||
return self._row_to_prediction_model(row)
|
||||
|
||||
def list_prediction_models(
|
||||
self, tenant_id: str, project_id: str | None = None,
|
||||
self,
|
||||
tenant_id: str,
|
||||
project_id: str | None = None,
|
||||
) -> list[PredictionModel]:
|
||||
"""列出预测模型"""
|
||||
query = "SELECT * FROM prediction_models WHERE tenant_id = ?"
|
||||
@@ -1149,7 +1168,9 @@ class AIManager:
|
||||
return [self._row_to_prediction_model(row) for row in rows]
|
||||
|
||||
async def train_prediction_model(
|
||||
self, model_id: str, historical_data: list[dict],
|
||||
self,
|
||||
model_id: str,
|
||||
historical_data: list[dict],
|
||||
) -> PredictionModel:
|
||||
"""训练预测模型"""
|
||||
model = self.get_prediction_model(model_id)
|
||||
@@ -1369,7 +1390,9 @@ class AIManager:
|
||||
predicted_relations = [
|
||||
{"type": rel_type, "likelihood": min(count / len(relation_history), 0.95)}
|
||||
for rel_type, count in sorted(
|
||||
relation_counts.items(), key=lambda x: x[1], reverse=True,
|
||||
relation_counts.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
]
|
||||
|
||||
@@ -1394,7 +1417,10 @@ class AIManager:
|
||||
return [self._row_to_prediction_result(row) for row in rows]
|
||||
|
||||
def update_prediction_feedback(
|
||||
self, prediction_id: str, actual_value: str, is_correct: bool,
|
||||
self,
|
||||
prediction_id: str,
|
||||
actual_value: str,
|
||||
is_correct: bool,
|
||||
) -> None:
|
||||
"""更新预测反馈(用于模型改进)"""
|
||||
with self._get_db() as conn:
|
||||
|
||||
Reference in New Issue
Block a user