123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- import numpy as np
- import onnxruntime as ort
- import cv2
- import os
- import argparse
- import pickle
- from glob import glob
- from typing import List, Dict, Optional
- import time
- from datetime import datetime
- from tqdm import tqdm
- from report_generator import ReportGenerator
- class ONNXDetector:
- def select_input_directory(self) -> Optional[str]:
- """图形界面选择输入目录"""
- import tkinter as tk
- from tkinter import filedialog
- root = tk.Tk()
- root.withdraw()
- folder_path = filedialog.askdirectory(title='选择输入目录')
- return folder_path if folder_path else None
- def __init__(self, model_path: str = 'UAV-250321.onnx', threshold: float = 0.5,
- output_dir: str = "None", save_empty: bool = False,
- max_bbox_ratio: float = 0.5, # 已有该参数
- input_dir: str = "None"):
- self.image_count = 0
- self.detection_records = []
- self.model_path = model_path
- self.input_dir = input_dir
-
- # 初始化ONNX会话
- self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
- self.session = ort.InferenceSession(model_path, providers=self.providers)
-
- # 获取模型输入尺寸
- input_dims = self.session.get_inputs()[0].shape
- self.input_shape = tuple([dim if isinstance(dim, int) else 1 for dim in input_dims])
-
- # 从模型路径解析版本号
- self.model_version = model_path.split('-')[-1].split('.')[0] if '-' in model_path else '1.0'
- self.input_size = f"{self.input_shape[2]}x{self.input_shape[3]}"
-
- self.threshold = threshold
- self.max_bbox_ratio = max_bbox_ratio # 已有该赋值
- self.save_empty = save_empty
- self.input_name = self.session.get_inputs()[0].name
- self.output_name = self.session.get_outputs()[0].name
-
- # 类别映射
- self.class_names = {0: 'UAV'}
- # 加载误报特征库
- self.false_positive_features = {}
- try:
- with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f:
- self.false_positive_features = pickle.load(f)
- except FileNotFoundError:
- print("未找到误报特征库文件,跳过加载")
-
- # 创建输出目录
- if output_dir is None and input_dir is not None:
- self.output_dir = f"{os.path.basename(os.path.normpath(input_dir))}_results"
- else:
- self.output_dir = output_dir
- # 创建目标存储目录
- self.targets_dir = os.path.join(self.output_dir, 'targets')
- os.makedirs(self.targets_dir, exist_ok=True)
- def preprocess(self, image_path: str) -> Optional[np.ndarray]:
- """图像预处理"""
- # 初始化CUDA加速
- if not hasattr(self, 'use_cuda'):
- self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
- self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
- # 读取图像到GPU(如果可用)
- if self.use_cuda:
- self.gpu_frame.upload(cv2.imread(image_path))
- if self.gpu_frame.empty():
- print(f"无法加载图像:{image_path}")
- return None
-
- # GPU预处理流水线
- gpu_resized = cv2.cuda.resize(self.gpu_frame, (self.input_shape[3], self.input_shape[2]))
- gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
-
- # 下载到CPU进行后续处理
- image_orig = gpu_rgb.download()
- self.orig_h, self.orig_w = image_orig.shape[:2]
- else:
- # CPU回退路径
- image_orig = cv2.imread(image_path)
- if image_orig is None:
- print(f"无法加载图像:{image_path}")
- return None
-
- self.orig_h, self.orig_w = image_orig.shape[:2]
- image_orig = cv2.resize(image_orig, (self.input_shape[3], self.input_shape[2]))
- image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
- # 统一的后处理
- image = image_orig.astype(np.float32) / 255.0
- image = np.transpose(image, (2, 0, 1)) # CHW 格式
- return np.expand_dims(image, axis=0)
- def inference(self, input_data: np.ndarray) -> np.ndarray:
- """执行模型推理"""
- scale_factor = np.array([[1, 1]], dtype=np.float32)
- return self.session.run(
- [self.output_name],
- {self.input_name: input_data, 'scale_factor': scale_factor}
- )[0]
- def _is_false_positive(self, roi: np.ndarray) -> bool:
- """使用ORB特征匹配验证是否误报"""
- if not self.false_positive_features:
- return False
- # 初始化特征检测器
- detector = cv2.ORB_create()
- _, des = detector.detectAndCompute(roi, None)
- # 与特征库进行匹配
- for fp_feature in self.false_positive_features.values():
- if des is None or fp_feature['features'] is None:
- continue
- # 使用FLANN匹配器
- flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict())
- matches = flann.knnMatch(des, fp_feature['features'], k=2)
- # 处理空匹配情况
- if not matches:
- continue
- # 安全验证匹配结果
- good_matches = []
- for match_group in matches:
- # 确保match_group有足够元素防止索引错误
- if len(match_group) < 2:
- continue
- m, n = match_group
- if m.distance < 0.7 * n.distance:
- good_matches.append(m)
- # 计算优质匹配数量
- if len(good_matches) > 15: # 匹配阈值
- return True
- return False
- def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> int:
- """后处理检测结果"""
- valid_detections = 0
- img_out = image_orig.copy()
- detections_list = []
-
- for det in detections:
- class_id = int(det[0])
- confidence = det[1]
-
- if confidence < self.threshold:
- continue
-
- # 坐标转换和边界检查
- x1, y1, x2, y2 = self._convert_coordinates(det[2:6])
- # 计算检测框面积比例
- bbox_area = (x2 - x1) * (y2 - y1)
- image_area = self.orig_w * self.orig_h
- if bbox_area / image_area > self.max_bbox_ratio:
- continue
- # 截取ROI区域进行误报验证
- roi = image_orig[y1:y2, x1:x2]
- if roi.size == 0:
- continue
-
- if self._is_false_positive(roi):
- continue
-
- # 生成唯一文件名并保存目标
- timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
- target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{timestamp}_{valid_detections}.jpg"
- cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
-
- # 绘制检测框和标签
- label = f'{self.class_names[class_id]} {confidence:.2f}'
- cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
- cv2.putText(img_out, label, (x1, y1 - 10),
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
- valid_detections += 1
- detections_list.append({
- 'class_id': class_id,
- 'confidence': float(confidence),
- 'bbox': [x1, y1, x2, y2],
- 'orig_w': self.orig_w,
- 'orig_h': self.orig_h
- })
-
- return valid_detections, img_out, detections_list
- def _convert_coordinates(self, coords: List[float]) -> tuple:
- """将模型输出坐标转换为原始图像尺寸"""
- x1 = int(coords[0] * self.orig_w / self.input_shape[3])
- y1 = int(coords[1] * self.orig_h / self.input_shape[2])
- x2 = int(coords[2] * self.orig_w / self.input_shape[3])
- y2 = int(coords[3] * self.orig_h / self.input_shape[2])
-
- # 边界检查
- clamp = lambda val, max_val: max(0, min(val, max_val - 1))
- return (
- clamp(x1, self.orig_w),
- clamp(y1, self.orig_h),
- clamp(x2, self.orig_w),
- clamp(y2, self.orig_h)
- )
- def process_image(self, image_path: str) -> int:
- """端到端处理单张图像"""
- input_data = self.preprocess(image_path)
- if input_data is None:
- return 0
-
- detections = self.inference(input_data)
- valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path)
-
- # 无论是否检测到目标都记录信息
- record = {
- 'image_path': image_path,
- 'has_detection': valid_count > 0,
- 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- }
- filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
- output_path = os.path.join(self.output_dir, filename)
-
- # 根据设置保存图片
- if valid_count > 0 or self.save_empty:
- cv2.imwrite(output_path, processed_img)
- else:
- output_path = None
- self.image_count += 1
-
- # 记录检测信息
- if valid_count > 0:
- record['detections'] = detections_list
-
- self.detection_records.append(record)
- return valid_count
- # 命令行接口
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
- parser.add_argument('--input', type=str, help='输入图像路径或目录')
- parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
- parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
- parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
- help='检测框最大面积比例阈值,默认0.05')
- parser.add_argument('--report', type=str, choices=['csv', 'excel', 'all'], default='all',
- help='输出报告格式: csv|excel|all')
- parser.add_argument('--save-empty', action='store_true',
- help='是否保存未检测到目标的图片')
- parser.add_argument('--gui', action='store_true',
- help='启用图形界面选择输入目录')
- args = parser.parse_args()
- # GUI模式处理
- if args.gui:
- detector = ONNXDetector(
- threshold=args.threshold,
- max_bbox_ratio=args.max_bbox_ratio,
- output_dir=args.output,
- save_empty=args.save_empty
- )
- input_dir = detector.select_input_directory()
- if not input_dir:
- print("未选择目录,程序退出")
- exit()
- args.input = input_dir
-
- # 初始化检测器时传递参数
- detector = ONNXDetector(
- threshold=args.threshold,
- max_bbox_ratio=args.max_bbox_ratio, # 添加该参数传递
- output_dir=args.output,
- save_empty=args.save_empty,
- input_dir=args.input if os.path.isdir(args.input) else None
- )
- def process_single(image_path: str):
- detections = detector.process_image(image_path)
- print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标')
- if os.path.isdir(args.input):
- total = 0
- image_files = []
- for root, _, files in os.walk(args.input):
- image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')])
- for img_file in tqdm(image_files, desc='Processing images'):
- total += detector.process_image(img_file)
- print(f'批量处理完成!共检测到 {total} 个目标')
-
- # 生成检测报告
- if args.report in ('csv', 'all'):
- csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
- ReportGenerator(detector).generate_csv(csv_path)
- print(f'CSV报告已生成: {csv_path}')
- if args.report in ('excel', 'all'):
- excel_path = os.path.join(detector.output_dir, 'detection_report.xlsx')
- ReportGenerator(detector).generate_excel(excel_path)
- print(f'Excel报告已生成: {excel_path}')
- else:
- detections = detector.process_image(args.input)
|