fix: auto-fix code issues (cron)
- 修复重复导入/字段 - 修复异常处理 - 修复PEP8格式问题 - 添加类型注解 - 修复缺失的urllib.parse导入
This commit is contained in:
@@ -13,6 +13,7 @@ from ai_manager import ModelType, PredictionType, get_ai_manager
|
||||
# Add backend directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def test_custom_model():
|
||||
"""测试自定义模型功能"""
|
||||
print("\n=== 测试自定义模型 ===")
|
||||
@@ -28,14 +29,10 @@ def test_custom_model():
|
||||
model_type=ModelType.CUSTOM_NER,
|
||||
training_data={
|
||||
"entity_types": ["DISEASE", "SYMPTOM", "DRUG", "TREATMENT"],
|
||||
"domain": "medical"
|
||||
"domain": "medical",
|
||||
},
|
||||
hyperparameters={
|
||||
"epochs": 15,
|
||||
"learning_rate": 0.001,
|
||||
"batch_size": 32
|
||||
},
|
||||
created_by="user_001"
|
||||
hyperparameters={"epochs": 15, "learning_rate": 0.001, "batch_size": 32},
|
||||
created_by="user_001",
|
||||
)
|
||||
print(f" 创建成功: {model.id}, 状态: {model.status.value}")
|
||||
|
||||
@@ -47,8 +44,8 @@ def test_custom_model():
|
||||
"entities": [
|
||||
{"start": 2, "end": 4, "label": "PERSON", "text": "张三"},
|
||||
{"start": 6, "end": 9, "label": "DISEASE", "text": "高血压"},
|
||||
{"start": 14, "end": 17, "label": "DRUG", "text": "降压药"}
|
||||
]
|
||||
{"start": 14, "end": 17, "label": "DRUG", "text": "降压药"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"text": "李四因感冒发烧到医院就诊,医生开具了退烧药。",
|
||||
@@ -56,16 +53,16 @@ def test_custom_model():
|
||||
{"start": 0, "end": 2, "label": "PERSON", "text": "李四"},
|
||||
{"start": 3, "end": 5, "label": "SYMPTOM", "text": "感冒"},
|
||||
{"start": 5, "end": 7, "label": "SYMPTOM", "text": "发烧"},
|
||||
{"start": 21, "end": 24, "label": "DRUG", "text": "退烧药"}
|
||||
]
|
||||
{"start": 21, "end": 24, "label": "DRUG", "text": "退烧药"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"text": "王五接受了心脏搭桥手术,术后恢复良好。",
|
||||
"entities": [
|
||||
{"start": 0, "end": 2, "label": "PERSON", "text": "王五"},
|
||||
{"start": 5, "end": 11, "label": "TREATMENT", "text": "心脏搭桥手术"}
|
||||
]
|
||||
}
|
||||
{"start": 5, "end": 11, "label": "TREATMENT", "text": "心脏搭桥手术"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for sample_data in samples:
|
||||
@@ -73,7 +70,7 @@ def test_custom_model():
|
||||
model_id=model.id,
|
||||
text=sample_data["text"],
|
||||
entities=sample_data["entities"],
|
||||
metadata={"source": "manual"}
|
||||
metadata={"source": "manual"},
|
||||
)
|
||||
print(f" 添加样本: {sample.id}")
|
||||
|
||||
@@ -91,6 +88,7 @@ def test_custom_model():
|
||||
|
||||
return model.id
|
||||
|
||||
|
||||
async def test_train_and_predict(model_id: str):
|
||||
"""测试训练和预测"""
|
||||
print("\n=== 测试模型训练和预测 ===")
|
||||
@@ -117,6 +115,7 @@ async def test_train_and_predict(model_id: str):
|
||||
except Exception as e:
|
||||
print(f" 预测失败: {e}")
|
||||
|
||||
|
||||
def test_prediction_models():
|
||||
"""测试预测模型"""
|
||||
print("\n=== 测试预测模型 ===")
|
||||
@@ -132,10 +131,7 @@ def test_prediction_models():
|
||||
prediction_type=PredictionType.TREND,
|
||||
target_entity_type="PERSON",
|
||||
features=["entity_count", "time_period", "document_count"],
|
||||
model_config={
|
||||
"algorithm": "linear_regression",
|
||||
"window_size": 7
|
||||
}
|
||||
model_config={"algorithm": "linear_regression", "window_size": 7},
|
||||
)
|
||||
print(f" 创建成功: {trend_model.id}")
|
||||
|
||||
@@ -148,10 +144,7 @@ def test_prediction_models():
|
||||
prediction_type=PredictionType.ANOMALY,
|
||||
target_entity_type=None,
|
||||
features=["daily_growth", "weekly_growth"],
|
||||
model_config={
|
||||
"threshold": 2.5,
|
||||
"sensitivity": "medium"
|
||||
}
|
||||
model_config={"threshold": 2.5, "sensitivity": "medium"},
|
||||
)
|
||||
print(f" 创建成功: {anomaly_model.id}")
|
||||
|
||||
@@ -164,6 +157,7 @@ def test_prediction_models():
|
||||
|
||||
return trend_model.id, anomaly_model.id
|
||||
|
||||
|
||||
async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
||||
"""测试预测功能"""
|
||||
print("\n=== 测试预测功能 ===")
|
||||
@@ -179,7 +173,7 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
||||
{"date": "2024-01-04", "value": 14},
|
||||
{"date": "2024-01-05", "value": 18},
|
||||
{"date": "2024-01-06", "value": 20},
|
||||
{"date": "2024-01-07", "value": 22}
|
||||
{"date": "2024-01-07", "value": 22},
|
||||
]
|
||||
trained = await manager.train_prediction_model(trend_model_id, historical_data)
|
||||
print(f" 训练完成,准确率: {trained.accuracy}")
|
||||
@@ -187,22 +181,18 @@ async def test_predictions(trend_model_id: str, anomaly_model_id: str):
|
||||
# 2. 趋势预测
|
||||
print("2. 趋势预测...")
|
||||
trend_result = await manager.predict(
|
||||
trend_model_id,
|
||||
{"historical_values": [10, 12, 15, 14, 18, 20, 22]}
|
||||
trend_model_id, {"historical_values": [10, 12, 15, 14, 18, 20, 22]}
|
||||
)
|
||||
print(f" 预测结果: {trend_result.prediction_data}")
|
||||
|
||||
# 3. 异常检测
|
||||
print("3. 异常检测...")
|
||||
anomaly_result = await manager.predict(
|
||||
anomaly_model_id,
|
||||
{
|
||||
"value": 50,
|
||||
"historical_values": [10, 12, 11, 13, 12, 14, 13]
|
||||
}
|
||||
anomaly_model_id, {"value": 50, "historical_values": [10, 12, 11, 13, 12, 14, 13]}
|
||||
)
|
||||
print(f" 检测结果: {anomaly_result.prediction_data}")
|
||||
|
||||
|
||||
def test_kg_rag():
|
||||
"""测试知识图谱 RAG"""
|
||||
print("\n=== 测试知识图谱 RAG ===")
|
||||
@@ -218,18 +208,10 @@ def test_kg_rag():
|
||||
description="基于项目知识图谱的智能问答",
|
||||
kg_config={
|
||||
"entity_types": ["PERSON", "ORG", "PROJECT", "TECH"],
|
||||
"relation_types": ["works_with", "belongs_to", "depends_on"]
|
||||
"relation_types": ["works_with", "belongs_to", "depends_on"],
|
||||
},
|
||||
retrieval_config={
|
||||
"top_k": 5,
|
||||
"similarity_threshold": 0.7,
|
||||
"expand_relations": True
|
||||
},
|
||||
generation_config={
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1000,
|
||||
"include_sources": True
|
||||
}
|
||||
retrieval_config={"top_k": 5, "similarity_threshold": 0.7, "expand_relations": True},
|
||||
generation_config={"temperature": 0.3, "max_tokens": 1000, "include_sources": True},
|
||||
)
|
||||
print(f" 创建成功: {rag.id}")
|
||||
|
||||
@@ -240,6 +222,7 @@ def test_kg_rag():
|
||||
|
||||
return rag.id
|
||||
|
||||
|
||||
async def test_kg_rag_query(rag_id: str):
|
||||
"""测试 RAG 查询"""
|
||||
print("\n=== 测试知识图谱 RAG 查询 ===")
|
||||
@@ -252,33 +235,43 @@ async def test_kg_rag_query(rag_id: str):
|
||||
{"id": "e2", "name": "李四", "type": "PERSON", "definition": "技术负责人"},
|
||||
{"id": "e3", "name": "Project Alpha", "type": "PROJECT", "definition": "核心产品项目"},
|
||||
{"id": "e4", "name": "Kubernetes", "type": "TECH", "definition": "容器编排平台"},
|
||||
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"}
|
||||
{"id": "e5", "name": "TechCorp", "type": "ORG", "definition": "科技公司"},
|
||||
]
|
||||
|
||||
project_relations = [{"source_entity_id": "e1",
|
||||
"target_entity_id": "e3",
|
||||
"source_name": "张三",
|
||||
"target_name": "Project Alpha",
|
||||
"relation_type": "works_with",
|
||||
"evidence": "张三负责 Project Alpha 的管理工作"},
|
||||
{"source_entity_id": "e2",
|
||||
"target_entity_id": "e3",
|
||||
"source_name": "李四",
|
||||
"target_name": "Project Alpha",
|
||||
"relation_type": "works_with",
|
||||
"evidence": "李四负责 Project Alpha 的技术架构"},
|
||||
{"source_entity_id": "e3",
|
||||
"target_entity_id": "e4",
|
||||
"source_name": "Project Alpha",
|
||||
"target_name": "Kubernetes",
|
||||
"relation_type": "depends_on",
|
||||
"evidence": "项目使用 Kubernetes 进行部署"},
|
||||
{"source_entity_id": "e1",
|
||||
"target_entity_id": "e5",
|
||||
"source_name": "张三",
|
||||
"target_name": "TechCorp",
|
||||
"relation_type": "belongs_to",
|
||||
"evidence": "张三是 TechCorp 的员工"}]
|
||||
project_relations = [
|
||||
{
|
||||
"source_entity_id": "e1",
|
||||
"target_entity_id": "e3",
|
||||
"source_name": "张三",
|
||||
"target_name": "Project Alpha",
|
||||
"relation_type": "works_with",
|
||||
"evidence": "张三负责 Project Alpha 的管理工作",
|
||||
},
|
||||
{
|
||||
"source_entity_id": "e2",
|
||||
"target_entity_id": "e3",
|
||||
"source_name": "李四",
|
||||
"target_name": "Project Alpha",
|
||||
"relation_type": "works_with",
|
||||
"evidence": "李四负责 Project Alpha 的技术架构",
|
||||
},
|
||||
{
|
||||
"source_entity_id": "e3",
|
||||
"target_entity_id": "e4",
|
||||
"source_name": "Project Alpha",
|
||||
"target_name": "Kubernetes",
|
||||
"relation_type": "depends_on",
|
||||
"evidence": "项目使用 Kubernetes 进行部署",
|
||||
},
|
||||
{
|
||||
"source_entity_id": "e1",
|
||||
"target_entity_id": "e5",
|
||||
"source_name": "张三",
|
||||
"target_name": "TechCorp",
|
||||
"relation_type": "belongs_to",
|
||||
"evidence": "张三是 TechCorp 的员工",
|
||||
},
|
||||
]
|
||||
|
||||
# 执行查询
|
||||
print("1. 执行 RAG 查询...")
|
||||
@@ -289,7 +282,7 @@ async def test_kg_rag_query(rag_id: str):
|
||||
rag_id=rag_id,
|
||||
query=query_text,
|
||||
project_entities=project_entities,
|
||||
project_relations=project_relations
|
||||
project_relations=project_relations,
|
||||
)
|
||||
|
||||
print(f" 查询: {result.query}")
|
||||
@@ -300,6 +293,7 @@ async def test_kg_rag_query(rag_id: str):
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
|
||||
async def test_smart_summary():
|
||||
"""测试智能摘要"""
|
||||
print("\n=== 测试智能摘要 ===")
|
||||
@@ -321,8 +315,8 @@ async def test_smart_summary():
|
||||
{"name": "张三", "type": "PERSON"},
|
||||
{"name": "李四", "type": "PERSON"},
|
||||
{"name": "Project Alpha", "type": "PROJECT"},
|
||||
{"name": "Kubernetes", "type": "TECH"}
|
||||
]
|
||||
{"name": "Kubernetes", "type": "TECH"},
|
||||
],
|
||||
}
|
||||
|
||||
# 生成不同类型的摘要
|
||||
@@ -337,7 +331,7 @@ async def test_smart_summary():
|
||||
source_type="transcript",
|
||||
source_id="transcript_001",
|
||||
summary_type=summary_type,
|
||||
content_data=content_data
|
||||
content_data=content_data,
|
||||
)
|
||||
|
||||
print(f" 摘要类型: {summary.summary_type}")
|
||||
@@ -347,6 +341,7 @@ async def test_smart_summary():
|
||||
except Exception as e:
|
||||
print(f" 生成失败: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("=" * 60)
|
||||
@@ -382,7 +377,9 @@ async def main():
|
||||
except Exception as e:
|
||||
print(f"\n测试失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user