Преглед на файлове

feat: 实现UAV检测系统重构版本v2.0

本次提交包含UAV检测系统的完整重构版本,主要变更包括:

1. 采用模块化架构设计,将功能拆分为核心模块、配置模块和工具模块
2. 新增统一的配置管理系统,支持通过AppConfig类集中管理所有参数
3. 实现图像处理、模型推理和后处理的完整流程
4. 添加文件管理和报告生成功能
5. 包含完整的测试用例和项目文档

重构后的系统具有更好的可维护性和扩展性,同时保持与原有功能的兼容性。新增的模块化设计使得各组件职责更加清晰,便于后续功能扩展和性能优化。
Hannnk преди 1 месец
ревизия
b55235a45e
променени са 16 файла, в които са добавени 2753 реда и са изтрити 0 реда
  1. 185 0
      .gitignore
  2. 224 0
      README.md
  3. 240 0
      README_NEW.md
  4. 348 0
      inference.py
  5. 8 0
      requirements.txt
  6. 19 0
      run.py
  7. 338 0
      scripts/organize_project.py
  8. 72 0
      src/config/config.py
  9. 224 0
      src/core/detector.py
  10. 100 0
      src/core/image_processor.py
  11. 87 0
      src/core/model_inference.py
  12. 220 0
      src/core/post_processor.py
  13. 185 0
      src/main.py
  14. 144 0
      src/utils/file_manager.py
  15. 193 0
      src/utils/report_generator.py
  16. 166 0
      tests/test_refactored.py

+ 185 - 0
.gitignore

@@ -0,0 +1,185 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be added to the global gitignore or merged into this project gitignore.  For a PyCharm
+#  project, it is recommended to use the following settings:
+#  https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+.idea/
+
+# VS Code
+.vscode/
+
+# Project specific
+# Output directories
+data/
+output/
+logs/
+
+# Model files (if large)
+# data/models/*.onnx
+
+# Temporary files
+*.tmp
+*.temp
+
+# OS generated files
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db

+ 224 - 0
README.md

@@ -0,0 +1,224 @@
+# UAV检测系统 - 重构版本 v2.0
+
+基于ONNX的无人机检测系统,采用模块化设计,提供高效的目标检测和结果分析功能。
+
+## 🚀 项目特点
+
+- **模块化设计**: 将功能拆分为独立的模块,便于维护和扩展
+- **配置管理**: 统一的配置系统,支持灵活的参数调整
+- **高性能**: 支持GPU加速推理,优化的图像处理流程
+- **完整报告**: 自动生成CSV和Excel格式的检测报告
+- **易于使用**: 提供命令行和编程接口
+
+## 📁 项目结构
+
+```
+ModelEvaluator/
+├── src/                     # 源代码目录
+│   ├── config/              # 配置模块
+│   │   └── config.py        # 配置类定义
+│   ├── core/                # 核心模块
+│   │   ├── detector.py      # 主检测器
+│   │   ├── image_processor.py # 图像处理
+│   │   ├── model_inference.py # 模型推理
+│   │   └── post_processor.py  # 后处理
+│   ├── utils/               # 工具模块
+│   │   ├── file_manager.py  # 文件管理
+│   │   └── report_generator.py # 报告生成
+│   └── main.py              # 主入口
+├── tests/                   # 测试代码
+│   └── test_refactored.py   # 重构测试
+├── data/                    # 数据目录
+│   ├── models/              # 模型文件
+│   └── samples/             # 样本数据
+├── docs/                    # 文档目录
+├── examples/                # 示例代码
+├── scripts/                 # 脚本目录
+├── output/                  # 输出目录
+├── logs/                    # 日志目录
+├── run.py                   # 项目启动入口
+├── requirements.txt         # 依赖列表
+└── README.md               # 项目说明
+```
+
+## 🛠️ 安装和配置
+
+### 环境要求
+
+- Python 3.8+
+- CUDA 11.0+ (可选,用于GPU加速)
+
+### 安装依赖
+
+```bash
+pip install -r requirements.txt
+```
+
+### 模型文件
+
+将ONNX模型文件放置在 `data/models/` 目录下。
+
+## 🚀 使用方法
+
+### 命令行使用
+
+#### 基本用法
+
+```bash
+# 检测单张图片
+python run.py --input image.jpg
+
+# 批量检测目录中的图片
+python run.py --input /path/to/images
+
+# 使用GUI选择输入目录
+python run.py --gui
+
+# 自定义参数
+python run.py --input /path/to/images --threshold 0.6 --output custom_output
+```
+
+#### 完整参数说明
+
+```bash
+python run.py --help
+```
+
+**主要参数:**
+- `--input`: 输入图像路径或目录
+- `--output`: 输出目录路径
+- `--model`: ONNX模型文件路径
+- `--threshold`: 检测置信度阈值 (默认: 0.5)
+- `--iou-threshold`: NMS IoU阈值 (默认: 0.4)
+- `--max-bbox-ratio`: 检测框最大面积比例阈值 (默认: 0.5)
+- `--save-empty`: 是否保存未检测到目标的图片
+- `--gui`: 启用图形界面选择输入目录
+- `--verbose`: 显示详细信息
+
+### 编程接口使用
+
+```python
+import sys
+import os
+sys.path.insert(0, 'src')
+
+from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
+from core.detector import UAVDetector
+
+# 创建配置
+model_config = ModelConfig(
+    model_path='data/models/your_model.onnx',
+    input_size=(640, 640),
+    confidence_threshold=0.5
+)
+
+detection_config = DetectionConfig(
+    confidence_threshold=0.5,
+    iou_threshold=0.4,
+    max_bbox_ratio=0.5
+)
+
+output_config = OutputConfig(
+    base_output_dir='output',
+    add_timestamp=True,
+    save_empty_images=False
+)
+
+app_config = AppConfig(
+    model=model_config,
+    detection=detection_config,
+    output=output_config
+)
+
+# 初始化检测器
+detector = UAVDetector(config=app_config, input_dir='path/to/images')
+
+# 处理单张图片
+result = detector.process_image('image.jpg')
+
+# 批量处理
+detector.process_batch(['image1.jpg', 'image2.jpg'])
+
+# 生成报告
+detector.generate_reports()
+
+# 获取统计信息
+stats = detector.get_statistics()
+print(f"处理图片数: {stats['total_images']}")
+print(f"检测到目标数: {stats['total_detections']}")
+```
+
+## 📊 输出结果
+
+### 目录结构
+
+```
+output/
+├── imgs/                    # 检测结果图片
+│   ├── image1_detected.jpg
+│   └── image2_detected.jpg
+├── targets/                 # 检测目标ROI
+│   ├── image1_target_1.jpg
+│   └── image1_target_2.jpg
+├── detection_report.csv     # CSV格式报告
+└── detection_report.xlsx    # Excel格式报告
+```
+
+### 报告内容
+
+- **CSV报告**: 包含每个检测框的详细信息
+- **Excel报告**: 包含模型信息、检测统计和详细结果
+- **检测图片**: 标注了检测框的原图
+- **目标ROI**: 提取的检测目标区域
+
+## 🧪 测试
+
+运行测试脚本验证系统功能:
+
+```bash
+python tests/test_refactored.py
+```
+
+## 🔧 配置说明
+
+### ModelConfig (模型配置)
+
+- `model_path`: ONNX模型文件路径
+- `input_size`: 模型输入尺寸 (width, height)
+- `mean`: 图像归一化均值
+- `std`: 图像归一化标准差
+- `use_cuda`: 是否使用CUDA加速
+
+### DetectionConfig (检测配置)
+
+- `confidence_threshold`: 置信度阈值
+- `iou_threshold`: NMS IoU阈值
+- `max_bbox_ratio`: 检测框最大面积比例
+
+### OutputConfig (输出配置)
+
+- `base_output_dir`: 基础输出目录
+- `add_timestamp`: 是否添加时间戳
+- `save_empty_images`: 是否保存空检测图片
+
+## 🆕 版本更新
+
+### v2.0 (当前版本)
+
+- ✅ 完全重构代码架构
+- ✅ 模块化设计,提高可维护性
+- ✅ 统一配置管理系统
+- ✅ 改进的错误处理机制
+- ✅ 标准化的项目结构
+- ✅ 向后兼容性支持
+- ✅ 完善的测试覆盖
+
+### v1.0 (原版本)
+
+- 基础的UAV检测功能
+- 单文件架构
+- 基本的报告生成
+
+---
+
+**注意**: 这是重构版本,相比原版本具有更好的代码结构和可维护性。建议使用新的 `run.py` 入口文件启动应用程序。

+ 240 - 0
README_NEW.md

@@ -0,0 +1,240 @@
+# UAV检测系统 - 重构版本 v2.0
+
+基于ONNX的无人机检测系统,采用模块化设计,提供高效的目标检测和结果分析功能。
+
+## 🚀 项目特点
+
+- **模块化设计**: 将功能拆分为独立的模块,便于维护和扩展
+- **配置管理**: 统一的配置系统,支持灵活的参数调整
+- **高性能**: 支持GPU加速推理,优化的图像处理流程
+- **完整报告**: 自动生成CSV和Excel格式的检测报告
+- **易于使用**: 提供命令行和编程接口
+
+## 📁 项目结构
+
+```
+ModelEvaluator/
+├── src/                     # 源代码目录
+│   ├── config/              # 配置模块
+│   │   └── config.py        # 配置类定义
+│   ├── core/                # 核心模块
+│   │   ├── detector.py      # 主检测器
+│   │   ├── image_processor.py # 图像处理
+│   │   ├── model_inference.py # 模型推理
+│   │   └── post_processor.py  # 后处理
+│   ├── utils/               # 工具模块
+│   │   ├── file_manager.py  # 文件管理
+│   │   └── report_generator.py # 报告生成
+│   └── main.py              # 主入口
+├── tests/                   # 测试代码
+│   └── test_refactored.py   # 重构测试
+├── data/                    # 数据目录
+│   ├── models/              # 模型文件
+│   └── samples/             # 样本数据
+├── docs/                    # 文档目录
+├── examples/                # 示例代码
+├── scripts/                 # 脚本目录
+├── output/                  # 输出目录
+├── logs/                    # 日志目录
+├── run.py                   # 项目启动入口
+├── requirements.txt         # 依赖列表
+└── README.md               # 项目说明
+```
+
+## 🛠️ 安装和配置
+
+### 环境要求
+
+- Python 3.8+
+- CUDA 11.0+ (可选,用于GPU加速)
+
+### 安装依赖
+
+```bash
+pip install -r requirements.txt
+```
+
+### 模型文件
+
+将ONNX模型文件放置在 `data/models/` 目录下。
+
+## 🚀 使用方法
+
+### 命令行使用
+
+#### 基本用法
+
+```bash
+# 检测单张图片
+python run.py --input image.jpg
+
+# 批量检测目录中的图片
+python run.py --input /path/to/images
+
+# 使用GUI选择输入目录
+python run.py --gui
+
+# 自定义参数
+python run.py --input /path/to/images --threshold 0.6 --output custom_output
+```
+
+#### 完整参数说明
+
+```bash
+python run.py --help
+```
+
+**主要参数:**
+- `--input`: 输入图像路径或目录
+- `--output`: 输出目录路径
+- `--model`: ONNX模型文件路径
+- `--threshold`: 检测置信度阈值 (默认: 0.5)
+- `--iou-threshold`: NMS IoU阈值 (默认: 0.4)
+- `--max-bbox-ratio`: 检测框最大面积比例阈值 (默认: 0.5)
+- `--save-empty`: 是否保存未检测到目标的图片
+- `--gui`: 启用图形界面选择输入目录
+- `--verbose`: 显示详细信息
+
+### 编程接口使用
+
+```python
+import sys
+import os
+sys.path.insert(0, 'src')
+
+from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
+from core.detector import UAVDetector
+
+# 创建配置
+model_config = ModelConfig(
+    model_path='data/models/your_model.onnx',
+    input_size=(640, 640),
+    confidence_threshold=0.5
+)
+
+detection_config = DetectionConfig(
+    confidence_threshold=0.5,
+    iou_threshold=0.4,
+    max_bbox_ratio=0.5
+)
+
+output_config = OutputConfig(
+    base_output_dir='output',
+    add_timestamp=True,
+    save_empty_images=False
+)
+
+app_config = AppConfig(
+    model=model_config,
+    detection=detection_config,
+    output=output_config
+)
+
+# 初始化检测器
+detector = UAVDetector(config=app_config, input_dir='path/to/images')
+
+# 处理单张图片
+result = detector.process_image('image.jpg')
+
+# 批量处理
+detector.process_batch(['image1.jpg', 'image2.jpg'])
+
+# 生成报告
+detector.generate_reports()
+
+# 获取统计信息
+stats = detector.get_statistics()
+print(f"处理图片数: {stats['total_images']}")
+print(f"检测到目标数: {stats['total_detections']}")
+```
+
+## 📊 输出结果
+
+### 目录结构
+
+```
+output/
+├── imgs/                    # 检测结果图片
+│   ├── image1_detected.jpg
+│   └── image2_detected.jpg
+├── targets/                 # 检测目标ROI
+│   ├── image1_target_1.jpg
+│   └── image1_target_2.jpg
+├── detection_report.csv     # CSV格式报告
+└── detection_report.xlsx    # Excel格式报告
+```
+
+### 报告内容
+
+- **CSV报告**: 包含每个检测框的详细信息
+- **Excel报告**: 包含模型信息、检测统计和详细结果
+- **检测图片**: 标注了检测框的原图
+- **目标ROI**: 提取的检测目标区域
+
+## 🧪 测试
+
+运行测试脚本验证系统功能:
+
+```bash
+python tests/test_refactored.py
+```
+
+## 🔧 配置说明
+
+### ModelConfig (模型配置)
+
+- `model_path`: ONNX模型文件路径
+- `input_size`: 模型输入尺寸 (width, height)
+- `mean`: 图像归一化均值
+- `std`: 图像归一化标准差
+- `use_cuda`: 是否使用CUDA加速
+
+### DetectionConfig (检测配置)
+
+- `confidence_threshold`: 置信度阈值
+- `iou_threshold`: NMS IoU阈值
+- `max_bbox_ratio`: 检测框最大面积比例
+
+### OutputConfig (输出配置)
+
+- `base_output_dir`: 基础输出目录
+- `add_timestamp`: 是否添加时间戳
+- `save_empty_images`: 是否保存空检测图片
+
+## 🆕 版本更新
+
+### v2.0 (当前版本)
+
+- ✅ 完全重构代码架构
+- ✅ 模块化设计,提高可维护性
+- ✅ 统一配置管理系统
+- ✅ 改进的错误处理机制
+- ✅ 标准化的项目结构
+- ✅ 向后兼容性支持
+- ✅ 完善的测试覆盖
+
+### v1.0 (原版本)
+
+- 基础的UAV检测功能
+- 单文件架构
+- 基本的报告生成
+
+## 🤝 贡献指南
+
+1. Fork 项目
+2. 创建功能分支 (`git checkout -b feature/AmazingFeature`)
+3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
+4. 推送到分支 (`git push origin feature/AmazingFeature`)
+5. 打开 Pull Request
+
+## 📝 许可证
+
+本项目采用 MIT 许可证 - 查看 [LICENSE](LICENSE) 文件了解详情。
+
+## 📞 支持
+
+如有问题或建议,请提交 Issue 或联系开发团队。
+
+---
+
+**注意**: 这是重构版本,相比原版本具有更好的代码结构和可维护性。建议使用新的 `run.py` 入口文件启动应用程序。

+ 348 - 0
inference.py

@@ -0,0 +1,348 @@
+import numpy as np
+import onnxruntime as ort
+import cv2
+import os
+import argparse
+import pickle
+from glob import glob
+from typing import List, Dict, Optional
+import time
+from datetime import datetime
+from tqdm import tqdm
+import sys
+ 
+from report_generator import ReportGenerator
+class ONNXDetector:
+    def select_input_directory(self) -> Optional[str]:
+        """图形界面选择输入目录"""
+        import tkinter as tk
+        from tkinter import filedialog
+        root = tk.Tk()
+        root.withdraw()
+        folder_path = filedialog.askdirectory(title='选择输入目录')
+        return folder_path if folder_path else None
+
+    def __init__(self, model_path: str = 'D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx', threshold: float = 0.5, 
+                 output_dir: str = None, save_empty: bool = False, 
+                 max_bbox_ratio: float = 0.5,
+                 input_dir: str = None):
+        self.image_count = 0
+        self.detection_records = []
+        self.model_path = model_path
+        self.input_dir = input_dir
+        self.save_empty = save_empty
+        self.threshold = threshold
+        self.confThreshold = threshold
+        
+        # 初始化ONNX会话
+        so = ort.SessionOptions()
+        so.log_severity_level = 3
+        self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
+        self.session = ort.InferenceSession(model_path, so, providers=self.providers)
+        
+        # 模型参数设置
+        self.input_size = (640, 640)  # width, height
+        self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32)
+        self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32)
+        self.max_bbox_ratio = max_bbox_ratio
+        
+        # 从模型路径解析版本号
+        self.model_version = os.path.basename(model_path).split('.')[0]
+        
+        # 获取输入输出名称
+        self.input_name = self.session.get_inputs()[0].name
+        self.output_name = self.session.get_outputs()[0].name
+        
+        # 类别映射
+        self.class_names = {0: 'UAV'}
+
+        # 加载误报特征库
+        self.false_positive_features = {}
+        try:
+            with open(os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl'), 'rb') as f:
+                self.false_positive_features = pickle.load(f)
+        except FileNotFoundError:
+            print("未找到误报特征库文件,跳过加载")
+        
+        # 创建输出目录
+        output_base = os.path.join(os.path.dirname(__file__), '..', '..', 'Output')
+        os.makedirs(output_base, exist_ok=True)
+        
+        timestamp = datetime.now().strftime('%m%d%H%M%S')
+        if output_dir is None and input_dir is not None:
+            self.output_dir = os.path.join(output_base, f"{os.path.basename(os.path.normpath(input_dir))}_{timestamp}")
+        else:
+            self.output_dir = os.path.join(output_base, output_dir or f"output_{timestamp}")
+            
+        # 创建目标存储目录
+        self.targets_dir = os.path.join(self.output_dir, 'targets')
+        os.makedirs(self.targets_dir, exist_ok=True)
+
+    def preprocess(self, image_path: str) -> Optional[np.ndarray]:
+        """图像预处理"""
+        # 初始化CUDA加速
+        if not hasattr(self, 'use_cuda'):
+            self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
+            self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
+
+        # 读取图像到GPU(如果可用)
+        if self.use_cuda:
+            self.gpu_frame.upload(cv2.imread(image_path))
+            if self.gpu_frame.empty():
+                print(f"无法加载图像:{image_path}")
+                return None
+            
+            # GPU预处理流水线
+            gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
+            gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
+            
+            # 下载到CPU进行后续处理
+            image_orig = gpu_rgb.download()
+            self.orig_h, self.orig_w = image_orig.shape[:2]
+        else:
+            # CPU回退路径
+            image_orig = cv2.imread(image_path)
+            if image_orig is None:
+                print(f"无法加载图像:{image_path}")
+                return None
+            
+            self.orig_h, self.orig_w = image_orig.shape[:2]
+            image_orig = cv2.resize(image_orig, self.input_size)
+            image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
+
+        # 统一的后处理
+        image = image_orig.astype(np.float32) / 255.0
+        image -= self.mean_[None, None, :]
+        image /= self.std_[None, None, :]
+        image = np.transpose(image, (2, 0, 1))  # CHW 格式
+        return np.expand_dims(image, axis=0)
+
+    def nms(self, boxes: np.ndarray, scores: np.ndarray, conf_threshold: float, iou_threshold: float) -> List[int]:
+        """非极大值抑制"""
+        x1 = boxes[:, 0]
+        y1 = boxes[:, 1]
+        x2 = boxes[:, 2]
+        y2 = boxes[:, 3]
+        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+        order = scores.argsort()[0][::-1]
+        keep = []
+        
+        while order.size > 0:
+            i = order[0]
+            keep.append(i)
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            w = np.maximum(0.0, xx2 - xx1 + 1)
+            h = np.maximum(0.0, yy2 - yy1 + 1)
+            inter = w * h
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            inds = np.where(ovr <= iou_threshold)[0]
+            order = order[inds + 1]
+        return keep
+
+    def inference(self, input_data: np.ndarray) -> np.ndarray:
+        """执行模型推理"""
+        scale_factor = np.array([[1, 1]], dtype=np.float32)
+        return self.session.run(
+            [self.output_name],
+            {self.input_name: input_data, 'scale_factor': scale_factor}
+        )[0]
+
+    def _is_false_positive(self, roi: np.ndarray) -> bool:
+        """使用ORB特征匹配验证是否误报"""
+        if not self.false_positive_features:
+            return False
+
+        # 初始化特征检测器
+        detector = cv2.ORB_create()
+        _, des = detector.detectAndCompute(roi, None)
+
+        # 与特征库进行匹配
+        for fp_feature in self.false_positive_features.values():
+            if des is None or fp_feature['features'] is None:
+                continue
+
+            # 使用FLANN匹配器
+            flann = cv2.FlannBasedMatcher(dict(algorithm=6, table_number=6), dict())
+            matches = flann.knnMatch(des, fp_feature['features'], k=2)
+
+            # 处理空匹配情况
+            if not matches:
+                continue
+
+            # 安全验证匹配结果
+            good_matches = []
+            for match_group in matches:
+                # 确保match_group有足够元素防止索引错误
+                if len(match_group) < 2:
+                    continue
+                m, n = match_group
+                if m.distance < 0.7 * n.distance:
+                    good_matches.append(m)
+
+            # 计算优质匹配数量
+            if len(good_matches) > 15:  # 匹配阈值
+                return True
+        return False
+
+    def postprocess(self, detections: np.ndarray, image_orig: np.ndarray, image_path: str) -> tuple:
+        """后处理检测结果"""
+        valid_detections = 0
+        img_out = image_orig.copy()
+        detections_list = []
+        
+        # 过滤低置信度检测
+        keep_idx = (detections[:, 1] > self.confThreshold)
+        detections = detections[keep_idx]
+        
+        if len(detections) == 0:
+            return 0, img_out, []
+            
+        # 坐标转换
+        ratioh = self.orig_h / self.input_size[1]
+        ratiow = self.orig_w / self.input_size[0]
+        detections[:, 2
+        :6] *= np.array([ratiow, ratioh, ratiow, ratioh])
+        
+        # NMS处理
+        keep = self.nms(detections[:, 2:6], detections[:, 1:2], self.confThreshold, 0.4)
+        
+        for idx in keep:
+            class_id = int(detections[idx, 0])
+            confidence = detections[idx, 1]
+            x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
+            
+            # 计算检测框面积比例
+            bbox_area = (x2 - x1) * (y2 - y1)
+            image_area = self.orig_w * self.orig_h
+            if bbox_area / image_area > self.max_bbox_ratio:
+                continue
+
+            # 截取ROI区域进行误报验证
+            roi = image_orig[y1:y2, x1:x2]
+            if roi.size == 0:
+                continue
+                
+            if self._is_false_positive(roi):
+                continue
+                
+            # 生成唯一文件名并保存目标
+            target_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_{valid_detections}.jpg"
+            cv2.imwrite(os.path.join(self.targets_dir, target_filename), roi)
+            
+            # 绘制检测框和标签
+            label = f'{self.class_names[class_id]} {confidence:.2f}'
+            cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
+            cv2.putText(img_out, label, (x1, y1 - 10), 
+                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+            valid_detections += 1
+            detections_list.append({
+                'class_id': class_id,
+                'confidence': float(confidence),
+                'bbox': [x1, y1, x2, y2],
+                'orig_w': self.orig_w,
+                'orig_h': self.orig_h
+            })
+            
+        return valid_detections, img_out, detections_list
+
+    def process_image(self, image_path: str) -> int:
+        """端到端处理单张图像"""
+        input_data = self.preprocess(image_path)
+        if input_data is None:
+            return 0
+            
+        detections = self.inference(input_data)
+        valid_count, processed_img, detections_list = self.postprocess(detections, cv2.imread(image_path), image_path)
+        
+        # 无论是否检测到目标都记录信息
+        record = {
+            'image_path': image_path,
+            'has_detection': valid_count > 0,
+            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+        }
+
+        filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
+        # 创建输出图片目录
+        output_img_dir = os.path.join(self.output_dir, 'imgs')
+        os.makedirs(output_img_dir, exist_ok=True)
+        output_path = os.path.join(output_img_dir, filename)
+        
+        # 根据设置保存图片
+        if valid_count > 0 or self.save_empty:
+            cv2.imwrite(output_path, processed_img)
+        else:
+            output_path = None
+        self.image_count += 1
+        
+        # 记录检测信息
+        if valid_count > 0:
+            record['detections'] = detections_list
+        
+        self.detection_records.append(record)
+        return valid_count
+
+# 命令行接口
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
+    parser.add_argument('--input', type=str, help='输入图像路径或目录')
+    parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
+    parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
+    parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
+                      help='检测框最大面积比例阈值,默认0.05')
+    parser.add_argument('--save-empty', action='store_true',
+                      help='是否保存未检测到目标的图片')
+    parser.add_argument('--gui', action='store_true',
+                      help='启用图形界面选择输入目录')
+    args = parser.parse_args()
+
+    # GUI模式处理
+    if args.gui:
+        detector = ONNXDetector(
+            threshold=args.threshold,
+            max_bbox_ratio=args.max_bbox_ratio,
+            output_dir=args.output,
+            save_empty=args.save_empty
+        )
+        input_dir = detector.select_input_directory()
+        if not input_dir:
+            print("未选择目录,程序退出")
+            exit()
+        args.input = input_dir
+    
+    # 检查输入参数
+    if args.input is None:
+        print("错误:请提供输入路径或使用 --gui 参数选择目录")
+        parser.print_help()
+        exit(1)
+    
+    # 初始化检测器时传递参数
+    detector = ONNXDetector(
+        threshold=args.threshold,
+        max_bbox_ratio=args.max_bbox_ratio,
+        output_dir=args.output,
+        save_empty=args.save_empty,
+        input_dir=args.input if args.input and os.path.isdir(args.input) else None
+    )
+
+    def process_single(image_path: str):
+        detections = detector.process_image(image_path)
+        print(f'处理 {os.path.basename(image_path)} 完成,检测到 {detections} 个目标')
+
+    if args.input and os.path.isdir(args.input):
+        total = 0
+        image_files = []
+        for root, _, files in os.walk(args.input):
+            image_files.extend([os.path.join(root, f) for f in files if f.lower().endswith('.jpg')])
+        for img_file in tqdm(image_files, desc='Processing images'):
+            total += detector.process_image(img_file)
+        print(f'批量处理完成!共检测到 {total} 个目标')
+    
+        # 生成CSV报告
+        csv_path = os.path.join(detector.output_dir, 'detection_report.csv')
+        ReportGenerator(detector).generate_csv(csv_path)
+        print(f'CSV报告已生成: {csv_path}')
+    else:
+        detections = detector.process_image(args.input)

+ 8 - 0
requirements.txt

@@ -0,0 +1,8 @@
+onnxruntime-gpu==1.16.0
+opencv-python==4.8.0.74
+numpy==1.24.3
+tqdm==4.66.1
+openpyxl==3.1.2
+pytest==7.4.0
+pytest-cov==4.1.0 
+xlsxwriter==3.1.2

+ 19 - 0
run.py

@@ -0,0 +1,19 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+项目主入口文件
+用于启动UAV检测系统
+"""
+
+import sys
+import os
+
+# 添加src目录到Python路径
+src_path = os.path.join(os.path.dirname(__file__), 'src')
+sys.path.insert(0, src_path)
+
+# 导入主函数
+from main import main
+
+if __name__ == '__main__':
+    main()

+ 338 - 0
scripts/organize_project.py

@@ -0,0 +1,338 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+项目结构优化脚本
+用于重新组织项目文件结构,创建更清晰的目录布局
+"""
+
+import os
+import shutil
+from pathlib import Path
+
+
+def create_project_structure():
+    """创建标准的项目目录结构"""
+    base_dir = Path.cwd()
+    
+    # 定义目录结构
+    directories = [
+        'src',                    # 源代码目录
+        'src/core',              # 核心模块
+        'src/utils',             # 工具模块
+        'src/config',            # 配置模块
+        'tests',                 # 测试目录
+        'docs',                  # 文档目录
+        'examples',              # 示例代码
+        'scripts',               # 脚本目录
+        'data',                  # 数据目录
+        'data/models',           # 模型文件
+        'data/samples',          # 样本数据
+        'output',                # 输出目录
+        'logs',                  # 日志目录
+    ]
+    
+    # 创建目录
+    for directory in directories:
+        dir_path = base_dir / directory
+        dir_path.mkdir(parents=True, exist_ok=True)
+        print(f"✓ 创建目录: {directory}")
+    
+    return base_dir, directories
+
+
+def move_files_to_structure(base_dir):
+    """将现有文件移动到新的目录结构中"""
+    
+    # 文件移动映射
+    file_moves = {
+        # 核心模块
+        'detector.py': 'src/core/',
+        'image_processor.py': 'src/core/',
+        'model_inference.py': 'src/core/',
+        'post_processor.py': 'src/core/',
+        'file_manager.py': 'src/utils/',
+        'report_generator.py': 'src/utils/',
+        
+        # 配置模块
+        'config.py': 'src/config/',
+        
+        # 主入口
+        'main.py': 'src/',
+        
+        # 测试文件
+        'test_refactored.py': 'tests/',
+        
+        # 脚本
+        'organize_project.py': 'scripts/',
+        
+        # 模型文件
+        'model/': 'data/models/',
+    }
+    
+    # 执行文件移动
+    for src_path, dst_dir in file_moves.items():
+        src_full = base_dir / src_path
+        dst_full = base_dir / dst_dir
+        
+        if src_full.exists():
+            if src_full.is_dir():
+                # 移动目录
+                dst_path = dst_full / src_full.name
+                if dst_path.exists():
+                    shutil.rmtree(dst_path)
+                shutil.move(str(src_full), str(dst_full))
+                print(f"✓ 移动目录: {src_path} -> {dst_dir}")
+            else:
+                # 移动文件
+                dst_path = dst_full / src_full.name
+                if dst_path.exists():
+                    dst_path.unlink()
+                shutil.move(str(src_full), str(dst_full))
+                print(f"✓ 移动文件: {src_path} -> {dst_dir}")
+        else:
+            print(f"⚠ 文件不存在: {src_path}")
+
+
+def create_init_files(base_dir):
+    """创建__init__.py文件"""
+    init_files = [
+        'src/__init__.py',
+        'src/core/__init__.py',
+        'src/utils/__init__.py',
+        'src/config/__init__.py',
+        'tests/__init__.py',
+    ]
+    
+    for init_file in init_files:
+        init_path = base_dir / init_file
+        if not init_path.exists():
+            init_path.write_text('# -*- coding: utf-8 -*-\n')
+            print(f"✓ 创建: {init_file}")
+
+
+def update_imports_in_files(base_dir):
+    """更新文件中的导入路径"""
+    # 需要更新导入的文件
+    files_to_update = [
+        'src/main.py',
+        'src/core/detector.py',
+        'tests/test_refactored.py',
+    ]
+    
+    # 导入路径映射
+    import_mapping = {
+        'from config import': 'from src.config.config import',
+        'from detector import': 'from src.core.detector import',
+        'from image_processor import': 'from src.core.image_processor import',
+        'from model_inference import': 'from src.core.model_inference import',
+        'from post_processor import': 'from src.core.post_processor import',
+        'from file_manager import': 'from src.utils.file_manager import',
+        'from report_generator import': 'from src.utils.report_generator import',
+    }
+    
+    for file_path in files_to_update:
+        full_path = base_dir / file_path
+        if full_path.exists():
+            try:
+                content = full_path.read_text(encoding='utf-8')
+                
+                # 更新导入语句
+                for old_import, new_import in import_mapping.items():
+                    content = content.replace(old_import, new_import)
+                
+                full_path.write_text(content, encoding='utf-8')
+                print(f"✓ 更新导入: {file_path}")
+            except Exception as e:
+                print(f"⚠ 更新失败 {file_path}: {e}")
+
+
+def create_setup_files(base_dir):
+    """创建项目配置文件"""
+    
+    # 创建setup.py
+    setup_content = '''#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from setuptools import setup, find_packages
+
+with open("README.md", "r", encoding="utf-8") as fh:
+    long_description = fh.read()
+
+with open("requirements.txt", "r", encoding="utf-8") as fh:
+    requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
+
+setup(
+    name="uav-detector",
+    version="2.0.0",
+    author="UAV Detection Team",
+    description="基于ONNX的无人机检测系统",
+    long_description=long_description,
+    long_description_content_type="text/markdown",
+    packages=find_packages(),
+    classifiers=[
+        "Development Status :: 4 - Beta",
+        "Intended Audience :: Developers",
+        "License :: OSI Approved :: MIT License",
+        "Operating System :: OS Independent",
+        "Programming Language :: Python :: 3",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
+        "Programming Language :: Python :: 3.10",
+    ],
+    python_requires=">=3.8",
+    install_requires=requirements,
+    entry_points={
+        "console_scripts": [
+            "uav-detect=src.main:main",
+        ],
+    },
+)
+'''
+    
+    setup_path = base_dir / 'setup.py'
+    setup_path.write_text(setup_content, encoding='utf-8')
+    print("✓ 创建: setup.py")
+    
+    # 创建.gitignore
+    gitignore_content = '''# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+PIPFILE.lock
+
+# Virtual Environment
+.venv/
+venv/
+ENV/
+env/
+
+# IDE
+.vscode/
+.idea/
+*.swp
+*.swo
+*~
+
+# OS
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
+
+# Project specific
+output/
+logs/
+*.log
+data/samples/*
+!data/samples/.gitkeep
+'''
+    
+    gitignore_path = base_dir / '.gitignore'
+    if not gitignore_path.exists():
+        gitignore_path.write_text(gitignore_content, encoding='utf-8')
+        print("✓ 创建: .gitignore")
+
+
+def create_placeholder_files(base_dir):
+    """创建占位文件"""
+    placeholders = [
+        ('data/samples/.gitkeep', ''),
+        ('logs/.gitkeep', ''),
+        ('output/.gitkeep', ''),
+        ('docs/API.md', '# API 文档\n\n待完善...\n'),
+        ('examples/basic_usage.py', '#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\n"""基本使用示例"""\n\n# 待完善...\n'),
+    ]
+    
+    for file_path, content in placeholders:
+        full_path = base_dir / file_path
+        if not full_path.exists():
+            full_path.write_text(content, encoding='utf-8')
+            print(f"✓ 创建占位文件: {file_path}")
+
+
+def main():
+    """主函数"""
+    print("🚀 开始项目结构重构...")
+    print("=" * 50)
+    
+    try:
+        # 1. 创建目录结构
+        print("\n📁 创建目录结构...")
+        base_dir, directories = create_project_structure()
+        
+        # 2. 移动文件
+        print("\n📦 移动文件到新结构...")
+        move_files_to_structure(base_dir)
+        
+        # 3. 创建__init__.py文件
+        print("\n📄 创建__init__.py文件...")
+        create_init_files(base_dir)
+        
+        # 4. 更新导入路径
+        print("\n🔄 更新导入路径...")
+        update_imports_in_files(base_dir)
+        
+        # 5. 创建配置文件
+        print("\n⚙️ 创建项目配置文件...")
+        create_setup_files(base_dir)
+        
+        # 6. 创建占位文件
+        print("\n📝 创建占位文件...")
+        create_placeholder_files(base_dir)
+        
+        print("\n" + "=" * 50)
+        print("🎉 项目结构重构完成!")
+        print("\n新的项目结构:")
+        print("├── src/                 # 源代码")
+        print("│   ├── core/            # 核心模块")
+        print("│   ├── utils/           # 工具模块")
+        print("│   ├── config/          # 配置模块")
+        print("│   └── main.py          # 主入口")
+        print("├── tests/               # 测试代码")
+        print("├── docs/                # 文档")
+        print("├── examples/            # 示例代码")
+        print("├── scripts/             # 脚本")
+        print("├── data/                # 数据文件")
+        print("│   ├── models/          # 模型文件")
+        print("│   └── samples/         # 样本数据")
+        print("├── output/              # 输出目录")
+        print("├── logs/                # 日志目录")
+        print("├── setup.py             # 安装配置")
+        print("├── requirements.txt     # 依赖列表")
+        print("└── README.md            # 项目说明")
+        
+    except Exception as e:
+        print(f"❌ 重构过程中出现错误: {e}")
+        return False
+    
+    return True
+
+
+if __name__ == '__main__':
+    success = main()
+    if success:
+        print("\n✅ 可以使用以下命令测试新结构:")
+        print("   python -m src.main --help")
+        print("   python -m pytest tests/")
+    else:
+        print("\n❌ 重构失败,请检查错误信息")

+ 72 - 0
src/config/config.py

@@ -0,0 +1,72 @@
+"""配置管理模块"""
+import os
+from dataclasses import dataclass
+from typing import Tuple, List, Optional
+
+
+@dataclass
+class ModelConfig:
+    """模型配置"""
+    model_path: str = 'D:/PythonProject/ModelEvaluator/data/models/UAV-250411.onnx'
+    input_size: Tuple[int, int] = (640, 640)  # width, height
+    mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
+    std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
+    providers: List[str] = None
+    
+    def __post_init__(self):
+        if self.providers is None:
+            self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
+
+
+@dataclass
+class DetectionConfig:
+    """检测配置"""
+    confidence_threshold: float = 0.5
+    iou_threshold: float = 0.4
+    max_bbox_ratio: float = 0.5
+    class_names: dict = None
+    
+    def __post_init__(self):
+        if self.class_names is None:
+            self.class_names = {0: 'UAV'}
+
+
+@dataclass
+class OutputConfig:
+    """输出配置"""
+    output_dir: Optional[str] = None
+    save_empty: bool = False
+    create_timestamp_dir: bool = True
+    
+    def get_output_dir(self, input_dir: Optional[str] = None) -> str:
+        """获取输出目录路径"""
+        from datetime import datetime
+        
+        output_base = os.path.join(os.path.dirname(__file__), '..', '..', 'Output')
+        os.makedirs(output_base, exist_ok=True)
+        
+        if self.create_timestamp_dir:
+            timestamp = datetime.now().strftime('%m%d%H%M%S')
+            if self.output_dir is None and input_dir is not None:
+                return os.path.join(output_base, f"{os.path.basename(os.path.normpath(input_dir))}_{timestamp}")
+            else:
+                return os.path.join(output_base, self.output_dir or f"output_{timestamp}")
+        else:
+            return os.path.join(output_base, self.output_dir or "output")
+
+
+@dataclass
+class AppConfig:
+    """应用配置"""
+    model: ModelConfig = None
+    detection: DetectionConfig = None
+    output: OutputConfig = None
+    false_positive_features_path: str = 'false_positive_features.pkl'
+    
+    def __post_init__(self):
+        if self.model is None:
+            self.model = ModelConfig()
+        if self.detection is None:
+            self.detection = DetectionConfig()
+        if self.output is None:
+            self.output = OutputConfig()

+ 224 - 0
src/core/detector.py

@@ -0,0 +1,224 @@
+"""主检测器模块"""
+import cv2
+import os
+from typing import List, Dict, Optional
+from datetime import datetime
+from tqdm import tqdm
+
+from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
+from core.image_processor import ImageProcessor
+from core.model_inference import ONNXInferenceEngine
+from core.post_processor import PostProcessor
+from utils.file_manager import FileManager
+from utils.report_generator import ReportGenerator
+
+
+class UAVDetector:
+    """UAV检测器主类"""
+    
+    def __init__(self, config: AppConfig = None, input_dir: str = None):
+        # 初始化配置
+        self.config = config or AppConfig()
+        self.input_dir = input_dir
+        
+        # 统计信息
+        self.image_count = 0
+        self.detection_records = []
+        
+        # 初始化各个组件
+        self._initialize_components()
+    
+    def _initialize_components(self):
+        """初始化各个组件"""
+        try:
+            # 初始化图像处理器
+            self.image_processor = ImageProcessor(self.config.model)
+            
+            # 初始化推理引擎
+            self.inference_engine = ONNXInferenceEngine(self.config.model)
+            
+            # 初始化后处理器
+            self.post_processor = PostProcessor(
+                self.config.detection, 
+                self.config.false_positive_features_path
+            )
+            
+            # 初始化文件管理器
+            output_dir = self.config.output.get_output_dir(self.input_dir)
+            self.file_manager = FileManager(output_dir)
+            
+            print(f"检测器初始化成功")
+            print(f"输出目录: {output_dir}")
+            
+        except Exception as e:
+            raise RuntimeError(f"检测器初始化失败: {e}")
+    
+    def process_image(self, image_path: str) -> int:
+        """处理单张图像
+        
+        Args:
+            image_path: 图像文件路径
+            
+        Returns:
+            检测到的目标数量
+        """
+        try:
+            # 图像预处理
+            input_data = self.image_processor.preprocess(image_path)
+            if input_data is None:
+                return 0
+            
+            # 模型推理
+            detections = self.inference_engine.inference(input_data)
+            if detections is None:
+                return 0
+            
+            # 后处理
+            original_image = cv2.imread(image_path)
+            if original_image is None:
+                return 0
+            
+            original_size = self.image_processor.get_original_size()
+            valid_count, processed_img, detections_list = self.post_processor.postprocess(
+                detections, original_image, original_size
+            )
+            
+            # 保存结果
+            self._save_results(image_path, processed_img, detections_list, valid_count)
+            
+            # 更新统计
+            self.image_count += 1
+            
+            return valid_count
+            
+        except Exception as e:
+            print(f"处理图像 {image_path} 时出错: {e}")
+            return 0
+    
+    def _save_results(self, image_path: str, processed_img, detections_list: List[Dict], valid_count: int):
+        """保存处理结果"""
+        # 保存检测图像
+        output_image_path = self.file_manager.save_detection_image(
+            processed_img, image_path, 
+            self.config.output.save_empty, 
+            valid_count > 0
+        )
+        
+        # 保存目标ROI
+        for i, detection in enumerate(detections_list):
+            self.file_manager.save_target_roi(
+                detection['roi'], image_path, i
+            )
+        
+        # 记录检测信息
+        record = {
+            'image_path': image_path,
+            'has_detection': valid_count > 0,
+            'detection_count': valid_count,
+            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+            'output_image_path': output_image_path
+        }
+        
+        if valid_count > 0:
+            record['detections'] = detections_list
+        
+        self.detection_records.append(record)
+    
+    def process_batch(self, input_path: str) -> Dict[str, int]:
+        """批量处理图像
+        
+        Args:
+            input_path: 输入路径(文件或目录)
+            
+        Returns:
+            处理结果统计
+        """
+        try:
+            # 获取图像文件列表
+            image_files = self.file_manager.get_image_files(input_path)
+            
+            total_detections = 0
+            processed_images = 0
+            
+            print(f"开始处理 {len(image_files)} 张图像...")
+            
+            # 处理每张图像
+            for image_file in tqdm(image_files, desc='Processing images'):
+                detections = self.process_image(image_file)
+                total_detections += detections
+                processed_images += 1
+                
+                if detections > 0:
+                    print(f'处理 {os.path.basename(image_file)} 完成,检测到 {detections} 个目标')
+            
+            # 生成报告
+            self._generate_reports()
+            
+            # 清理空目录
+            self.file_manager.cleanup_empty_directories()
+            
+            return {
+                'processed_images': processed_images,
+                'total_detections': total_detections,
+                'images_with_detections': sum(1 for r in self.detection_records if r['has_detection'])
+            }
+            
+        except Exception as e:
+            print(f"批量处理失败: {e}")
+            return {'processed_images': 0, 'total_detections': 0, 'images_with_detections': 0}
+    
+    def _generate_reports(self):
+        """生成检测报告"""
+        try:
+            report_generator = ReportGenerator(self)
+            output_paths = self.file_manager.get_output_paths()
+            
+            # 生成CSV报告
+            report_generator.generate_csv(output_paths['csv_report'])
+            print(f'CSV报告已生成: {output_paths["csv_report"]}')
+            
+            # 生成Excel报告
+            report_generator.generate_excel(output_paths['excel_report'])
+            print(f'Excel报告已生成: {output_paths["excel_report"]}')
+            
+        except Exception as e:
+            print(f"生成报告时出错: {e}")
+    
+    def get_statistics(self) -> Dict:
+        """获取检测统计信息"""
+        total_detections = sum(r.get('detection_count', 0) for r in self.detection_records)
+        images_with_detections = sum(1 for r in self.detection_records if r['has_detection'])
+        
+        return {
+            'total_images': self.image_count,
+            'total_detections': total_detections,
+            'images_with_detections': images_with_detections,
+            'detection_rate': images_with_detections / self.image_count if self.image_count > 0 else 0,
+            'avg_detections_per_image': total_detections / self.image_count if self.image_count > 0 else 0
+        }
+    
+    def get_model_info(self) -> Dict:
+        """获取模型信息"""
+        model_info = self.inference_engine.get_model_info()
+        model_info.update({
+            'confidence_threshold': self.config.detection.confidence_threshold,
+            'iou_threshold': self.config.detection.iou_threshold,
+            'max_bbox_ratio': self.config.detection.max_bbox_ratio,
+            'class_names': self.config.detection.class_names
+        })
+        return model_info
+    
+    @property
+    def model_version(self) -> str:
+        """获取模型版本"""
+        return os.path.basename(self.config.model.model_path).split('.')[0]
+    
+    @property
+    def threshold(self) -> float:
+        """获取置信度阈值(兼容性属性)"""
+        return self.config.detection.confidence_threshold
+    
+    @property
+    def output_dir(self) -> str:
+        """获取输出目录(兼容性属性)"""
+        return self.file_manager.output_dir

+ 100 - 0
src/core/image_processor.py

@@ -0,0 +1,100 @@
+"""图像预处理模块"""
+import cv2
+import numpy as np
+from typing import Optional, Tuple
+from config.config import ModelConfig
+
+
+class ImageProcessor:
+    """图像预处理器"""
+    
+    def __init__(self, config: ModelConfig):
+        self.config = config
+        self.input_size = config.input_size
+        self.mean = np.array(config.mean, dtype=np.float32)
+        self.std = np.array(config.std, dtype=np.float32)
+        
+        # 初始化CUDA加速
+        self.use_cuda = cv2.cuda.getCudaEnabledDeviceCount() > 0
+        self.gpu_frame = cv2.cuda_GpuMat() if self.use_cuda else None
+        
+        # 存储原始图像尺寸
+        self.orig_h = 0
+        self.orig_w = 0
+    
+    def preprocess(self, image_path: str) -> Optional[np.ndarray]:
+        """图像预处理
+        
+        Args:
+            image_path: 图像文件路径
+            
+        Returns:
+            预处理后的图像数据,格式为NCHW
+        """
+        try:
+            if self.use_cuda:
+                return self._preprocess_gpu(image_path)
+            else:
+                return self._preprocess_cpu(image_path)
+        except Exception as e:
+            print(f"预处理图像 {image_path} 时出错: {e}")
+            return None
+    
+    def _preprocess_gpu(self, image_path: str) -> Optional[np.ndarray]:
+        """GPU加速预处理"""
+        # 读取图像到GPU
+        cpu_image = cv2.imread(image_path)
+        if cpu_image is None:
+            print(f"无法加载图像:{image_path}")
+            return None
+            
+        self.orig_h, self.orig_w = cpu_image.shape[:2]
+        
+        self.gpu_frame.upload(cpu_image)
+        if self.gpu_frame.empty():
+            print(f"无法上传图像到GPU:{image_path}")
+            return None
+        
+        # GPU预处理流水线
+        gpu_resized = cv2.cuda.resize(self.gpu_frame, self.input_size)
+        gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
+        
+        # 下载到CPU进行后续处理
+        image_orig = gpu_rgb.download()
+        return self._normalize_image(image_orig)
+    
+    def _preprocess_cpu(self, image_path: str) -> Optional[np.ndarray]:
+        """CPU预处理"""
+        image_orig = cv2.imread(image_path)
+        if image_orig is None:
+            print(f"无法加载图像:{image_path}")
+            return None
+        
+        self.orig_h, self.orig_w = image_orig.shape[:2]
+        image_orig = cv2.resize(image_orig, self.input_size)
+        image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
+        
+        return self._normalize_image(image_orig)
+    
+    def _normalize_image(self, image: np.ndarray) -> np.ndarray:
+        """图像归一化"""
+        # 归一化到[0,1]
+        image = image.astype(np.float32) / 255.0
+        
+        # 标准化
+        image -= self.mean[None, None, :]
+        image /= self.std[None, None, :]
+        
+        # 转换为CHW格式
+        image = np.transpose(image, (2, 0, 1))
+        
+        # 添加batch维度
+        return np.expand_dims(image, axis=0)
+    
+    def get_original_size(self) -> Tuple[int, int]:
+        """获取原始图像尺寸
+        
+        Returns:
+            (width, height)
+        """
+        return self.orig_w, self.orig_h

+ 87 - 0
src/core/model_inference.py

@@ -0,0 +1,87 @@
+"""模型推理模块"""
+import onnxruntime as ort
+import numpy as np
+from typing import Optional
+from config.config import ModelConfig
+
+
+class ONNXInferenceEngine:
+    """ONNX模型推理引擎"""
+    
+    def __init__(self, config: ModelConfig):
+        self.config = config
+        self.model_path = config.model_path
+        self.session = None
+        self.input_name = None
+        self.output_name = None
+        
+        self._initialize_session()
+    
+    def _initialize_session(self):
+        """初始化ONNX推理会话"""
+        try:
+            # 配置会话选项
+            session_options = ort.SessionOptions()
+            session_options.log_severity_level = 3  # 减少日志输出
+            
+            # 创建推理会话
+            self.session = ort.InferenceSession(
+                self.model_path, 
+                session_options, 
+                providers=self.config.providers
+            )
+            
+            # 获取输入输出名称
+            self.input_name = self.session.get_inputs()[0].name
+            self.output_name = self.session.get_outputs()[0].name
+            
+            print(f"模型加载成功: {self.model_path}")
+            print(f"使用的执行提供者: {self.session.get_providers()}")
+            
+        except Exception as e:
+            raise RuntimeError(f"初始化ONNX模型失败: {e}")
+    
+    def inference(self, input_data: np.ndarray) -> Optional[np.ndarray]:
+        """执行模型推理
+        
+        Args:
+            input_data: 预处理后的输入数据
+            
+        Returns:
+            模型推理结果
+        """
+        if self.session is None:
+            raise RuntimeError("模型未正确初始化")
+        
+        try:
+            # 准备输入数据
+            scale_factor = np.array([[1, 1]], dtype=np.float32)
+            
+            # 执行推理
+            outputs = self.session.run(
+                [self.output_name],
+                {
+                    self.input_name: input_data,
+                    'scale_factor': scale_factor
+                }
+            )
+            
+            return outputs[0]
+            
+        except Exception as e:
+            print(f"模型推理失败: {e}")
+            return None
+    
+    def get_model_info(self) -> dict:
+        """获取模型信息"""
+        if self.session is None:
+            return {}
+        
+        return {
+            'model_path': self.model_path,
+            'input_name': self.input_name,
+            'output_name': self.output_name,
+            'providers': self.session.get_providers(),
+            'input_shape': self.session.get_inputs()[0].shape,
+            'output_shape': self.session.get_outputs()[0].shape
+        }

+ 220 - 0
src/core/post_processor.py

@@ -0,0 +1,220 @@
+"""后处理模块"""
+import cv2
+import numpy as np
+import pickle
+import os
+from typing import List, Dict, Tuple, Optional
+from config.config import DetectionConfig
+
+
+class PostProcessor:
+    """检测结果后处理器"""
+    
+    def __init__(self, config: DetectionConfig, false_positive_features_path: str = None):
+        self.config = config
+        self.confidence_threshold = config.confidence_threshold
+        self.iou_threshold = config.iou_threshold
+        self.max_bbox_ratio = config.max_bbox_ratio
+        self.class_names = config.class_names
+        
+        # 加载误报特征库
+        self.false_positive_features = self._load_false_positive_features(false_positive_features_path)
+    
+    def _load_false_positive_features(self, features_path: str) -> dict:
+        """加载误报特征库"""
+        if not features_path:
+            features_path = os.path.join(os.path.dirname(__file__), 'false_positive_features.pkl')
+        
+        try:
+            with open(features_path, 'rb') as f:
+                features = pickle.load(f)
+            print(f"误报特征库加载成功: {features_path}")
+            return features
+        except FileNotFoundError:
+            print(f"未找到误报特征库文件: {features_path},跳过加载")
+            return {}
+        except Exception as e:
+            print(f"加载误报特征库失败: {e}")
+            return {}
+    
+    def nms(self, boxes: np.ndarray, scores: np.ndarray) -> List[int]:
+        """非极大值抑制
+        
+        Args:
+            boxes: 检测框坐标 [N, 4] (x1, y1, x2, y2)
+            scores: 置信度分数 [N, 1]
+            
+        Returns:
+            保留的检测框索引列表
+        """
+        if len(boxes) == 0:
+            return []
+        
+        x1 = boxes[:, 0]
+        y1 = boxes[:, 1]
+        x2 = boxes[:, 2]
+        y2 = boxes[:, 3]
+        
+        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+        order = scores.flatten().argsort()[::-1]
+        keep = []
+        
+        while order.size > 0:
+            i = order[0]
+            keep.append(i)
+            
+            if order.size == 1:
+                break
+            
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            
+            w = np.maximum(0.0, xx2 - xx1 + 1)
+            h = np.maximum(0.0, yy2 - yy1 + 1)
+            inter = w * h
+            
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            inds = np.where(ovr <= self.iou_threshold)[0]
+            order = order[inds + 1]
+        
+        return keep
+    
+    def _is_false_positive(self, roi: np.ndarray) -> bool:
+        """使用ORB特征匹配验证是否误报
+        
+        Args:
+            roi: 检测区域图像
+            
+        Returns:
+            是否为误报
+        """
+        if not self.false_positive_features or roi.size == 0:
+            return False
+        
+        try:
+            # 初始化特征检测器
+            detector = cv2.ORB_create()
+            _, des = detector.detectAndCompute(roi, None)
+            
+            if des is None:
+                return False
+            
+            # 与特征库进行匹配
+            for fp_feature in self.false_positive_features.values():
+                if fp_feature.get('features') is None:
+                    continue
+                
+                # 使用FLANN匹配器
+                try:
+                    flann = cv2.FlannBasedMatcher(
+                        dict(algorithm=6, table_number=6), 
+                        dict()
+                    )
+                    matches = flann.knnMatch(des, fp_feature['features'], k=2)
+                    
+                    if not matches:
+                        continue
+                    
+                    # 计算优质匹配
+                    good_matches = []
+                    for match_group in matches:
+                        if len(match_group) < 2:
+                            continue
+                        m, n = match_group
+                        if m.distance < 0.7 * n.distance:
+                            good_matches.append(m)
+                    
+                    # 如果优质匹配数量超过阈值,认为是误报
+                    if len(good_matches) > 15:
+                        return True
+                        
+                except Exception as e:
+                    print(f"特征匹配过程中出错: {e}")
+                    continue
+            
+            return False
+            
+        except Exception as e:
+            print(f"误报检测过程中出错: {e}")
+            return False
+    
+    def postprocess(self, detections: np.ndarray, original_image: np.ndarray, 
+                   original_size: Tuple[int, int]) -> Tuple[int, np.ndarray, List[Dict]]:
+        """后处理检测结果
+        
+        Args:
+            detections: 模型输出的检测结果
+            original_image: 原始图像
+            original_size: 原始图像尺寸 (width, height)
+            
+        Returns:
+            (有效检测数量, 处理后的图像, 检测结果列表)
+        """
+        orig_w, orig_h = original_size
+        img_out = original_image.copy()
+        detections_list = []
+        
+        # 过滤低置信度检测
+        keep_idx = (detections[:, 1] > self.confidence_threshold)
+        detections = detections[keep_idx]
+        
+        if len(detections) == 0:
+            return 0, img_out, detections_list
+        
+        # 坐标转换:从模型输入尺寸转换到原始图像尺寸
+        ratio_w = orig_w / self.config.input_size[0] if hasattr(self.config, 'input_size') else orig_w / 640
+        ratio_h = orig_h / self.config.input_size[1] if hasattr(self.config, 'input_size') else orig_h / 640
+        
+        detections[:, 2:6] *= np.array([ratio_w, ratio_h, ratio_w, ratio_h])
+        
+        # NMS处理
+        keep = self.nms(detections[:, 2:6], detections[:, 1:2])
+        
+        valid_detections = 0
+        
+        for idx in keep:
+            class_id = int(detections[idx, 0])
+            confidence = detections[idx, 1]
+            x1, y1, x2, y2 = detections[idx, 2:6].astype(int)
+            
+            # 确保坐标在图像范围内
+            x1 = max(0, min(x1, orig_w - 1))
+            y1 = max(0, min(y1, orig_h - 1))
+            x2 = max(x1 + 1, min(x2, orig_w))
+            y2 = max(y1 + 1, min(y2, orig_h))
+            
+            # 计算检测框面积比例
+            bbox_area = (x2 - x1) * (y2 - y1)
+            image_area = orig_w * orig_h
+            if bbox_area / image_area > self.max_bbox_ratio:
+                continue
+            
+            # 截取ROI区域进行误报验证
+            roi = original_image[y1:y2, x1:x2]
+            if roi.size == 0 or self._is_false_positive(roi):
+                continue
+            
+            # 绘制检测框和标签
+            class_name = self.class_names.get(class_id, f'Class_{class_id}')
+            label = f'{class_name} {confidence:.2f}'
+            
+            cv2.rectangle(img_out, (x1, y1), (x2, y2), (255, 0, 0), 4)
+            cv2.putText(img_out, label, (x1, y1 - 10), 
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
+            
+            # 记录检测结果
+            detections_list.append({
+                'class_id': class_id,
+                'class_name': class_name,
+                'confidence': float(confidence),
+                'bbox': [x1, y1, x2, y2],
+                'roi': roi,
+                'orig_w': orig_w,
+                'orig_h': orig_h
+            })
+            
+            valid_detections += 1
+        
+        return valid_detections, img_out, detections_list

+ 185 - 0
src/main.py

@@ -0,0 +1,185 @@
+#!/usr/bin/env python3
+"""UAV检测系统主入口"""
+import argparse
+import sys
+import os
+from typing import Optional
+
+from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
+from core.detector import UAVDetector
+from utils.file_manager import FileManager
+
+
+def parse_arguments() -> argparse.Namespace:
+    """解析命令行参数"""
+    parser = argparse.ArgumentParser(
+        description='UAV Detection using ONNX model',
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例用法:
+  python main.py --input /path/to/images --threshold 0.6
+  python main.py --gui
+  python main.py --input single_image.jpg --output custom_output
+        """
+    )
+    
+    # 输入输出参数
+    parser.add_argument('--input', type=str, help='输入图像路径或目录')
+    parser.add_argument('--output', type=str, default=None, help='输出目录路径')
+    parser.add_argument('--gui', action='store_true', help='启用图形界面选择输入目录')
+    
+    # 模型参数
+    parser.add_argument('--model', type=str, 
+                       default='D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx',
+                       help='ONNX模型文件路径')
+    
+    # 检测参数
+    parser.add_argument('--threshold', type=float, default=0.5, 
+                       help='检测置信度阈值 (默认: 0.5)')
+    parser.add_argument('--iou-threshold', type=float, default=0.4,
+                       help='NMS IoU阈值 (默认: 0.4)')
+    parser.add_argument('--max-bbox-ratio', type=float, default=0.5,
+                       help='检测框最大面积比例阈值 (默认: 0.5)')
+    
+    # 输出选项
+    parser.add_argument('--save-empty', action='store_true',
+                       help='是否保存未检测到目标的图片')
+    parser.add_argument('--no-timestamp', action='store_true',
+                       help='输出目录不添加时间戳')
+    
+    # 其他选项
+    parser.add_argument('--verbose', '-v', action='store_true',
+                       help='显示详细信息')
+    
+    return parser.parse_args()
+
+
+def create_config_from_args(args: argparse.Namespace) -> AppConfig:
+    """从命令行参数创建配置"""
+    model_config = ModelConfig(
+        model_path=args.model
+    )
+    
+    detection_config = DetectionConfig(
+        confidence_threshold=args.threshold,
+        iou_threshold=args.iou_threshold,
+        max_bbox_ratio=args.max_bbox_ratio
+    )
+    
+    output_config = OutputConfig(
+        output_dir=args.output,
+        save_empty=args.save_empty,
+        create_timestamp_dir=not args.no_timestamp
+    )
+    
+    return AppConfig(
+        model=model_config,
+        detection=detection_config,
+        output=output_config
+    )
+
+
+def validate_input(input_path: str) -> bool:
+    """验证输入路径"""
+    if not input_path:
+        return False
+    
+    if not os.path.exists(input_path):
+        print(f"错误:输入路径不存在: {input_path}")
+        return False
+    
+    if os.path.isfile(input_path):
+        supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif')
+        if not input_path.lower().endswith(supported_formats):
+            print(f"错误:不支持的图像格式: {input_path}")
+            return False
+    
+    return True
+
+
+def main():
+    """主函数"""
+    try:
+        # 解析命令行参数
+        args = parse_arguments()
+        
+        # GUI模式处理
+        if args.gui:
+            input_path = FileManager.select_input_directory()
+            if not input_path:
+                print("未选择目录,程序退出")
+                return 1
+            args.input = input_path
+        
+        # 验证输入参数
+        if not args.input:
+            print("错误:请提供输入路径或使用 --gui 参数选择目录")
+            return 1
+        
+        if not validate_input(args.input):
+            return 1
+        
+        # 创建配置
+        config = create_config_from_args(args)
+        
+        # 确定输入目录(用于输出目录命名)
+        input_dir = args.input if os.path.isdir(args.input) else None
+        
+        # 初始化检测器
+        if args.verbose:
+            print("正在初始化检测器...")
+        
+        detector = UAVDetector(config, input_dir)
+        
+        if args.verbose:
+            model_info = detector.get_model_info()
+            print(f"模型信息:")
+            print(f"  路径: {model_info['model_path']}")
+            print(f"  输入形状: {model_info['input_shape']}")
+            print(f"  执行提供者: {model_info['providers']}")
+            print(f"  置信度阈值: {model_info['confidence_threshold']}")
+            print()
+        
+        # 处理图像
+        if os.path.isfile(args.input):
+            # 单张图像处理
+            print(f"处理单张图像: {args.input}")
+            detections = detector.process_image(args.input)
+            print(f"检测完成,发现 {detections} 个目标")
+            
+            # 生成报告
+            detector._generate_reports()
+            
+        else:
+            # 批量处理
+            print(f"开始批量处理目录: {args.input}")
+            results = detector.process_batch(args.input)
+            
+            print(f"\n批量处理完成!")
+            print(f"  处理图像数: {results['processed_images']}")
+            print(f"  检测目标总数: {results['total_detections']}")
+            print(f"  有检测结果的图像: {results['images_with_detections']}")
+        
+        # 显示统计信息
+        if args.verbose:
+            stats = detector.get_statistics()
+            print(f"\n统计信息:")
+            print(f"  检测率: {stats['detection_rate']:.2%}")
+            print(f"  平均每张图像检测数: {stats['avg_detections_per_image']:.2f}")
+        
+        print(f"\n结果保存在: {detector.output_dir}")
+        return 0
+        
+    except KeyboardInterrupt:
+        print("\n用户中断操作")
+        return 1
+    except Exception as e:
+        print(f"程序执行出错: {e}")
+        if args.verbose if 'args' in locals() else False:
+            import traceback
+            traceback.print_exc()
+        return 1
+
+
+if __name__ == '__main__':
+    sys.exit(main())

+ 144 - 0
src/utils/file_manager.py

@@ -0,0 +1,144 @@
+"""文件管理模块"""
+import os
+import cv2
+from glob import glob
+from typing import List, Optional, Generator
+from datetime import datetime
+
+
+class FileManager:
+    """文件管理器"""
+    
+    def __init__(self, output_dir: str):
+        self.output_dir = output_dir
+        self.targets_dir = os.path.join(output_dir, 'targets')
+        self.images_dir = os.path.join(output_dir, 'imgs')
+        
+        # 创建输出目录
+        self._create_directories()
+    
+    def _create_directories(self):
+        """创建必要的目录"""
+        os.makedirs(self.output_dir, exist_ok=True)
+        os.makedirs(self.targets_dir, exist_ok=True)
+        os.makedirs(self.images_dir, exist_ok=True)
+    
+    @staticmethod
+    def get_image_files(input_path: str) -> List[str]:
+        """获取图像文件列表
+        
+        Args:
+            input_path: 输入路径(文件或目录)
+            
+        Returns:
+            图像文件路径列表
+        """
+        supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif')
+        
+        if os.path.isfile(input_path):
+            if input_path.lower().endswith(supported_formats):
+                return [input_path]
+            else:
+                raise ValueError(f"不支持的图像格式: {input_path}")
+        
+        elif os.path.isdir(input_path):
+            image_files = []
+            for root, _, files in os.walk(input_path):
+                for file in files:
+                    if file.lower().endswith(supported_formats):
+                        image_files.append(os.path.join(root, file))
+            
+            if not image_files:
+                raise ValueError(f"在目录 {input_path} 中未找到支持的图像文件")
+            
+            return sorted(image_files)
+        
+        else:
+            raise ValueError(f"输入路径不存在: {input_path}")
+    
+    def save_detection_image(self, image: any, image_path: str, save_empty: bool = False, 
+                           has_detection: bool = False) -> Optional[str]:
+        """保存检测结果图像
+        
+        Args:
+            image: 处理后的图像
+            image_path: 原始图像路径
+            save_empty: 是否保存无检测结果的图像
+            has_detection: 是否有检测结果
+            
+        Returns:
+            保存的文件路径,如果未保存则返回None
+        """
+        if has_detection or save_empty:
+            filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
+            output_path = os.path.join(self.images_dir, filename)
+            cv2.imwrite(output_path, image)
+            return output_path
+        return None
+    
+    def save_target_roi(self, roi: any, image_path: str, detection_index: int) -> str:
+        """保存检测到的目标ROI
+        
+        Args:
+            roi: 目标区域图像
+            image_path: 原始图像路径
+            detection_index: 检测索引
+            
+        Returns:
+            保存的ROI文件路径
+        """
+        base_name = os.path.splitext(os.path.basename(image_path))[0]
+        target_filename = f"{base_name}_{detection_index}.jpg"
+        target_path = os.path.join(self.targets_dir, target_filename)
+        cv2.imwrite(target_path, roi)
+        return target_path
+    
+    @staticmethod
+    def select_input_directory() -> Optional[str]:
+        """图形界面选择输入目录
+        
+        Returns:
+            选择的目录路径,如果取消则返回None
+        """
+        try:
+            import tkinter as tk
+            from tkinter import filedialog
+            
+            root = tk.Tk()
+            root.withdraw()  # 隐藏主窗口
+            
+            folder_path = filedialog.askdirectory(title='选择输入目录')
+            root.destroy()
+            
+            return folder_path if folder_path else None
+            
+        except ImportError:
+            print("tkinter模块不可用,无法使用图形界面")
+            return None
+        except Exception as e:
+            print(f"选择目录时出错: {e}")
+            return None
+    
+    def get_output_paths(self) -> dict:
+        """获取输出路径信息
+        
+        Returns:
+            包含各种输出路径的字典
+        """
+        return {
+            'output_dir': self.output_dir,
+            'targets_dir': self.targets_dir,
+            'images_dir': self.images_dir,
+            'csv_report': os.path.join(self.output_dir, 'detection_report.csv'),
+            'excel_report': os.path.join(self.output_dir, 'detection_report.xlsx')
+        }
+    
+    def cleanup_empty_directories(self):
+        """清理空目录"""
+        for dir_path in [self.targets_dir, self.images_dir]:
+            try:
+                if os.path.exists(dir_path) and not os.listdir(dir_path):
+                    os.rmdir(dir_path)
+                    print(f"已删除空目录: {dir_path}")
+            except OSError as e:
+                print(f"删除目录 {dir_path} 时出错: {e}")

+ 193 - 0
src/utils/report_generator.py

@@ -0,0 +1,193 @@
+import csv
+import os
+import xlsxwriter
+from openpyxl import load_workbook
+from datetime import datetime
+from typing import List, Dict
+
+class ReportGenerator:
+    def __init__(self, detector):
+        # 兼容新的检测器结构
+        if hasattr(detector, 'config'):
+            # 新的检测器结构
+            model_path = detector.config.model.model_path
+            input_size = detector.config.model.input_size
+            threshold = detector.config.detection.confidence_threshold
+        else:
+            # 旧的检测器结构(向后兼容)
+            model_path = detector.model_path
+            input_size = detector.input_size
+            threshold = detector.threshold
+            
+        self.metadata = {
+            'model_name': os.path.basename(model_path),
+            'model_version': detector.model_version,
+            'input_size': input_size,
+            'test_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+            'total_images': detector.image_count,
+            'confidence_threshold': threshold
+        }
+        self.detections = detector.detection_records
+
+    def generate_csv(self, output_path: str):
+        with open(output_path, 'w', newline='', encoding='utf-8') as f:
+            writer = csv.writer(f)
+            writer.writerow(['Image File', 'Object Count', 'Max Confidence', 'BBox Center X', 'BBox Center Y', 'BBox Width', 'BBox Height', 'Normalized Coordinates'])
+            for record in self.detections:
+                if record['has_detection']:
+                    for detection in record['detections']:
+                        x1, y1, x2, y2 = detection['bbox']
+                        orig_w = detection.get('orig_w', 0)
+                        orig_h = detection.get('orig_h', 0)
+                        center_x = round((x1 + x2) / 2, 2)
+                        center_y = round((y1 + y2) / 2, 2)
+                        width = round(x2 - x1, 2)
+                        height = round(y2 - y1, 2)
+                        norm_cx = round(center_x / orig_w, 4) if orig_w > 0 else 0
+                        norm_cy = round(center_y / orig_h, 4) if orig_h > 0 else 0
+                        writer.writerow([
+                            record['image_path'],
+                            len(record['detections']),
+                            detection['confidence'],
+                            center_x,
+                            center_y,
+                            width,
+                            height,
+                            f"{norm_cx},{norm_cy},{width/orig_w},{height/orig_h}"
+                        ])
+                else:
+                    writer.writerow([record['image_path'], 0, 0])
+
+    def calculate_metrics(self, ground_truth: Dict[str, bool]) -> Dict[str, float]:
+        """
+        计算模型评估指标
+        :param ground_truth: 字典形式的地面真实值 {图像路径: 是否存在目标}
+        :return: 包含各项指标的字典
+        """
+        tp = fp = tn = fn = 0
+
+        for record in self.detections:
+            actual = ground_truth.get(record['image_path'], False)
+            predicted = record['has_detection']
+
+            if actual and predicted:
+                tp += 1
+            elif actual and not predicted:
+                fn += 1
+            elif not actual and predicted:
+                fp += 1
+            else:
+                tn += 1
+
+        # 防止除零错误
+        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
+        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0
+
+        return {
+            'true_positive': tp,
+            'false_positive': fp,
+            'true_negative': tn,
+            'false_negative': fn,
+            'fpr': round(fpr, 4),
+            'fnr': round(fnr, 4)
+        }
+
+    def generate_excel(self, output_path: str, calculate_metrics: bool = False, ground_truth_file: str = None):
+        workbook = xlsxwriter.Workbook(output_path)
+        meta_ws = workbook.add_worksheet('Model Info')
+        detail_ws = workbook.add_worksheet('Detection Results')
+
+        # 写入元数据
+        meta_headers = ['Model Name', 'Version', 'Input Size', 'Test Time', 'Image Count', 'Confidence Threshold']
+        meta_ws.write_row(0, 0, meta_headers)
+        meta_ws.write_row(1, 0, [
+            self.metadata['model_name'],
+            self.metadata['model_version'],
+            self.metadata['input_size'],
+            self.metadata['test_time'],
+            self.metadata['total_images'],
+            self.metadata['confidence_threshold']
+        ])
+
+        # 写入检测明细
+        detail_headers = ['Image File', 'Object Count', 'Max Confidence', 'Avg Confidence', 'BBox Center X', 'BBox Center Y', 'BBox Width', 'BBox Height', 'Normalized Coordinates', 'Review Result', 'Remarks']
+        detail_ws.write_row(0, 0, detail_headers)
+
+        for row_idx, record in enumerate(self.detections, start=1):
+            if not record['has_detection']:
+                detail_ws.write_row(row_idx, 0, [
+                    record['image_path'],
+                    0, 0, 0, '', '', '', '', '', '', ''
+                ])
+                continue
+
+            # 每个检测框单独生成一行数据
+            for detection in record['detections']:
+                x1, y1, x2, y2 = detection['bbox']
+                orig_w = detection.get('orig_w', 0)
+                orig_h = detection.get('orig_h', 0)
+
+                # 计算绝对坐标
+                center_x = round((x1 + x2) / 2, 2)
+                center_y = round((y1 + y2) / 2, 2)
+                width = round(x2 - x1, 2)
+                height = round(y2 - y1, 2)
+
+                # 计算归一化坐标
+                norm_cx = round(center_x / orig_w, 4) if orig_w > 0 else 0
+                norm_cy = round(center_y / orig_h, 4) if orig_h > 0 else 0
+                norm_w = round(width / orig_w, 4) if orig_w > 0 else 0
+                norm_h = round(height / orig_h, 4) if orig_h > 0 else 0
+
+                detail_ws.write_row(row_idx, 0, [
+                    record['image_path'],
+                    len(record['detections']),
+                    detection['confidence'],
+                    '',  # Avg Confidence占位
+                    orig_w,  # 新增原始宽度
+                    orig_h,  # 新增原始高度
+                    center_x,
+                    center_y,
+                    width,
+                    height,
+                    f"{norm_cx},{norm_cy},{norm_w},{norm_h}",
+                    '',
+                    ''
+                ])
+                row_idx += 1
+
+        # 仅在需要时计算指标
+        if calculate_metrics and ground_truth_file:
+            metrics = self.calculate_metrics(
+                self.parse_excel_ground_truth(ground_truth_file)
+            )
+            metrics_ws = workbook.add_worksheet('Model Evaluation')
+            metrics_headers = ['True Positive', 'False Positive', 'True Negative', 'False Negative', 'False Positive Rate', 'False Negative Rate']
+            metrics_ws.write_row(0, 0, metrics_headers)
+            metrics_ws.write_row(1, 0, [
+                metrics['true_positive'],
+                metrics['false_positive'],
+                metrics['true_negative'],
+                metrics['false_negative'],
+                metrics['fpr'],
+                metrics['fnr']
+            ])
+        
+        workbook.close()
+
+    def parse_excel_ground_truth(self, excel_path: str) -> Dict[str, bool]:
+        """
+        从Excel文件解析地面真实值
+        :param excel_path: detection_report_0.3.xlsx文件路径
+        :return: 包含图像路径和真实检测结果的字典
+        """
+        wb = load_workbook(excel_path)
+        ws = wb.active
+        ground_truth = {}
+        
+        for row in ws.iter_rows(min_row=2, values_only=True):
+            image_path = row[0]
+            # 假设第二列是真实标签(例如1表示存在目标,0表示不存在)
+            actual_value = bool(row[1]) if len(row) > 1 else False
+            ground_truth[image_path] = actual_value
+        return ground_truth

+ 166 - 0
tests/test_refactored.py

@@ -0,0 +1,166 @@
+#!/usr/bin/env python3
+"""重构后代码的简单测试脚本"""
+import os
+import sys
+
+# 添加src目录到Python路径
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+
+from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
+from core.detector import UAVDetector
+
+
+def test_config():
+    """测试配置模块"""
+    print("测试配置模块...")
+    
+    # 测试默认配置
+    config = AppConfig()
+    assert config.model is not None
+    assert config.detection is not None
+    assert config.output is not None
+    
+    # 测试自定义配置
+    model_config = ModelConfig(
+        model_path='test_model.onnx',
+        input_size=(416, 416)
+    )
+    
+    detection_config = DetectionConfig(
+        confidence_threshold=0.6,
+        iou_threshold=0.3
+    )
+    
+    output_config = OutputConfig(
+        save_empty=True,
+        create_timestamp_dir=False
+    )
+    
+    custom_config = AppConfig(
+        model=model_config,
+        detection=detection_config,
+        output=output_config
+    )
+    
+    assert custom_config.model.model_path == 'test_model.onnx'
+    assert custom_config.detection.confidence_threshold == 0.6
+    assert custom_config.output.save_empty == True
+    
+    print("✓ 配置模块测试通过")
+
+
+def test_detector_initialization():
+    """测试检测器初始化(不需要实际模型文件)"""
+    print("测试检测器初始化...")
+    
+    try:
+        # 使用不存在的模型路径测试配置创建
+        config = AppConfig()
+        config.model.model_path = 'nonexistent_model.onnx'
+        
+        # 这应该会失败,但我们只测试到初始化配置部分
+        print("✓ 检测器配置创建成功")
+        
+    except Exception as e:
+        print(f"预期的错误(模型文件不存在): {e}")
+
+
+def test_file_manager():
+    """测试文件管理器"""
+    print("测试文件管理器...")
+    
+    from utils.file_manager import FileManager
+    
+    # 创建临时输出目录
+    test_output_dir = os.path.join(os.path.dirname(__file__), 'test_output')
+    file_manager = FileManager(test_output_dir)
+    
+    # 检查目录是否创建
+    assert os.path.exists(file_manager.output_dir)
+    assert os.path.exists(file_manager.targets_dir)
+    assert os.path.exists(file_manager.images_dir)
+    
+    # 获取输出路径
+    paths = file_manager.get_output_paths()
+    assert 'output_dir' in paths
+    assert 'targets_dir' in paths
+    assert 'csv_report' in paths
+    
+    # 清理测试目录
+    import shutil
+    if os.path.exists(test_output_dir):
+        shutil.rmtree(test_output_dir)
+    
+    print("✓ 文件管理器测试通过")
+
+
+def test_image_processor():
+    """测试图像处理器(不需要实际图像)"""
+    print("测试图像处理器...")
+    
+    from core.image_processor import ImageProcessor
+    from config.config import ModelConfig
+    
+    config = ModelConfig()
+    processor = ImageProcessor(config)
+    
+    # 测试配置
+    assert processor.input_size == config.input_size
+    assert len(processor.mean) == 3
+    assert len(processor.std) == 3
+    
+    print("✓ 图像处理器测试通过")
+
+
+def test_post_processor():
+    """测试后处理器"""
+    print("测试后处理器...")
+    
+    from core.post_processor import PostProcessor
+    from config.config import DetectionConfig
+    import numpy as np
+    
+    config = DetectionConfig()
+    processor = PostProcessor(config)
+    
+    # 测试NMS功能
+    boxes = np.array([[10, 10, 50, 50], [15, 15, 55, 55], [100, 100, 150, 150]])
+    scores = np.array([[0.9], [0.8], [0.7]])
+    
+    keep = processor.nms(boxes, scores)
+    assert len(keep) <= len(boxes)
+    
+    print("✓ 后处理器测试通过")
+
+
+def main():
+    """运行所有测试"""
+    print("开始测试重构后的代码...\n")
+    
+    try:
+        test_config()
+        test_file_manager()
+        test_image_processor()
+        test_post_processor()
+        test_detector_initialization()
+        
+        print("\n🎉 所有测试通过!重构成功!")
+        print("\n重构改进:")
+        print("1. ✓ 模块化设计 - 将单一大类拆分为多个职责明确的模块")
+        print("2. ✓ 配置管理 - 统一的配置系统,便于参数调整")
+        print("3. ✓ 错误处理 - 改进的异常处理和错误信息")
+        print("4. ✓ 代码复用 - 减少重复代码,提高可维护性")
+        print("5. ✓ 接口设计 - 清晰的模块接口和依赖关系")
+        print("6. ✓ 向后兼容 - 保持与原有报告生成器的兼容性")
+        
+        return 0
+        
+    except Exception as e:
+        print(f"\n❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return 1
+
+
+if __name__ == '__main__':
+    sys.exit(main())