inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import numpy as np
  2. import onnxruntime as ort
  3. import cv2
  4. import os
  5. import argparse
  6. import pickle
  7. from glob import glob
  8. from typing import List, Dict, Optional
  9. import time
  10. from datetime import datetime
  11. from tqdm import tqdm
  12. from ..analysis.report_generator import ReportGenerator
  13. class ONNXDetector:
  14. def select_input_directory(self) -> Optional[str]:
  15. """图形界面选择输入目录"""
  16. import tkinter as tk
  17. from tkinter import filedialog
  18. root = tk.Tk()
  19. root.withdraw()
  20. folder_path = filedialog.askdirectory(title='选择输入目录')
  21. return folder_path if folder_path else None
  22. def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5,
  23. output_dir: str = None, save_empty: bool = False,
  24. max_bbox_ratio: float = 0.5,
  25. input_dir: str = None):
  26. self.image_count = 0
  27. self.detection_records = []
  28. self.model_path = model_path
  29. self.input_dir = input_dir
  30. self.save_empty = save_empty
  31. self.threshold = threshold
  32. self.confThreshold = threshold
  33. # 初始化ONNX会话
  34. so = ort.SessionOptions()
  35. so.log_severity_level = 3
  36. self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
  37. self.session = ort.InferenceSession(model_path, so, providers=self.providers)
  38. # 模型参数设置
  39. self.input_size = (640, 640) # width, height
  40. self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32)
  41. self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32)
  42. self.max_bbox_ratio = max_bbox_ratio
  43. # 从模型路径解析版本号
  44. self.model_version = os.path.basename(model_path).split('.')[0]
  45. # 获取输入输出名称
  46. self.input_name = self.session.get_inputs()[0].name
  47. self.output_name = self.session.get_outputs()[0].name
  48. # 类别映射
  49. self.class_names = {0: 'UAV'}
  50. # 加载误报特征库
  51. self.false_positive_features = {}
  52. try:
  53. with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f:
  54. self.false_positive_features = pickle.load(f)
  55. except FileNotFoundError:
  56. print("未找到误报特征库文件,跳过加载")
  57. # 创建输出目录
  58. output_base = os.path.join(os.path.dirname(__file__), '..', '..', 'Output')
  59. os.makedirs(output_base, exist_ok=True)
  60. timestamp = datetime.now().strftime('%m%d%H%M%S')
  61. if output_dir is None and input_dir is not None:
  62. self.output_dir = os.path.join(output_base, f"{os.path.basename(os.path.normpath(input_dir))}_{timestamp}")
  63. else:
  64. self.output_dir = os.path.join(output_base, output_dir or f"output_{timestamp}")
  65. # 创建目标存储目录
  66. self.targets_dir = os.path.join(self.output_dir, 'targets')
  67. os.makedirs(self.targets_dir, exist_ok=True)
  68. def preprocess(self, image_path: str) -> Optional[np.ndarray]:
  69. """图像预处理"""
  70. # 初始化CUDA加速
  71. if not hasattr(self, 'use_cuda'):
  72. self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
  73. self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
  74. # 读取图像到GPU(如果可用)
  75. if self.use_cuda:
  76. self.gpu_frame.upload(cv2.imread(image_path))
  77. if self.gpu_frame.empty():
  78. print(f"无法加载图像:{image_path}")
  79. return None
  80. # GPU预处理流水线
  81. gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
  82. gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
  83. # 下载到CPU进行后续处理
  84. image_orig = gpu_rgb.download()
  85. self.orig_h, self.orig_w = image_orig.shape[:2]
  86. else:
  87. # CPU回退路径
  88. image_orig = cv2.imread(image_path)
  89. if image_orig is None:
  90. print(f"无法加载图像:{image_path}")
  91. return None
  92. self.orig_h, self.orig_w = image_orig.shape[:2]
  93. image_orig = cv2.resize(image_orig, self.input_size)
  94. image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
  95. # 统一的后处理
  96. image = image_orig.astype(np.float32) / 255.0
  97. image -= self.mean_[None, None, :]
  98. image /= self.std_[None, None, :]
  99. image = np.transpose(image, (2, 0, 1)) # CHW 格式
  100. return np.expand_dims(image, axis=0)
  101. def nms(self, boxes: np.ndarray, scores: np.ndarray, conf_threshold: float, iou_threshold: float) -> List[int]:
  102. """非极大值抑制"""
  103. x1 = boxes[:, 0]
  104. y1 = boxes[:, 1]
  105. x2 = boxes[:, 2]
  106. y2 = boxes[:, 3]
  107. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  108. order = scores.argsort()[0][::-1]
  109. keep = []
  110. while order.size > 0:
  111. i = order[0]
  112. keep.append(i)
  113. xx1 = np.maximum(x1[i], x1[order[1:]])
  114. yy1 = np.maximum(y1[i], y1[order[1:]])
  115. xx2 = np.minimum(x2[i], x2[order[1:]])
  116. yy2 = np.minimum(y2[i], y2[order[1:]])
  117. w = np.maximum(0.0, xx2 - xx1 + 1)
  118. h = np.maximum(0.0, yy2 - yy1 + 1)
  119. inter = w * h
  120. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  121. inds = np.where(ovr <= iou_threshold)[0]
  122. order = order[inds + 1]
  123. return keep
  124. def inference(self, input_data: np.ndarray) -> np.ndarray:
  125. """执行模型推理"""
  126. scale_factor = np.array([[1, 1]], dtype=np.float32)
  127. return self.session.run(
  128. [self.output_name],
  129. {self.input_name: input_data, 'scale_factor': scale_factor}
  130. )[0]
  131. def _is_false_positive(self, roi: np.ndarray) -> bool:
  132. """使用ORB特征匹配验证是否误报"""
  133. if not self.false_positive_features:
  134. return False
  135. # 初始化特征检测器
  136. detector = cv2.ORB_create()
  137. _, des = detector.detectAndCompute(roi, None)
  138. # 与特征库进行匹配
  139. for fp_feature in self.false_positive_features.values():
  140. if des is None or fp_feature['features'] is None:
  141. continue
  142. # 使用FLANN匹配器
  143. flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict())
  144. matches = flann.knnMatch(des, fp_feature['features'], k=2)
  145. # 处理空匹配情况
  146. if not matches:
  147. continue
  148. # 安全验证匹配结果
  149. good_matches = []
  150. for match_group in matches:
  151. # 确保match_group有足够元素防止索引错误
  152. if len(match_group) < 2:
  153. continue
  154. m, n = match_group
  155. if m.distance < 0.7 * n.distance:
  156. good_matches.append(m)
  157. # 计算优质匹配数量
  158. if len(good_matches) > 15: # 匹配阈值
  159. return True
  160. return False
  161. def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
  162. """后处理检测结果"""
  163. valid_detections = 0
  164. img_out = image_orig.copy()
  165. detections_list = []
  166. # 过滤低置信度检测
  167. keep_idx = (detections[:, 1] > self.confThreshold)
  168. detections = detections[keep_idx]
  169. if len(detections) == 0:
  170. return 0, img_out, []
  171. # 坐标转换
  172. ratioh = self.orig_h / self.input_size[1]
  173. ratiow = self.orig_w / self.input_size[0]
  174. detections[:, 2
  175. :6] *= np.array([ratiow, ratioh, ratiow, ratioh])
  176. # NMS处理
  177. keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
  178. for idx in keep:
  179. class_id = int(detections[idx, 0])
  180. confidence = detections[idx, 1]
  181. x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
  182. # 计算检测框面积比例
  183. bbox_area = (x2 - x1) * (y2 - y1)
  184. image_area = self.orig_w * self.orig_h
  185. if bbox_area / image_area > self.max_bbox_ratio:
  186. continue
  187. # 截取ROI区域进行误报验证
  188. roi = image_orig[y1:y2, x1:x2]
  189. if roi.size == 0:
  190. continue
  191. if self._is_false_positive(roi):
  192. continue
  193. # 生成唯一文件名并保存目标
  194. target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
  195. cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
  196. # 绘制检测框和标签
  197. label = f'{self.class_names[class_id]} {confidence:.2f}'
  198. cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
  199. cv2.putText(img_out, label, (x1, y1 - 10),
  200. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  201. valid_detections += 1
  202. detections_list.append({
  203. 'class_id': class_id,
  204. 'confidence': float(confidence),
  205. 'bbox': [x1, y1, x2, y2],
  206. 'orig_w': self.orig_w,
  207. 'orig_h': self.orig_h
  208. })
  209. return valid_detections, img_out, detections_list
  210. def process_image(self, image_path: str) -> int:
  211. """端到端处理单张图像"""
  212. input_data = self.preprocess(image_path)
  213. if input_data is None:
  214. return 0
  215. detections = self.inference(input_data)
  216. valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path)
  217. # 无论是否检测到目标都记录信息
  218. record = {
  219. 'image_path': image_path,
  220. 'has_detection': valid_count > 0,
  221. 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  222. }
  223. filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
  224. # 创建输出图片目录
  225. output_img_dir = os.path.join(self.output_dir, 'imgs')
  226. os.makedirs(output_img_dir, exist_ok=True)
  227. output_path = os.path.join(output_img_dir, filename)
  228. # 根据设置保存图片
  229. if valid_count > 0 or self.save_empty:
  230. cv2.imwrite(output_path, processed_img)
  231. else:
  232. output_path = None
  233. self.image_count += 1
  234. # 记录检测信息
  235. if valid_count > 0:
  236. record['detections'] = detections_list
  237. self.detection_records.append(record)
  238. return valid_count
  239. # 命令行接口
  240. if __name__ == '__main__':
  241. parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
  242. parser.add_argument('--input', type=str, help='输入图像路径或目录')
  243. parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
  244. parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
  245. parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
  246. help='检测框最大面积比例阈值,默认0.05')
  247. parser.add_argument('--save-empty', action='store_true',
  248. help='是否保存未检测到目标的图片')
  249. parser.add_argument('--gui', action='store_true',
  250. help='启用图形界面选择输入目录')
  251. args = parser.parse_args()
  252. # GUI模式处理
  253. if args.gui:
  254. detector = ONNXDetector(
  255. threshold=args.threshold,
  256. max_bbox_ratio=args.max_bbox_ratio,
  257. output_dir=args.output,
  258. save_empty=args.save_empty
  259. )
  260. input_dir = detector.select_input_directory()
  261. if not input_dir:
  262. print("未选择目录,程序退出")
  263. exit()
  264. args.input = input_dir
  265. # 初始化检测器时传递参数
  266. detector = ONNXDetector(
  267. threshold=args.threshold,
  268. max_bbox_ratio=args.max_bbox_ratio,
  269. output_dir=args.output,
  270. save_empty=args.save_empty,
  271. input_dir=args.input if os.path.isdir(args.input) else None
  272. )
  273. def process_single(image_path: str):
  274. detections = detector.process_image(image_path)
  275. print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标')
  276. if os.path.isdir(args.input):
  277. total = 0
  278. image_files = []
  279. for root, _, files in os.walk(args.input):
  280. image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')])
  281. for img_file in tqdm(image_files, desc='Processing images'):
  282. total += detector.process_image(img_file)
  283. print(f'批量处理完成!共检测到 {total} 个目标')
  284. # 生成CSV报告
  285. csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
  286. ReportGenerator(detector).generate_csv(csv_path)
  287. print(f'CSV报告已生成: {csv_path}')
  288. else:
  289. detections = detector.process_image(args.input)