test_inference.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import pytest
  2. import os
  3. import cv2
  4. import numpy as np
  5. from src.core.inference import InferenceEngine
  6. class TestInference:
  7. @pytest.fixture
  8. def inference_engine(self):
  9. model_path = os.path.join('data', 'models', 'UAV-250321.onnx')
  10. return InferenceEngine(model_path)
  11. def test_single_image_inference(self, inference_engine):
  12. # 创建一个测试图像
  13. test_image = np.zeros((640, 640, 3), dtype=np.uint8)
  14. test_image[100:200, 100:200] = 255 # 添加一个白色方块
  15. # 执行推理
  16. results = inference_engine.infer(test_image)
  17. # 验证结果格式
  18. assert isinstance(results, list)
  19. for result in results:
  20. assert 'bbox' in result
  21. assert 'confidence' in result
  22. assert 'class_id' in result
  23. def test_batch_inference(self, inference_engine):
  24. # 创建测试图像列表
  25. test_images = [
  26. np.zeros((640, 640, 3), dtype=np.uint8) for _ in range(3)
  27. ]
  28. # 执行批量推理
  29. results = inference_engine.batch_infer(test_images)
  30. # 验证结果
  31. assert len(results) == len(test_images)
  32. assert all(isinstance(r, list) for r in results)
  33. def test_invalid_input(self, inference_engine):
  34. # 测试无效输入
  35. with pytest.raises(ValueError):
  36. inference_engine.infer(None)
  37. with pytest.raises(ValueError):
  38. inference_engine.infer(np.zeros((100, 100), dtype=np.uint8)) # 灰度图像