|
@@ -0,0 +1,348 @@
|
|
|
|
+import numpy as np
|
|
|
|
+import onnxruntime as ort
|
|
|
|
+import cv2
|
|
|
|
+import os
|
|
|
|
+import argparse
|
|
|
|
+import pickle
|
|
|
|
+from glob import glob
|
|
|
|
+from typing import List, Dict, Optional
|
|
|
|
+import time
|
|
|
|
+from datetime import datetime
|
|
|
|
+from tqdm import tqdm
|
|
|
|
+import sys
|
|
|
|
+
|
|
|
|
+from report_generator import ReportGenerator
|
|
|
|
+class ONNXDetector:
|
|
|
|
+ def select_input_directory(self) -> Optional[str]:
|
|
|
|
+ """图形界面选择输入目录"""
|
|
|
|
+ import tkinter as tk
|
|
|
|
+ from tkinter import filedialog
|
|
|
|
+ root = tk.Tk()
|
|
|
|
+ root.withdraw()
|
|
|
|
+ folder_path = filedialog.askdirectory(title='选择输入目录')
|
|
|
|
+ return folder_path if folder_path else None
|
|
|
|
+
|
|
|
|
+ def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5,
|
|
|
|
+ output_dir: str = None, save_empty: bool = False,
|
|
|
|
+ max_bbox_ratio: float = 0.5,
|
|
|
|
+ input_dir: str = None):
|
|
|
|
+ self.image_count = 0
|
|
|
|
+ self.detection_records = []
|
|
|
|
+ self.model_path = model_path
|
|
|
|
+ self.input_dir = input_dir
|
|
|
|
+ self.save_empty = save_empty
|
|
|
|
+ self.threshold = threshold
|
|
|
|
+ self.confThreshold = threshold
|
|
|
|
+
|
|
|
|
+ # 初始化ONNX会话
|
|
|
|
+ so = ort.SessionOptions()
|
|
|
|
+ so.log_severity_level = 3
|
|
|
|
+ self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
|
|
+ self.session = ort.InferenceSession(model_path, so, providers=self.providers)
|
|
|
|
+
|
|
|
|
+ # 模型参数设置
|
|
|
|
+ self.input_size = (640, 640) # width, height
|
|
|
|
+ self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
|
|
+ self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
|
|
+ self.max_bbox_ratio = max_bbox_ratio
|
|
|
|
+
|
|
|
|
+ # 从模型路径解析版本号
|
|
|
|
+ self.model_version = os.path.basename(model_path).split('.')[0]
|
|
|
|
+
|
|
|
|
+ # 获取输入输出名称
|
|
|
|
+ self.input_name = self.session.get_inputs()[0].name
|
|
|
|
+ self.output_name = self.session.get_outputs()[0].name
|
|
|
|
+
|
|
|
|
+ # 类别映射
|
|
|
|
+ self.class_names = {0: 'UAV'}
|
|
|
|
+
|
|
|
|
+ # 加载误报特征库
|
|
|
|
+ self.false_positive_features = {}
|
|
|
|
+ try:
|
|
|
|
+ with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f:
|
|
|
|
+ self.false_positive_features = pickle.load(f)
|
|
|
|
+ except FileNotFoundError:
|
|
|
|
+ print("未找到误报特征库文件,跳过加载")
|
|
|
|
+
|
|
|
|
+ # 创建输出目录
|
|
|
|
+ output_base = os.path.join(os.path.dirname(__file__), '..', '..', 'Output')
|
|
|
|
+ os.makedirs(output_base, exist_ok=True)
|
|
|
|
+
|
|
|
|
+ timestamp = datetime.now().strftime('%m%d%H%M%S')
|
|
|
|
+ if output_dir is None and input_dir is not None:
|
|
|
|
+ self.output_dir = os.path.join(output_base, f"{os.path.basename(os.path.normpath(input_dir))}_{timestamp}")
|
|
|
|
+ else:
|
|
|
|
+ self.output_dir = os.path.join(output_base, output_dir or f"output_{timestamp}")
|
|
|
|
+
|
|
|
|
+ # 创建目标存储目录
|
|
|
|
+ self.targets_dir = os.path.join(self.output_dir, 'targets')
|
|
|
|
+ os.makedirs(self.targets_dir, exist_ok=True)
|
|
|
|
+
|
|
|
|
+ def preprocess(self, image_path: str) -> Optional[np.ndarray]:
|
|
|
|
+ """图像预处理"""
|
|
|
|
+ # 初始化CUDA加速
|
|
|
|
+ if not hasattr(self, 'use_cuda'):
|
|
|
|
+ self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
|
|
|
|
+ self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
|
|
|
|
+
|
|
|
|
+ # 读取图像到GPU(如果可用)
|
|
|
|
+ if self.use_cuda:
|
|
|
|
+ self.gpu_frame.upload(cv2.imread(image_path))
|
|
|
|
+ if self.gpu_frame.empty():
|
|
|
|
+ print(f"无法加载图像:{image_path}")
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ # GPU预处理流水线
|
|
|
|
+ gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
|
|
|
|
+ gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
|
|
|
|
+
|
|
|
|
+ # 下载到CPU进行后续处理
|
|
|
|
+ image_orig = gpu_rgb.download()
|
|
|
|
+ self.orig_h, self.orig_w = image_orig.shape[:2]
|
|
|
|
+ else:
|
|
|
|
+ # CPU回退路径
|
|
|
|
+ image_orig = cv2.imread(image_path)
|
|
|
|
+ if image_orig is None:
|
|
|
|
+ print(f"无法加载图像:{image_path}")
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ self.orig_h, self.orig_w = image_orig.shape[:2]
|
|
|
|
+ image_orig = cv2.resize(image_orig, self.input_size)
|
|
|
|
+ image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
|
|
|
|
+
|
|
|
|
+ # 统一的后处理
|
|
|
|
+ image = image_orig.astype(np.float32) / 255.0
|
|
|
|
+ image -= self.mean_[None, None, :]
|
|
|
|
+ image /= self.std_[None, None, :]
|
|
|
|
+ image = np.transpose(image, (2, 0, 1)) # CHW 格式
|
|
|
|
+ return np.expand_dims(image, axis=0)
|
|
|
|
+
|
|
|
|
+ def nms(self, boxes: np.ndarray, scores: np.ndarray, conf_threshold: float, iou_threshold: float) -> List[int]:
|
|
|
|
+ """非极大值抑制"""
|
|
|
|
+ x1 = boxes[:, 0]
|
|
|
|
+ y1 = boxes[:, 1]
|
|
|
|
+ x2 = boxes[:, 2]
|
|
|
|
+ y2 = boxes[:, 3]
|
|
|
|
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
|
|
+ order = scores.argsort()[0][::-1]
|
|
|
|
+ keep = []
|
|
|
|
+
|
|
|
|
+ while order.size > 0:
|
|
|
|
+ i = order[0]
|
|
|
|
+ keep.append(i)
|
|
|
|
+ xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
|
|
+ yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
|
|
+ xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
|
|
+ yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
|
|
+ w = np.maximum(0.0, xx2 - xx1 + 1)
|
|
|
|
+ h = np.maximum(0.0, yy2 - yy1 + 1)
|
|
|
|
+ inter = w * h
|
|
|
|
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
|
|
|
+ inds = np.where(ovr <= iou_threshold)[0]
|
|
|
|
+ order = order[inds + 1]
|
|
|
|
+ return keep
|
|
|
|
+
|
|
|
|
+ def inference(self, input_data: np.ndarray) -> np.ndarray:
|
|
|
|
+ """执行模型推理"""
|
|
|
|
+ scale_factor = np.array([[1, 1]], dtype=np.float32)
|
|
|
|
+ return self.session.run(
|
|
|
|
+ [self.output_name],
|
|
|
|
+ {self.input_name: input_data, 'scale_factor': scale_factor}
|
|
|
|
+ )[0]
|
|
|
|
+
|
|
|
|
+ def _is_false_positive(self, roi: np.ndarray) -> bool:
|
|
|
|
+ """使用ORB特征匹配验证是否误报"""
|
|
|
|
+ if not self.false_positive_features:
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ # 初始化特征检测器
|
|
|
|
+ detector = cv2.ORB_create()
|
|
|
|
+ _, des = detector.detectAndCompute(roi, None)
|
|
|
|
+
|
|
|
|
+ # 与特征库进行匹配
|
|
|
|
+ for fp_feature in self.false_positive_features.values():
|
|
|
|
+ if des is None or fp_feature['features'] is None:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 使用FLANN匹配器
|
|
|
|
+ flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict())
|
|
|
|
+ matches = flann.knnMatch(des, fp_feature['features'], k=2)
|
|
|
|
+
|
|
|
|
+ # 处理空匹配情况
|
|
|
|
+ if not matches:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 安全验证匹配结果
|
|
|
|
+ good_matches = []
|
|
|
|
+ for match_group in matches:
|
|
|
|
+ # 确保match_group有足够元素防止索引错误
|
|
|
|
+ if len(match_group) < 2:
|
|
|
|
+ continue
|
|
|
|
+ m, n = match_group
|
|
|
|
+ if m.distance < 0.7 * n.distance:
|
|
|
|
+ good_matches.append(m)
|
|
|
|
+
|
|
|
|
+ # 计算优质匹配数量
|
|
|
|
+ if len(good_matches) > 15: # 匹配阈值
|
|
|
|
+ return True
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
|
|
|
|
+ """后处理检测结果"""
|
|
|
|
+ valid_detections = 0
|
|
|
|
+ img_out = image_orig.copy()
|
|
|
|
+ detections_list = []
|
|
|
|
+
|
|
|
|
+ # 过滤低置信度检测
|
|
|
|
+ keep_idx = (detections[:, 1] > self.confThreshold)
|
|
|
|
+ detections = detections[keep_idx]
|
|
|
|
+
|
|
|
|
+ if len(detections) == 0:
|
|
|
|
+ return 0, img_out, []
|
|
|
|
+
|
|
|
|
+ # 坐标转换
|
|
|
|
+ ratioh = self.orig_h / self.input_size[1]
|
|
|
|
+ ratiow = self.orig_w / self.input_size[0]
|
|
|
|
+ detections[:, 2
|
|
|
|
+ :6] *= np.array([ratiow, ratioh, ratiow, ratioh])
|
|
|
|
+
|
|
|
|
+ # NMS处理
|
|
|
|
+ keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
|
|
|
|
+
|
|
|
|
+ for idx in keep:
|
|
|
|
+ class_id = int(detections[idx, 0])
|
|
|
|
+ confidence = detections[idx, 1]
|
|
|
|
+ x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
|
|
|
|
+
|
|
|
|
+ # 计算检测框面积比例
|
|
|
|
+ bbox_area = (x2 - x1) * (y2 - y1)
|
|
|
|
+ image_area = self.orig_w * self.orig_h
|
|
|
|
+ if bbox_area / image_area > self.max_bbox_ratio:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 截取ROI区域进行误报验证
|
|
|
|
+ roi = image_orig[y1:y2, x1:x2]
|
|
|
|
+ if roi.size == 0:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ if self._is_false_positive(roi):
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 生成唯一文件名并保存目标
|
|
|
|
+ target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
|
|
|
|
+ cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
|
|
|
|
+
|
|
|
|
+ # 绘制检测框和标签
|
|
|
|
+ label = f'{self.class_names[class_id]} {confidence:.2f}'
|
|
|
|
+ cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
|
|
|
|
+ cv2.putText(img_out, label, (x1, y1 - 10),
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
|
|
|
+ valid_detections += 1
|
|
|
|
+ detections_list.append({
|
|
|
|
+ 'class_id': class_id,
|
|
|
|
+ 'confidence': float(confidence),
|
|
|
|
+ 'bbox': [x1, y1, x2, y2],
|
|
|
|
+ 'orig_w': self.orig_w,
|
|
|
|
+ 'orig_h': self.orig_h
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ return valid_detections, img_out, detections_list
|
|
|
|
+
|
|
|
|
+ def process_image(self, image_path: str) -> int:
|
|
|
|
+ """端到端处理单张图像"""
|
|
|
|
+ input_data = self.preprocess(image_path)
|
|
|
|
+ if input_data is None:
|
|
|
|
+ return 0
|
|
|
|
+
|
|
|
|
+ detections = self.inference(input_data)
|
|
|
|
+ valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path)
|
|
|
|
+
|
|
|
|
+ # 无论是否检测到目标都记录信息
|
|
|
|
+ record = {
|
|
|
|
+ 'image_path': image_path,
|
|
|
|
+ 'has_detection': valid_count > 0,
|
|
|
|
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
|
|
|
|
+ # 创建输出图片目录
|
|
|
|
+ output_img_dir = os.path.join(self.output_dir, 'imgs')
|
|
|
|
+ os.makedirs(output_img_dir, exist_ok=True)
|
|
|
|
+ output_path = os.path.join(output_img_dir, filename)
|
|
|
|
+
|
|
|
|
+ # 根据设置保存图片
|
|
|
|
+ if valid_count > 0 or self.save_empty:
|
|
|
|
+ cv2.imwrite(output_path, processed_img)
|
|
|
|
+ else:
|
|
|
|
+ output_path = None
|
|
|
|
+ self.image_count += 1
|
|
|
|
+
|
|
|
|
+ # 记录检测信息
|
|
|
|
+ if valid_count > 0:
|
|
|
|
+ record['detections'] = detections_list
|
|
|
|
+
|
|
|
|
+ self.detection_records.append(record)
|
|
|
|
+ return valid_count
|
|
|
|
+
|
|
|
|
+# 命令行接口
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
|
|
|
|
+ parser.add_argument('--input', type=str, help='输入图像路径或目录')
|
|
|
|
+ parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
|
|
|
|
+ parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
|
|
|
|
+ parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
|
|
|
|
+ help='检测框最大面积比例阈值,默认0.05')
|
|
|
|
+ parser.add_argument('--save-empty', action='store_true',
|
|
|
|
+ help='是否保存未检测到目标的图片')
|
|
|
|
+ parser.add_argument('--gui', action='store_true',
|
|
|
|
+ help='启用图形界面选择输入目录')
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
+
|
|
|
|
+ # GUI模式处理
|
|
|
|
+ if args.gui:
|
|
|
|
+ detector = ONNXDetector(
|
|
|
|
+ threshold=args.threshold,
|
|
|
|
+ max_bbox_ratio=args.max_bbox_ratio,
|
|
|
|
+ output_dir=args.output,
|
|
|
|
+ save_empty=args.save_empty
|
|
|
|
+ )
|
|
|
|
+ input_dir = detector.select_input_directory()
|
|
|
|
+ if not input_dir:
|
|
|
|
+ print("未选择目录,程序退出")
|
|
|
|
+ exit()
|
|
|
|
+ args.input = input_dir
|
|
|
|
+
|
|
|
|
+ # 检查输入参数
|
|
|
|
+ if args.input is None:
|
|
|
|
+ print("错误:请提供输入路径或使用 --gui 参数选择目录")
|
|
|
|
+ parser.print_help()
|
|
|
|
+ exit(1)
|
|
|
|
+
|
|
|
|
+ # 初始化检测器时传递参数
|
|
|
|
+ detector = ONNXDetector(
|
|
|
|
+ threshold=args.threshold,
|
|
|
|
+ max_bbox_ratio=args.max_bbox_ratio,
|
|
|
|
+ output_dir=args.output,
|
|
|
|
+ save_empty=args.save_empty,
|
|
|
|
+ input_dir=args.input if args.input and os.path.isdir(args.input) else None
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def process_single(image_path: str):
|
|
|
|
+ detections = detector.process_image(image_path)
|
|
|
|
+ print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标')
|
|
|
|
+
|
|
|
|
+ if args.input and os.path.isdir(args.input):
|
|
|
|
+ total = 0
|
|
|
|
+ image_files = []
|
|
|
|
+ for root, _, files in os.walk(args.input):
|
|
|
|
+ image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')])
|
|
|
|
+ for img_file in tqdm(image_files, desc='Processing images'):
|
|
|
|
+ total += detector.process_image(img_file)
|
|
|
|
+ print(f'批量处理完成!共检测到 {total} 个目标')
|
|
|
|
+
|
|
|
|
+ # 生成CSV报告
|
|
|
|
+ csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
|
|
|
|
+ ReportGenerator(detector).generate_csv(csv_path)
|
|
|
|
+ print(f'CSV报告已生成: {csv_path}')
|
|
|
|
+ else:
|
|
|
|
+ detections = detector.process_image(args.input)
|