inference.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import numpy as np
  2. import onnxruntime as ort
  3. import cv2
  4. import os
  5. import argparse
  6. import pickle
  7. from glob import glob
  8. from typing import List, Dict, Optional
  9. import time
  10. from datetime import datetime
  11. from tqdm import tqdm
  12. from ..analysis.report_generator import ReportGenerator
  13. class ONNXDetector:
  14. def select_input_directory(self) -> Optional[str]:
  15. """图形界面选择输入目录"""
  16. import tkinter as tk
  17. from tkinter import filedialog
  18. root = tk.Tk()
  19. root.withdraw()
  20. folder_path = filedialog.askdirectory(title='选择输入目录')
  21. return folder_path if folder_path else None
  22. def __init__(self, model_path: str = 'data/models/UAV-250321.onnx', threshold: float = 0.5,
  23. output_dir: str = "None", save_empty: bool = False,
  24. max_bbox_ratio: float = 0.5, # 已有该参数
  25. input_dir: str = "None"):
  26. self.image_count = 0
  27. self.detection_records = []
  28. self.model_path = model_path
  29. self.input_dir = input_dir
  30. # 初始化ONNX会话
  31. self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
  32. self.session = ort.InferenceSession(model_path, providers=self.providers)
  33. # 获取模型输入尺寸
  34. input_dims = self.session.get_inputs()[0].shape
  35. self.input_shape = tuple([dim if isinstance(dim, int) else 1 for dim in input_dims])
  36. # 从模型路径解析版本号
  37. self.model_version = model_path.split('-')[-1].split('.')[0] if '-' in model_path else '1.0'
  38. self.input_size = f"{self.input_shape[2]}x{self.input_shape[3]}"
  39. self.threshold = threshold
  40. self.max_bbox_ratio = max_bbox_ratio # 已有该赋值
  41. self.save_empty = save_empty
  42. self.input_name = self.session.get_inputs()[0].name
  43. self.output_name = self.session.get_outputs()[0].name
  44. # 类别映射
  45. self.class_names = {0: 'UAV'}
  46. # 加载误报特征库
  47. self.false_positive_features = {}
  48. try:
  49. with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f:
  50. self.false_positive_features = pickle.load(f)
  51. except FileNotFoundError:
  52. print("未找到误报特征库文件,跳过加载")
  53. # 创建输出目录
  54. if output_dir is None and input_dir is not None:
  55. self.output_dir = f"{os.path.basename(os.path.normpath(input_dir))}_results"
  56. else:
  57. self.output_dir = output_dir
  58. # 创建目标存储目录
  59. self.targets_dir = os.path.join(self.output_dir, 'targets')
  60. os.makedirs(self.targets_dir, exist_ok=True)
  61. def preprocess(self, image_path: str) -> Optional[np.ndarray]:
  62. """图像预处理"""
  63. # 初始化CUDA加速
  64. if not hasattr(self, 'use_cuda'):
  65. self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
  66. self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
  67. # 读取图像到GPU(如果可用)
  68. if self.use_cuda:
  69. self.gpu_frame.upload(cv2.imread(image_path))
  70. if self.gpu_frame.empty():
  71. print(f"无法加载图像:{image_path}")
  72. return None
  73. # GPU预处理流水线
  74. gpu_resized = cv2.cuda.resize(self.gpu_frame, (self.input_shape[3], self.input_shape[2]))
  75. gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
  76. # 下载到CPU进行后续处理
  77. image_orig = gpu_rgb.download()
  78. self.orig_h, self.orig_w = image_orig.shape[:2]
  79. else:
  80. # CPU回退路径
  81. image_orig = cv2.imread(image_path)
  82. if image_orig is None:
  83. print(f"无法加载图像:{image_path}")
  84. return None
  85. self.orig_h, self.orig_w = image_orig.shape[:2]
  86. image_orig = cv2.resize(image_orig, (self.input_shape[3], self.input_shape[2]))
  87. image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
  88. # 统一的后处理
  89. image = image_orig.astype(np.float32) / 255.0
  90. image = np.transpose(image, (2, 0, 1)) # CHW 格式
  91. return np.expand_dims(image, axis=0)
  92. def inference(self, input_data: np.ndarray) -> np.ndarray:
  93. """执行模型推理"""
  94. scale_factor = np.array([[1, 1]], dtype=np.float32)
  95. return self.session.run(
  96. [self.output_name],
  97. {self.input_name: input_data, 'scale_factor': scale_factor}
  98. )[0]
  99. def _is_false_positive(self, roi: np.ndarray) -> bool:
  100. """使用ORB特征匹配验证是否误报"""
  101. if not self.false_positive_features:
  102. return False
  103. # 初始化特征检测器
  104. detector = cv2.ORB_create()
  105. _, des = detector.detectAndCompute(roi, None)
  106. # 与特征库进行匹配
  107. for fp_feature in self.false_positive_features.values():
  108. if des is None or fp_feature['features'] is None:
  109. continue
  110. # 使用FLANN匹配器
  111. flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict())
  112. matches = flann.knnMatch(des, fp_feature['features'], k=2)
  113. # 处理空匹配情况
  114. if not matches:
  115. continue
  116. # 安全验证匹配结果
  117. good_matches = []
  118. for match_group in matches:
  119. # 确保match_group有足够元素防止索引错误
  120. if len(match_group) < 2:
  121. continue
  122. m, n = match_group
  123. if m.distance < 0.7 * n.distance:
  124. good_matches.append(m)
  125. # 计算优质匹配数量
  126. if len(good_matches) > 15: # 匹配阈值
  127. return True
  128. return False
  129. def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> int:
  130. """后处理检测结果"""
  131. valid_detections = 0
  132. img_out = image_orig.copy()
  133. detections_list = []
  134. for det in detections:
  135. class_id = int(det[0])
  136. confidence = det[1]
  137. if confidence < self.threshold:
  138. continue
  139. # 坐标转换和边界检查
  140. x1, y1, x2, y2 = self._convert_coordinates(det[2:6])
  141. # 计算检测框面积比例
  142. bbox_area = (x2 - x1) * (y2 - y1)
  143. image_area = self.orig_w * self.orig_h
  144. if bbox_area / image_area > self.max_bbox_ratio:
  145. continue
  146. # 截取ROI区域进行误报验证
  147. roi = image_orig[y1:y2, x1:x2]
  148. if roi.size == 0:
  149. continue
  150. if self._is_false_positive(roi):
  151. continue
  152. # 生成唯一文件名并保存目标
  153. timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
  154. target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{timestamp}_{valid_detections}.jpg"
  155. cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
  156. # 绘制检测框和标签
  157. label = f'{self.class_names[class_id]} {confidence:.2f}'
  158. cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
  159. cv2.putText(img_out, label, (x1, y1 - 10),
  160. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  161. valid_detections += 1
  162. detections_list.append({
  163. 'class_id': class_id,
  164. 'confidence': float(confidence),
  165. 'bbox': [x1, y1, x2, y2],
  166. 'orig_w': self.orig_w,
  167. 'orig_h': self.orig_h
  168. })
  169. return valid_detections, img_out, detections_list
  170. def _convert_coordinates(self, coords: List[float]) -> tuple:
  171. """将模型输出坐标转换为原始图像尺寸"""
  172. x1 = int(coords[0] * self.orig_w / self.input_shape[3])
  173. y1 = int(coords[1] * self.orig_h / self.input_shape[2])
  174. x2 = int(coords[2] * self.orig_w / self.input_shape[3])
  175. y2 = int(coords[3] * self.orig_h / self.input_shape[2])
  176. # 边界检查
  177. clamp = lambda val, max_val: max(0, min(val, max_val - 1))
  178. return (
  179. clamp(x1, self.orig_w),
  180. clamp(y1, self.orig_h),
  181. clamp(x2, self.orig_w),
  182. clamp(y2, self.orig_h)
  183. )
  184. def process_image(self, image_path: str) -> int:
  185. """端到端处理单张图像"""
  186. input_data = self.preprocess(image_path)
  187. if input_data is None:
  188. return 0
  189. detections = self.inference(input_data)
  190. valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path)
  191. # 无论是否检测到目标都记录信息
  192. record = {
  193. 'image_path': image_path,
  194. 'has_detection': valid_count > 0,
  195. 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  196. }
  197. filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
  198. output_path = os.path.join(self.output_dir, filename)
  199. # 根据设置保存图片
  200. if valid_count > 0 or self.save_empty:
  201. cv2.imwrite(output_path, processed_img)
  202. else:
  203. output_path = None
  204. self.image_count += 1
  205. # 记录检测信息
  206. if valid_count > 0:
  207. record['detections'] = detections_list
  208. self.detection_records.append(record)
  209. return valid_count
  210. # 命令行接口
  211. if __name__ == '__main__':
  212. parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
  213. parser.add_argument('--input', type=str, help='输入图像路径或目录')
  214. parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
  215. parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
  216. parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
  217. help='检测框最大面积比例阈值,默认0.05')
  218. parser.add_argument('--report', type=str, choices=['csv', 'excel', 'all'], default='all',
  219. help='输出报告格式: csv|excel|all')
  220. parser.add_argument('--save-empty', action='store_true',
  221. help='是否保存未检测到目标的图片')
  222. parser.add_argument('--gui', action='store_true',
  223. help='启用图形界面选择输入目录')
  224. args = parser.parse_args()
  225. # GUI模式处理
  226. if args.gui:
  227. detector = ONNXDetector(
  228. threshold=args.threshold,
  229. max_bbox_ratio=args.max_bbox_ratio,
  230. output_dir=args.output,
  231. save_empty=args.save_empty
  232. )
  233. input_dir = detector.select_input_directory()
  234. if not input_dir:
  235. print("未选择目录,程序退出")
  236. exit()
  237. args.input = input_dir
  238. # 初始化检测器时传递参数
  239. detector = ONNXDetector(
  240. threshold=args.threshold,
  241. max_bbox_ratio=args.max_bbox_ratio, # 添加该参数传递
  242. output_dir=args.output,
  243. save_empty=args.save_empty,
  244. input_dir=args.input if os.path.isdir(args.input) else None
  245. )
  246. def process_single(image_path: str):
  247. detections = detector.process_image(image_path)
  248. print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标')
  249. if os.path.isdir(args.input):
  250. total = 0
  251. image_files = []
  252. for root, _, files in os.walk(args.input):
  253. image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')])
  254. for img_file in tqdm(image_files, desc='Processing images'):
  255. total += detector.process_image(img_file)
  256. print(f'批量处理完成!共检测到 {total} 个目标')
  257. # 生成检测报告
  258. if args.report in ('csv', 'all'):
  259. csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
  260. ReportGenerator(detector).generate_csv(csv_path)
  261. print(f'CSV报告已生成: {csv_path}')
  262. if args.report in ('excel', 'all'):
  263. excel_path = os.path.join(detector.output_dir, 'detection_report.xlsx')
  264. ReportGenerator(detector).generate_excel(excel_path)
  265. print(f'Excel报告已生成: {excel_path}')
  266. else:
  267. detections = detector.process_image(args.input)