image_processor.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """图像预处理模块"""
  2. import cv2
  3. import numpy as np
  4. from typing import Optional, Tuple
  5. from config.config import ModelConfig
  6. class ImageProcessor:
  7. """图像预处理器"""
  8. def __init__(self, config: ModelConfig):
  9. self.config = config
  10. self.input_size = config.input_size
  11. self.mean = np.array(config.mean, dtype=np.float32)
  12. self.std = np.array(config.std, dtype=np.float32)
  13. # 初始化CUDA加速
  14. self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
  15. self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
  16. # 存储原始图像尺寸
  17. self.orig_h = 0
  18. self.orig_w = 0
  19. def preprocess(self, image_path: str) -> Optional[np.ndarray]:
  20. """图像预处理
  21. Args:
  22. image_path: 图像文件路径
  23. Returns:
  24. 预处理后的图像数据,格式为NCHW
  25. """
  26. try:
  27. if self.use_cuda:
  28. return self._preprocess_gpu(image_path)
  29. else:
  30. return self._preprocess_cpu(image_path)
  31. except Exception as e:
  32. print(f"预处理图像 {image_path} 时出错: {e}")
  33. return None
  34. def _preprocess_gpu(self, image_path: str) -> Optional[np.ndarray]:
  35. """GPU加速预处理"""
  36. # 读取图像到GPU
  37. cpu_image = cv2.imread(image_path)
  38. if cpu_image is None:
  39. print(f"无法加载图像:{image_path}")
  40. return None
  41. self.orig_h, self.orig_w = cpu_image.shape[:2]
  42. self.gpu_frame.upload(cpu_image)
  43. if self.gpu_frame.empty():
  44. print(f"无法上传图像到GPU:{image_path}")
  45. return None
  46. # GPU预处理流水线
  47. gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
  48. gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
  49. # 下载到CPU进行后续处理
  50. image_orig = gpu_rgb.download()
  51. return self._normalize_image(image_orig)
  52. def _preprocess_cpu(self, image_path: str) -> Optional[np.ndarray]:
  53. """CPU预处理"""
  54. image_orig = cv2.imread(image_path)
  55. if image_orig is None:
  56. print(f"无法加载图像:{image_path}")
  57. return None
  58. self.orig_h, self.orig_w = image_orig.shape[:2]
  59. image_orig = cv2.resize(image_orig, self.input_size)
  60. image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
  61. return self._normalize_image(image_orig)
  62. def _normalize_image(self, image: np.ndarray) -> np.ndarray:
  63. """图像归一化"""
  64. # 归一化到[0,1]
  65. image = image.astype(np.float32) / 255.0
  66. # 标准化
  67. image -= self.mean[None, None, :]
  68. image /= self.std[None, None, :]
  69. # 转换为CHW格式
  70. image = np.transpose(image, (2, 0, 1))
  71. # 添加batch维度
  72. return np.expand_dims(image, axis=0)
  73. def get_original_size(self) -> Tuple[int, int]:
  74. """获取原始图像尺寸
  75. Returns:
  76. (width, height)
  77. """
  78. return self.orig_w, self.orig_h