|
@@ -9,7 +9,9 @@ from typing import List, Dict, Optional
|
|
|
import time
|
|
|
from datetime import datetime
|
|
|
from tqdm import tqdm
|
|
|
-from ..analysis.report_generator import ReportGenerator
|
|
|
+from src.analysis.report_generator import ReportGenerator
|
|
|
+import onnx
|
|
|
+
|
|
|
class ONNXDetector:
|
|
|
def select_input_directory(self) -> Optional[str]:
|
|
|
"""图形界面选择输入目录"""
|
|
@@ -20,23 +22,38 @@ class ONNXDetector:
|
|
|
folder_path = filedialog.askdirectory(title='选择输入目录')
|
|
|
return folder_path if folder_path else None
|
|
|
|
|
|
- def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5,
|
|
|
+ def __init__(self, threshold: float = 0.5,
|
|
|
output_dir: str = None, save_empty: bool = False,
|
|
|
max_bbox_ratio: float = 0.5,
|
|
|
- input_dir: str = None):
|
|
|
+ input_dir: str = None, model_type: str = None, model_path: str = None):
|
|
|
self.image_count = 0
|
|
|
self.detection_records = []
|
|
|
- self.model_path = model_path
|
|
|
self.input_dir = input_dir
|
|
|
self.save_empty = save_empty
|
|
|
self.threshold = threshold
|
|
|
self.confThreshold = threshold
|
|
|
|
|
|
+ # 自动识别模型类型
|
|
|
+ if model_path:
|
|
|
+ self.model_path = model_path
|
|
|
+ self.model_type = model_type if model_type else get_model_type(model_path)
|
|
|
+ else:
|
|
|
+ self.model_type = model_type if model_type else get_model_type(r"D:\PythonProject\R360-UAVmodelTool\src\models\UAV-250411.onnx")
|
|
|
+ # 根据模型类型选择不同的模型路径
|
|
|
+ if self.model_type == 'Anti_UAV':
|
|
|
+ self.model_path = r"D:\PythonProject\R360-UAVmodelTool\src\models\250411_Anti_UAV.onnx"
|
|
|
+ elif self.model_type == 'UAV-250411':
|
|
|
+ self.model_path = r"D:\PythonProject\R360-UAVmodelTool\src\models\UAV-250411.onnx"
|
|
|
+ else:
|
|
|
+ self.model_path = r"D:\PythonProject\R360-UAVmodelTool\src\models\uav_and_bird.onnx"
|
|
|
+ print(f"模型类型: {self.model_type}")
|
|
|
+ print(f"模型路径: {self.model_path}")
|
|
|
+
|
|
|
# 初始化ONNX会话
|
|
|
so = ort.SessionOptions()
|
|
|
so.log_severity_level = 3
|
|
|
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
|
- self.session = ort.InferenceSession(model_path, so, providers=self.providers)
|
|
|
+ self.session = ort.InferenceSession(self.model_path, so, providers=self.providers)
|
|
|
|
|
|
# 模型参数设置
|
|
|
self.input_size = (640, 640) # width, height
|
|
@@ -45,11 +62,16 @@ class ONNXDetector:
|
|
|
self.max_bbox_ratio = max_bbox_ratio
|
|
|
|
|
|
# 从模型路径解析版本号
|
|
|
- self.model_version = os.path.basename(model_path).split('.')[0]
|
|
|
+ self.model_version = os.path.basename(self.model_path).split('.')[0]
|
|
|
|
|
|
# 获取输入输出名称
|
|
|
self.input_name = self.session.get_inputs()[0].name
|
|
|
- self.output_name = self.session.get_outputs()[0].name
|
|
|
+ if self.model_type == 'Anti_UAV':
|
|
|
+ self.output_name = 'multiclass_nms3_0.tmp_0'
|
|
|
+ elif self.model_type == 'UAV-250411':
|
|
|
+ self.output_name = 'tmp_17'
|
|
|
+ else:
|
|
|
+ self.output_name = 'conv2d_308.tmp_1'
|
|
|
|
|
|
# 类别映射
|
|
|
self.class_names = {0: 'UAV'}
|
|
@@ -140,13 +162,21 @@ class ONNXDetector:
|
|
|
order = order[inds + 1]
|
|
|
return keep
|
|
|
|
|
|
- def inference(self, input_data: np.ndarray) -> np.ndarray:
|
|
|
- """执行模型推理"""
|
|
|
- scale_factor = np.array([[1, 1]], dtype=np.float32)
|
|
|
- return self.session.run(
|
|
|
- [self.output_name],
|
|
|
- {self.input_name: input_data, 'scale_factor': scale_factor}
|
|
|
- )[0]
|
|
|
+ def inference(self, input_data: np.ndarray):
|
|
|
+ if self.model_type == 'uav_and_bird':
|
|
|
+ # 返回所有输出
|
|
|
+ return self.session.run(None, {self.input_name: input_data})
|
|
|
+ elif self.model_type in ['Anti_UAV', 'UAV-250411']:
|
|
|
+ scale_factor = np.array([[1, 1]], dtype=np.float32)
|
|
|
+ return self.session.run(
|
|
|
+ [self.output_name],
|
|
|
+ {self.input_name: input_data, 'scale_factor': scale_factor}
|
|
|
+ )[0]
|
|
|
+ else:
|
|
|
+ return self.session.run(
|
|
|
+ [self.output_name],
|
|
|
+ {self.input_name: input_data}
|
|
|
+ )[0]
|
|
|
|
|
|
def _is_false_positive(self, roi: np.ndarray) -> bool:
|
|
|
"""使用ORB特征匹配验证是否误报"""
|
|
@@ -185,65 +215,148 @@ class ONNXDetector:
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
- def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
|
|
|
- """后处理检测结果"""
|
|
|
+ def postprocess(self, detections, image_orig: np.ndarray, image_path: str) -> tuple:
|
|
|
valid_detections = 0
|
|
|
img_out = image_orig.copy()
|
|
|
detections_list = []
|
|
|
-
|
|
|
- # 过滤低置信度检测
|
|
|
- keep_idx = (detections[:, 1] > self.confThreshold)
|
|
|
- detections = detections[keep_idx]
|
|
|
-
|
|
|
- if len(detections) == 0:
|
|
|
- return 0, img_out, []
|
|
|
-
|
|
|
- # 坐标转换
|
|
|
- ratioh = self.orig_h / self.input_size[1]
|
|
|
- ratiow = self.orig_w / self.input_size[0]
|
|
|
- detections[:, 2
|
|
|
- :6] *= np.array([ratiow, ratioh, ratiow, ratioh])
|
|
|
-
|
|
|
- # NMS处理
|
|
|
- keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
|
|
|
-
|
|
|
- for idx in keep:
|
|
|
- class_id = int(detections[idx, 0])
|
|
|
- confidence = detections[idx, 1]
|
|
|
- x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
|
|
|
-
|
|
|
- # 计算检测框面积比例
|
|
|
- bbox_area = (x2 - x1) * (y2 - y1)
|
|
|
- image_area = self.orig_w * self.orig_h
|
|
|
- if bbox_area / image_area > self.max_bbox_ratio:
|
|
|
- continue
|
|
|
-
|
|
|
- # 截取ROI区域进行误报验证
|
|
|
- roi = image_orig[y1:y2, x1:x2]
|
|
|
- if roi.size == 0:
|
|
|
- continue
|
|
|
-
|
|
|
- if self._is_false_positive(roi):
|
|
|
- continue
|
|
|
-
|
|
|
- # 生成唯一文件名并保存目标
|
|
|
- target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
|
|
|
- cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
|
|
|
-
|
|
|
- # 绘制检测框和标签
|
|
|
- label = f'{self.class_names[class_id]} {confidence:.2f}'
|
|
|
- cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
|
|
|
- cv2.putText(img_out, label, (x1, y1 - 10),
|
|
|
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
|
|
- valid_detections += 1
|
|
|
- detections_list.append({
|
|
|
- 'class_id': class_id,
|
|
|
- 'confidence': float(confidence),
|
|
|
- 'bbox': [x1, y1, x2, y2],
|
|
|
- 'orig_w': self.orig_w,
|
|
|
- 'orig_h': self.orig_h
|
|
|
- })
|
|
|
-
|
|
|
+ if self.model_type == 'Anti_UAV':
|
|
|
+ # Anti_UAV模型的后处理逻辑
|
|
|
+ keep_idx = (detections[:, 1] > self.confThreshold)
|
|
|
+ detections = detections[keep_idx]
|
|
|
+ if len(detections) == 0:
|
|
|
+ return 0, img_out, []
|
|
|
+ ratioh = self.orig_h / self.input_size[1]
|
|
|
+ ratiow = self.orig_w / self.input_size[0]
|
|
|
+ detections[:, 2:6] *= np.array([ratiow, ratioh, ratiow, ratioh])
|
|
|
+ keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
|
|
|
+ for idx in keep:
|
|
|
+ class_id = int(detections[idx, 0])
|
|
|
+ confidence = detections[idx, 1]
|
|
|
+ x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
|
|
|
+ bbox_area = (x2 - x1) * (y2 - y1)
|
|
|
+ image_area = self.orig_w * self.orig_h
|
|
|
+ if bbox_area / image_area > self.max_bbox_ratio:
|
|
|
+ continue
|
|
|
+ roi = image_orig[y1:y2, x1:x2]
|
|
|
+ if roi.size == 0:
|
|
|
+ continue
|
|
|
+ if self._is_false_positive(roi):
|
|
|
+ continue
|
|
|
+ target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
|
|
|
+ cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
|
|
|
+ label = f'{self.class_names[class_id]} {confidence:.2f}'
|
|
|
+ cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
|
|
|
+ cv2.putText(img_out, label, (x1, y1 - 10),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
|
|
+ valid_detections += 1
|
|
|
+ detections_list.append({
|
|
|
+ 'class_id': class_id,
|
|
|
+ 'confidence': float(confidence),
|
|
|
+ 'bbox': [x1, y1, x2, y2],
|
|
|
+ 'orig_w': self.orig_w,
|
|
|
+ 'orig_h': self.orig_h
|
|
|
+ })
|
|
|
+ elif self.model_type == 'UAV-250411':
|
|
|
+ # UAV-250411模型的后处理逻辑
|
|
|
+ keep_idx = (detections[:, 1] > self.confThreshold)
|
|
|
+ detections = detections[keep_idx]
|
|
|
+ if len(detections) == 0:
|
|
|
+ return 0, img_out, []
|
|
|
+ ratioh = self.orig_h / self.input_size[1]
|
|
|
+ ratiow = self.orig_w / self.input_size[0]
|
|
|
+ detections[:, 2:6] *= np.array([ratiow, ratioh, ratiow, ratioh])
|
|
|
+ keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
|
|
|
+ for idx in keep:
|
|
|
+ class_id = int(detections[idx, 0])
|
|
|
+ confidence = detections[idx, 1]
|
|
|
+ x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
|
|
|
+ bbox_area = (x2 - x1) * (y2 - y1)
|
|
|
+ image_area = self.orig_w * self.orig_h
|
|
|
+ if bbox_area / image_area > self.max_bbox_ratio:
|
|
|
+ continue
|
|
|
+ roi = image_orig[y1:y2, x1:x2]
|
|
|
+ if roi.size == 0:
|
|
|
+ continue
|
|
|
+ if self._is_false_positive(roi):
|
|
|
+ continue
|
|
|
+ target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
|
|
|
+ cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
|
|
|
+ label = f'{self.class_names[class_id]} {confidence:.2f}'
|
|
|
+ cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
|
|
|
+ cv2.putText(img_out, label, (x1, y1 - 10),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
|
|
+ valid_detections += 1
|
|
|
+ detections_list.append({
|
|
|
+ 'class_id': class_id,
|
|
|
+ 'confidence': float(confidence),
|
|
|
+ 'bbox': [x1, y1, x2, y2],
|
|
|
+ 'orig_w': self.orig_w,
|
|
|
+ 'orig_h': self.orig_h
|
|
|
+ })
|
|
|
+ elif self.model_type == 'uav_and_bird':
|
|
|
+ # Netron顺序: [bbox20, cls20, obj20, bbox40, cls40, obj40, bbox80, cls80, obj80]
|
|
|
+ all_boxes, all_scores, all_classes = [], [], []
|
|
|
+ scales = [
|
|
|
+ (0, 1, 2, 20, 20),
|
|
|
+ (3, 4, 5, 40, 40),
|
|
|
+ (6, 7, 8, 80, 80),
|
|
|
+ ]
|
|
|
+ for bbox_idx, cls_idx, obj_idx, H, W in scales:
|
|
|
+ bbox_pred = detections[bbox_idx][0] # (68, H, W)
|
|
|
+ cls_pred = detections[cls_idx][0] # (2, H, W)
|
|
|
+ obj_pred = detections[obj_idx][0] # (1, H, W)
|
|
|
+ bbox_pred = bbox_pred.reshape(68, -1).T # (H*W, 68)
|
|
|
+ cls_pred = cls_pred.reshape(2, -1).T # (H*W, 2)
|
|
|
+ obj_pred = obj_pred.reshape(-1) # (H*W,)
|
|
|
+ scores = obj_pred[:, None] * (1 / (1 + np.exp(-cls_pred)))
|
|
|
+ max_scores = np.max(scores, axis=1)
|
|
|
+ class_ids = np.argmax(scores, axis=1)
|
|
|
+ mask = max_scores > self.confThreshold
|
|
|
+ boxes = bbox_pred[mask]
|
|
|
+ scores = max_scores[mask]
|
|
|
+ classes = class_ids[mask]
|
|
|
+ all_boxes.append(boxes)
|
|
|
+ all_scores.append(scores)
|
|
|
+ all_classes.append(classes)
|
|
|
+ if len(all_boxes) == 0 or all([len(b) == 0 for b in all_boxes]):
|
|
|
+ return 0, img_out, []
|
|
|
+ boxes = np.concatenate(all_boxes, axis=0)
|
|
|
+ scores = np.concatenate(all_scores, axis=0)
|
|
|
+ classes = np.concatenate(all_classes, axis=0)
|
|
|
+ # 坐标缩放
|
|
|
+ ratioh = self.orig_h / self.input_size[1]
|
|
|
+ ratiow = self.orig_w / self.input_size[0]
|
|
|
+ boxes[:, [0, 2]] *= ratiow
|
|
|
+ boxes[:, [1, 3]] *= ratioh
|
|
|
+ # NMS
|
|
|
+ keep = self.nms(boxes, scores[:, None], self.confThreshold, 0.4)
|
|
|
+ for idx in keep:
|
|
|
+ x1, y1, x2, y2 = boxes[idx][:4].astype(int)
|
|
|
+ confidence = scores[idx]
|
|
|
+ class_id = int(classes[idx])
|
|
|
+ bbox_area = (x2 - x1) * (y2 - y1)
|
|
|
+ image_area = self.orig_w * self.orig_h
|
|
|
+ if bbox_area / image_area > self.max_bbox_ratio:
|
|
|
+ continue
|
|
|
+ roi = image_orig[y1:y2, x1:x2]
|
|
|
+ if roi.size == 0:
|
|
|
+ continue
|
|
|
+ if self._is_false_positive(roi):
|
|
|
+ continue
|
|
|
+ target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
|
|
|
+ cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
|
|
|
+ label = f'class {class_id} {confidence:.2f}'
|
|
|
+ cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
|
|
|
+ cv2.putText(img_out, label, (x1, y1 - 10),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
|
|
+ valid_detections += 1
|
|
|
+ detections_list.append({
|
|
|
+ 'class_id': class_id,
|
|
|
+ 'confidence': float(confidence),
|
|
|
+ 'bbox': [x1, y1, x2, y2],
|
|
|
+ 'orig_w': self.orig_w,
|
|
|
+ 'orig_h': self.orig_h
|
|
|
+ })
|
|
|
return valid_detections, img_out, detections_list
|
|
|
|
|
|
def process_image(self, image_path: str) -> int:
|
|
@@ -282,6 +395,36 @@ class ONNXDetector:
|
|
|
self.detection_records.append(record)
|
|
|
return valid_count
|
|
|
|
|
|
+def get_model_type(model_path: str) -> str:
|
|
|
+ """
|
|
|
+ 根据ONNX模型的输入输出结构自动判断模型类型。
|
|
|
+ 返回值:
|
|
|
+ 'Anti_UAV':250411_Anti_UAV.onnx
|
|
|
+ 'UAV-250411':UAV-250411.onnx
|
|
|
+ 'uav_and_bird':uav_and_bird.onnx
|
|
|
+ 'unknown':无法识别
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ model = onnx.load(model_path)
|
|
|
+ input_names = [i.name for i in model.graph.input]
|
|
|
+ output_names = [o.name for o in model.graph.output]
|
|
|
+ # 判断Anti_UAV模型
|
|
|
+ if 'scale_factor' in input_names:
|
|
|
+ if any(name.startswith('multiclass_nms3') for name in output_names):
|
|
|
+ return 'Anti_UAV'
|
|
|
+ elif len(output_names) == 2 and output_names[0].startswith('tmp_'):
|
|
|
+ return 'UAV-250411'
|
|
|
+ else:
|
|
|
+ return 'UAV-250411'
|
|
|
+ else:
|
|
|
+ # 没有scale_factor输入,且输出较多
|
|
|
+ if len(output_names) > 2:
|
|
|
+ return 'uav_and_bird'
|
|
|
+ return 'unknown'
|
|
|
+ except Exception as e:
|
|
|
+ print(f"模型类型识别失败: {e}")
|
|
|
+ return 'unknown'
|
|
|
+
|
|
|
# 命令行接口
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
|
|
@@ -294,6 +437,9 @@ if __name__ == '__main__':
|
|
|
help='是否保存未检测到目标的图片')
|
|
|
parser.add_argument('--gui', action='store_true',
|
|
|
help='启用图形界面选择输入目录')
|
|
|
+ parser.add_argument('--model-type', type=str, choices=['Anti_UAV', 'UAV-250411', 'uav_and_bird'],
|
|
|
+ help='指定模型类型,不指定则自动识别')
|
|
|
+ parser.add_argument('--model-path', type=str, help='指定模型路径,不指定则根据模型类型自动选择')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
# GUI模式处理
|
|
@@ -302,7 +448,9 @@ if __name__ == '__main__':
|
|
|
threshold=args.threshold,
|
|
|
max_bbox_ratio=args.max_bbox_ratio,
|
|
|
output_dir=args.output,
|
|
|
- save_empty=args.save_empty
|
|
|
+ save_empty=args.save_empty,
|
|
|
+ model_type=args.model_type,
|
|
|
+ model_path=args.model_path
|
|
|
)
|
|
|
input_dir = detector.select_input_directory()
|
|
|
if not input_dir:
|
|
@@ -316,7 +464,9 @@ if __name__ == '__main__':
|
|
|
max_bbox_ratio=args.max_bbox_ratio,
|
|
|
output_dir=args.output,
|
|
|
save_empty=args.save_empty,
|
|
|
- input_dir=args.input if os.path.isdir(args.input) else None
|
|
|
+ input_dir=args.input if os.path.isdir(args.input) else None,
|
|
|
+ model_type=args.model_type,
|
|
|
+ model_path=args.model_path
|
|
|
)
|
|
|
|
|
|
def process_single(image_path: str):
|
|
@@ -333,7 +483,7 @@ if __name__ == '__main__':
|
|
|
print(f'批量处理完成!共检测到 {total} 个目标')
|
|
|
|
|
|
# 生成CSV报告
|
|
|
- csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
|
|
|
+ csv_path = os.path.join(detector.output_dir, f'detection_report_{detector.model_type}.csv')
|
|
|
ReportGenerator(detector).generate_csv(csv_path)
|
|
|
print(f'CSV报告已生成: {csv_path}')
|
|
|
else:
|