Przeglądaj źródła

feat: 添加多模型支持并优化推理流程

refactor(inference): 重构ONNXDetector类以支持不同模型类型
docs: 添加模型结构说明文档model_structure.md
build: 更新依赖版本并添加新依赖onnx
Hannnk 2 tygodni temu
rodzic
commit
1c06210aff

+ 1 - 1
.gitignore

@@ -21,7 +21,7 @@ wheels/
 *.egg
 
 # Virtual Environment
-venv/
+.venv/
 ENV/
 env/
 

+ 0 - 114
Output/output_20250328_173528_results/performance_stats.txt

@@ -1,114 +0,0 @@
-检测性能统计报告
-==================================================
-
-时间统计信息:
-采集开始时间: 2025-03-28 17:35:28
-采集结束时间: 2025-03-28 17:53:15
-持续时长: 1067秒
-
-详细统计信息:
-总样本数(CSV中唯一图像数): 27558
-检测到的目标总数: 320
-真实目标总数(TP): 97
-误报目标数(FP): 223
-真阴性数量(TN): 27238
-
-性能指标:
-准确率: 0.9919
-精确率: 0.3031
-召回率: 1.0000
-F1分数: 0.4652
-误报率(虚警率): 0.0081
-
-镜头统计信息:
-------------------------------
-镜头 cam_00:
-  总报警数: 4
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_01:
-  总报警数: 3
-  误报数: 1
-  误报率: 0.3333
-------------------------------
-镜头 cam_02:
-  总报警数: 2
-  误报数: 1
-  误报率: 0.5000
-------------------------------
-镜头 cam_03:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_04:
-  总报警数: 21
-  误报数: 6
-  误报率: 0.2857
-------------------------------
-镜头 cam_05:
-  总报警数: 119
-  误报数: 119
-  误报率: 1.0000
-------------------------------
-镜头 cam_06:
-  总报警数: 26
-  误报数: 26
-  误报率: 1.0000
-------------------------------
-镜头 cam_07:
-  总报警数: 45
-  误报数: 41
-  误报率: 0.9111
-------------------------------
-镜头 cam_08:
-  总报警数: 22
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_09:
-  总报警数: 12
-  误报数: 6
-  误报率: 0.5000
-------------------------------
-镜头 cam_10:
-  总报警数: 7
-  误报数: 1
-  误报率: 0.1429
-------------------------------
-镜头 cam_11:
-  总报警数: 1
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_12:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_13:
-  总报警数: 6
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_14:
-  总报警数: 24
-  误报数: 22
-  误报率: 0.9167
-------------------------------
-镜头 cam_15:
-  总报警数: 15
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_16:
-  总报警数: 4
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_17:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------

+ 0 - 114
Output/output_20250329_100448_results/performance_stats.txt

@@ -1,114 +0,0 @@
-检测性能统计报告
-==================================================
-
-时间统计信息:
-采集开始时间: 2025-03-29 10:04:48
-采集结束时间: 2025-03-29 10:26:52
-持续时长: 1324秒
-
-详细统计信息:
-总样本数(CSV中唯一图像数): 32850
-检测到的目标总数: 444
-真实目标总数(TP): 220
-误报目标数(FP): 224
-真阴性数量(TN): 32406
-
-性能指标:
-准确率: 0.9932
-精确率: 0.4955
-召回率: 1.0000
-F1分数: 0.6627
-误报率(虚警率): 0.0069
-
-镜头统计信息:
-------------------------------
-镜头 cam_00:
-  总报警数: 14
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_01:
-  总报警数: 49
-  误报数: 30
-  误报率: 0.6122
-------------------------------
-镜头 cam_02:
-  总报警数: 42
-  误报数: 17
-  误报率: 0.4048
-------------------------------
-镜头 cam_03:
-  总报警数: 21
-  误报数: 4
-  误报率: 0.1905
-------------------------------
-镜头 cam_04:
-  总报警数: 42
-  误报数: 23
-  误报率: 0.5476
-------------------------------
-镜头 cam_05:
-  总报警数: 72
-  误报数: 70
-  误报率: 0.9722
-------------------------------
-镜头 cam_06:
-  总报警数: 38
-  误报数: 34
-  误报率: 0.8947
-------------------------------
-镜头 cam_07:
-  总报警数: 47
-  误报数: 9
-  误报率: 0.1915
-------------------------------
-镜头 cam_08:
-  总报警数: 8
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_09:
-  总报警数: 15
-  误报数: 4
-  误报率: 0.2667
-------------------------------
-镜头 cam_10:
-  总报警数: 15
-  误报数: 7
-  误报率: 0.4667
-------------------------------
-镜头 cam_11:
-  总报警数: 11
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_12:
-  总报警数: 8
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_13:
-  总报警数: 6
-  误报数: 1
-  误报率: 0.1667
-------------------------------
-镜头 cam_14:
-  总报警数: 33
-  误报数: 24
-  误报率: 0.7273
-------------------------------
-镜头 cam_15:
-  总报警数: 6
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_16:
-  总报警数: 10
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_17:
-  总报警数: 7
-  误报数: 1
-  误报率: 0.1429
-------------------------------

+ 0 - 94
Output/output_20250329_111006_results/performance_stats.txt

@@ -1,94 +0,0 @@
-检测性能统计报告
-==================================================
-
-时间统计信息:
-采集开始时间: 2025-03-29 11:10:06
-采集结束时间: 2025-03-29 11:16:01
-持续时长: 355秒
-
-详细统计信息:
-总样本数(CSV中唯一图像数): 8136
-检测到的目标总数: 109
-真实目标总数(TP): 39
-误报目标数(FP): 70
-真阴性数量(TN): 8027
-
-性能指标:
-准确率: 0.9914
-精确率: 0.3578
-召回率: 1.0000
-F1分数: 0.5270
-误报率(虚警率): 0.0086
-
-镜头统计信息:
-------------------------------
-镜头 cam_01:
-  总报警数: 7
-  误报数: 6
-  误报率: 0.8571
-------------------------------
-镜头 cam_02:
-  总报警数: 13
-  误报数: 7
-  误报率: 0.5385
-------------------------------
-镜头 cam_03:
-  总报警数: 26
-  误报数: 23
-  误报率: 0.8846
-------------------------------
-镜头 cam_04:
-  总报警数: 2
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_05:
-  总报警数: 8
-  误报数: 8
-  误报率: 1.0000
-------------------------------
-镜头 cam_06:
-  总报警数: 4
-  误报数: 4
-  误报率: 1.0000
-------------------------------
-镜头 cam_07:
-  总报警数: 9
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_08:
-  总报警数: 2
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_09:
-  总报警数: 1
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_10:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_11:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_14:
-  总报警数: 26
-  误报数: 22
-  误报率: 0.8462
-------------------------------
-镜头 cam_15:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_17:
-  总报警数: 2
-  误报数: 0
-  误报率: 0.0000
-------------------------------

+ 0 - 109
Output/output_20250329_130338_results/performance_stats.txt

@@ -1,109 +0,0 @@
-检测性能统计报告
-==================================================
-
-时间统计信息:
-采集开始时间: 2025-03-29 13:03:38
-采集结束时间: 2025-03-29 13:27:23
-持续时长: 1425秒
-
-详细统计信息:
-总样本数(CSV中唯一图像数): 34562
-检测到的目标总数: 644
-真实目标总数(TP): 177
-误报目标数(FP): 467
-真阴性数量(TN): 33918
-
-性能指标:
-准确率: 0.9865
-精确率: 0.2748
-召回率: 1.0000
-F1分数: 0.4312
-误报率(虚警率): 0.0136
-
-镜头统计信息:
-------------------------------
-镜头 cam_00:
-  总报警数: 5
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_01:
-  总报警数: 48
-  误报数: 32
-  误报率: 0.6667
-------------------------------
-镜头 cam_02:
-  总报警数: 76
-  误报数: 59
-  误报率: 0.7763
-------------------------------
-镜头 cam_03:
-  总报警数: 20
-  误报数: 11
-  误报率: 0.5500
-------------------------------
-镜头 cam_04:
-  总报警数: 115
-  误报数: 95
-  误报率: 0.8261
-------------------------------
-镜头 cam_05:
-  总报警数: 110
-  误报数: 110
-  误报率: 1.0000
-------------------------------
-镜头 cam_06:
-  总报警数: 117
-  误报数: 115
-  误报率: 0.9829
-------------------------------
-镜头 cam_07:
-  总报警数: 45
-  误报数: 10
-  误报率: 0.2222
-------------------------------
-镜头 cam_08:
-  总报警数: 10
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_09:
-  总报警数: 16
-  误报数: 13
-  误报率: 0.8125
-------------------------------
-镜头 cam_10:
-  总报警数: 15
-  误报数: 9
-  误报率: 0.6000
-------------------------------
-镜头 cam_11:
-  总报警数: 2
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_13:
-  总报警数: 9
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_14:
-  总报警数: 22
-  误报数: 13
-  误报率: 0.5909
-------------------------------
-镜头 cam_15:
-  总报警数: 8
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_16:
-  总报警数: 13
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_17:
-  总报警数: 13
-  误报数: 0
-  误报率: 0.0000
-------------------------------

+ 0 - 109
Output/output_20250329_133810_results/performance_stats.txt

@@ -1,109 +0,0 @@
-检测性能统计报告
-==================================================
-
-时间统计信息:
-采集开始时间: 2025-03-29 13:38:10
-采集结束时间: 2025-03-29 14:02:17
-持续时长: 1447秒
-
-详细统计信息:
-总样本数(CSV中唯一图像数): 36540
-检测到的目标总数: 347
-真实目标总数(TP): 111
-误报目标数(FP): 236
-真阴性数量(TN): 36193
-
-性能指标:
-准确率: 0.9935
-精确率: 0.3199
-召回率: 1.0000
-F1分数: 0.4847
-误报率(虚警率): 0.0065
-
-镜头统计信息:
-------------------------------
-镜头 cam_00:
-  总报警数: 16
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_01:
-  总报警数: 47
-  误报数: 44
-  误报率: 0.9362
-------------------------------
-镜头 cam_02:
-  总报警数: 8
-  误报数: 8
-  误报率: 1.0000
-------------------------------
-镜头 cam_03:
-  总报警数: 11
-  误报数: 9
-  误报率: 0.8182
-------------------------------
-镜头 cam_04:
-  总报警数: 34
-  误报数: 30
-  误报率: 0.8824
-------------------------------
-镜头 cam_05:
-  总报警数: 89
-  误报数: 78
-  误报率: 0.8764
-------------------------------
-镜头 cam_06:
-  总报警数: 50
-  误报数: 43
-  误报率: 0.8600
-------------------------------
-镜头 cam_07:
-  总报警数: 28
-  误报数: 4
-  误报率: 0.1429
-------------------------------
-镜头 cam_08:
-  总报警数: 7
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_09:
-  总报警数: 7
-  误报数: 1
-  误报率: 0.1429
-------------------------------
-镜头 cam_10:
-  总报警数: 10
-  误报数: 7
-  误报率: 0.7000
-------------------------------
-镜头 cam_11:
-  总报警数: 10
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_13:
-  总报警数: 8
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_14:
-  总报警数: 16
-  误报数: 12
-  误报率: 0.7500
-------------------------------
-镜头 cam_15:
-  总报警数: 1
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_16:
-  总报警数: 4
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_17:
-  总报警数: 1
-  误报数: 0
-  误报率: 0.0000
-------------------------------

+ 0 - 109
Output/output_20250329_140816_results/performance_stats.txt

@@ -1,109 +0,0 @@
-检测性能统计报告
-==================================================
-
-时间统计信息:
-采集开始时间: 2025-03-29 14:08:16
-采集结束时间: 2025-03-29 14:28:47
-持续时长: 1231秒
-
-详细统计信息:
-总样本数(CSV中唯一图像数): 31086
-检测到的目标总数: 361
-真实目标总数(TP): 196
-误报目标数(FP): 165
-真阴性数量(TN): 30725
-
-性能指标:
-准确率: 0.9947
-精确率: 0.5429
-召回率: 1.0000
-F1分数: 0.7038
-误报率(虚警率): 0.0053
-
-镜头统计信息:
-------------------------------
-镜头 cam_00:
-  总报警数: 11
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_01:
-  总报警数: 58
-  误报数: 39
-  误报率: 0.6724
-------------------------------
-镜头 cam_02:
-  总报警数: 30
-  误报数: 7
-  误报率: 0.2333
-------------------------------
-镜头 cam_03:
-  总报警数: 20
-  误报数: 5
-  误报率: 0.2500
-------------------------------
-镜头 cam_04:
-  总报警数: 43
-  误报数: 14
-  误报率: 0.3256
-------------------------------
-镜头 cam_05:
-  总报警数: 73
-  误报数: 62
-  误报率: 0.8493
-------------------------------
-镜头 cam_06:
-  总报警数: 43
-  误报数: 22
-  误报率: 0.5116
-------------------------------
-镜头 cam_07:
-  总报警数: 29
-  误报数: 3
-  误报率: 0.1034
-------------------------------
-镜头 cam_08:
-  总报警数: 5
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_09:
-  总报警数: 13
-  误报数: 2
-  误报率: 0.1538
-------------------------------
-镜头 cam_10:
-  总报警数: 10
-  误报数: 6
-  误报率: 0.6000
-------------------------------
-镜头 cam_11:
-  总报警数: 5
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_13:
-  总报警数: 2
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_14:
-  总报警数: 8
-  误报数: 5
-  误报率: 0.6250
-------------------------------
-镜头 cam_15:
-  总报警数: 3
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_16:
-  总报警数: 6
-  误报数: 0
-  误报率: 0.0000
-------------------------------
-镜头 cam_17:
-  总报警数: 2
-  误报数: 0
-  误报率: 0.0000
-------------------------------

+ 3 - 2
requirements.txt

@@ -1,5 +1,6 @@
-onnxruntime-gpu==1.16.0
-opencv-python==4.8.0
+onnxruntime-gpu==1.15.1
+onnx==1.17.0
+opencv-python==4.11.0.86
 numpy==1.24.3
 tqdm==4.66.1
 openpyxl==3.1.2

+ 4 - 0
src/__init__.py

@@ -1 +1,5 @@
+"""
+R360 UAV Model Tool
+"""
+
 # 使src成为一个Python包 

+ 4 - 0
src/analysis/__init__.py

@@ -1 +1,5 @@
+"""
+Analysis and reporting functionality
+"""
+
 # 使analysis成为一个Python包 

+ 223 - 73
src/core/inference.py

@@ -9,7 +9,9 @@ from typing import List, Dict, Optional
 import time
 from datetime import datetime
 from tqdm import tqdm
-from ..analysis.report_generator import ReportGenerator
+from src.analysis.report_generator import ReportGenerator
+import onnx
+
 class ONNXDetector:
     def select_input_directory(self) -> Optional[str]:
         """图形界面选择输入目录"""
@@ -20,23 +22,38 @@ class ONNXDetector:
         folder_path = filedialog.askdirectory(title='选择输入目录')
         return folder_path if folder_path else None
 
-    def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5, 
+    def __init__(self, threshold: float = 0.5, 
                  output_dir: str = None, save_empty: bool = False, 
                  max_bbox_ratio: float = 0.5,
-                 input_dir: str = None):
+                 input_dir: str = None, model_type: str = None, model_path: str = None):
         self.image_count = 0
         self.detection_records = []
-        self.model_path = model_path
         self.input_dir = input_dir
         self.save_empty = save_empty
         self.threshold = threshold
         self.confThreshold = threshold
         
+        # 自动识别模型类型
+        if model_path:
+            self.model_path = model_path
+            self.model_type = model_type if model_type else get_model_type(model_path)
+        else:
+            self.model_type = model_type if model_type else get_model_type(r"D:\PythonProject\R360-UAVmodelTool\src\models\UAV-250411.onnx")
+            # 根据模型类型选择不同的模型路径
+            if self.model_type == 'Anti_UAV':
+                self.model_path = r"D:\PythonProject\R360-UAVmodelTool\src\models\250411_Anti_UAV.onnx"
+            elif self.model_type == 'UAV-250411':
+                self.model_path = r"D:\PythonProject\R360-UAVmodelTool\src\models\UAV-250411.onnx"
+            else:
+                self.model_path = r"D:\PythonProject\R360-UAVmodelTool\src\models\uav_and_bird.onnx"
+        print(f"模型类型: {self.model_type}")
+        print(f"模型路径: {self.model_path}")
+        
         # 初始化ONNX会话
         so = ort.SessionOptions()
         so.log_severity_level = 3
         self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
-        self.session = ort.InferenceSession(model_path, so, providers=self.providers)
+        self.session = ort.InferenceSession(self.model_path, so, providers=self.providers)
         
         # 模型参数设置
         self.input_size = (640, 640)  # width, height
@@ -45,11 +62,16 @@ class ONNXDetector:
         self.max_bbox_ratio = max_bbox_ratio
         
         # 从模型路径解析版本号
-        self.model_version = os.path.basename(model_path).split('.')[0]
+        self.model_version = os.path.basename(self.model_path).split('.')[0]
         
         # 获取输入输出名称
         self.input_name = self.session.get_inputs()[0].name
-        self.output_name = self.session.get_outputs()[0].name
+        if self.model_type == 'Anti_UAV':
+            self.output_name = 'multiclass_nms3_0.tmp_0'
+        elif self.model_type == 'UAV-250411':
+            self.output_name = 'tmp_17'
+        else:
+            self.output_name = 'conv2d_308.tmp_1'
         
         # 类别映射
         self.class_names = {0: 'UAV'}
@@ -140,13 +162,21 @@ class ONNXDetector:
             order = order[inds + 1]
         return keep
 
-    def inference(self, input_data: np.ndarray) -> np.ndarray:
-        """执行模型推理"""
-        scale_factor = np.array([[1, 1]], dtype=np.float32)
-        return self.session.run(
-            [self.output_name],
-            {self.input_name: input_data, 'scale_factor': scale_factor}
-        )[0]
+    def inference(self, input_data: np.ndarray):
+        if self.model_type == 'uav_and_bird':
+            # 返回所有输出
+            return self.session.run(None, {self.input_name: input_data})
+        elif self.model_type in ['Anti_UAV', 'UAV-250411']:
+            scale_factor = np.array([[1, 1]], dtype=np.float32)
+            return self.session.run(
+                [self.output_name],
+                {self.input_name: input_data, 'scale_factor': scale_factor}
+            )[0]
+        else:
+            return self.session.run(
+                [self.output_name],
+                {self.input_name: input_data}
+            )[0]
 
     def _is_false_positive(self, roi: np.ndarray) -> bool:
         """使用ORB特征匹配验证是否误报"""
@@ -185,65 +215,148 @@ class ONNXDetector:
                 return True
         return False
 
-    def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
-        """后处理检测结果"""
+    def postprocess(self, detections, image_orig: np.ndarray, image_path: str) -> tuple:
         valid_detections = 0
         img_out = image_orig.copy()
         detections_list = []
-        
-        # 过滤低置信度检测
-        keep_idx = (detections[:, 1] > self.confThreshold)
-        detections = detections[keep_idx]
-        
-        if len(detections) == 0:
-            return 0, img_out, []
-            
-        # 坐标转换
-        ratioh = self.orig_h / self.input_size[1]
-        ratiow = self.orig_w / self.input_size[0]
-        detections[:, 2
-        :6] *= np.array([ratiow, ratioh, ratiow, ratioh])
-        
-        # NMS处理
-        keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
-        
-        for idx in keep:
-            class_id = int(detections[idx, 0])
-            confidence = detections[idx, 1]
-            x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
-            
-            # 计算检测框面积比例
-            bbox_area = (x2 - x1) * (y2 - y1)
-            image_area = self.orig_w * self.orig_h
-            if bbox_area / image_area > self.max_bbox_ratio:
-                continue
-
-            # 截取ROI区域进行误报验证
-            roi = image_orig[y1:y2, x1:x2]
-            if roi.size == 0:
-                continue
-                
-            if self._is_false_positive(roi):
-                continue
-                
-            # 生成唯一文件名并保存目标
-            target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
-            cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
-            
-            # 绘制检测框和标签
-            label = f'{self.class_names[class_id]} {confidence:.2f}'
-            cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
-            cv2.putText(img_out, label, (x1, y1 - 10), 
-                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
-            valid_detections += 1
-            detections_list.append({
-                'class_id': class_id,
-                'confidence': float(confidence),
-                'bbox': [x1, y1, x2, y2],
-                'orig_w': self.orig_w,
-                'orig_h': self.orig_h
-            })
-            
+        if self.model_type == 'Anti_UAV':
+            # Anti_UAV模型的后处理逻辑
+            keep_idx = (detections[:, 1] > self.confThreshold)
+            detections = detections[keep_idx]
+            if len(detections) == 0:
+                return 0, img_out, []
+            ratioh = self.orig_h / self.input_size[1]
+            ratiow = self.orig_w / self.input_size[0]
+            detections[:, 2:6] *= np.array([ratiow, ratioh, ratiow, ratioh])
+            keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
+            for idx in keep:
+                class_id = int(detections[idx, 0])
+                confidence = detections[idx, 1]
+                x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
+                bbox_area = (x2 - x1) * (y2 - y1)
+                image_area = self.orig_w * self.orig_h
+                if bbox_area / image_area > self.max_bbox_ratio:
+                    continue
+                roi = image_orig[y1:y2, x1:x2]
+                if roi.size == 0:
+                    continue
+                if self._is_false_positive(roi):
+                    continue
+                target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
+                cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
+                label = f'{self.class_names[class_id]} {confidence:.2f}'
+                cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
+                cv2.putText(img_out, label, (x1, y1 - 10), 
+                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+                valid_detections += 1
+                detections_list.append({
+                    'class_id': class_id,
+                    'confidence': float(confidence),
+                    'bbox': [x1, y1, x2, y2],
+                    'orig_w': self.orig_w,
+                    'orig_h': self.orig_h
+                })
+        elif self.model_type == 'UAV-250411':
+            # UAV-250411模型的后处理逻辑
+            keep_idx = (detections[:, 1] > self.confThreshold)
+            detections = detections[keep_idx]
+            if len(detections) == 0:
+                return 0, img_out, []
+            ratioh = self.orig_h / self.input_size[1]
+            ratiow = self.orig_w / self.input_size[0]
+            detections[:, 2:6] *= np.array([ratiow, ratioh, ratiow, ratioh])
+            keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
+            for idx in keep:
+                class_id = int(detections[idx, 0])
+                confidence = detections[idx, 1]
+                x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
+                bbox_area = (x2 - x1) * (y2 - y1)
+                image_area = self.orig_w * self.orig_h
+                if bbox_area / image_area > self.max_bbox_ratio:
+                    continue
+                roi = image_orig[y1:y2, x1:x2]
+                if roi.size == 0:
+                    continue
+                if self._is_false_positive(roi):
+                    continue
+                target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
+                cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
+                label = f'{self.class_names[class_id]} {confidence:.2f}'
+                cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
+                cv2.putText(img_out, label, (x1, y1 - 10), 
+                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+                valid_detections += 1
+                detections_list.append({
+                    'class_id': class_id,
+                    'confidence': float(confidence),
+                    'bbox': [x1, y1, x2, y2],
+                    'orig_w': self.orig_w,
+                    'orig_h': self.orig_h
+                })
+        elif self.model_type == 'uav_and_bird':
+            # Netron顺序: [bbox20, cls20, obj20, bbox40, cls40, obj40, bbox80, cls80, obj80]
+            all_boxes, all_scores, all_classes = [], [], []
+            scales = [
+                (0, 1, 2, 20, 20),
+                (3, 4, 5, 40, 40),
+                (6, 7, 8, 80, 80),
+            ]
+            for bbox_idx, cls_idx, obj_idx, H, W in scales:
+                bbox_pred = detections[bbox_idx][0]  # (68, H, W)
+                cls_pred = detections[cls_idx][0]    # (2, H, W)
+                obj_pred = detections[obj_idx][0]    # (1, H, W)
+                bbox_pred = bbox_pred.reshape(68, -1).T  # (H*W, 68)
+                cls_pred = cls_pred.reshape(2, -1).T     # (H*W, 2)
+                obj_pred = obj_pred.reshape(-1)          # (H*W,)
+                scores = obj_pred[:, None] * (1 / (1 + np.exp(-cls_pred)))
+                max_scores = np.max(scores, axis=1)
+                class_ids = np.argmax(scores, axis=1)
+                mask = max_scores > self.confThreshold
+                boxes = bbox_pred[mask]
+                scores = max_scores[mask]
+                classes = class_ids[mask]
+                all_boxes.append(boxes)
+                all_scores.append(scores)
+                all_classes.append(classes)
+            if len(all_boxes) == 0 or all([len(b) == 0 for b in all_boxes]):
+                return 0, img_out, []
+            boxes = np.concatenate(all_boxes, axis=0)
+            scores = np.concatenate(all_scores, axis=0)
+            classes = np.concatenate(all_classes, axis=0)
+            # 坐标缩放
+            ratioh = self.orig_h / self.input_size[1]
+            ratiow = self.orig_w / self.input_size[0]
+            boxes[:, [0, 2]] *= ratiow
+            boxes[:, [1, 3]] *= ratioh
+            # NMS
+            keep = self.nms(boxes, scores[:, None], self.confThreshold, 0.4)
+            for idx in keep:
+                x1, y1, x2, y2 = boxes[idx][:4].astype(int)
+                confidence = scores[idx]
+                class_id = int(classes[idx])
+                bbox_area = (x2 - x1) * (y2 - y1)
+                image_area = self.orig_w * self.orig_h
+                if bbox_area / image_area > self.max_bbox_ratio:
+                    continue
+                roi = image_orig[y1:y2, x1:x2]
+                if roi.size == 0:
+                    continue
+                if self._is_false_positive(roi):
+                    continue
+                target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
+                cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
+                label = f'class {class_id} {confidence:.2f}'
+                cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
+                cv2.putText(img_out, label, (x1, y1 - 10), 
+                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+                valid_detections += 1
+                detections_list.append({
+                    'class_id': class_id,
+                    'confidence': float(confidence),
+                    'bbox': [x1, y1, x2, y2],
+                    'orig_w': self.orig_w,
+                    'orig_h': self.orig_h
+                })
         return valid_detections, img_out, detections_list
 
     def process_image(self, image_path: str) -> int:
@@ -282,6 +395,36 @@ class ONNXDetector:
         self.detection_records.append(record)
         return valid_count
 
+def get_model_type(model_path: str) -> str:
+    """
+    根据ONNX模型的输入输出结构自动判断模型类型。
+    返回值:
+        'Anti_UAV':250411_Anti_UAV.onnx
+        'UAV-250411':UAV-250411.onnx
+        'uav_and_bird':uav_and_bird.onnx
+        'unknown':无法识别
+    """
+    try:
+        model = onnx.load(model_path)
+        input_names = [i.name for i in model.graph.input]
+        output_names = [o.name for o in model.graph.output]
+        # 判断Anti_UAV模型
+        if 'scale_factor' in input_names:
+            if any(name.startswith('multiclass_nms3') for name in output_names):
+                return 'Anti_UAV'
+            elif len(output_names) == 2 and output_names[0].startswith('tmp_'):
+                return 'UAV-250411'
+            else:
+                return 'UAV-250411'
+        else:
+            # 没有scale_factor输入,且输出较多
+            if len(output_names) > 2:
+                return 'uav_and_bird'
+        return 'unknown'
+    except Exception as e:
+        print(f"模型类型识别失败: {e}")
+        return 'unknown'
+
 # 命令行接口
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
@@ -294,6 +437,9 @@ if __name__ == '__main__':
                       help='是否保存未检测到目标的图片')
     parser.add_argument('--gui', action='store_true',
                       help='启用图形界面选择输入目录')
+    parser.add_argument('--model-type', type=str, choices=['Anti_UAV', 'UAV-250411', 'uav_and_bird'],
+                      help='指定模型类型,不指定则自动识别')
+    parser.add_argument('--model-path', type=str, help='指定模型路径,不指定则根据模型类型自动选择')
     args = parser.parse_args()
 
     # GUI模式处理
@@ -302,7 +448,9 @@ if __name__ == '__main__':
             threshold=args.threshold,
             max_bbox_ratio=args.max_bbox_ratio,
             output_dir=args.output,
-            save_empty=args.save_empty
+            save_empty=args.save_empty,
+            model_type=args.model_type,
+            model_path=args.model_path
         )
         input_dir = detector.select_input_directory()
         if not input_dir:
@@ -316,7 +464,9 @@ if __name__ == '__main__':
         max_bbox_ratio=args.max_bbox_ratio,
         output_dir=args.output,
         save_empty=args.save_empty,
-        input_dir=args.input if os.path.isdir(args.input) else None
+        input_dir=args.input if os.path.isdir(args.input) else None,
+        model_type=args.model_type,
+        model_path=args.model_path
     )
 
     def process_single(image_path: str):
@@ -333,7 +483,7 @@ if __name__ == '__main__':
         print(f'批量处理完成!共检测到 {total} 个目标')
     
         # 生成CSV报告
-        csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
+        csv_path = os.path.join(detector.output_dir, f'detection_report_{detector.model_type}.csv')
         ReportGenerator(detector).generate_csv(csv_path)
         print(f'CSV报告已生成: {csv_path}')
     else: