import numpy as np import onnxruntime as ort import cv2 import os import argparse import pickle from glob import glob from typing import List, Dict, Optional import time from datetime import datetime from tqdm import tqdm from ..analysis.report_generator import ReportGenerator class ONNXDetector: def select_input_directory(self) -> Optional[str]: """图形界面选择输入目录""" import tkinter as tk from tkinter import filedialog root = tk.Tk() root.withdraw() folder_path = filedialog.askdirectory(title='选择输入目录') return folder_path if folder_path else None def __init__(self, model_path: str = 'data/models/UAV-250321.onnx', threshold: float = 0.5, output_dir: str = "None", save_empty: bool = False, max_bbox_ratio: float = 0.5, # 已有该参数 input_dir: str = "None"): self.image_count = 0 self.detection_records = [] self.model_path = model_path self.input_dir = input_dir # 初始化ONNX会话 self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] self.session = ort.InferenceSession(model_path, providers=self.providers) # 获取模型输入尺寸 input_dims = self.session.get_inputs()[0].shape self.input_shape = tuple([dim if isinstance(dim, int) else 1 for dim in input_dims]) # 从模型路径解析版本号 self.model_version = model_path.split('-')[-1].split('.')[0] if '-' in model_path else '1.0' self.input_size = f"{self.input_shape[2]}x{self.input_shape[3]}" self.threshold = threshold self.max_bbox_ratio = max_bbox_ratio # 已有该赋值 self.save_empty = save_empty self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name # 类别映射 self.class_names = {0: 'UAV'} # 加载误报特征库 self.false_positive_features = {} try: with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f: self.false_positive_features = pickle.load(f) except FileNotFoundError: print("未找到误报特征库文件,跳过加载") # 创建输出目录 if output_dir is None and input_dir is not None: self.output_dir = f"{os.path.basename(os.path.normpath(input_dir))}_results" else: self.output_dir = output_dir # 创建目标存储目录 self.targets_dir = os.path.join(self.output_dir, 'targets') os.makedirs(self.targets_dir, exist_ok=True) def preprocess(self, image_path: str) -> Optional[np.ndarray]: """图像预处理""" # 初始化CUDA加速 if not hasattr(self, 'use_cuda'): self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0 self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None # 读取图像到GPU(如果可用) if self.use_cuda: self.gpu_frame.upload(cv2.imread(image_path)) if self.gpu_frame.empty(): print(f"无法加载图像:{image_path}") return None # GPU预处理流水线 gpu_resized = cv2.cuda.resize(self.gpu_frame, (self.input_shape[3], self.input_shape[2])) gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB) # 下载到CPU进行后续处理 image_orig = gpu_rgb.download() self.orig_h, self.orig_w = image_orig.shape[:2] else: # CPU回退路径 image_orig = cv2.imread(image_path) if image_orig is None: print(f"无法加载图像:{image_path}") return None self.orig_h, self.orig_w = image_orig.shape[:2] image_orig = cv2.resize(image_orig, (self.input_shape[3], self.input_shape[2])) image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB) # 统一的后处理 image = image_orig.astype(np.float32) / 255.0 image = np.transpose(image, (2, 0, 1)) # CHW 格式 return np.expand_dims(image, axis=0) def inference(self, input_data: np.ndarray) -> np.ndarray: """执行模型推理""" scale_factor = np.array([[1, 1]], dtype=np.float32) return self.session.run( [self.output_name], {self.input_name: input_data, 'scale_factor': scale_factor} )[0] def _is_false_positive(self, roi: np.ndarray) -> bool: """使用ORB特征匹配验证是否误报""" if not self.false_positive_features: return False # 初始化特征检测器 detector = cv2.ORB_create() _, des = detector.detectAndCompute(roi, None) # 与特征库进行匹配 for fp_feature in self.false_positive_features.values(): if des is None or fp_feature['features'] is None: continue # 使用FLANN匹配器 flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict()) matches = flann.knnMatch(des, fp_feature['features'], k=2) # 处理空匹配情况 if not matches: continue # 安全验证匹配结果 good_matches = [] for match_group in matches: # 确保match_group有足够元素防止索引错误 if len(match_group) < 2: continue m, n = match_group if m.distance < 0.7 * n.distance: good_matches.append(m) # 计算优质匹配数量 if len(good_matches) > 15: # 匹配阈值 return True return False def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> int: """后处理检测结果""" valid_detections = 0 img_out = image_orig.copy() detections_list = [] for det in detections: class_id = int(det[0]) confidence = det[1] if confidence < self.threshold: continue # 坐标转换和边界检查 x1, y1, x2, y2 = self._convert_coordinates(det[2:6]) # 计算检测框面积比例 bbox_area = (x2 - x1) * (y2 - y1) image_area = self.orig_w * self.orig_h if bbox_area / image_area > self.max_bbox_ratio: continue # 截取ROI区域进行误报验证 roi = image_orig[y1:y2, x1:x2] if roi.size == 0: continue if self._is_false_positive(roi): continue # 生成唯一文件名并保存目标 timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f') target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{timestamp}_{valid_detections}.jpg" cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi) # 绘制检测框和标签 label = f'{self.class_names[class_id]} {confidence:.2f}' cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4) cv2.putText(img_out, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) valid_detections += 1 detections_list.append({ 'class_id': class_id, 'confidence': float(confidence), 'bbox': [x1, y1, x2, y2], 'orig_w': self.orig_w, 'orig_h': self.orig_h }) return valid_detections, img_out, detections_list def _convert_coordinates(self, coords: List[float]) -> tuple: """将模型输出坐标转换为原始图像尺寸""" x1 = int(coords[0] * self.orig_w / self.input_shape[3]) y1 = int(coords[1] * self.orig_h / self.input_shape[2]) x2 = int(coords[2] * self.orig_w / self.input_shape[3]) y2 = int(coords[3] * self.orig_h / self.input_shape[2]) # 边界检查 clamp = lambda val, max_val: max(0, min(val, max_val - 1)) return ( clamp(x1, self.orig_w), clamp(y1, self.orig_h), clamp(x2, self.orig_w), clamp(y2, self.orig_h) ) def process_image(self, image_path: str) -> int: """端到端处理单张图像""" input_data = self.preprocess(image_path) if input_data is None: return 0 detections = self.inference(input_data) valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path) # 无论是否检测到目标都记录信息 record = { 'image_path': image_path, 'has_detection': valid_count > 0, 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg" output_path = os.path.join(self.output_dir, filename) # 根据设置保存图片 if valid_count > 0 or self.save_empty: cv2.imwrite(output_path, processed_img) else: output_path = None self.image_count += 1 # 记录检测信息 if valid_count > 0: record['detections'] = detections_list self.detection_records.append(record) return valid_count # 命令行接口 if __name__ == '__main__': parser = argparse.ArgumentParser(description='UAV Detection using ONNX model') parser.add_argument('--input', type=str, help='输入图像路径或目录') parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值') parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results') parser.add_argument('--max-bbox-ratio', type=float, default=0.05, help='检测框最大面积比例阈值,默认0.05') parser.add_argument('--report', type=str, choices=['csv', 'excel', 'all'], default='all', help='输出报告格式: csv|excel|all') parser.add_argument('--save-empty', action='store_true', help='是否保存未检测到目标的图片') parser.add_argument('--gui', action='store_true', help='启用图形界面选择输入目录') args = parser.parse_args() # GUI模式处理 if args.gui: detector = ONNXDetector( threshold=args.threshold, max_bbox_ratio=args.max_bbox_ratio, output_dir=args.output, save_empty=args.save_empty ) input_dir = detector.select_input_directory() if not input_dir: print("未选择目录,程序退出") exit() args.input = input_dir # 初始化检测器时传递参数 detector = ONNXDetector( threshold=args.threshold, max_bbox_ratio=args.max_bbox_ratio, # 添加该参数传递 output_dir=args.output, save_empty=args.save_empty, input_dir=args.input if os.path.isdir(args.input) else None ) def process_single(image_path: str): detections = detector.process_image(image_path) print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标') if os.path.isdir(args.input): total = 0 image_files = [] for root, _, files in os.walk(args.input): image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')]) for img_file in tqdm(image_files, desc='Processing images'): total += detector.process_image(img_file) print(f'批量处理完成!共检测到 {total} 个目标') # 生成检测报告 if args.report in ('csv', 'all'): csv_path = os.path.join(detector.output_dir, 'detection_report.csv') ReportGenerator(detector).generate_csv(csv_path) print(f'CSV报告已生成: {csv_path}') if args.report in ('excel', 'all'): excel_path = os.path.join(detector.output_dir, 'detection_report.xlsx') ReportGenerator(detector).generate_excel(excel_path) print(f'Excel报告已生成: {excel_path}') else: detections = detector.process_image(args.input)