Bladeren bron

refactor(inference): 优化模型推理流程和代码结构

重构了ONNXDetector类的初始化、预处理和后处理逻辑,移除了Excel报告生成功能,统一了模型参数设置,并改进了非极大值抑制(NMS)的实现。这些更改提高了代码的可维护性和推理效率。
Hannnk 3 weken geleden
bovenliggende
commit
4a506625db
6 gewijzigde bestanden met toevoegingen van 358 en 193 verwijderingen
  1. 56 143
      README.md
  2. 1 1
      src/analysis/performance_analyzer.py
  3. 66 0
      src/core/compare_outputs.py
  4. 0 0
      src/core/detect.py
  5. 168 0
      src/core/infer.py
  6. 67 49
      src/core/inference.py

+ 56 - 143
README.md

@@ -1,172 +1,85 @@
-# 无人机检测 ONNX 模型推理工具
+# UAV Detection System
 
-## 项目概述
-基于ONNX Runtime的无人机目标检测解决方案,支持:
-- GPU加速推理(需CUDA环境)
-- 自适应图像预处理流水线(CPU/GPU双模式)
-- 误报过滤机制(ORB特征匹配)
-- 多格式检测报告生成(CSV/Excel)
+基于 ONNX 的无人机检测系统,支持单张图片和批量处理。
 
-## 项目结构
-```
-Model/
-├── src/                    # 源代码目录
-│   ├── core/              # 核心功能模块
-│   │   ├── inference.py   # 推理主程序
-│   │   ├── feature_extractor.py
-│   │   └── image_validation.py
-│   ├── utils/             # 工具函数
-│   │   ├── visualizer.py  # 可视化工具
-│   │   └── dataset.py     # 数据集工具
-│   └── analysis/          # 分析工具
-│       ├── hotspot_analyzer.py
-│       └── report_generator.py
-├── data/                  # 数据目录
-│   ├── false_positive/   # 误报样本
-│   └── models/           # 模型文件
-├── tests/                # 测试目录
-├── outputs/              # 输出目录
-└── docs/                 # 文档目录
-```
+## 功能特点
+
+- 支持 ONNX 模型推理
+- 支持 CUDA 加速
+- 支持批量处理图片
+- 自动生成检测报告(CSV格式)
+- 支持误报过滤
+- 支持检测框面积比例限制
+- 支持保存未检测到目标的图片
 
 ## 环境要求
-- Python 3.8+ 
-- CUDA环境(如需GPU加速)
-- 依赖库:
-```bash
-pip install -r requirements.txt
-```
 
-## 快速开始
+- Python 3.8+
+- OpenCV 4.5+
+- ONNX Runtime 1.9+
+- CUDA 11.0+ (可选,用于GPU加速)
+
+## 安装依赖
 
-### 1. 安装依赖
 ```bash
 pip install -r requirements.txt
 ```
 
-### 2. 运行推理
-在项目根目录下运行:
+## 使用方法
 
-#### 使用图形界面(推荐)
-```bash
-python -m src.core.inference --gui
-```
+### 命令行模式
 
-#### 单张图片检测
 ```bash
-python -m src.core.inference --input 图片路径 --threshold 0.5 --output results --report excel
-```
+# 处理单张图片
+python -m src.core.inference --input path/to/image.jpg --output results
 
-#### 批量检测目录
-```bash
-python -m src.core.inference --input 图片目录 --max-bbox-ratio 0.1 --output batch_results --report all
-```
+# 处理整个目录
+python -m src.core.inference --input path/to/images_dir --output results
 
-## 参数说明
-| 参数 | 类型 | 默认值 | 说明 |
-|------|------|-------|-----|
-| --input | path | 必填 | 输入图像路径/目录 |
-| --threshold | float | 0.45 | 检测置信度阈值 |
-| --max-bbox-ratio | float | 0.05 | 最大检测框面积比例 |
-| --output | path | 自动生成 | 输出目录(自动创建) |
-| --report | [csv/excel/all] | excel | 报告格式 |
-| --save-empty | flag | False | 保存无检测结果的图像 |
-| --gui | flag | False | 启用图形界面选择输入目录 |
-
-## 结果输出
-输出目录包含:
-- `目标截图/`:检测到的无人机目标截图
-- `检测报告.xlsx`:包含检测时间、置信度、坐标等详细信息
-- 带检测框的标注图像
-
-## 模型管理
-- 模型文件位置:`data/models/UAV-250321.onnx`
-- 模型版本通过文件名自动识别(如:UAV-250321.onnx 版本号为250321)
-
-## 常见问题
-Q: 如何启用GPU加速?
-A: 确保安装onnxruntime-gpu并配置CUDA环境,程序会自动优先使用GPU
-
-Q: 误报特征库如何更新?
-A: 将误报样本放入data/false_positive目录,程序启动时会自动加载
-
-## 开发指南
-1. 代码规范
-   - 遵循PEP 8编码规范
-   - 使用类型注解
-   - 编写单元测试
-
-2. 测试
-```bash
-# 运行测试
-pytest tests/
+# 调整检测参数
+python -m src.core.inference --input path/to/image.jpg --threshold 0.6 --max-bbox-ratio 0.1
 
-# 生成测试覆盖率报告
-pytest --cov=src tests/
+# 保存未检测到目标的图片
+python -m src.core.inference --input path/to/images_dir --save-empty
 ```
 
-3. 文档
-   - 使用Google风格文档字符串
-   - 保持README.md更新
-   - 重要功能添加使用示例
-
-## 性能指标说明
-
-### 1. 基础统计指标
+### 图形界面模式
 
-- **总样本数**:CSV文件中所有唯一的图像文件数量
-- **检测到的目标总数**:模型检测到的所有目标数量
-- **真实目标总数(TP)**:确认的真实无人机目标数量
-- **误报目标数(FP)**:错误检测为非无人机目标的数量
-- **真阴性数量(TN)**:正确识别为非无人机目标的数量
-
-### 2. 性能评估指标
-
-- **准确率(Accuracy)**
-  - 计算公式:`(TP + TN) / (TP + FP + TN)`
-  - 说明:反映模型整体正确分类的能力,即正确预测的样本占总样本的比例
-  - 取值范围:0-1,越接近1表示模型整体性能越好
-
-- **精确率(Precision)**
-  - 计算公式:`TP / (TP + FP)`
-  - 说明:反映模型在预测为正样本(无人机)时的准确程度
-  - 取值范围:0-1,越接近1表示模型在预测无人机时越准确
-
-- **召回率(Recall)**
-  - 计算公式:`TP / (TP + FN)`
-  - 说明:反映模型发现所有真实无人机目标的能力
-  - 取值范围:0-1,越接近1表示模型发现真实无人机的能力越强
-
-- **F1分数(F1 Score)**
-  - 计算公式:`2 * (Precision * Recall) / (Precision + Recall)`
-  - 说明:精确率和召回率的调和平均数,用于平衡精确率和召回率
-  - 取值范围:0-1,越接近1表示模型在精确率和召回率之间取得更好的平衡
-
-- **误报率/虚警率(False Alarm Rate)**
-  - 计算公式:`FP / (TN + FP)`
-  - 说明:反映模型将非无人机目标错误识别为无人机的比例
-  - 取值范围:0-1,越接近0表示模型在避免误报方面表现越好
-
-## 使用方法
-
-1. 运行推理:
 ```bash
-python src/core/inference.py
+python -m src.core.inference --gui
 ```
 
-2. 分析性能:
-```bash
-python src/analysis/performance_analyzer.py
-```
+## 参数说明
+
+- `--input`: 输入图像路径或目录(必需)
+- `--output`: 输出目录路径,默认为输入目录名+_results
+- `--threshold`: 检测置信度阈值,默认0.5
+- `--max-bbox-ratio`: 检测框最大面积比例阈值,默认0.05
+- `--save-empty`: 是否保存未检测到目标的图片
+- `--gui`: 启用图形界面选择输入目录
 
 ## 输出说明
 
-运行性能分析后,将在结果目录下生成`performance_stats.txt`文件,包含:
-1. 详细统计信息(总样本数、检测目标数等)
-2. 性能指标(准确率、精确率、召回率等)
+程序会在输出目录中生成以下内容:
+
+- 检测结果图片(带检测框)
+- `detection_report.csv`: 检测报告,包含以下信息:
+  - 图片路径
+  - 检测时间
+  - 是否检测到目标
+  - 检测框坐标
+  - 置信度分数
 
 ## 注意事项
 
 1. 确保模型文件路径正确
-2. 确保输入图像格式正确
-3. 性能分析前需要将确认的真实目标放入`targets_all_True`目录
+2. 如果使用GPU加速,请确保CUDA环境配置正确
+3. 批量处理时建议使用相对较小的图片尺寸以提高处理速度
+4. 检测报告会自动覆盖同名文件,请注意备份
+
+## 更新日志
+
+### 2024-03-29
+- 移除Excel报告生成功能
+- 优化检测报告生成逻辑
+- 修复模型加载和属性访问问题

+ 1 - 1
src/analysis/performance_analyzer.py

@@ -209,7 +209,7 @@ class PerformanceAnalyzer:
 
 def main():
     # 示例使用
-    result_dir = r"D:\PythonProject\Model\Output\output_20250328_161223_results"
+    result_dir = r"D:\PythonProject\Model\output_20250329_140816_results"
     analyzer = PerformanceAnalyzer(result_dir)
     
     # 分析性能

+ 66 - 0
src/core/compare_outputs.py

@@ -0,0 +1,66 @@
+import os
+import shutil
+from pathlib import Path
+import re
+
+def get_base_filename(filename):
+    """获取文件名的基础部分(去除时间戳和序号后缀)"""
+    # 匹配形如 _20250414094956928104_0.jpg 的后缀
+    pattern = r'(_\d+_\d+\.jpg)$'
+    match = re.search(pattern, filename)
+    if match:
+        # 返回去除后缀的文件名
+        return filename[:-len(match.group(1))]
+    return filename
+
+def compare_directories(dir1, dir2, output_dir):
+    """比较两个目录中的文件并整理差异"""
+    # 创建输出目录
+    os.makedirs(output_dir, exist_ok=True)
+    
+    # 获取两个目录中的所有文件
+    dir1_files = set(os.listdir(dir1))
+    dir2_files = set(os.listdir(dir2))
+    
+    # 创建文件名映射
+    dir1_base_files = {get_base_filename(f): f for f in dir1_files}
+    dir2_base_files = {get_base_filename(f): f for f in dir2_files}
+    
+    # 找出只在dir1中存在的文件
+    only_in_dir1 = set(dir1_base_files.keys()) - set(dir2_base_files.keys())
+    # 找出只在dir2中存在的文件
+    only_in_dir2 = set(dir2_base_files.keys()) - set(dir1_base_files.keys())
+    
+    # 创建子目录
+    dir1_only_dir = os.path.join(output_dir, "only_in_dir1")
+    dir2_only_dir = os.path.join(output_dir, "only_in_dir2")
+    os.makedirs(dir1_only_dir, exist_ok=True)
+    os.makedirs(dir2_only_dir, exist_ok=True)
+    
+    # 复制文件
+    for base_name in only_in_dir1:
+        src_file = os.path.join(dir1, dir1_base_files[base_name])
+        dst_file = os.path.join(dir1_only_dir, dir1_base_files[base_name])
+        shutil.copy2(src_file, dst_file)
+        print(f"复制 {src_file} 到 {dst_file}")
+    
+    for base_name in only_in_dir2:
+        src_file = os.path.join(dir2, dir2_base_files[base_name])
+        dst_file = os.path.join(dir2_only_dir, dir2_base_files[base_name])
+        shutil.copy2(src_file, dst_file)
+        print(f"复制 {src_file} 到 {dst_file}")
+    
+    # 打印统计信息
+    print(f"\n统计信息:")
+    print(f"目录1中的文件总数: {len(dir1_files)}")
+    print(f"目录2中的文件总数: {len(dir2_files)}")
+    print(f"只在目录1中存在的文件数: {len(only_in_dir1)}")
+    print(f"只在目录2中存在的文件数: {len(only_in_dir2)}")
+    print(f"共同存在的文件数: {len(dir1_files) - len(only_in_dir1)}")
+
+if __name__ == "__main__":
+    dir1 = r"D:\PythonProject\Model\output_20250329_140816_results"
+    dir2 = r"D:\PythonProject\Model\Output\output_20250329_140816_results"
+    output_dir = r"D:\PythonProject\Model\output_differences"
+    
+    compare_directories(dir1, dir2, output_dir) 

+ 0 - 0
src/core/detect.py


+ 168 - 0
src/core/infer.py

@@ -0,0 +1,168 @@
+import os
+import cv2
+import numpy as np
+import onnxruntime
+from typing import List, Dict, Optional
+from datetime import datetime
+import argparse
+from tqdm import tqdm
+
+class UAVDetector:
+    def __init__(self, model_path: str, prob_threshold: float = 0.5, max_bbox_ratio: float = 0.05):
+        # 初始化ONNX会话
+        so = onnxruntime.SessionOptions()
+        so.log_severity_level = 3
+        self.session = onnxruntime.InferenceSession(
+            model_path, 
+            so,
+            providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
+        )
+        
+        # 模型参数设置
+        self.input_size = (640, 640)  # width, height
+        self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32)
+        self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32)
+        self.confThreshold = prob_threshold
+        self.max_bbox_ratio = max_bbox_ratio
+        
+        # 获取输入输出名称
+        self.input_name = self.session.get_inputs()[0].name
+        self.output_name = self.session.get_outputs()[0].name
+
+    def preprocess(self, srcimg: np.ndarray) -> tuple:
+        """图像预处理"""
+        img = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)
+        img = cv2.resize(img, self.input_size, interpolation=cv2.INTER_LINEAR)
+        img = img.astype(np.float32)
+        img = img / 255.
+        img -= self.mean_[None, None, :]
+        img /= self.std_[None, None, :]
+        img = np.transpose(img, [2, 0, 1])
+        scale_factor = np.array([1., 1.], dtype=np.float32)
+        return img, scale_factor
+
+    def nms(self, boxes: np.ndarray, scores: np.ndarray, conf_threshold: float, iou_threshold: float) -> List[int]:
+        """非极大值抑制"""
+        x1 = boxes[:, 0]
+        y1 = boxes[:, 1]
+        x2 = boxes[:, 2]
+        y2 = boxes[:, 3]
+        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+        order = scores.argsort()[0][::-1]
+        keep = []
+        
+        while order.size > 0:
+            i = order[0]
+            keep.append(i)
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            w = np.maximum(0.0, xx2 - xx1 + 1)
+            h = np.maximum(0.0, yy2 - yy1 + 1)
+            inter = w * h
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            inds = np.where(ovr <= iou_threshold)[0]
+            order = order[inds + 1]
+        return keep
+
+    def detect(self, srcimg: np.ndarray) -> tuple:
+        """执行检测"""
+        # 预处理
+        img, scale_factor = self.preprocess(srcimg)
+        inputs = {'image': img[None, :, :, :], 'scale_factor': scale_factor[None, :]}
+        ort_inputs = {i.name: inputs[i.name] for i in self.session.get_inputs() if i.name in inputs}
+        
+        # 推理
+        output = self.session.run(None, ort_inputs)
+        bbox, confidence = output
+        
+        # 过滤低置信度检测
+        keep_idx = (confidence[:, 0] > self.confThreshold)
+        bbox = bbox[keep_idx, :]
+        confidence = confidence[:, keep_idx]
+
+        # 坐标转换
+        ratioh = srcimg.shape[0] / self.input_size[1]
+        ratiow = srcimg.shape[1] / self.input_size[0]
+        bbox[:, 0] *= ratiow
+        bbox[:, 1] *= ratioh
+        bbox[:, 2] *= ratiow
+        bbox[:, 3] *= ratioh
+
+        # NMS处理
+        keep = self.nms(bbox, confidence, self.confThreshold, 0.4)
+
+        # 后处理
+        outres = []
+        img_out = srcimg.copy()
+        for idx in keep:
+            xmin, ymin, xmax, ymax = bbox[idx].astype(int)
+            
+            # 检查检测框面积比例
+            bbox_area = (xmax - xmin) * (ymax - ymin)
+            image_area = srcimg.shape[0] * srcimg.shape[1]
+            if bbox_area / image_area > self.max_bbox_ratio:
+                continue
+                
+            score = confidence[0][idx]
+            outres.append([xmin, ymin, xmax, ymax, score])
+            
+            # 绘制检测框和标签
+            cv2.rectangle(img_out, (xmin, ymin), (xmax, ymax), (0, 0, 255), thickness=2)
+            cv2.putText(img_out, f"UAV: {score:.2f}", (xmin, ymin - 10),
+                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), thickness=1)
+            
+        return img_out, outres
+
+def process_image(detector: UAVDetector, image_path: str, output_dir: str) -> int:
+    """处理单张图像"""
+    srcimg = cv2.imread(image_path)
+    if srcimg is None:
+        print(f"无法读取图像: {image_path}")
+        return 0
+        
+    processed_img, detections = detector.detect(srcimg)
+    
+    # 保存结果
+    if detections or args.save_empty:
+        filename = os.path.basename(image_path)
+        output_path = os.path.join(output_dir, filename)
+        cv2.imwrite(output_path, processed_img)
+        
+    return len(detections)
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
+    parser.add_argument('--model', type=str, required=True, help='ONNX模型路径')
+    parser.add_argument('--input', type=str, required=True, help='输入图像路径或目录')
+    parser.add_argument('--output', type=str, default='output', help='输出目录')
+    parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
+    parser.add_argument('--max-bbox-ratio', type=float, default=0.05, help='检测框最大面积比例阈值')
+    parser.add_argument('--save-empty', action='store_true', help='是否保存未检测到目标的图片')
+    args = parser.parse_args()
+
+    # 创建输出目录
+    os.makedirs(args.output, exist_ok=True)
+
+    # 初始化检测器
+    detector = UAVDetector(
+        model_path=args.model,
+        prob_threshold=args.threshold,
+        max_bbox_ratio=args.max_bbox_ratio
+    )
+
+    # 处理输入
+    if os.path.isdir(args.input):
+        total_detections = 0
+        image_files = []
+        for root, _, files in os.walk(args.input):
+            image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith(('.jpg', '.png'))])
+            
+        for img_file in tqdm(image_files, desc='Processing images'):
+            total_detections += process_image(detector, img_file, args.output)
+            
+        print(f'批量处理完成!共检测到 {total_detections} 个目标')
+    else:
+        detections = process_image(detector, args.input, args.output)
+        print(f'处理完成!检测到 {detections} 个目标') 

+ 67 - 49
src/core/inference.py

@@ -20,30 +20,34 @@ class ONNXDetector:
         folder_path = filedialog.askdirectory(title='选择输入目录')
         return folder_path if folder_path else None
 
-    def __init__(self, model_path: str = 'D:\PythonProject\Model\Data\models\250411_Anti_UAV.onnx', threshold: float = 0.5, 
+    def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5, 
                  output_dir: str = "None", save_empty: bool = False, 
-                 max_bbox_ratio: float = 0.5,  # 已有该参数
+                 max_bbox_ratio: float = 0.5,
                  input_dir: str = "None"):
         self.image_count = 0
         self.detection_records = []
         self.model_path = model_path
         self.input_dir = input_dir
+        self.save_empty = save_empty
+        self.threshold = threshold
+        self.confThreshold = threshold
         
         # 初始化ONNX会话
+        so = ort.SessionOptions()
+        so.log_severity_level = 3
         self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
-        self.session = ort.InferenceSession(model_path, providers=self.providers)
+        self.session = ort.InferenceSession(model_path, so, providers=self.providers)
         
-        # 获取模型输入尺寸
-        input_dims = self.session.get_inputs()[0].shape
-        self.input_shape = tuple([dim if isinstance(dim, int) else 1 for dim in input_dims])
+        # 模型参数设置
+        self.input_size = (640, 640)  # width, height
+        self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32)
+        self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32)
+        self.max_bbox_ratio = max_bbox_ratio
         
         # 从模型路径解析版本号
-        self.model_version = model_path.split('-')[-1].split('.')[0] if '-' in model_path else '1.0'
-        self.input_size = f"{self.input_shape[2]}x{self.input_shape[3]}"
+        self.model_version = os.path.basename(model_path).split('.')[0]
         
-        self.threshold = threshold
-        self.max_bbox_ratio = max_bbox_ratio  # 已有该赋值
-        self.save_empty = save_empty
+        # 获取输入输出名称
         self.input_name = self.session.get_inputs()[0].name
         self.output_name = self.session.get_outputs()[0].name
         
@@ -82,7 +86,7 @@ class ONNXDetector:
                 return None
             
             # GPU预处理流水线
-            gpu_resized = cv2.cuda.resize(self.gpu_frame, (self.input_shape[3], self.input_shape[2]))
+            gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
             gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
             
             # 下载到CPU进行后续处理
@@ -96,14 +100,41 @@ class ONNXDetector:
                 return None
             
             self.orig_h, self.orig_w = image_orig.shape[:2]
-            image_orig = cv2.resize(image_orig, (self.input_shape[3], self.input_shape[2]))
+            image_orig = cv2.resize(image_orig, self.input_size)
             image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
 
         # 统一的后处理
         image = image_orig.astype(np.float32) / 255.0
+        image -= self.mean_[None, None, :]
+        image /= self.std_[None, None, :]
         image = np.transpose(image, (2, 0, 1))  # CHW 格式
         return np.expand_dims(image, axis=0)
 
+    def nms(self, boxes: np.ndarray, scores: np.ndarray, conf_threshold: float, iou_threshold: float) -> List[int]:
+        """非极大值抑制"""
+        x1 = boxes[:, 0]
+        y1 = boxes[:, 1]
+        x2 = boxes[:, 2]
+        y2 = boxes[:, 3]
+        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+        order = scores.argsort()[0][::-1]
+        keep = []
+        
+        while order.size > 0:
+            i = order[0]
+            keep.append(i)
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            w = np.maximum(0.0, xx2 - xx1 + 1)
+            h = np.maximum(0.0, yy2 - yy1 + 1)
+            inter = w * h
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            inds = np.where(ovr <= iou_threshold)[0]
+            order = order[inds + 1]
+        return keep
+
     def inference(self, input_data: np.ndarray) -> np.ndarray:
         """执行模型推理"""
         scale_factor = np.array([[1, 1]], dtype=np.float32)
@@ -149,22 +180,32 @@ class ONNXDetector:
                 return True
         return False
 
-    def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> int:
+    def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
         """后处理检测结果"""
         valid_detections = 0
         img_out = image_orig.copy()
         detections_list = []
         
-        for det in detections:
-            class_id = int(det[0])
-            confidence = det[1]
+        # 过滤低置信度检测
+        keep_idx = (detections[:, 1] > self.confThreshold)
+        detections = detections[keep_idx]
+        
+        if len(detections) == 0:
+            return 0, img_out, []
             
-            if confidence < self.threshold:
-                continue
+        # 坐标转换
+        ratioh = self.orig_h / self.input_size[1]
+        ratiow = self.orig_w / self.input_size[0]
+        detections[:, 2:6] *= np.array([ratiow, ratioh, ratiow, ratioh])
+        
+        # NMS处理
+        keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
+        
+        for idx in keep:
+            class_id = int(detections[idx, 0])
+            confidence = detections[idx, 1]
+            x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
             
-            # 坐标转换和边界检查
-            x1, y1, x2, y2 = self._convert_coordinates(det[2:6])
-
             # 计算检测框面积比例
             bbox_area = (x2 - x1) * (y2 - y1)
             image_area = self.orig_w * self.orig_h
@@ -194,28 +235,12 @@ class ONNXDetector:
                 'class_id': class_id,
                 'confidence': float(confidence),
                 'bbox': [x1, y1, x2, y2],
-            'orig_w': self.orig_w,
-            'orig_h': self.orig_h
+                'orig_w': self.orig_w,
+                'orig_h': self.orig_h
             })
             
         return valid_detections, img_out, detections_list
 
-    def _convert_coordinates(self, coords: List[float]) -> tuple:
-        """将模型输出坐标转换为原始图像尺寸"""
-        x1 = int(coords[0] * self.orig_w / self.input_shape[3])
-        y1 = int(coords[1] * self.orig_h / self.input_shape[2])
-        x2 = int(coords[2] * self.orig_w / self.input_shape[3])
-        y2 = int(coords[3] * self.orig_h / self.input_shape[2])
-        
-        # 边界检查
-        clamp = lambda val, max_val: max(0, min(val, max_val - 1))
-        return (
-            clamp(x1, self.orig_w),
-            clamp(y1, self.orig_h),
-            clamp(x2, self.orig_w),
-            clamp(y2, self.orig_h)
-        )
-
     def process_image(self, image_path: str) -> int:
         """端到端处理单张图像"""
         input_data = self.preprocess(image_path)
@@ -257,8 +282,6 @@ if __name__ == '__main__':
     parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
     parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
                       help='检测框最大面积比例阈值,默认0.05')
-    parser.add_argument('--report', type=str, choices=['csv', 'excel', 'all'], default='all',
-                      help='输出报告格式: csv|excel|all')
     parser.add_argument('--save-empty', action='store_true',
                       help='是否保存未检测到目标的图片')
     parser.add_argument('--gui', action='store_true',
@@ -282,7 +305,7 @@ if __name__ == '__main__':
     # 初始化检测器时传递参数
     detector = ONNXDetector(
         threshold=args.threshold,
-        max_bbox_ratio=args.max_bbox_ratio,  # 添加该参数传递
+        max_bbox_ratio=args.max_bbox_ratio,
         output_dir=args.output,
         save_empty=args.save_empty,
         input_dir=args.input if os.path.isdir(args.input) else None
@@ -301,14 +324,9 @@ if __name__ == '__main__':
             total += detector.process_image(img_file)
         print(f'批量处理完成!共检测到 {total} 个目标')
     
-    # 生成检测报告
-    if args.report in ('csv', 'all'):
+        # 生成CSV报告
         csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
         ReportGenerator(detector).generate_csv(csv_path)
         print(f'CSV报告已生成: {csv_path}')
-    if args.report in ('excel', 'all'):
-        excel_path = os.path.join(detector.output_dir, 'detection_report.xlsx')
-        ReportGenerator(detector).generate_excel(excel_path)
-        print(f'Excel报告已生成: {excel_path}')
     else:
         detections = detector.process_image(args.input)