inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import numpy as np
  2. import onnxruntime as ort
  3. import cv2
  4. import os
  5. import argparse
  6. import pickle
  7. from glob import glob
  8. from typing import List, Dict, Optional
  9. import time
  10. from datetime import datetime
  11. from tqdm import tqdm
  12. import sys
  13. from report_generator import ReportGenerator
  14. class ONNXDetector:
  15. def select_input_directory(self) -> Optional[str]:
  16. """图形界面选择输入目录"""
  17. import tkinter as tk
  18. from tkinter import filedialog
  19. root = tk.Tk()
  20. root.withdraw()
  21. folder_path = filedialog.askdirectory(title='选择输入目录')
  22. return folder_path if folder_path else None
  23. def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5,
  24. output_dir: str = None, save_empty: bool = False,
  25. max_bbox_ratio: float = 0.5,
  26. input_dir: str = None):
  27. self.image_count = 0
  28. self.detection_records = []
  29. self.model_path = model_path
  30. self.input_dir = input_dir
  31. self.save_empty = save_empty
  32. self.threshold = threshold
  33. self.confThreshold = threshold
  34. # 初始化ONNX会话
  35. so = ort.SessionOptions()
  36. so.log_severity_level = 3
  37. self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
  38. self.session = ort.InferenceSession(model_path, so, providers=self.providers)
  39. # 模型参数设置
  40. self.input_size = (640, 640) # width, height
  41. self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32)
  42. self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32)
  43. self.max_bbox_ratio = max_bbox_ratio
  44. # 从模型路径解析版本号
  45. self.model_version = os.path.basename(model_path).split('.')[0]
  46. # 获取输入输出名称
  47. self.input_name = self.session.get_inputs()[0].name
  48. self.output_name = self.session.get_outputs()[0].name
  49. # 类别映射
  50. self.class_names = {0: 'UAV'}
  51. # 加载误报特征库
  52. self.false_positive_features = {}
  53. try:
  54. with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f:
  55. self.false_positive_features = pickle.load(f)
  56. except FileNotFoundError:
  57. print("未找到误报特征库文件,跳过加载")
  58. # 创建输出目录
  59. output_base = os.path.join(os.path.dirname(__file__), '..', '..', 'Output')
  60. os.makedirs(output_base, exist_ok=True)
  61. timestamp = datetime.now().strftime('%m%d%H%M%S')
  62. if output_dir is None and input_dir is not None:
  63. self.output_dir = os.path.join(output_base, f"{os.path.basename(os.path.normpath(input_dir))}_{timestamp}")
  64. else:
  65. self.output_dir = os.path.join(output_base, output_dir or f"output_{timestamp}")
  66. # 创建目标存储目录
  67. self.targets_dir = os.path.join(self.output_dir, 'targets')
  68. os.makedirs(self.targets_dir, exist_ok=True)
  69. def preprocess(self, image_path: str) -> Optional[np.ndarray]:
  70. """图像预处理"""
  71. # 初始化CUDA加速
  72. if not hasattr(self, 'use_cuda'):
  73. self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
  74. self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
  75. # 读取图像到GPU(如果可用)
  76. if self.use_cuda:
  77. self.gpu_frame.upload(cv2.imread(image_path))
  78. if self.gpu_frame.empty():
  79. print(f"无法加载图像:{image_path}")
  80. return None
  81. # GPU预处理流水线
  82. gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
  83. gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
  84. # 下载到CPU进行后续处理
  85. image_orig = gpu_rgb.download()
  86. self.orig_h, self.orig_w = image_orig.shape[:2]
  87. else:
  88. # CPU回退路径
  89. image_orig = cv2.imread(image_path)
  90. if image_orig is None:
  91. print(f"无法加载图像:{image_path}")
  92. return None
  93. self.orig_h, self.orig_w = image_orig.shape[:2]
  94. image_orig = cv2.resize(image_orig, self.input_size)
  95. image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
  96. # 统一的后处理
  97. image = image_orig.astype(np.float32) / 255.0
  98. image -= self.mean_[None, None, :]
  99. image /= self.std_[None, None, :]
  100. image = np.transpose(image, (2, 0, 1)) # CHW 格式
  101. return np.expand_dims(image, axis=0)
  102. def nms(self, boxes: np.ndarray, scores: np.ndarray, conf_threshold: float, iou_threshold: float) -> List[int]:
  103. """非极大值抑制"""
  104. x1 = boxes[:, 0]
  105. y1 = boxes[:, 1]
  106. x2 = boxes[:, 2]
  107. y2 = boxes[:, 3]
  108. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  109. order = scores.argsort()[0][::-1]
  110. keep = []
  111. while order.size > 0:
  112. i = order[0]
  113. keep.append(i)
  114. xx1 = np.maximum(x1[i], x1[order[1:]])
  115. yy1 = np.maximum(y1[i], y1[order[1:]])
  116. xx2 = np.minimum(x2[i], x2[order[1:]])
  117. yy2 = np.minimum(y2[i], y2[order[1:]])
  118. w = np.maximum(0.0, xx2 - xx1 + 1)
  119. h = np.maximum(0.0, yy2 - yy1 + 1)
  120. inter = w * h
  121. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  122. inds = np.where(ovr <= iou_threshold)[0]
  123. order = order[inds + 1]
  124. return keep
  125. def inference(self, input_data: np.ndarray) -> np.ndarray:
  126. """执行模型推理"""
  127. scale_factor = np.array([[1, 1]], dtype=np.float32)
  128. return self.session.run(
  129. [self.output_name],
  130. {self.input_name: input_data, 'scale_factor': scale_factor}
  131. )[0]
  132. def _is_false_positive(self, roi: np.ndarray) -> bool:
  133. """使用ORB特征匹配验证是否误报"""
  134. if not self.false_positive_features:
  135. return False
  136. # 初始化特征检测器
  137. detector = cv2.ORB_create()
  138. _, des = detector.detectAndCompute(roi, None)
  139. # 与特征库进行匹配
  140. for fp_feature in self.false_positive_features.values():
  141. if des is None or fp_feature['features'] is None:
  142. continue
  143. # 使用FLANN匹配器
  144. flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict())
  145. matches = flann.knnMatch(des, fp_feature['features'], k=2)
  146. # 处理空匹配情况
  147. if not matches:
  148. continue
  149. # 安全验证匹配结果
  150. good_matches = []
  151. for match_group in matches:
  152. # 确保match_group有足够元素防止索引错误
  153. if len(match_group) < 2:
  154. continue
  155. m, n = match_group
  156. if m.distance < 0.7 * n.distance:
  157. good_matches.append(m)
  158. # 计算优质匹配数量
  159. if len(good_matches) > 15: # 匹配阈值
  160. return True
  161. return False
  162. def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
  163. """后处理检测结果"""
  164. valid_detections = 0
  165. img_out = image_orig.copy()
  166. detections_list = []
  167. # 过滤低置信度检测
  168. keep_idx = (detections[:, 1] > self.confThreshold)
  169. detections = detections[keep_idx]
  170. if len(detections) == 0:
  171. return 0, img_out, []
  172. # 坐标转换
  173. ratioh = self.orig_h / self.input_size[1]
  174. ratiow = self.orig_w / self.input_size[0]
  175. detections[:, 2
  176. :6] *= np.array([ratiow, ratioh, ratiow, ratioh])
  177. # NMS处理
  178. keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
  179. for idx in keep:
  180. class_id = int(detections[idx, 0])
  181. confidence = detections[idx, 1]
  182. x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
  183. # 计算检测框面积比例
  184. bbox_area = (x2 - x1) * (y2 - y1)
  185. image_area = self.orig_w * self.orig_h
  186. if bbox_area / image_area > self.max_bbox_ratio:
  187. continue
  188. # 截取ROI区域进行误报验证
  189. roi = image_orig[y1:y2, x1:x2]
  190. if roi.size == 0:
  191. continue
  192. if self._is_false_positive(roi):
  193. continue
  194. # 生成唯一文件名并保存目标
  195. target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
  196. cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
  197. # 绘制检测框和标签
  198. label = f'{self.class_names[class_id]} {confidence:.2f}'
  199. cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
  200. cv2.putText(img_out, label, (x1, y1 - 10),
  201. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  202. valid_detections += 1
  203. detections_list.append({
  204. 'class_id': class_id,
  205. 'confidence': float(confidence),
  206. 'bbox': [x1, y1, x2, y2],
  207. 'orig_w': self.orig_w,
  208. 'orig_h': self.orig_h
  209. })
  210. return valid_detections, img_out, detections_list
  211. def process_image(self, image_path: str) -> int:
  212. """端到端处理单张图像"""
  213. input_data = self.preprocess(image_path)
  214. if input_data is None:
  215. return 0
  216. detections = self.inference(input_data)
  217. valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path)
  218. # 无论是否检测到目标都记录信息
  219. record = {
  220. 'image_path': image_path,
  221. 'has_detection': valid_count > 0,
  222. 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  223. }
  224. filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
  225. # 创建输出图片目录
  226. output_img_dir = os.path.join(self.output_dir, 'imgs')
  227. os.makedirs(output_img_dir, exist_ok=True)
  228. output_path = os.path.join(output_img_dir, filename)
  229. # 根据设置保存图片
  230. if valid_count > 0 or self.save_empty:
  231. cv2.imwrite(output_path, processed_img)
  232. else:
  233. output_path = None
  234. self.image_count += 1
  235. # 记录检测信息
  236. if valid_count > 0:
  237. record['detections'] = detections_list
  238. self.detection_records.append(record)
  239. return valid_count
  240. # 命令行接口
  241. if __name__ == '__main__':
  242. parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
  243. parser.add_argument('--input', type=str, help='输入图像路径或目录')
  244. parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
  245. parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
  246. parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
  247. help='检测框最大面积比例阈值,默认0.05')
  248. parser.add_argument('--save-empty', action='store_true',
  249. help='是否保存未检测到目标的图片')
  250. parser.add_argument('--gui', action='store_true',
  251. help='启用图形界面选择输入目录')
  252. args = parser.parse_args()
  253. # GUI模式处理
  254. if args.gui:
  255. detector = ONNXDetector(
  256. threshold=args.threshold,
  257. max_bbox_ratio=args.max_bbox_ratio,
  258. output_dir=args.output,
  259. save_empty=args.save_empty
  260. )
  261. input_dir = detector.select_input_directory()
  262. if not input_dir:
  263. print("未选择目录,程序退出")
  264. exit()
  265. args.input = input_dir
  266. # 检查输入参数
  267. if args.input is None:
  268. print("错误:请提供输入路径或使用 --gui 参数选择目录")
  269. parser.print_help()
  270. exit(1)
  271. # 初始化检测器时传递参数
  272. detector = ONNXDetector(
  273. threshold=args.threshold,
  274. max_bbox_ratio=args.max_bbox_ratio,
  275. output_dir=args.output,
  276. save_empty=args.save_empty,
  277. input_dir=args.input if args.input and os.path.isdir(args.input) else None
  278. )
  279. def process_single(image_path: str):
  280. detections = detector.process_image(image_path)
  281. print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标')
  282. if args.input and os.path.isdir(args.input):
  283. total = 0
  284. image_files = []
  285. for root, _, files in os.walk(args.input):
  286. image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')])
  287. for img_file in tqdm(image_files, desc='Processing images'):
  288. total += detector.process_image(img_file)
  289. print(f'批量处理完成!共检测到 {total} 个目标')
  290. # 生成CSV报告
  291. csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
  292. ReportGenerator(detector).generate_csv(csv_path)
  293. print(f'CSV报告已生成: {csv_path}')
  294. else:
  295. detections = detector.process_image(args.input)