1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import pytest
- import os
- import cv2
- import numpy as np
- from src.core.inference import InferenceEngine
- class TestInference:
- @pytest.fixture
- def inference_engine(self):
- model_path = os.path.join('data', 'models', 'UAV-250321.onnx')
- return InferenceEngine(model_path)
- def test_single_image_inference(self, inference_engine):
- # 创建一个测试图像
- test_image = np.zeros((640, 640, 3), dtype=np.uint8)
- test_image[100:200, 100:200] = 255 # 添加一个白色方块
- # 执行推理
- results = inference_engine.infer(test_image)
-
- # 验证结果格式
- assert isinstance(results, list)
- for result in results:
- assert 'bbox' in result
- assert 'confidence' in result
- assert 'class_id' in result
- def test_batch_inference(self, inference_engine):
- # 创建测试图像列表
- test_images = [
- np.zeros((640, 640, 3), dtype=np.uint8) for _ in range(3)
- ]
-
- # 执行批量推理
- results = inference_engine.batch_infer(test_images)
-
- # 验证结果
- assert len(results) == len(test_images)
- assert all(isinstance(r, list) for r in results)
- def test_invalid_input(self, inference_engine):
- # 测试无效输入
- with pytest.raises(ValueError):
- inference_engine.infer(None)
-
- with pytest.raises(ValueError):
- inference_engine.infer(np.zeros((100, 100), dtype=np.uint8)) # 灰度图像
|