|
@@ -0,0 +1,47 @@
|
|
|
+import pytest
|
|
|
+import os
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+from src.core.inference import InferenceEngine
|
|
|
+
|
|
|
+class TestInference:
|
|
|
+ @pytest.fixture
|
|
|
+ def inference_engine(self):
|
|
|
+ model_path = os.path.join('data', 'models', 'UAV-250321.onnx')
|
|
|
+ return InferenceEngine(model_path)
|
|
|
+
|
|
|
+ def test_single_image_inference(self, inference_engine):
|
|
|
+ # 创建一个测试图像
|
|
|
+ test_image = np.zeros((640, 640, 3), dtype=np.uint8)
|
|
|
+ test_image[100:200, 100:200] = 255 # 添加一个白色方块
|
|
|
+
|
|
|
+ # 执行推理
|
|
|
+ results = inference_engine.infer(test_image)
|
|
|
+
|
|
|
+ # 验证结果格式
|
|
|
+ assert isinstance(results, list)
|
|
|
+ for result in results:
|
|
|
+ assert 'bbox' in result
|
|
|
+ assert 'confidence' in result
|
|
|
+ assert 'class_id' in result
|
|
|
+
|
|
|
+ def test_batch_inference(self, inference_engine):
|
|
|
+ # 创建测试图像列表
|
|
|
+ test_images = [
|
|
|
+ np.zeros((640, 640, 3), dtype=np.uint8) for _ in range(3)
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 执行批量推理
|
|
|
+ results = inference_engine.batch_infer(test_images)
|
|
|
+
|
|
|
+ # 验证结果
|
|
|
+ assert len(results) == len(test_images)
|
|
|
+ assert all(isinstance(r, list) for r in results)
|
|
|
+
|
|
|
+ def test_invalid_input(self, inference_engine):
|
|
|
+ # 测试无效输入
|
|
|
+ with pytest.raises(ValueError):
|
|
|
+ inference_engine.infer(None)
|
|
|
+
|
|
|
+ with pytest.raises(ValueError):
|
|
|
+ inference_engine.infer(np.zeros((100, 100), dtype=np.uint8)) # 灰度图像
|